From 36f78f6e4a9ac0bc5718918ffac6fe46c5a30964 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 18 Sep 2023 14:39:52 -0700 Subject: [PATCH] threshold stream utility function. PiperOrigin-RevId: 566417914 --- mediapipe/framework/api2/stream/BUILD | 25 ++++++++++++ mediapipe/framework/api2/stream/threshold.cc | 17 ++++++++ mediapipe/framework/api2/stream/threshold.h | 13 +++++++ .../framework/api2/stream/threshold_test.cc | 39 +++++++++++++++++++ 4 files changed, 94 insertions(+) create mode 100644 mediapipe/framework/api2/stream/threshold.cc create mode 100644 mediapipe/framework/api2/stream/threshold.h create mode 100644 mediapipe/framework/api2/stream/threshold_test.cc diff --git a/mediapipe/framework/api2/stream/BUILD b/mediapipe/framework/api2/stream/BUILD index 091656b2e..8c6e9bd18 100644 --- a/mediapipe/framework/api2/stream/BUILD +++ b/mediapipe/framework/api2/stream/BUILD @@ -167,3 +167,28 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", ], ) + +cc_library( + name = "threshold", + srcs = ["threshold.cc"], + hdrs = ["threshold.h"], + deps = [ + "//mediapipe/calculators/util:thresholding_calculator", + "//mediapipe/calculators/util:thresholding_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + ], +) + +cc_test( + name = "threshold_test", + srcs = ["threshold_test.cc"], + deps = [ + ":threshold", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status_matchers", + ], +) diff --git a/mediapipe/framework/api2/stream/threshold.cc b/mediapipe/framework/api2/stream/threshold.cc new file mode 100644 index 000000000..48912b0f9 --- /dev/null +++ b/mediapipe/framework/api2/stream/threshold.cc @@ -0,0 +1,17 @@ +#include "mediapipe/framework/api2/stream/threshold.h" + +#include "mediapipe/calculators/util/thresholding_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" + +namespace mediapipe::api2::builder { + +Stream IsOverThreshold(Stream value, double threshold, + mediapipe::api2::builder::Graph& graph) { + auto& node = graph.AddNode("ThresholdingCalculator"); + auto& node_opts = node.GetOptions(); + node_opts.set_threshold(threshold); + value.ConnectTo(node.In("FLOAT")); + return node.Out("FLAG").Cast(); +} + +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/threshold.h b/mediapipe/framework/api2/stream/threshold.h new file mode 100644 index 000000000..8bb3bf2d7 --- /dev/null +++ b/mediapipe/framework/api2/stream/threshold.h @@ -0,0 +1,13 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_THRESHOLD_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_THRESHOLD_H_ + +#include "mediapipe/framework/api2/builder.h" + +namespace mediapipe::api2::builder { + +Stream IsOverThreshold(Stream value, double threshold, + Graph& graph); + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_THRESHOLD_H_ diff --git a/mediapipe/framework/api2/stream/threshold_test.cc b/mediapipe/framework/api2/stream/threshold_test.cc new file mode 100644 index 000000000..df531a67f --- /dev/null +++ b/mediapipe/framework/api2/stream/threshold_test.cc @@ -0,0 +1,39 @@ +#include "mediapipe/framework/api2/stream/threshold.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(ThresholdTest, IsOverThresholdTest) { + mediapipe::api2::builder::Graph graph; + + Stream score = graph.In("SCORE").Cast(); + Stream flag = IsOverThreshold(score, /*threshold=*/0.5f, graph); + flag.SetName("flag"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "ThresholdingCalculator" + input_stream: "FLOAT:__stream_0" + output_stream: "FLAG:flag" + options { + [mediapipe.ThresholdingCalculatorOptions.ext] { threshold: 0.5 } + } + } + input_stream: "SCORE:__stream_0" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +} // namespace +} // namespace mediapipe::api2::builder