Update EmbeddingResult format and dependent tasks.
PiperOrigin-RevId: 486186491
This commit is contained in:
parent
66e591d4bc
commit
5e1a2fcdbb
|
@ -61,41 +61,6 @@ cc_library(
|
|||
|
||||
# TODO: Enable this test
|
||||
|
||||
cc_library(
|
||||
name = "embedder_options",
|
||||
srcs = ["embedder_options.cc"],
|
||||
hdrs = ["embedder_options.h"],
|
||||
deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedding_postprocessing_graph",
|
||||
srcs = ["embedding_postprocessing_graph.cc"],
|
||||
hdrs = ["embedding_postprocessing_graph.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# TODO: Investigate rewriting the build rule to only link
|
||||
# the Bert Preprocessor if it's needed.
|
||||
cc_library(
|
||||
|
|
|
@ -163,7 +163,7 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:embedder_options_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -178,7 +178,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
|
|
|
@ -26,14 +26,14 @@
|
|||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
||||
using ::mediapipe::tasks::components::containers::proto::Embedding;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
|
||||
// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
|
||||
|
@ -66,7 +66,7 @@ float GetInverseL2Norm(const float* values, int size) {
|
|||
class TensorsToEmbeddingsCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
|
||||
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDING_RESULT"};
|
||||
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDINGS"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
|
@ -77,8 +77,8 @@ class TensorsToEmbeddingsCalculator : public Node {
|
|||
bool quantize_;
|
||||
std::vector<std::string> head_names_;
|
||||
|
||||
void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
|
||||
void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
|
||||
void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding);
|
||||
void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding);
|
||||
};
|
||||
|
||||
absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
|
||||
|
@ -104,42 +104,42 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
|
|||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
const auto& tensor = tensors[i];
|
||||
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
|
||||
auto* embeddings = result.add_embeddings();
|
||||
embeddings->set_head_index(i);
|
||||
auto* embedding = result.add_embeddings();
|
||||
embedding->set_head_index(i);
|
||||
if (!head_names_.empty()) {
|
||||
embeddings->set_head_name(head_names_[i]);
|
||||
embedding->set_head_name(head_names_[i]);
|
||||
}
|
||||
if (quantize_) {
|
||||
FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries());
|
||||
FillQuantizedEmbedding(tensor, embedding);
|
||||
} else {
|
||||
FillFloatEmbeddingEntry(tensor, embeddings->add_entries());
|
||||
FillFloatEmbedding(tensor, embedding);
|
||||
}
|
||||
}
|
||||
kEmbeddingsOut(cc).Send(result);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry(
|
||||
const Tensor& tensor, EmbeddingEntry* entry) {
|
||||
void TensorsToEmbeddingsCalculator::FillFloatEmbedding(const Tensor& tensor,
|
||||
Embedding* embedding) {
|
||||
int size = tensor.shape().num_elements();
|
||||
auto tensor_view = tensor.GetCpuReadView();
|
||||
const float* tensor_buffer = tensor_view.buffer<float>();
|
||||
float inv_l2_norm =
|
||||
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
||||
auto* float_embedding = entry->mutable_float_embedding();
|
||||
auto* float_embedding = embedding->mutable_float_embedding();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
float_embedding->add_values(tensor_buffer[i] * inv_l2_norm);
|
||||
}
|
||||
}
|
||||
|
||||
void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry(
|
||||
const Tensor& tensor, EmbeddingEntry* entry) {
|
||||
void TensorsToEmbeddingsCalculator::FillQuantizedEmbedding(
|
||||
const Tensor& tensor, Embedding* embedding) {
|
||||
int size = tensor.shape().num_elements();
|
||||
auto tensor_view = tensor.GetCpuReadView();
|
||||
const float* tensor_buffer = tensor_view.buffer<float>();
|
||||
float inv_l2_norm =
|
||||
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
||||
auto* values = entry->mutable_quantized_embedding()->mutable_values();
|
||||
auto* values = embedding->mutable_quantized_embedding()->mutable_values();
|
||||
values->resize(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
// Normalize.
|
||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
|||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/proto/embedder_options.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||
|
||||
message TensorsToEmbeddingsCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
|
@ -27,8 +27,8 @@ message TensorsToEmbeddingsCalculatorOptions {
|
|||
|
||||
// The embedder options defining whether to L2-normalize or scalar-quantize
|
||||
// the outputs.
|
||||
optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options =
|
||||
1;
|
||||
optional mediapipe.tasks.components.processors.proto.EmbedderOptions
|
||||
embedder_options = 1;
|
||||
|
||||
// The embedder head names.
|
||||
repeated string head_names = 2;
|
||||
|
|
|
@ -55,7 +55,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) {
|
|||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToEmbeddingsCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||
output_stream: "EMBEDDINGS:embeddings"
|
||||
options {
|
||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" }
|
||||
}
|
||||
|
@ -73,7 +73,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
|
|||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToEmbeddingsCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||
output_stream: "EMBEDDINGS:embeddings"
|
||||
options {
|
||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||
embedder_options { l2_normalize: false quantize: false }
|
||||
|
@ -84,28 +84,24 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
|
|||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const EmbeddingResult& result = runner.Outputs()
|
||||
.Get("EMBEDDING_RESULT", 0)
|
||||
.packets[0]
|
||||
.Get<EmbeddingResult>();
|
||||
EXPECT_THAT(
|
||||
result,
|
||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings {
|
||||
entries { float_embedding { values: 0.1 values: 0.2 } }
|
||||
head_index: 0
|
||||
}
|
||||
embeddings {
|
||||
entries { float_embedding { values: -0.2 values: -0.3 } }
|
||||
head_index: 1
|
||||
})pb")));
|
||||
const EmbeddingResult& result =
|
||||
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||
EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings {
|
||||
float_embedding { values: 0.1 values: 0.2 }
|
||||
head_index: 0
|
||||
}
|
||||
embeddings {
|
||||
float_embedding { values: -0.2 values: -0.3 }
|
||||
head_index: 1
|
||||
})pb")));
|
||||
}
|
||||
|
||||
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToEmbeddingsCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||
output_stream: "EMBEDDINGS:embeddings"
|
||||
options {
|
||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||
embedder_options { l2_normalize: false quantize: false }
|
||||
|
@ -118,30 +114,26 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
|
|||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const EmbeddingResult& result = runner.Outputs()
|
||||
.Get("EMBEDDING_RESULT", 0)
|
||||
.packets[0]
|
||||
.Get<EmbeddingResult>();
|
||||
EXPECT_THAT(
|
||||
result,
|
||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings {
|
||||
entries { float_embedding { values: 0.1 values: 0.2 } }
|
||||
head_index: 0
|
||||
head_name: "foo"
|
||||
}
|
||||
embeddings {
|
||||
entries { float_embedding { values: -0.2 values: -0.3 } }
|
||||
head_index: 1
|
||||
head_name: "bar"
|
||||
})pb")));
|
||||
const EmbeddingResult& result =
|
||||
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||
EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings {
|
||||
float_embedding { values: 0.1 values: 0.2 }
|
||||
head_index: 0
|
||||
head_name: "foo"
|
||||
}
|
||||
embeddings {
|
||||
float_embedding { values: -0.2 values: -0.3 }
|
||||
head_index: 1
|
||||
head_name: "bar"
|
||||
})pb")));
|
||||
}
|
||||
|
||||
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToEmbeddingsCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||
output_stream: "EMBEDDINGS:embeddings"
|
||||
options {
|
||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||
embedder_options { l2_normalize: true quantize: false }
|
||||
|
@ -152,23 +144,17 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
|
|||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const EmbeddingResult& result = runner.Outputs()
|
||||
.Get("EMBEDDING_RESULT", 0)
|
||||
.packets[0]
|
||||
.Get<EmbeddingResult>();
|
||||
const EmbeddingResult& result =
|
||||
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||
EXPECT_THAT(
|
||||
result,
|
||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings {
|
||||
entries {
|
||||
float_embedding { values: 0.44721356 values: 0.8944271 }
|
||||
}
|
||||
float_embedding { values: 0.44721356 values: 0.8944271 }
|
||||
head_index: 0
|
||||
}
|
||||
embeddings {
|
||||
entries {
|
||||
float_embedding { values: -0.5547002 values: -0.8320503 }
|
||||
}
|
||||
float_embedding { values: -0.5547002 values: -0.8320503 }
|
||||
head_index: 1
|
||||
})pb")));
|
||||
}
|
||||
|
@ -177,7 +163,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
|
|||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToEmbeddingsCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||
output_stream: "EMBEDDINGS:embeddings"
|
||||
options {
|
||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||
embedder_options { l2_normalize: false quantize: true }
|
||||
|
@ -188,22 +174,16 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
|
|||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const EmbeddingResult& result = runner.Outputs()
|
||||
.Get("EMBEDDING_RESULT", 0)
|
||||
.packets[0]
|
||||
.Get<EmbeddingResult>();
|
||||
const EmbeddingResult& result =
|
||||
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||
EXPECT_THAT(result,
|
||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings {
|
||||
entries {
|
||||
quantized_embedding { values: "\x0d\x1a" } # 13,26
|
||||
}
|
||||
quantized_embedding { values: "\x0d\x1a" } # 13,26
|
||||
head_index: 0
|
||||
}
|
||||
embeddings {
|
||||
entries {
|
||||
quantized_embedding { values: "\xe6\xda" } # -26,-38
|
||||
}
|
||||
quantized_embedding { values: "\xe6\xda" } # -26,-38
|
||||
head_index: 1
|
||||
})pb")));
|
||||
}
|
||||
|
@ -213,7 +193,7 @@ TEST(TensorsToEmbeddingsCalculatorTest,
|
|||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToEmbeddingsCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
||||
output_stream: "EMBEDDINGS:embeddings"
|
||||
options {
|
||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||
embedder_options { l2_normalize: true quantize: true }
|
||||
|
@ -224,25 +204,18 @@ TEST(TensorsToEmbeddingsCalculatorTest,
|
|||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const EmbeddingResult& result = runner.Outputs()
|
||||
.Get("EMBEDDING_RESULT", 0)
|
||||
.packets[0]
|
||||
.Get<EmbeddingResult>();
|
||||
EXPECT_THAT(
|
||||
result,
|
||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings {
|
||||
entries {
|
||||
quantized_embedding { values: "\x39\x72" } # 57,114
|
||||
}
|
||||
head_index: 0
|
||||
}
|
||||
embeddings {
|
||||
entries {
|
||||
quantized_embedding { values: "\xb9\x95" } # -71,-107
|
||||
}
|
||||
head_index: 1
|
||||
})pb")));
|
||||
const EmbeddingResult& result =
|
||||
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||
EXPECT_THAT(result,
|
||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings {
|
||||
quantized_embedding { values: "\x39\x72" } # 57,114
|
||||
head_index: 0
|
||||
}
|
||||
embeddings {
|
||||
quantized_embedding { values: "\xb9\x95" } # -71,-107
|
||||
head_index: 1
|
||||
})pb")));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -49,3 +49,12 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedding_result",
|
||||
srcs = ["embedding_result.cc"],
|
||||
hdrs = ["embedding_result.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
],
|
||||
)
|
||||
|
|
57
mediapipe/tasks/cc/components/containers/embedding_result.cc
Normal file
57
mediapipe/tasks/cc/components/containers/embedding_result.cc
Normal file
|
@ -0,0 +1,57 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
Embedding ConvertToEmbedding(const proto::Embedding& proto) {
|
||||
Embedding embedding;
|
||||
if (proto.has_float_embedding()) {
|
||||
embedding.float_embedding = {
|
||||
std::make_move_iterator(proto.float_embedding().values().begin()),
|
||||
std::make_move_iterator(proto.float_embedding().values().end())};
|
||||
} else {
|
||||
embedding.quantized_embedding = {
|
||||
std::make_move_iterator(proto.quantized_embedding().values().begin()),
|
||||
std::make_move_iterator(proto.quantized_embedding().values().end())};
|
||||
}
|
||||
embedding.head_index = proto.head_index();
|
||||
if (proto.has_head_name()) {
|
||||
embedding.head_name = proto.head_name();
|
||||
}
|
||||
return embedding;
|
||||
}
|
||||
|
||||
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto) {
|
||||
EmbeddingResult embedding_result;
|
||||
embedding_result.embeddings.reserve(proto.embeddings_size());
|
||||
for (const auto& embedding : proto.embeddings()) {
|
||||
embedding_result.embeddings.push_back(ConvertToEmbedding(embedding));
|
||||
}
|
||||
if (proto.has_timestamp_ms()) {
|
||||
embedding_result.timestamp_ms = proto.timestamp_ms();
|
||||
}
|
||||
return embedding_result;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
72
mediapipe/tasks/cc/components/containers/embedding_result.h
Normal file
72
mediapipe/tasks/cc/components/containers/embedding_result.h
Normal file
|
@ -0,0 +1,72 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
// Embedding result for a given embedder head.
|
||||
//
|
||||
// One and only one of the two 'float_embedding' and 'quantized_embedding' will
|
||||
// contain data, based on whether or not the embedder was configured to perform
|
||||
// scalar quantization.
|
||||
struct Embedding {
|
||||
// Floating-point embedding. Empty if the embedder was configured to perform
|
||||
// scalar-quantization.
|
||||
std::vector<float> float_embedding;
|
||||
// Scalar-quantized embedding. Empty if the embedder was not configured to
|
||||
// perform scalar quantization.
|
||||
std::string quantized_embedding;
|
||||
// The index of the embedder head (i.e. output tensor) this embedding comes
|
||||
// from. This is useful for multi-head models.
|
||||
int head_index;
|
||||
// The optional name of the embedder head, as provided in the TFLite Model
|
||||
// Metadata [1] if present. This is useful for multi-head models.
|
||||
//
|
||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||
std::optional<std::string> head_name = std::nullopt;
|
||||
};
|
||||
|
||||
// Defines embedding results of a model.
|
||||
struct EmbeddingResult {
|
||||
// The embedding results for each head of the model.
|
||||
std::vector<Embedding> embeddings;
|
||||
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
// corresponding to these results.
|
||||
//
|
||||
// This is only used for embedding extraction on time series (e.g. audio
|
||||
// embedding). In these use cases, the amount of data to process might
|
||||
// exceed the maximum size that the model can process: to solve this, the
|
||||
// input data is split into multiple chunks starting at different timestamps.
|
||||
std::optional<int64_t> timestamp_ms = std::nullopt;
|
||||
};
|
||||
|
||||
// Utility function to convert from Embedding proto to Embedding struct.
|
||||
Embedding ConvertToEmbedding(const proto::Embedding& proto);
|
||||
|
||||
// Utility function to convert from EmbeddingResult proto to EmbeddingResult
|
||||
// struct.
|
||||
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto);
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
|
@ -30,30 +30,31 @@ message QuantizedEmbedding {
|
|||
optional bytes values = 1;
|
||||
}
|
||||
|
||||
// Floating-point or scalar-quantized embedding with an optional timestamp.
|
||||
message EmbeddingEntry {
|
||||
// The actual embedding, either floating-point or scalar-quantized.
|
||||
// Embedding result for a given embedder head.
|
||||
message Embedding {
|
||||
// The actual embedding, either floating-point or quantized.
|
||||
oneof embedding {
|
||||
FloatEmbedding float_embedding = 1;
|
||||
QuantizedEmbedding quantized_embedding = 2;
|
||||
}
|
||||
// The optional timestamp (in milliseconds) associated to the embedding entry.
|
||||
// This is useful for time series use cases, e.g. audio embedding.
|
||||
optional int64 timestamp_ms = 3;
|
||||
}
|
||||
|
||||
// Embeddings for a given embedder head.
|
||||
message Embeddings {
|
||||
repeated EmbeddingEntry entries = 1;
|
||||
// The index of the embedder head that produced this embedding. This is useful
|
||||
// for multi-head models.
|
||||
optional int32 head_index = 2;
|
||||
optional int32 head_index = 3;
|
||||
// The name of the embedder head, which is the corresponding tensor metadata
|
||||
// name (if any). This is useful for multi-head models.
|
||||
optional string head_name = 3;
|
||||
optional string head_name = 4;
|
||||
}
|
||||
|
||||
// Contains one set of results per embedder head.
|
||||
// Embedding results for a given embedder model.
|
||||
message EmbeddingResult {
|
||||
repeated Embeddings embeddings = 1;
|
||||
// The embedding results for each model head, i.e. one for each output tensor.
|
||||
repeated Embedding embeddings = 1;
|
||||
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
// corresponding to these results.
|
||||
//
|
||||
// This is only used for embedding extraction on time series (e.g. audio
|
||||
// embedding). In these use cases, the amount of data to process might
|
||||
// exceed the maximum size that the model can process: to solve this, the
|
||||
// input data is split into multiple chunks starting at different timestamps.
|
||||
optional int64 timestamp_ms = 2;
|
||||
}
|
||||
|
|
|
@ -62,3 +62,38 @@ cc_library(
|
|||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedder_options",
|
||||
srcs = ["embedder_options.cc"],
|
||||
hdrs = ["embedder_options.h"],
|
||||
deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedding_postprocessing_graph",
|
||||
srcs = ["embedding_postprocessing_graph.cc"],
|
||||
hdrs = ["embedding_postprocessing_graph.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -13,22 +13,24 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
||||
proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
||||
EmbedderOptions* embedder_options) {
|
||||
tasks::components::proto::EmbedderOptions options_proto;
|
||||
proto::EmbedderOptions options_proto;
|
||||
options_proto.set_l2_normalize(embedder_options->l2_normalize);
|
||||
options_proto.set_quantize(embedder_options->quantize);
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||
|
||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
// Embedder options for MediaPipe C++ embedding extraction tasks.
|
||||
struct EmbedderOptions {
|
||||
|
@ -37,11 +38,12 @@ struct EmbedderOptions {
|
|||
bool quantize;
|
||||
};
|
||||
|
||||
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
||||
proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
||||
EmbedderOptions* embedder_options);
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -29,8 +29,8 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -49,13 +50,12 @@ using ::mediapipe::api2::builder::GenericNode;
|
|||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::mediapipe::tasks::components::proto::EmbedderOptions;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using TensorsSource =
|
||||
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
|
||||
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
|
||||
// Identifies whether or not the model has quantized outputs, and performs
|
||||
// sanity checks.
|
||||
|
@ -144,7 +144,7 @@ absl::StatusOr<std::vector<std::string>> GetHeadNames(
|
|||
|
||||
absl::Status ConfigureEmbeddingPostprocessing(
|
||||
const ModelResources& model_resources,
|
||||
const EmbedderOptions& embedder_options,
|
||||
const proto::EmbedderOptions& embedder_options,
|
||||
proto::EmbeddingPostprocessingGraphOptions* options) {
|
||||
ASSIGN_OR_RETURN(bool has_quantized_outputs,
|
||||
HasQuantizedOutputs(model_resources));
|
||||
|
@ -188,7 +188,7 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
|||
BuildEmbeddingPostprocessing(
|
||||
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
|
||||
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingResultTag)];
|
||||
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -220,13 +220,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
|||
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
|
||||
.CopyFrom(options.tensors_to_embeddings_options());
|
||||
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
|
||||
return tensors_to_embeddings_node[Output<EmbeddingResult>(
|
||||
kEmbeddingResultTag)];
|
||||
return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
}
|
||||
};
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::components::EmbeddingPostprocessingGraph);
|
||||
::mediapipe::tasks::components::processors::EmbeddingPostprocessingGraph);
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
// Configures an EmbeddingPostprocessingGraph using the provided model resources
|
||||
// and EmbedderOptions.
|
||||
|
@ -44,18 +45,19 @@ namespace components {
|
|||
// The output tensors of an InferenceCalculator, to convert into
|
||||
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
||||
// Outputs:
|
||||
// EMBEDDING_RESULT - EmbeddingResult
|
||||
// EMBEDDINGS - EmbeddingResult
|
||||
// The output EmbeddingResult.
|
||||
//
|
||||
// TODO: add support for additional optional "TIMESTAMPS" input for
|
||||
// embeddings aggregation.
|
||||
absl::Status ConfigureEmbeddingPostprocessing(
|
||||
const tasks::core::ModelResources& model_resources,
|
||||
const tasks::components::proto::EmbedderOptions& embedder_options,
|
||||
const proto::EmbedderOptions& embedder_options,
|
||||
proto::EmbeddingPostprocessingGraphOptions* options);
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
|
@ -25,8 +25,8 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
@ -34,12 +34,10 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::proto::EmbedderOptions;
|
||||
using ::mediapipe::tasks::components::proto::
|
||||
EmbeddingPostprocessingGraphOptions;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
||||
|
@ -69,68 +67,72 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||
EmbedderOptions options_in;
|
||||
proto::EmbedderOptions options_in;
|
||||
options_in.set_l2_normalize(true);
|
||||
|
||||
EmbeddingPostprocessingGraphOptions options_out;
|
||||
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||
&options_out));
|
||||
|
||||
EXPECT_THAT(
|
||||
options_out,
|
||||
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
|
||||
R"pb(tensors_to_embeddings_options {
|
||||
embedder_options { l2_normalize: true }
|
||||
head_names: "probability"
|
||||
}
|
||||
has_quantized_outputs: true)pb")));
|
||||
EqualsProto(
|
||||
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
|
||||
R"pb(tensors_to_embeddings_options {
|
||||
embedder_options { l2_normalize: true }
|
||||
head_names: "probability"
|
||||
}
|
||||
has_quantized_outputs: true)pb")));
|
||||
}
|
||||
|
||||
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||
EmbedderOptions options_in;
|
||||
proto::EmbedderOptions options_in;
|
||||
options_in.set_quantize(true);
|
||||
|
||||
EmbeddingPostprocessingGraphOptions options_out;
|
||||
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||
&options_out));
|
||||
|
||||
EXPECT_THAT(
|
||||
options_out,
|
||||
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
|
||||
R"pb(tensors_to_embeddings_options {
|
||||
embedder_options { quantize: true }
|
||||
}
|
||||
has_quantized_outputs: true)pb")));
|
||||
EqualsProto(
|
||||
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
|
||||
R"pb(tensors_to_embeddings_options {
|
||||
embedder_options { quantize: true }
|
||||
}
|
||||
has_quantized_outputs: true)pb")));
|
||||
}
|
||||
|
||||
TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
||||
CreateModelResourcesForModel(kMobileNetV3Embedder));
|
||||
EmbedderOptions options_in;
|
||||
proto::EmbedderOptions options_in;
|
||||
options_in.set_quantize(true);
|
||||
options_in.set_l2_normalize(true);
|
||||
|
||||
EmbeddingPostprocessingGraphOptions options_out;
|
||||
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||
&options_out));
|
||||
|
||||
EXPECT_THAT(
|
||||
options_out,
|
||||
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
|
||||
R"pb(tensors_to_embeddings_options {
|
||||
embedder_options { quantize: true l2_normalize: true }
|
||||
head_names: "feature"
|
||||
}
|
||||
has_quantized_outputs: false)pb")));
|
||||
EqualsProto(
|
||||
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
|
||||
R"pb(tensors_to_embeddings_options {
|
||||
embedder_options { quantize: true l2_normalize: true }
|
||||
head_names: "feature"
|
||||
}
|
||||
has_quantized_outputs: false)pb")));
|
||||
}
|
||||
|
||||
// TODO: add E2E Postprocessing tests once timestamp aggregation is
|
||||
// supported.
|
||||
|
||||
} // namespace
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -34,3 +34,18 @@ mediapipe_proto_library(
|
|||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "embedder_options_proto",
|
||||
srcs = ["embedder_options.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "embedding_postprocessing_graph_options_proto",
|
||||
srcs = ["embedding_postprocessing_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -15,7 +15,10 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.components.proto;
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
|
||||
option java_outer_classname = "EmbedderOptionsProto";
|
||||
|
||||
// Shared options used by all embedding extraction tasks.
|
||||
message EmbedderOptions {
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.components.proto;
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto";
|
|
@ -23,21 +23,6 @@ mediapipe_proto_library(
|
|||
srcs = ["segmenter_options.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "embedder_options_proto",
|
||||
srcs = ["embedder_options.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "embedding_postprocessing_graph_options_proto",
|
||||
srcs = ["embedding_postprocessing_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "text_preprocessing_graph_options_proto",
|
||||
srcs = ["text_preprocessing_graph_options.proto"],
|
||||
|
|
|
@ -26,7 +26,7 @@ cc_library(
|
|||
hdrs = ["cosine_similarity.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
@ -39,7 +39,7 @@ cc_test(
|
|||
deps = [
|
||||
":cosine_similarity",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -30,7 +30,7 @@ namespace utils {
|
|||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
||||
using ::mediapipe::tasks::components::containers::Embedding;
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
|
||||
|
@ -66,39 +66,35 @@ absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
|
|||
// an L2-norm of 0.
|
||||
//
|
||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||
absl::StatusOr<double> CosineSimilarity(const EmbeddingEntry& u,
|
||||
const EmbeddingEntry& v) {
|
||||
if (u.has_float_embedding() && v.has_float_embedding()) {
|
||||
if (u.float_embedding().values().size() !=
|
||||
v.float_embedding().values().size()) {
|
||||
absl::StatusOr<double> CosineSimilarity(const Embedding& u,
|
||||
const Embedding& v) {
|
||||
if (!u.float_embedding.empty() && !v.float_embedding.empty()) {
|
||||
if (u.float_embedding.size() != v.float_embedding.size()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Cannot compute cosine similarity between embeddings "
|
||||
"of different sizes (%d vs. %d)",
|
||||
u.float_embedding().values().size(),
|
||||
v.float_embedding().values().size()),
|
||||
u.float_embedding.size(), v.float_embedding.size()),
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
return ComputeCosineSimilarity(u.float_embedding().values().data(),
|
||||
v.float_embedding().values().data(),
|
||||
u.float_embedding().values().size());
|
||||
return ComputeCosineSimilarity(u.float_embedding.data(),
|
||||
v.float_embedding.data(),
|
||||
u.float_embedding.size());
|
||||
}
|
||||
if (u.has_quantized_embedding() && v.has_quantized_embedding()) {
|
||||
if (u.quantized_embedding().values().size() !=
|
||||
v.quantized_embedding().values().size()) {
|
||||
if (!u.quantized_embedding.empty() && !v.quantized_embedding.empty()) {
|
||||
if (u.quantized_embedding.size() != v.quantized_embedding.size()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Cannot compute cosine similarity between embeddings "
|
||||
"of different sizes (%d vs. %d)",
|
||||
u.quantized_embedding().values().size(),
|
||||
v.quantized_embedding().values().size()),
|
||||
u.quantized_embedding.size(),
|
||||
v.quantized_embedding.size()),
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
return ComputeCosineSimilarity(reinterpret_cast<const int8_t*>(
|
||||
u.quantized_embedding().values().data()),
|
||||
reinterpret_cast<const int8_t*>(
|
||||
v.quantized_embedding().values().data()),
|
||||
u.quantized_embedding().values().size());
|
||||
return ComputeCosineSimilarity(
|
||||
reinterpret_cast<const int8_t*>(u.quantized_embedding.data()),
|
||||
reinterpret_cast<const int8_t*>(v.quantized_embedding.data()),
|
||||
u.quantized_embedding.size());
|
||||
}
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
|
|
|
@ -17,22 +17,20 @@ limitations under the License.
|
|||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace utils {
|
||||
|
||||
// Utility function to compute cosine similarity [1] between two embedding
|
||||
// entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||
// of different types (quantized vs. float), have different sizes, or have a
|
||||
// an L2-norm of 0.
|
||||
// Utility function to compute cosine similarity [1] between two embeddings. May
|
||||
// return an InvalidArgumentError if e.g. the embeddings are of different types
|
||||
// (quantized vs. float), have different sizes, or have a an L2-norm of 0.
|
||||
//
|
||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||
absl::StatusOr<double> CosineSimilarity(
|
||||
const containers::proto::EmbeddingEntry& u,
|
||||
const containers::proto::EmbeddingEntry& v);
|
||||
absl::StatusOr<double> CosineSimilarity(const containers::Embedding& u,
|
||||
const containers::Embedding& v);
|
||||
|
||||
} // namespace utils
|
||||
} // namespace components
|
||||
|
|
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -30,29 +30,27 @@ namespace components {
|
|||
namespace utils {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
||||
using ::mediapipe::tasks::components::containers::Embedding;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
// Helper function to generate float EmbeddingEntry.
|
||||
EmbeddingEntry BuildFloatEntry(std::vector<float> values) {
|
||||
EmbeddingEntry entry;
|
||||
for (const float value : values) {
|
||||
entry.mutable_float_embedding()->add_values(value);
|
||||
}
|
||||
return entry;
|
||||
// Helper function to generate float Embedding.
|
||||
Embedding BuildFloatEmbedding(std::vector<float> values) {
|
||||
Embedding embedding;
|
||||
embedding.float_embedding = values;
|
||||
return embedding;
|
||||
}
|
||||
|
||||
// Helper function to generate quantized EmbeddingEntry.
|
||||
EmbeddingEntry BuildQuantizedEntry(std::vector<int8_t> values) {
|
||||
EmbeddingEntry entry;
|
||||
entry.mutable_quantized_embedding()->set_values(
|
||||
reinterpret_cast<uint8_t*>(values.data()), values.size());
|
||||
return entry;
|
||||
// Helper function to generate quantized Embedding.
|
||||
Embedding BuildQuantizedEmbedding(std::vector<int8_t> values) {
|
||||
Embedding embedding;
|
||||
uint8_t* data = reinterpret_cast<uint8_t*>(values.data());
|
||||
embedding.quantized_embedding = {data, data + values.size()};
|
||||
return embedding;
|
||||
}
|
||||
|
||||
TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
|
||||
auto u = BuildFloatEntry({0.1, 0.2});
|
||||
auto v = BuildQuantizedEntry({0, 1});
|
||||
auto u = BuildFloatEmbedding({0.1, 0.2});
|
||||
auto v = BuildQuantizedEmbedding({0, 1});
|
||||
|
||||
auto status = CosineSimilarity(u, v);
|
||||
|
||||
|
@ -63,8 +61,8 @@ TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
|
|||
}
|
||||
|
||||
TEST(CosineSimilarity, FailsWithZeroNorm) {
|
||||
auto u = BuildFloatEntry({0.1, 0.2});
|
||||
auto v = BuildFloatEntry({0.0, 0.0});
|
||||
auto u = BuildFloatEmbedding({0.1, 0.2});
|
||||
auto v = BuildFloatEmbedding({0.0, 0.0});
|
||||
|
||||
auto status = CosineSimilarity(u, v);
|
||||
|
||||
|
@ -75,8 +73,8 @@ TEST(CosineSimilarity, FailsWithZeroNorm) {
|
|||
}
|
||||
|
||||
TEST(CosineSimilarity, FailsWithDifferentSizes) {
|
||||
auto u = BuildFloatEntry({0.1, 0.2});
|
||||
auto v = BuildFloatEntry({0.1, 0.2, 0.3});
|
||||
auto u = BuildFloatEmbedding({0.1, 0.2});
|
||||
auto v = BuildFloatEmbedding({0.1, 0.2, 0.3});
|
||||
|
||||
auto status = CosineSimilarity(u, v);
|
||||
|
||||
|
@ -87,8 +85,8 @@ TEST(CosineSimilarity, FailsWithDifferentSizes) {
|
|||
}
|
||||
|
||||
TEST(CosineSimilarity, SucceedsWithFloatEntries) {
|
||||
auto u = BuildFloatEntry({1.0, 0.0, 0.0, 0.0});
|
||||
auto v = BuildFloatEntry({0.5, 0.5, 0.5, 0.5});
|
||||
auto u = BuildFloatEmbedding({1.0, 0.0, 0.0, 0.0});
|
||||
auto v = BuildFloatEmbedding({0.5, 0.5, 0.5, 0.5});
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
|
||||
|
||||
|
@ -96,8 +94,8 @@ TEST(CosineSimilarity, SucceedsWithFloatEntries) {
|
|||
}
|
||||
|
||||
TEST(CosineSimilarity, SucceedsWithQuantizedEntries) {
|
||||
auto u = BuildQuantizedEntry({127, 0, 0, 0});
|
||||
auto v = BuildQuantizedEntry({-128, 0, 0, 0});
|
||||
auto u = BuildQuantizedEmbedding({127, 0, 0, 0});
|
||||
auto v = BuildQuantizedEmbedding({-128, 0, 0, 0});
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
|
||||
|
||||
|
|
|
@ -26,12 +26,12 @@ cc_library(
|
|||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components:embedding_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
@ -49,9 +49,10 @@ cc_library(
|
|||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
"//mediapipe/tasks/cc/components:embedder_options",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
|
|
|
@ -21,9 +21,10 @@ limitations under the License.
|
|||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/tool/options_map.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
|
@ -41,8 +42,8 @@ namespace image_embedder {
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr char kEmbeddingResultStreamName[] = "embedding_result_out";
|
||||
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
|
||||
constexpr char kEmbeddingsStreamName[] = "embeddings_out";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kImageInStreamName[] = "image_in";
|
||||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
|
@ -53,7 +54,7 @@ constexpr char kGraphTypeName[] =
|
|||
"mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
||||
using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
using ::mediapipe::tasks::vision::image_embedder::proto::
|
||||
|
@ -71,13 +72,13 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||
auto& task_graph = graph.AddNode(kGraphTypeName);
|
||||
task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get());
|
||||
task_graph.Out(kEmbeddingResultTag).SetName(kEmbeddingResultStreamName) >>
|
||||
graph.Out(kEmbeddingResultTag);
|
||||
task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >>
|
||||
graph.Out(kEmbeddingsTag);
|
||||
task_graph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||
graph.Out(kImageTag);
|
||||
if (enable_flow_limiting) {
|
||||
return tasks::core::AddFlowLimiterCalculator(
|
||||
graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingResultTag);
|
||||
graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingsTag);
|
||||
}
|
||||
graph.In(kImageTag) >> task_graph.In(kImageTag);
|
||||
graph.In(kNormRectTag) >> task_graph.In(kNormRectTag);
|
||||
|
@ -95,8 +96,8 @@ std::unique_ptr<ImageEmbedderGraphOptions> ConvertImageEmbedderOptionsToProto(
|
|||
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||
options->running_mode != core::RunningMode::IMAGE);
|
||||
auto embedder_options_proto =
|
||||
std::make_unique<tasks::components::proto::EmbedderOptions>(
|
||||
components::ConvertEmbedderOptionsToProto(
|
||||
std::make_unique<components::processors::proto::EmbedderOptions>(
|
||||
components::processors::ConvertEmbedderOptionsToProto(
|
||||
&(options->embedder_options)));
|
||||
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
|
||||
return options_proto;
|
||||
|
@ -121,9 +122,10 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
|
|||
return;
|
||||
}
|
||||
Packet embedding_result_packet =
|
||||
status_or_packets.value()[kEmbeddingResultStreamName];
|
||||
status_or_packets.value()[kEmbeddingsStreamName];
|
||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||
result_callback(embedding_result_packet.Get<EmbeddingResult>(),
|
||||
result_callback(ConvertToEmbeddingResult(
|
||||
embedding_result_packet.Get<EmbeddingResult>()),
|
||||
image_packet.Get<Image>(),
|
||||
embedding_result_packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond);
|
||||
|
@ -138,7 +140,7 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
|
|||
std::move(packets_callback));
|
||||
}
|
||||
|
||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
||||
absl::StatusOr<ImageEmbedderResult> ImageEmbedder::Embed(
|
||||
Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
|
@ -155,10 +157,11 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
|||
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
|
||||
return ConvertToEmbeddingResult(
|
||||
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
|
||||
}
|
||||
|
||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
||||
absl::StatusOr<ImageEmbedderResult> ImageEmbedder::EmbedForVideo(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
|
@ -178,7 +181,8 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
|||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
|
||||
return ConvertToEmbeddingResult(
|
||||
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
|
||||
}
|
||||
|
||||
absl::Status ImageEmbedder::EmbedAsync(
|
||||
|
@ -202,7 +206,8 @@ absl::Status ImageEmbedder::EmbedAsync(
|
|||
}
|
||||
|
||||
absl::StatusOr<double> ImageEmbedder::CosineSimilarity(
|
||||
const EmbeddingEntry& u, const EmbeddingEntry& v) {
|
||||
const components::containers::Embedding& u,
|
||||
const components::containers::Embedding& v) {
|
||||
return components::utils::CosineSimilarity(u, v);
|
||||
}
|
||||
|
||||
|
|
|
@ -21,8 +21,8 @@ limitations under the License.
|
|||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
|
@ -33,6 +33,10 @@ namespace tasks {
|
|||
namespace vision {
|
||||
namespace image_embedder {
|
||||
|
||||
// Alias the shared EmbeddingResult struct as result typo.
|
||||
using ImageEmbedderResult =
|
||||
::mediapipe::tasks::components::containers::EmbeddingResult;
|
||||
|
||||
// The options for configuring a MediaPipe image embedder task.
|
||||
struct ImageEmbedderOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
|
@ -50,14 +54,12 @@ struct ImageEmbedderOptions {
|
|||
|
||||
// Options for configuring the embedder behavior, such as L2-normalization or
|
||||
// scalar-quantization.
|
||||
components::EmbedderOptions embedder_options;
|
||||
components::processors::EmbedderOptions embedder_options;
|
||||
|
||||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM.
|
||||
std::function<void(
|
||||
absl::StatusOr<components::containers::proto::EmbeddingResult>,
|
||||
const Image&, int64)>
|
||||
std::function<void(absl::StatusOr<ImageEmbedderResult>, const Image&, int64)>
|
||||
result_callback = nullptr;
|
||||
};
|
||||
|
||||
|
@ -104,7 +106,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
|||
// running mode.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA.
|
||||
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed(
|
||||
absl::StatusOr<ImageEmbedderResult> Embed(
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
@ -127,7 +129,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo(
|
||||
absl::StatusOr<ImageEmbedderResult> EmbedForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
@ -168,15 +170,15 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
|||
// Shuts down the ImageEmbedder when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
||||
// Utility function to compute cosine similarity [1] between two embedding
|
||||
// entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||
// of different types (quantized vs. float), have different sizes, or have a
|
||||
// an L2-norm of 0.
|
||||
// Utility function to compute cosine similarity [1] between two embeddings.
|
||||
// May return an InvalidArgumentError if e.g. the embeddings are of different
|
||||
// types (quantized vs. float), have different sizes, or have a an L2-norm of
|
||||
// 0.
|
||||
//
|
||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||
static absl::StatusOr<double> CosineSimilarity(
|
||||
const components::containers::proto::EmbeddingEntry& u,
|
||||
const components::containers::proto::EmbeddingEntry& v);
|
||||
const components::containers::Embedding& u,
|
||||
const components::containers::Embedding& v);
|
||||
};
|
||||
|
||||
} // namespace image_embedder
|
||||
|
|
|
@ -20,10 +20,10 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
|
||||
|
||||
|
@ -40,10 +40,8 @@ using ::mediapipe::api2::builder::GenericNode;
|
|||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::mediapipe::tasks::components::proto::
|
||||
EmbeddingPostprocessingGraphOptions;
|
||||
|
||||
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
@ -67,7 +65,7 @@ struct ImageEmbedderOutputStreams {
|
|||
// Describes region of image to perform embedding extraction on.
|
||||
// @Optional: rect covering the whole image is used if not specified.
|
||||
// Outputs:
|
||||
// EMBEDDING_RESULT - EmbeddingResult
|
||||
// EMBEDDINGS - EmbeddingResult
|
||||
// The embedding result.
|
||||
// IMAGE - Image
|
||||
// The image that embedding extraction runs on.
|
||||
|
@ -76,7 +74,7 @@ struct ImageEmbedderOutputStreams {
|
|||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"
|
||||
// input_stream: "IMAGE:image_in"
|
||||
// output_stream: "EMBEDDING_RESULT:embedding_result_out"
|
||||
// output_stream: "EMBEDDINGS:embedding_result_out"
|
||||
// output_stream: "IMAGE:image_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext]
|
||||
|
@ -107,7 +105,7 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
|
|||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
output_streams.embedding_result >>
|
||||
graph[Output<EmbeddingResult>(kEmbeddingResultTag)];
|
||||
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
@ -152,16 +150,17 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
|
|||
// Adds postprocessing calculators and connects its input stream to the
|
||||
// inference results.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(components::ConfigureEmbeddingPostprocessing(
|
||||
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
||||
model_resources, task_options.embedder_options(),
|
||||
&postprocessing.GetOptions<EmbeddingPostprocessingGraphOptions>()));
|
||||
&postprocessing.GetOptions<components::processors::proto::
|
||||
EmbeddingPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Outputs the embedding results.
|
||||
return ImageEmbedderOutputStreams{
|
||||
/*embedding_result=*/postprocessing[Output<EmbeddingResult>(
|
||||
kEmbeddingResultTag)],
|
||||
kEmbeddingsTag)],
|
||||
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
||||
}
|
||||
};
|
||||
|
|
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
@ -42,7 +42,6 @@ namespace {
|
|||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
@ -54,18 +53,14 @@ constexpr double kSimilarityTolerancy = 1e-6;
|
|||
|
||||
// Utility function to check the sizes, head_index and head_names of a result
|
||||
// procuded by kMobileNetV3Embedder.
|
||||
void CheckMobileNetV3Result(const EmbeddingResult& result, bool quantized) {
|
||||
EXPECT_EQ(result.embeddings().size(), 1);
|
||||
EXPECT_EQ(result.embeddings(0).head_index(), 0);
|
||||
EXPECT_EQ(result.embeddings(0).head_name(), "feature");
|
||||
EXPECT_EQ(result.embeddings(0).entries().size(), 1);
|
||||
void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
|
||||
EXPECT_EQ(result.embeddings.size(), 1);
|
||||
EXPECT_EQ(result.embeddings[0].head_index, 0);
|
||||
EXPECT_EQ(result.embeddings[0].head_name, "feature");
|
||||
if (quantized) {
|
||||
EXPECT_EQ(
|
||||
result.embeddings(0).entries(0).quantized_embedding().values().size(),
|
||||
1024);
|
||||
EXPECT_EQ(result.embeddings[0].quantized_embedding.size(), 1024);
|
||||
} else {
|
||||
EXPECT_EQ(result.embeddings(0).entries(0).float_embedding().values().size(),
|
||||
1024);
|
||||
EXPECT_EQ(result.embeddings[0].float_embedding.size(), 1024);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -154,7 +149,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||
options->running_mode = running_mode;
|
||||
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
|
||||
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
|
||||
auto image_embedder = ImageEmbedder::Create(std::move(options));
|
||||
|
@ -231,19 +226,18 @@ TEST_F(ImageModeTest, SucceedsWithoutL2Normalization) {
|
|||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
||||
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||
image_embedder->Embed(image));
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
||||
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||
image_embedder->Embed(crop));
|
||||
|
||||
// Check results.
|
||||
CheckMobileNetV3Result(image_result, false);
|
||||
CheckMobileNetV3Result(crop_result, false);
|
||||
// CheckCosineSimilarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
||||
crop_result.embeddings(0).entries(0)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||
image_result.embeddings[0],
|
||||
crop_result.embeddings[0]));
|
||||
double expected_similarity = 0.925519;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
@ -264,19 +258,18 @@ TEST_F(ImageModeTest, SucceedsWithL2Normalization) {
|
|||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
||||
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||
image_embedder->Embed(image));
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
||||
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||
image_embedder->Embed(crop));
|
||||
|
||||
// Check results.
|
||||
CheckMobileNetV3Result(image_result, false);
|
||||
CheckMobileNetV3Result(crop_result, false);
|
||||
// CheckCosineSimilarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
||||
crop_result.embeddings(0).entries(0)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||
image_result.embeddings[0],
|
||||
crop_result.embeddings[0]));
|
||||
double expected_similarity = 0.925519;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
@ -297,19 +290,18 @@ TEST_F(ImageModeTest, SucceedsWithQuantization) {
|
|||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
||||
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||
image_embedder->Embed(image));
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
||||
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||
image_embedder->Embed(crop));
|
||||
|
||||
// Check results.
|
||||
CheckMobileNetV3Result(image_result, true);
|
||||
CheckMobileNetV3Result(crop_result, true);
|
||||
// CheckCosineSimilarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
||||
crop_result.embeddings(0).entries(0)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||
image_result.embeddings[0],
|
||||
crop_result.embeddings[0]));
|
||||
double expected_similarity = 0.926791;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
@ -333,19 +325,18 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
|||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
const EmbeddingResult& image_result,
|
||||
const ImageEmbedderResult& image_result,
|
||||
image_embedder->Embed(image, image_processing_options));
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
||||
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||
image_embedder->Embed(crop));
|
||||
|
||||
// Check results.
|
||||
CheckMobileNetV3Result(image_result, false);
|
||||
CheckMobileNetV3Result(crop_result, false);
|
||||
// CheckCosineSimilarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
||||
crop_result.embeddings(0).entries(0)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||
image_result.embeddings[0],
|
||||
crop_result.embeddings[0]));
|
||||
double expected_similarity = 0.999931;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
@ -367,20 +358,19 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
|
|||
image_processing_options.rotation_degrees = -90;
|
||||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
||||
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||
image_embedder->Embed(image));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
const EmbeddingResult& rotated_result,
|
||||
const ImageEmbedderResult& rotated_result,
|
||||
image_embedder->Embed(rotated, image_processing_options));
|
||||
|
||||
// Check results.
|
||||
CheckMobileNetV3Result(image_result, false);
|
||||
CheckMobileNetV3Result(rotated_result, false);
|
||||
// CheckCosineSimilarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
||||
rotated_result.embeddings(0).entries(0)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||
image_result.embeddings[0],
|
||||
rotated_result.embeddings[0]));
|
||||
double expected_similarity = 0.572265;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
@ -403,20 +393,19 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
|||
/*rotation_degrees=*/-90};
|
||||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
||||
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||
image_embedder->Embed(crop));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
const EmbeddingResult& rotated_result,
|
||||
const ImageEmbedderResult& rotated_result,
|
||||
image_embedder->Embed(rotated, image_processing_options));
|
||||
|
||||
// Check results.
|
||||
CheckMobileNetV3Result(crop_result, false);
|
||||
CheckMobileNetV3Result(rotated_result, false);
|
||||
// CheckCosineSimilarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0),
|
||||
rotated_result.embeddings(0).entries(0)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||
crop_result.embeddings[0],
|
||||
rotated_result.embeddings[0]));
|
||||
double expected_similarity = 0.62838;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
@ -487,16 +476,16 @@ TEST_F(VideoModeTest, Succeeds) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||
ImageEmbedder::Create(std::move(options)));
|
||||
|
||||
EmbeddingResult previous_results;
|
||||
ImageEmbedderResult previous_results;
|
||||
for (int i = 0; i < iterations; ++i) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
image_embedder->EmbedForVideo(image, i));
|
||||
CheckMobileNetV3Result(results, false);
|
||||
if (i > 0) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(double similarity,
|
||||
ImageEmbedder::CosineSimilarity(
|
||||
results.embeddings(0).entries(0),
|
||||
previous_results.embeddings(0).entries(0)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(results.embeddings[0],
|
||||
previous_results.embeddings[0]));
|
||||
double expected_similarity = 1.000000;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
@ -515,7 +504,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
|
||||
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||
ImageEmbedder::Create(std::move(options)));
|
||||
|
@ -546,7 +535,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
|
||||
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||
ImageEmbedder::Create(std::move(options)));
|
||||
|
@ -564,7 +553,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
|||
}
|
||||
|
||||
struct LiveStreamModeResults {
|
||||
EmbeddingResult embedding_result;
|
||||
ImageEmbedderResult embedding_result;
|
||||
std::pair<int, int> image_size;
|
||||
int64 timestamp_ms;
|
||||
};
|
||||
|
@ -580,7 +569,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
|||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->result_callback =
|
||||
[&results](absl::StatusOr<EmbeddingResult> embedding_result,
|
||||
[&results](absl::StatusOr<ImageEmbedderResult> embedding_result,
|
||||
const Image& image, int64 timestamp_ms) {
|
||||
MP_ASSERT_OK(embedding_result.status());
|
||||
results.push_back(
|
||||
|
@ -612,8 +601,8 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(
|
||||
result.embedding_result.embeddings(0).entries(0),
|
||||
results[i - 1].embedding_result.embeddings(0).entries(0)));
|
||||
result.embedding_result.embeddings[0],
|
||||
results[i - 1].embedding_result.embeddings[0]));
|
||||
double expected_similarity = 1.000000;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:embedder_options_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.vision.image_embedder.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/proto/embedder_options.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message ImageEmbedderGraphOptions {
|
||||
|
@ -31,5 +31,5 @@ message ImageEmbedderGraphOptions {
|
|||
|
||||
// Options for configuring the embedder behavior, such as normalization or
|
||||
// quantization.
|
||||
optional components.proto.EmbedderOptions embedder_options = 2;
|
||||
optional components.processors.proto.EmbedderOptions embedder_options = 2;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
|||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
||||
|
|
Loading…
Reference in New Issue
Block a user