diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 16931811c..bb49bdb9d 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -178,6 +178,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", ], diff --git a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc index 3ea9bcca4..77b14339f 100644 --- a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "mediapipe/framework/api2/node.h" @@ -76,6 +77,7 @@ class TensorsToEmbeddingsCalculator : public Node { bool l2_normalize_; bool quantize_; std::vector head_names_; + absl::flat_hash_set ignored_head_names_; void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding); void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding); @@ -89,6 +91,9 @@ absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) { head_names_.assign(options.head_names().begin(), options.head_names().end()); } + for (const absl::string_view head_name : options.ignored_head_names()) { + ignored_head_names_.insert(std::string(head_name)); + } return absl::OkStatus(); } @@ -102,6 +107,9 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) { head_names_.size(), tensors.size())); } for (int i = 0; i < tensors.size(); ++i) { + if (!head_names_.empty() && ignored_head_names_.contains(head_names_[i])) { + continue; + } const auto& tensor = tensors[i]; RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32); auto* embedding = result.add_embeddings(); diff --git a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto index 745052afa..fd87383b4 100644 --- a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto +++ b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto @@ -32,4 +32,8 @@ message TensorsToEmbeddingsCalculatorOptions { // The embedder head names. repeated string head_names = 2; + + // Names of output heads that should not be included in the embeddings. + // Only applies if `head_names` is non-empty. + repeated string ignored_head_names = 3; } diff --git a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc index b79cf4863..94c95cf0e 100644 --- a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc @@ -129,6 +129,60 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) { })pb"))); } +TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNameIgnored) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToEmbeddingsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "EMBEDDINGS:embeddings" + options { + [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { + embedder_options { l2_normalize: false quantize: false } + head_names: "foo" + head_names: "bar" + ignored_head_names: "foo" + } + } + )pb")); + + BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); + MP_ASSERT_OK(runner.Run()); + + const EmbeddingResult& result = + runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get(); + EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie( + R"pb( + embeddings { + float_embedding { values: -0.2 values: -0.3 } + head_index: 1 + head_name: "bar" + })pb"))); +} + +TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithBothHeadsIgnored) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToEmbeddingsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "EMBEDDINGS:embeddings" + options { + [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { + embedder_options { l2_normalize: false quantize: false } + head_names: "foo" + head_names: "bar" + ignored_head_names: "foo" + ignored_head_names: "bar" + } + } + )pb")); + + BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); + MP_ASSERT_OK(runner.Run()); + + const EmbeddingResult& result = + runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get(); + EXPECT_THAT(result, + EqualsProto(ParseTextProtoOrDie(R"pb()pb"))); +} + TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "TensorsToEmbeddingsCalculator" diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 10bc0726a..dfa18e806 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -144,6 +144,7 @@ cc_library( "//mediapipe/calculators/tensor:regex_preprocessor_calculator", "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", "//mediapipe/calculators/tensor:text_to_tensor_calculator", + "//mediapipe/calculators/tensor:universal_sentence_encoder_preprocessor_calculator", "//mediapipe/framework:subgraph", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 809268a63..768508446 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/timestamp.h" +#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" @@ -156,7 +157,8 @@ class PostprocessingTest : public tflite_shims::testing::Test { protected: absl::StatusOr BuildGraph( absl::string_view model_name, const proto::EmbedderOptions& options, - bool connect_timestamps = false) { + bool connect_timestamps = false, + const std::vector& ignored_head_names = {}) { ASSIGN_OR_RETURN(auto model_resources, CreateModelResourcesForModel(model_name)); @@ -164,10 +166,15 @@ class PostprocessingTest : public tflite_shims::testing::Test { auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors." "EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph( - *model_resources, options, + auto* postprocessing_options = &postprocessing - .GetOptions())); + .GetOptions(); + for (const absl::string_view head_name : ignored_head_names) { + postprocessing_options->mutable_tensors_to_embeddings_options() + ->add_ignored_head_names(std::string(head_name)); + } + MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph( + *model_resources, options, postprocessing_options)); graph[Input>(kTensorsTag)].SetName(kTensorsName) >> postprocessing.In(kTensorsTag); if (connect_timestamps) { @@ -274,6 +281,28 @@ TEST_F(PostprocessingTest, SucceedsWithoutAggregation) { } } +TEST_F(PostprocessingTest, SucceedsWithFilter) { + // Build graph. + proto::EmbedderOptions options; + MP_ASSERT_OK_AND_ASSIGN( + auto poller, + BuildGraph(kMobileNetV3Embedder, options, /*connect_timestamps=*/false, + /*ignored_head_names=*/{"feature"})); + // Build input tensor. + std::vector tensor(kMobileNetV3EmbedderEmbeddingSize, 0); + tensor[0] = 1.0; + + // Send tensor and get results. + AddTensor(tensor, Tensor::ElementType::kFloat32); + MP_ASSERT_OK(Run()); + MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult(poller)); + + // Validate results. + EXPECT_TRUE(results.has_timestamp_ms()); + EXPECT_EQ(results.timestamp_ms(), 0); + EXPECT_EQ(results.embeddings_size(), 0); +} + TEST_F(PostprocessingTest, SucceedsWithAggregation) { // Build graph. proto::EmbedderOptions options;