Add timestamp aggregation to EmbeddingPostprocessingGraph.

PiperOrigin-RevId: 487463848
This commit is contained in:
MediaPipe Team 2022-11-10 01:19:00 -08:00 committed by Copybara-Service
parent f11c757629
commit 0ac604d507
8 changed files with 576 additions and 21 deletions

View File

@ -200,3 +200,38 @@ cc_test(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
], ],
) )
cc_library(
name = "embedding_aggregation_calculator",
srcs = ["embedding_aggregation_calculator.cc"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"@com_google_absl//absl/status",
],
alwayslink = 1,
)
cc_test(
name = "embedding_aggregation_calculator_test",
srcs = ["embedding_aggregation_calculator_test.cc"],
deps = [
":embedding_aggregation_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:output_stream_poller",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)

View File

@ -0,0 +1,132 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unordered_map>
#include <vector>
#include "absl/status/status.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
// Aggregates EmbeddingResult packets into a vector of timestamped
// EmbeddingResult. Acts as a pass-through if no timestamp aggregation is
// needed.
//
// Inputs:
// EMBEDDINGS: EmbeddingResult
// The EmbeddingResult packets to aggregate.
// TIMESTAMPS: std::vector<Timestamp> @Optional.
// The collection of timestamps that this calculator should aggregate. This
// stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS output
// will contain the aggregated results. Otherwise as no timestamp
// aggregation is required the EMBEDDINGS output is used to pass the inputs
// EmbeddingResults unchanged.
//
// Outputs:
// EMBEDDINGS: EmbeddingResult @Optional
// The input EmbeddingResult, unchanged. Must be connected if the TIMESTAMPS
// input is not connected, as it signals that timestamp aggregation is not
// required.
// TIMESTAMPED_EMBEDDINGS: std::vector<EmbeddingResult> @Optional
// The embedding results aggregated by timestamp. Must be connected if the
// TIMESTAMPS input is connected as it signals that timestamp aggregation is
// required.
//
// Example without timestamp aggregation (pass-through):
// node {
// calculator: "EmbeddingAggregationCalculator"
// input_stream: "EMBEDDINGS:embeddings_in"
// output_stream: "EMBEDDINGS:embeddings_out"
// }
//
// Example with timestamp aggregation:
// node {
// calculator: "EmbeddingAggregationCalculator"
// input_stream: "EMBEDDINGS:embeddings_in"
// input_stream: "TIMESTAMPS:timestamps_in"
// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out"
// }
class EmbeddingAggregationCalculator : public Node {
public:
static constexpr Input<EmbeddingResult> kEmbeddingsIn{"EMBEDDINGS"};
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
"TIMESTAMPS"};
static constexpr Output<EmbeddingResult>::Optional kEmbeddingsOut{
"EMBEDDINGS"};
static constexpr Output<std::vector<EmbeddingResult>>::Optional
kTimestampedEmbeddingsOut{"TIMESTAMPED_EMBEDDINGS"};
MEDIAPIPE_NODE_CONTRACT(kEmbeddingsIn, kTimestampsIn, kEmbeddingsOut,
kTimestampedEmbeddingsOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc);
absl::Status Process(CalculatorContext* cc);
private:
bool time_aggregation_enabled_;
std::unordered_map<int64, EmbeddingResult> cached_embeddings_;
};
absl::Status EmbeddingAggregationCalculator::UpdateContract(
CalculatorContract* cc) {
if (kTimestampsIn(cc).IsConnected()) {
RET_CHECK(kTimestampedEmbeddingsOut(cc).IsConnected());
} else {
RET_CHECK(kEmbeddingsOut(cc).IsConnected());
}
return absl::OkStatus();
}
absl::Status EmbeddingAggregationCalculator::Open(CalculatorContext* cc) {
time_aggregation_enabled_ = kTimestampsIn(cc).IsConnected();
return absl::OkStatus();
}
absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) {
if (time_aggregation_enabled_) {
cached_embeddings_[cc->InputTimestamp().Value()] =
std::move(*kEmbeddingsIn(cc));
if (kTimestampsIn(cc).IsEmpty()) {
return absl::OkStatus();
}
auto timestamps = kTimestampsIn(cc).Get();
std::vector<EmbeddingResult> results;
results.reserve(timestamps.size());
for (const auto& timestamp : timestamps) {
auto& result = cached_embeddings_[timestamp.Value()];
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) /
1000);
results.push_back(std::move(result));
cached_embeddings_.erase(timestamp.Value());
}
kTimestampedEmbeddingsOut(cc).Send(std::move(results));
} else {
kEmbeddingsOut(cc).Send(kEmbeddingsIn(cc));
}
RET_CHECK(cached_embeddings_.empty());
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(EmbeddingAggregationCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,158 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <optional>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace {
using ::mediapipe::ParseTextProtoOrDie;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::testing::Pointwise;
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kEmbeddingsInName[] = "embeddings_in";
constexpr char kEmbeddingsOutName[] = "embeddings_out";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampsName[] = "timestamps_in";
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
Graph graph;
auto& calculator = graph.AddNode("EmbeddingAggregationCalculator");
graph[Input<EmbeddingResult>(kEmbeddingsTag)].SetName(kEmbeddingsInName) >>
calculator.In(kEmbeddingsTag);
if (connect_timestamps) {
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
kTimestampsName) >>
calculator.In(kTimestampsTag);
calculator.Out(kTimestampedEmbeddingsTag)
.SetName(kTimestampedEmbeddingsName) >>
graph[Output<std::vector<EmbeddingResult>>(
kTimestampedEmbeddingsTag)];
} else {
calculator.Out(kEmbeddingsTag).SetName(kEmbeddingsOutName) >>
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
}
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
if (connect_timestamps) {
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kTimestampedEmbeddingsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kEmbeddingsOutName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
absl::Status Send(
const EmbeddingResult& embeddings, int timestamp = 0,
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kEmbeddingsInName, MakePacket<EmbeddingResult>(std::move(embeddings))
.At(Timestamp(timestamp))));
if (aggregation_timestamps.has_value()) {
auto packet = std::make_unique<std::vector<Timestamp>>();
for (const auto& timestamp : *aggregation_timestamps) {
packet->emplace_back(Timestamp(timestamp));
}
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
}
return absl::OkStatus();
}
template <typename T>
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<T>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
};
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) {
EmbeddingResult embedding = ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { head_index: 0 })pb");
MP_ASSERT_OK_AND_ASSIGN(auto poller,
BuildGraph(/*connect_timestamps=*/false));
MP_ASSERT_OK(Send(embedding));
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<EmbeddingResult>(poller));
EXPECT_THAT(result, EqualsProto(embedding));
}
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
MP_ASSERT_OK(Send(ParseTextProtoOrDie<EmbeddingResult>(R"pb(embeddings {
head_index: 0
})pb")));
MP_ASSERT_OK(Send(
ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { head_index: 1 })pb"),
/*timestamp=*/1000,
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
MP_ASSERT_OK_AND_ASSIGN(auto results,
GetResult<std::vector<EmbeddingResult>>(poller));
EXPECT_THAT(results,
Pointwise(EqualsProto(), {ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { head_index: 0 }
timestamp_ms: 0)pb"),
ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { head_index: 1 }
timestamp_ms: 1)pb")}));
}
} // namespace
} // namespace mediapipe

View File

@ -82,6 +82,7 @@ cc_library(
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/tool:options_map", "//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:embedding_aggregation_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",

View File

@ -56,6 +56,14 @@ using TensorsSource =
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Struct holding the different output streams produced by the graph.
struct EmbeddingPostprocessingOutputStreams {
Source<EmbeddingResult> embeddings;
Source<std::vector<EmbeddingResult>> timestamped_embeddings;
};
// Identifies whether or not the model has quantized outputs, and performs // Identifies whether or not the model has quantized outputs, and performs
// sanity checks. // sanity checks.
@ -168,27 +176,39 @@ absl::Status ConfigureEmbeddingPostprocessing(
// TENSORS - std::vector<Tensor> // TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator, to convert into // The output tensors of an InferenceCalculator, to convert into
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. // EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
// TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of the timestamps that this calculator should aggregate.
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
// output is used for results. Otherwise as no timestamp aggregation is
// required the EMBEDDINGS output is used for results.
//
// Outputs: // Outputs:
// EMBEDDING_RESULT - EmbeddingResult // EMBEDDINGS - EmbeddingResult @Optional
// The output EmbeddingResult. // The embedding results aggregated by head. Must be connected if the
// TIMESTAMPS input is not connected, as it signals that timestamp
// aggregation is not required.
// TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
// The embedding result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// //
// The recommended way of using this graph is through the GraphBuilder API using // The recommended way of using this graph is through the GraphBuilder API using
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more // the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
// details. // details.
//
// TODO: add support for additional optional "TIMESTAMPS" input for
// embeddings aggregation.
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
public: public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig( absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override { mediapipe::SubgraphContext* sc) override {
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto embedding_result_out, auto output_streams,
BuildEmbeddingPostprocessing( BuildEmbeddingPostprocessing(
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(), sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph)); graph[Input<std::vector<Tensor>>(kTensorsTag)],
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)]; graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
output_streams.embeddings >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
output_streams.timestamped_embeddings >>
graph[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -200,10 +220,14 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
// //
// options: the on-device EmbeddingPostprocessingGraphOptions // options: the on-device EmbeddingPostprocessingGraphOptions
// tensors_in: (std::vector<mediapipe::Tensor>) tensors to postprocess. // tensors_in: (std::vector<mediapipe::Tensor>) tensors to postprocess.
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
// timestamps that should be used to aggregate embedding results.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<EmbeddingResult>> BuildEmbeddingPostprocessing( absl::StatusOr<EmbeddingPostprocessingOutputStreams>
BuildEmbeddingPostprocessing(
const proto::EmbeddingPostprocessingGraphOptions options, const proto::EmbeddingPostprocessingGraphOptions options,
Source<std::vector<Tensor>> tensors_in, Graph& graph) { Source<std::vector<Tensor>> tensors_in,
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
// If output tensors are quantized, they must be dequantized first. // If output tensors are quantized, they must be dequantized first.
TensorsSource dequantized_tensors(&tensors_in); TensorsSource dequantized_tensors(&tensors_in);
if (options.has_quantized_outputs()) { if (options.has_quantized_outputs()) {
@ -220,7 +244,20 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>() .GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
.CopyFrom(options.tensors_to_embeddings_options()); .CopyFrom(options.tensors_to_embeddings_options());
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag); dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
// Adds EmbeddingAggregationCalculator.
GenericNode& aggregation_node =
graph.AddNode("EmbeddingAggregationCalculator");
tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)] >>
aggregation_node.In(kEmbeddingsTag);
timestamps_in >> aggregation_node.In(kTimestampsTag);
// Connects outputs.
return EmbeddingPostprocessingOutputStreams{
/*embeddings=*/aggregation_node[Output<EmbeddingResult>(
kEmbeddingsTag)],
/*timestamped_embeddings=*/aggregation_node
[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)]};
} }
}; };
REGISTER_MEDIAPIPE_GRAPH( REGISTER_MEDIAPIPE_GRAPH(

View File

@ -44,12 +44,20 @@ namespace processors {
// TENSORS - std::vector<Tensor> // TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator, to convert into // The output tensors of an InferenceCalculator, to convert into
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. // EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
// TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of the timestamps that this calculator should aggregate.
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
// output is used for results. Otherwise as no timestamp aggregation is
// required the EMBEDDINGS output is used for results.
// Outputs: // Outputs:
// EMBEDDINGS - EmbeddingResult // EMBEDDINGS - EmbeddingResult @Optional
// The output EmbeddingResult. // The embedding results aggregated by head. Must be connected if the
// // TIMESTAMPS input is not connected, as it signals that timestamp
// TODO: add support for additional optional "TIMESTAMPS" input for // aggregation is not required.
// embeddings aggregation. // TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
// The embedding result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
absl::Status ConfigureEmbeddingPostprocessing( absl::Status ConfigureEmbeddingPostprocessing(
const tasks::core::ModelResources& model_resources, const tasks::core::ModelResources& model_resources,
const proto::EmbedderOptions& embedder_options, const proto::EmbedderOptions& embedder_options,

View File

@ -20,11 +20,20 @@ limitations under the License.
#include "absl/flags/flag.h" #include "absl/flags/flag.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/graph_runner.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
@ -37,7 +46,12 @@ namespace components {
namespace processors { namespace processors {
namespace { namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
@ -51,6 +65,16 @@ constexpr char kQuantizedImageClassifierWithoutMetadata[] =
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; "vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
constexpr char kTestModelResourcesTag[] = "test_model_resources"; constexpr char kTestModelResourcesTag[] = "test_model_resources";
constexpr int kMobileNetV3EmbedderEmbeddingSize = 1024;
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTensorsName[] = "tensors";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampsName[] = "timestamps";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kEmbeddingsName[] = "embeddings";
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings";
// Helper function to get ModelResources. // Helper function to get ModelResources.
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel( absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
@ -128,8 +152,171 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
has_quantized_outputs: false)pb"))); has_quantized_outputs: false)pb")));
} }
// TODO: add E2E Postprocessing tests once timestamp aggregation is class PostprocessingTest : public tflite_shims::testing::Test {
// supported. protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const proto::EmbedderOptions& options,
bool connect_timestamps = false) {
ASSIGN_OR_RETURN(auto model_resources,
CreateModelResourcesForModel(model_name));
Graph graph;
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors."
"EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing(
*model_resources, options,
&postprocessing
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
postprocessing.In(kTensorsTag);
if (connect_timestamps) {
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
kTimestampsName) >>
postprocessing.In(kTimestampsTag);
postprocessing.Out(kTimestampedEmbeddingsTag)
.SetName(kTimestampedEmbeddingsName) >>
graph[Output<std::vector<EmbeddingResult>>(
kTimestampedEmbeddingsTag)];
} else {
postprocessing.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >>
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
}
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
if (connect_timestamps) {
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kTimestampedEmbeddingsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
ASSIGN_OR_RETURN(auto poller,
calculator_graph_.AddOutputStreamPoller(kEmbeddingsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
template <typename T>
void AddTensor(
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
const Tensor::QuantizationParameters& quantization_parameters = {}) {
tensors_->emplace_back(element_type,
Tensor::Shape{1, static_cast<int>(tensor.size())},
quantization_parameters);
auto view = tensors_->back().GetCpuWriteView();
T* buffer = view.buffer<T>();
std::copy(tensor.begin(), tensor.end(), buffer);
}
absl::Status Run(
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
int timestamp = 0) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
// Reset tensors for future calls.
tensors_ = absl::make_unique<std::vector<Tensor>>();
if (aggregation_timestamps.has_value()) {
auto packet = absl::make_unique<std::vector<Timestamp>>();
for (const auto& timestamp : *aggregation_timestamps) {
packet->emplace_back(Timestamp(timestamp));
}
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
}
return absl::OkStatus();
}
template <typename T>
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<T>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
std::unique_ptr<std::vector<Tensor>> tensors_ =
absl::make_unique<std::vector<Tensor>>();
};
TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) {
// Build graph.
proto::EmbedderOptions options;
MP_ASSERT_OK_AND_ASSIGN(auto poller,
BuildGraph(kMobileNetV3Embedder, options));
// Build input tensor.
std::vector<float> tensor(kMobileNetV3EmbedderEmbeddingSize, 0);
tensor[0] = 1.0;
// Send tensor and get results.
AddTensor(tensor, Tensor::ElementType::kFloat32);
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult<EmbeddingResult>(poller));
// Validate results.
EXPECT_FALSE(results.has_timestamp_ms());
EXPECT_EQ(results.embeddings_size(), 1);
EXPECT_EQ(results.embeddings(0).head_index(), 0);
EXPECT_EQ(results.embeddings(0).head_name(), "feature");
EXPECT_EQ(results.embeddings(0).float_embedding().values_size(),
kMobileNetV3EmbedderEmbeddingSize);
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(0), 1.0);
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(i), 0.0);
}
}
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
// Build graph.
proto::EmbedderOptions options;
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(kMobileNetV3Embedder, options,
/*connect_timestamps=*/true));
// Build input tensors.
std::vector<float> tensor_0(kMobileNetV3EmbedderEmbeddingSize, 0);
tensor_0[0] = 1.0;
std::vector<float> tensor_1(kMobileNetV3EmbedderEmbeddingSize, 0);
tensor_1[0] = 2.0;
// Send tensors and get results.
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
MP_ASSERT_OK(Run());
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
MP_ASSERT_OK(Run(
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
/*timestamp=*/1000));
MP_ASSERT_OK_AND_ASSIGN(auto results,
GetResult<std::vector<EmbeddingResult>>(poller));
// Validate results.
EXPECT_EQ(results.size(), 2);
// First timestamp.
EXPECT_EQ(results[0].timestamp_ms(), 0);
EXPECT_EQ(results[0].embeddings(0).head_index(), 0);
EXPECT_EQ(results[0].embeddings(0).head_name(), "feature");
EXPECT_EQ(results[0].embeddings(0).float_embedding().values_size(),
kMobileNetV3EmbedderEmbeddingSize);
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(0), 1.0);
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(i), 0.0);
}
// Second timestamp.
EXPECT_EQ(results[1].timestamp_ms(), 1);
EXPECT_EQ(results[1].embeddings(0).head_index(), 0);
EXPECT_EQ(results[1].embeddings(0).head_name(), "feature");
EXPECT_EQ(results[1].embeddings(0).float_embedding().values_size(),
kMobileNetV3EmbedderEmbeddingSize);
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(0), 2.0);
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(i), 0.0);
}
}
} // namespace } // namespace
} // namespace processors } // namespace processors

View File

@ -32,7 +32,4 @@ message EmbeddingPostprocessingGraphOptions {
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32). // Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
optional bool has_quantized_outputs = 2; optional bool has_quantized_outputs = 2;
// TODO: add options to control whether timestamp aggregation
// should be used or not.
} }