Internal change
PiperOrigin-RevId: 508173086
This commit is contained in:
parent
6c4ebd2d93
commit
6ea2d579e1
|
@ -178,6 +178,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
|
@ -76,6 +77,7 @@ class TensorsToEmbeddingsCalculator : public Node {
|
|||
bool l2_normalize_;
|
||||
bool quantize_;
|
||||
std::vector<std::string> head_names_;
|
||||
absl::flat_hash_set<std::string> ignored_head_names_;
|
||||
|
||||
void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding);
|
||||
void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding);
|
||||
|
@ -89,6 +91,9 @@ absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
|
|||
head_names_.assign(options.head_names().begin(),
|
||||
options.head_names().end());
|
||||
}
|
||||
for (const absl::string_view head_name : options.ignored_head_names()) {
|
||||
ignored_head_names_.insert(std::string(head_name));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -102,6 +107,9 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
|
|||
head_names_.size(), tensors.size()));
|
||||
}
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
if (!head_names_.empty() && ignored_head_names_.contains(head_names_[i])) {
|
||||
continue;
|
||||
}
|
||||
const auto& tensor = tensors[i];
|
||||
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
|
||||
auto* embedding = result.add_embeddings();
|
||||
|
|
|
@ -32,4 +32,8 @@ message TensorsToEmbeddingsCalculatorOptions {
|
|||
|
||||
// The embedder head names.
|
||||
repeated string head_names = 2;
|
||||
|
||||
// Names of output heads that should not be included in the embeddings.
|
||||
// Only applies if `head_names` is non-empty.
|
||||
repeated string ignored_head_names = 3;
|
||||
}
|
||||
|
|
|
@ -129,6 +129,60 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
|
|||
})pb")));
|
||||
}
|
||||
|
||||
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNameIgnored) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToEmbeddingsCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "EMBEDDINGS:embeddings"
|
||||
options {
|
||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||
embedder_options { l2_normalize: false quantize: false }
|
||||
head_names: "foo"
|
||||
head_names: "bar"
|
||||
ignored_head_names: "foo"
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const EmbeddingResult& result =
|
||||
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||
EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(
|
||||
embeddings {
|
||||
float_embedding { values: -0.2 values: -0.3 }
|
||||
head_index: 1
|
||||
head_name: "bar"
|
||||
})pb")));
|
||||
}
|
||||
|
||||
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithBothHeadsIgnored) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToEmbeddingsCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "EMBEDDINGS:embeddings"
|
||||
options {
|
||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||
embedder_options { l2_normalize: false quantize: false }
|
||||
head_names: "foo"
|
||||
head_names: "bar"
|
||||
ignored_head_names: "foo"
|
||||
ignored_head_names: "bar"
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const EmbeddingResult& result =
|
||||
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||
EXPECT_THAT(result,
|
||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(R"pb()pb")));
|
||||
}
|
||||
|
||||
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToEmbeddingsCalculator"
|
||||
|
|
|
@ -144,6 +144,7 @@ cc_library(
|
|||
"//mediapipe/calculators/tensor:regex_preprocessor_calculator",
|
||||
"//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:text_to_tensor_calculator",
|
||||
"//mediapipe/calculators/tensor:universal_sentence_encoder_preprocessor_calculator",
|
||||
"//mediapipe/framework:subgraph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
|
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
|
@ -156,7 +157,8 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
|||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
absl::string_view model_name, const proto::EmbedderOptions& options,
|
||||
bool connect_timestamps = false) {
|
||||
bool connect_timestamps = false,
|
||||
const std::vector<absl::string_view>& ignored_head_names = {}) {
|
||||
ASSIGN_OR_RETURN(auto model_resources,
|
||||
CreateModelResourcesForModel(model_name));
|
||||
|
||||
|
@ -164,10 +166,15 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
|||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph(
|
||||
*model_resources, options,
|
||||
auto* postprocessing_options =
|
||||
&postprocessing
|
||||
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));
|
||||
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>();
|
||||
for (const absl::string_view head_name : ignored_head_names) {
|
||||
postprocessing_options->mutable_tensors_to_embeddings_options()
|
||||
->add_ignored_head_names(std::string(head_name));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph(
|
||||
*model_resources, options, postprocessing_options));
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||
postprocessing.In(kTensorsTag);
|
||||
if (connect_timestamps) {
|
||||
|
@ -274,6 +281,28 @@ TEST_F(PostprocessingTest, SucceedsWithoutAggregation) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithFilter) {
|
||||
// Build graph.
|
||||
proto::EmbedderOptions options;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller,
|
||||
BuildGraph(kMobileNetV3Embedder, options, /*connect_timestamps=*/false,
|
||||
/*ignored_head_names=*/{"feature"}));
|
||||
// Build input tensor.
|
||||
std::vector<float> tensor(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||
tensor[0] = 1.0;
|
||||
|
||||
// Send tensor and get results.
|
||||
AddTensor(tensor, Tensor::ElementType::kFloat32);
|
||||
MP_ASSERT_OK(Run());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult<EmbeddingResult>(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_TRUE(results.has_timestamp_ms());
|
||||
EXPECT_EQ(results.timestamp_ms(), 0);
|
||||
EXPECT_EQ(results.embeddings_size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithAggregation) {
|
||||
// Build graph.
|
||||
proto::EmbedderOptions options;
|
||||
|
|
Loading…
Reference in New Issue
Block a user