segmentation smoothing stream utility function.

PiperOrigin-RevId: 569283980
This commit is contained in:
MediaPipe Team 2023-09-28 14:00:41 -07:00 committed by Copybara-Service
parent 636cf99a3e
commit f78f24f576
4 changed files with 113 additions and 0 deletions

View File

@ -325,6 +325,33 @@ cc_test(
],
)
cc_library(
name = "segmentation_smoothing",
srcs = ["segmentation_smoothing.cc"],
hdrs = ["segmentation_smoothing.h"],
deps = [
"//mediapipe/calculators/image:segmentation_smoothing_calculator",
"//mediapipe/calculators/image:segmentation_smoothing_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
],
)
cc_test(
name = "segmentation_smoothing_test",
srcs = ["segmentation_smoothing_test.cc"],
deps = [
":segmentation_smoothing",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
],
)
cc_library(
name = "split",
hdrs = ["split.h"],

View File

@ -0,0 +1,24 @@
#include "mediapipe/framework/api2/stream/segmentation_smoothing.h"
#include "mediapipe/calculators/image/segmentation_smoothing_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h"
namespace mediapipe::api2::builder {
Stream<Image> SmoothSegmentationMask(Stream<Image> mask,
Stream<Image> previous_mask,
float combine_with_previous_ratio,
Graph& graph) {
auto& smoothing_node = graph.AddNode("SegmentationSmoothingCalculator");
auto& smoothing_node_opts =
smoothing_node
.GetOptions<mediapipe::SegmentationSmoothingCalculatorOptions>();
smoothing_node_opts.set_combine_with_previous_ratio(
combine_with_previous_ratio);
mask.ConnectTo(smoothing_node.In("MASK"));
previous_mask.ConnectTo(smoothing_node.In("MASK_PREVIOUS"));
return smoothing_node.Out("MASK_SMOOTHED").Cast<Image>();
}
} // namespace mediapipe::api2::builder

View File

@ -0,0 +1,19 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_SEGMENTATION_SMOOTHING_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_SEGMENTATION_SMOOTHING_H_
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h"
namespace mediapipe::api2::builder {
// Updates @graph to smooth @mask by mixing @mask and @previous_mask based on an
// uncertantity probability estimate calculated per each @mask pixel multiplied
// by @combine_with_previous_ratio.
Stream<Image> SmoothSegmentationMask(Stream<Image> mask,
Stream<Image> previous_mask,
float combine_with_previous_ratio,
Graph& graph);
} // namespace mediapipe::api2::builder
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_SEGMENTATION_SMOOTHING_H_

View File

@ -0,0 +1,43 @@
#include "mediapipe/framework/api2/stream/segmentation_smoothing.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
namespace mediapipe::api2::builder {
namespace {
using ::mediapipe::Image;
TEST(SegmentationSmoothing, VerifyConfig) {
Graph graph;
Stream<Image> mask = graph.In("MASK").Cast<Image>();
Stream<Image> prev_mask = graph.In("PREV_MASK").Cast<Image>();
Stream<Image> smoothed_mask = SmoothSegmentationMask(
mask, prev_mask, /*combine_with_previous_ratio=*/0.1f, graph);
smoothed_mask.SetName("smoothed_mask");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "SegmentationSmoothingCalculator"
input_stream: "MASK:__stream_0"
input_stream: "MASK_PREVIOUS:__stream_1"
output_stream: "MASK_SMOOTHED:smoothed_mask"
options {
[mediapipe.SegmentationSmoothingCalculatorOptions.ext] {
combine_with_previous_ratio: 0.1
}
}
}
input_stream: "MASK:__stream_0"
input_stream: "PREV_MASK:__stream_1"
)pb")));
}
} // namespace
} // namespace mediapipe::api2::builder