Update EmbeddingResult format and dependent tasks.

PiperOrigin-RevId: 486186491
This commit is contained in:
MediaPipe Team 2022-11-04 11:18:27 -07:00 committed by Copybara-Service
parent 66e591d4bc
commit 5e1a2fcdbb
31 changed files with 499 additions and 387 deletions

View File

@ -61,41 +61,6 @@ cc_library(
# TODO: Enable this test
cc_library(
name = "embedder_options",
srcs = ["embedder_options.cc"],
hdrs = ["embedder_options.h"],
deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"],
)
cc_library(
name = "embedding_postprocessing_graph",
srcs = ["embedding_postprocessing_graph.cc"],
hdrs = ["embedding_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
# TODO: Investigate rewriting the build rule to only link
# the Bert Preprocessor if it's needed.
cc_library(

View File

@ -163,7 +163,7 @@ mediapipe_proto_library(
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
],
)
@ -178,7 +178,7 @@ cc_library(
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
],

View File

@ -26,14 +26,14 @@
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe {
namespace api2 {
namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
using ::mediapipe::tasks::components::containers::proto::Embedding;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
@ -66,7 +66,7 @@ float GetInverseL2Norm(const float* values, int size) {
class TensorsToEmbeddingsCalculator : public Node {
public:
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDING_RESULT"};
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDINGS"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut);
absl::Status Open(CalculatorContext* cc) override;
@ -77,8 +77,8 @@ class TensorsToEmbeddingsCalculator : public Node {
bool quantize_;
std::vector<std::string> head_names_;
void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding);
void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding);
};
absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
@ -104,42 +104,42 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
for (int i = 0; i < tensors.size(); ++i) {
const auto& tensor = tensors[i];
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
auto* embeddings = result.add_embeddings();
embeddings->set_head_index(i);
auto* embedding = result.add_embeddings();
embedding->set_head_index(i);
if (!head_names_.empty()) {
embeddings->set_head_name(head_names_[i]);
embedding->set_head_name(head_names_[i]);
}
if (quantize_) {
FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries());
FillQuantizedEmbedding(tensor, embedding);
} else {
FillFloatEmbeddingEntry(tensor, embeddings->add_entries());
FillFloatEmbedding(tensor, embedding);
}
}
kEmbeddingsOut(cc).Send(result);
return absl::OkStatus();
}
void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry(
const Tensor& tensor, EmbeddingEntry* entry) {
void TensorsToEmbeddingsCalculator::FillFloatEmbedding(const Tensor& tensor,
Embedding* embedding) {
int size = tensor.shape().num_elements();
auto tensor_view = tensor.GetCpuReadView();
const float* tensor_buffer = tensor_view.buffer<float>();
float inv_l2_norm =
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
auto* float_embedding = entry->mutable_float_embedding();
auto* float_embedding = embedding->mutable_float_embedding();
for (int i = 0; i < size; ++i) {
float_embedding->add_values(tensor_buffer[i] * inv_l2_norm);
}
}
void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry(
const Tensor& tensor, EmbeddingEntry* entry) {
void TensorsToEmbeddingsCalculator::FillQuantizedEmbedding(
const Tensor& tensor, Embedding* embedding) {
int size = tensor.shape().num_elements();
auto tensor_view = tensor.GetCpuReadView();
const float* tensor_buffer = tensor_view.buffer<float>();
float inv_l2_norm =
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
auto* values = entry->mutable_quantized_embedding()->mutable_values();
auto* values = embedding->mutable_quantized_embedding()->mutable_values();
values->resize(size);
for (int i = 0; i < size; ++i) {
// Normalize.

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/embedder_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
message TensorsToEmbeddingsCalculatorOptions {
extend mediapipe.CalculatorOptions {
@ -27,8 +27,8 @@ message TensorsToEmbeddingsCalculatorOptions {
// The embedder options defining whether to L2-normalize or scalar-quantize
// the outputs.
optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options =
1;
optional mediapipe.tasks.components.processors.proto.EmbedderOptions
embedder_options = 1;
// The embedder head names.
repeated string head_names = 2;

View File

@ -55,7 +55,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
output_stream: "EMBEDDINGS:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" }
}
@ -73,7 +73,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
output_stream: "EMBEDDINGS:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false }
@ -84,28 +84,24 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs()
.Get("EMBEDDING_RESULT", 0)
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT(
result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
entries { float_embedding { values: 0.1 values: 0.2 } }
head_index: 0
}
embeddings {
entries { float_embedding { values: -0.2 values: -0.3 } }
head_index: 1
})pb")));
const EmbeddingResult& result =
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
float_embedding { values: 0.1 values: 0.2 }
head_index: 0
}
embeddings {
float_embedding { values: -0.2 values: -0.3 }
head_index: 1
})pb")));
}
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
output_stream: "EMBEDDINGS:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false }
@ -118,30 +114,26 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs()
.Get("EMBEDDING_RESULT", 0)
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT(
result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
entries { float_embedding { values: 0.1 values: 0.2 } }
head_index: 0
head_name: "foo"
}
embeddings {
entries { float_embedding { values: -0.2 values: -0.3 } }
head_index: 1
head_name: "bar"
})pb")));
const EmbeddingResult& result =
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
float_embedding { values: 0.1 values: 0.2 }
head_index: 0
head_name: "foo"
}
embeddings {
float_embedding { values: -0.2 values: -0.3 }
head_index: 1
head_name: "bar"
})pb")));
}
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
output_stream: "EMBEDDINGS:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: true quantize: false }
@ -152,23 +144,17 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs()
.Get("EMBEDDING_RESULT", 0)
.packets[0]
.Get<EmbeddingResult>();
const EmbeddingResult& result =
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
EXPECT_THAT(
result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
entries {
float_embedding { values: 0.44721356 values: 0.8944271 }
}
float_embedding { values: 0.44721356 values: 0.8944271 }
head_index: 0
}
embeddings {
entries {
float_embedding { values: -0.5547002 values: -0.8320503 }
}
float_embedding { values: -0.5547002 values: -0.8320503 }
head_index: 1
})pb")));
}
@ -177,7 +163,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
output_stream: "EMBEDDINGS:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: true }
@ -188,22 +174,16 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs()
.Get("EMBEDDING_RESULT", 0)
.packets[0]
.Get<EmbeddingResult>();
const EmbeddingResult& result =
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
entries {
quantized_embedding { values: "\x0d\x1a" } # 13,26
}
quantized_embedding { values: "\x0d\x1a" } # 13,26
head_index: 0
}
embeddings {
entries {
quantized_embedding { values: "\xe6\xda" } # -26,-38
}
quantized_embedding { values: "\xe6\xda" } # -26,-38
head_index: 1
})pb")));
}
@ -213,7 +193,7 @@ TEST(TensorsToEmbeddingsCalculatorTest,
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings"
output_stream: "EMBEDDINGS:embeddings"
options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: true quantize: true }
@ -224,25 +204,18 @@ TEST(TensorsToEmbeddingsCalculatorTest,
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs()
.Get("EMBEDDING_RESULT", 0)
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT(
result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
entries {
quantized_embedding { values: "\x39\x72" } # 57,114
}
head_index: 0
}
embeddings {
entries {
quantized_embedding { values: "\xb9\x95" } # -71,-107
}
head_index: 1
})pb")));
const EmbeddingResult& result =
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings {
quantized_embedding { values: "\x39\x72" } # 57,114
head_index: 0
}
embeddings {
quantized_embedding { values: "\xb9\x95" } # -71,-107
head_index: 1
})pb")));
}
} // namespace

View File

@ -49,3 +49,12 @@ cc_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
],
)
cc_library(
name = "embedding_result",
srcs = ["embedding_result.cc"],
hdrs = ["embedding_result.h"],
deps = [
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
],
)

View File

@ -0,0 +1,57 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include <iterator>
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
namespace mediapipe::tasks::components::containers {
Embedding ConvertToEmbedding(const proto::Embedding& proto) {
Embedding embedding;
if (proto.has_float_embedding()) {
embedding.float_embedding = {
std::make_move_iterator(proto.float_embedding().values().begin()),
std::make_move_iterator(proto.float_embedding().values().end())};
} else {
embedding.quantized_embedding = {
std::make_move_iterator(proto.quantized_embedding().values().begin()),
std::make_move_iterator(proto.quantized_embedding().values().end())};
}
embedding.head_index = proto.head_index();
if (proto.has_head_name()) {
embedding.head_name = proto.head_name();
}
return embedding;
}
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto) {
EmbeddingResult embedding_result;
embedding_result.embeddings.reserve(proto.embeddings_size());
for (const auto& embedding : proto.embeddings()) {
embedding_result.embeddings.push_back(ConvertToEmbedding(embedding));
}
if (proto.has_timestamp_ms()) {
embedding_result.timestamp_ms = proto.timestamp_ms();
}
return embedding_result;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,72 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
namespace mediapipe::tasks::components::containers {
// Embedding result for a given embedder head.
//
// One and only one of the two 'float_embedding' and 'quantized_embedding' will
// contain data, based on whether or not the embedder was configured to perform
// scalar quantization.
struct Embedding {
// Floating-point embedding. Empty if the embedder was configured to perform
// scalar-quantization.
std::vector<float> float_embedding;
// Scalar-quantized embedding. Empty if the embedder was not configured to
// perform scalar quantization.
std::string quantized_embedding;
// The index of the embedder head (i.e. output tensor) this embedding comes
// from. This is useful for multi-head models.
int head_index;
// The optional name of the embedder head, as provided in the TFLite Model
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
std::optional<std::string> head_name = std::nullopt;
};
// Defines embedding results of a model.
struct EmbeddingResult {
// The embedding results for each head of the model.
std::vector<Embedding> embeddings;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for embedding extraction on time series (e.g. audio
// embedding). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
std::optional<int64_t> timestamp_ms = std::nullopt;
};
// Utility function to convert from Embedding proto to Embedding struct.
Embedding ConvertToEmbedding(const proto::Embedding& proto);
// Utility function to convert from EmbeddingResult proto to EmbeddingResult
// struct.
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_

View File

@ -30,30 +30,31 @@ message QuantizedEmbedding {
optional bytes values = 1;
}
// Floating-point or scalar-quantized embedding with an optional timestamp.
message EmbeddingEntry {
// The actual embedding, either floating-point or scalar-quantized.
// Embedding result for a given embedder head.
message Embedding {
// The actual embedding, either floating-point or quantized.
oneof embedding {
FloatEmbedding float_embedding = 1;
QuantizedEmbedding quantized_embedding = 2;
}
// The optional timestamp (in milliseconds) associated to the embedding entry.
// This is useful for time series use cases, e.g. audio embedding.
optional int64 timestamp_ms = 3;
}
// Embeddings for a given embedder head.
message Embeddings {
repeated EmbeddingEntry entries = 1;
// The index of the embedder head that produced this embedding. This is useful
// for multi-head models.
optional int32 head_index = 2;
optional int32 head_index = 3;
// The name of the embedder head, which is the corresponding tensor metadata
// name (if any). This is useful for multi-head models.
optional string head_name = 3;
optional string head_name = 4;
}
// Contains one set of results per embedder head.
// Embedding results for a given embedder model.
message EmbeddingResult {
repeated Embeddings embeddings = 1;
// The embedding results for each model head, i.e. one for each output tensor.
repeated Embedding embeddings = 1;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for embedding extraction on time series (e.g. audio
// embedding). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
optional int64 timestamp_ms = 2;
}

View File

@ -62,3 +62,38 @@ cc_library(
],
alwayslink = 1,
)
cc_library(
name = "embedder_options",
srcs = ["embedder_options.cc"],
hdrs = ["embedder_options.h"],
deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"],
)
cc_library(
name = "embedding_postprocessing_graph",
srcs = ["embedding_postprocessing_graph.cc"],
hdrs = ["embedding_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)

View File

@ -13,22 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/embedder_options.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto(
proto::EmbedderOptions ConvertEmbedderOptionsToProto(
EmbedderOptions* embedder_options) {
tasks::components::proto::EmbedderOptions options_proto;
proto::EmbedderOptions options_proto;
options_proto.set_l2_normalize(embedder_options->l2_normalize);
options_proto.set_quantize(embedder_options->quantize);
return options_proto;
}
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
// Embedder options for MediaPipe C++ embedding extraction tasks.
struct EmbedderOptions {
@ -37,11 +38,12 @@ struct EmbedderOptions {
bool quantize;
};
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto(
proto::EmbedderOptions ConvertEmbedderOptionsToProto(
EmbedderOptions* embedder_options);
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include <string>
#include <vector>
@ -29,8 +29,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
@ -39,6 +39,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
namespace {
@ -49,13 +50,12 @@ using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::components::proto::EmbedderOptions;
using ::mediapipe::tasks::core::ModelResources;
using TensorsSource =
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
// Identifies whether or not the model has quantized outputs, and performs
// sanity checks.
@ -144,7 +144,7 @@ absl::StatusOr<std::vector<std::string>> GetHeadNames(
absl::Status ConfigureEmbeddingPostprocessing(
const ModelResources& model_resources,
const EmbedderOptions& embedder_options,
const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options) {
ASSIGN_OR_RETURN(bool has_quantized_outputs,
HasQuantizedOutputs(model_resources));
@ -188,7 +188,7 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
BuildEmbeddingPostprocessing(
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingResultTag)];
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
return graph.GetConfig();
}
@ -220,13 +220,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
.CopyFrom(options.tensors_to_embeddings_options());
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
return tensors_to_embeddings_node[Output<EmbeddingResult>(
kEmbeddingResultTag)];
return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::EmbeddingPostprocessingGraph);
::mediapipe::tasks::components::processors::EmbeddingPostprocessingGraph);
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
#include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
// Configures an EmbeddingPostprocessingGraph using the provided model resources
// and EmbedderOptions.
@ -44,18 +45,19 @@ namespace components {
// The output tensors of an InferenceCalculator, to convert into
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
// Outputs:
// EMBEDDING_RESULT - EmbeddingResult
// EMBEDDINGS - EmbeddingResult
// The output EmbeddingResult.
//
// TODO: add support for additional optional "TIMESTAMPS" input for
// embeddings aggregation.
absl::Status ConfigureEmbeddingPostprocessing(
const tasks::core::ModelResources& model_resources,
const tasks::components::proto::EmbedderOptions& embedder_options,
const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options);
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include <memory>
@ -25,8 +25,8 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -34,12 +34,10 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::proto::EmbedderOptions;
using ::mediapipe::tasks::components::proto::
EmbeddingPostprocessingGraphOptions;
using ::mediapipe::tasks::core::ModelResources;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
@ -69,68 +67,72 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
EmbedderOptions options_in;
proto::EmbedderOptions options_in;
options_in.set_l2_normalize(true);
EmbeddingPostprocessingGraphOptions options_out;
proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out));
EXPECT_THAT(
options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
R"pb(tensors_to_embeddings_options {
embedder_options { l2_normalize: true }
head_names: "probability"
}
has_quantized_outputs: true)pb")));
EqualsProto(
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
R"pb(tensors_to_embeddings_options {
embedder_options { l2_normalize: true }
head_names: "probability"
}
has_quantized_outputs: true)pb")));
}
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
EmbedderOptions options_in;
proto::EmbedderOptions options_in;
options_in.set_quantize(true);
EmbeddingPostprocessingGraphOptions options_out;
proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out));
EXPECT_THAT(
options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
R"pb(tensors_to_embeddings_options {
embedder_options { quantize: true }
}
has_quantized_outputs: true)pb")));
EqualsProto(
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
R"pb(tensors_to_embeddings_options {
embedder_options { quantize: true }
}
has_quantized_outputs: true)pb")));
}
TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
CreateModelResourcesForModel(kMobileNetV3Embedder));
EmbedderOptions options_in;
proto::EmbedderOptions options_in;
options_in.set_quantize(true);
options_in.set_l2_normalize(true);
EmbeddingPostprocessingGraphOptions options_out;
proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out));
EXPECT_THAT(
options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
R"pb(tensors_to_embeddings_options {
embedder_options { quantize: true l2_normalize: true }
head_names: "feature"
}
has_quantized_outputs: false)pb")));
EqualsProto(
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
R"pb(tensors_to_embeddings_options {
embedder_options { quantize: true l2_normalize: true }
head_names: "feature"
}
has_quantized_outputs: false)pb")));
}
// TODO: add E2E Postprocessing tests once timestamp aggregation is
// supported.
} // namespace
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -34,3 +34,18 @@ mediapipe_proto_library(
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
],
)
mediapipe_proto_library(
name = "embedder_options_proto",
srcs = ["embedder_options.proto"],
)
mediapipe_proto_library(
name = "embedding_postprocessing_graph_options_proto",
srcs = ["embedding_postprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
],
)

View File

@ -15,7 +15,10 @@ limitations under the License.
syntax = "proto2";
package mediapipe.tasks.components.proto;
package mediapipe.tasks.components.processors.proto;
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
option java_outer_classname = "EmbedderOptionsProto";
// Shared options used by all embedding extraction tasks.
message EmbedderOptions {

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2";
package mediapipe.tasks.components.proto;
package mediapipe.tasks.components.processors.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto";

View File

@ -23,21 +23,6 @@ mediapipe_proto_library(
srcs = ["segmenter_options.proto"],
)
mediapipe_proto_library(
name = "embedder_options_proto",
srcs = ["embedder_options.proto"],
)
mediapipe_proto_library(
name = "embedding_postprocessing_graph_options_proto",
srcs = ["embedding_postprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
],
)
mediapipe_proto_library(
name = "text_preprocessing_graph_options_proto",
srcs = ["text_preprocessing_graph_options.proto"],

View File

@ -26,7 +26,7 @@ cc_library(
hdrs = ["cosine_similarity.h"],
deps = [
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/containers:embedding_result",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
@ -39,7 +39,7 @@ cc_test(
deps = [
":cosine_similarity",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/containers:embedding_result",
],
)

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe {
namespace tasks {
@ -30,7 +30,7 @@ namespace utils {
namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
using ::mediapipe::tasks::components::containers::Embedding;
template <typename T>
absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
@ -66,39 +66,35 @@ absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
// an L2-norm of 0.
//
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
absl::StatusOr<double> CosineSimilarity(const EmbeddingEntry& u,
const EmbeddingEntry& v) {
if (u.has_float_embedding() && v.has_float_embedding()) {
if (u.float_embedding().values().size() !=
v.float_embedding().values().size()) {
absl::StatusOr<double> CosineSimilarity(const Embedding& u,
const Embedding& v) {
if (!u.float_embedding.empty() && !v.float_embedding.empty()) {
if (u.float_embedding.size() != v.float_embedding.size()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Cannot compute cosine similarity between embeddings "
"of different sizes (%d vs. %d)",
u.float_embedding().values().size(),
v.float_embedding().values().size()),
u.float_embedding.size(), v.float_embedding.size()),
MediaPipeTasksStatus::kInvalidArgumentError);
}
return ComputeCosineSimilarity(u.float_embedding().values().data(),
v.float_embedding().values().data(),
u.float_embedding().values().size());
return ComputeCosineSimilarity(u.float_embedding.data(),
v.float_embedding.data(),
u.float_embedding.size());
}
if (u.has_quantized_embedding() && v.has_quantized_embedding()) {
if (u.quantized_embedding().values().size() !=
v.quantized_embedding().values().size()) {
if (!u.quantized_embedding.empty() && !v.quantized_embedding.empty()) {
if (u.quantized_embedding.size() != v.quantized_embedding.size()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Cannot compute cosine similarity between embeddings "
"of different sizes (%d vs. %d)",
u.quantized_embedding().values().size(),
v.quantized_embedding().values().size()),
u.quantized_embedding.size(),
v.quantized_embedding.size()),
MediaPipeTasksStatus::kInvalidArgumentError);
}
return ComputeCosineSimilarity(reinterpret_cast<const int8_t*>(
u.quantized_embedding().values().data()),
reinterpret_cast<const int8_t*>(
v.quantized_embedding().values().data()),
u.quantized_embedding().values().size());
return ComputeCosineSimilarity(
reinterpret_cast<const int8_t*>(u.quantized_embedding.data()),
reinterpret_cast<const int8_t*>(v.quantized_embedding.data()),
u.quantized_embedding.size());
}
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,

View File

@ -17,22 +17,20 @@ limitations under the License.
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_
#include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace utils {
// Utility function to compute cosine similarity [1] between two embedding
// entries. May return an InvalidArgumentError if e.g. the feature vectors are
// of different types (quantized vs. float), have different sizes, or have a
// an L2-norm of 0.
// Utility function to compute cosine similarity [1] between two embeddings. May
// return an InvalidArgumentError if e.g. the embeddings are of different types
// (quantized vs. float), have different sizes, or have a an L2-norm of 0.
//
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
absl::StatusOr<double> CosineSimilarity(
const containers::proto::EmbeddingEntry& u,
const containers::proto::EmbeddingEntry& v);
absl::StatusOr<double> CosineSimilarity(const containers::Embedding& u,
const containers::Embedding& v);
} // namespace utils
} // namespace components

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe {
namespace tasks {
@ -30,29 +30,27 @@ namespace components {
namespace utils {
namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
using ::mediapipe::tasks::components::containers::Embedding;
using ::testing::HasSubstr;
// Helper function to generate float EmbeddingEntry.
EmbeddingEntry BuildFloatEntry(std::vector<float> values) {
EmbeddingEntry entry;
for (const float value : values) {
entry.mutable_float_embedding()->add_values(value);
}
return entry;
// Helper function to generate float Embedding.
Embedding BuildFloatEmbedding(std::vector<float> values) {
Embedding embedding;
embedding.float_embedding = values;
return embedding;
}
// Helper function to generate quantized EmbeddingEntry.
EmbeddingEntry BuildQuantizedEntry(std::vector<int8_t> values) {
EmbeddingEntry entry;
entry.mutable_quantized_embedding()->set_values(
reinterpret_cast<uint8_t*>(values.data()), values.size());
return entry;
// Helper function to generate quantized Embedding.
Embedding BuildQuantizedEmbedding(std::vector<int8_t> values) {
Embedding embedding;
uint8_t* data = reinterpret_cast<uint8_t*>(values.data());
embedding.quantized_embedding = {data, data + values.size()};
return embedding;
}
TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
auto u = BuildFloatEntry({0.1, 0.2});
auto v = BuildQuantizedEntry({0, 1});
auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildQuantizedEmbedding({0, 1});
auto status = CosineSimilarity(u, v);
@ -63,8 +61,8 @@ TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
}
TEST(CosineSimilarity, FailsWithZeroNorm) {
auto u = BuildFloatEntry({0.1, 0.2});
auto v = BuildFloatEntry({0.0, 0.0});
auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildFloatEmbedding({0.0, 0.0});
auto status = CosineSimilarity(u, v);
@ -75,8 +73,8 @@ TEST(CosineSimilarity, FailsWithZeroNorm) {
}
TEST(CosineSimilarity, FailsWithDifferentSizes) {
auto u = BuildFloatEntry({0.1, 0.2});
auto v = BuildFloatEntry({0.1, 0.2, 0.3});
auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildFloatEmbedding({0.1, 0.2, 0.3});
auto status = CosineSimilarity(u, v);
@ -87,8 +85,8 @@ TEST(CosineSimilarity, FailsWithDifferentSizes) {
}
TEST(CosineSimilarity, SucceedsWithFloatEntries) {
auto u = BuildFloatEntry({1.0, 0.0, 0.0, 0.0});
auto v = BuildFloatEntry({0.5, 0.5, 0.5, 0.5});
auto u = BuildFloatEmbedding({1.0, 0.0, 0.0, 0.0});
auto v = BuildFloatEmbedding({0.5, 0.5, 0.5, 0.5});
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
@ -96,8 +94,8 @@ TEST(CosineSimilarity, SucceedsWithFloatEntries) {
}
TEST(CosineSimilarity, SucceedsWithQuantizedEntries) {
auto u = BuildQuantizedEntry({127, 0, 0, 0});
auto v = BuildQuantizedEntry({-128, 0, 0, 0});
auto u = BuildQuantizedEmbedding({127, 0, 0, 0});
auto v = BuildQuantizedEmbedding({-128, 0, 0, 0});
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));

View File

@ -26,12 +26,12 @@ cc_library(
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",
@ -49,9 +49,10 @@ cc_library(
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc/components:embedder_options",
"//mediapipe/tasks/cc/components/containers:embedding_result",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedder_options",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner",

View File

@ -21,9 +21,10 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/tool/options_map.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedder_options.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
@ -41,8 +42,8 @@ namespace image_embedder {
namespace {
constexpr char kEmbeddingResultStreamName[] = "embedding_result_out";
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
constexpr char kEmbeddingsStreamName[] = "embeddings_out";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE";
@ -53,7 +54,7 @@ constexpr char kGraphTypeName[] =
"mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::vision::image_embedder::proto::
@ -71,13 +72,13 @@ CalculatorGraphConfig CreateGraphConfig(
graph.In(kNormRectTag).SetName(kNormRectStreamName);
auto& task_graph = graph.AddNode(kGraphTypeName);
task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get());
task_graph.Out(kEmbeddingResultTag).SetName(kEmbeddingResultStreamName) >>
graph.Out(kEmbeddingResultTag);
task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >>
graph.Out(kEmbeddingsTag);
task_graph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(
graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingResultTag);
graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingsTag);
}
graph.In(kImageTag) >> task_graph.In(kImageTag);
graph.In(kNormRectTag) >> task_graph.In(kNormRectTag);
@ -95,8 +96,8 @@ std::unique_ptr<ImageEmbedderGraphOptions> ConvertImageEmbedderOptionsToProto(
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE);
auto embedder_options_proto =
std::make_unique<tasks::components::proto::EmbedderOptions>(
components::ConvertEmbedderOptionsToProto(
std::make_unique<components::processors::proto::EmbedderOptions>(
components::processors::ConvertEmbedderOptionsToProto(
&(options->embedder_options)));
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
return options_proto;
@ -121,9 +122,10 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
return;
}
Packet embedding_result_packet =
status_or_packets.value()[kEmbeddingResultStreamName];
status_or_packets.value()[kEmbeddingsStreamName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(embedding_result_packet.Get<EmbeddingResult>(),
result_callback(ConvertToEmbeddingResult(
embedding_result_packet.Get<EmbeddingResult>()),
image_packet.Get<Image>(),
embedding_result_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
@ -138,7 +140,7 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
std::move(packets_callback));
}
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
absl::StatusOr<ImageEmbedderResult> ImageEmbedder::Embed(
Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
@ -155,10 +157,11 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
return ConvertToEmbeddingResult(
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
}
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
absl::StatusOr<ImageEmbedderResult> ImageEmbedder::EmbedForVideo(
Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
@ -178,7 +181,8 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
return ConvertToEmbeddingResult(
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
}
absl::Status ImageEmbedder::EmbedAsync(
@ -202,7 +206,8 @@ absl::Status ImageEmbedder::EmbedAsync(
}
absl::StatusOr<double> ImageEmbedder::CosineSimilarity(
const EmbeddingEntry& u, const EmbeddingEntry& v) {
const components::containers::Embedding& u,
const components::containers::Embedding& v) {
return components::utils::CosineSimilarity(u, v);
}

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedder_options.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
@ -33,6 +33,10 @@ namespace tasks {
namespace vision {
namespace image_embedder {
// Alias the shared EmbeddingResult struct as result typo.
using ImageEmbedderResult =
::mediapipe::tasks::components::containers::EmbeddingResult;
// The options for configuring a MediaPipe image embedder task.
struct ImageEmbedderOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
@ -50,14 +54,12 @@ struct ImageEmbedderOptions {
// Options for configuring the embedder behavior, such as L2-normalization or
// scalar-quantization.
components::EmbedderOptions embedder_options;
components::processors::EmbedderOptions embedder_options;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(
absl::StatusOr<components::containers::proto::EmbeddingResult>,
const Image&, int64)>
std::function<void(absl::StatusOr<ImageEmbedderResult>, const Image&, int64)>
result_callback = nullptr;
};
@ -104,7 +106,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// running mode.
//
// The image can be of any size with format RGB or RGBA.
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed(
absl::StatusOr<ImageEmbedderResult> Embed(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
@ -127,7 +129,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo(
absl::StatusOr<ImageEmbedderResult> EmbedForVideo(
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
@ -168,15 +170,15 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// Shuts down the ImageEmbedder when all works are done.
absl::Status Close() { return runner_->Close(); }
// Utility function to compute cosine similarity [1] between two embedding
// entries. May return an InvalidArgumentError if e.g. the feature vectors are
// of different types (quantized vs. float), have different sizes, or have a
// an L2-norm of 0.
// Utility function to compute cosine similarity [1] between two embeddings.
// May return an InvalidArgumentError if e.g. the embeddings are of different
// types (quantized vs. float), have different sizes, or have a an L2-norm of
// 0.
//
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
static absl::StatusOr<double> CosineSimilarity(
const components::containers::proto::EmbeddingEntry& u,
const components::containers::proto::EmbeddingEntry& v);
const components::containers::Embedding& u,
const components::containers::Embedding& v);
};
} // namespace image_embedder

View File

@ -20,10 +20,10 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
@ -40,10 +40,8 @@ using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::components::proto::
EmbeddingPostprocessingGraphOptions;
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS";
@ -67,7 +65,7 @@ struct ImageEmbedderOutputStreams {
// Describes region of image to perform embedding extraction on.
// @Optional: rect covering the whole image is used if not specified.
// Outputs:
// EMBEDDING_RESULT - EmbeddingResult
// EMBEDDINGS - EmbeddingResult
// The embedding result.
// IMAGE - Image
// The image that embedding extraction runs on.
@ -76,7 +74,7 @@ struct ImageEmbedderOutputStreams {
// node {
// calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"
// input_stream: "IMAGE:image_in"
// output_stream: "EMBEDDING_RESULT:embedding_result_out"
// output_stream: "EMBEDDINGS:embedding_result_out"
// output_stream: "IMAGE:image_out"
// options {
// [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext]
@ -107,7 +105,7 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.embedding_result >>
graph[Output<EmbeddingResult>(kEmbeddingResultTag)];
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
output_streams.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
@ -152,16 +150,17 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
// Adds postprocessing calculators and connects its input stream to the
// inference results.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::ConfigureEmbeddingPostprocessing(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
model_resources, task_options.embedder_options(),
&postprocessing.GetOptions<EmbeddingPostprocessingGraphOptions>()));
&postprocessing.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding results.
return ImageEmbedderOutputStreams{
/*embedding_result=*/postprocessing[Output<EmbeddingResult>(
kEmbeddingResultTag)],
kEmbeddingsTag)],
/*image=*/preprocessing[Output<Image>(kImageTag)]};
}
};

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/api/op_resolver.h"
@ -42,7 +42,6 @@ namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
@ -54,18 +53,14 @@ constexpr double kSimilarityTolerancy = 1e-6;
// Utility function to check the sizes, head_index and head_names of a result
// procuded by kMobileNetV3Embedder.
void CheckMobileNetV3Result(const EmbeddingResult& result, bool quantized) {
EXPECT_EQ(result.embeddings().size(), 1);
EXPECT_EQ(result.embeddings(0).head_index(), 0);
EXPECT_EQ(result.embeddings(0).head_name(), "feature");
EXPECT_EQ(result.embeddings(0).entries().size(), 1);
void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
EXPECT_EQ(result.embeddings.size(), 1);
EXPECT_EQ(result.embeddings[0].head_index, 0);
EXPECT_EQ(result.embeddings[0].head_name, "feature");
if (quantized) {
EXPECT_EQ(
result.embeddings(0).entries(0).quantized_embedding().values().size(),
1024);
EXPECT_EQ(result.embeddings[0].quantized_embedding.size(), 1024);
} else {
EXPECT_EQ(result.embeddings(0).entries(0).float_embedding().values().size(),
1024);
EXPECT_EQ(result.embeddings[0].float_embedding.size(), 1024);
}
}
@ -154,7 +149,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = running_mode;
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {};
auto image_embedder = ImageEmbedder::Create(std::move(options));
@ -231,19 +226,18 @@ TEST_F(ImageModeTest, SucceedsWithoutL2Normalization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop));
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
crop_result.embeddings(0).entries(0)));
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
image_result.embeddings[0],
crop_result.embeddings[0]));
double expected_similarity = 0.925519;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
@ -264,19 +258,18 @@ TEST_F(ImageModeTest, SucceedsWithL2Normalization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop));
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
crop_result.embeddings(0).entries(0)));
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
image_result.embeddings[0],
crop_result.embeddings[0]));
double expected_similarity = 0.925519;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
@ -297,19 +290,18 @@ TEST_F(ImageModeTest, SucceedsWithQuantization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop));
// Check results.
CheckMobileNetV3Result(image_result, true);
CheckMobileNetV3Result(crop_result, true);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
crop_result.embeddings(0).entries(0)));
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
image_result.embeddings[0],
crop_result.embeddings[0]));
double expected_similarity = 0.926791;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
@ -333,19 +325,18 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& image_result,
const ImageEmbedderResult& image_result,
image_embedder->Embed(image, image_processing_options));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop));
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
crop_result.embeddings(0).entries(0)));
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
image_result.embeddings[0],
crop_result.embeddings[0]));
double expected_similarity = 0.999931;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
@ -367,20 +358,19 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
image_processing_options.rotation_degrees = -90;
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result,
const ImageEmbedderResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options));
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
rotated_result.embeddings(0).entries(0)));
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
image_result.embeddings[0],
rotated_result.embeddings[0]));
double expected_similarity = 0.572265;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
@ -403,20 +393,19 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
/*rotation_degrees=*/-90};
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result,
const ImageEmbedderResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options));
// Check results.
CheckMobileNetV3Result(crop_result, false);
CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0),
rotated_result.embeddings(0).entries(0)));
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
crop_result.embeddings[0],
rotated_result.embeddings[0]));
double expected_similarity = 0.62838;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
@ -487,16 +476,16 @@ TEST_F(VideoModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
EmbeddingResult previous_results;
ImageEmbedderResult previous_results;
for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto results,
image_embedder->EmbedForVideo(image, i));
CheckMobileNetV3Result(results, false);
if (i > 0) {
MP_ASSERT_OK_AND_ASSIGN(double similarity,
ImageEmbedder::CosineSimilarity(
results.embeddings(0).entries(0),
previous_results.embeddings(0).entries(0)));
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(results.embeddings[0],
previous_results.embeddings[0]));
double expected_similarity = 1.000000;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
@ -515,7 +504,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
@ -546,7 +535,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
@ -564,7 +553,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
}
struct LiveStreamModeResults {
EmbeddingResult embedding_result;
ImageEmbedderResult embedding_result;
std::pair<int, int> image_size;
int64 timestamp_ms;
};
@ -580,7 +569,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback =
[&results](absl::StatusOr<EmbeddingResult> embedding_result,
[&results](absl::StatusOr<ImageEmbedderResult> embedding_result,
const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(embedding_result.status());
results.push_back(
@ -612,8 +601,8 @@ TEST_F(LiveStreamModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(
result.embedding_result.embeddings(0).entries(0),
results[i - 1].embedding_result.embeddings(0).entries(0)));
result.embedding_result.embeddings[0],
results[i - 1].embedding_result.embeddings[0]));
double expected_similarity = 1.000000;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}

View File

@ -24,7 +24,7 @@ mediapipe_proto_library(
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.image_embedder.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/embedder_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message ImageEmbedderGraphOptions {
@ -31,5 +31,5 @@ message ImageEmbedderGraphOptions {
// Options for configuring the embedder behavior, such as normalization or
// quantization.
optional components.proto.EmbedderOptions embedder_options = 2;
optional components.processors.proto.EmbedderOptions embedder_options = 2;
}

View File

@ -23,6 +23,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",