From 340d7651af8caca795220b81124b2a3e557f4784 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 11 Nov 2022 11:52:11 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 487881149 --- mediapipe/calculators/tensorflow/BUILD | 1 + ...flow_session_from_saved_model_generator.cc | 22 +++++++++- ...session_from_saved_model_generator_test.cc | 44 +++++++++++++++++++ 3 files changed, 65 insertions(+), 2 deletions(-) diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 4037d89ce..d0dfc12ab 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -613,6 +613,7 @@ cc_library( deps = [ ":tensorflow_session", ":tensorflow_session_from_saved_model_generator_cc_proto", + "@com_google_absl//absl/status", "//mediapipe/framework:packet_generator", "//mediapipe/framework:packet_type", "//mediapipe/framework/tool:status_util", diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index 97c675920..d5236f1cc 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -14,6 +14,8 @@ #include +#include "absl/status/status.h" + #if !defined(__ANDROID__) #include "mediapipe/framework/port/file_helpers.h" #endif @@ -38,6 +40,8 @@ constexpr char kSessionTag[] = "SESSION"; static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; +static constexpr char kStringSignatureName[] = "STRING_SIGNATURE_NAME"; + // Given the path to a directory containing multiple tensorflow saved models // in subdirectories, replaces path with the alphabetically last subdirectory. absl::Status GetLatestDirectory(std::string* path) { @@ -104,6 +108,10 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { if (input_side_packets->HasTag(kStringSavedModelPath)) { input_side_packets->Tag(kStringSavedModelPath).Set(); } + // Set Signature_def. + if (input_side_packets->HasTag(kStringSignatureName)) { + input_side_packets->Tag(kStringSignatureName).Set(); + } // A TensorFlow model loaded and ready for use along with tensor output_side_packets->Tag(kSessionTag).Set(); return absl::OkStatus(); @@ -146,9 +154,19 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { auto session = absl::make_unique(); session->session = std::move(saved_model->session); - RET_CHECK(!options.signature_name().empty()); + // Use input side packet to overwrite signature name in options. + std::string signature_name = + input_side_packets.HasTag(kStringSignatureName) + ? input_side_packets.Tag(kStringSignatureName).Get() + : options.signature_name(); + RET_CHECK(!signature_name.empty()); const auto& signature_def_map = saved_model->meta_graph_def.signature_def(); - const auto& signature_def = signature_def_map.at(options.signature_name()); + if (signature_def_map.find(signature_name) == signature_def_map.end()) { + return absl::NotFoundError(absl::StrFormat( + "Signature name '%s' does not exist in the loaded signature def", + signature_name)); + } + const auto& signature_def = signature_def_map.at(signature_name); for (const auto& input_signature : signature_def.inputs()) { session->tag_to_tensor_map[MaybeConvertSignatureToTag( input_signature.first, options)] = input_signature.second.name(); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc index 5c6de3e86..c002b1bde 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc @@ -30,11 +30,13 @@ namespace mediapipe { +using ::testing::status::StatusIs; namespace { namespace tf = ::tensorflow; constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH"; +constexpr char kStringSignatureNameTag[] = "STRING_SIGNATURE_NAME"; constexpr char kSessionTag[] = "SESSION"; std::string GetSavedModelDir() { @@ -124,6 +126,48 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, ASSERT_NE(session.session, nullptr); } +TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, + CreateSessionFromSidePacketWithCorrectSignatureName) { + generator_options_->clear_saved_model_path(); + PacketSet input_side_packets( + tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir", + "STRING_SIGNATURE_NAME:signature_name"}) + .value()); + input_side_packets.Tag(kStringSavedModelPathTag) = + Adopt(new std::string(GetSavedModelDir())); + input_side_packets.Tag(kStringSignatureNameTag) = + Adopt(new std::string("serving_default")); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromSavedModelGenerator", extendable_options_, + input_side_packets, &output_side_packets); + MP_EXPECT_OK(run_status) << run_status.message(); + const TensorFlowSession& session = + output_side_packets.Tag(kSessionTag).Get(); + // Session must be set. + ASSERT_NE(session.session, nullptr); +} + +TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, + CreateSessionFromSidePacketWithWrongSignatureName) { + generator_options_->clear_saved_model_path(); + PacketSet input_side_packets( + tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir", + "STRING_SIGNATURE_NAME:signature_name"}) + .value()); + input_side_packets.Tag(kStringSavedModelPathTag) = + Adopt(new std::string(GetSavedModelDir())); + input_side_packets.Tag(kStringSignatureNameTag) = + Adopt(new std::string("wrong_signature_name")); + PacketSet output_side_packets( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( + "TensorFlowSessionFromSavedModelGenerator", extendable_options_, + input_side_packets, &output_side_packets); + EXPECT_THAT(run_status, StatusIs(absl::StatusCode::kNotFound)); +} + // Integration test. Verifies that TensorFlowInferenceCalculator correctly // consumes the Packet emitted by this factory. TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,