Internal change

PiperOrigin-RevId: 508173086
This commit is contained in:
MediaPipe Team 2023-02-08 13:49:13 -08:00 committed by Copybara-Service
parent 6c4ebd2d93
commit 6ea2d579e1
6 changed files with 101 additions and 4 deletions

View File

@ -178,6 +178,7 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
], ],

View File

@ -17,6 +17,7 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
@ -76,6 +77,7 @@ class TensorsToEmbeddingsCalculator : public Node {
bool l2_normalize_; bool l2_normalize_;
bool quantize_; bool quantize_;
std::vector<std::string> head_names_; std::vector<std::string> head_names_;
absl::flat_hash_set<std::string> ignored_head_names_;
void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding); void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding);
void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding); void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding);
@ -89,6 +91,9 @@ absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
head_names_.assign(options.head_names().begin(), head_names_.assign(options.head_names().begin(),
options.head_names().end()); options.head_names().end());
} }
for (const absl::string_view head_name : options.ignored_head_names()) {
ignored_head_names_.insert(std::string(head_name));
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -102,6 +107,9 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
head_names_.size(), tensors.size())); head_names_.size(), tensors.size()));
} }
for (int i = 0; i < tensors.size(); ++i) { for (int i = 0; i < tensors.size(); ++i) {
if (!head_names_.empty() && ignored_head_names_.contains(head_names_[i])) {
continue;
}
const auto& tensor = tensors[i]; const auto& tensor = tensors[i];
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32); RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
auto* embedding = result.add_embeddings(); auto* embedding = result.add_embeddings();

View File

@ -32,4 +32,8 @@ message TensorsToEmbeddingsCalculatorOptions {
// The embedder head names. // The embedder head names.
repeated string head_names = 2; repeated string head_names = 2;
// Names of output heads that should not be included in the embeddings.
// Only applies if `head_names` is non-empty.
repeated string ignored_head_names = 3;
} }

View File

@ -129,6 +129,60 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
})pb"))); })pb")));
} }
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNameIgnored) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDINGS:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false }
head_names: "foo"
head_names: "bar"
ignored_head_names: "foo"
}
}
)pb"));
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result =
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(
embeddings {
float_embedding { values: -0.2 values: -0.3 }
head_index: 1
head_name: "bar"
})pb")));
}
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithBothHeadsIgnored) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDINGS:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false }
head_names: "foo"
head_names: "bar"
ignored_head_names: "foo"
ignored_head_names: "bar"
}
}
)pb"));
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result =
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(R"pb()pb")));
}
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) { TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"

View File

@ -144,6 +144,7 @@ cc_library(
"//mediapipe/calculators/tensor:regex_preprocessor_calculator", "//mediapipe/calculators/tensor:regex_preprocessor_calculator",
"//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto",
"//mediapipe/calculators/tensor:text_to_tensor_calculator", "//mediapipe/calculators/tensor:text_to_tensor_calculator",
"//mediapipe/calculators/tensor:universal_sentence_encoder_preprocessor_calculator",
"//mediapipe/framework:subgraph", "//mediapipe/framework:subgraph",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
@ -156,7 +157,8 @@ class PostprocessingTest : public tflite_shims::testing::Test {
protected: protected:
absl::StatusOr<OutputStreamPoller> BuildGraph( absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const proto::EmbedderOptions& options, absl::string_view model_name, const proto::EmbedderOptions& options,
bool connect_timestamps = false) { bool connect_timestamps = false,
const std::vector<absl::string_view>& ignored_head_names = {}) {
ASSIGN_OR_RETURN(auto model_resources, ASSIGN_OR_RETURN(auto model_resources,
CreateModelResourcesForModel(model_name)); CreateModelResourcesForModel(model_name));
@ -164,10 +166,15 @@ class PostprocessingTest : public tflite_shims::testing::Test {
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors." "mediapipe.tasks.components.processors."
"EmbeddingPostprocessingGraph"); "EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph( auto* postprocessing_options =
*model_resources, options,
&postprocessing &postprocessing
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>())); .GetOptions<proto::EmbeddingPostprocessingGraphOptions>();
for (const absl::string_view head_name : ignored_head_names) {
postprocessing_options->mutable_tensors_to_embeddings_options()
->add_ignored_head_names(std::string(head_name));
}
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph(
*model_resources, options, postprocessing_options));
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >> graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
postprocessing.In(kTensorsTag); postprocessing.In(kTensorsTag);
if (connect_timestamps) { if (connect_timestamps) {
@ -274,6 +281,28 @@ TEST_F(PostprocessingTest, SucceedsWithoutAggregation) {
} }
} }
TEST_F(PostprocessingTest, SucceedsWithFilter) {
// Build graph.
proto::EmbedderOptions options;
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
BuildGraph(kMobileNetV3Embedder, options, /*connect_timestamps=*/false,
/*ignored_head_names=*/{"feature"}));
// Build input tensor.
std::vector<float> tensor(kMobileNetV3EmbedderEmbeddingSize, 0);
tensor[0] = 1.0;
// Send tensor and get results.
AddTensor(tensor, Tensor::ElementType::kFloat32);
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult<EmbeddingResult>(poller));
// Validate results.
EXPECT_TRUE(results.has_timestamp_ms());
EXPECT_EQ(results.timestamp_ms(), 0);
EXPECT_EQ(results.embeddings_size(), 0);
}
TEST_F(PostprocessingTest, SucceedsWithAggregation) { TEST_F(PostprocessingTest, SucceedsWithAggregation) {
// Build graph. // Build graph.
proto::EmbedderOptions options; proto::EmbedderOptions options;