No public description
PiperOrigin-RevId: 569274219
This commit is contained in:
parent
33d6143a1a
commit
e169849041
|
@ -1616,3 +1616,34 @@ cc_test(
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "pass_through_or_empty_detection_vector_calculator",
|
||||||
|
srcs = ["pass_through_or_empty_detection_vector_calculator.cc"],
|
||||||
|
hdrs = ["pass_through_or_empty_detection_vector_calculator.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_context",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "pass_through_or_empty_detection_vector_calculator_test",
|
||||||
|
srcs = ["pass_through_or_empty_detection_vector_calculator_test.cc"],
|
||||||
|
tags = ["desktop_only_test"],
|
||||||
|
deps = [
|
||||||
|
":pass_through_or_empty_detection_vector_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:status_matchers",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
#include "mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/calculator_context.h"
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
class PassThroughOrEmptyDetectionVectorCalculatorImpl
|
||||||
|
: public mediapipe::api2::NodeImpl<
|
||||||
|
PassThroughOrEmptyDetectionVectorCalculator> {
|
||||||
|
public:
|
||||||
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
if (kInputVector(cc).IsEmpty()) {
|
||||||
|
kOutputVector(cc).Send(std::vector<mediapipe::Detection>{});
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
kOutputVector(cc).Send(kInputVector(cc));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_NODE_IMPLEMENTATION(PassThroughOrEmptyDetectionVectorCalculatorImpl);
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,55 @@
|
||||||
|
#ifndef MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_
|
||||||
|
#define MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
// Calculator to pass through input vector of detections if packet is not empty,
|
||||||
|
// otherwise - outputing a new empty vector. So, instead of empty packet you get
|
||||||
|
// a packet containing empty vector.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// node {
|
||||||
|
// calculator: "PassThroughOrEmptyDetectionVectorCalculator"
|
||||||
|
// input_stream: "TICK:tick"
|
||||||
|
// input_stream: "VECTOR:input_detections"
|
||||||
|
// output_stream: "VECTOR:output_detections"
|
||||||
|
// }
|
||||||
|
class PassThroughOrEmptyDetectionVectorCalculator
|
||||||
|
: public mediapipe::api2::NodeIntf {
|
||||||
|
public:
|
||||||
|
static constexpr mediapipe::api2::Input<std::vector<mediapipe::Detection>>
|
||||||
|
kInputVector{"VECTOR"};
|
||||||
|
static constexpr mediapipe::api2::Input<mediapipe::api2::AnyType> kTick{
|
||||||
|
"TICK"};
|
||||||
|
static constexpr mediapipe::api2::Output<std::vector<mediapipe::Detection>>
|
||||||
|
kOutputVector{"VECTOR"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_INTERFACE(
|
||||||
|
::mediapipe::PassThroughOrEmptyDetectionVectorCalculator, kInputVector,
|
||||||
|
kTick, kOutputVector);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename TickT>
|
||||||
|
api2::builder::Stream<std::vector<mediapipe::Detection>>
|
||||||
|
PassThroughOrEmptyDetectionVector(
|
||||||
|
api2::builder::Stream<std::vector<mediapipe::Detection>> detections,
|
||||||
|
api2::builder::Stream<TickT> tick, mediapipe::api2::builder::Graph& graph) {
|
||||||
|
auto& node =
|
||||||
|
graph.AddNode("mediapipe.PassThroughOrEmptyDetectionVectorCalculator");
|
||||||
|
detections.ConnectTo(
|
||||||
|
node[PassThroughOrEmptyDetectionVectorCalculator::kInputVector]);
|
||||||
|
tick.ConnectTo(node[PassThroughOrEmptyDetectionVectorCalculator::kTick]);
|
||||||
|
return node[PassThroughOrEmptyDetectionVectorCalculator::kOutputVector];
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_
|
|
@ -0,0 +1,113 @@
|
||||||
|
#include "mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
CalculatorGraphConfig GetGraphConfig() {
|
||||||
|
mediapipe::api2::builder::Graph graph;
|
||||||
|
mediapipe::api2::builder::Stream<std::string> tick =
|
||||||
|
graph.In("TICK").SetName("tick").Cast<std::string>();
|
||||||
|
mediapipe::api2::builder::Stream<std::vector<mediapipe::Detection>>
|
||||||
|
detections = graph.In("DETECTIONS")
|
||||||
|
.SetName("input_detections")
|
||||||
|
.Cast<std::vector<mediapipe::Detection>>();
|
||||||
|
|
||||||
|
mediapipe::api2::builder::Stream<std::vector<mediapipe::Detection>>
|
||||||
|
output_detections =
|
||||||
|
PassThroughOrEmptyDetectionVector(detections, tick, graph);
|
||||||
|
output_detections.SetName("output_detections");
|
||||||
|
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status SendTick(CalculatorGraph& graph, int at) {
|
||||||
|
return graph.AddPacketToInputStream(
|
||||||
|
"tick",
|
||||||
|
mediapipe::MakePacket<std::string>("tick").At(mediapipe::Timestamp(at)));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status SendDetections(CalculatorGraph& graph,
|
||||||
|
std::vector<mediapipe::Detection> detections,
|
||||||
|
int at) {
|
||||||
|
return graph.AddPacketToInputStream(
|
||||||
|
"input_detections",
|
||||||
|
mediapipe::MakePacket<std::vector<mediapipe::Detection>>(
|
||||||
|
std::move(detections))
|
||||||
|
.At(mediapipe::Timestamp(at)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassThroughOrEmptyDetectionVectorCalculatorTest, PassThrough) {
|
||||||
|
CalculatorGraphConfig graph_config = GetGraphConfig();
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output_detections", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph calculator_graph(graph_config);
|
||||||
|
MP_ASSERT_OK(calculator_graph.StartRun({}));
|
||||||
|
|
||||||
|
// Sending empty vector.
|
||||||
|
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/1));
|
||||||
|
MP_ASSERT_OK(SendDetections(calculator_graph,
|
||||||
|
std::vector<mediapipe::Detection>{},
|
||||||
|
/*at=*/1));
|
||||||
|
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
ASSERT_EQ(output_packets.size(), 1);
|
||||||
|
EXPECT_TRUE(
|
||||||
|
output_packets[0].Get<std::vector<mediapipe::Detection>>().empty());
|
||||||
|
|
||||||
|
// Sending non empty vector.
|
||||||
|
output_packets.clear();
|
||||||
|
mediapipe::Detection detection;
|
||||||
|
detection.set_detection_id(1000);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/2));
|
||||||
|
MP_ASSERT_OK(SendDetections(calculator_graph, {detection}, /*at=*/2));
|
||||||
|
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
ASSERT_EQ(output_packets.size(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PassThroughOrEmptyDetectionVectorCalculatorTest, OrEmptyVector) {
|
||||||
|
CalculatorGraphConfig graph_config = GetGraphConfig();
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output_detections", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph calculator_graph(graph_config);
|
||||||
|
MP_ASSERT_OK(calculator_graph.StartRun({}));
|
||||||
|
|
||||||
|
mediapipe::Detection detection;
|
||||||
|
detection.set_detection_id(1000);
|
||||||
|
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/1));
|
||||||
|
MP_ASSERT_OK(SendDetections(calculator_graph, {detection}, /*at=*/1));
|
||||||
|
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/2));
|
||||||
|
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/3));
|
||||||
|
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/4));
|
||||||
|
// This should trigger trigger calculator at 2, 3, 4 as detections are not
|
||||||
|
// expected.
|
||||||
|
MP_ASSERT_OK(SendDetections(calculator_graph,
|
||||||
|
std::vector<mediapipe::Detection>{},
|
||||||
|
/*at=*/5));
|
||||||
|
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
ASSERT_EQ(output_packets.size(), 4);
|
||||||
|
|
||||||
|
for (int i = 1; i < output_packets.size(); ++i) {
|
||||||
|
EXPECT_TRUE(
|
||||||
|
output_packets[i].Get<std::vector<mediapipe::Detection>>().empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user