segmentation smoothing stream utility function.
PiperOrigin-RevId: 569283980
This commit is contained in:
		
							parent
							
								
									636cf99a3e
								
							
						
					
					
						commit
						f78f24f576
					
				| 
						 | 
					@ -325,6 +325,33 @@ cc_test(
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_library(
 | 
				
			||||||
 | 
					    name = "segmentation_smoothing",
 | 
				
			||||||
 | 
					    srcs = ["segmentation_smoothing.cc"],
 | 
				
			||||||
 | 
					    hdrs = ["segmentation_smoothing.h"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/calculators/image:segmentation_smoothing_calculator",
 | 
				
			||||||
 | 
					        "//mediapipe/calculators/image:segmentation_smoothing_calculator_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:builder",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:image",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_test(
 | 
				
			||||||
 | 
					    name = "segmentation_smoothing_test",
 | 
				
			||||||
 | 
					    srcs = ["segmentation_smoothing_test.cc"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":segmentation_smoothing",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_framework",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:builder",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:image",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:gtest",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:gtest_main",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:parse_text_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:status_matchers",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cc_library(
 | 
					cc_library(
 | 
				
			||||||
    name = "split",
 | 
					    name = "split",
 | 
				
			||||||
    hdrs = ["split.h"],
 | 
					    hdrs = ["split.h"],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										24
									
								
								mediapipe/framework/api2/stream/segmentation_smoothing.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								mediapipe/framework/api2/stream/segmentation_smoothing.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,24 @@
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/stream/segmentation_smoothing.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/calculators/image/segmentation_smoothing_calculator.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/builder.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/image.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe::api2::builder {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Stream<Image> SmoothSegmentationMask(Stream<Image> mask,
 | 
				
			||||||
 | 
					                                     Stream<Image> previous_mask,
 | 
				
			||||||
 | 
					                                     float combine_with_previous_ratio,
 | 
				
			||||||
 | 
					                                     Graph& graph) {
 | 
				
			||||||
 | 
					  auto& smoothing_node = graph.AddNode("SegmentationSmoothingCalculator");
 | 
				
			||||||
 | 
					  auto& smoothing_node_opts =
 | 
				
			||||||
 | 
					      smoothing_node
 | 
				
			||||||
 | 
					          .GetOptions<mediapipe::SegmentationSmoothingCalculatorOptions>();
 | 
				
			||||||
 | 
					  smoothing_node_opts.set_combine_with_previous_ratio(
 | 
				
			||||||
 | 
					      combine_with_previous_ratio);
 | 
				
			||||||
 | 
					  mask.ConnectTo(smoothing_node.In("MASK"));
 | 
				
			||||||
 | 
					  previous_mask.ConnectTo(smoothing_node.In("MASK_PREVIOUS"));
 | 
				
			||||||
 | 
					  return smoothing_node.Out("MASK_SMOOTHED").Cast<Image>();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace mediapipe::api2::builder
 | 
				
			||||||
							
								
								
									
										19
									
								
								mediapipe/framework/api2/stream/segmentation_smoothing.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								mediapipe/framework/api2/stream/segmentation_smoothing.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,19 @@
 | 
				
			||||||
 | 
					#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_SEGMENTATION_SMOOTHING_H_
 | 
				
			||||||
 | 
					#define MEDIAPIPE_FRAMEWORK_API2_STREAM_SEGMENTATION_SMOOTHING_H_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/builder.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/image.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe::api2::builder {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Updates @graph to smooth @mask by mixing @mask and @previous_mask based on an
 | 
				
			||||||
 | 
					// uncertantity probability estimate calculated per each @mask pixel multiplied
 | 
				
			||||||
 | 
					// by @combine_with_previous_ratio.
 | 
				
			||||||
 | 
					Stream<Image> SmoothSegmentationMask(Stream<Image> mask,
 | 
				
			||||||
 | 
					                                     Stream<Image> previous_mask,
 | 
				
			||||||
 | 
					                                     float combine_with_previous_ratio,
 | 
				
			||||||
 | 
					                                     Graph& graph);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace mediapipe::api2::builder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#endif  // MEDIAPIPE_FRAMEWORK_API2_STREAM_SEGMENTATION_SMOOTHING_H_
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,43 @@
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/stream/segmentation_smoothing.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/builder.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/image.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gmock.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gtest.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/parse_text_proto.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe::api2::builder {
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using ::mediapipe::Image;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SegmentationSmoothing, VerifyConfig) {
 | 
				
			||||||
 | 
					  Graph graph;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Stream<Image> mask = graph.In("MASK").Cast<Image>();
 | 
				
			||||||
 | 
					  Stream<Image> prev_mask = graph.In("PREV_MASK").Cast<Image>();
 | 
				
			||||||
 | 
					  Stream<Image> smoothed_mask = SmoothSegmentationMask(
 | 
				
			||||||
 | 
					      mask, prev_mask, /*combine_with_previous_ratio=*/0.1f, graph);
 | 
				
			||||||
 | 
					  smoothed_mask.SetName("smoothed_mask");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_THAT(graph.GetConfig(),
 | 
				
			||||||
 | 
					              EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					                node {
 | 
				
			||||||
 | 
					                  calculator: "SegmentationSmoothingCalculator"
 | 
				
			||||||
 | 
					                  input_stream: "MASK:__stream_0"
 | 
				
			||||||
 | 
					                  input_stream: "MASK_PREVIOUS:__stream_1"
 | 
				
			||||||
 | 
					                  output_stream: "MASK_SMOOTHED:smoothed_mask"
 | 
				
			||||||
 | 
					                  options {
 | 
				
			||||||
 | 
					                    [mediapipe.SegmentationSmoothingCalculatorOptions.ext] {
 | 
				
			||||||
 | 
					                      combine_with_previous_ratio: 0.1
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                  }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                input_stream: "MASK:__stream_0"
 | 
				
			||||||
 | 
					                input_stream: "PREV_MASK:__stream_1"
 | 
				
			||||||
 | 
					              )pb")));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace
 | 
				
			||||||
 | 
					}  // namespace mediapipe::api2::builder
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user