diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 061875272..1f726a018 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -200,3 +200,38 @@ cc_test( "@com_google_absl//absl/status", ], ) + +cc_library( + name = "embedding_aggregation_calculator", + srcs = ["embedding_aggregation_calculator.cc"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + +cc_test( + name = "embedding_aggregation_calculator_test", + srcs = ["embedding_aggregation_calculator_test.cc"], + deps = [ + ":embedding_aggregation_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:output_stream_poller", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc new file mode 100644 index 000000000..bae926b76 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc @@ -0,0 +1,132 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" + +namespace mediapipe { +namespace api2 { + +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; + +// Aggregates EmbeddingResult packets into a vector of timestamped +// EmbeddingResult. Acts as a pass-through if no timestamp aggregation is +// needed. +// +// Inputs: +// EMBEDDINGS: EmbeddingResult +// The EmbeddingResult packets to aggregate. +// TIMESTAMPS: std::vector @Optional. +// The collection of timestamps that this calculator should aggregate. This +// stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS output +// will contain the aggregated results. Otherwise as no timestamp +// aggregation is required the EMBEDDINGS output is used to pass the inputs +// EmbeddingResults unchanged. +// +// Outputs: +// EMBEDDINGS: EmbeddingResult @Optional +// The input EmbeddingResult, unchanged. Must be connected if the TIMESTAMPS +// input is not connected, as it signals that timestamp aggregation is not +// required. +// TIMESTAMPED_EMBEDDINGS: std::vector @Optional +// The embedding results aggregated by timestamp. Must be connected if the +// TIMESTAMPS input is connected as it signals that timestamp aggregation is +// required. +// +// Example without timestamp aggregation (pass-through): +// node { +// calculator: "EmbeddingAggregationCalculator" +// input_stream: "EMBEDDINGS:embeddings_in" +// output_stream: "EMBEDDINGS:embeddings_out" +// } +// +// Example with timestamp aggregation: +// node { +// calculator: "EmbeddingAggregationCalculator" +// input_stream: "EMBEDDINGS:embeddings_in" +// input_stream: "TIMESTAMPS:timestamps_in" +// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out" +// } +class EmbeddingAggregationCalculator : public Node { + public: + static constexpr Input kEmbeddingsIn{"EMBEDDINGS"}; + static constexpr Input>::Optional kTimestampsIn{ + "TIMESTAMPS"}; + static constexpr Output::Optional kEmbeddingsOut{ + "EMBEDDINGS"}; + static constexpr Output>::Optional + kTimestampedEmbeddingsOut{"TIMESTAMPED_EMBEDDINGS"}; + MEDIAPIPE_NODE_CONTRACT(kEmbeddingsIn, kTimestampsIn, kEmbeddingsOut, + kTimestampedEmbeddingsOut); + + static absl::Status UpdateContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc); + absl::Status Process(CalculatorContext* cc); + + private: + bool time_aggregation_enabled_; + std::unordered_map cached_embeddings_; +}; + +absl::Status EmbeddingAggregationCalculator::UpdateContract( + CalculatorContract* cc) { + if (kTimestampsIn(cc).IsConnected()) { + RET_CHECK(kTimestampedEmbeddingsOut(cc).IsConnected()); + } else { + RET_CHECK(kEmbeddingsOut(cc).IsConnected()); + } + return absl::OkStatus(); +} + +absl::Status EmbeddingAggregationCalculator::Open(CalculatorContext* cc) { + time_aggregation_enabled_ = kTimestampsIn(cc).IsConnected(); + return absl::OkStatus(); +} + +absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) { + if (time_aggregation_enabled_) { + cached_embeddings_[cc->InputTimestamp().Value()] = + std::move(*kEmbeddingsIn(cc)); + if (kTimestampsIn(cc).IsEmpty()) { + return absl::OkStatus(); + } + auto timestamps = kTimestampsIn(cc).Get(); + std::vector results; + results.reserve(timestamps.size()); + for (const auto& timestamp : timestamps) { + auto& result = cached_embeddings_[timestamp.Value()]; + result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) / + 1000); + results.push_back(std::move(result)); + cached_embeddings_.erase(timestamp.Value()); + } + kTimestampedEmbeddingsOut(cc).Send(std::move(results)); + } else { + kEmbeddingsOut(cc).Send(kEmbeddingsIn(cc)); + } + RET_CHECK(cached_embeddings_.empty()); + return absl::OkStatus(); +} + +MEDIAPIPE_REGISTER_NODE(EmbeddingAggregationCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc new file mode 100644 index 000000000..ebb4d8880 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc @@ -0,0 +1,158 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/output_stream_poller.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::ParseTextProtoOrDie; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::testing::Pointwise; + +constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; +constexpr char kEmbeddingsInName[] = "embeddings_in"; +constexpr char kEmbeddingsOutName[] = "embeddings_out"; +constexpr char kTimestampsTag[] = "TIMESTAMPS"; +constexpr char kTimestampsName[] = "timestamps_in"; +constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS"; +constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out"; + +class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test { + protected: + absl::StatusOr BuildGraph(bool connect_timestamps) { + Graph graph; + auto& calculator = graph.AddNode("EmbeddingAggregationCalculator"); + graph[Input(kEmbeddingsTag)].SetName(kEmbeddingsInName) >> + calculator.In(kEmbeddingsTag); + if (connect_timestamps) { + graph[Input>(kTimestampsTag)].SetName( + kTimestampsName) >> + calculator.In(kTimestampsTag); + calculator.Out(kTimestampedEmbeddingsTag) + .SetName(kTimestampedEmbeddingsName) >> + graph[Output>( + kTimestampedEmbeddingsTag)]; + } else { + calculator.Out(kEmbeddingsTag).SetName(kEmbeddingsOutName) >> + graph[Output(kEmbeddingsTag)]; + } + + MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); + if (connect_timestamps) { + ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( + kTimestampedEmbeddingsName)); + MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); + return poller; + } + ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( + kEmbeddingsOutName)); + MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); + return poller; + } + + absl::Status Send( + const EmbeddingResult& embeddings, int timestamp = 0, + std::optional> aggregation_timestamps = std::nullopt) { + MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( + kEmbeddingsInName, MakePacket(std::move(embeddings)) + .At(Timestamp(timestamp)))); + if (aggregation_timestamps.has_value()) { + auto packet = std::make_unique>(); + for (const auto& timestamp : *aggregation_timestamps) { + packet->emplace_back(Timestamp(timestamp)); + } + MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( + kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); + } + return absl::OkStatus(); + } + + template + absl::StatusOr GetResult(OutputStreamPoller& poller) { + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); + MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); + + Packet packet; + if (!poller.Next(&packet)) { + return absl::InternalError("Unable to get output packet"); + } + auto result = packet.Get(); + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); + return result; + } + + private: + CalculatorGraph calculator_graph_; +}; + +TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) { + EmbeddingResult embedding = ParseTextProtoOrDie( + R"pb(embeddings { head_index: 0 })pb"); + + MP_ASSERT_OK_AND_ASSIGN(auto poller, + BuildGraph(/*connect_timestamps=*/false)); + MP_ASSERT_OK(Send(embedding)); + MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult(poller)); + + EXPECT_THAT(result, EqualsProto(embedding)); +} + +TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithTimestamps) { + MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true)); + MP_ASSERT_OK(Send(ParseTextProtoOrDie(R"pb(embeddings { + head_index: 0 + })pb"))); + MP_ASSERT_OK(Send( + ParseTextProtoOrDie( + R"pb(embeddings { head_index: 1 })pb"), + /*timestamp=*/1000, + /*aggregation_timestamps=*/std::optional>({0, 1000}))); + MP_ASSERT_OK_AND_ASSIGN(auto results, + GetResult>(poller)); + + EXPECT_THAT(results, + Pointwise(EqualsProto(), {ParseTextProtoOrDie( + R"pb(embeddings { head_index: 0 } + timestamp_ms: 0)pb"), + ParseTextProtoOrDie( + R"pb(embeddings { head_index: 1 } + timestamp_ms: 1)pb")})); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 12af55ed9..7845a3dae 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -82,6 +82,7 @@ cc_library( "//mediapipe/framework/formats:tensor", "//mediapipe/framework/tool:options_map", "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:embedding_aggregation_calculator", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc index 3a3884689..880aec5d7 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc @@ -56,6 +56,14 @@ using TensorsSource = constexpr char kTensorsTag[] = "TENSORS"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; +constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS"; +constexpr char kTimestampsTag[] = "TIMESTAMPS"; + +// Struct holding the different output streams produced by the graph. +struct EmbeddingPostprocessingOutputStreams { + Source embeddings; + Source> timestamped_embeddings; +}; // Identifies whether or not the model has quantized outputs, and performs // sanity checks. @@ -168,27 +176,39 @@ absl::Status ConfigureEmbeddingPostprocessing( // TENSORS - std::vector // The output tensors of an InferenceCalculator, to convert into // EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. +// TIMESTAMPS - std::vector @Optional +// The collection of the timestamps that this calculator should aggregate. +// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS +// output is used for results. Otherwise as no timestamp aggregation is +// required the EMBEDDINGS output is used for results. +// // Outputs: -// EMBEDDING_RESULT - EmbeddingResult -// The output EmbeddingResult. +// EMBEDDINGS - EmbeddingResult @Optional +// The embedding results aggregated by head. Must be connected if the +// TIMESTAMPS input is not connected, as it signals that timestamp +// aggregation is not required. +// TIMESTAMPED_EMBEDDINGS - std::vector @Optional +// The embedding result aggregated by timestamp, then by head. Must be +// connected if the TIMESTAMPS input is connected, as it signals that +// timestamp aggregation is required. // // The recommended way of using this graph is through the GraphBuilder API using // the 'ConfigureEmbeddingPostprocessing()' function. See header file for more // details. -// -// TODO: add support for additional optional "TIMESTAMPS" input for -// embeddings aggregation. class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( mediapipe::SubgraphContext* sc) override { Graph graph; ASSIGN_OR_RETURN( - auto embedding_result_out, + auto output_streams, BuildEmbeddingPostprocessing( sc->Options(), - graph[Input>(kTensorsTag)], graph)); - embedding_result_out >> graph[Output(kEmbeddingsTag)]; + graph[Input>(kTensorsTag)], + graph[Input>(kTimestampsTag)], graph)); + output_streams.embeddings >> graph[Output(kEmbeddingsTag)]; + output_streams.timestamped_embeddings >> + graph[Output>(kTimestampedEmbeddingsTag)]; return graph.GetConfig(); } @@ -200,10 +220,14 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { // // options: the on-device EmbeddingPostprocessingGraphOptions // tensors_in: (std::vector) tensors to postprocess. + // timestamps_in: (std::vector) optional collection of + // timestamps that should be used to aggregate embedding results. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr> BuildEmbeddingPostprocessing( + absl::StatusOr + BuildEmbeddingPostprocessing( const proto::EmbeddingPostprocessingGraphOptions options, - Source> tensors_in, Graph& graph) { + Source> tensors_in, + Source> timestamps_in, Graph& graph) { // If output tensors are quantized, they must be dequantized first. TensorsSource dequantized_tensors(&tensors_in); if (options.has_quantized_outputs()) { @@ -220,7 +244,20 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { .GetOptions() .CopyFrom(options.tensors_to_embeddings_options()); dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag); - return tensors_to_embeddings_node[Output(kEmbeddingsTag)]; + + // Adds EmbeddingAggregationCalculator. + GenericNode& aggregation_node = + graph.AddNode("EmbeddingAggregationCalculator"); + tensors_to_embeddings_node[Output(kEmbeddingsTag)] >> + aggregation_node.In(kEmbeddingsTag); + timestamps_in >> aggregation_node.In(kTimestampsTag); + + // Connects outputs. + return EmbeddingPostprocessingOutputStreams{ + /*embeddings=*/aggregation_node[Output( + kEmbeddingsTag)], + /*timestamped_embeddings=*/aggregation_node + [Output>(kTimestampedEmbeddingsTag)]}; } }; REGISTER_MEDIAPIPE_GRAPH( diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h index 5e8f2c084..58606ed80 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h @@ -44,12 +44,20 @@ namespace processors { // TENSORS - std::vector // The output tensors of an InferenceCalculator, to convert into // EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. +// TIMESTAMPS - std::vector @Optional +// The collection of the timestamps that this calculator should aggregate. +// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS +// output is used for results. Otherwise as no timestamp aggregation is +// required the EMBEDDINGS output is used for results. // Outputs: -// EMBEDDINGS - EmbeddingResult -// The output EmbeddingResult. -// -// TODO: add support for additional optional "TIMESTAMPS" input for -// embeddings aggregation. +// EMBEDDINGS - EmbeddingResult @Optional +// The embedding results aggregated by head. Must be connected if the +// TIMESTAMPS input is not connected, as it signals that timestamp +// aggregation is not required. +// TIMESTAMPED_EMBEDDINGS - std::vector @Optional +// The embedding result aggregated by timestamp, then by head. Must be +// connected if the TIMESTAMPS input is connected, as it signals that +// timestamp aggregation is required. absl::Status ConfigureEmbeddingPostprocessing( const tasks::core::ModelResources& model_resources, const proto::EmbedderOptions& embedder_options, diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 62fab8f7e..84d84d648 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -20,11 +20,20 @@ limitations under the License. #include "absl/flags/flag.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/graph_runner.h" +#include "mediapipe/framework/output_stream_poller.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" @@ -37,7 +46,12 @@ namespace components { namespace processors { namespace { +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::core::ModelResources; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; @@ -51,6 +65,16 @@ constexpr char kQuantizedImageClassifierWithoutMetadata[] = "vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; constexpr char kTestModelResourcesTag[] = "test_model_resources"; +constexpr int kMobileNetV3EmbedderEmbeddingSize = 1024; + +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kTensorsName[] = "tensors"; +constexpr char kTimestampsTag[] = "TIMESTAMPS"; +constexpr char kTimestampsName[] = "timestamps"; +constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; +constexpr char kEmbeddingsName[] = "embeddings"; +constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS"; +constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings"; // Helper function to get ModelResources. absl::StatusOr> CreateModelResourcesForModel( @@ -128,8 +152,171 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { has_quantized_outputs: false)pb"))); } -// TODO: add E2E Postprocessing tests once timestamp aggregation is -// supported. +class PostprocessingTest : public tflite_shims::testing::Test { + protected: + absl::StatusOr BuildGraph( + absl::string_view model_name, const proto::EmbedderOptions& options, + bool connect_timestamps = false) { + ASSIGN_OR_RETURN(auto model_resources, + CreateModelResourcesForModel(model_name)); + + Graph graph; + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors." + "EmbeddingPostprocessingGraph"); + MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing( + *model_resources, options, + &postprocessing + .GetOptions())); + graph[Input>(kTensorsTag)].SetName(kTensorsName) >> + postprocessing.In(kTensorsTag); + if (connect_timestamps) { + graph[Input>(kTimestampsTag)].SetName( + kTimestampsName) >> + postprocessing.In(kTimestampsTag); + postprocessing.Out(kTimestampedEmbeddingsTag) + .SetName(kTimestampedEmbeddingsName) >> + graph[Output>( + kTimestampedEmbeddingsTag)]; + } else { + postprocessing.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >> + graph[Output(kEmbeddingsTag)]; + } + + MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); + if (connect_timestamps) { + ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( + kTimestampedEmbeddingsName)); + MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); + return poller; + } + ASSIGN_OR_RETURN(auto poller, + calculator_graph_.AddOutputStreamPoller(kEmbeddingsName)); + MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); + return poller; + } + + template + void AddTensor( + const std::vector& tensor, const Tensor::ElementType& element_type, + const Tensor::QuantizationParameters& quantization_parameters = {}) { + tensors_->emplace_back(element_type, + Tensor::Shape{1, static_cast(tensor.size())}, + quantization_parameters); + auto view = tensors_->back().GetCpuWriteView(); + T* buffer = view.buffer(); + std::copy(tensor.begin(), tensor.end(), buffer); + } + + absl::Status Run( + std::optional> aggregation_timestamps = std::nullopt, + int timestamp = 0) { + MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( + kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp)))); + // Reset tensors for future calls. + tensors_ = absl::make_unique>(); + if (aggregation_timestamps.has_value()) { + auto packet = absl::make_unique>(); + for (const auto& timestamp : *aggregation_timestamps) { + packet->emplace_back(Timestamp(timestamp)); + } + MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( + kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); + } + return absl::OkStatus(); + } + + template + absl::StatusOr GetResult(OutputStreamPoller& poller) { + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); + MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); + + Packet packet; + if (!poller.Next(&packet)) { + return absl::InternalError("Unable to get output packet"); + } + auto result = packet.Get(); + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); + return result; + } + + private: + CalculatorGraph calculator_graph_; + std::unique_ptr> tensors_ = + absl::make_unique>(); +}; + +TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) { + // Build graph. + proto::EmbedderOptions options; + MP_ASSERT_OK_AND_ASSIGN(auto poller, + BuildGraph(kMobileNetV3Embedder, options)); + // Build input tensor. + std::vector tensor(kMobileNetV3EmbedderEmbeddingSize, 0); + tensor[0] = 1.0; + + // Send tensor and get results. + AddTensor(tensor, Tensor::ElementType::kFloat32); + MP_ASSERT_OK(Run()); + MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult(poller)); + + // Validate results. + EXPECT_FALSE(results.has_timestamp_ms()); + EXPECT_EQ(results.embeddings_size(), 1); + EXPECT_EQ(results.embeddings(0).head_index(), 0); + EXPECT_EQ(results.embeddings(0).head_name(), "feature"); + EXPECT_EQ(results.embeddings(0).float_embedding().values_size(), + kMobileNetV3EmbedderEmbeddingSize); + EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(0), 1.0); + for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) { + EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(i), 0.0); + } +} + +TEST_F(PostprocessingTest, SucceedsWithTimestamps) { + // Build graph. + proto::EmbedderOptions options; + MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(kMobileNetV3Embedder, options, + /*connect_timestamps=*/true)); + // Build input tensors. + std::vector tensor_0(kMobileNetV3EmbedderEmbeddingSize, 0); + tensor_0[0] = 1.0; + std::vector tensor_1(kMobileNetV3EmbedderEmbeddingSize, 0); + tensor_1[0] = 2.0; + + // Send tensors and get results. + AddTensor(tensor_0, Tensor::ElementType::kFloat32); + MP_ASSERT_OK(Run()); + AddTensor(tensor_1, Tensor::ElementType::kFloat32); + MP_ASSERT_OK(Run( + /*aggregation_timestamps=*/std::optional>({0, 1000}), + /*timestamp=*/1000)); + MP_ASSERT_OK_AND_ASSIGN(auto results, + GetResult>(poller)); + + // Validate results. + EXPECT_EQ(results.size(), 2); + // First timestamp. + EXPECT_EQ(results[0].timestamp_ms(), 0); + EXPECT_EQ(results[0].embeddings(0).head_index(), 0); + EXPECT_EQ(results[0].embeddings(0).head_name(), "feature"); + EXPECT_EQ(results[0].embeddings(0).float_embedding().values_size(), + kMobileNetV3EmbedderEmbeddingSize); + EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(0), 1.0); + for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) { + EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(i), 0.0); + } + // Second timestamp. + EXPECT_EQ(results[1].timestamp_ms(), 1); + EXPECT_EQ(results[1].embeddings(0).head_index(), 0); + EXPECT_EQ(results[1].embeddings(0).head_name(), "feature"); + EXPECT_EQ(results[1].embeddings(0).float_embedding().values_size(), + kMobileNetV3EmbedderEmbeddingSize); + EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(0), 2.0); + for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) { + EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(i), 0.0); + } +} } // namespace } // namespace processors diff --git a/mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.proto index f8dbf59f0..3a50818f6 100644 --- a/mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.proto @@ -32,7 +32,4 @@ message EmbeddingPostprocessingGraphOptions { // Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32). optional bool has_quantized_outputs = 2; - - // TODO: add options to control whether timestamp aggregation - // should be used or not. }