diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 4186cbea2..ac69d969f 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -1616,3 +1616,34 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "pass_through_or_empty_detection_vector_calculator", + srcs = ["pass_through_or_empty_detection_vector_calculator.cc"], + hdrs = ["pass_through_or_empty_detection_vector_calculator.h"], + deps = [ + "//mediapipe/framework:calculator_context", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + +cc_test( + name = "pass_through_or_empty_detection_vector_calculator_test", + srcs = ["pass_through_or_empty_detection_vector_calculator_test.cc"], + tags = ["desktop_only_test"], + deps = [ + ":pass_through_or_empty_detection_vector_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status_matchers", + "@com_google_absl//absl/status", + ], +) diff --git a/mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.cc b/mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.cc new file mode 100644 index 000000000..811c5ae7e --- /dev/null +++ b/mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.cc @@ -0,0 +1,27 @@ +#include "mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.h" + +#include + +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/formats/detection.pb.h" + +namespace mediapipe { + +class PassThroughOrEmptyDetectionVectorCalculatorImpl + : public mediapipe::api2::NodeImpl< + PassThroughOrEmptyDetectionVectorCalculator> { + public: + absl::Status Process(CalculatorContext* cc) override { + if (kInputVector(cc).IsEmpty()) { + kOutputVector(cc).Send(std::vector{}); + return absl::OkStatus(); + } + kOutputVector(cc).Send(kInputVector(cc)); + return absl::OkStatus(); + } +}; +MEDIAPIPE_NODE_IMPLEMENTATION(PassThroughOrEmptyDetectionVectorCalculatorImpl); + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.h b/mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.h new file mode 100644 index 000000000..ef8c5c5ca --- /dev/null +++ b/mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.h @@ -0,0 +1,55 @@ +#ifndef MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_ + +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/detection.pb.h" + +namespace mediapipe { + +// Calculator to pass through input vector of detections if packet is not empty, +// otherwise - outputing a new empty vector. So, instead of empty packet you get +// a packet containing empty vector. +// +// Example: +// node { +// calculator: "PassThroughOrEmptyDetectionVectorCalculator" +// input_stream: "TICK:tick" +// input_stream: "VECTOR:input_detections" +// output_stream: "VECTOR:output_detections" +// } +class PassThroughOrEmptyDetectionVectorCalculator + : public mediapipe::api2::NodeIntf { + public: + static constexpr mediapipe::api2::Input> + kInputVector{"VECTOR"}; + static constexpr mediapipe::api2::Input kTick{ + "TICK"}; + static constexpr mediapipe::api2::Output> + kOutputVector{"VECTOR"}; + + MEDIAPIPE_NODE_INTERFACE( + ::mediapipe::PassThroughOrEmptyDetectionVectorCalculator, kInputVector, + kTick, kOutputVector); +}; + +template +api2::builder::Stream> +PassThroughOrEmptyDetectionVector( + api2::builder::Stream> detections, + api2::builder::Stream tick, mediapipe::api2::builder::Graph& graph) { + auto& node = + graph.AddNode("mediapipe.PassThroughOrEmptyDetectionVectorCalculator"); + detections.ConnectTo( + node[PassThroughOrEmptyDetectionVectorCalculator::kInputVector]); + tick.ConnectTo(node[PassThroughOrEmptyDetectionVectorCalculator::kTick]); + return node[PassThroughOrEmptyDetectionVectorCalculator::kOutputVector]; +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_ diff --git a/mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator_test.cc b/mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator_test.cc new file mode 100644 index 000000000..a586f509b --- /dev/null +++ b/mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator_test.cc @@ -0,0 +1,113 @@ +#include "mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +CalculatorGraphConfig GetGraphConfig() { + mediapipe::api2::builder::Graph graph; + mediapipe::api2::builder::Stream tick = + graph.In("TICK").SetName("tick").Cast(); + mediapipe::api2::builder::Stream> + detections = graph.In("DETECTIONS") + .SetName("input_detections") + .Cast>(); + + mediapipe::api2::builder::Stream> + output_detections = + PassThroughOrEmptyDetectionVector(detections, tick, graph); + output_detections.SetName("output_detections"); + + return graph.GetConfig(); +} + +absl::Status SendTick(CalculatorGraph& graph, int at) { + return graph.AddPacketToInputStream( + "tick", + mediapipe::MakePacket("tick").At(mediapipe::Timestamp(at))); +} + +absl::Status SendDetections(CalculatorGraph& graph, + std::vector detections, + int at) { + return graph.AddPacketToInputStream( + "input_detections", + mediapipe::MakePacket>( + std::move(detections)) + .At(mediapipe::Timestamp(at))); +} + +TEST(PassThroughOrEmptyDetectionVectorCalculatorTest, PassThrough) { + CalculatorGraphConfig graph_config = GetGraphConfig(); + std::vector output_packets; + tool::AddVectorSink("output_detections", &graph_config, &output_packets); + + CalculatorGraph calculator_graph(graph_config); + MP_ASSERT_OK(calculator_graph.StartRun({})); + + // Sending empty vector. + MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/1)); + MP_ASSERT_OK(SendDetections(calculator_graph, + std::vector{}, + /*at=*/1)); + MP_ASSERT_OK(calculator_graph.WaitUntilIdle()); + + ASSERT_EQ(output_packets.size(), 1); + EXPECT_TRUE( + output_packets[0].Get>().empty()); + + // Sending non empty vector. + output_packets.clear(); + mediapipe::Detection detection; + detection.set_detection_id(1000); + + MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/2)); + MP_ASSERT_OK(SendDetections(calculator_graph, {detection}, /*at=*/2)); + MP_ASSERT_OK(calculator_graph.WaitUntilIdle()); + + ASSERT_EQ(output_packets.size(), 1); +} + +TEST(PassThroughOrEmptyDetectionVectorCalculatorTest, OrEmptyVector) { + CalculatorGraphConfig graph_config = GetGraphConfig(); + std::vector output_packets; + tool::AddVectorSink("output_detections", &graph_config, &output_packets); + + CalculatorGraph calculator_graph(graph_config); + MP_ASSERT_OK(calculator_graph.StartRun({})); + + mediapipe::Detection detection; + detection.set_detection_id(1000); + MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/1)); + MP_ASSERT_OK(SendDetections(calculator_graph, {detection}, /*at=*/1)); + MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/2)); + MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/3)); + MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/4)); + // This should trigger trigger calculator at 2, 3, 4 as detections are not + // expected. + MP_ASSERT_OK(SendDetections(calculator_graph, + std::vector{}, + /*at=*/5)); + MP_ASSERT_OK(calculator_graph.WaitUntilIdle()); + + ASSERT_EQ(output_packets.size(), 4); + + for (int i = 1; i < output_packets.size(); ++i) { + EXPECT_TRUE( + output_packets[i].Get>().empty()); + } +} + +} // namespace +} // namespace mediapipe