Add timestamp aggregation to EmbeddingPostprocessingGraph.
PiperOrigin-RevId: 487463848
This commit is contained in:
		
							parent
							
								
									f11c757629
								
							
						
					
					
						commit
						0ac604d507
					
				|  | @ -200,3 +200,38 @@ cc_test( | |||
|         "@com_google_absl//absl/status", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "embedding_aggregation_calculator", | ||||
|     srcs = ["embedding_aggregation_calculator.cc"], | ||||
|     deps = [ | ||||
|         "//mediapipe/framework:calculator_framework", | ||||
|         "//mediapipe/framework/api2:node", | ||||
|         "//mediapipe/framework/api2:packet", | ||||
|         "//mediapipe/framework/api2:port", | ||||
|         "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", | ||||
|         "@com_google_absl//absl/status", | ||||
|     ], | ||||
|     alwayslink = 1, | ||||
| ) | ||||
| 
 | ||||
| cc_test( | ||||
|     name = "embedding_aggregation_calculator_test", | ||||
|     srcs = ["embedding_aggregation_calculator_test.cc"], | ||||
|     deps = [ | ||||
|         ":embedding_aggregation_calculator", | ||||
|         "//mediapipe/framework:calculator_framework", | ||||
|         "//mediapipe/framework:output_stream_poller", | ||||
|         "//mediapipe/framework:packet", | ||||
|         "//mediapipe/framework:timestamp", | ||||
|         "//mediapipe/framework/api2:builder", | ||||
|         "//mediapipe/framework/api2:port", | ||||
|         "//mediapipe/framework/port:gtest_main", | ||||
|         "//mediapipe/framework/port:parse_text_proto", | ||||
|         "//mediapipe/framework/port:status", | ||||
|         "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", | ||||
|         "@com_google_absl//absl/status", | ||||
|         "@com_google_absl//absl/status:statusor", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", | ||||
|     ], | ||||
| ) | ||||
|  |  | |||
|  | @ -0,0 +1,132 @@ | |||
| // Copyright 2022 The MediaPipe Authors.
 | ||||
| //
 | ||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||
| // you may not use this file except in compliance with the License.
 | ||||
| // You may obtain a copy of the License at
 | ||||
| //
 | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| //
 | ||||
| // Unless required by applicable law or agreed to in writing, software
 | ||||
| // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||
| // See the License for the specific language governing permissions and
 | ||||
| // limitations under the License.
 | ||||
| 
 | ||||
| #include <unordered_map> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "absl/status/status.h" | ||||
| #include "mediapipe/framework/api2/node.h" | ||||
| #include "mediapipe/framework/api2/packet.h" | ||||
| #include "mediapipe/framework/api2/port.h" | ||||
| #include "mediapipe/framework/calculator_framework.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace api2 { | ||||
| 
 | ||||
| using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; | ||||
| 
 | ||||
| // Aggregates EmbeddingResult packets into a vector of timestamped
 | ||||
| // EmbeddingResult. Acts as a pass-through if no timestamp aggregation is
 | ||||
| // needed.
 | ||||
| //
 | ||||
| // Inputs:
 | ||||
| //   EMBEDDINGS: EmbeddingResult
 | ||||
| //     The EmbeddingResult packets to aggregate.
 | ||||
| //   TIMESTAMPS: std::vector<Timestamp> @Optional.
 | ||||
| //     The collection of timestamps that this calculator should aggregate. This
 | ||||
| //     stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS output
 | ||||
| //     will contain the aggregated results. Otherwise as no timestamp
 | ||||
| //     aggregation is required the EMBEDDINGS output is used to pass the inputs
 | ||||
| //     EmbeddingResults unchanged.
 | ||||
| //
 | ||||
| // Outputs:
 | ||||
| //   EMBEDDINGS: EmbeddingResult @Optional
 | ||||
| //     The input EmbeddingResult, unchanged. Must be connected if the TIMESTAMPS
 | ||||
| //     input is not connected, as it signals that timestamp aggregation is not
 | ||||
| //     required.
 | ||||
| //  TIMESTAMPED_EMBEDDINGS: std::vector<EmbeddingResult> @Optional
 | ||||
| //     The embedding results aggregated by timestamp. Must be connected if the
 | ||||
| //     TIMESTAMPS input is connected as it signals that timestamp aggregation is
 | ||||
| //     required.
 | ||||
| //
 | ||||
| // Example without timestamp aggregation (pass-through):
 | ||||
| // node {
 | ||||
| //   calculator: "EmbeddingAggregationCalculator"
 | ||||
| //   input_stream: "EMBEDDINGS:embeddings_in"
 | ||||
| //   output_stream: "EMBEDDINGS:embeddings_out"
 | ||||
| // }
 | ||||
| //
 | ||||
| // Example with timestamp aggregation:
 | ||||
| // node {
 | ||||
| //   calculator: "EmbeddingAggregationCalculator"
 | ||||
| //   input_stream: "EMBEDDINGS:embeddings_in"
 | ||||
| //   input_stream: "TIMESTAMPS:timestamps_in"
 | ||||
| //   output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out"
 | ||||
| // }
 | ||||
| class EmbeddingAggregationCalculator : public Node { | ||||
|  public: | ||||
|   static constexpr Input<EmbeddingResult> kEmbeddingsIn{"EMBEDDINGS"}; | ||||
|   static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{ | ||||
|       "TIMESTAMPS"}; | ||||
|   static constexpr Output<EmbeddingResult>::Optional kEmbeddingsOut{ | ||||
|       "EMBEDDINGS"}; | ||||
|   static constexpr Output<std::vector<EmbeddingResult>>::Optional | ||||
|       kTimestampedEmbeddingsOut{"TIMESTAMPED_EMBEDDINGS"}; | ||||
|   MEDIAPIPE_NODE_CONTRACT(kEmbeddingsIn, kTimestampsIn, kEmbeddingsOut, | ||||
|                           kTimestampedEmbeddingsOut); | ||||
| 
 | ||||
|   static absl::Status UpdateContract(CalculatorContract* cc); | ||||
|   absl::Status Open(CalculatorContext* cc); | ||||
|   absl::Status Process(CalculatorContext* cc); | ||||
| 
 | ||||
|  private: | ||||
|   bool time_aggregation_enabled_; | ||||
|   std::unordered_map<int64, EmbeddingResult> cached_embeddings_; | ||||
| }; | ||||
| 
 | ||||
| absl::Status EmbeddingAggregationCalculator::UpdateContract( | ||||
|     CalculatorContract* cc) { | ||||
|   if (kTimestampsIn(cc).IsConnected()) { | ||||
|     RET_CHECK(kTimestampedEmbeddingsOut(cc).IsConnected()); | ||||
|   } else { | ||||
|     RET_CHECK(kEmbeddingsOut(cc).IsConnected()); | ||||
|   } | ||||
|   return absl::OkStatus(); | ||||
| } | ||||
| 
 | ||||
| absl::Status EmbeddingAggregationCalculator::Open(CalculatorContext* cc) { | ||||
|   time_aggregation_enabled_ = kTimestampsIn(cc).IsConnected(); | ||||
|   return absl::OkStatus(); | ||||
| } | ||||
| 
 | ||||
| absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) { | ||||
|   if (time_aggregation_enabled_) { | ||||
|     cached_embeddings_[cc->InputTimestamp().Value()] = | ||||
|         std::move(*kEmbeddingsIn(cc)); | ||||
|     if (kTimestampsIn(cc).IsEmpty()) { | ||||
|       return absl::OkStatus(); | ||||
|     } | ||||
|     auto timestamps = kTimestampsIn(cc).Get(); | ||||
|     std::vector<EmbeddingResult> results; | ||||
|     results.reserve(timestamps.size()); | ||||
|     for (const auto& timestamp : timestamps) { | ||||
|       auto& result = cached_embeddings_[timestamp.Value()]; | ||||
|       result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) / | ||||
|                               1000); | ||||
|       results.push_back(std::move(result)); | ||||
|       cached_embeddings_.erase(timestamp.Value()); | ||||
|     } | ||||
|     kTimestampedEmbeddingsOut(cc).Send(std::move(results)); | ||||
|   } else { | ||||
|     kEmbeddingsOut(cc).Send(kEmbeddingsIn(cc)); | ||||
|   } | ||||
|   RET_CHECK(cached_embeddings_.empty()); | ||||
|   return absl::OkStatus(); | ||||
| } | ||||
| 
 | ||||
| MEDIAPIPE_REGISTER_NODE(EmbeddingAggregationCalculator); | ||||
| 
 | ||||
| }  // namespace api2
 | ||||
| }  // namespace mediapipe
 | ||||
|  | @ -0,0 +1,158 @@ | |||
| /* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include <memory> | ||||
| #include <optional> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "absl/status/status.h" | ||||
| #include "absl/status/statusor.h" | ||||
| #include "mediapipe/framework/api2/builder.h" | ||||
| #include "mediapipe/framework/api2/port.h" | ||||
| #include "mediapipe/framework/calculator_framework.h" | ||||
| #include "mediapipe/framework/output_stream_poller.h" | ||||
| #include "mediapipe/framework/packet.h" | ||||
| #include "mediapipe/framework/port/gmock.h" | ||||
| #include "mediapipe/framework/port/gtest.h" | ||||
| #include "mediapipe/framework/port/parse_text_proto.h" | ||||
| #include "mediapipe/framework/port/status_macros.h" | ||||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| #include "mediapipe/framework/timestamp.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace { | ||||
| 
 | ||||
| using ::mediapipe::ParseTextProtoOrDie; | ||||
| using ::mediapipe::api2::Input; | ||||
| using ::mediapipe::api2::Output; | ||||
| using ::mediapipe::api2::builder::Graph; | ||||
| using ::mediapipe::api2::builder::Source; | ||||
| using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; | ||||
| using ::testing::Pointwise; | ||||
| 
 | ||||
| constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; | ||||
| constexpr char kEmbeddingsInName[] = "embeddings_in"; | ||||
| constexpr char kEmbeddingsOutName[] = "embeddings_out"; | ||||
| constexpr char kTimestampsTag[] = "TIMESTAMPS"; | ||||
| constexpr char kTimestampsName[] = "timestamps_in"; | ||||
| constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS"; | ||||
| constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out"; | ||||
| 
 | ||||
| class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test { | ||||
|  protected: | ||||
|   absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) { | ||||
|     Graph graph; | ||||
|     auto& calculator = graph.AddNode("EmbeddingAggregationCalculator"); | ||||
|     graph[Input<EmbeddingResult>(kEmbeddingsTag)].SetName(kEmbeddingsInName) >> | ||||
|         calculator.In(kEmbeddingsTag); | ||||
|     if (connect_timestamps) { | ||||
|       graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName( | ||||
|           kTimestampsName) >> | ||||
|           calculator.In(kTimestampsTag); | ||||
|       calculator.Out(kTimestampedEmbeddingsTag) | ||||
|               .SetName(kTimestampedEmbeddingsName) >> | ||||
|           graph[Output<std::vector<EmbeddingResult>>( | ||||
|               kTimestampedEmbeddingsTag)]; | ||||
|     } else { | ||||
|       calculator.Out(kEmbeddingsTag).SetName(kEmbeddingsOutName) >> | ||||
|           graph[Output<EmbeddingResult>(kEmbeddingsTag)]; | ||||
|     } | ||||
| 
 | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); | ||||
|     if (connect_timestamps) { | ||||
|       ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( | ||||
|                                         kTimestampedEmbeddingsName)); | ||||
|       MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); | ||||
|       return poller; | ||||
|     } | ||||
|     ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( | ||||
|                                       kEmbeddingsOutName)); | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); | ||||
|     return poller; | ||||
|   } | ||||
| 
 | ||||
|   absl::Status Send( | ||||
|       const EmbeddingResult& embeddings, int timestamp = 0, | ||||
|       std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) { | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( | ||||
|         kEmbeddingsInName, MakePacket<EmbeddingResult>(std::move(embeddings)) | ||||
|                                .At(Timestamp(timestamp)))); | ||||
|     if (aggregation_timestamps.has_value()) { | ||||
|       auto packet = std::make_unique<std::vector<Timestamp>>(); | ||||
|       for (const auto& timestamp : *aggregation_timestamps) { | ||||
|         packet->emplace_back(Timestamp(timestamp)); | ||||
|       } | ||||
|       MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( | ||||
|           kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); | ||||
|     } | ||||
|     return absl::OkStatus(); | ||||
|   } | ||||
| 
 | ||||
|   template <typename T> | ||||
|   absl::StatusOr<T> GetResult(OutputStreamPoller& poller) { | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); | ||||
| 
 | ||||
|     Packet packet; | ||||
|     if (!poller.Next(&packet)) { | ||||
|       return absl::InternalError("Unable to get output packet"); | ||||
|     } | ||||
|     auto result = packet.Get<T>(); | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); | ||||
|     return result; | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   CalculatorGraph calculator_graph_; | ||||
| }; | ||||
| 
 | ||||
| TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) { | ||||
|   EmbeddingResult embedding = ParseTextProtoOrDie<EmbeddingResult>( | ||||
|       R"pb(embeddings { head_index: 0 })pb"); | ||||
| 
 | ||||
|   MP_ASSERT_OK_AND_ASSIGN(auto poller, | ||||
|                           BuildGraph(/*connect_timestamps=*/false)); | ||||
|   MP_ASSERT_OK(Send(embedding)); | ||||
|   MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<EmbeddingResult>(poller)); | ||||
| 
 | ||||
|   EXPECT_THAT(result, EqualsProto(embedding)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithTimestamps) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true)); | ||||
|   MP_ASSERT_OK(Send(ParseTextProtoOrDie<EmbeddingResult>(R"pb(embeddings { | ||||
|                                                                 head_index: 0 | ||||
|                                                               })pb"))); | ||||
|   MP_ASSERT_OK(Send( | ||||
|       ParseTextProtoOrDie<EmbeddingResult>( | ||||
|           R"pb(embeddings { head_index: 1 })pb"), | ||||
|       /*timestamp=*/1000, | ||||
|       /*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}))); | ||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, | ||||
|                           GetResult<std::vector<EmbeddingResult>>(poller)); | ||||
| 
 | ||||
|   EXPECT_THAT(results, | ||||
|               Pointwise(EqualsProto(), {ParseTextProtoOrDie<EmbeddingResult>( | ||||
|                                             R"pb(embeddings { head_index: 0 } | ||||
|                                                  timestamp_ms: 0)pb"), | ||||
|                                         ParseTextProtoOrDie<EmbeddingResult>( | ||||
|                                             R"pb(embeddings { head_index: 1 } | ||||
|                                                  timestamp_ms: 1)pb")})); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace mediapipe
 | ||||
|  | @ -82,6 +82,7 @@ cc_library( | |||
|         "//mediapipe/framework/formats:tensor", | ||||
|         "//mediapipe/framework/tool:options_map", | ||||
|         "//mediapipe/tasks/cc:common", | ||||
|         "//mediapipe/tasks/cc/components/calculators:embedding_aggregation_calculator", | ||||
|         "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", | ||||
|         "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto", | ||||
|         "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", | ||||
|  |  | |||
|  | @ -56,6 +56,14 @@ using TensorsSource = | |||
| 
 | ||||
| constexpr char kTensorsTag[] = "TENSORS"; | ||||
| constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; | ||||
| constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS"; | ||||
| constexpr char kTimestampsTag[] = "TIMESTAMPS"; | ||||
| 
 | ||||
| // Struct holding the different output streams produced by the graph.
 | ||||
| struct EmbeddingPostprocessingOutputStreams { | ||||
|   Source<EmbeddingResult> embeddings; | ||||
|   Source<std::vector<EmbeddingResult>> timestamped_embeddings; | ||||
| }; | ||||
| 
 | ||||
| // Identifies whether or not the model has quantized outputs, and performs
 | ||||
| // sanity checks.
 | ||||
|  | @ -168,27 +176,39 @@ absl::Status ConfigureEmbeddingPostprocessing( | |||
| //   TENSORS - std::vector<Tensor>
 | ||||
| //     The output tensors of an InferenceCalculator, to convert into
 | ||||
| //     EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
 | ||||
| //   TIMESTAMPS - std::vector<Timestamp> @Optional
 | ||||
| //     The collection of the timestamps that this calculator should aggregate.
 | ||||
| //     This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
 | ||||
| //     output is used for results. Otherwise as no timestamp aggregation is
 | ||||
| //     required the EMBEDDINGS output is used for results.
 | ||||
| //
 | ||||
| // Outputs:
 | ||||
| //   EMBEDDING_RESULT - EmbeddingResult
 | ||||
| //     The output EmbeddingResult.
 | ||||
| //   EMBEDDINGS - EmbeddingResult @Optional
 | ||||
| //     The embedding results aggregated by head. Must be connected if the
 | ||||
| //     TIMESTAMPS input is not connected, as it signals that timestamp
 | ||||
| //     aggregation is not required.
 | ||||
| //   TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
 | ||||
| //     The embedding result aggregated by timestamp, then by head. Must be
 | ||||
| //     connected if the TIMESTAMPS input is connected, as it signals that
 | ||||
| //     timestamp aggregation is required.
 | ||||
| //
 | ||||
| // The recommended way of using this graph is through the GraphBuilder API using
 | ||||
| // the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
 | ||||
| // details.
 | ||||
| //
 | ||||
| // TODO: add support for additional optional "TIMESTAMPS" input for
 | ||||
| // embeddings aggregation.
 | ||||
| class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { | ||||
|  public: | ||||
|   absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig( | ||||
|       mediapipe::SubgraphContext* sc) override { | ||||
|     Graph graph; | ||||
|     ASSIGN_OR_RETURN( | ||||
|         auto embedding_result_out, | ||||
|         auto output_streams, | ||||
|         BuildEmbeddingPostprocessing( | ||||
|             sc->Options<proto::EmbeddingPostprocessingGraphOptions>(), | ||||
|             graph[Input<std::vector<Tensor>>(kTensorsTag)], graph)); | ||||
|     embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)]; | ||||
|             graph[Input<std::vector<Tensor>>(kTensorsTag)], | ||||
|             graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph)); | ||||
|     output_streams.embeddings >> graph[Output<EmbeddingResult>(kEmbeddingsTag)]; | ||||
|     output_streams.timestamped_embeddings >> | ||||
|         graph[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)]; | ||||
|     return graph.GetConfig(); | ||||
|   } | ||||
| 
 | ||||
|  | @ -200,10 +220,14 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { | |||
|   //
 | ||||
|   // options: the on-device EmbeddingPostprocessingGraphOptions
 | ||||
|   // tensors_in: (std::vector<mediapipe::Tensor>) tensors to postprocess.
 | ||||
|   // timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
 | ||||
|   //   timestamps that should be used to aggregate embedding results.
 | ||||
|   // graph: the mediapipe builder::Graph instance to be updated.
 | ||||
|   absl::StatusOr<Source<EmbeddingResult>> BuildEmbeddingPostprocessing( | ||||
|   absl::StatusOr<EmbeddingPostprocessingOutputStreams> | ||||
|   BuildEmbeddingPostprocessing( | ||||
|       const proto::EmbeddingPostprocessingGraphOptions options, | ||||
|       Source<std::vector<Tensor>> tensors_in, Graph& graph) { | ||||
|       Source<std::vector<Tensor>> tensors_in, | ||||
|       Source<std::vector<Timestamp>> timestamps_in, Graph& graph) { | ||||
|     // If output tensors are quantized, they must be dequantized first.
 | ||||
|     TensorsSource dequantized_tensors(&tensors_in); | ||||
|     if (options.has_quantized_outputs()) { | ||||
|  | @ -220,7 +244,20 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { | |||
|         .GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>() | ||||
|         .CopyFrom(options.tensors_to_embeddings_options()); | ||||
|     dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag); | ||||
|     return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)]; | ||||
| 
 | ||||
|     // Adds EmbeddingAggregationCalculator.
 | ||||
|     GenericNode& aggregation_node = | ||||
|         graph.AddNode("EmbeddingAggregationCalculator"); | ||||
|     tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)] >> | ||||
|         aggregation_node.In(kEmbeddingsTag); | ||||
|     timestamps_in >> aggregation_node.In(kTimestampsTag); | ||||
| 
 | ||||
|     // Connects outputs.
 | ||||
|     return EmbeddingPostprocessingOutputStreams{ | ||||
|         /*embeddings=*/aggregation_node[Output<EmbeddingResult>( | ||||
|             kEmbeddingsTag)], | ||||
|         /*timestamped_embeddings=*/aggregation_node | ||||
|             [Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)]}; | ||||
|   } | ||||
| }; | ||||
| REGISTER_MEDIAPIPE_GRAPH( | ||||
|  |  | |||
|  | @ -44,12 +44,20 @@ namespace processors { | |||
| //   TENSORS - std::vector<Tensor>
 | ||||
| //     The output tensors of an InferenceCalculator, to convert into
 | ||||
| //     EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
 | ||||
| //   TIMESTAMPS - std::vector<Timestamp> @Optional
 | ||||
| //     The collection of the timestamps that this calculator should aggregate.
 | ||||
| //     This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
 | ||||
| //     output is used for results. Otherwise as no timestamp aggregation is
 | ||||
| //     required the EMBEDDINGS output is used for results.
 | ||||
| // Outputs:
 | ||||
| //   EMBEDDINGS - EmbeddingResult
 | ||||
| //     The output EmbeddingResult.
 | ||||
| //
 | ||||
| // TODO: add support for additional optional "TIMESTAMPS" input for
 | ||||
| // embeddings aggregation.
 | ||||
| //   EMBEDDINGS - EmbeddingResult @Optional
 | ||||
| //     The embedding results aggregated by head. Must be connected if the
 | ||||
| //     TIMESTAMPS input is not connected, as it signals that timestamp
 | ||||
| //     aggregation is not required.
 | ||||
| //   TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
 | ||||
| //     The embedding result aggregated by timestamp, then by head. Must be
 | ||||
| //     connected if the TIMESTAMPS input is connected, as it signals that
 | ||||
| //     timestamp aggregation is required.
 | ||||
| absl::Status ConfigureEmbeddingPostprocessing( | ||||
|     const tasks::core::ModelResources& model_resources, | ||||
|     const proto::EmbedderOptions& embedder_options, | ||||
|  |  | |||
|  | @ -20,11 +20,20 @@ limitations under the License. | |||
| #include "absl/flags/flag.h" | ||||
| #include "absl/status/statusor.h" | ||||
| #include "absl/strings/string_view.h" | ||||
| #include "mediapipe/framework/api2/builder.h" | ||||
| #include "mediapipe/framework/api2/port.h" | ||||
| #include "mediapipe/framework/calculator_framework.h" | ||||
| #include "mediapipe/framework/calculator_runner.h" | ||||
| #include "mediapipe/framework/deps/file_path.h" | ||||
| #include "mediapipe/framework/formats/tensor.h" | ||||
| #include "mediapipe/framework/graph_runner.h" | ||||
| #include "mediapipe/framework/output_stream_poller.h" | ||||
| #include "mediapipe/framework/port/gmock.h" | ||||
| #include "mediapipe/framework/port/gtest.h" | ||||
| #include "mediapipe/framework/port/parse_text_proto.h" | ||||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| #include "mediapipe/framework/timestamp.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" | ||||
| #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" | ||||
| #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" | ||||
| #include "mediapipe/tasks/cc/core/model_resources.h" | ||||
|  | @ -37,7 +46,12 @@ namespace components { | |||
| namespace processors { | ||||
| namespace { | ||||
| 
 | ||||
| using ::mediapipe::api2::Input; | ||||
| using ::mediapipe::api2::Output; | ||||
| using ::mediapipe::api2::builder::Graph; | ||||
| using ::mediapipe::api2::builder::Source; | ||||
| using ::mediapipe::file::JoinPath; | ||||
| using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; | ||||
| using ::mediapipe::tasks::core::ModelResources; | ||||
| 
 | ||||
| constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; | ||||
|  | @ -51,6 +65,16 @@ constexpr char kQuantizedImageClassifierWithoutMetadata[] = | |||
|     "vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; | ||||
| 
 | ||||
| constexpr char kTestModelResourcesTag[] = "test_model_resources"; | ||||
| constexpr int kMobileNetV3EmbedderEmbeddingSize = 1024; | ||||
| 
 | ||||
| constexpr char kTensorsTag[] = "TENSORS"; | ||||
| constexpr char kTensorsName[] = "tensors"; | ||||
| constexpr char kTimestampsTag[] = "TIMESTAMPS"; | ||||
| constexpr char kTimestampsName[] = "timestamps"; | ||||
| constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; | ||||
| constexpr char kEmbeddingsName[] = "embeddings"; | ||||
| constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS"; | ||||
| constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings"; | ||||
| 
 | ||||
| // Helper function to get ModelResources.
 | ||||
| absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel( | ||||
|  | @ -128,8 +152,171 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { | |||
|                    has_quantized_outputs: false)pb"))); | ||||
| } | ||||
| 
 | ||||
| // TODO: add E2E Postprocessing tests once timestamp aggregation is
 | ||||
| // supported.
 | ||||
| class PostprocessingTest : public tflite_shims::testing::Test { | ||||
|  protected: | ||||
|   absl::StatusOr<OutputStreamPoller> BuildGraph( | ||||
|       absl::string_view model_name, const proto::EmbedderOptions& options, | ||||
|       bool connect_timestamps = false) { | ||||
|     ASSIGN_OR_RETURN(auto model_resources, | ||||
|                      CreateModelResourcesForModel(model_name)); | ||||
| 
 | ||||
|     Graph graph; | ||||
|     auto& postprocessing = graph.AddNode( | ||||
|         "mediapipe.tasks.components.processors." | ||||
|         "EmbeddingPostprocessingGraph"); | ||||
|     MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing( | ||||
|         *model_resources, options, | ||||
|         &postprocessing | ||||
|              .GetOptions<proto::EmbeddingPostprocessingGraphOptions>())); | ||||
|     graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >> | ||||
|         postprocessing.In(kTensorsTag); | ||||
|     if (connect_timestamps) { | ||||
|       graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName( | ||||
|           kTimestampsName) >> | ||||
|           postprocessing.In(kTimestampsTag); | ||||
|       postprocessing.Out(kTimestampedEmbeddingsTag) | ||||
|               .SetName(kTimestampedEmbeddingsName) >> | ||||
|           graph[Output<std::vector<EmbeddingResult>>( | ||||
|               kTimestampedEmbeddingsTag)]; | ||||
|     } else { | ||||
|       postprocessing.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >> | ||||
|           graph[Output<EmbeddingResult>(kEmbeddingsTag)]; | ||||
|     } | ||||
| 
 | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); | ||||
|     if (connect_timestamps) { | ||||
|       ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( | ||||
|                                         kTimestampedEmbeddingsName)); | ||||
|       MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); | ||||
|       return poller; | ||||
|     } | ||||
|     ASSIGN_OR_RETURN(auto poller, | ||||
|                      calculator_graph_.AddOutputStreamPoller(kEmbeddingsName)); | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); | ||||
|     return poller; | ||||
|   } | ||||
| 
 | ||||
|   template <typename T> | ||||
|   void AddTensor( | ||||
|       const std::vector<T>& tensor, const Tensor::ElementType& element_type, | ||||
|       const Tensor::QuantizationParameters& quantization_parameters = {}) { | ||||
|     tensors_->emplace_back(element_type, | ||||
|                            Tensor::Shape{1, static_cast<int>(tensor.size())}, | ||||
|                            quantization_parameters); | ||||
|     auto view = tensors_->back().GetCpuWriteView(); | ||||
|     T* buffer = view.buffer<T>(); | ||||
|     std::copy(tensor.begin(), tensor.end(), buffer); | ||||
|   } | ||||
| 
 | ||||
|   absl::Status Run( | ||||
|       std::optional<std::vector<int>> aggregation_timestamps = std::nullopt, | ||||
|       int timestamp = 0) { | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( | ||||
|         kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp)))); | ||||
|     // Reset tensors for future calls.
 | ||||
|     tensors_ = absl::make_unique<std::vector<Tensor>>(); | ||||
|     if (aggregation_timestamps.has_value()) { | ||||
|       auto packet = absl::make_unique<std::vector<Timestamp>>(); | ||||
|       for (const auto& timestamp : *aggregation_timestamps) { | ||||
|         packet->emplace_back(Timestamp(timestamp)); | ||||
|       } | ||||
|       MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( | ||||
|           kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); | ||||
|     } | ||||
|     return absl::OkStatus(); | ||||
|   } | ||||
| 
 | ||||
|   template <typename T> | ||||
|   absl::StatusOr<T> GetResult(OutputStreamPoller& poller) { | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); | ||||
| 
 | ||||
|     Packet packet; | ||||
|     if (!poller.Next(&packet)) { | ||||
|       return absl::InternalError("Unable to get output packet"); | ||||
|     } | ||||
|     auto result = packet.Get<T>(); | ||||
|     MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); | ||||
|     return result; | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   CalculatorGraph calculator_graph_; | ||||
|   std::unique_ptr<std::vector<Tensor>> tensors_ = | ||||
|       absl::make_unique<std::vector<Tensor>>(); | ||||
| }; | ||||
| 
 | ||||
| TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) { | ||||
|   // Build graph.
 | ||||
|   proto::EmbedderOptions options; | ||||
|   MP_ASSERT_OK_AND_ASSIGN(auto poller, | ||||
|                           BuildGraph(kMobileNetV3Embedder, options)); | ||||
|   // Build input tensor.
 | ||||
|   std::vector<float> tensor(kMobileNetV3EmbedderEmbeddingSize, 0); | ||||
|   tensor[0] = 1.0; | ||||
| 
 | ||||
|   // Send tensor and get results.
 | ||||
|   AddTensor(tensor, Tensor::ElementType::kFloat32); | ||||
|   MP_ASSERT_OK(Run()); | ||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult<EmbeddingResult>(poller)); | ||||
| 
 | ||||
|   // Validate results.
 | ||||
|   EXPECT_FALSE(results.has_timestamp_ms()); | ||||
|   EXPECT_EQ(results.embeddings_size(), 1); | ||||
|   EXPECT_EQ(results.embeddings(0).head_index(), 0); | ||||
|   EXPECT_EQ(results.embeddings(0).head_name(), "feature"); | ||||
|   EXPECT_EQ(results.embeddings(0).float_embedding().values_size(), | ||||
|             kMobileNetV3EmbedderEmbeddingSize); | ||||
|   EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(0), 1.0); | ||||
|   for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) { | ||||
|     EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(i), 0.0); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| TEST_F(PostprocessingTest, SucceedsWithTimestamps) { | ||||
|   // Build graph.
 | ||||
|   proto::EmbedderOptions options; | ||||
|   MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(kMobileNetV3Embedder, options, | ||||
|                                                   /*connect_timestamps=*/true)); | ||||
|   // Build input tensors.
 | ||||
|   std::vector<float> tensor_0(kMobileNetV3EmbedderEmbeddingSize, 0); | ||||
|   tensor_0[0] = 1.0; | ||||
|   std::vector<float> tensor_1(kMobileNetV3EmbedderEmbeddingSize, 0); | ||||
|   tensor_1[0] = 2.0; | ||||
| 
 | ||||
|   // Send tensors and get results.
 | ||||
|   AddTensor(tensor_0, Tensor::ElementType::kFloat32); | ||||
|   MP_ASSERT_OK(Run()); | ||||
|   AddTensor(tensor_1, Tensor::ElementType::kFloat32); | ||||
|   MP_ASSERT_OK(Run( | ||||
|       /*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}), | ||||
|       /*timestamp=*/1000)); | ||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, | ||||
|                           GetResult<std::vector<EmbeddingResult>>(poller)); | ||||
| 
 | ||||
|   // Validate results.
 | ||||
|   EXPECT_EQ(results.size(), 2); | ||||
|   // First timestamp.
 | ||||
|   EXPECT_EQ(results[0].timestamp_ms(), 0); | ||||
|   EXPECT_EQ(results[0].embeddings(0).head_index(), 0); | ||||
|   EXPECT_EQ(results[0].embeddings(0).head_name(), "feature"); | ||||
|   EXPECT_EQ(results[0].embeddings(0).float_embedding().values_size(), | ||||
|             kMobileNetV3EmbedderEmbeddingSize); | ||||
|   EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(0), 1.0); | ||||
|   for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) { | ||||
|     EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(i), 0.0); | ||||
|   } | ||||
|   // Second timestamp.
 | ||||
|   EXPECT_EQ(results[1].timestamp_ms(), 1); | ||||
|   EXPECT_EQ(results[1].embeddings(0).head_index(), 0); | ||||
|   EXPECT_EQ(results[1].embeddings(0).head_name(), "feature"); | ||||
|   EXPECT_EQ(results[1].embeddings(0).float_embedding().values_size(), | ||||
|             kMobileNetV3EmbedderEmbeddingSize); | ||||
|   EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(0), 2.0); | ||||
|   for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) { | ||||
|     EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(i), 0.0); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace processors
 | ||||
|  |  | |||
|  | @ -32,7 +32,4 @@ message EmbeddingPostprocessingGraphOptions { | |||
| 
 | ||||
|   // Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32). | ||||
|   optional bool has_quantized_outputs = 2; | ||||
| 
 | ||||
|   // TODO: add options to control whether timestamp aggregation | ||||
|   // should be used or not. | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user