Update EmbeddingResult format and dependent tasks.

PiperOrigin-RevId: 486186491
This commit is contained in:
MediaPipe Team 2022-11-04 11:18:27 -07:00 committed by Copybara-Service
parent 66e591d4bc
commit 5e1a2fcdbb
31 changed files with 499 additions and 387 deletions

View File

@ -61,41 +61,6 @@ cc_library(
# TODO: Enable this test # TODO: Enable this test
cc_library(
name = "embedder_options",
srcs = ["embedder_options.cc"],
hdrs = ["embedder_options.h"],
deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"],
)
cc_library(
name = "embedding_postprocessing_graph",
srcs = ["embedding_postprocessing_graph.cc"],
hdrs = ["embedding_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
# TODO: Investigate rewriting the build rule to only link # TODO: Investigate rewriting the build rule to only link
# the Bert Preprocessor if it's needed. # the Bert Preprocessor if it's needed.
cc_library( cc_library(

View File

@ -163,7 +163,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
], ],
) )
@ -178,7 +178,7 @@ cc_library(
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
], ],

View File

@ -26,14 +26,14 @@
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
namespace { namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::proto::Embedding;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in // Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
@ -66,7 +66,7 @@ float GetInverseL2Norm(const float* values, int size) {
class TensorsToEmbeddingsCalculator : public Node { class TensorsToEmbeddingsCalculator : public Node {
public: public:
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"}; static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDING_RESULT"}; static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDINGS"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut); MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut);
absl::Status Open(CalculatorContext* cc) override; absl::Status Open(CalculatorContext* cc) override;
@ -77,8 +77,8 @@ class TensorsToEmbeddingsCalculator : public Node {
bool quantize_; bool quantize_;
std::vector<std::string> head_names_; std::vector<std::string> head_names_;
void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding);
void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding);
}; };
absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) { absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
@ -104,42 +104,42 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
for (int i = 0; i < tensors.size(); ++i) { for (int i = 0; i < tensors.size(); ++i) {
const auto& tensor = tensors[i]; const auto& tensor = tensors[i];
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32); RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
auto* embeddings = result.add_embeddings(); auto* embedding = result.add_embeddings();
embeddings->set_head_index(i); embedding->set_head_index(i);
if (!head_names_.empty()) { if (!head_names_.empty()) {
embeddings->set_head_name(head_names_[i]); embedding->set_head_name(head_names_[i]);
} }
if (quantize_) { if (quantize_) {
FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries()); FillQuantizedEmbedding(tensor, embedding);
} else { } else {
FillFloatEmbeddingEntry(tensor, embeddings->add_entries()); FillFloatEmbedding(tensor, embedding);
} }
} }
kEmbeddingsOut(cc).Send(result); kEmbeddingsOut(cc).Send(result);
return absl::OkStatus(); return absl::OkStatus();
} }
void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry( void TensorsToEmbeddingsCalculator::FillFloatEmbedding(const Tensor& tensor,
const Tensor& tensor, EmbeddingEntry* entry) { Embedding* embedding) {
int size = tensor.shape().num_elements(); int size = tensor.shape().num_elements();
auto tensor_view = tensor.GetCpuReadView(); auto tensor_view = tensor.GetCpuReadView();
const float* tensor_buffer = tensor_view.buffer<float>(); const float* tensor_buffer = tensor_view.buffer<float>();
float inv_l2_norm = float inv_l2_norm =
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
auto* float_embedding = entry->mutable_float_embedding(); auto* float_embedding = embedding->mutable_float_embedding();
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
float_embedding->add_values(tensor_buffer[i] * inv_l2_norm); float_embedding->add_values(tensor_buffer[i] * inv_l2_norm);
} }
} }
void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry( void TensorsToEmbeddingsCalculator::FillQuantizedEmbedding(
const Tensor& tensor, EmbeddingEntry* entry) { const Tensor& tensor, Embedding* embedding) {
int size = tensor.shape().num_elements(); int size = tensor.shape().num_elements();
auto tensor_view = tensor.GetCpuReadView(); auto tensor_view = tensor.GetCpuReadView();
const float* tensor_buffer = tensor_view.buffer<float>(); const float* tensor_buffer = tensor_view.buffer<float>();
float inv_l2_norm = float inv_l2_norm =
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
auto* values = entry->mutable_quantized_embedding()->mutable_values(); auto* values = embedding->mutable_quantized_embedding()->mutable_values();
values->resize(size); values->resize(size);
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
// Normalize. // Normalize.

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/embedder_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
message TensorsToEmbeddingsCalculatorOptions { message TensorsToEmbeddingsCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
@ -27,8 +27,8 @@ message TensorsToEmbeddingsCalculatorOptions {
// The embedder options defining whether to L2-normalize or scalar-quantize // The embedder options defining whether to L2-normalize or scalar-quantize
// the outputs. // the outputs.
optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options = optional mediapipe.tasks.components.processors.proto.EmbedderOptions
1; embedder_options = 1;
// The embedder head names. // The embedder head names.
repeated string head_names = 2; repeated string head_names = 2;

View File

@ -55,7 +55,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" } [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" }
} }
@ -73,7 +73,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false } embedder_options { l2_normalize: false quantize: false }
@ -84,28 +84,24 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0] EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
.Get<EmbeddingResult>(); R"pb(embeddings {
EXPECT_THAT( float_embedding { values: 0.1 values: 0.2 }
result, head_index: 0
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( }
R"pb(embeddings { embeddings {
entries { float_embedding { values: 0.1 values: 0.2 } } float_embedding { values: -0.2 values: -0.3 }
head_index: 0 head_index: 1
} })pb")));
embeddings {
entries { float_embedding { values: -0.2 values: -0.3 } }
head_index: 1
})pb")));
} }
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) { TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false } embedder_options { l2_normalize: false quantize: false }
@ -118,30 +114,26 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0] EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
.Get<EmbeddingResult>(); R"pb(embeddings {
EXPECT_THAT( float_embedding { values: 0.1 values: 0.2 }
result, head_index: 0
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( head_name: "foo"
R"pb(embeddings { }
entries { float_embedding { values: 0.1 values: 0.2 } } embeddings {
head_index: 0 float_embedding { values: -0.2 values: -0.3 }
head_name: "foo" head_index: 1
} head_name: "bar"
embeddings { })pb")));
entries { float_embedding { values: -0.2 values: -0.3 } }
head_index: 1
head_name: "bar"
})pb")));
} }
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) { TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: true quantize: false } embedder_options { l2_normalize: true quantize: false }
@ -152,23 +144,17 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT( EXPECT_THAT(
result, result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { R"pb(embeddings {
entries { float_embedding { values: 0.44721356 values: 0.8944271 }
float_embedding { values: 0.44721356 values: 0.8944271 }
}
head_index: 0 head_index: 0
} }
embeddings { embeddings {
entries { float_embedding { values: -0.5547002 values: -0.8320503 }
float_embedding { values: -0.5547002 values: -0.8320503 }
}
head_index: 1 head_index: 1
})pb"))); })pb")));
} }
@ -177,7 +163,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: true } embedder_options { l2_normalize: false quantize: true }
@ -188,22 +174,16 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT(result, EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { R"pb(embeddings {
entries { quantized_embedding { values: "\x0d\x1a" } # 13,26
quantized_embedding { values: "\x0d\x1a" } # 13,26
}
head_index: 0 head_index: 0
} }
embeddings { embeddings {
entries { quantized_embedding { values: "\xe6\xda" } # -26,-38
quantized_embedding { values: "\xe6\xda" } # -26,-38
}
head_index: 1 head_index: 1
})pb"))); })pb")));
} }
@ -213,7 +193,7 @@ TEST(TensorsToEmbeddingsCalculatorTest,
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: true quantize: true } embedder_options { l2_normalize: true quantize: true }
@ -224,25 +204,18 @@ TEST(TensorsToEmbeddingsCalculatorTest,
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0] EXPECT_THAT(result,
.Get<EmbeddingResult>(); EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
EXPECT_THAT( R"pb(embeddings {
result, quantized_embedding { values: "\x39\x72" } # 57,114
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( head_index: 0
R"pb(embeddings { }
entries { embeddings {
quantized_embedding { values: "\x39\x72" } # 57,114 quantized_embedding { values: "\xb9\x95" } # -71,-107
} head_index: 1
head_index: 0 })pb")));
}
embeddings {
entries {
quantized_embedding { values: "\xb9\x95" } # -71,-107
}
head_index: 1
})pb")));
} }
} // namespace } // namespace

View File

@ -49,3 +49,12 @@ cc_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
], ],
) )
cc_library(
name = "embedding_result",
srcs = ["embedding_result.cc"],
hdrs = ["embedding_result.h"],
deps = [
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
],
)

View File

@ -0,0 +1,57 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include <iterator>
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
namespace mediapipe::tasks::components::containers {
Embedding ConvertToEmbedding(const proto::Embedding& proto) {
Embedding embedding;
if (proto.has_float_embedding()) {
embedding.float_embedding = {
std::make_move_iterator(proto.float_embedding().values().begin()),
std::make_move_iterator(proto.float_embedding().values().end())};
} else {
embedding.quantized_embedding = {
std::make_move_iterator(proto.quantized_embedding().values().begin()),
std::make_move_iterator(proto.quantized_embedding().values().end())};
}
embedding.head_index = proto.head_index();
if (proto.has_head_name()) {
embedding.head_name = proto.head_name();
}
return embedding;
}
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto) {
EmbeddingResult embedding_result;
embedding_result.embeddings.reserve(proto.embeddings_size());
for (const auto& embedding : proto.embeddings()) {
embedding_result.embeddings.push_back(ConvertToEmbedding(embedding));
}
if (proto.has_timestamp_ms()) {
embedding_result.timestamp_ms = proto.timestamp_ms();
}
return embedding_result;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,72 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
namespace mediapipe::tasks::components::containers {
// Embedding result for a given embedder head.
//
// One and only one of the two 'float_embedding' and 'quantized_embedding' will
// contain data, based on whether or not the embedder was configured to perform
// scalar quantization.
struct Embedding {
// Floating-point embedding. Empty if the embedder was configured to perform
// scalar-quantization.
std::vector<float> float_embedding;
// Scalar-quantized embedding. Empty if the embedder was not configured to
// perform scalar quantization.
std::string quantized_embedding;
// The index of the embedder head (i.e. output tensor) this embedding comes
// from. This is useful for multi-head models.
int head_index;
// The optional name of the embedder head, as provided in the TFLite Model
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
std::optional<std::string> head_name = std::nullopt;
};
// Defines embedding results of a model.
struct EmbeddingResult {
// The embedding results for each head of the model.
std::vector<Embedding> embeddings;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for embedding extraction on time series (e.g. audio
// embedding). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
std::optional<int64_t> timestamp_ms = std::nullopt;
};
// Utility function to convert from Embedding proto to Embedding struct.
Embedding ConvertToEmbedding(const proto::Embedding& proto);
// Utility function to convert from EmbeddingResult proto to EmbeddingResult
// struct.
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_

View File

@ -30,30 +30,31 @@ message QuantizedEmbedding {
optional bytes values = 1; optional bytes values = 1;
} }
// Floating-point or scalar-quantized embedding with an optional timestamp. // Embedding result for a given embedder head.
message EmbeddingEntry { message Embedding {
// The actual embedding, either floating-point or scalar-quantized. // The actual embedding, either floating-point or quantized.
oneof embedding { oneof embedding {
FloatEmbedding float_embedding = 1; FloatEmbedding float_embedding = 1;
QuantizedEmbedding quantized_embedding = 2; QuantizedEmbedding quantized_embedding = 2;
} }
// The optional timestamp (in milliseconds) associated to the embedding entry.
// This is useful for time series use cases, e.g. audio embedding.
optional int64 timestamp_ms = 3;
}
// Embeddings for a given embedder head.
message Embeddings {
repeated EmbeddingEntry entries = 1;
// The index of the embedder head that produced this embedding. This is useful // The index of the embedder head that produced this embedding. This is useful
// for multi-head models. // for multi-head models.
optional int32 head_index = 2; optional int32 head_index = 3;
// The name of the embedder head, which is the corresponding tensor metadata // The name of the embedder head, which is the corresponding tensor metadata
// name (if any). This is useful for multi-head models. // name (if any). This is useful for multi-head models.
optional string head_name = 3; optional string head_name = 4;
} }
// Contains one set of results per embedder head. // Embedding results for a given embedder model.
message EmbeddingResult { message EmbeddingResult {
repeated Embeddings embeddings = 1; // The embedding results for each model head, i.e. one for each output tensor.
repeated Embedding embeddings = 1;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for embedding extraction on time series (e.g. audio
// embedding). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
optional int64 timestamp_ms = 2;
} }

View File

@ -62,3 +62,38 @@ cc_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "embedder_options",
srcs = ["embedder_options.cc"],
hdrs = ["embedder_options.h"],
deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"],
)
cc_library(
name = "embedding_postprocessing_graph",
srcs = ["embedding_postprocessing_graph.cc"],
hdrs = ["embedding_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)

View File

@ -13,22 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto( proto::EmbedderOptions ConvertEmbedderOptionsToProto(
EmbedderOptions* embedder_options) { EmbedderOptions* embedder_options) {
tasks::components::proto::EmbedderOptions options_proto; proto::EmbedderOptions options_proto;
options_proto.set_l2_normalize(embedder_options->l2_normalize); options_proto.set_l2_normalize(embedder_options->l2_normalize);
options_proto.set_quantize(embedder_options->quantize); options_proto.set_quantize(embedder_options->quantize);
return options_proto; return options_proto;
} }
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
// Embedder options for MediaPipe C++ embedding extraction tasks. // Embedder options for MediaPipe C++ embedding extraction tasks.
struct EmbedderOptions { struct EmbedderOptions {
@ -37,11 +38,12 @@ struct EmbedderOptions {
bool quantize; bool quantize;
}; };
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto( proto::EmbedderOptions ConvertEmbedderOptionsToProto(
EmbedderOptions* embedder_options); EmbedderOptions* embedder_options);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include <string> #include <string>
#include <vector> #include <vector>
@ -29,8 +29,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
@ -39,6 +39,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
namespace { namespace {
@ -49,13 +50,12 @@ using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::components::proto::EmbedderOptions;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using TensorsSource = using TensorsSource =
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>; ::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
// Identifies whether or not the model has quantized outputs, and performs // Identifies whether or not the model has quantized outputs, and performs
// sanity checks. // sanity checks.
@ -144,7 +144,7 @@ absl::StatusOr<std::vector<std::string>> GetHeadNames(
absl::Status ConfigureEmbeddingPostprocessing( absl::Status ConfigureEmbeddingPostprocessing(
const ModelResources& model_resources, const ModelResources& model_resources,
const EmbedderOptions& embedder_options, const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options) { proto::EmbeddingPostprocessingGraphOptions* options) {
ASSIGN_OR_RETURN(bool has_quantized_outputs, ASSIGN_OR_RETURN(bool has_quantized_outputs,
HasQuantizedOutputs(model_resources)); HasQuantizedOutputs(model_resources));
@ -188,7 +188,7 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
BuildEmbeddingPostprocessing( BuildEmbeddingPostprocessing(
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(), sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph)); graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingResultTag)]; embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -220,13 +220,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>() .GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
.CopyFrom(options.tensors_to_embeddings_options()); .CopyFrom(options.tensors_to_embeddings_options());
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag); dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
return tensors_to_embeddings_node[Output<EmbeddingResult>( return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
kEmbeddingResultTag)];
} }
}; };
REGISTER_MEDIAPIPE_GRAPH( REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::EmbeddingPostprocessingGraph); ::mediapipe::tasks::components::processors::EmbeddingPostprocessingGraph);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
// Configures an EmbeddingPostprocessingGraph using the provided model resources // Configures an EmbeddingPostprocessingGraph using the provided model resources
// and EmbedderOptions. // and EmbedderOptions.
@ -44,18 +45,19 @@ namespace components {
// The output tensors of an InferenceCalculator, to convert into // The output tensors of an InferenceCalculator, to convert into
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. // EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
// Outputs: // Outputs:
// EMBEDDING_RESULT - EmbeddingResult // EMBEDDINGS - EmbeddingResult
// The output EmbeddingResult. // The output EmbeddingResult.
// //
// TODO: add support for additional optional "TIMESTAMPS" input for // TODO: add support for additional optional "TIMESTAMPS" input for
// embeddings aggregation. // embeddings aggregation.
absl::Status ConfigureEmbeddingPostprocessing( absl::Status ConfigureEmbeddingPostprocessing(
const tasks::core::ModelResources& model_resources, const tasks::core::ModelResources& model_resources,
const tasks::components::proto::EmbedderOptions& embedder_options, const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options); proto::EmbeddingPostprocessingGraphOptions* options);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include <memory> #include <memory>
@ -25,8 +25,8 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -34,12 +34,10 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::proto::EmbedderOptions;
using ::mediapipe::tasks::components::proto::
EmbeddingPostprocessingGraphOptions;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
@ -69,68 +67,72 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
EmbedderOptions options_in; proto::EmbedderOptions options_in;
options_in.set_l2_normalize(true); options_in.set_l2_normalize(true);
EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out)); &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>( EqualsProto(
R"pb(tensors_to_embeddings_options { ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
embedder_options { l2_normalize: true } R"pb(tensors_to_embeddings_options {
head_names: "probability" embedder_options { l2_normalize: true }
} head_names: "probability"
has_quantized_outputs: true)pb"))); }
has_quantized_outputs: true)pb")));
} }
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
EmbedderOptions options_in; proto::EmbedderOptions options_in;
options_in.set_quantize(true); options_in.set_quantize(true);
EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out)); &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>( EqualsProto(
R"pb(tensors_to_embeddings_options { ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
embedder_options { quantize: true } R"pb(tensors_to_embeddings_options {
} embedder_options { quantize: true }
has_quantized_outputs: true)pb"))); }
has_quantized_outputs: true)pb")));
} }
TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(auto model_resources, MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
CreateModelResourcesForModel(kMobileNetV3Embedder)); CreateModelResourcesForModel(kMobileNetV3Embedder));
EmbedderOptions options_in; proto::EmbedderOptions options_in;
options_in.set_quantize(true); options_in.set_quantize(true);
options_in.set_l2_normalize(true); options_in.set_l2_normalize(true);
EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out)); &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>( EqualsProto(
R"pb(tensors_to_embeddings_options { ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
embedder_options { quantize: true l2_normalize: true } R"pb(tensors_to_embeddings_options {
head_names: "feature" embedder_options { quantize: true l2_normalize: true }
} head_names: "feature"
has_quantized_outputs: false)pb"))); }
has_quantized_outputs: false)pb")));
} }
// TODO: add E2E Postprocessing tests once timestamp aggregation is // TODO: add E2E Postprocessing tests once timestamp aggregation is
// supported. // supported.
} // namespace } // namespace
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -34,3 +34,18 @@ mediapipe_proto_library(
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
], ],
) )
mediapipe_proto_library(
name = "embedder_options_proto",
srcs = ["embedder_options.proto"],
)
mediapipe_proto_library(
name = "embedding_postprocessing_graph_options_proto",
srcs = ["embedding_postprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
],
)

View File

@ -15,7 +15,10 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks.components.proto; package mediapipe.tasks.components.processors.proto;
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
option java_outer_classname = "EmbedderOptionsProto";
// Shared options used by all embedding extraction tasks. // Shared options used by all embedding extraction tasks.
message EmbedderOptions { message EmbedderOptions {

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks.components.proto; package mediapipe.tasks.components.processors.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto";

View File

@ -23,21 +23,6 @@ mediapipe_proto_library(
srcs = ["segmenter_options.proto"], srcs = ["segmenter_options.proto"],
) )
mediapipe_proto_library(
name = "embedder_options_proto",
srcs = ["embedder_options.proto"],
)
mediapipe_proto_library(
name = "embedding_postprocessing_graph_options_proto",
srcs = ["embedding_postprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "text_preprocessing_graph_options_proto", name = "text_preprocessing_graph_options_proto",
srcs = ["text_preprocessing_graph_options.proto"], srcs = ["text_preprocessing_graph_options.proto"],

View File

@ -26,7 +26,7 @@ cc_library(
hdrs = ["cosine_similarity.h"], hdrs = ["cosine_similarity.h"],
deps = [ deps = [
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers:embedding_result",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
@ -39,7 +39,7 @@ cc_test(
deps = [ deps = [
":cosine_similarity", ":cosine_similarity",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers:embedding_result",
], ],
) )

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -30,7 +30,7 @@ namespace utils {
namespace { namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::Embedding;
template <typename T> template <typename T>
absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v, absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
@ -66,39 +66,35 @@ absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
// an L2-norm of 0. // an L2-norm of 0.
// //
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
absl::StatusOr<double> CosineSimilarity(const EmbeddingEntry& u, absl::StatusOr<double> CosineSimilarity(const Embedding& u,
const EmbeddingEntry& v) { const Embedding& v) {
if (u.has_float_embedding() && v.has_float_embedding()) { if (!u.float_embedding.empty() && !v.float_embedding.empty()) {
if (u.float_embedding().values().size() != if (u.float_embedding.size() != v.float_embedding.size()) {
v.float_embedding().values().size()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrFormat("Cannot compute cosine similarity between embeddings " absl::StrFormat("Cannot compute cosine similarity between embeddings "
"of different sizes (%d vs. %d)", "of different sizes (%d vs. %d)",
u.float_embedding().values().size(), u.float_embedding.size(), v.float_embedding.size()),
v.float_embedding().values().size()),
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
} }
return ComputeCosineSimilarity(u.float_embedding().values().data(), return ComputeCosineSimilarity(u.float_embedding.data(),
v.float_embedding().values().data(), v.float_embedding.data(),
u.float_embedding().values().size()); u.float_embedding.size());
} }
if (u.has_quantized_embedding() && v.has_quantized_embedding()) { if (!u.quantized_embedding.empty() && !v.quantized_embedding.empty()) {
if (u.quantized_embedding().values().size() != if (u.quantized_embedding.size() != v.quantized_embedding.size()) {
v.quantized_embedding().values().size()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrFormat("Cannot compute cosine similarity between embeddings " absl::StrFormat("Cannot compute cosine similarity between embeddings "
"of different sizes (%d vs. %d)", "of different sizes (%d vs. %d)",
u.quantized_embedding().values().size(), u.quantized_embedding.size(),
v.quantized_embedding().values().size()), v.quantized_embedding.size()),
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
} }
return ComputeCosineSimilarity(reinterpret_cast<const int8_t*>( return ComputeCosineSimilarity(
u.quantized_embedding().values().data()), reinterpret_cast<const int8_t*>(u.quantized_embedding.data()),
reinterpret_cast<const int8_t*>( reinterpret_cast<const int8_t*>(v.quantized_embedding.data()),
v.quantized_embedding().values().data()), u.quantized_embedding.size());
u.quantized_embedding().values().size());
} }
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,

View File

@ -17,22 +17,20 @@ limitations under the License.
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace utils { namespace utils {
// Utility function to compute cosine similarity [1] between two embedding // Utility function to compute cosine similarity [1] between two embeddings. May
// entries. May return an InvalidArgumentError if e.g. the feature vectors are // return an InvalidArgumentError if e.g. the embeddings are of different types
// of different types (quantized vs. float), have different sizes, or have a // (quantized vs. float), have different sizes, or have a an L2-norm of 0.
// an L2-norm of 0.
// //
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
absl::StatusOr<double> CosineSimilarity( absl::StatusOr<double> CosineSimilarity(const containers::Embedding& u,
const containers::proto::EmbeddingEntry& u, const containers::Embedding& v);
const containers::proto::EmbeddingEntry& v);
} // namespace utils } // namespace utils
} // namespace components } // namespace components

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -30,29 +30,27 @@ namespace components {
namespace utils { namespace utils {
namespace { namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::Embedding;
using ::testing::HasSubstr; using ::testing::HasSubstr;
// Helper function to generate float EmbeddingEntry. // Helper function to generate float Embedding.
EmbeddingEntry BuildFloatEntry(std::vector<float> values) { Embedding BuildFloatEmbedding(std::vector<float> values) {
EmbeddingEntry entry; Embedding embedding;
for (const float value : values) { embedding.float_embedding = values;
entry.mutable_float_embedding()->add_values(value); return embedding;
}
return entry;
} }
// Helper function to generate quantized EmbeddingEntry. // Helper function to generate quantized Embedding.
EmbeddingEntry BuildQuantizedEntry(std::vector<int8_t> values) { Embedding BuildQuantizedEmbedding(std::vector<int8_t> values) {
EmbeddingEntry entry; Embedding embedding;
entry.mutable_quantized_embedding()->set_values( uint8_t* data = reinterpret_cast<uint8_t*>(values.data());
reinterpret_cast<uint8_t*>(values.data()), values.size()); embedding.quantized_embedding = {data, data + values.size()};
return entry; return embedding;
} }
TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) { TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
auto u = BuildFloatEntry({0.1, 0.2}); auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildQuantizedEntry({0, 1}); auto v = BuildQuantizedEmbedding({0, 1});
auto status = CosineSimilarity(u, v); auto status = CosineSimilarity(u, v);
@ -63,8 +61,8 @@ TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
} }
TEST(CosineSimilarity, FailsWithZeroNorm) { TEST(CosineSimilarity, FailsWithZeroNorm) {
auto u = BuildFloatEntry({0.1, 0.2}); auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildFloatEntry({0.0, 0.0}); auto v = BuildFloatEmbedding({0.0, 0.0});
auto status = CosineSimilarity(u, v); auto status = CosineSimilarity(u, v);
@ -75,8 +73,8 @@ TEST(CosineSimilarity, FailsWithZeroNorm) {
} }
TEST(CosineSimilarity, FailsWithDifferentSizes) { TEST(CosineSimilarity, FailsWithDifferentSizes) {
auto u = BuildFloatEntry({0.1, 0.2}); auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildFloatEntry({0.1, 0.2, 0.3}); auto v = BuildFloatEmbedding({0.1, 0.2, 0.3});
auto status = CosineSimilarity(u, v); auto status = CosineSimilarity(u, v);
@ -87,8 +85,8 @@ TEST(CosineSimilarity, FailsWithDifferentSizes) {
} }
TEST(CosineSimilarity, SucceedsWithFloatEntries) { TEST(CosineSimilarity, SucceedsWithFloatEntries) {
auto u = BuildFloatEntry({1.0, 0.0, 0.0, 0.0}); auto u = BuildFloatEmbedding({1.0, 0.0, 0.0, 0.0});
auto v = BuildFloatEntry({0.5, 0.5, 0.5, 0.5}); auto v = BuildFloatEmbedding({0.5, 0.5, 0.5, 0.5});
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v)); MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
@ -96,8 +94,8 @@ TEST(CosineSimilarity, SucceedsWithFloatEntries) {
} }
TEST(CosineSimilarity, SucceedsWithQuantizedEntries) { TEST(CosineSimilarity, SucceedsWithQuantizedEntries) {
auto u = BuildQuantizedEntry({127, 0, 0, 0}); auto u = BuildQuantizedEmbedding({127, 0, 0, 0});
auto v = BuildQuantizedEntry({-128, 0, 0, 0}); auto v = BuildQuantizedEmbedding({-128, 0, 0, 0});
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v)); MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));

View File

@ -26,12 +26,12 @@ cc_library(
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
@ -49,9 +49,10 @@ cc_library(
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/tool:options_map", "//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc/components:embedder_options", "//mediapipe/tasks/cc/components/containers:embedding_result",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", "//mediapipe/tasks/cc/components/processors:embedder_options",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:cosine_similarity", "//mediapipe/tasks/cc/components/utils:cosine_similarity",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",

View File

@ -21,9 +21,10 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/tool/options_map.h" #include "mediapipe/framework/tool/options_map.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h" #include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
@ -41,8 +42,8 @@ namespace image_embedder {
namespace { namespace {
constexpr char kEmbeddingResultStreamName[] = "embedding_result_out"; constexpr char kEmbeddingsStreamName[] = "embeddings_out";
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
@ -53,7 +54,7 @@ constexpr char kGraphTypeName[] =
"mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::vision::image_embedder::proto:: using ::mediapipe::tasks::vision::image_embedder::proto::
@ -71,13 +72,13 @@ CalculatorGraphConfig CreateGraphConfig(
graph.In(kNormRectTag).SetName(kNormRectStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName);
auto& task_graph = graph.AddNode(kGraphTypeName); auto& task_graph = graph.AddNode(kGraphTypeName);
task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get()); task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get());
task_graph.Out(kEmbeddingResultTag).SetName(kEmbeddingResultStreamName) >> task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >>
graph.Out(kEmbeddingResultTag); graph.Out(kEmbeddingsTag);
task_graph.Out(kImageTag).SetName(kImageOutStreamName) >> task_graph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag); graph.Out(kImageTag);
if (enable_flow_limiting) { if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator( return tasks::core::AddFlowLimiterCalculator(
graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingResultTag); graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingsTag);
} }
graph.In(kImageTag) >> task_graph.In(kImageTag); graph.In(kImageTag) >> task_graph.In(kImageTag);
graph.In(kNormRectTag) >> task_graph.In(kNormRectTag); graph.In(kNormRectTag) >> task_graph.In(kNormRectTag);
@ -95,8 +96,8 @@ std::unique_ptr<ImageEmbedderGraphOptions> ConvertImageEmbedderOptionsToProto(
options_proto->mutable_base_options()->set_use_stream_mode( options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE); options->running_mode != core::RunningMode::IMAGE);
auto embedder_options_proto = auto embedder_options_proto =
std::make_unique<tasks::components::proto::EmbedderOptions>( std::make_unique<components::processors::proto::EmbedderOptions>(
components::ConvertEmbedderOptionsToProto( components::processors::ConvertEmbedderOptionsToProto(
&(options->embedder_options))); &(options->embedder_options)));
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get()); options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
return options_proto; return options_proto;
@ -121,9 +122,10 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
return; return;
} }
Packet embedding_result_packet = Packet embedding_result_packet =
status_or_packets.value()[kEmbeddingResultStreamName]; status_or_packets.value()[kEmbeddingsStreamName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(embedding_result_packet.Get<EmbeddingResult>(), result_callback(ConvertToEmbeddingResult(
embedding_result_packet.Get<EmbeddingResult>()),
image_packet.Get<Image>(), image_packet.Get<Image>(),
embedding_result_packet.Timestamp().Value() / embedding_result_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond); kMicroSecondsPerMilliSecond);
@ -138,7 +140,7 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
std::move(packets_callback)); std::move(packets_callback));
} }
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed( absl::StatusOr<ImageEmbedderResult> ImageEmbedder::Embed(
Image image, Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -155,10 +157,11 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
{{kImageInStreamName, MakePacket<Image>(std::move(image))}, {{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}})); MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>(); return ConvertToEmbeddingResult(
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
} }
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo( absl::StatusOr<ImageEmbedderResult> ImageEmbedder::EmbedForVideo(
Image image, int64 timestamp_ms, Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -178,7 +181,8 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>(); return ConvertToEmbeddingResult(
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
} }
absl::Status ImageEmbedder::EmbedAsync( absl::Status ImageEmbedder::EmbedAsync(
@ -202,7 +206,8 @@ absl::Status ImageEmbedder::EmbedAsync(
} }
absl::StatusOr<double> ImageEmbedder::CosineSimilarity( absl::StatusOr<double> ImageEmbedder::CosineSimilarity(
const EmbeddingEntry& u, const EmbeddingEntry& v) { const components::containers::Embedding& u,
const components::containers::Embedding& v) {
return components::utils::CosineSimilarity(u, v); return components::utils::CosineSimilarity(u, v);
} }

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
@ -33,6 +33,10 @@ namespace tasks {
namespace vision { namespace vision {
namespace image_embedder { namespace image_embedder {
// Alias the shared EmbeddingResult struct as result typo.
using ImageEmbedderResult =
::mediapipe::tasks::components::containers::EmbeddingResult;
// The options for configuring a MediaPipe image embedder task. // The options for configuring a MediaPipe image embedder task.
struct ImageEmbedderOptions { struct ImageEmbedderOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model // Base options for configuring MediaPipe Tasks, such as specifying the model
@ -50,14 +54,12 @@ struct ImageEmbedderOptions {
// Options for configuring the embedder behavior, such as L2-normalization or // Options for configuring the embedder behavior, such as L2-normalization or
// scalar-quantization. // scalar-quantization.
components::EmbedderOptions embedder_options; components::processors::EmbedderOptions embedder_options;
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.
std::function<void( std::function<void(absl::StatusOr<ImageEmbedderResult>, const Image&, int64)>
absl::StatusOr<components::containers::proto::EmbeddingResult>,
const Image&, int64)>
result_callback = nullptr; result_callback = nullptr;
}; };
@ -104,7 +106,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// running mode. // running mode.
// //
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed( absl::StatusOr<ImageEmbedderResult> Embed(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -127,7 +129,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to // The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo( absl::StatusOr<ImageEmbedderResult> EmbedForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -168,15 +170,15 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// Shuts down the ImageEmbedder when all works are done. // Shuts down the ImageEmbedder when all works are done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
// Utility function to compute cosine similarity [1] between two embedding // Utility function to compute cosine similarity [1] between two embeddings.
// entries. May return an InvalidArgumentError if e.g. the feature vectors are // May return an InvalidArgumentError if e.g. the embeddings are of different
// of different types (quantized vs. float), have different sizes, or have a // types (quantized vs. float), have different sizes, or have a an L2-norm of
// an L2-norm of 0. // 0.
// //
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
static absl::StatusOr<double> CosineSimilarity( static absl::StatusOr<double> CosineSimilarity(
const components::containers::proto::EmbeddingEntry& u, const components::containers::Embedding& u,
const components::containers::proto::EmbeddingEntry& v); const components::containers::Embedding& v);
}; };
} // namespace image_embedder } // namespace image_embedder

View File

@ -20,10 +20,10 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
@ -40,10 +40,8 @@ using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::components::proto::
EmbeddingPostprocessingGraphOptions;
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
@ -67,7 +65,7 @@ struct ImageEmbedderOutputStreams {
// Describes region of image to perform embedding extraction on. // Describes region of image to perform embedding extraction on.
// @Optional: rect covering the whole image is used if not specified. // @Optional: rect covering the whole image is used if not specified.
// Outputs: // Outputs:
// EMBEDDING_RESULT - EmbeddingResult // EMBEDDINGS - EmbeddingResult
// The embedding result. // The embedding result.
// IMAGE - Image // IMAGE - Image
// The image that embedding extraction runs on. // The image that embedding extraction runs on.
@ -76,7 +74,7 @@ struct ImageEmbedderOutputStreams {
// node { // node {
// calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph" // calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"
// input_stream: "IMAGE:image_in" // input_stream: "IMAGE:image_in"
// output_stream: "EMBEDDING_RESULT:embedding_result_out" // output_stream: "EMBEDDINGS:embedding_result_out"
// output_stream: "IMAGE:image_out" // output_stream: "IMAGE:image_out"
// options { // options {
// [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext] // [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext]
@ -107,7 +105,7 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
graph[Input<Image>(kImageTag)], graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.embedding_result >> output_streams.embedding_result >>
graph[Output<EmbeddingResult>(kEmbeddingResultTag)]; graph[Output<EmbeddingResult>(kEmbeddingsTag)];
output_streams.image >> graph[Output<Image>(kImageTag)]; output_streams.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -152,16 +150,17 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
// Adds postprocessing calculators and connects its input stream to the // Adds postprocessing calculators and connects its input stream to the
// inference results. // inference results.
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.EmbeddingPostprocessingGraph"); "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::ConfigureEmbeddingPostprocessing( MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
model_resources, task_options.embedder_options(), model_resources, task_options.embedder_options(),
&postprocessing.GetOptions<EmbeddingPostprocessingGraphOptions>())); &postprocessing.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding results. // Outputs the embedding results.
return ImageEmbedderOutputStreams{ return ImageEmbedderOutputStreams{
/*embedding_result=*/postprocessing[Output<EmbeddingResult>( /*embedding_result=*/postprocessing[Output<EmbeddingResult>(
kEmbeddingResultTag)], kEmbeddingsTag)],
/*image=*/preprocessing[Output<Image>(kImageTag)]}; /*image=*/preprocessing[Output<Image>(kImageTag)]};
} }
}; };

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -42,7 +42,6 @@ namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -54,18 +53,14 @@ constexpr double kSimilarityTolerancy = 1e-6;
// Utility function to check the sizes, head_index and head_names of a result // Utility function to check the sizes, head_index and head_names of a result
// procuded by kMobileNetV3Embedder. // procuded by kMobileNetV3Embedder.
void CheckMobileNetV3Result(const EmbeddingResult& result, bool quantized) { void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
EXPECT_EQ(result.embeddings().size(), 1); EXPECT_EQ(result.embeddings.size(), 1);
EXPECT_EQ(result.embeddings(0).head_index(), 0); EXPECT_EQ(result.embeddings[0].head_index, 0);
EXPECT_EQ(result.embeddings(0).head_name(), "feature"); EXPECT_EQ(result.embeddings[0].head_name, "feature");
EXPECT_EQ(result.embeddings(0).entries().size(), 1);
if (quantized) { if (quantized) {
EXPECT_EQ( EXPECT_EQ(result.embeddings[0].quantized_embedding.size(), 1024);
result.embeddings(0).entries(0).quantized_embedding().values().size(),
1024);
} else { } else {
EXPECT_EQ(result.embeddings(0).entries(0).float_embedding().values().size(), EXPECT_EQ(result.embeddings[0].float_embedding.size(), 1024);
1024);
} }
} }
@ -154,7 +149,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = running_mode; options->running_mode = running_mode;
options->result_callback = [](absl::StatusOr<EmbeddingResult>, options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
auto image_embedder = ImageEmbedder::Create(std::move(options)); auto image_embedder = ImageEmbedder::Create(std::move(options));
@ -231,19 +226,18 @@ TEST_F(ImageModeTest, SucceedsWithoutL2Normalization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.925519; double expected_similarity = 0.925519;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -264,19 +258,18 @@ TEST_F(ImageModeTest, SucceedsWithL2Normalization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.925519; double expected_similarity = 0.925519;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -297,19 +290,18 @@ TEST_F(ImageModeTest, SucceedsWithQuantization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, true); CheckMobileNetV3Result(image_result, true);
CheckMobileNetV3Result(crop_result, true); CheckMobileNetV3Result(crop_result, true);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.926791; double expected_similarity = 0.926791;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -333,19 +325,18 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& image_result, const ImageEmbedderResult& image_result,
image_embedder->Embed(image, image_processing_options)); image_embedder->Embed(image, image_processing_options));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.999931; double expected_similarity = 0.999931;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -367,20 +358,19 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
image_processing_options.rotation_degrees = -90; image_processing_options.rotation_degrees = -90;
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result, const ImageEmbedderResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options)); image_embedder->Embed(rotated, image_processing_options));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(rotated_result, false); CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), rotated_result.embeddings[0]));
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.572265; double expected_similarity = 0.572265;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -403,20 +393,19 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
/*rotation_degrees=*/-90}; /*rotation_degrees=*/-90};
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result, const ImageEmbedderResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options)); image_embedder->Embed(rotated, image_processing_options));
// Check results. // Check results.
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
CheckMobileNetV3Result(rotated_result, false); CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, crop_result.embeddings[0],
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0), rotated_result.embeddings[0]));
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.62838; double expected_similarity = 0.62838;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -487,16 +476,16 @@ TEST_F(VideoModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options))); ImageEmbedder::Create(std::move(options)));
EmbeddingResult previous_results; ImageEmbedderResult previous_results;
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto results, MP_ASSERT_OK_AND_ASSIGN(auto results,
image_embedder->EmbedForVideo(image, i)); image_embedder->EmbedForVideo(image, i));
CheckMobileNetV3Result(results, false); CheckMobileNetV3Result(results, false);
if (i > 0) { if (i > 0) {
MP_ASSERT_OK_AND_ASSIGN(double similarity, MP_ASSERT_OK_AND_ASSIGN(
ImageEmbedder::CosineSimilarity( double similarity,
results.embeddings(0).entries(0), ImageEmbedder::CosineSimilarity(results.embeddings[0],
previous_results.embeddings(0).entries(0))); previous_results.embeddings[0]));
double expected_similarity = 1.000000; double expected_similarity = 1.000000;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -515,7 +504,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<EmbeddingResult>, options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options))); ImageEmbedder::Create(std::move(options)));
@ -546,7 +535,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<EmbeddingResult>, options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options))); ImageEmbedder::Create(std::move(options)));
@ -564,7 +553,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
} }
struct LiveStreamModeResults { struct LiveStreamModeResults {
EmbeddingResult embedding_result; ImageEmbedderResult embedding_result;
std::pair<int, int> image_size; std::pair<int, int> image_size;
int64 timestamp_ms; int64 timestamp_ms;
}; };
@ -580,7 +569,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = options->result_callback =
[&results](absl::StatusOr<EmbeddingResult> embedding_result, [&results](absl::StatusOr<ImageEmbedderResult> embedding_result,
const Image& image, int64 timestamp_ms) { const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(embedding_result.status()); MP_ASSERT_OK(embedding_result.status());
results.push_back( results.push_back(
@ -612,8 +601,8 @@ TEST_F(LiveStreamModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
double similarity, double similarity,
ImageEmbedder::CosineSimilarity( ImageEmbedder::CosineSimilarity(
result.embedding_result.embeddings(0).entries(0), result.embedding_result.embeddings[0],
results[i - 1].embedding_result.embeddings(0).entries(0))); results[i - 1].embedding_result.embeddings[0]));
double expected_similarity = 1.000000; double expected_similarity = 1.000000;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }

View File

@ -24,7 +24,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.image_embedder.proto; package mediapipe.tasks.vision.image_embedder.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/embedder_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message ImageEmbedderGraphOptions { message ImageEmbedderGraphOptions {
@ -31,5 +31,5 @@ message ImageEmbedderGraphOptions {
// Options for configuring the embedder behavior, such as normalization or // Options for configuring the embedder behavior, such as normalization or
// quantization. // quantization.
optional components.proto.EmbedderOptions embedder_options = 2; optional components.processors.proto.EmbedderOptions embedder_options = 2;
} }

View File

@ -23,6 +23,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",