Update EmbeddingResult format and dependent tasks.
PiperOrigin-RevId: 486186491
This commit is contained in:
parent
66e591d4bc
commit
5e1a2fcdbb
|
@ -61,41 +61,6 @@ cc_library(
|
||||||
|
|
||||||
# TODO: Enable this test
|
# TODO: Enable this test
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "embedder_options",
|
|
||||||
srcs = ["embedder_options.cc"],
|
|
||||||
hdrs = ["embedder_options.h"],
|
|
||||||
deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "embedding_postprocessing_graph",
|
|
||||||
srcs = ["embedding_postprocessing_graph.cc"],
|
|
||||||
hdrs = ["embedding_postprocessing_graph.h"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
|
||||||
"//mediapipe/framework:calculator_framework",
|
|
||||||
"//mediapipe/framework/api2:builder",
|
|
||||||
"//mediapipe/framework/api2:port",
|
|
||||||
"//mediapipe/framework/formats:tensor",
|
|
||||||
"//mediapipe/framework/tool:options_map",
|
|
||||||
"//mediapipe/tasks/cc:common",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
|
||||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Investigate rewriting the build rule to only link
|
# TODO: Investigate rewriting the build rule to only link
|
||||||
# the Bert Preprocessor if it's needed.
|
# the Bert Preprocessor if it's needed.
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -163,7 +163,7 @@ mediapipe_proto_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:embedder_options_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
],
|
],
|
||||||
|
|
|
@ -26,14 +26,14 @@
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
using ::mediapipe::tasks::components::containers::proto::Embedding;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
|
||||||
// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
|
// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
|
||||||
|
@ -66,7 +66,7 @@ float GetInverseL2Norm(const float* values, int size) {
|
||||||
class TensorsToEmbeddingsCalculator : public Node {
|
class TensorsToEmbeddingsCalculator : public Node {
|
||||||
public:
|
public:
|
||||||
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
|
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
|
||||||
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDING_RESULT"};
|
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDINGS"};
|
||||||
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut);
|
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut);
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) override;
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
@ -77,8 +77,8 @@ class TensorsToEmbeddingsCalculator : public Node {
|
||||||
bool quantize_;
|
bool quantize_;
|
||||||
std::vector<std::string> head_names_;
|
std::vector<std::string> head_names_;
|
||||||
|
|
||||||
void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
|
void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding);
|
||||||
void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
|
void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding);
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
|
absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
|
||||||
|
@ -104,42 +104,42 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
|
||||||
for (int i = 0; i < tensors.size(); ++i) {
|
for (int i = 0; i < tensors.size(); ++i) {
|
||||||
const auto& tensor = tensors[i];
|
const auto& tensor = tensors[i];
|
||||||
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
|
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
|
||||||
auto* embeddings = result.add_embeddings();
|
auto* embedding = result.add_embeddings();
|
||||||
embeddings->set_head_index(i);
|
embedding->set_head_index(i);
|
||||||
if (!head_names_.empty()) {
|
if (!head_names_.empty()) {
|
||||||
embeddings->set_head_name(head_names_[i]);
|
embedding->set_head_name(head_names_[i]);
|
||||||
}
|
}
|
||||||
if (quantize_) {
|
if (quantize_) {
|
||||||
FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries());
|
FillQuantizedEmbedding(tensor, embedding);
|
||||||
} else {
|
} else {
|
||||||
FillFloatEmbeddingEntry(tensor, embeddings->add_entries());
|
FillFloatEmbedding(tensor, embedding);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
kEmbeddingsOut(cc).Send(result);
|
kEmbeddingsOut(cc).Send(result);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry(
|
void TensorsToEmbeddingsCalculator::FillFloatEmbedding(const Tensor& tensor,
|
||||||
const Tensor& tensor, EmbeddingEntry* entry) {
|
Embedding* embedding) {
|
||||||
int size = tensor.shape().num_elements();
|
int size = tensor.shape().num_elements();
|
||||||
auto tensor_view = tensor.GetCpuReadView();
|
auto tensor_view = tensor.GetCpuReadView();
|
||||||
const float* tensor_buffer = tensor_view.buffer<float>();
|
const float* tensor_buffer = tensor_view.buffer<float>();
|
||||||
float inv_l2_norm =
|
float inv_l2_norm =
|
||||||
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
||||||
auto* float_embedding = entry->mutable_float_embedding();
|
auto* float_embedding = embedding->mutable_float_embedding();
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
float_embedding->add_values(tensor_buffer[i] * inv_l2_norm);
|
float_embedding->add_values(tensor_buffer[i] * inv_l2_norm);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry(
|
void TensorsToEmbeddingsCalculator::FillQuantizedEmbedding(
|
||||||
const Tensor& tensor, EmbeddingEntry* entry) {
|
const Tensor& tensor, Embedding* embedding) {
|
||||||
int size = tensor.shape().num_elements();
|
int size = tensor.shape().num_elements();
|
||||||
auto tensor_view = tensor.GetCpuReadView();
|
auto tensor_view = tensor.GetCpuReadView();
|
||||||
const float* tensor_buffer = tensor_view.buffer<float>();
|
const float* tensor_buffer = tensor_view.buffer<float>();
|
||||||
float inv_l2_norm =
|
float inv_l2_norm =
|
||||||
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
||||||
auto* values = entry->mutable_quantized_embedding()->mutable_values();
|
auto* values = embedding->mutable_quantized_embedding()->mutable_values();
|
||||||
values->resize(size);
|
values->resize(size);
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
// Normalize.
|
// Normalize.
|
||||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe;
|
package mediapipe;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/proto/embedder_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||||
|
|
||||||
message TensorsToEmbeddingsCalculatorOptions {
|
message TensorsToEmbeddingsCalculatorOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
|
@ -27,8 +27,8 @@ message TensorsToEmbeddingsCalculatorOptions {
|
||||||
|
|
||||||
// The embedder options defining whether to L2-normalize or scalar-quantize
|
// The embedder options defining whether to L2-normalize or scalar-quantize
|
||||||
// the outputs.
|
// the outputs.
|
||||||
optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options =
|
optional mediapipe.tasks.components.processors.proto.EmbedderOptions
|
||||||
1;
|
embedder_options = 1;
|
||||||
|
|
||||||
// The embedder head names.
|
// The embedder head names.
|
||||||
repeated string head_names = 2;
|
repeated string head_names = 2;
|
||||||
|
|
|
@ -55,7 +55,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" }
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" }
|
||||||
}
|
}
|
||||||
|
@ -73,7 +73,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
embedder_options { l2_normalize: false quantize: false }
|
embedder_options { l2_normalize: false quantize: false }
|
||||||
|
@ -84,19 +84,15 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
|
||||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const EmbeddingResult& result = runner.Outputs()
|
const EmbeddingResult& result =
|
||||||
.Get("EMBEDDING_RESULT", 0)
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
.packets[0]
|
EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
.Get<EmbeddingResult>();
|
|
||||||
EXPECT_THAT(
|
|
||||||
result,
|
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
|
||||||
R"pb(embeddings {
|
R"pb(embeddings {
|
||||||
entries { float_embedding { values: 0.1 values: 0.2 } }
|
float_embedding { values: 0.1 values: 0.2 }
|
||||||
head_index: 0
|
head_index: 0
|
||||||
}
|
}
|
||||||
embeddings {
|
embeddings {
|
||||||
entries { float_embedding { values: -0.2 values: -0.3 } }
|
float_embedding { values: -0.2 values: -0.3 }
|
||||||
head_index: 1
|
head_index: 1
|
||||||
})pb")));
|
})pb")));
|
||||||
}
|
}
|
||||||
|
@ -105,7 +101,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
embedder_options { l2_normalize: false quantize: false }
|
embedder_options { l2_normalize: false quantize: false }
|
||||||
|
@ -118,20 +114,16 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
|
||||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const EmbeddingResult& result = runner.Outputs()
|
const EmbeddingResult& result =
|
||||||
.Get("EMBEDDING_RESULT", 0)
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
.packets[0]
|
EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
.Get<EmbeddingResult>();
|
|
||||||
EXPECT_THAT(
|
|
||||||
result,
|
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
|
||||||
R"pb(embeddings {
|
R"pb(embeddings {
|
||||||
entries { float_embedding { values: 0.1 values: 0.2 } }
|
float_embedding { values: 0.1 values: 0.2 }
|
||||||
head_index: 0
|
head_index: 0
|
||||||
head_name: "foo"
|
head_name: "foo"
|
||||||
}
|
}
|
||||||
embeddings {
|
embeddings {
|
||||||
entries { float_embedding { values: -0.2 values: -0.3 } }
|
float_embedding { values: -0.2 values: -0.3 }
|
||||||
head_index: 1
|
head_index: 1
|
||||||
head_name: "bar"
|
head_name: "bar"
|
||||||
})pb")));
|
})pb")));
|
||||||
|
@ -141,7 +133,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
embedder_options { l2_normalize: true quantize: false }
|
embedder_options { l2_normalize: true quantize: false }
|
||||||
|
@ -152,23 +144,17 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
|
||||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const EmbeddingResult& result = runner.Outputs()
|
const EmbeddingResult& result =
|
||||||
.Get("EMBEDDING_RESULT", 0)
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
.packets[0]
|
|
||||||
.Get<EmbeddingResult>();
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
result,
|
result,
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
R"pb(embeddings {
|
R"pb(embeddings {
|
||||||
entries {
|
|
||||||
float_embedding { values: 0.44721356 values: 0.8944271 }
|
float_embedding { values: 0.44721356 values: 0.8944271 }
|
||||||
}
|
|
||||||
head_index: 0
|
head_index: 0
|
||||||
}
|
}
|
||||||
embeddings {
|
embeddings {
|
||||||
entries {
|
|
||||||
float_embedding { values: -0.5547002 values: -0.8320503 }
|
float_embedding { values: -0.5547002 values: -0.8320503 }
|
||||||
}
|
|
||||||
head_index: 1
|
head_index: 1
|
||||||
})pb")));
|
})pb")));
|
||||||
}
|
}
|
||||||
|
@ -177,7 +163,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
embedder_options { l2_normalize: false quantize: true }
|
embedder_options { l2_normalize: false quantize: true }
|
||||||
|
@ -188,22 +174,16 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
|
||||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const EmbeddingResult& result = runner.Outputs()
|
const EmbeddingResult& result =
|
||||||
.Get("EMBEDDING_RESULT", 0)
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
.packets[0]
|
|
||||||
.Get<EmbeddingResult>();
|
|
||||||
EXPECT_THAT(result,
|
EXPECT_THAT(result,
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
R"pb(embeddings {
|
R"pb(embeddings {
|
||||||
entries {
|
|
||||||
quantized_embedding { values: "\x0d\x1a" } # 13,26
|
quantized_embedding { values: "\x0d\x1a" } # 13,26
|
||||||
}
|
|
||||||
head_index: 0
|
head_index: 0
|
||||||
}
|
}
|
||||||
embeddings {
|
embeddings {
|
||||||
entries {
|
|
||||||
quantized_embedding { values: "\xe6\xda" } # -26,-38
|
quantized_embedding { values: "\xe6\xda" } # -26,-38
|
||||||
}
|
|
||||||
head_index: 1
|
head_index: 1
|
||||||
})pb")));
|
})pb")));
|
||||||
}
|
}
|
||||||
|
@ -213,7 +193,7 @@ TEST(TensorsToEmbeddingsCalculatorTest,
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
embedder_options { l2_normalize: true quantize: true }
|
embedder_options { l2_normalize: true quantize: true }
|
||||||
|
@ -224,23 +204,16 @@ TEST(TensorsToEmbeddingsCalculatorTest,
|
||||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const EmbeddingResult& result = runner.Outputs()
|
const EmbeddingResult& result =
|
||||||
.Get("EMBEDDING_RESULT", 0)
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
.packets[0]
|
EXPECT_THAT(result,
|
||||||
.Get<EmbeddingResult>();
|
|
||||||
EXPECT_THAT(
|
|
||||||
result,
|
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
R"pb(embeddings {
|
R"pb(embeddings {
|
||||||
entries {
|
|
||||||
quantized_embedding { values: "\x39\x72" } # 57,114
|
quantized_embedding { values: "\x39\x72" } # 57,114
|
||||||
}
|
|
||||||
head_index: 0
|
head_index: 0
|
||||||
}
|
}
|
||||||
embeddings {
|
embeddings {
|
||||||
entries {
|
|
||||||
quantized_embedding { values: "\xb9\x95" } # -71,-107
|
quantized_embedding { values: "\xb9\x95" } # -71,-107
|
||||||
}
|
|
||||||
head_index: 1
|
head_index: 1
|
||||||
})pb")));
|
})pb")));
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,3 +49,12 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedding_result",
|
||||||
|
srcs = ["embedding_result.cc"],
|
||||||
|
hdrs = ["embedding_result.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
57
mediapipe/tasks/cc/components/containers/embedding_result.cc
Normal file
57
mediapipe/tasks/cc/components/containers/embedding_result.cc
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
|
||||||
|
#include <iterator>
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
Embedding ConvertToEmbedding(const proto::Embedding& proto) {
|
||||||
|
Embedding embedding;
|
||||||
|
if (proto.has_float_embedding()) {
|
||||||
|
embedding.float_embedding = {
|
||||||
|
std::make_move_iterator(proto.float_embedding().values().begin()),
|
||||||
|
std::make_move_iterator(proto.float_embedding().values().end())};
|
||||||
|
} else {
|
||||||
|
embedding.quantized_embedding = {
|
||||||
|
std::make_move_iterator(proto.quantized_embedding().values().begin()),
|
||||||
|
std::make_move_iterator(proto.quantized_embedding().values().end())};
|
||||||
|
}
|
||||||
|
embedding.head_index = proto.head_index();
|
||||||
|
if (proto.has_head_name()) {
|
||||||
|
embedding.head_name = proto.head_name();
|
||||||
|
}
|
||||||
|
return embedding;
|
||||||
|
}
|
||||||
|
|
||||||
|
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto) {
|
||||||
|
EmbeddingResult embedding_result;
|
||||||
|
embedding_result.embeddings.reserve(proto.embeddings_size());
|
||||||
|
for (const auto& embedding : proto.embeddings()) {
|
||||||
|
embedding_result.embeddings.push_back(ConvertToEmbedding(embedding));
|
||||||
|
}
|
||||||
|
if (proto.has_timestamp_ms()) {
|
||||||
|
embedding_result.timestamp_ms = proto.timestamp_ms();
|
||||||
|
}
|
||||||
|
return embedding_result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::components::containers
|
72
mediapipe/tasks/cc/components/containers/embedding_result.h
Normal file
72
mediapipe/tasks/cc/components/containers/embedding_result.h
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
// Embedding result for a given embedder head.
|
||||||
|
//
|
||||||
|
// One and only one of the two 'float_embedding' and 'quantized_embedding' will
|
||||||
|
// contain data, based on whether or not the embedder was configured to perform
|
||||||
|
// scalar quantization.
|
||||||
|
struct Embedding {
|
||||||
|
// Floating-point embedding. Empty if the embedder was configured to perform
|
||||||
|
// scalar-quantization.
|
||||||
|
std::vector<float> float_embedding;
|
||||||
|
// Scalar-quantized embedding. Empty if the embedder was not configured to
|
||||||
|
// perform scalar quantization.
|
||||||
|
std::string quantized_embedding;
|
||||||
|
// The index of the embedder head (i.e. output tensor) this embedding comes
|
||||||
|
// from. This is useful for multi-head models.
|
||||||
|
int head_index;
|
||||||
|
// The optional name of the embedder head, as provided in the TFLite Model
|
||||||
|
// Metadata [1] if present. This is useful for multi-head models.
|
||||||
|
//
|
||||||
|
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||||
|
std::optional<std::string> head_name = std::nullopt;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Defines embedding results of a model.
|
||||||
|
struct EmbeddingResult {
|
||||||
|
// The embedding results for each head of the model.
|
||||||
|
std::vector<Embedding> embeddings;
|
||||||
|
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||||
|
// corresponding to these results.
|
||||||
|
//
|
||||||
|
// This is only used for embedding extraction on time series (e.g. audio
|
||||||
|
// embedding). In these use cases, the amount of data to process might
|
||||||
|
// exceed the maximum size that the model can process: to solve this, the
|
||||||
|
// input data is split into multiple chunks starting at different timestamps.
|
||||||
|
std::optional<int64_t> timestamp_ms = std::nullopt;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Utility function to convert from Embedding proto to Embedding struct.
|
||||||
|
Embedding ConvertToEmbedding(const proto::Embedding& proto);
|
||||||
|
|
||||||
|
// Utility function to convert from EmbeddingResult proto to EmbeddingResult
|
||||||
|
// struct.
|
||||||
|
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::components::containers
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
|
@ -30,30 +30,31 @@ message QuantizedEmbedding {
|
||||||
optional bytes values = 1;
|
optional bytes values = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Floating-point or scalar-quantized embedding with an optional timestamp.
|
// Embedding result for a given embedder head.
|
||||||
message EmbeddingEntry {
|
message Embedding {
|
||||||
// The actual embedding, either floating-point or scalar-quantized.
|
// The actual embedding, either floating-point or quantized.
|
||||||
oneof embedding {
|
oneof embedding {
|
||||||
FloatEmbedding float_embedding = 1;
|
FloatEmbedding float_embedding = 1;
|
||||||
QuantizedEmbedding quantized_embedding = 2;
|
QuantizedEmbedding quantized_embedding = 2;
|
||||||
}
|
}
|
||||||
// The optional timestamp (in milliseconds) associated to the embedding entry.
|
|
||||||
// This is useful for time series use cases, e.g. audio embedding.
|
|
||||||
optional int64 timestamp_ms = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Embeddings for a given embedder head.
|
|
||||||
message Embeddings {
|
|
||||||
repeated EmbeddingEntry entries = 1;
|
|
||||||
// The index of the embedder head that produced this embedding. This is useful
|
// The index of the embedder head that produced this embedding. This is useful
|
||||||
// for multi-head models.
|
// for multi-head models.
|
||||||
optional int32 head_index = 2;
|
optional int32 head_index = 3;
|
||||||
// The name of the embedder head, which is the corresponding tensor metadata
|
// The name of the embedder head, which is the corresponding tensor metadata
|
||||||
// name (if any). This is useful for multi-head models.
|
// name (if any). This is useful for multi-head models.
|
||||||
optional string head_name = 3;
|
optional string head_name = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contains one set of results per embedder head.
|
// Embedding results for a given embedder model.
|
||||||
message EmbeddingResult {
|
message EmbeddingResult {
|
||||||
repeated Embeddings embeddings = 1;
|
// The embedding results for each model head, i.e. one for each output tensor.
|
||||||
|
repeated Embedding embeddings = 1;
|
||||||
|
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||||
|
// corresponding to these results.
|
||||||
|
//
|
||||||
|
// This is only used for embedding extraction on time series (e.g. audio
|
||||||
|
// embedding). In these use cases, the amount of data to process might
|
||||||
|
// exceed the maximum size that the model can process: to solve this, the
|
||||||
|
// input data is split into multiple chunks starting at different timestamps.
|
||||||
|
optional int64 timestamp_ms = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,3 +62,38 @@ cc_library(
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedder_options",
|
||||||
|
srcs = ["embedder_options.cc"],
|
||||||
|
hdrs = ["embedder_options.h"],
|
||||||
|
deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedding_postprocessing_graph",
|
||||||
|
srcs = ["embedding_postprocessing_graph.cc"],
|
||||||
|
hdrs = ["embedding_postprocessing_graph.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/tool:options_map",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
|
@ -13,22 +13,24 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
||||||
EmbedderOptions* embedder_options) {
|
EmbedderOptions* embedder_options) {
|
||||||
tasks::components::proto::EmbedderOptions options_proto;
|
proto::EmbedderOptions options_proto;
|
||||||
options_proto.set_l2_normalize(embedder_options->l2_normalize);
|
options_proto.set_l2_normalize(embedder_options->l2_normalize);
|
||||||
options_proto.set_quantize(embedder_options->quantize);
|
options_proto.set_quantize(embedder_options->quantize);
|
||||||
return options_proto;
|
return options_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
|
@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
// Embedder options for MediaPipe C++ embedding extraction tasks.
|
// Embedder options for MediaPipe C++ embedding extraction tasks.
|
||||||
struct EmbedderOptions {
|
struct EmbedderOptions {
|
||||||
|
@ -37,11 +38,12 @@ struct EmbedderOptions {
|
||||||
bool quantize;
|
bool quantize;
|
||||||
};
|
};
|
||||||
|
|
||||||
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
||||||
EmbedderOptions* embedder_options);
|
EmbedderOptions* embedder_options);
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
|
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -29,8 +29,8 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -49,13 +50,12 @@ using ::mediapipe::api2::builder::GenericNode;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
using ::mediapipe::tasks::components::proto::EmbedderOptions;
|
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
using TensorsSource =
|
using TensorsSource =
|
||||||
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
|
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
|
||||||
|
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
|
|
||||||
// Identifies whether or not the model has quantized outputs, and performs
|
// Identifies whether or not the model has quantized outputs, and performs
|
||||||
// sanity checks.
|
// sanity checks.
|
||||||
|
@ -144,7 +144,7 @@ absl::StatusOr<std::vector<std::string>> GetHeadNames(
|
||||||
|
|
||||||
absl::Status ConfigureEmbeddingPostprocessing(
|
absl::Status ConfigureEmbeddingPostprocessing(
|
||||||
const ModelResources& model_resources,
|
const ModelResources& model_resources,
|
||||||
const EmbedderOptions& embedder_options,
|
const proto::EmbedderOptions& embedder_options,
|
||||||
proto::EmbeddingPostprocessingGraphOptions* options) {
|
proto::EmbeddingPostprocessingGraphOptions* options) {
|
||||||
ASSIGN_OR_RETURN(bool has_quantized_outputs,
|
ASSIGN_OR_RETURN(bool has_quantized_outputs,
|
||||||
HasQuantizedOutputs(model_resources));
|
HasQuantizedOutputs(model_resources));
|
||||||
|
@ -188,7 +188,7 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
BuildEmbeddingPostprocessing(
|
BuildEmbeddingPostprocessing(
|
||||||
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
|
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
|
||||||
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
|
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
|
||||||
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingResultTag)];
|
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,13 +220,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
|
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
|
||||||
.CopyFrom(options.tensors_to_embeddings_options());
|
.CopyFrom(options.tensors_to_embeddings_options());
|
||||||
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
|
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
|
||||||
return tensors_to_embeddings_node[Output<EmbeddingResult>(
|
return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
kEmbeddingResultTag)];
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_MEDIAPIPE_GRAPH(
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
::mediapipe::tasks::components::EmbeddingPostprocessingGraph);
|
::mediapipe::tasks::components::processors::EmbeddingPostprocessingGraph);
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
|
@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
// Configures an EmbeddingPostprocessingGraph using the provided model resources
|
// Configures an EmbeddingPostprocessingGraph using the provided model resources
|
||||||
// and EmbedderOptions.
|
// and EmbedderOptions.
|
||||||
|
@ -44,18 +45,19 @@ namespace components {
|
||||||
// The output tensors of an InferenceCalculator, to convert into
|
// The output tensors of an InferenceCalculator, to convert into
|
||||||
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// EMBEDDING_RESULT - EmbeddingResult
|
// EMBEDDINGS - EmbeddingResult
|
||||||
// The output EmbeddingResult.
|
// The output EmbeddingResult.
|
||||||
//
|
//
|
||||||
// TODO: add support for additional optional "TIMESTAMPS" input for
|
// TODO: add support for additional optional "TIMESTAMPS" input for
|
||||||
// embeddings aggregation.
|
// embeddings aggregation.
|
||||||
absl::Status ConfigureEmbeddingPostprocessing(
|
absl::Status ConfigureEmbeddingPostprocessing(
|
||||||
const tasks::core::ModelResources& model_resources,
|
const tasks::core::ModelResources& model_resources,
|
||||||
const tasks::components::proto::EmbedderOptions& embedder_options,
|
const proto::EmbedderOptions& embedder_options,
|
||||||
proto::EmbeddingPostprocessingGraphOptions* options);
|
proto::EmbeddingPostprocessingGraphOptions* options);
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
|
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
@ -25,8 +25,8 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
@ -34,12 +34,10 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::proto::EmbedderOptions;
|
|
||||||
using ::mediapipe::tasks::components::proto::
|
|
||||||
EmbeddingPostprocessingGraphOptions;
|
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
|
|
||||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
||||||
|
@ -69,16 +67,17 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||||
EmbedderOptions options_in;
|
proto::EmbedderOptions options_in;
|
||||||
options_in.set_l2_normalize(true);
|
options_in.set_l2_normalize(true);
|
||||||
|
|
||||||
EmbeddingPostprocessingGraphOptions options_out;
|
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||||
&options_out));
|
&options_out));
|
||||||
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
options_out,
|
options_out,
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
|
EqualsProto(
|
||||||
|
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
|
||||||
R"pb(tensors_to_embeddings_options {
|
R"pb(tensors_to_embeddings_options {
|
||||||
embedder_options { l2_normalize: true }
|
embedder_options { l2_normalize: true }
|
||||||
head_names: "probability"
|
head_names: "probability"
|
||||||
|
@ -90,16 +89,17 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||||
EmbedderOptions options_in;
|
proto::EmbedderOptions options_in;
|
||||||
options_in.set_quantize(true);
|
options_in.set_quantize(true);
|
||||||
|
|
||||||
EmbeddingPostprocessingGraphOptions options_out;
|
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||||
&options_out));
|
&options_out));
|
||||||
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
options_out,
|
options_out,
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
|
EqualsProto(
|
||||||
|
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
|
||||||
R"pb(tensors_to_embeddings_options {
|
R"pb(tensors_to_embeddings_options {
|
||||||
embedder_options { quantize: true }
|
embedder_options { quantize: true }
|
||||||
}
|
}
|
||||||
|
@ -109,17 +109,18 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
|
||||||
TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
||||||
CreateModelResourcesForModel(kMobileNetV3Embedder));
|
CreateModelResourcesForModel(kMobileNetV3Embedder));
|
||||||
EmbedderOptions options_in;
|
proto::EmbedderOptions options_in;
|
||||||
options_in.set_quantize(true);
|
options_in.set_quantize(true);
|
||||||
options_in.set_l2_normalize(true);
|
options_in.set_l2_normalize(true);
|
||||||
|
|
||||||
EmbeddingPostprocessingGraphOptions options_out;
|
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||||
&options_out));
|
&options_out));
|
||||||
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
options_out,
|
options_out,
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
|
EqualsProto(
|
||||||
|
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
|
||||||
R"pb(tensors_to_embeddings_options {
|
R"pb(tensors_to_embeddings_options {
|
||||||
embedder_options { quantize: true l2_normalize: true }
|
embedder_options { quantize: true l2_normalize: true }
|
||||||
head_names: "feature"
|
head_names: "feature"
|
||||||
|
@ -131,6 +132,7 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
||||||
// supported.
|
// supported.
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
|
@ -34,3 +34,18 @@ mediapipe_proto_library(
|
||||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "embedder_options_proto",
|
||||||
|
srcs = ["embedder_options.proto"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "embedding_postprocessing_graph_options_proto",
|
||||||
|
srcs = ["embedding_postprocessing_graph_options.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -15,7 +15,10 @@ limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.components.proto;
|
package mediapipe.tasks.components.processors.proto;
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
|
||||||
|
option java_outer_classname = "EmbedderOptionsProto";
|
||||||
|
|
||||||
// Shared options used by all embedding extraction tasks.
|
// Shared options used by all embedding extraction tasks.
|
||||||
message EmbedderOptions {
|
message EmbedderOptions {
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.components.proto;
|
package mediapipe.tasks.components.processors.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto";
|
import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto";
|
|
@ -23,21 +23,6 @@ mediapipe_proto_library(
|
||||||
srcs = ["segmenter_options.proto"],
|
srcs = ["segmenter_options.proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
|
||||||
name = "embedder_options_proto",
|
|
||||||
srcs = ["embedder_options.proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
mediapipe_proto_library(
|
|
||||||
name = "embedding_postprocessing_graph_options_proto",
|
|
||||||
srcs = ["embedding_postprocessing_graph_options.proto"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
|
||||||
"//mediapipe/framework:calculator_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "text_preprocessing_graph_options_proto",
|
name = "text_preprocessing_graph_options_proto",
|
||||||
srcs = ["text_preprocessing_graph_options.proto"],
|
srcs = ["text_preprocessing_graph_options.proto"],
|
||||||
|
|
|
@ -26,7 +26,7 @@ cc_library(
|
||||||
hdrs = ["cosine_similarity.h"],
|
hdrs = ["cosine_similarity.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
@ -39,7 +39,7 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":cosine_similarity",
|
":cosine_similarity",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -30,7 +30,7 @@ namespace utils {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
using ::mediapipe::tasks::components::containers::Embedding;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
|
absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
|
||||||
|
@ -66,39 +66,35 @@ absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
|
||||||
// an L2-norm of 0.
|
// an L2-norm of 0.
|
||||||
//
|
//
|
||||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||||
absl::StatusOr<double> CosineSimilarity(const EmbeddingEntry& u,
|
absl::StatusOr<double> CosineSimilarity(const Embedding& u,
|
||||||
const EmbeddingEntry& v) {
|
const Embedding& v) {
|
||||||
if (u.has_float_embedding() && v.has_float_embedding()) {
|
if (!u.float_embedding.empty() && !v.float_embedding.empty()) {
|
||||||
if (u.float_embedding().values().size() !=
|
if (u.float_embedding.size() != v.float_embedding.size()) {
|
||||||
v.float_embedding().values().size()) {
|
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::StrFormat("Cannot compute cosine similarity between embeddings "
|
absl::StrFormat("Cannot compute cosine similarity between embeddings "
|
||||||
"of different sizes (%d vs. %d)",
|
"of different sizes (%d vs. %d)",
|
||||||
u.float_embedding().values().size(),
|
u.float_embedding.size(), v.float_embedding.size()),
|
||||||
v.float_embedding().values().size()),
|
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
}
|
}
|
||||||
return ComputeCosineSimilarity(u.float_embedding().values().data(),
|
return ComputeCosineSimilarity(u.float_embedding.data(),
|
||||||
v.float_embedding().values().data(),
|
v.float_embedding.data(),
|
||||||
u.float_embedding().values().size());
|
u.float_embedding.size());
|
||||||
}
|
}
|
||||||
if (u.has_quantized_embedding() && v.has_quantized_embedding()) {
|
if (!u.quantized_embedding.empty() && !v.quantized_embedding.empty()) {
|
||||||
if (u.quantized_embedding().values().size() !=
|
if (u.quantized_embedding.size() != v.quantized_embedding.size()) {
|
||||||
v.quantized_embedding().values().size()) {
|
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::StrFormat("Cannot compute cosine similarity between embeddings "
|
absl::StrFormat("Cannot compute cosine similarity between embeddings "
|
||||||
"of different sizes (%d vs. %d)",
|
"of different sizes (%d vs. %d)",
|
||||||
u.quantized_embedding().values().size(),
|
u.quantized_embedding.size(),
|
||||||
v.quantized_embedding().values().size()),
|
v.quantized_embedding.size()),
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
}
|
}
|
||||||
return ComputeCosineSimilarity(reinterpret_cast<const int8_t*>(
|
return ComputeCosineSimilarity(
|
||||||
u.quantized_embedding().values().data()),
|
reinterpret_cast<const int8_t*>(u.quantized_embedding.data()),
|
||||||
reinterpret_cast<const int8_t*>(
|
reinterpret_cast<const int8_t*>(v.quantized_embedding.data()),
|
||||||
v.quantized_embedding().values().data()),
|
u.quantized_embedding.size());
|
||||||
u.quantized_embedding().values().size());
|
|
||||||
}
|
}
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
|
|
@ -17,22 +17,20 @@ limitations under the License.
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
|
|
||||||
// Utility function to compute cosine similarity [1] between two embedding
|
// Utility function to compute cosine similarity [1] between two embeddings. May
|
||||||
// entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
// return an InvalidArgumentError if e.g. the embeddings are of different types
|
||||||
// of different types (quantized vs. float), have different sizes, or have a
|
// (quantized vs. float), have different sizes, or have a an L2-norm of 0.
|
||||||
// an L2-norm of 0.
|
|
||||||
//
|
//
|
||||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||||
absl::StatusOr<double> CosineSimilarity(
|
absl::StatusOr<double> CosineSimilarity(const containers::Embedding& u,
|
||||||
const containers::proto::EmbeddingEntry& u,
|
const containers::Embedding& v);
|
||||||
const containers::proto::EmbeddingEntry& v);
|
|
||||||
|
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
} // namespace components
|
} // namespace components
|
||||||
|
|
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -30,29 +30,27 @@ namespace components {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
using ::mediapipe::tasks::components::containers::Embedding;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
// Helper function to generate float EmbeddingEntry.
|
// Helper function to generate float Embedding.
|
||||||
EmbeddingEntry BuildFloatEntry(std::vector<float> values) {
|
Embedding BuildFloatEmbedding(std::vector<float> values) {
|
||||||
EmbeddingEntry entry;
|
Embedding embedding;
|
||||||
for (const float value : values) {
|
embedding.float_embedding = values;
|
||||||
entry.mutable_float_embedding()->add_values(value);
|
return embedding;
|
||||||
}
|
|
||||||
return entry;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to generate quantized EmbeddingEntry.
|
// Helper function to generate quantized Embedding.
|
||||||
EmbeddingEntry BuildQuantizedEntry(std::vector<int8_t> values) {
|
Embedding BuildQuantizedEmbedding(std::vector<int8_t> values) {
|
||||||
EmbeddingEntry entry;
|
Embedding embedding;
|
||||||
entry.mutable_quantized_embedding()->set_values(
|
uint8_t* data = reinterpret_cast<uint8_t*>(values.data());
|
||||||
reinterpret_cast<uint8_t*>(values.data()), values.size());
|
embedding.quantized_embedding = {data, data + values.size()};
|
||||||
return entry;
|
return embedding;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
|
TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
|
||||||
auto u = BuildFloatEntry({0.1, 0.2});
|
auto u = BuildFloatEmbedding({0.1, 0.2});
|
||||||
auto v = BuildQuantizedEntry({0, 1});
|
auto v = BuildQuantizedEmbedding({0, 1});
|
||||||
|
|
||||||
auto status = CosineSimilarity(u, v);
|
auto status = CosineSimilarity(u, v);
|
||||||
|
|
||||||
|
@ -63,8 +61,8 @@ TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CosineSimilarity, FailsWithZeroNorm) {
|
TEST(CosineSimilarity, FailsWithZeroNorm) {
|
||||||
auto u = BuildFloatEntry({0.1, 0.2});
|
auto u = BuildFloatEmbedding({0.1, 0.2});
|
||||||
auto v = BuildFloatEntry({0.0, 0.0});
|
auto v = BuildFloatEmbedding({0.0, 0.0});
|
||||||
|
|
||||||
auto status = CosineSimilarity(u, v);
|
auto status = CosineSimilarity(u, v);
|
||||||
|
|
||||||
|
@ -75,8 +73,8 @@ TEST(CosineSimilarity, FailsWithZeroNorm) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CosineSimilarity, FailsWithDifferentSizes) {
|
TEST(CosineSimilarity, FailsWithDifferentSizes) {
|
||||||
auto u = BuildFloatEntry({0.1, 0.2});
|
auto u = BuildFloatEmbedding({0.1, 0.2});
|
||||||
auto v = BuildFloatEntry({0.1, 0.2, 0.3});
|
auto v = BuildFloatEmbedding({0.1, 0.2, 0.3});
|
||||||
|
|
||||||
auto status = CosineSimilarity(u, v);
|
auto status = CosineSimilarity(u, v);
|
||||||
|
|
||||||
|
@ -87,8 +85,8 @@ TEST(CosineSimilarity, FailsWithDifferentSizes) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CosineSimilarity, SucceedsWithFloatEntries) {
|
TEST(CosineSimilarity, SucceedsWithFloatEntries) {
|
||||||
auto u = BuildFloatEntry({1.0, 0.0, 0.0, 0.0});
|
auto u = BuildFloatEmbedding({1.0, 0.0, 0.0, 0.0});
|
||||||
auto v = BuildFloatEntry({0.5, 0.5, 0.5, 0.5});
|
auto v = BuildFloatEmbedding({0.5, 0.5, 0.5, 0.5});
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
|
||||||
|
|
||||||
|
@ -96,8 +94,8 @@ TEST(CosineSimilarity, SucceedsWithFloatEntries) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CosineSimilarity, SucceedsWithQuantizedEntries) {
|
TEST(CosineSimilarity, SucceedsWithQuantizedEntries) {
|
||||||
auto u = BuildQuantizedEntry({127, 0, 0, 0});
|
auto u = BuildQuantizedEmbedding({127, 0, 0, 0});
|
||||||
auto v = BuildQuantizedEntry({-128, 0, 0, 0});
|
auto v = BuildQuantizedEmbedding({-128, 0, 0, 0});
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
|
||||||
|
|
||||||
|
|
|
@ -26,12 +26,12 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components:embedding_postprocessing_graph",
|
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
|
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
@ -49,9 +49,10 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/tool:options_map",
|
"//mediapipe/framework/tool:options_map",
|
||||||
"//mediapipe/tasks/cc/components:embedder_options",
|
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors:embedder_options",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
|
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
"//mediapipe/tasks/cc/core:task_runner",
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
|
|
|
@ -21,9 +21,10 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/tool/options_map.h"
|
#include "mediapipe/framework/tool/options_map.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
|
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
|
@ -41,8 +42,8 @@ namespace image_embedder {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kEmbeddingResultStreamName[] = "embedding_result_out";
|
constexpr char kEmbeddingsStreamName[] = "embeddings_out";
|
||||||
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
constexpr char kImageInStreamName[] = "image_in";
|
constexpr char kImageInStreamName[] = "image_in";
|
||||||
constexpr char kImageOutStreamName[] = "image_out";
|
constexpr char kImageOutStreamName[] = "image_out";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
@ -53,7 +54,7 @@ constexpr char kGraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph";
|
"mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph";
|
||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
using ::mediapipe::tasks::core::PacketMap;
|
using ::mediapipe::tasks::core::PacketMap;
|
||||||
using ::mediapipe::tasks::vision::image_embedder::proto::
|
using ::mediapipe::tasks::vision::image_embedder::proto::
|
||||||
|
@ -71,13 +72,13 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||||
auto& task_graph = graph.AddNode(kGraphTypeName);
|
auto& task_graph = graph.AddNode(kGraphTypeName);
|
||||||
task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get());
|
task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get());
|
||||||
task_graph.Out(kEmbeddingResultTag).SetName(kEmbeddingResultStreamName) >>
|
task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >>
|
||||||
graph.Out(kEmbeddingResultTag);
|
graph.Out(kEmbeddingsTag);
|
||||||
task_graph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
task_graph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
graph.Out(kImageTag);
|
graph.Out(kImageTag);
|
||||||
if (enable_flow_limiting) {
|
if (enable_flow_limiting) {
|
||||||
return tasks::core::AddFlowLimiterCalculator(
|
return tasks::core::AddFlowLimiterCalculator(
|
||||||
graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingResultTag);
|
graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingsTag);
|
||||||
}
|
}
|
||||||
graph.In(kImageTag) >> task_graph.In(kImageTag);
|
graph.In(kImageTag) >> task_graph.In(kImageTag);
|
||||||
graph.In(kNormRectTag) >> task_graph.In(kNormRectTag);
|
graph.In(kNormRectTag) >> task_graph.In(kNormRectTag);
|
||||||
|
@ -95,8 +96,8 @@ std::unique_ptr<ImageEmbedderGraphOptions> ConvertImageEmbedderOptionsToProto(
|
||||||
options_proto->mutable_base_options()->set_use_stream_mode(
|
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||||
options->running_mode != core::RunningMode::IMAGE);
|
options->running_mode != core::RunningMode::IMAGE);
|
||||||
auto embedder_options_proto =
|
auto embedder_options_proto =
|
||||||
std::make_unique<tasks::components::proto::EmbedderOptions>(
|
std::make_unique<components::processors::proto::EmbedderOptions>(
|
||||||
components::ConvertEmbedderOptionsToProto(
|
components::processors::ConvertEmbedderOptionsToProto(
|
||||||
&(options->embedder_options)));
|
&(options->embedder_options)));
|
||||||
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
|
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
|
||||||
return options_proto;
|
return options_proto;
|
||||||
|
@ -121,9 +122,10 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Packet embedding_result_packet =
|
Packet embedding_result_packet =
|
||||||
status_or_packets.value()[kEmbeddingResultStreamName];
|
status_or_packets.value()[kEmbeddingsStreamName];
|
||||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||||
result_callback(embedding_result_packet.Get<EmbeddingResult>(),
|
result_callback(ConvertToEmbeddingResult(
|
||||||
|
embedding_result_packet.Get<EmbeddingResult>()),
|
||||||
image_packet.Get<Image>(),
|
image_packet.Get<Image>(),
|
||||||
embedding_result_packet.Timestamp().Value() /
|
embedding_result_packet.Timestamp().Value() /
|
||||||
kMicroSecondsPerMilliSecond);
|
kMicroSecondsPerMilliSecond);
|
||||||
|
@ -138,7 +140,7 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
|
||||||
std::move(packets_callback));
|
std::move(packets_callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
absl::StatusOr<ImageEmbedderResult> ImageEmbedder::Embed(
|
||||||
Image image,
|
Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -155,10 +157,11 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
||||||
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
||||||
{kNormRectStreamName,
|
{kNormRectStreamName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
|
return ConvertToEmbeddingResult(
|
||||||
|
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
absl::StatusOr<ImageEmbedderResult> ImageEmbedder::EmbedForVideo(
|
||||||
Image image, int64 timestamp_ms,
|
Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -178,7 +181,8 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
||||||
{kNormRectStreamName,
|
{kNormRectStreamName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||||
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
|
return ConvertToEmbeddingResult(
|
||||||
|
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ImageEmbedder::EmbedAsync(
|
absl::Status ImageEmbedder::EmbedAsync(
|
||||||
|
@ -202,7 +206,8 @@ absl::Status ImageEmbedder::EmbedAsync(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<double> ImageEmbedder::CosineSimilarity(
|
absl::StatusOr<double> ImageEmbedder::CosineSimilarity(
|
||||||
const EmbeddingEntry& u, const EmbeddingEntry& v) {
|
const components::containers::Embedding& u,
|
||||||
|
const components::containers::Embedding& v) {
|
||||||
return components::utils::CosineSimilarity(u, v);
|
return components::utils::CosineSimilarity(u, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
|
@ -33,6 +33,10 @@ namespace tasks {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace image_embedder {
|
namespace image_embedder {
|
||||||
|
|
||||||
|
// Alias the shared EmbeddingResult struct as result typo.
|
||||||
|
using ImageEmbedderResult =
|
||||||
|
::mediapipe::tasks::components::containers::EmbeddingResult;
|
||||||
|
|
||||||
// The options for configuring a MediaPipe image embedder task.
|
// The options for configuring a MediaPipe image embedder task.
|
||||||
struct ImageEmbedderOptions {
|
struct ImageEmbedderOptions {
|
||||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||||
|
@ -50,14 +54,12 @@ struct ImageEmbedderOptions {
|
||||||
|
|
||||||
// Options for configuring the embedder behavior, such as L2-normalization or
|
// Options for configuring the embedder behavior, such as L2-normalization or
|
||||||
// scalar-quantization.
|
// scalar-quantization.
|
||||||
components::EmbedderOptions embedder_options;
|
components::processors::EmbedderOptions embedder_options;
|
||||||
|
|
||||||
// The user-defined result callback for processing live stream data.
|
// The user-defined result callback for processing live stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::LIVE_STREAM.
|
// to RunningMode::LIVE_STREAM.
|
||||||
std::function<void(
|
std::function<void(absl::StatusOr<ImageEmbedderResult>, const Image&, int64)>
|
||||||
absl::StatusOr<components::containers::proto::EmbeddingResult>,
|
|
||||||
const Image&, int64)>
|
|
||||||
result_callback = nullptr;
|
result_callback = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -104,7 +106,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
||||||
// running mode.
|
// running mode.
|
||||||
//
|
//
|
||||||
// The image can be of any size with format RGB or RGBA.
|
// The image can be of any size with format RGB or RGBA.
|
||||||
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed(
|
absl::StatusOr<ImageEmbedderResult> Embed(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
@ -127,7 +129,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
||||||
// The image can be of any size with format RGB or RGBA. It's required to
|
// The image can be of any size with format RGB or RGBA. It's required to
|
||||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
// must be monotonically increasing.
|
// must be monotonically increasing.
|
||||||
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo(
|
absl::StatusOr<ImageEmbedderResult> EmbedForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms,
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
@ -168,15 +170,15 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
||||||
// Shuts down the ImageEmbedder when all works are done.
|
// Shuts down the ImageEmbedder when all works are done.
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
|
||||||
// Utility function to compute cosine similarity [1] between two embedding
|
// Utility function to compute cosine similarity [1] between two embeddings.
|
||||||
// entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
// May return an InvalidArgumentError if e.g. the embeddings are of different
|
||||||
// of different types (quantized vs. float), have different sizes, or have a
|
// types (quantized vs. float), have different sizes, or have a an L2-norm of
|
||||||
// an L2-norm of 0.
|
// 0.
|
||||||
//
|
//
|
||||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||||
static absl::StatusOr<double> CosineSimilarity(
|
static absl::StatusOr<double> CosineSimilarity(
|
||||||
const components::containers::proto::EmbeddingEntry& u,
|
const components::containers::Embedding& u,
|
||||||
const components::containers::proto::EmbeddingEntry& v);
|
const components::containers::Embedding& v);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace image_embedder
|
} // namespace image_embedder
|
||||||
|
|
|
@ -20,10 +20,10 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
|
||||||
|
|
||||||
|
@ -40,10 +40,8 @@ using ::mediapipe::api2::builder::GenericNode;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
using ::mediapipe::tasks::components::proto::
|
|
||||||
EmbeddingPostprocessingGraphOptions;
|
|
||||||
|
|
||||||
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
@ -67,7 +65,7 @@ struct ImageEmbedderOutputStreams {
|
||||||
// Describes region of image to perform embedding extraction on.
|
// Describes region of image to perform embedding extraction on.
|
||||||
// @Optional: rect covering the whole image is used if not specified.
|
// @Optional: rect covering the whole image is used if not specified.
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// EMBEDDING_RESULT - EmbeddingResult
|
// EMBEDDINGS - EmbeddingResult
|
||||||
// The embedding result.
|
// The embedding result.
|
||||||
// IMAGE - Image
|
// IMAGE - Image
|
||||||
// The image that embedding extraction runs on.
|
// The image that embedding extraction runs on.
|
||||||
|
@ -76,7 +74,7 @@ struct ImageEmbedderOutputStreams {
|
||||||
// node {
|
// node {
|
||||||
// calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"
|
// calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"
|
||||||
// input_stream: "IMAGE:image_in"
|
// input_stream: "IMAGE:image_in"
|
||||||
// output_stream: "EMBEDDING_RESULT:embedding_result_out"
|
// output_stream: "EMBEDDINGS:embedding_result_out"
|
||||||
// output_stream: "IMAGE:image_out"
|
// output_stream: "IMAGE:image_out"
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext]
|
// [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext]
|
||||||
|
@ -107,7 +105,7 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
|
||||||
graph[Input<Image>(kImageTag)],
|
graph[Input<Image>(kImageTag)],
|
||||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||||
output_streams.embedding_result >>
|
output_streams.embedding_result >>
|
||||||
graph[Output<EmbeddingResult>(kEmbeddingResultTag)];
|
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
@ -152,16 +150,17 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
|
||||||
// Adds postprocessing calculators and connects its input stream to the
|
// Adds postprocessing calculators and connects its input stream to the
|
||||||
// inference results.
|
// inference results.
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.EmbeddingPostprocessingGraph");
|
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||||
MP_RETURN_IF_ERROR(components::ConfigureEmbeddingPostprocessing(
|
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
||||||
model_resources, task_options.embedder_options(),
|
model_resources, task_options.embedder_options(),
|
||||||
&postprocessing.GetOptions<EmbeddingPostprocessingGraphOptions>()));
|
&postprocessing.GetOptions<components::processors::proto::
|
||||||
|
EmbeddingPostprocessingGraphOptions>()));
|
||||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
|
|
||||||
// Outputs the embedding results.
|
// Outputs the embedding results.
|
||||||
return ImageEmbedderOutputStreams{
|
return ImageEmbedderOutputStreams{
|
||||||
/*embedding_result=*/postprocessing[Output<EmbeddingResult>(
|
/*embedding_result=*/postprocessing[Output<EmbeddingResult>(
|
||||||
kEmbeddingResultTag)],
|
kEmbeddingsTag)],
|
||||||
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
@ -42,7 +42,6 @@ namespace {
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::containers::Rect;
|
using ::mediapipe::tasks::components::containers::Rect;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
|
||||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
@ -54,18 +53,14 @@ constexpr double kSimilarityTolerancy = 1e-6;
|
||||||
|
|
||||||
// Utility function to check the sizes, head_index and head_names of a result
|
// Utility function to check the sizes, head_index and head_names of a result
|
||||||
// procuded by kMobileNetV3Embedder.
|
// procuded by kMobileNetV3Embedder.
|
||||||
void CheckMobileNetV3Result(const EmbeddingResult& result, bool quantized) {
|
void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
|
||||||
EXPECT_EQ(result.embeddings().size(), 1);
|
EXPECT_EQ(result.embeddings.size(), 1);
|
||||||
EXPECT_EQ(result.embeddings(0).head_index(), 0);
|
EXPECT_EQ(result.embeddings[0].head_index, 0);
|
||||||
EXPECT_EQ(result.embeddings(0).head_name(), "feature");
|
EXPECT_EQ(result.embeddings[0].head_name, "feature");
|
||||||
EXPECT_EQ(result.embeddings(0).entries().size(), 1);
|
|
||||||
if (quantized) {
|
if (quantized) {
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(result.embeddings[0].quantized_embedding.size(), 1024);
|
||||||
result.embeddings(0).entries(0).quantized_embedding().values().size(),
|
|
||||||
1024);
|
|
||||||
} else {
|
} else {
|
||||||
EXPECT_EQ(result.embeddings(0).entries(0).float_embedding().values().size(),
|
EXPECT_EQ(result.embeddings[0].float_embedding.size(), 1024);
|
||||||
1024);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +149,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||||
options->running_mode = running_mode;
|
options->running_mode = running_mode;
|
||||||
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
|
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
|
||||||
const Image& image, int64 timestamp_ms) {};
|
const Image& image, int64 timestamp_ms) {};
|
||||||
|
|
||||||
auto image_embedder = ImageEmbedder::Create(std::move(options));
|
auto image_embedder = ImageEmbedder::Create(std::move(options));
|
||||||
|
@ -231,19 +226,18 @@ TEST_F(ImageModeTest, SucceedsWithoutL2Normalization) {
|
||||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||||
image_embedder->Embed(image));
|
image_embedder->Embed(image));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||||
image_embedder->Embed(crop));
|
image_embedder->Embed(crop));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(image_result, false);
|
CheckMobileNetV3Result(image_result, false);
|
||||||
CheckMobileNetV3Result(crop_result, false);
|
CheckMobileNetV3Result(crop_result, false);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
image_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
crop_result.embeddings[0]));
|
||||||
crop_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.925519;
|
double expected_similarity = 0.925519;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -264,19 +258,18 @@ TEST_F(ImageModeTest, SucceedsWithL2Normalization) {
|
||||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||||
image_embedder->Embed(image));
|
image_embedder->Embed(image));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||||
image_embedder->Embed(crop));
|
image_embedder->Embed(crop));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(image_result, false);
|
CheckMobileNetV3Result(image_result, false);
|
||||||
CheckMobileNetV3Result(crop_result, false);
|
CheckMobileNetV3Result(crop_result, false);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
image_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
crop_result.embeddings[0]));
|
||||||
crop_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.925519;
|
double expected_similarity = 0.925519;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -297,19 +290,18 @@ TEST_F(ImageModeTest, SucceedsWithQuantization) {
|
||||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||||
image_embedder->Embed(image));
|
image_embedder->Embed(image));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||||
image_embedder->Embed(crop));
|
image_embedder->Embed(crop));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(image_result, true);
|
CheckMobileNetV3Result(image_result, true);
|
||||||
CheckMobileNetV3Result(crop_result, true);
|
CheckMobileNetV3Result(crop_result, true);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
image_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
crop_result.embeddings[0]));
|
||||||
crop_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.926791;
|
double expected_similarity = 0.926791;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -333,19 +325,18 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
const EmbeddingResult& image_result,
|
const ImageEmbedderResult& image_result,
|
||||||
image_embedder->Embed(image, image_processing_options));
|
image_embedder->Embed(image, image_processing_options));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||||
image_embedder->Embed(crop));
|
image_embedder->Embed(crop));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(image_result, false);
|
CheckMobileNetV3Result(image_result, false);
|
||||||
CheckMobileNetV3Result(crop_result, false);
|
CheckMobileNetV3Result(crop_result, false);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
image_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
crop_result.embeddings[0]));
|
||||||
crop_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.999931;
|
double expected_similarity = 0.999931;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -367,20 +358,19 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||||
image_processing_options.rotation_degrees = -90;
|
image_processing_options.rotation_degrees = -90;
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||||
image_embedder->Embed(image));
|
image_embedder->Embed(image));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
const EmbeddingResult& rotated_result,
|
const ImageEmbedderResult& rotated_result,
|
||||||
image_embedder->Embed(rotated, image_processing_options));
|
image_embedder->Embed(rotated, image_processing_options));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(image_result, false);
|
CheckMobileNetV3Result(image_result, false);
|
||||||
CheckMobileNetV3Result(rotated_result, false);
|
CheckMobileNetV3Result(rotated_result, false);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
image_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
rotated_result.embeddings[0]));
|
||||||
rotated_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.572265;
|
double expected_similarity = 0.572265;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -403,20 +393,19 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||||
/*rotation_degrees=*/-90};
|
/*rotation_degrees=*/-90};
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||||
image_embedder->Embed(crop));
|
image_embedder->Embed(crop));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
const EmbeddingResult& rotated_result,
|
const ImageEmbedderResult& rotated_result,
|
||||||
image_embedder->Embed(rotated, image_processing_options));
|
image_embedder->Embed(rotated, image_processing_options));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(crop_result, false);
|
CheckMobileNetV3Result(crop_result, false);
|
||||||
CheckMobileNetV3Result(rotated_result, false);
|
CheckMobileNetV3Result(rotated_result, false);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
crop_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0),
|
rotated_result.embeddings[0]));
|
||||||
rotated_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.62838;
|
double expected_similarity = 0.62838;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -487,16 +476,16 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||||
ImageEmbedder::Create(std::move(options)));
|
ImageEmbedder::Create(std::move(options)));
|
||||||
|
|
||||||
EmbeddingResult previous_results;
|
ImageEmbedderResult previous_results;
|
||||||
for (int i = 0; i < iterations; ++i) {
|
for (int i = 0; i < iterations; ++i) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
image_embedder->EmbedForVideo(image, i));
|
image_embedder->EmbedForVideo(image, i));
|
||||||
CheckMobileNetV3Result(results, false);
|
CheckMobileNetV3Result(results, false);
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(double similarity,
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
ImageEmbedder::CosineSimilarity(
|
double similarity,
|
||||||
results.embeddings(0).entries(0),
|
ImageEmbedder::CosineSimilarity(results.embeddings[0],
|
||||||
previous_results.embeddings(0).entries(0)));
|
previous_results.embeddings[0]));
|
||||||
double expected_similarity = 1.000000;
|
double expected_similarity = 1.000000;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -515,7 +504,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
|
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
|
||||||
const Image& image, int64 timestamp_ms) {};
|
const Image& image, int64 timestamp_ms) {};
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||||
ImageEmbedder::Create(std::move(options)));
|
ImageEmbedder::Create(std::move(options)));
|
||||||
|
@ -546,7 +535,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
|
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
|
||||||
const Image& image, int64 timestamp_ms) {};
|
const Image& image, int64 timestamp_ms) {};
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||||
ImageEmbedder::Create(std::move(options)));
|
ImageEmbedder::Create(std::move(options)));
|
||||||
|
@ -564,7 +553,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LiveStreamModeResults {
|
struct LiveStreamModeResults {
|
||||||
EmbeddingResult embedding_result;
|
ImageEmbedderResult embedding_result;
|
||||||
std::pair<int, int> image_size;
|
std::pair<int, int> image_size;
|
||||||
int64 timestamp_ms;
|
int64 timestamp_ms;
|
||||||
};
|
};
|
||||||
|
@ -580,7 +569,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[&results](absl::StatusOr<EmbeddingResult> embedding_result,
|
[&results](absl::StatusOr<ImageEmbedderResult> embedding_result,
|
||||||
const Image& image, int64 timestamp_ms) {
|
const Image& image, int64 timestamp_ms) {
|
||||||
MP_ASSERT_OK(embedding_result.status());
|
MP_ASSERT_OK(embedding_result.status());
|
||||||
results.push_back(
|
results.push_back(
|
||||||
|
@ -612,8 +601,8 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
double similarity,
|
double similarity,
|
||||||
ImageEmbedder::CosineSimilarity(
|
ImageEmbedder::CosineSimilarity(
|
||||||
result.embedding_result.embeddings(0).entries(0),
|
result.embedding_result.embeddings[0],
|
||||||
results[i - 1].embedding_result.embeddings(0).entries(0)));
|
results[i - 1].embedding_result.embeddings[0]));
|
||||||
double expected_similarity = 1.000000;
|
double expected_similarity = 1.000000;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ mediapipe_proto_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:embedder_options_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.image_embedder.proto;
|
package mediapipe.tasks.vision.image_embedder.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/proto/embedder_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
message ImageEmbedderGraphOptions {
|
message ImageEmbedderGraphOptions {
|
||||||
|
@ -31,5 +31,5 @@ message ImageEmbedderGraphOptions {
|
||||||
|
|
||||||
// Options for configuring the embedder behavior, such as normalization or
|
// Options for configuring the embedder behavior, such as normalization or
|
||||||
// quantization.
|
// quantization.
|
||||||
optional components.proto.EmbedderOptions embedder_options = 2;
|
optional components.processors.proto.EmbedderOptions embedder_options = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
|
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user