Add timestamp aggregation to EmbeddingPostprocessingGraph.
PiperOrigin-RevId: 487463848
This commit is contained in:
parent
f11c757629
commit
0ac604d507
|
@ -200,3 +200,38 @@ cc_test(
|
|||
"@com_google_absl//absl/status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedding_aggregation_calculator",
|
||||
srcs = ["embedding_aggregation_calculator.cc"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "embedding_aggregation_calculator_test",
|
||||
srcs = ["embedding_aggregation_calculator_test.cc"],
|
||||
deps = [
|
||||
":embedding_aggregation_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:output_stream_poller",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
|
||||
// Aggregates EmbeddingResult packets into a vector of timestamped
|
||||
// EmbeddingResult. Acts as a pass-through if no timestamp aggregation is
|
||||
// needed.
|
||||
//
|
||||
// Inputs:
|
||||
// EMBEDDINGS: EmbeddingResult
|
||||
// The EmbeddingResult packets to aggregate.
|
||||
// TIMESTAMPS: std::vector<Timestamp> @Optional.
|
||||
// The collection of timestamps that this calculator should aggregate. This
|
||||
// stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS output
|
||||
// will contain the aggregated results. Otherwise as no timestamp
|
||||
// aggregation is required the EMBEDDINGS output is used to pass the inputs
|
||||
// EmbeddingResults unchanged.
|
||||
//
|
||||
// Outputs:
|
||||
// EMBEDDINGS: EmbeddingResult @Optional
|
||||
// The input EmbeddingResult, unchanged. Must be connected if the TIMESTAMPS
|
||||
// input is not connected, as it signals that timestamp aggregation is not
|
||||
// required.
|
||||
// TIMESTAMPED_EMBEDDINGS: std::vector<EmbeddingResult> @Optional
|
||||
// The embedding results aggregated by timestamp. Must be connected if the
|
||||
// TIMESTAMPS input is connected as it signals that timestamp aggregation is
|
||||
// required.
|
||||
//
|
||||
// Example without timestamp aggregation (pass-through):
|
||||
// node {
|
||||
// calculator: "EmbeddingAggregationCalculator"
|
||||
// input_stream: "EMBEDDINGS:embeddings_in"
|
||||
// output_stream: "EMBEDDINGS:embeddings_out"
|
||||
// }
|
||||
//
|
||||
// Example with timestamp aggregation:
|
||||
// node {
|
||||
// calculator: "EmbeddingAggregationCalculator"
|
||||
// input_stream: "EMBEDDINGS:embeddings_in"
|
||||
// input_stream: "TIMESTAMPS:timestamps_in"
|
||||
// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out"
|
||||
// }
|
||||
class EmbeddingAggregationCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<EmbeddingResult> kEmbeddingsIn{"EMBEDDINGS"};
|
||||
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
|
||||
"TIMESTAMPS"};
|
||||
static constexpr Output<EmbeddingResult>::Optional kEmbeddingsOut{
|
||||
"EMBEDDINGS"};
|
||||
static constexpr Output<std::vector<EmbeddingResult>>::Optional
|
||||
kTimestampedEmbeddingsOut{"TIMESTAMPED_EMBEDDINGS"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kEmbeddingsIn, kTimestampsIn, kEmbeddingsOut,
|
||||
kTimestampedEmbeddingsOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc);
|
||||
absl::Status Process(CalculatorContext* cc);
|
||||
|
||||
private:
|
||||
bool time_aggregation_enabled_;
|
||||
std::unordered_map<int64, EmbeddingResult> cached_embeddings_;
|
||||
};
|
||||
|
||||
absl::Status EmbeddingAggregationCalculator::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
if (kTimestampsIn(cc).IsConnected()) {
|
||||
RET_CHECK(kTimestampedEmbeddingsOut(cc).IsConnected());
|
||||
} else {
|
||||
RET_CHECK(kEmbeddingsOut(cc).IsConnected());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status EmbeddingAggregationCalculator::Open(CalculatorContext* cc) {
|
||||
time_aggregation_enabled_ = kTimestampsIn(cc).IsConnected();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) {
|
||||
if (time_aggregation_enabled_) {
|
||||
cached_embeddings_[cc->InputTimestamp().Value()] =
|
||||
std::move(*kEmbeddingsIn(cc));
|
||||
if (kTimestampsIn(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
auto timestamps = kTimestampsIn(cc).Get();
|
||||
std::vector<EmbeddingResult> results;
|
||||
results.reserve(timestamps.size());
|
||||
for (const auto& timestamp : timestamps) {
|
||||
auto& result = cached_embeddings_[timestamp.Value()];
|
||||
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) /
|
||||
1000);
|
||||
results.push_back(std::move(result));
|
||||
cached_embeddings_.erase(timestamp.Value());
|
||||
}
|
||||
kTimestampedEmbeddingsOut(cc).Send(std::move(results));
|
||||
} else {
|
||||
kEmbeddingsOut(cc).Send(kEmbeddingsIn(cc));
|
||||
}
|
||||
RET_CHECK(cached_embeddings_.empty());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(EmbeddingAggregationCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,158 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/output_stream_poller.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::ParseTextProtoOrDie;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::testing::Pointwise;
|
||||
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kEmbeddingsInName[] = "embeddings_in";
|
||||
constexpr char kEmbeddingsOutName[] = "embeddings_out";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
constexpr char kTimestampsName[] = "timestamps_in";
|
||||
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
|
||||
|
||||
class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
|
||||
Graph graph;
|
||||
auto& calculator = graph.AddNode("EmbeddingAggregationCalculator");
|
||||
graph[Input<EmbeddingResult>(kEmbeddingsTag)].SetName(kEmbeddingsInName) >>
|
||||
calculator.In(kEmbeddingsTag);
|
||||
if (connect_timestamps) {
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||
kTimestampsName) >>
|
||||
calculator.In(kTimestampsTag);
|
||||
calculator.Out(kTimestampedEmbeddingsTag)
|
||||
.SetName(kTimestampedEmbeddingsName) >>
|
||||
graph[Output<std::vector<EmbeddingResult>>(
|
||||
kTimestampedEmbeddingsTag)];
|
||||
} else {
|
||||
calculator.Out(kEmbeddingsTag).SetName(kEmbeddingsOutName) >>
|
||||
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||
if (connect_timestamps) {
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kTimestampedEmbeddingsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kEmbeddingsOutName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
|
||||
absl::Status Send(
|
||||
const EmbeddingResult& embeddings, int timestamp = 0,
|
||||
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kEmbeddingsInName, MakePacket<EmbeddingResult>(std::move(embeddings))
|
||||
.At(Timestamp(timestamp))));
|
||||
if (aggregation_timestamps.has_value()) {
|
||||
auto packet = std::make_unique<std::vector<Timestamp>>();
|
||||
for (const auto& timestamp : *aggregation_timestamps) {
|
||||
packet->emplace_back(Timestamp(timestamp));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||
|
||||
Packet packet;
|
||||
if (!poller.Next(&packet)) {
|
||||
return absl::InternalError("Unable to get output packet");
|
||||
}
|
||||
auto result = packet.Get<T>();
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
CalculatorGraph calculator_graph_;
|
||||
};
|
||||
|
||||
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) {
|
||||
EmbeddingResult embedding = ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings { head_index: 0 })pb");
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller,
|
||||
BuildGraph(/*connect_timestamps=*/false));
|
||||
MP_ASSERT_OK(Send(embedding));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<EmbeddingResult>(poller));
|
||||
|
||||
EXPECT_THAT(result, EqualsProto(embedding));
|
||||
}
|
||||
|
||||
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithTimestamps) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
|
||||
MP_ASSERT_OK(Send(ParseTextProtoOrDie<EmbeddingResult>(R"pb(embeddings {
|
||||
head_index: 0
|
||||
})pb")));
|
||||
MP_ASSERT_OK(Send(
|
||||
ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings { head_index: 1 })pb"),
|
||||
/*timestamp=*/1000,
|
||||
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
GetResult<std::vector<EmbeddingResult>>(poller));
|
||||
|
||||
EXPECT_THAT(results,
|
||||
Pointwise(EqualsProto(), {ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings { head_index: 0 }
|
||||
timestamp_ms: 0)pb"),
|
||||
ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings { head_index: 1 }
|
||||
timestamp_ms: 1)pb")}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -82,6 +82,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/calculators:embedding_aggregation_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
|
|
|
@ -56,6 +56,14 @@ using TensorsSource =
|
|||
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
|
||||
// Struct holding the different output streams produced by the graph.
|
||||
struct EmbeddingPostprocessingOutputStreams {
|
||||
Source<EmbeddingResult> embeddings;
|
||||
Source<std::vector<EmbeddingResult>> timestamped_embeddings;
|
||||
};
|
||||
|
||||
// Identifies whether or not the model has quantized outputs, and performs
|
||||
// sanity checks.
|
||||
|
@ -168,27 +176,39 @@ absl::Status ConfigureEmbeddingPostprocessing(
|
|||
// TENSORS - std::vector<Tensor>
|
||||
// The output tensors of an InferenceCalculator, to convert into
|
||||
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||
// The collection of the timestamps that this calculator should aggregate.
|
||||
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
|
||||
// output is used for results. Otherwise as no timestamp aggregation is
|
||||
// required the EMBEDDINGS output is used for results.
|
||||
//
|
||||
// Outputs:
|
||||
// EMBEDDING_RESULT - EmbeddingResult
|
||||
// The output EmbeddingResult.
|
||||
// EMBEDDINGS - EmbeddingResult @Optional
|
||||
// The embedding results aggregated by head. Must be connected if the
|
||||
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||
// aggregation is not required.
|
||||
// TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
|
||||
// The embedding result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
//
|
||||
// The recommended way of using this graph is through the GraphBuilder API using
|
||||
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
|
||||
// details.
|
||||
//
|
||||
// TODO: add support for additional optional "TIMESTAMPS" input for
|
||||
// embeddings aggregation.
|
||||
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||
public:
|
||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||
mediapipe::SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto embedding_result_out,
|
||||
auto output_streams,
|
||||
BuildEmbeddingPostprocessing(
|
||||
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
|
||||
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||
output_streams.embeddings >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
output_streams.timestamped_embeddings >>
|
||||
graph[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -200,10 +220,14 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
|||
//
|
||||
// options: the on-device EmbeddingPostprocessingGraphOptions
|
||||
// tensors_in: (std::vector<mediapipe::Tensor>) tensors to postprocess.
|
||||
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
||||
// timestamps that should be used to aggregate embedding results.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<EmbeddingResult>> BuildEmbeddingPostprocessing(
|
||||
absl::StatusOr<EmbeddingPostprocessingOutputStreams>
|
||||
BuildEmbeddingPostprocessing(
|
||||
const proto::EmbeddingPostprocessingGraphOptions options,
|
||||
Source<std::vector<Tensor>> tensors_in, Graph& graph) {
|
||||
Source<std::vector<Tensor>> tensors_in,
|
||||
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
|
||||
// If output tensors are quantized, they must be dequantized first.
|
||||
TensorsSource dequantized_tensors(&tensors_in);
|
||||
if (options.has_quantized_outputs()) {
|
||||
|
@ -220,7 +244,20 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
|||
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
|
||||
.CopyFrom(options.tensors_to_embeddings_options());
|
||||
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
|
||||
return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
|
||||
// Adds EmbeddingAggregationCalculator.
|
||||
GenericNode& aggregation_node =
|
||||
graph.AddNode("EmbeddingAggregationCalculator");
|
||||
tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)] >>
|
||||
aggregation_node.In(kEmbeddingsTag);
|
||||
timestamps_in >> aggregation_node.In(kTimestampsTag);
|
||||
|
||||
// Connects outputs.
|
||||
return EmbeddingPostprocessingOutputStreams{
|
||||
/*embeddings=*/aggregation_node[Output<EmbeddingResult>(
|
||||
kEmbeddingsTag)],
|
||||
/*timestamped_embeddings=*/aggregation_node
|
||||
[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)]};
|
||||
}
|
||||
};
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
|
|
|
@ -44,12 +44,20 @@ namespace processors {
|
|||
// TENSORS - std::vector<Tensor>
|
||||
// The output tensors of an InferenceCalculator, to convert into
|
||||
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||
// The collection of the timestamps that this calculator should aggregate.
|
||||
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
|
||||
// output is used for results. Otherwise as no timestamp aggregation is
|
||||
// required the EMBEDDINGS output is used for results.
|
||||
// Outputs:
|
||||
// EMBEDDINGS - EmbeddingResult
|
||||
// The output EmbeddingResult.
|
||||
//
|
||||
// TODO: add support for additional optional "TIMESTAMPS" input for
|
||||
// embeddings aggregation.
|
||||
// EMBEDDINGS - EmbeddingResult @Optional
|
||||
// The embedding results aggregated by head. Must be connected if the
|
||||
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||
// aggregation is not required.
|
||||
// TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
|
||||
// The embedding result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
absl::Status ConfigureEmbeddingPostprocessing(
|
||||
const tasks::core::ModelResources& model_resources,
|
||||
const proto::EmbedderOptions& embedder_options,
|
||||
|
|
|
@ -20,11 +20,20 @@ limitations under the License.
|
|||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/graph_runner.h"
|
||||
#include "mediapipe/framework/output_stream_poller.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
@ -37,7 +46,12 @@ namespace components {
|
|||
namespace processors {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
||||
|
@ -51,6 +65,16 @@ constexpr char kQuantizedImageClassifierWithoutMetadata[] =
|
|||
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
|
||||
|
||||
constexpr char kTestModelResourcesTag[] = "test_model_resources";
|
||||
constexpr int kMobileNetV3EmbedderEmbeddingSize = 1024;
|
||||
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kTensorsName[] = "tensors";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
constexpr char kTimestampsName[] = "timestamps";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kEmbeddingsName[] = "embeddings";
|
||||
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings";
|
||||
|
||||
// Helper function to get ModelResources.
|
||||
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
||||
|
@ -128,8 +152,171 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
|||
has_quantized_outputs: false)pb")));
|
||||
}
|
||||
|
||||
// TODO: add E2E Postprocessing tests once timestamp aggregation is
|
||||
// supported.
|
||||
class PostprocessingTest : public tflite_shims::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
absl::string_view model_name, const proto::EmbedderOptions& options,
|
||||
bool connect_timestamps = false) {
|
||||
ASSIGN_OR_RETURN(auto model_resources,
|
||||
CreateModelResourcesForModel(model_name));
|
||||
|
||||
Graph graph;
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing(
|
||||
*model_resources, options,
|
||||
&postprocessing
|
||||
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||
postprocessing.In(kTensorsTag);
|
||||
if (connect_timestamps) {
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||
kTimestampsName) >>
|
||||
postprocessing.In(kTimestampsTag);
|
||||
postprocessing.Out(kTimestampedEmbeddingsTag)
|
||||
.SetName(kTimestampedEmbeddingsName) >>
|
||||
graph[Output<std::vector<EmbeddingResult>>(
|
||||
kTimestampedEmbeddingsTag)];
|
||||
} else {
|
||||
postprocessing.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >>
|
||||
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||
if (connect_timestamps) {
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kTimestampedEmbeddingsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto poller,
|
||||
calculator_graph_.AddOutputStreamPoller(kEmbeddingsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddTensor(
|
||||
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
|
||||
const Tensor::QuantizationParameters& quantization_parameters = {}) {
|
||||
tensors_->emplace_back(element_type,
|
||||
Tensor::Shape{1, static_cast<int>(tensor.size())},
|
||||
quantization_parameters);
|
||||
auto view = tensors_->back().GetCpuWriteView();
|
||||
T* buffer = view.buffer<T>();
|
||||
std::copy(tensor.begin(), tensor.end(), buffer);
|
||||
}
|
||||
|
||||
absl::Status Run(
|
||||
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
|
||||
int timestamp = 0) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
|
||||
// Reset tensors for future calls.
|
||||
tensors_ = absl::make_unique<std::vector<Tensor>>();
|
||||
if (aggregation_timestamps.has_value()) {
|
||||
auto packet = absl::make_unique<std::vector<Timestamp>>();
|
||||
for (const auto& timestamp : *aggregation_timestamps) {
|
||||
packet->emplace_back(Timestamp(timestamp));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||
|
||||
Packet packet;
|
||||
if (!poller.Next(&packet)) {
|
||||
return absl::InternalError("Unable to get output packet");
|
||||
}
|
||||
auto result = packet.Get<T>();
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
CalculatorGraph calculator_graph_;
|
||||
std::unique_ptr<std::vector<Tensor>> tensors_ =
|
||||
absl::make_unique<std::vector<Tensor>>();
|
||||
};
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) {
|
||||
// Build graph.
|
||||
proto::EmbedderOptions options;
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller,
|
||||
BuildGraph(kMobileNetV3Embedder, options));
|
||||
// Build input tensor.
|
||||
std::vector<float> tensor(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||
tensor[0] = 1.0;
|
||||
|
||||
// Send tensor and get results.
|
||||
AddTensor(tensor, Tensor::ElementType::kFloat32);
|
||||
MP_ASSERT_OK(Run());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult<EmbeddingResult>(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_FALSE(results.has_timestamp_ms());
|
||||
EXPECT_EQ(results.embeddings_size(), 1);
|
||||
EXPECT_EQ(results.embeddings(0).head_index(), 0);
|
||||
EXPECT_EQ(results.embeddings(0).head_name(), "feature");
|
||||
EXPECT_EQ(results.embeddings(0).float_embedding().values_size(),
|
||||
kMobileNetV3EmbedderEmbeddingSize);
|
||||
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(0), 1.0);
|
||||
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(i), 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||
// Build graph.
|
||||
proto::EmbedderOptions options;
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(kMobileNetV3Embedder, options,
|
||||
/*connect_timestamps=*/true));
|
||||
// Build input tensors.
|
||||
std::vector<float> tensor_0(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||
tensor_0[0] = 1.0;
|
||||
std::vector<float> tensor_1(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||
tensor_1[0] = 2.0;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
|
||||
MP_ASSERT_OK(Run());
|
||||
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
|
||||
MP_ASSERT_OK(Run(
|
||||
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
|
||||
/*timestamp=*/1000));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
GetResult<std::vector<EmbeddingResult>>(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_EQ(results.size(), 2);
|
||||
// First timestamp.
|
||||
EXPECT_EQ(results[0].timestamp_ms(), 0);
|
||||
EXPECT_EQ(results[0].embeddings(0).head_index(), 0);
|
||||
EXPECT_EQ(results[0].embeddings(0).head_name(), "feature");
|
||||
EXPECT_EQ(results[0].embeddings(0).float_embedding().values_size(),
|
||||
kMobileNetV3EmbedderEmbeddingSize);
|
||||
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(0), 1.0);
|
||||
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(i), 0.0);
|
||||
}
|
||||
// Second timestamp.
|
||||
EXPECT_EQ(results[1].timestamp_ms(), 1);
|
||||
EXPECT_EQ(results[1].embeddings(0).head_index(), 0);
|
||||
EXPECT_EQ(results[1].embeddings(0).head_name(), "feature");
|
||||
EXPECT_EQ(results[1].embeddings(0).float_embedding().values_size(),
|
||||
kMobileNetV3EmbedderEmbeddingSize);
|
||||
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(0), 2.0);
|
||||
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(i), 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace processors
|
||||
|
|
|
@ -32,7 +32,4 @@ message EmbeddingPostprocessingGraphOptions {
|
|||
|
||||
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
|
||||
optional bool has_quantized_outputs = 2;
|
||||
|
||||
// TODO: add options to control whether timestamp aggregation
|
||||
// should be used or not.
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user