Internal change
PiperOrigin-RevId: 508173086
This commit is contained in:
parent
6c4ebd2d93
commit
6ea2d579e1
|
@ -178,6 +178,7 @@ cc_library(
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
],
|
],
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
@ -76,6 +77,7 @@ class TensorsToEmbeddingsCalculator : public Node {
|
||||||
bool l2_normalize_;
|
bool l2_normalize_;
|
||||||
bool quantize_;
|
bool quantize_;
|
||||||
std::vector<std::string> head_names_;
|
std::vector<std::string> head_names_;
|
||||||
|
absl::flat_hash_set<std::string> ignored_head_names_;
|
||||||
|
|
||||||
void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding);
|
void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding);
|
||||||
void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding);
|
void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding);
|
||||||
|
@ -89,6 +91,9 @@ absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
|
||||||
head_names_.assign(options.head_names().begin(),
|
head_names_.assign(options.head_names().begin(),
|
||||||
options.head_names().end());
|
options.head_names().end());
|
||||||
}
|
}
|
||||||
|
for (const absl::string_view head_name : options.ignored_head_names()) {
|
||||||
|
ignored_head_names_.insert(std::string(head_name));
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,6 +107,9 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
|
||||||
head_names_.size(), tensors.size()));
|
head_names_.size(), tensors.size()));
|
||||||
}
|
}
|
||||||
for (int i = 0; i < tensors.size(); ++i) {
|
for (int i = 0; i < tensors.size(); ++i) {
|
||||||
|
if (!head_names_.empty() && ignored_head_names_.contains(head_names_[i])) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
const auto& tensor = tensors[i];
|
const auto& tensor = tensors[i];
|
||||||
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
|
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
|
||||||
auto* embedding = result.add_embeddings();
|
auto* embedding = result.add_embeddings();
|
||||||
|
|
|
@ -32,4 +32,8 @@ message TensorsToEmbeddingsCalculatorOptions {
|
||||||
|
|
||||||
// The embedder head names.
|
// The embedder head names.
|
||||||
repeated string head_names = 2;
|
repeated string head_names = 2;
|
||||||
|
|
||||||
|
// Names of output heads that should not be included in the embeddings.
|
||||||
|
// Only applies if `head_names` is non-empty.
|
||||||
|
repeated string ignored_head_names = 3;
|
||||||
}
|
}
|
||||||
|
|
|
@ -129,6 +129,60 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
|
||||||
})pb")));
|
})pb")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNameIgnored) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
|
input_stream: "TENSORS:tensors"
|
||||||
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
|
embedder_options { l2_normalize: false quantize: false }
|
||||||
|
head_names: "foo"
|
||||||
|
head_names: "bar"
|
||||||
|
ignored_head_names: "foo"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const EmbeddingResult& result =
|
||||||
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
|
EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(
|
||||||
|
embeddings {
|
||||||
|
float_embedding { values: -0.2 values: -0.3 }
|
||||||
|
head_index: 1
|
||||||
|
head_name: "bar"
|
||||||
|
})pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithBothHeadsIgnored) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
|
input_stream: "TENSORS:tensors"
|
||||||
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
|
embedder_options { l2_normalize: false quantize: false }
|
||||||
|
head_names: "foo"
|
||||||
|
head_names: "bar"
|
||||||
|
ignored_head_names: "foo"
|
||||||
|
ignored_head_names: "bar"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const EmbeddingResult& result =
|
||||||
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
|
EXPECT_THAT(result,
|
||||||
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(R"pb()pb")));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
|
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
|
|
|
@ -144,6 +144,7 @@ cc_library(
|
||||||
"//mediapipe/calculators/tensor:regex_preprocessor_calculator",
|
"//mediapipe/calculators/tensor:regex_preprocessor_calculator",
|
||||||
"//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto",
|
"//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/tensor:text_to_tensor_calculator",
|
"//mediapipe/calculators/tensor:text_to_tensor_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:universal_sentence_encoder_preprocessor_calculator",
|
||||||
"//mediapipe/framework:subgraph",
|
"//mediapipe/framework:subgraph",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
|
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
|
@ -156,7 +157,8 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
absl::string_view model_name, const proto::EmbedderOptions& options,
|
absl::string_view model_name, const proto::EmbedderOptions& options,
|
||||||
bool connect_timestamps = false) {
|
bool connect_timestamps = false,
|
||||||
|
const std::vector<absl::string_view>& ignored_head_names = {}) {
|
||||||
ASSIGN_OR_RETURN(auto model_resources,
|
ASSIGN_OR_RETURN(auto model_resources,
|
||||||
CreateModelResourcesForModel(model_name));
|
CreateModelResourcesForModel(model_name));
|
||||||
|
|
||||||
|
@ -164,10 +166,15 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.processors."
|
"mediapipe.tasks.components.processors."
|
||||||
"EmbeddingPostprocessingGraph");
|
"EmbeddingPostprocessingGraph");
|
||||||
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph(
|
auto* postprocessing_options =
|
||||||
*model_resources, options,
|
|
||||||
&postprocessing
|
&postprocessing
|
||||||
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));
|
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>();
|
||||||
|
for (const absl::string_view head_name : ignored_head_names) {
|
||||||
|
postprocessing_options->mutable_tensors_to_embeddings_options()
|
||||||
|
->add_ignored_head_names(std::string(head_name));
|
||||||
|
}
|
||||||
|
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph(
|
||||||
|
*model_resources, options, postprocessing_options));
|
||||||
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||||
postprocessing.In(kTensorsTag);
|
postprocessing.In(kTensorsTag);
|
||||||
if (connect_timestamps) {
|
if (connect_timestamps) {
|
||||||
|
@ -274,6 +281,28 @@ TEST_F(PostprocessingTest, SucceedsWithoutAggregation) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(PostprocessingTest, SucceedsWithFilter) {
|
||||||
|
// Build graph.
|
||||||
|
proto::EmbedderOptions options;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto poller,
|
||||||
|
BuildGraph(kMobileNetV3Embedder, options, /*connect_timestamps=*/false,
|
||||||
|
/*ignored_head_names=*/{"feature"}));
|
||||||
|
// Build input tensor.
|
||||||
|
std::vector<float> tensor(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||||
|
tensor[0] = 1.0;
|
||||||
|
|
||||||
|
// Send tensor and get results.
|
||||||
|
AddTensor(tensor, Tensor::ElementType::kFloat32);
|
||||||
|
MP_ASSERT_OK(Run());
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult<EmbeddingResult>(poller));
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
EXPECT_TRUE(results.has_timestamp_ms());
|
||||||
|
EXPECT_EQ(results.timestamp_ms(), 0);
|
||||||
|
EXPECT_EQ(results.embeddings_size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithAggregation) {
|
TEST_F(PostprocessingTest, SucceedsWithAggregation) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
proto::EmbedderOptions options;
|
proto::EmbedderOptions options;
|
||||||
|
|
Loading…
Reference in New Issue
Block a user