segmentation smoothing stream utility function.
PiperOrigin-RevId: 569283980
This commit is contained in:
		
							parent
							
								
									636cf99a3e
								
							
						
					
					
						commit
						f78f24f576
					
				| 
						 | 
				
			
			@ -325,6 +325,33 @@ cc_test(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "segmentation_smoothing",
 | 
			
		||||
    srcs = ["segmentation_smoothing.cc"],
 | 
			
		||||
    hdrs = ["segmentation_smoothing.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/calculators/image:segmentation_smoothing_calculator",
 | 
			
		||||
        "//mediapipe/calculators/image:segmentation_smoothing_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "segmentation_smoothing_test",
 | 
			
		||||
    srcs = ["segmentation_smoothing_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":segmentation_smoothing",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
        "//mediapipe/framework/port:gtest",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
        "//mediapipe/framework/port:status_matchers",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "split",
 | 
			
		||||
    hdrs = ["split.h"],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										24
									
								
								mediapipe/framework/api2/stream/segmentation_smoothing.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								mediapipe/framework/api2/stream/segmentation_smoothing.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,24 @@
 | 
			
		|||
#include "mediapipe/framework/api2/stream/segmentation_smoothing.h"
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/calculators/image/segmentation_smoothing_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::api2::builder {
 | 
			
		||||
 | 
			
		||||
Stream<Image> SmoothSegmentationMask(Stream<Image> mask,
 | 
			
		||||
                                     Stream<Image> previous_mask,
 | 
			
		||||
                                     float combine_with_previous_ratio,
 | 
			
		||||
                                     Graph& graph) {
 | 
			
		||||
  auto& smoothing_node = graph.AddNode("SegmentationSmoothingCalculator");
 | 
			
		||||
  auto& smoothing_node_opts =
 | 
			
		||||
      smoothing_node
 | 
			
		||||
          .GetOptions<mediapipe::SegmentationSmoothingCalculatorOptions>();
 | 
			
		||||
  smoothing_node_opts.set_combine_with_previous_ratio(
 | 
			
		||||
      combine_with_previous_ratio);
 | 
			
		||||
  mask.ConnectTo(smoothing_node.In("MASK"));
 | 
			
		||||
  previous_mask.ConnectTo(smoothing_node.In("MASK_PREVIOUS"));
 | 
			
		||||
  return smoothing_node.Out("MASK_SMOOTHED").Cast<Image>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::api2::builder
 | 
			
		||||
							
								
								
									
										19
									
								
								mediapipe/framework/api2/stream/segmentation_smoothing.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								mediapipe/framework/api2/stream/segmentation_smoothing.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,19 @@
 | 
			
		|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_SEGMENTATION_SMOOTHING_H_
 | 
			
		||||
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_SEGMENTATION_SMOOTHING_H_
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::api2::builder {
 | 
			
		||||
 | 
			
		||||
// Updates @graph to smooth @mask by mixing @mask and @previous_mask based on an
 | 
			
		||||
// uncertantity probability estimate calculated per each @mask pixel multiplied
 | 
			
		||||
// by @combine_with_previous_ratio.
 | 
			
		||||
Stream<Image> SmoothSegmentationMask(Stream<Image> mask,
 | 
			
		||||
                                     Stream<Image> previous_mask,
 | 
			
		||||
                                     float combine_with_previous_ratio,
 | 
			
		||||
                                     Graph& graph);
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::api2::builder
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_FRAMEWORK_API2_STREAM_SEGMENTATION_SMOOTHING_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,43 @@
 | 
			
		|||
#include "mediapipe/framework/api2/stream/segmentation_smoothing.h"
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::api2::builder {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::Image;
 | 
			
		||||
 | 
			
		||||
TEST(SegmentationSmoothing, VerifyConfig) {
 | 
			
		||||
  Graph graph;
 | 
			
		||||
 | 
			
		||||
  Stream<Image> mask = graph.In("MASK").Cast<Image>();
 | 
			
		||||
  Stream<Image> prev_mask = graph.In("PREV_MASK").Cast<Image>();
 | 
			
		||||
  Stream<Image> smoothed_mask = SmoothSegmentationMask(
 | 
			
		||||
      mask, prev_mask, /*combine_with_previous_ratio=*/0.1f, graph);
 | 
			
		||||
  smoothed_mask.SetName("smoothed_mask");
 | 
			
		||||
 | 
			
		||||
  EXPECT_THAT(graph.GetConfig(),
 | 
			
		||||
              EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
                node {
 | 
			
		||||
                  calculator: "SegmentationSmoothingCalculator"
 | 
			
		||||
                  input_stream: "MASK:__stream_0"
 | 
			
		||||
                  input_stream: "MASK_PREVIOUS:__stream_1"
 | 
			
		||||
                  output_stream: "MASK_SMOOTHED:smoothed_mask"
 | 
			
		||||
                  options {
 | 
			
		||||
                    [mediapipe.SegmentationSmoothingCalculatorOptions.ext] {
 | 
			
		||||
                      combine_with_previous_ratio: 0.1
 | 
			
		||||
                    }
 | 
			
		||||
                  }
 | 
			
		||||
                }
 | 
			
		||||
                input_stream: "MASK:__stream_0"
 | 
			
		||||
                input_stream: "PREV_MASK:__stream_1"
 | 
			
		||||
              )pb")));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace mediapipe::api2::builder
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user