Internal change
PiperOrigin-RevId: 487881149
This commit is contained in:
parent
a83d87e157
commit
340d7651af
|
@ -613,6 +613,7 @@ cc_library(
|
|||
deps = [
|
||||
":tensorflow_session",
|
||||
":tensorflow_session_from_saved_model_generator_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"//mediapipe/framework:packet_generator",
|
||||
"//mediapipe/framework:packet_type",
|
||||
"//mediapipe/framework/tool:status_util",
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
#include <algorithm>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
|
||||
#if !defined(__ANDROID__)
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#endif
|
||||
|
@ -38,6 +40,8 @@ constexpr char kSessionTag[] = "SESSION";
|
|||
|
||||
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
|
||||
|
||||
static constexpr char kStringSignatureName[] = "STRING_SIGNATURE_NAME";
|
||||
|
||||
// Given the path to a directory containing multiple tensorflow saved models
|
||||
// in subdirectories, replaces path with the alphabetically last subdirectory.
|
||||
absl::Status GetLatestDirectory(std::string* path) {
|
||||
|
@ -104,6 +108,10 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
|
|||
if (input_side_packets->HasTag(kStringSavedModelPath)) {
|
||||
input_side_packets->Tag(kStringSavedModelPath).Set<std::string>();
|
||||
}
|
||||
// Set Signature_def.
|
||||
if (input_side_packets->HasTag(kStringSignatureName)) {
|
||||
input_side_packets->Tag(kStringSignatureName).Set<std::string>();
|
||||
}
|
||||
// A TensorFlow model loaded and ready for use along with tensor
|
||||
output_side_packets->Tag(kSessionTag).Set<TensorFlowSession>();
|
||||
return absl::OkStatus();
|
||||
|
@ -146,9 +154,19 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
|
|||
auto session = absl::make_unique<TensorFlowSession>();
|
||||
session->session = std::move(saved_model->session);
|
||||
|
||||
RET_CHECK(!options.signature_name().empty());
|
||||
// Use input side packet to overwrite signature name in options.
|
||||
std::string signature_name =
|
||||
input_side_packets.HasTag(kStringSignatureName)
|
||||
? input_side_packets.Tag(kStringSignatureName).Get<std::string>()
|
||||
: options.signature_name();
|
||||
RET_CHECK(!signature_name.empty());
|
||||
const auto& signature_def_map = saved_model->meta_graph_def.signature_def();
|
||||
const auto& signature_def = signature_def_map.at(options.signature_name());
|
||||
if (signature_def_map.find(signature_name) == signature_def_map.end()) {
|
||||
return absl::NotFoundError(absl::StrFormat(
|
||||
"Signature name '%s' does not exist in the loaded signature def",
|
||||
signature_name));
|
||||
}
|
||||
const auto& signature_def = signature_def_map.at(signature_name);
|
||||
for (const auto& input_signature : signature_def.inputs()) {
|
||||
session->tag_to_tensor_map[MaybeConvertSignatureToTag(
|
||||
input_signature.first, options)] = input_signature.second.name();
|
||||
|
|
|
@ -30,11 +30,13 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::testing::status::StatusIs;
|
||||
namespace {
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
|
||||
constexpr char kStringSignatureNameTag[] = "STRING_SIGNATURE_NAME";
|
||||
constexpr char kSessionTag[] = "SESSION";
|
||||
|
||||
std::string GetSavedModelDir() {
|
||||
|
@ -124,6 +126,48 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
|||
ASSERT_NE(session.session, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||
CreateSessionFromSidePacketWithCorrectSignatureName) {
|
||||
generator_options_->clear_saved_model_path();
|
||||
PacketSet input_side_packets(
|
||||
tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir",
|
||||
"STRING_SIGNATURE_NAME:signature_name"})
|
||||
.value());
|
||||
input_side_packets.Tag(kStringSavedModelPathTag) =
|
||||
Adopt(new std::string(GetSavedModelDir()));
|
||||
input_side_packets.Tag(kStringSignatureNameTag) =
|
||||
Adopt(new std::string("serving_default"));
|
||||
PacketSet output_side_packets(
|
||||
tool::CreateTagMap({"SESSION:session"}).value());
|
||||
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||
"TensorFlowSessionFromSavedModelGenerator", extendable_options_,
|
||||
input_side_packets, &output_side_packets);
|
||||
MP_EXPECT_OK(run_status) << run_status.message();
|
||||
const TensorFlowSession& session =
|
||||
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
// Session must be set.
|
||||
ASSERT_NE(session.session, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||
CreateSessionFromSidePacketWithWrongSignatureName) {
|
||||
generator_options_->clear_saved_model_path();
|
||||
PacketSet input_side_packets(
|
||||
tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir",
|
||||
"STRING_SIGNATURE_NAME:signature_name"})
|
||||
.value());
|
||||
input_side_packets.Tag(kStringSavedModelPathTag) =
|
||||
Adopt(new std::string(GetSavedModelDir()));
|
||||
input_side_packets.Tag(kStringSignatureNameTag) =
|
||||
Adopt(new std::string("wrong_signature_name"));
|
||||
PacketSet output_side_packets(
|
||||
tool::CreateTagMap({"SESSION:session"}).value());
|
||||
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||
"TensorFlowSessionFromSavedModelGenerator", extendable_options_,
|
||||
input_side_packets, &output_side_packets);
|
||||
EXPECT_THAT(run_status, StatusIs(absl::StatusCode::kNotFound));
|
||||
}
|
||||
|
||||
// Integration test. Verifies that TensorFlowInferenceCalculator correctly
|
||||
// consumes the Packet emitted by this factory.
|
||||
TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||
|
|
Loading…
Reference in New Issue
Block a user