Add timestamp aggregation to EmbeddingPostprocessingGraph.
PiperOrigin-RevId: 487463848
This commit is contained in:
parent
f11c757629
commit
0ac604d507
|
@ -200,3 +200,38 @@ cc_test(
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedding_aggregation_calculator",
|
||||||
|
srcs = ["embedding_aggregation_calculator.cc"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "embedding_aggregation_calculator_test",
|
||||||
|
srcs = ["embedding_aggregation_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":embedding_aggregation_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:output_stream_poller",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework:timestamp",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,132 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
|
||||||
|
// Aggregates EmbeddingResult packets into a vector of timestamped
|
||||||
|
// EmbeddingResult. Acts as a pass-through if no timestamp aggregation is
|
||||||
|
// needed.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// EMBEDDINGS: EmbeddingResult
|
||||||
|
// The EmbeddingResult packets to aggregate.
|
||||||
|
// TIMESTAMPS: std::vector<Timestamp> @Optional.
|
||||||
|
// The collection of timestamps that this calculator should aggregate. This
|
||||||
|
// stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS output
|
||||||
|
// will contain the aggregated results. Otherwise as no timestamp
|
||||||
|
// aggregation is required the EMBEDDINGS output is used to pass the inputs
|
||||||
|
// EmbeddingResults unchanged.
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// EMBEDDINGS: EmbeddingResult @Optional
|
||||||
|
// The input EmbeddingResult, unchanged. Must be connected if the TIMESTAMPS
|
||||||
|
// input is not connected, as it signals that timestamp aggregation is not
|
||||||
|
// required.
|
||||||
|
// TIMESTAMPED_EMBEDDINGS: std::vector<EmbeddingResult> @Optional
|
||||||
|
// The embedding results aggregated by timestamp. Must be connected if the
|
||||||
|
// TIMESTAMPS input is connected as it signals that timestamp aggregation is
|
||||||
|
// required.
|
||||||
|
//
|
||||||
|
// Example without timestamp aggregation (pass-through):
|
||||||
|
// node {
|
||||||
|
// calculator: "EmbeddingAggregationCalculator"
|
||||||
|
// input_stream: "EMBEDDINGS:embeddings_in"
|
||||||
|
// output_stream: "EMBEDDINGS:embeddings_out"
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Example with timestamp aggregation:
|
||||||
|
// node {
|
||||||
|
// calculator: "EmbeddingAggregationCalculator"
|
||||||
|
// input_stream: "EMBEDDINGS:embeddings_in"
|
||||||
|
// input_stream: "TIMESTAMPS:timestamps_in"
|
||||||
|
// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out"
|
||||||
|
// }
|
||||||
|
class EmbeddingAggregationCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<EmbeddingResult> kEmbeddingsIn{"EMBEDDINGS"};
|
||||||
|
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
|
||||||
|
"TIMESTAMPS"};
|
||||||
|
static constexpr Output<EmbeddingResult>::Optional kEmbeddingsOut{
|
||||||
|
"EMBEDDINGS"};
|
||||||
|
static constexpr Output<std::vector<EmbeddingResult>>::Optional
|
||||||
|
kTimestampedEmbeddingsOut{"TIMESTAMPED_EMBEDDINGS"};
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kEmbeddingsIn, kTimestampsIn, kEmbeddingsOut,
|
||||||
|
kTimestampedEmbeddingsOut);
|
||||||
|
|
||||||
|
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||||
|
absl::Status Open(CalculatorContext* cc);
|
||||||
|
absl::Status Process(CalculatorContext* cc);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool time_aggregation_enabled_;
|
||||||
|
std::unordered_map<int64, EmbeddingResult> cached_embeddings_;
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::Status EmbeddingAggregationCalculator::UpdateContract(
|
||||||
|
CalculatorContract* cc) {
|
||||||
|
if (kTimestampsIn(cc).IsConnected()) {
|
||||||
|
RET_CHECK(kTimestampedEmbeddingsOut(cc).IsConnected());
|
||||||
|
} else {
|
||||||
|
RET_CHECK(kEmbeddingsOut(cc).IsConnected());
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status EmbeddingAggregationCalculator::Open(CalculatorContext* cc) {
|
||||||
|
time_aggregation_enabled_ = kTimestampsIn(cc).IsConnected();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) {
|
||||||
|
if (time_aggregation_enabled_) {
|
||||||
|
cached_embeddings_[cc->InputTimestamp().Value()] =
|
||||||
|
std::move(*kEmbeddingsIn(cc));
|
||||||
|
if (kTimestampsIn(cc).IsEmpty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
auto timestamps = kTimestampsIn(cc).Get();
|
||||||
|
std::vector<EmbeddingResult> results;
|
||||||
|
results.reserve(timestamps.size());
|
||||||
|
for (const auto& timestamp : timestamps) {
|
||||||
|
auto& result = cached_embeddings_[timestamp.Value()];
|
||||||
|
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) /
|
||||||
|
1000);
|
||||||
|
results.push_back(std::move(result));
|
||||||
|
cached_embeddings_.erase(timestamp.Value());
|
||||||
|
}
|
||||||
|
kTimestampedEmbeddingsOut(cc).Send(std::move(results));
|
||||||
|
} else {
|
||||||
|
kEmbeddingsOut(cc).Send(kEmbeddingsIn(cc));
|
||||||
|
}
|
||||||
|
RET_CHECK(cached_embeddings_.empty());
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(EmbeddingAggregationCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,158 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/output_stream_poller.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/timestamp.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::ParseTextProtoOrDie;
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::mediapipe::api2::builder::Graph;
|
||||||
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
using ::testing::Pointwise;
|
||||||
|
|
||||||
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
|
constexpr char kEmbeddingsInName[] = "embeddings_in";
|
||||||
|
constexpr char kEmbeddingsOutName[] = "embeddings_out";
|
||||||
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
constexpr char kTimestampsName[] = "timestamps_in";
|
||||||
|
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||||
|
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
|
||||||
|
|
||||||
|
class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test {
|
||||||
|
protected:
|
||||||
|
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
|
||||||
|
Graph graph;
|
||||||
|
auto& calculator = graph.AddNode("EmbeddingAggregationCalculator");
|
||||||
|
graph[Input<EmbeddingResult>(kEmbeddingsTag)].SetName(kEmbeddingsInName) >>
|
||||||
|
calculator.In(kEmbeddingsTag);
|
||||||
|
if (connect_timestamps) {
|
||||||
|
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||||
|
kTimestampsName) >>
|
||||||
|
calculator.In(kTimestampsTag);
|
||||||
|
calculator.Out(kTimestampedEmbeddingsTag)
|
||||||
|
.SetName(kTimestampedEmbeddingsName) >>
|
||||||
|
graph[Output<std::vector<EmbeddingResult>>(
|
||||||
|
kTimestampedEmbeddingsTag)];
|
||||||
|
} else {
|
||||||
|
calculator.Out(kEmbeddingsTag).SetName(kEmbeddingsOutName) >>
|
||||||
|
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
|
}
|
||||||
|
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||||
|
if (connect_timestamps) {
|
||||||
|
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||||
|
kTimestampedEmbeddingsName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||||
|
kEmbeddingsOutName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Send(
|
||||||
|
const EmbeddingResult& embeddings, int timestamp = 0,
|
||||||
|
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kEmbeddingsInName, MakePacket<EmbeddingResult>(std::move(embeddings))
|
||||||
|
.At(Timestamp(timestamp))));
|
||||||
|
if (aggregation_timestamps.has_value()) {
|
||||||
|
auto packet = std::make_unique<std::vector<Timestamp>>();
|
||||||
|
for (const auto& timestamp : *aggregation_timestamps) {
|
||||||
|
packet->emplace_back(Timestamp(timestamp));
|
||||||
|
}
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||||
|
|
||||||
|
Packet packet;
|
||||||
|
if (!poller.Next(&packet)) {
|
||||||
|
return absl::InternalError("Unable to get output packet");
|
||||||
|
}
|
||||||
|
auto result = packet.Get<T>();
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CalculatorGraph calculator_graph_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) {
|
||||||
|
EmbeddingResult embedding = ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings { head_index: 0 })pb");
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto poller,
|
||||||
|
BuildGraph(/*connect_timestamps=*/false));
|
||||||
|
MP_ASSERT_OK(Send(embedding));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<EmbeddingResult>(poller));
|
||||||
|
|
||||||
|
EXPECT_THAT(result, EqualsProto(embedding));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithTimestamps) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
|
||||||
|
MP_ASSERT_OK(Send(ParseTextProtoOrDie<EmbeddingResult>(R"pb(embeddings {
|
||||||
|
head_index: 0
|
||||||
|
})pb")));
|
||||||
|
MP_ASSERT_OK(Send(
|
||||||
|
ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings { head_index: 1 })pb"),
|
||||||
|
/*timestamp=*/1000,
|
||||||
|
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
|
GetResult<std::vector<EmbeddingResult>>(poller));
|
||||||
|
|
||||||
|
EXPECT_THAT(results,
|
||||||
|
Pointwise(EqualsProto(), {ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings { head_index: 0 }
|
||||||
|
timestamp_ms: 0)pb"),
|
||||||
|
ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings { head_index: 1 }
|
||||||
|
timestamp_ms: 1)pb")}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -82,6 +82,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/tool:options_map",
|
"//mediapipe/framework/tool:options_map",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:embedding_aggregation_calculator",
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
|
|
@ -56,6 +56,14 @@ using TensorsSource =
|
||||||
|
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
|
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||||
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
|
||||||
|
// Struct holding the different output streams produced by the graph.
|
||||||
|
struct EmbeddingPostprocessingOutputStreams {
|
||||||
|
Source<EmbeddingResult> embeddings;
|
||||||
|
Source<std::vector<EmbeddingResult>> timestamped_embeddings;
|
||||||
|
};
|
||||||
|
|
||||||
// Identifies whether or not the model has quantized outputs, and performs
|
// Identifies whether or not the model has quantized outputs, and performs
|
||||||
// sanity checks.
|
// sanity checks.
|
||||||
|
@ -168,27 +176,39 @@ absl::Status ConfigureEmbeddingPostprocessing(
|
||||||
// TENSORS - std::vector<Tensor>
|
// TENSORS - std::vector<Tensor>
|
||||||
// The output tensors of an InferenceCalculator, to convert into
|
// The output tensors of an InferenceCalculator, to convert into
|
||||||
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
||||||
|
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||||
|
// The collection of the timestamps that this calculator should aggregate.
|
||||||
|
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
|
||||||
|
// output is used for results. Otherwise as no timestamp aggregation is
|
||||||
|
// required the EMBEDDINGS output is used for results.
|
||||||
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// EMBEDDING_RESULT - EmbeddingResult
|
// EMBEDDINGS - EmbeddingResult @Optional
|
||||||
// The output EmbeddingResult.
|
// The embedding results aggregated by head. Must be connected if the
|
||||||
|
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||||
|
// aggregation is not required.
|
||||||
|
// TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
|
||||||
|
// The embedding result aggregated by timestamp, then by head. Must be
|
||||||
|
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||||
|
// timestamp aggregation is required.
|
||||||
//
|
//
|
||||||
// The recommended way of using this graph is through the GraphBuilder API using
|
// The recommended way of using this graph is through the GraphBuilder API using
|
||||||
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
|
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
|
||||||
// details.
|
// details.
|
||||||
//
|
|
||||||
// TODO: add support for additional optional "TIMESTAMPS" input for
|
|
||||||
// embeddings aggregation.
|
|
||||||
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
public:
|
public:
|
||||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||||
mediapipe::SubgraphContext* sc) override {
|
mediapipe::SubgraphContext* sc) override {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto embedding_result_out,
|
auto output_streams,
|
||||||
BuildEmbeddingPostprocessing(
|
BuildEmbeddingPostprocessing(
|
||||||
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
|
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
|
||||||
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
|
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||||
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||||
|
output_streams.embeddings >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
|
output_streams.timestamped_embeddings >>
|
||||||
|
graph[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -200,10 +220,14 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
//
|
//
|
||||||
// options: the on-device EmbeddingPostprocessingGraphOptions
|
// options: the on-device EmbeddingPostprocessingGraphOptions
|
||||||
// tensors_in: (std::vector<mediapipe::Tensor>) tensors to postprocess.
|
// tensors_in: (std::vector<mediapipe::Tensor>) tensors to postprocess.
|
||||||
|
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
||||||
|
// timestamps that should be used to aggregate embedding results.
|
||||||
// graph: the mediapipe builder::Graph instance to be updated.
|
// graph: the mediapipe builder::Graph instance to be updated.
|
||||||
absl::StatusOr<Source<EmbeddingResult>> BuildEmbeddingPostprocessing(
|
absl::StatusOr<EmbeddingPostprocessingOutputStreams>
|
||||||
|
BuildEmbeddingPostprocessing(
|
||||||
const proto::EmbeddingPostprocessingGraphOptions options,
|
const proto::EmbeddingPostprocessingGraphOptions options,
|
||||||
Source<std::vector<Tensor>> tensors_in, Graph& graph) {
|
Source<std::vector<Tensor>> tensors_in,
|
||||||
|
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
|
||||||
// If output tensors are quantized, they must be dequantized first.
|
// If output tensors are quantized, they must be dequantized first.
|
||||||
TensorsSource dequantized_tensors(&tensors_in);
|
TensorsSource dequantized_tensors(&tensors_in);
|
||||||
if (options.has_quantized_outputs()) {
|
if (options.has_quantized_outputs()) {
|
||||||
|
@ -220,7 +244,20 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
|
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
|
||||||
.CopyFrom(options.tensors_to_embeddings_options());
|
.CopyFrom(options.tensors_to_embeddings_options());
|
||||||
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
|
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
|
||||||
return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
|
|
||||||
|
// Adds EmbeddingAggregationCalculator.
|
||||||
|
GenericNode& aggregation_node =
|
||||||
|
graph.AddNode("EmbeddingAggregationCalculator");
|
||||||
|
tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)] >>
|
||||||
|
aggregation_node.In(kEmbeddingsTag);
|
||||||
|
timestamps_in >> aggregation_node.In(kTimestampsTag);
|
||||||
|
|
||||||
|
// Connects outputs.
|
||||||
|
return EmbeddingPostprocessingOutputStreams{
|
||||||
|
/*embeddings=*/aggregation_node[Output<EmbeddingResult>(
|
||||||
|
kEmbeddingsTag)],
|
||||||
|
/*timestamped_embeddings=*/aggregation_node
|
||||||
|
[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)]};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_MEDIAPIPE_GRAPH(
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
|
|
|
@ -44,12 +44,20 @@ namespace processors {
|
||||||
// TENSORS - std::vector<Tensor>
|
// TENSORS - std::vector<Tensor>
|
||||||
// The output tensors of an InferenceCalculator, to convert into
|
// The output tensors of an InferenceCalculator, to convert into
|
||||||
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
||||||
|
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||||
|
// The collection of the timestamps that this calculator should aggregate.
|
||||||
|
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
|
||||||
|
// output is used for results. Otherwise as no timestamp aggregation is
|
||||||
|
// required the EMBEDDINGS output is used for results.
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// EMBEDDINGS - EmbeddingResult
|
// EMBEDDINGS - EmbeddingResult @Optional
|
||||||
// The output EmbeddingResult.
|
// The embedding results aggregated by head. Must be connected if the
|
||||||
//
|
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||||
// TODO: add support for additional optional "TIMESTAMPS" input for
|
// aggregation is not required.
|
||||||
// embeddings aggregation.
|
// TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
|
||||||
|
// The embedding result aggregated by timestamp, then by head. Must be
|
||||||
|
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||||
|
// timestamp aggregation is required.
|
||||||
absl::Status ConfigureEmbeddingPostprocessing(
|
absl::Status ConfigureEmbeddingPostprocessing(
|
||||||
const tasks::core::ModelResources& model_resources,
|
const tasks::core::ModelResources& model_resources,
|
||||||
const proto::EmbedderOptions& embedder_options,
|
const proto::EmbedderOptions& embedder_options,
|
||||||
|
|
|
@ -20,11 +20,20 @@ limitations under the License.
|
||||||
#include "absl/flags/flag.h"
|
#include "absl/flags/flag.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
#include "mediapipe/framework/deps/file_path.h"
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/graph_runner.h"
|
||||||
|
#include "mediapipe/framework/output_stream_poller.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/timestamp.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
@ -37,7 +46,12 @@ namespace components {
|
||||||
namespace processors {
|
namespace processors {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::mediapipe::api2::builder::Graph;
|
||||||
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
|
|
||||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
||||||
|
@ -51,6 +65,16 @@ constexpr char kQuantizedImageClassifierWithoutMetadata[] =
|
||||||
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
|
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
|
||||||
|
|
||||||
constexpr char kTestModelResourcesTag[] = "test_model_resources";
|
constexpr char kTestModelResourcesTag[] = "test_model_resources";
|
||||||
|
constexpr int kMobileNetV3EmbedderEmbeddingSize = 1024;
|
||||||
|
|
||||||
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
constexpr char kTensorsName[] = "tensors";
|
||||||
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
constexpr char kTimestampsName[] = "timestamps";
|
||||||
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
|
constexpr char kEmbeddingsName[] = "embeddings";
|
||||||
|
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||||
|
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings";
|
||||||
|
|
||||||
// Helper function to get ModelResources.
|
// Helper function to get ModelResources.
|
||||||
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
||||||
|
@ -128,8 +152,171 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
||||||
has_quantized_outputs: false)pb")));
|
has_quantized_outputs: false)pb")));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: add E2E Postprocessing tests once timestamp aggregation is
|
class PostprocessingTest : public tflite_shims::testing::Test {
|
||||||
// supported.
|
protected:
|
||||||
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
|
absl::string_view model_name, const proto::EmbedderOptions& options,
|
||||||
|
bool connect_timestamps = false) {
|
||||||
|
ASSIGN_OR_RETURN(auto model_resources,
|
||||||
|
CreateModelResourcesForModel(model_name));
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
auto& postprocessing = graph.AddNode(
|
||||||
|
"mediapipe.tasks.components.processors."
|
||||||
|
"EmbeddingPostprocessingGraph");
|
||||||
|
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing(
|
||||||
|
*model_resources, options,
|
||||||
|
&postprocessing
|
||||||
|
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));
|
||||||
|
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||||
|
postprocessing.In(kTensorsTag);
|
||||||
|
if (connect_timestamps) {
|
||||||
|
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||||
|
kTimestampsName) >>
|
||||||
|
postprocessing.In(kTimestampsTag);
|
||||||
|
postprocessing.Out(kTimestampedEmbeddingsTag)
|
||||||
|
.SetName(kTimestampedEmbeddingsName) >>
|
||||||
|
graph[Output<std::vector<EmbeddingResult>>(
|
||||||
|
kTimestampedEmbeddingsTag)];
|
||||||
|
} else {
|
||||||
|
postprocessing.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >>
|
||||||
|
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
|
}
|
||||||
|
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||||
|
if (connect_timestamps) {
|
||||||
|
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||||
|
kTimestampedEmbeddingsName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(auto poller,
|
||||||
|
calculator_graph_.AddOutputStreamPoller(kEmbeddingsName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void AddTensor(
|
||||||
|
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
|
||||||
|
const Tensor::QuantizationParameters& quantization_parameters = {}) {
|
||||||
|
tensors_->emplace_back(element_type,
|
||||||
|
Tensor::Shape{1, static_cast<int>(tensor.size())},
|
||||||
|
quantization_parameters);
|
||||||
|
auto view = tensors_->back().GetCpuWriteView();
|
||||||
|
T* buffer = view.buffer<T>();
|
||||||
|
std::copy(tensor.begin(), tensor.end(), buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Run(
|
||||||
|
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
|
||||||
|
int timestamp = 0) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
|
||||||
|
// Reset tensors for future calls.
|
||||||
|
tensors_ = absl::make_unique<std::vector<Tensor>>();
|
||||||
|
if (aggregation_timestamps.has_value()) {
|
||||||
|
auto packet = absl::make_unique<std::vector<Timestamp>>();
|
||||||
|
for (const auto& timestamp : *aggregation_timestamps) {
|
||||||
|
packet->emplace_back(Timestamp(timestamp));
|
||||||
|
}
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||||
|
|
||||||
|
Packet packet;
|
||||||
|
if (!poller.Next(&packet)) {
|
||||||
|
return absl::InternalError("Unable to get output packet");
|
||||||
|
}
|
||||||
|
auto result = packet.Get<T>();
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CalculatorGraph calculator_graph_;
|
||||||
|
std::unique_ptr<std::vector<Tensor>> tensors_ =
|
||||||
|
absl::make_unique<std::vector<Tensor>>();
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) {
|
||||||
|
// Build graph.
|
||||||
|
proto::EmbedderOptions options;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto poller,
|
||||||
|
BuildGraph(kMobileNetV3Embedder, options));
|
||||||
|
// Build input tensor.
|
||||||
|
std::vector<float> tensor(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||||
|
tensor[0] = 1.0;
|
||||||
|
|
||||||
|
// Send tensor and get results.
|
||||||
|
AddTensor(tensor, Tensor::ElementType::kFloat32);
|
||||||
|
MP_ASSERT_OK(Run());
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult<EmbeddingResult>(poller));
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
EXPECT_FALSE(results.has_timestamp_ms());
|
||||||
|
EXPECT_EQ(results.embeddings_size(), 1);
|
||||||
|
EXPECT_EQ(results.embeddings(0).head_index(), 0);
|
||||||
|
EXPECT_EQ(results.embeddings(0).head_name(), "feature");
|
||||||
|
EXPECT_EQ(results.embeddings(0).float_embedding().values_size(),
|
||||||
|
kMobileNetV3EmbedderEmbeddingSize);
|
||||||
|
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(0), 1.0);
|
||||||
|
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||||
|
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(i), 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||||
|
// Build graph.
|
||||||
|
proto::EmbedderOptions options;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(kMobileNetV3Embedder, options,
|
||||||
|
/*connect_timestamps=*/true));
|
||||||
|
// Build input tensors.
|
||||||
|
std::vector<float> tensor_0(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||||
|
tensor_0[0] = 1.0;
|
||||||
|
std::vector<float> tensor_1(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||||
|
tensor_1[0] = 2.0;
|
||||||
|
|
||||||
|
// Send tensors and get results.
|
||||||
|
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
|
||||||
|
MP_ASSERT_OK(Run());
|
||||||
|
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
|
||||||
|
MP_ASSERT_OK(Run(
|
||||||
|
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
|
||||||
|
/*timestamp=*/1000));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
|
GetResult<std::vector<EmbeddingResult>>(poller));
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
EXPECT_EQ(results.size(), 2);
|
||||||
|
// First timestamp.
|
||||||
|
EXPECT_EQ(results[0].timestamp_ms(), 0);
|
||||||
|
EXPECT_EQ(results[0].embeddings(0).head_index(), 0);
|
||||||
|
EXPECT_EQ(results[0].embeddings(0).head_name(), "feature");
|
||||||
|
EXPECT_EQ(results[0].embeddings(0).float_embedding().values_size(),
|
||||||
|
kMobileNetV3EmbedderEmbeddingSize);
|
||||||
|
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(0), 1.0);
|
||||||
|
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||||
|
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(i), 0.0);
|
||||||
|
}
|
||||||
|
// Second timestamp.
|
||||||
|
EXPECT_EQ(results[1].timestamp_ms(), 1);
|
||||||
|
EXPECT_EQ(results[1].embeddings(0).head_index(), 0);
|
||||||
|
EXPECT_EQ(results[1].embeddings(0).head_name(), "feature");
|
||||||
|
EXPECT_EQ(results[1].embeddings(0).float_embedding().values_size(),
|
||||||
|
kMobileNetV3EmbedderEmbeddingSize);
|
||||||
|
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(0), 2.0);
|
||||||
|
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||||
|
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(i), 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace processors
|
} // namespace processors
|
||||||
|
|
|
@ -32,7 +32,4 @@ message EmbeddingPostprocessingGraphOptions {
|
||||||
|
|
||||||
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
|
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
|
||||||
optional bool has_quantized_outputs = 2;
|
optional bool has_quantized_outputs = 2;
|
||||||
|
|
||||||
// TODO: add options to control whether timestamp aggregation
|
|
||||||
// should be used or not.
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user