Internal change
PiperOrigin-RevId: 487881149
This commit is contained in:
parent
a83d87e157
commit
340d7651af
|
@ -613,6 +613,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":tensorflow_session",
|
":tensorflow_session",
|
||||||
":tensorflow_session_from_saved_model_generator_cc_proto",
|
":tensorflow_session_from_saved_model_generator_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
"//mediapipe/framework:packet_generator",
|
"//mediapipe/framework:packet_generator",
|
||||||
"//mediapipe/framework:packet_type",
|
"//mediapipe/framework:packet_type",
|
||||||
"//mediapipe/framework/tool:status_util",
|
"//mediapipe/framework/tool:status_util",
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
|
||||||
#if !defined(__ANDROID__)
|
#if !defined(__ANDROID__)
|
||||||
#include "mediapipe/framework/port/file_helpers.h"
|
#include "mediapipe/framework/port/file_helpers.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -38,6 +40,8 @@ constexpr char kSessionTag[] = "SESSION";
|
||||||
|
|
||||||
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
|
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
|
||||||
|
|
||||||
|
static constexpr char kStringSignatureName[] = "STRING_SIGNATURE_NAME";
|
||||||
|
|
||||||
// Given the path to a directory containing multiple tensorflow saved models
|
// Given the path to a directory containing multiple tensorflow saved models
|
||||||
// in subdirectories, replaces path with the alphabetically last subdirectory.
|
// in subdirectories, replaces path with the alphabetically last subdirectory.
|
||||||
absl::Status GetLatestDirectory(std::string* path) {
|
absl::Status GetLatestDirectory(std::string* path) {
|
||||||
|
@ -104,6 +108,10 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
|
||||||
if (input_side_packets->HasTag(kStringSavedModelPath)) {
|
if (input_side_packets->HasTag(kStringSavedModelPath)) {
|
||||||
input_side_packets->Tag(kStringSavedModelPath).Set<std::string>();
|
input_side_packets->Tag(kStringSavedModelPath).Set<std::string>();
|
||||||
}
|
}
|
||||||
|
// Set Signature_def.
|
||||||
|
if (input_side_packets->HasTag(kStringSignatureName)) {
|
||||||
|
input_side_packets->Tag(kStringSignatureName).Set<std::string>();
|
||||||
|
}
|
||||||
// A TensorFlow model loaded and ready for use along with tensor
|
// A TensorFlow model loaded and ready for use along with tensor
|
||||||
output_side_packets->Tag(kSessionTag).Set<TensorFlowSession>();
|
output_side_packets->Tag(kSessionTag).Set<TensorFlowSession>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -146,9 +154,19 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
|
||||||
auto session = absl::make_unique<TensorFlowSession>();
|
auto session = absl::make_unique<TensorFlowSession>();
|
||||||
session->session = std::move(saved_model->session);
|
session->session = std::move(saved_model->session);
|
||||||
|
|
||||||
RET_CHECK(!options.signature_name().empty());
|
// Use input side packet to overwrite signature name in options.
|
||||||
|
std::string signature_name =
|
||||||
|
input_side_packets.HasTag(kStringSignatureName)
|
||||||
|
? input_side_packets.Tag(kStringSignatureName).Get<std::string>()
|
||||||
|
: options.signature_name();
|
||||||
|
RET_CHECK(!signature_name.empty());
|
||||||
const auto& signature_def_map = saved_model->meta_graph_def.signature_def();
|
const auto& signature_def_map = saved_model->meta_graph_def.signature_def();
|
||||||
const auto& signature_def = signature_def_map.at(options.signature_name());
|
if (signature_def_map.find(signature_name) == signature_def_map.end()) {
|
||||||
|
return absl::NotFoundError(absl::StrFormat(
|
||||||
|
"Signature name '%s' does not exist in the loaded signature def",
|
||||||
|
signature_name));
|
||||||
|
}
|
||||||
|
const auto& signature_def = signature_def_map.at(signature_name);
|
||||||
for (const auto& input_signature : signature_def.inputs()) {
|
for (const auto& input_signature : signature_def.inputs()) {
|
||||||
session->tag_to_tensor_map[MaybeConvertSignatureToTag(
|
session->tag_to_tensor_map[MaybeConvertSignatureToTag(
|
||||||
input_signature.first, options)] = input_signature.second.name();
|
input_signature.first, options)] = input_signature.second.name();
|
||||||
|
|
|
@ -30,11 +30,13 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
using ::testing::status::StatusIs;
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
|
|
||||||
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
|
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
|
||||||
|
constexpr char kStringSignatureNameTag[] = "STRING_SIGNATURE_NAME";
|
||||||
constexpr char kSessionTag[] = "SESSION";
|
constexpr char kSessionTag[] = "SESSION";
|
||||||
|
|
||||||
std::string GetSavedModelDir() {
|
std::string GetSavedModelDir() {
|
||||||
|
@ -124,6 +126,48 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||||
|
CreateSessionFromSidePacketWithCorrectSignatureName) {
|
||||||
|
generator_options_->clear_saved_model_path();
|
||||||
|
PacketSet input_side_packets(
|
||||||
|
tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir",
|
||||||
|
"STRING_SIGNATURE_NAME:signature_name"})
|
||||||
|
.value());
|
||||||
|
input_side_packets.Tag(kStringSavedModelPathTag) =
|
||||||
|
Adopt(new std::string(GetSavedModelDir()));
|
||||||
|
input_side_packets.Tag(kStringSignatureNameTag) =
|
||||||
|
Adopt(new std::string("serving_default"));
|
||||||
|
PacketSet output_side_packets(
|
||||||
|
tool::CreateTagMap({"SESSION:session"}).value());
|
||||||
|
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||||
|
"TensorFlowSessionFromSavedModelGenerator", extendable_options_,
|
||||||
|
input_side_packets, &output_side_packets);
|
||||||
|
MP_EXPECT_OK(run_status) << run_status.message();
|
||||||
|
const TensorFlowSession& session =
|
||||||
|
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
|
// Session must be set.
|
||||||
|
ASSERT_NE(session.session, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||||
|
CreateSessionFromSidePacketWithWrongSignatureName) {
|
||||||
|
generator_options_->clear_saved_model_path();
|
||||||
|
PacketSet input_side_packets(
|
||||||
|
tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir",
|
||||||
|
"STRING_SIGNATURE_NAME:signature_name"})
|
||||||
|
.value());
|
||||||
|
input_side_packets.Tag(kStringSavedModelPathTag) =
|
||||||
|
Adopt(new std::string(GetSavedModelDir()));
|
||||||
|
input_side_packets.Tag(kStringSignatureNameTag) =
|
||||||
|
Adopt(new std::string("wrong_signature_name"));
|
||||||
|
PacketSet output_side_packets(
|
||||||
|
tool::CreateTagMap({"SESSION:session"}).value());
|
||||||
|
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||||
|
"TensorFlowSessionFromSavedModelGenerator", extendable_options_,
|
||||||
|
input_side_packets, &output_side_packets);
|
||||||
|
EXPECT_THAT(run_status, StatusIs(absl::StatusCode::kNotFound));
|
||||||
|
}
|
||||||
|
|
||||||
// Integration test. Verifies that TensorFlowInferenceCalculator correctly
|
// Integration test. Verifies that TensorFlowInferenceCalculator correctly
|
||||||
// consumes the Packet emitted by this factory.
|
// consumes the Packet emitted by this factory.
|
||||||
TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user