Internal change

PiperOrigin-RevId: 487881149
This commit is contained in:
MediaPipe Team 2022-11-11 11:52:11 -08:00 committed by Copybara-Service
parent a83d87e157
commit 340d7651af
3 changed files with 65 additions and 2 deletions

View File

@ -613,6 +613,7 @@ cc_library(
deps = [
":tensorflow_session",
":tensorflow_session_from_saved_model_generator_cc_proto",
"@com_google_absl//absl/status",
"//mediapipe/framework:packet_generator",
"//mediapipe/framework:packet_type",
"//mediapipe/framework/tool:status_util",

View File

@ -14,6 +14,8 @@
#include <algorithm>
#include "absl/status/status.h"
#if !defined(__ANDROID__)
#include "mediapipe/framework/port/file_helpers.h"
#endif
@ -38,6 +40,8 @@ constexpr char kSessionTag[] = "SESSION";
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
static constexpr char kStringSignatureName[] = "STRING_SIGNATURE_NAME";
// Given the path to a directory containing multiple tensorflow saved models
// in subdirectories, replaces path with the alphabetically last subdirectory.
absl::Status GetLatestDirectory(std::string* path) {
@ -104,6 +108,10 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
if (input_side_packets->HasTag(kStringSavedModelPath)) {
input_side_packets->Tag(kStringSavedModelPath).Set<std::string>();
}
// Set Signature_def.
if (input_side_packets->HasTag(kStringSignatureName)) {
input_side_packets->Tag(kStringSignatureName).Set<std::string>();
}
// A TensorFlow model loaded and ready for use along with tensor
output_side_packets->Tag(kSessionTag).Set<TensorFlowSession>();
return absl::OkStatus();
@ -146,9 +154,19 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
auto session = absl::make_unique<TensorFlowSession>();
session->session = std::move(saved_model->session);
RET_CHECK(!options.signature_name().empty());
// Use input side packet to overwrite signature name in options.
std::string signature_name =
input_side_packets.HasTag(kStringSignatureName)
? input_side_packets.Tag(kStringSignatureName).Get<std::string>()
: options.signature_name();
RET_CHECK(!signature_name.empty());
const auto& signature_def_map = saved_model->meta_graph_def.signature_def();
const auto& signature_def = signature_def_map.at(options.signature_name());
if (signature_def_map.find(signature_name) == signature_def_map.end()) {
return absl::NotFoundError(absl::StrFormat(
"Signature name '%s' does not exist in the loaded signature def",
signature_name));
}
const auto& signature_def = signature_def_map.at(signature_name);
for (const auto& input_signature : signature_def.inputs()) {
session->tag_to_tensor_map[MaybeConvertSignatureToTag(
input_signature.first, options)] = input_signature.second.name();

View File

@ -30,11 +30,13 @@
namespace mediapipe {
using ::testing::status::StatusIs;
namespace {
namespace tf = ::tensorflow;
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
constexpr char kStringSignatureNameTag[] = "STRING_SIGNATURE_NAME";
constexpr char kSessionTag[] = "SESSION";
std::string GetSavedModelDir() {
@ -124,6 +126,48 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
ASSERT_NE(session.session, nullptr);
}
TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
CreateSessionFromSidePacketWithCorrectSignatureName) {
generator_options_->clear_saved_model_path();
PacketSet input_side_packets(
tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir",
"STRING_SIGNATURE_NAME:signature_name"})
.value());
input_side_packets.Tag(kStringSavedModelPathTag) =
Adopt(new std::string(GetSavedModelDir()));
input_side_packets.Tag(kStringSignatureNameTag) =
Adopt(new std::string("serving_default"));
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).value());
absl::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromSavedModelGenerator", extendable_options_,
input_side_packets, &output_side_packets);
MP_EXPECT_OK(run_status) << run_status.message();
const TensorFlowSession& session =
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
}
TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
CreateSessionFromSidePacketWithWrongSignatureName) {
generator_options_->clear_saved_model_path();
PacketSet input_side_packets(
tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir",
"STRING_SIGNATURE_NAME:signature_name"})
.value());
input_side_packets.Tag(kStringSavedModelPathTag) =
Adopt(new std::string(GetSavedModelDir()));
input_side_packets.Tag(kStringSignatureNameTag) =
Adopt(new std::string("wrong_signature_name"));
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).value());
absl::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromSavedModelGenerator", extendable_options_,
input_side_packets, &output_side_packets);
EXPECT_THAT(run_status, StatusIs(absl::StatusCode::kNotFound));
}
// Integration test. Verifies that TensorFlowInferenceCalculator correctly
// consumes the Packet emitted by this factory.
TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,