No public description
PiperOrigin-RevId: 569274219
This commit is contained in:
parent
33d6143a1a
commit
e169849041
|
@ -1616,3 +1616,34 @@ cc_test(
|
|||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pass_through_or_empty_detection_vector_calculator",
|
||||
srcs = ["pass_through_or_empty_detection_vector_calculator.cc"],
|
||||
hdrs = ["pass_through_or_empty_detection_vector_calculator.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "pass_through_or_empty_detection_vector_calculator_test",
|
||||
srcs = ["pass_through_or_empty_detection_vector_calculator_test.cc"],
|
||||
tags = ["desktop_only_test"],
|
||||
deps = [
|
||||
":pass_through_or_empty_detection_vector_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status_matchers",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
#include "mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class PassThroughOrEmptyDetectionVectorCalculatorImpl
|
||||
: public mediapipe::api2::NodeImpl<
|
||||
PassThroughOrEmptyDetectionVectorCalculator> {
|
||||
public:
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (kInputVector(cc).IsEmpty()) {
|
||||
kOutputVector(cc).Send(std::vector<mediapipe::Detection>{});
|
||||
return absl::OkStatus();
|
||||
}
|
||||
kOutputVector(cc).Send(kInputVector(cc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_NODE_IMPLEMENTATION(PassThroughOrEmptyDetectionVectorCalculatorImpl);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,55 @@
|
|||
#ifndef MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Calculator to pass through input vector of detections if packet is not empty,
|
||||
// otherwise - outputing a new empty vector. So, instead of empty packet you get
|
||||
// a packet containing empty vector.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "PassThroughOrEmptyDetectionVectorCalculator"
|
||||
// input_stream: "TICK:tick"
|
||||
// input_stream: "VECTOR:input_detections"
|
||||
// output_stream: "VECTOR:output_detections"
|
||||
// }
|
||||
class PassThroughOrEmptyDetectionVectorCalculator
|
||||
: public mediapipe::api2::NodeIntf {
|
||||
public:
|
||||
static constexpr mediapipe::api2::Input<std::vector<mediapipe::Detection>>
|
||||
kInputVector{"VECTOR"};
|
||||
static constexpr mediapipe::api2::Input<mediapipe::api2::AnyType> kTick{
|
||||
"TICK"};
|
||||
static constexpr mediapipe::api2::Output<std::vector<mediapipe::Detection>>
|
||||
kOutputVector{"VECTOR"};
|
||||
|
||||
MEDIAPIPE_NODE_INTERFACE(
|
||||
::mediapipe::PassThroughOrEmptyDetectionVectorCalculator, kInputVector,
|
||||
kTick, kOutputVector);
|
||||
};
|
||||
|
||||
template <typename TickT>
|
||||
api2::builder::Stream<std::vector<mediapipe::Detection>>
|
||||
PassThroughOrEmptyDetectionVector(
|
||||
api2::builder::Stream<std::vector<mediapipe::Detection>> detections,
|
||||
api2::builder::Stream<TickT> tick, mediapipe::api2::builder::Graph& graph) {
|
||||
auto& node =
|
||||
graph.AddNode("mediapipe.PassThroughOrEmptyDetectionVectorCalculator");
|
||||
detections.ConnectTo(
|
||||
node[PassThroughOrEmptyDetectionVectorCalculator::kInputVector]);
|
||||
tick.ConnectTo(node[PassThroughOrEmptyDetectionVectorCalculator::kTick]);
|
||||
return node[PassThroughOrEmptyDetectionVectorCalculator::kOutputVector];
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_UTILS_PASS_THROUGH_OR_EMPTY_DETECTION_VECTOR_CALCULATOR_H_
|
|
@ -0,0 +1,113 @@
|
|||
#include "mediapipe/calculators/util/pass_through_or_empty_detection_vector_calculator.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
CalculatorGraphConfig GetGraphConfig() {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
mediapipe::api2::builder::Stream<std::string> tick =
|
||||
graph.In("TICK").SetName("tick").Cast<std::string>();
|
||||
mediapipe::api2::builder::Stream<std::vector<mediapipe::Detection>>
|
||||
detections = graph.In("DETECTIONS")
|
||||
.SetName("input_detections")
|
||||
.Cast<std::vector<mediapipe::Detection>>();
|
||||
|
||||
mediapipe::api2::builder::Stream<std::vector<mediapipe::Detection>>
|
||||
output_detections =
|
||||
PassThroughOrEmptyDetectionVector(detections, tick, graph);
|
||||
output_detections.SetName("output_detections");
|
||||
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
absl::Status SendTick(CalculatorGraph& graph, int at) {
|
||||
return graph.AddPacketToInputStream(
|
||||
"tick",
|
||||
mediapipe::MakePacket<std::string>("tick").At(mediapipe::Timestamp(at)));
|
||||
}
|
||||
|
||||
absl::Status SendDetections(CalculatorGraph& graph,
|
||||
std::vector<mediapipe::Detection> detections,
|
||||
int at) {
|
||||
return graph.AddPacketToInputStream(
|
||||
"input_detections",
|
||||
mediapipe::MakePacket<std::vector<mediapipe::Detection>>(
|
||||
std::move(detections))
|
||||
.At(mediapipe::Timestamp(at)));
|
||||
}
|
||||
|
||||
TEST(PassThroughOrEmptyDetectionVectorCalculatorTest, PassThrough) {
|
||||
CalculatorGraphConfig graph_config = GetGraphConfig();
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output_detections", &graph_config, &output_packets);
|
||||
|
||||
CalculatorGraph calculator_graph(graph_config);
|
||||
MP_ASSERT_OK(calculator_graph.StartRun({}));
|
||||
|
||||
// Sending empty vector.
|
||||
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/1));
|
||||
MP_ASSERT_OK(SendDetections(calculator_graph,
|
||||
std::vector<mediapipe::Detection>{},
|
||||
/*at=*/1));
|
||||
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
|
||||
|
||||
ASSERT_EQ(output_packets.size(), 1);
|
||||
EXPECT_TRUE(
|
||||
output_packets[0].Get<std::vector<mediapipe::Detection>>().empty());
|
||||
|
||||
// Sending non empty vector.
|
||||
output_packets.clear();
|
||||
mediapipe::Detection detection;
|
||||
detection.set_detection_id(1000);
|
||||
|
||||
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/2));
|
||||
MP_ASSERT_OK(SendDetections(calculator_graph, {detection}, /*at=*/2));
|
||||
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
|
||||
|
||||
ASSERT_EQ(output_packets.size(), 1);
|
||||
}
|
||||
|
||||
TEST(PassThroughOrEmptyDetectionVectorCalculatorTest, OrEmptyVector) {
|
||||
CalculatorGraphConfig graph_config = GetGraphConfig();
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output_detections", &graph_config, &output_packets);
|
||||
|
||||
CalculatorGraph calculator_graph(graph_config);
|
||||
MP_ASSERT_OK(calculator_graph.StartRun({}));
|
||||
|
||||
mediapipe::Detection detection;
|
||||
detection.set_detection_id(1000);
|
||||
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/1));
|
||||
MP_ASSERT_OK(SendDetections(calculator_graph, {detection}, /*at=*/1));
|
||||
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/2));
|
||||
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/3));
|
||||
MP_ASSERT_OK(SendTick(calculator_graph, /*at=*/4));
|
||||
// This should trigger trigger calculator at 2, 3, 4 as detections are not
|
||||
// expected.
|
||||
MP_ASSERT_OK(SendDetections(calculator_graph,
|
||||
std::vector<mediapipe::Detection>{},
|
||||
/*at=*/5));
|
||||
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
|
||||
|
||||
ASSERT_EQ(output_packets.size(), 4);
|
||||
|
||||
for (int i = 1; i < output_packets.size(); ++i) {
|
||||
EXPECT_TRUE(
|
||||
output_packets[i].Get<std::vector<mediapipe::Detection>>().empty());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user