threshold stream utility function.
PiperOrigin-RevId: 566417914
This commit is contained in:
parent
58a7790081
commit
36f78f6e4a
|
@ -167,3 +167,28 @@ cc_test(
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "threshold",
|
||||||
|
srcs = ["threshold.cc"],
|
||||||
|
hdrs = ["threshold.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/util:thresholding_calculator",
|
||||||
|
"//mediapipe/calculators/util:thresholding_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "threshold_test",
|
||||||
|
srcs = ["threshold_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":threshold",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status_matchers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
17
mediapipe/framework/api2/stream/threshold.cc
Normal file
17
mediapipe/framework/api2/stream/threshold.cc
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
#include "mediapipe/framework/api2/stream/threshold.h"
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/util/thresholding_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
|
||||||
|
namespace mediapipe::api2::builder {
|
||||||
|
|
||||||
|
Stream<bool> IsOverThreshold(Stream<float> value, double threshold,
|
||||||
|
mediapipe::api2::builder::Graph& graph) {
|
||||||
|
auto& node = graph.AddNode("ThresholdingCalculator");
|
||||||
|
auto& node_opts = node.GetOptions<mediapipe::ThresholdingCalculatorOptions>();
|
||||||
|
node_opts.set_threshold(threshold);
|
||||||
|
value.ConnectTo(node.In("FLOAT"));
|
||||||
|
return node.Out("FLAG").Cast<bool>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::api2::builder
|
13
mediapipe/framework/api2/stream/threshold.h
Normal file
13
mediapipe/framework/api2/stream/threshold.h
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_THRESHOLD_H_
|
||||||
|
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_THRESHOLD_H_
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
|
||||||
|
namespace mediapipe::api2::builder {
|
||||||
|
|
||||||
|
Stream<bool> IsOverThreshold(Stream<float> value, double threshold,
|
||||||
|
Graph& graph);
|
||||||
|
|
||||||
|
} // namespace mediapipe::api2::builder
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_THRESHOLD_H_
|
39
mediapipe/framework/api2/stream/threshold_test.cc
Normal file
39
mediapipe/framework/api2/stream/threshold_test.cc
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
#include "mediapipe/framework/api2/stream/threshold.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe::api2::builder {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(ThresholdTest, IsOverThresholdTest) {
|
||||||
|
mediapipe::api2::builder::Graph graph;
|
||||||
|
|
||||||
|
Stream<float> score = graph.In("SCORE").Cast<float>();
|
||||||
|
Stream<bool> flag = IsOverThreshold(score, /*threshold=*/0.5f, graph);
|
||||||
|
flag.SetName("flag");
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
graph.GetConfig(),
|
||||||
|
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "ThresholdingCalculator"
|
||||||
|
input_stream: "FLOAT:__stream_0"
|
||||||
|
output_stream: "FLAG:flag"
|
||||||
|
options {
|
||||||
|
[mediapipe.ThresholdingCalculatorOptions.ext] { threshold: 0.5 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_stream: "SCORE:__stream_0"
|
||||||
|
)pb")));
|
||||||
|
|
||||||
|
CalculatorGraph calcualtor_graph;
|
||||||
|
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe::api2::builder
|
Loading…
Reference in New Issue
Block a user