No public description

PiperOrigin-RevId: 569274219
This commit is contained in:
MediaPipe Team 2023-09-28 13:28:46 -07:00 committed by Copybara-Service
parent 33d6143a1a
commit e169849041
4 changed files with 226 additions and 0 deletions

View File

@ -1616,3 +1616,34 @@ cc_test(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
cc_library(
name = "pass_through_or_empty_detection_vector_calculator",
srcs = ["pass_through_or_empty_detection_vector_calculator.cc"],
hdrs = ["pass_through_or_empty_detection_vector_calculator.h"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:detection_cc_proto",
"@com_google_absl//absl/status",
],
alwayslink = 1,
)
cc_test(
name = "pass_through_or_empty_detection_vector_calculator_test",
srcs = ["pass_through_or_empty_detection_vector_calculator_test.cc"],
tags = ["desktop_only_test"],
deps = [
":pass_through_or_empty_detection_vector_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status_matchers",
"@com_google_absl//absl/status",
],
)

View File

@ -0,0 +1,27 @@
#include "mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.h"
#include <vector>
#include "absl/status/status.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/formats/detection.pb.h"
namespace mediapipe {
class PassThroughOrEmptyDetectionVectorCalculatorImpl
: public mediapipe::api2::NodeImpl<
PassThroughOrEmptyDetectionVectorCalculator> {
public:
absl::Status Process(CalculatorContext* cc) override {
if (kInputVector(cc).IsEmpty()) {
kOutputVector(cc).Send(std::vector<mediapipe::Detection>{});
return absl::OkStatus();
}
kOutputVector(cc).Send(kInputVector(cc));
return absl::OkStatus();
}
};
MEDIAPIPE_NODE_IMPLEMENTATION(PassThroughOrEmptyDetectionVectorCalculatorImpl);
} // namespace mediapipe

View File

@ -0,0 +1,55 @@
#ifndef MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_
#include <vector>
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/detection.pb.h"
namespace mediapipe {
// Calculator to pass through input vector of detections if packet is not empty,
// otherwise - outputing a new empty vector. So, instead of empty packet you get
// a packet containing empty vector.
//
// Example:
// node {
// calculator: "PassThroughOrEmptyDetectionVectorCalculator"
// input_stream: "TICK:tick"
// input_stream: "VECTOR:input_detections"
// output_stream: "VECTOR:output_detections"
// }
class PassThroughOrEmptyDetectionVectorCalculator
: public mediapipe::api2::NodeIntf {
public:
static constexpr mediapipe::api2::Input<std::vector<mediapipe::Detection>>
kInputVector{"VECTOR"};
static constexpr mediapipe::api2::Input<mediapipe::api2::AnyType> kTick{
"TICK"};
static constexpr mediapipe::api2::Output<std::vector<mediapipe::Detection>>
kOutputVector{"VECTOR"};
MEDIAPIPE_NODE_INTERFACE(
::mediapipe::PassThroughOrEmptyDetectionVectorCalculator, kInputVector,
kTick, kOutputVector);
};
template <typename TickT>
api2::builder::Stream<std::vector<mediapipe::Detection>>
PassThroughOrEmptyDetectionVector(
api2::builder::Stream<std::vector<mediapipe::Detection>> detections,
api2::builder::Stream<TickT> tick, mediapipe::api2::builder::Graph& graph) {
auto& node =
graph.AddNode("mediapipe.PassThroughOrEmptyDetectionVectorCalculator");
detections.ConnectTo(
node[PassThroughOrEmptyDetectionVectorCalculator::kInputVector]);
tick.ConnectTo(node[PassThroughOrEmptyDetectionVectorCalculator::kTick]);
return node[PassThroughOrEmptyDetectionVectorCalculator::kOutputVector];
}
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_

View File

@ -0,0 +1,113 @@
#include "mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.h"
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
CalculatorGraphConfig GetGraphConfig() {
mediapipe::api2::builder::Graph graph;
mediapipe::api2::builder::Stream<std::string> tick =
graph.In("TICK").SetName("tick").Cast<std::string>();
mediapipe::api2::builder::Stream<std::vector<mediapipe::Detection>>
detections = graph.In("DETECTIONS")
.SetName("input_detections")
.Cast<std::vector<mediapipe::Detection>>();
mediapipe::api2::builder::Stream<std::vector<mediapipe::Detection>>
output_detections =
PassThroughOrEmptyDetectionVector(detections, tick, graph);
output_detections.SetName("output_detections");
return graph.GetConfig();
}
absl::Status SendTick(CalculatorGraph& graph, int at) {
return graph.AddPacketToInputStream(
"tick",
mediapipe::MakePacket<std::string>("tick").At(mediapipe::Timestamp(at)));
}
absl::Status SendDetections(CalculatorGraph& graph,
std::vector<mediapipe::Detection> detections,
int at) {
return graph.AddPacketToInputStream(
"input_detections",
mediapipe::MakePacket<std::vector<mediapipe::Detection>>(
std::move(detections))
.At(mediapipe::Timestamp(at)));
}
TEST(PassThroughOrEmptyDetectionVectorCalculatorTest, PassThrough) {
CalculatorGraphConfig graph_config = GetGraphConfig();
std::vector<Packet> output_packets;
tool::AddVectorSink("output_detections", &graph_config, &output_packets);
CalculatorGraph calculator_graph(graph_config);
MP_ASSERT_OK(calculator_graph.StartRun({}));
// Sending empty vector.
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/1));
MP_ASSERT_OK(SendDetections(calculator_graph,
std::vector<mediapipe::Detection>{},
/*at=*/1));
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
ASSERT_EQ(output_packets.size(), 1);
EXPECT_TRUE(
output_packets[0].Get<std::vector<mediapipe::Detection>>().empty());
// Sending non empty vector.
output_packets.clear();
mediapipe::Detection detection;
detection.set_detection_id(1000);
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/2));
MP_ASSERT_OK(SendDetections(calculator_graph, {detection}, /*at=*/2));
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
ASSERT_EQ(output_packets.size(), 1);
}
TEST(PassThroughOrEmptyDetectionVectorCalculatorTest, OrEmptyVector) {
CalculatorGraphConfig graph_config = GetGraphConfig();
std::vector<Packet> output_packets;
tool::AddVectorSink("output_detections", &graph_config, &output_packets);
CalculatorGraph calculator_graph(graph_config);
MP_ASSERT_OK(calculator_graph.StartRun({}));
mediapipe::Detection detection;
detection.set_detection_id(1000);
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/1));
MP_ASSERT_OK(SendDetections(calculator_graph, {detection}, /*at=*/1));
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/2));
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/3));
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/4));
// This should trigger trigger calculator at 2, 3, 4 as detections are not
// expected.
MP_ASSERT_OK(SendDetections(calculator_graph,
std::vector<mediapipe::Detection>{},
/*at=*/5));
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
ASSERT_EQ(output_packets.size(), 4);
for (int i = 1; i < output_packets.size(); ++i) {
EXPECT_TRUE(
output_packets[i].Get<std::vector<mediapipe::Detection>>().empty());
}
}
} // namespace
} // namespace mediapipe