Move stream API loopback to third_party.
PiperOrigin-RevId: 559037020
This commit is contained in:
parent
9bc8b3bb4f
commit
edb0a64d0e
14
mediapipe/framework/api2/stream/BUILD
Normal file
14
mediapipe/framework/api2/stream/BUILD
Normal file
|
@ -0,0 +1,14 @@
|
|||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "loopback",
|
||||
hdrs = ["loopback.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:previous_loopback_calculator",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
55
mediapipe/framework/api2/stream/loopback.h
Normal file
55
mediapipe/framework/api2/stream/loopback.h
Normal file
|
@ -0,0 +1,55 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_LOOPBACK_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_LOOPBACK_H_
|
||||
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
|
||||
// Returns a pair of two values:
|
||||
// - A stream with loopback data. Such stream, for each new packet in @tick
|
||||
// stream, provides a packet previously calculated within the graph.
|
||||
// - A function to define/set loopback data producing stream.
|
||||
// NOTE:
|
||||
// * function must be called and only once, otherwise graph validation will
|
||||
// fail.
|
||||
// * calling function after graph is destroyed results in undefined behavior
|
||||
//
|
||||
// The function wraps `PreviousLoopbackCalculator` into a convenience function
|
||||
// and allows graph input to be processed together with some previous output.
|
||||
//
|
||||
// -------
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```
|
||||
//
|
||||
// Graph graph;
|
||||
// Stream<...> tick = ...; // E.g. main input can surve as a tick.
|
||||
// auto [prev_data, set_loopback_fn] = GetLoopbackData<int>(tick, graph);
|
||||
// ...
|
||||
// Stream<int> data = ...;
|
||||
// set_loopback_fn(data);
|
||||
//
|
||||
// ```
|
||||
template <class DataT, class TickT>
|
||||
std::pair<Stream<DataT>, std::function<void(Stream<DataT>)>> GetLoopbackData(
|
||||
Stream<TickT> tick, mediapipe::api2::builder::Graph& graph) {
|
||||
auto& prev = graph.AddNode("PreviousLoopbackCalculator");
|
||||
tick.ConnectTo(prev.In("MAIN"));
|
||||
return {prev.Out("PREV_LOOP").template Cast<DataT>(),
|
||||
[prev_ptr = &prev](Stream<DataT> data) {
|
||||
// TODO: input stream info must be specified, but
|
||||
// builder api doesn't support it at the moment. As a workaround,
|
||||
// input stream info is added by GraphBuilder as a graph building
|
||||
// post processing step.
|
||||
data.ConnectTo(prev_ptr->In("LOOP"));
|
||||
}};
|
||||
}
|
||||
|
||||
} // namespace mediapipe::api2::builder
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_LOOPBACK_H_
|
55
mediapipe/framework/api2/stream/loopback_test.cc
Normal file
55
mediapipe/framework/api2/stream/loopback_test.cc
Normal file
|
@ -0,0 +1,55 @@
|
|||
#include "mediapipe/framework/api2/stream/loopback.h"
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
namespace {
|
||||
|
||||
class TestDataProducer : public NodeIntf {
|
||||
public:
|
||||
static constexpr Input<float> kLoopbackData{"LOOPBACK_DATA"};
|
||||
static constexpr Output<float> kProducedData{"PRODUCED_DATA"};
|
||||
MEDIAPIPE_NODE_INTERFACE(TestDataProducer, kLoopbackData, kProducedData);
|
||||
};
|
||||
|
||||
TEST(LoopbackTest, GetLoopbackData) {
|
||||
Graph graph;
|
||||
|
||||
Stream<int> tick = graph.In("TICK").Cast<int>();
|
||||
|
||||
auto [data, set_loopback_data_fn] = GetLoopbackData<float>(tick, graph);
|
||||
|
||||
auto& producer = graph.AddNode<TestDataProducer>();
|
||||
data.ConnectTo(producer[TestDataProducer::kLoopbackData]);
|
||||
Stream<float> data_to_loopback(producer[TestDataProducer::kProducedData]);
|
||||
|
||||
set_loopback_data_fn(data_to_loopback);
|
||||
|
||||
// PreviousLoopbackCalculator configuration is incorrect here and should be
|
||||
// updated when corresponding b/175887687 is fixed.
|
||||
// Use mediapipe::aimatter::GraphBuilder to fix back edges in the graph.
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
testing::EqualsProto(
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "PreviousLoopbackCalculator"
|
||||
input_stream: "LOOP:__stream_2"
|
||||
input_stream: "MAIN:__stream_0"
|
||||
output_stream: "PREV_LOOP:__stream_1"
|
||||
}
|
||||
node {
|
||||
calculator: "TestDataProducer"
|
||||
input_stream: "LOOPBACK_DATA:__stream_1"
|
||||
output_stream: "PRODUCED_DATA:__stream_2"
|
||||
}
|
||||
input_stream: "TICK:__stream_0"
|
||||
)pb")));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::api2::builder
|
Loading…
Reference in New Issue
Block a user