threshold stream utility function.
PiperOrigin-RevId: 566417914
This commit is contained in:
		
							parent
							
								
									58a7790081
								
							
						
					
					
						commit
						36f78f6e4a
					
				| 
						 | 
				
			
			@ -167,3 +167,28 @@ cc_test(
 | 
			
		|||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "threshold",
 | 
			
		||||
    srcs = ["threshold.cc"],
 | 
			
		||||
    hdrs = ["threshold.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/calculators/util:thresholding_calculator",
 | 
			
		||||
        "//mediapipe/calculators/util:thresholding_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "threshold_test",
 | 
			
		||||
    srcs = ["threshold_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":threshold",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
        "//mediapipe/framework/port:gtest",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
        "//mediapipe/framework/port:status_matchers",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										17
									
								
								mediapipe/framework/api2/stream/threshold.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								mediapipe/framework/api2/stream/threshold.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,17 @@
 | 
			
		|||
#include "mediapipe/framework/api2/stream/threshold.h"
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/calculators/util/thresholding_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::api2::builder {
 | 
			
		||||
 | 
			
		||||
Stream<bool> IsOverThreshold(Stream<float> value, double threshold,
 | 
			
		||||
                             mediapipe::api2::builder::Graph& graph) {
 | 
			
		||||
  auto& node = graph.AddNode("ThresholdingCalculator");
 | 
			
		||||
  auto& node_opts = node.GetOptions<mediapipe::ThresholdingCalculatorOptions>();
 | 
			
		||||
  node_opts.set_threshold(threshold);
 | 
			
		||||
  value.ConnectTo(node.In("FLOAT"));
 | 
			
		||||
  return node.Out("FLAG").Cast<bool>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::api2::builder
 | 
			
		||||
							
								
								
									
										13
									
								
								mediapipe/framework/api2/stream/threshold.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								mediapipe/framework/api2/stream/threshold.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,13 @@
 | 
			
		|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_THRESHOLD_H_
 | 
			
		||||
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_THRESHOLD_H_
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::api2::builder {
 | 
			
		||||
 | 
			
		||||
Stream<bool> IsOverThreshold(Stream<float> value, double threshold,
 | 
			
		||||
                             Graph& graph);
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::api2::builder
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_FRAMEWORK_API2_STREAM_THRESHOLD_H_
 | 
			
		||||
							
								
								
									
										39
									
								
								mediapipe/framework/api2/stream/threshold_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								mediapipe/framework/api2/stream/threshold_test.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,39 @@
 | 
			
		|||
#include "mediapipe/framework/api2/stream/threshold.h"
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_matchers.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::api2::builder {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
TEST(ThresholdTest, IsOverThresholdTest) {
 | 
			
		||||
  mediapipe::api2::builder::Graph graph;
 | 
			
		||||
 | 
			
		||||
  Stream<float> score = graph.In("SCORE").Cast<float>();
 | 
			
		||||
  Stream<bool> flag = IsOverThreshold(score, /*threshold=*/0.5f, graph);
 | 
			
		||||
  flag.SetName("flag");
 | 
			
		||||
 | 
			
		||||
  EXPECT_THAT(
 | 
			
		||||
      graph.GetConfig(),
 | 
			
		||||
      EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
        node {
 | 
			
		||||
          calculator: "ThresholdingCalculator"
 | 
			
		||||
          input_stream: "FLOAT:__stream_0"
 | 
			
		||||
          output_stream: "FLAG:flag"
 | 
			
		||||
          options {
 | 
			
		||||
            [mediapipe.ThresholdingCalculatorOptions.ext] { threshold: 0.5 }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        input_stream: "SCORE:__stream_0"
 | 
			
		||||
      )pb")));
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph calcualtor_graph;
 | 
			
		||||
  MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace mediapipe::api2::builder
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user