update ImageSegmenterGraph to always output confidence mask and optionally output category mask

PiperOrigin-RevId: 521679910
This commit is contained in:
MediaPipe Team 2023-04-04 00:23:03 -07:00 committed by Copybara-Service
parent c31a4681e5
commit 367ccbfdf3
9 changed files with 370 additions and 257 deletions

View File

@ -16,6 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
cc_library(
name = "image_segmenter_result",
hdrs = ["image_segmenter_result.h"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/formats:image"],
)
# Docs for Mediapipe Tasks Image Segmenter # Docs for Mediapipe Tasks Image Segmenter
# https://developers.google.com/mediapipe/solutions/vision/image_segmenter # https://developers.google.com/mediapipe/solutions/vision/image_segmenter
cc_library( cc_library(
@ -25,6 +32,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":image_segmenter_graph", ":image_segmenter_graph",
":image_segmenter_result",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
@ -82,6 +90,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc", "//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc",
"//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:graph_builder_utils",
"//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util", "//mediapipe/util:label_map_util",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",

View File

@ -80,7 +80,7 @@ void Sigmoid(absl::Span<const float> values,
[](float value) { return 1. / (1 + std::exp(-value)); }); [](float value) { return 1. / (1 + std::exp(-value)); });
} }
std::vector<Image> ProcessForCategoryMaskCpu(const Shape& input_shape, Image ProcessForCategoryMaskCpu(const Shape& input_shape,
const Shape& output_shape, const Shape& output_shape,
const SegmenterOptions& options, const SegmenterOptions& options,
const float* tensors_buffer) { const float* tensors_buffer) {
@ -135,7 +135,7 @@ std::vector<Image> ProcessForCategoryMaskCpu(const Shape& input_shape,
pixel = maximum_category_idx; pixel = maximum_category_idx;
} }
}); });
return {category_mask}; return category_mask;
} }
std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape, std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
@ -209,7 +209,9 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
} // namespace } // namespace
// Converts Tensors from a vector of Tensor to Segmentation. // Converts Tensors from a vector of Tensor to Segmentation masks. The
// calculator always output confidence masks, and an optional category mask if
// CATEGORY_MASK is connected.
// //
// Performs optional resizing to OUTPUT_SIZE dimension if provided, // Performs optional resizing to OUTPUT_SIZE dimension if provided,
// otherwise the segmented masks is the same size as input tensor. // otherwise the segmented masks is the same size as input tensor.
@ -221,7 +223,12 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
// the size to resize masks to. // the size to resize masks to.
// //
// Output: // Output:
// Segmentation: Segmentation proto. // CONFIDENCE_MASK @Multiple: Multiple masks of float image where, for each
// mask, each pixel represents the prediction confidence, usually in the [0,
// 1] range.
// CATEGORY_MASK @Optional: A category mask of uint8 image where each pixel
// represents the class which the pixel in the original image was predicted to
// belong to.
// //
// Options: // Options:
// See tensors_to_segmentation_calculator.proto // See tensors_to_segmentation_calculator.proto
@ -231,13 +238,13 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
// calculator: "TensorsToSegmentationCalculator" // calculator: "TensorsToSegmentationCalculator"
// input_stream: "TENSORS:tensors" // input_stream: "TENSORS:tensors"
// input_stream: "OUTPUT_SIZE:size" // input_stream: "OUTPUT_SIZE:size"
// output_stream: "SEGMENTATION:0:segmentation" // output_stream: "CONFIDENCE_MASK:0:confidence_mask"
// output_stream: "SEGMENTATION:1:segmentation" // output_stream: "CONFIDENCE_MASK:1:confidence_mask"
// output_stream: "CATEGORY_MASK:category_mask"
// options { // options {
// [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { // [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
// segmenter_options { // segmenter_options {
// activation: SOFTMAX // activation: SOFTMAX
// output_type: CONFIDENCE_MASK
// } // }
// } // }
// } // }
@ -248,7 +255,11 @@ class TensorsToSegmentationCalculator : public Node {
static constexpr Input<std::pair<int, int>>::Optional kOutputSizeIn{ static constexpr Input<std::pair<int, int>>::Optional kOutputSizeIn{
"OUTPUT_SIZE"}; "OUTPUT_SIZE"};
static constexpr Output<Image>::Multiple kSegmentationOut{"SEGMENTATION"}; static constexpr Output<Image>::Multiple kSegmentationOut{"SEGMENTATION"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut); static constexpr Output<Image>::Multiple kConfidenceMaskOut{
"CONFIDENCE_MASK"};
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
kConfidenceMaskOut, kCategoryMaskOut);
static absl::Status UpdateContract(CalculatorContract* cc); static absl::Status UpdateContract(CalculatorContract* cc);
@ -279,9 +290,13 @@ absl::Status TensorsToSegmentationCalculator::UpdateContract(
absl::Status TensorsToSegmentationCalculator::Open( absl::Status TensorsToSegmentationCalculator::Open(
mediapipe::CalculatorContext* cc) { mediapipe::CalculatorContext* cc) {
options_ = cc->Options<TensorsToSegmentationCalculatorOptions>(); options_ = cc->Options<TensorsToSegmentationCalculatorOptions>();
// TODO: remove deprecated output type support.
if (options_.segmenter_options().has_output_type()) {
RET_CHECK_NE(options_.segmenter_options().output_type(), RET_CHECK_NE(options_.segmenter_options().output_type(),
SegmenterOptions::UNSPECIFIED) SegmenterOptions::UNSPECIFIED)
<< "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK]."; << "Must specify output_type as one of "
"[CONFIDENCE_MASK|CATEGORY_MASK].";
}
#ifdef __EMSCRIPTEN__ #ifdef __EMSCRIPTEN__
MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_)); MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_));
#endif // __EMSCRIPTEN__ #endif // __EMSCRIPTEN__
@ -309,6 +324,10 @@ absl::Status TensorsToSegmentationCalculator::Process(
if (cc->Inputs().HasTag("OUTPUT_SIZE")) { if (cc->Inputs().HasTag("OUTPUT_SIZE")) {
std::tie(output_width, output_height) = kOutputSizeIn(cc).Get(); std::tie(output_width, output_height) = kOutputSizeIn(cc).Get();
} }
// Use GPU postprocessing on web when Tensor is there already and has <= 12
// categories.
#ifdef __EMSCRIPTEN__
Shape output_shape = { Shape output_shape = {
/* height= */ output_height, /* height= */ output_height,
/* width= */ output_width, /* width= */ output_width,
@ -316,10 +335,6 @@ absl::Status TensorsToSegmentationCalculator::Process(
SegmenterOptions::CATEGORY_MASK SegmenterOptions::CATEGORY_MASK
? 1 ? 1
: input_shape.channels}; : input_shape.channels};
// Use GPU postprocessing on web when Tensor is there already and has <= 12
// categories.
#ifdef __EMSCRIPTEN__
if (input_tensor.ready_as_opengl_texture_2d() && input_shape.channels <= 12) { if (input_tensor.ready_as_opengl_texture_2d() && input_shape.channels <= 12) {
std::vector<std::unique_ptr<Image>> segmented_masks = std::vector<std::unique_ptr<Image>> segmented_masks =
postprocessor_.GetSegmentationResultGpu(input_shape, output_shape, postprocessor_.GetSegmentationResultGpu(input_shape, output_shape,
@ -332,22 +347,53 @@ absl::Status TensorsToSegmentationCalculator::Process(
#endif // __EMSCRIPTEN__ #endif // __EMSCRIPTEN__
// Otherwise, use CPU postprocessing. // Otherwise, use CPU postprocessing.
const float* tensors_buffer = input_tensor.GetCpuReadView().buffer<float>();
// TODO: remove deprecated output type support.
if (options_.segmenter_options().has_output_type()) {
std::vector<Image> segmented_masks = GetSegmentationResultCpu( std::vector<Image> segmented_masks = GetSegmentationResultCpu(
input_shape, output_shape, input_tensor.GetCpuReadView().buffer<float>()); input_shape,
{/* height= */ output_height,
/* width= */ output_width,
/* channels= */ options_.segmenter_options().output_type() ==
SegmenterOptions::CATEGORY_MASK
? 1
: input_shape.channels},
input_tensor.GetCpuReadView().buffer<float>());
for (int i = 0; i < segmented_masks.size(); ++i) { for (int i = 0; i < segmented_masks.size(); ++i) {
kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i])); kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i]));
} }
return absl::OkStatus(); return absl::OkStatus();
} }
std::vector<Image> confidence_masks =
ProcessForConfidenceMaskCpu(input_shape,
{/* height= */ output_height,
/* width= */ output_width,
/* channels= */ input_shape.channels},
options_.segmenter_options(), tensors_buffer);
for (int i = 0; i < confidence_masks.size(); ++i) {
kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i]));
}
if (cc->Outputs().HasTag("CATEGORY_MASK")) {
kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu(
input_shape,
{/* height= */ output_height,
/* width= */ output_width,
/* channels= */ 1},
options_.segmenter_options(), tensors_buffer));
}
return absl::OkStatus();
}
std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResultCpu( std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResultCpu(
const Shape& input_shape, const Shape& output_shape, const Shape& input_shape, const Shape& output_shape,
const float* tensors_buffer) { const float* tensors_buffer) {
if (options_.segmenter_options().output_type() == if (options_.segmenter_options().output_type() ==
SegmenterOptions::CATEGORY_MASK) { SegmenterOptions::CATEGORY_MASK) {
return ProcessForCategoryMaskCpu(input_shape, output_shape, return {ProcessForCategoryMaskCpu(input_shape, output_shape,
options_.segmenter_options(), options_.segmenter_options(),
tensors_buffer); tensors_buffer)};
} else { } else {
return ProcessForConfidenceMaskCpu(input_shape, output_shape, return ProcessForConfidenceMaskCpu(input_shape, output_shape,
options_.segmenter_options(), options_.segmenter_options(),

View File

@ -79,8 +79,9 @@ void PushTensorsToRunner(int tensor_height, int tensor_width,
std::vector<Packet> GetPackets(const CalculatorRunner& runner) { std::vector<Packet> GetPackets(const CalculatorRunner& runner) {
std::vector<Packet> mask_packets; std::vector<Packet> mask_packets;
for (int i = 0; i < runner.Outputs().NumEntries(); ++i) { for (int i = 0; i < runner.Outputs().NumEntries(); ++i) {
EXPECT_EQ(runner.Outputs().Get("SEGMENTATION", i).packets.size(), 1); EXPECT_EQ(runner.Outputs().Get("CONFIDENCE_MASK", i).packets.size(), 1);
mask_packets.push_back(runner.Outputs().Get("SEGMENTATION", i).packets[0]); mask_packets.push_back(
runner.Outputs().Get("CONFIDENCE_MASK", i).packets[0]);
} }
return mask_packets; return mask_packets;
} }
@ -118,13 +119,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionOne) {
R"pb( R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:segmentation" output_stream: "CONFIDENCE_MASK:segmentation"
options { options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options { segmenter_options { activation: SOFTMAX }
activation: SOFTMAX
output_type: CONFIDENCE_MASK
}
} }
} }
)pb")); )pb"));
@ -145,13 +143,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionFive) {
R"pb( R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:segmentation" output_stream: "CONFIDENCE_MASK:segmentation"
options { options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options { segmenter_options { activation: SOFTMAX }
activation: SOFTMAX
output_type: CONFIDENCE_MASK
}
} }
} }
)pb")); )pb"));
@ -173,16 +168,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSoftmax) {
R"pb( R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:0:segmented_mask_0" output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
output_stream: "SEGMENTATION:1:segmented_mask_1" output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
output_stream: "SEGMENTATION:2:segmented_mask_2" output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
output_stream: "SEGMENTATION:3:segmented_mask_3" output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
options { options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options { segmenter_options { activation: SOFTMAX }
activation: SOFTMAX
output_type: CONFIDENCE_MASK
}
} }
} }
)pb")); )pb"));
@ -218,16 +210,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithNone) {
R"pb( R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:0:segmented_mask_0" output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
output_stream: "SEGMENTATION:1:segmented_mask_1" output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
output_stream: "SEGMENTATION:2:segmented_mask_2" output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
output_stream: "SEGMENTATION:3:segmented_mask_3" output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
options { options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options { segmenter_options { activation: NONE }
activation: NONE
output_type: CONFIDENCE_MASK
}
} }
} }
)pb")); )pb"));
@ -259,16 +248,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSigmoid) {
R"pb( R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:0:segmented_mask_0" output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
output_stream: "SEGMENTATION:1:segmented_mask_1" output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
output_stream: "SEGMENTATION:2:segmented_mask_2" output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
output_stream: "SEGMENTATION:3:segmented_mask_3" output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
options { options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options { segmenter_options { activation: SIGMOID }
activation: SIGMOID
output_type: CONFIDENCE_MASK
}
} }
} }
)pb")); )pb"));
@ -301,13 +287,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) {
R"pb( R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:segmentation" output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
output_stream: "CATEGORY_MASK:segmentation"
options { options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options { segmenter_options { activation: NONE }
activation: NONE
output_type: CATEGORY_MASK
}
} }
} }
)pb")); )pb"));
@ -318,11 +305,11 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) {
tensor_height, tensor_width, tensor_height, tensor_width,
std::vector<float>(kTestValues.begin(), kTestValues.end()), &runner); std::vector<float>(kTestValues.begin(), kTestValues.end()), &runner);
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
ASSERT_EQ(runner.Outputs().NumEntries(), 1); ASSERT_EQ(runner.Outputs().NumEntries(), 5);
// Largest element index is 3. // Largest element index is 3.
const int expected_index = 3; const int expected_index = 3;
const std::vector<int> buffer_indices = {0}; const std::vector<int> buffer_indices = {0};
std::vector<Packet> packets = GetPackets(runner); std::vector<Packet> packets = runner.Outputs().Tag("CATEGORY_MASK").packets;
EXPECT_THAT(packets, testing::ElementsAre( EXPECT_THAT(packets, testing::ElementsAre(
Uint8ImagePacket(tensor_height, tensor_width, Uint8ImagePacket(tensor_height, tensor_width,
expected_index, buffer_indices))); expected_index, buffer_indices)));
@ -335,13 +322,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) {
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
input_stream: "OUTPUT_SIZE:size" input_stream: "OUTPUT_SIZE:size"
output_stream: "SEGMENTATION:segmentation" output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
output_stream: "CATEGORY_MASK:segmentation"
options { options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options { segmenter_options { activation: NONE }
activation: NONE
output_type: CATEGORY_MASK
}
} }
} }
)pb")); )pb"));
@ -367,7 +355,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) {
const std::vector<int> buffer_indices = { const std::vector<int> buffer_indices = {
0 * output_width + 0, 0 * output_width + 1, 1 * output_width + 0, 0 * output_width + 0, 0 * output_width + 1, 1 * output_width + 0,
1 * output_width + 1}; 1 * output_width + 1};
std::vector<Packet> packets = GetPackets(runner); std::vector<Packet> packets = runner.Outputs().Tag("CATEGORY_MASK").packets;
EXPECT_THAT(packets, testing::ElementsAre( EXPECT_THAT(packets, testing::ElementsAre(
Uint8ImagePacket(output_height, output_width, Uint8ImagePacket(output_height, output_width,
expected_index, buffer_indices))); expected_index, buffer_indices)));

View File

@ -37,8 +37,10 @@ namespace vision {
namespace image_segmenter { namespace image_segmenter {
namespace { namespace {
constexpr char kSegmentationStreamName[] = "segmented_mask_out"; constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr char kConfidenceMasksStreamName[] = "confidence_masks";
constexpr char kCategoryMaskTag[] = "CATEGORY_MASK";
constexpr char kCategoryMaskStreamName[] = "category_mask";
constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
@ -51,7 +53,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image; using ::mediapipe::Image;
using ::mediapipe::NormalizedRect; using ::mediapipe::NormalizedRect;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
image_segmenter::proto::ImageSegmenterGraphOptions; image_segmenter::proto::ImageSegmenterGraphOptions;
@ -59,21 +60,24 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
CalculatorGraphConfig CreateGraphConfig( CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageSegmenterGraphOptionsProto> options, std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
bool enable_flow_limiting) { bool output_category_mask, bool enable_flow_limiting) {
api2::builder::Graph graph; api2::builder::Graph graph;
auto& task_subgraph = graph.AddNode(kSubgraphTypeName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap( task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
options.get()); options.get());
graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> task_subgraph.Out(kConfidenceMasksTag).SetName(kConfidenceMasksStreamName) >>
graph.Out(kGroupedSegmentationTag); graph.Out(kConfidenceMasksTag);
if (output_category_mask) {
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag);
}
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag); graph.Out(kImageTag);
if (enable_flow_limiting) { if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, return tasks::core::AddFlowLimiterCalculator(
{kImageTag, kNormRectTag}, graph, task_subgraph, {kImageTag, kNormRectTag}, kConfidenceMasksTag);
kGroupedSegmentationTag);
} }
graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
@ -91,16 +95,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
options_proto->mutable_base_options()->set_use_stream_mode( options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE); options->running_mode != core::RunningMode::IMAGE);
options_proto->set_display_names_locale(options->display_names_locale); options_proto->set_display_names_locale(options->display_names_locale);
switch (options->output_type) {
case ImageSegmenterOptions::OutputType::CATEGORY_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CATEGORY_MASK);
break;
case ImageSegmenterOptions::OutputType::CONFIDENCE_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CONFIDENCE_MASK);
break;
}
return options_proto; return options_proto;
} }
@ -145,6 +139,7 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
tasks::core::PacketsCallback packets_callback = nullptr; tasks::core::PacketsCallback packets_callback = nullptr;
if (options->result_callback) { if (options->result_callback) {
auto result_callback = options->result_callback; auto result_callback = options->result_callback;
bool output_category_mask = options->output_category_mask;
packets_callback = packets_callback =
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) { [=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) { if (!status_or_packets.ok()) {
@ -156,34 +151,41 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return; return;
} }
Packet segmented_masks = Packet confidence_masks =
status_or_packets.value()[kSegmentationStreamName]; status_or_packets.value()[kConfidenceMasksStreamName];
std::optional<Image> category_mask;
if (output_category_mask) {
category_mask =
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
}
Packet image_packet = status_or_packets.value()[kImageOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(segmented_masks.Get<std::vector<Image>>(), result_callback(
{{confidence_masks.Get<std::vector<Image>>(), category_mask}},
image_packet.Get<Image>(), image_packet.Get<Image>(),
segmented_masks.Timestamp().Value() / confidence_masks.Timestamp().Value() /
kMicroSecondsPerMilliSecond); kMicroSecondsPerMilliSecond);
}; };
} }
auto image_segmenter = auto image_segmenter =
core::VisionTaskApiFactory::Create<ImageSegmenter, core::VisionTaskApiFactory::Create<ImageSegmenter,
ImageSegmenterGraphOptionsProto>( ImageSegmenterGraphOptionsProto>(
CreateGraphConfig( CreateGraphConfig(
std::move(options_proto), std::move(options_proto), options->output_category_mask,
options->running_mode == core::RunningMode::LIVE_STREAM), options->running_mode == core::RunningMode::LIVE_STREAM),
std::move(options->base_options.op_resolver), options->running_mode, std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback)); std::move(packets_callback));
if (!image_segmenter.ok()) { if (!image_segmenter.ok()) {
return image_segmenter.status(); return image_segmenter.status();
} }
image_segmenter.value()->output_category_mask_ =
options->output_category_mask;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
(*image_segmenter)->labels_, (*image_segmenter)->labels_,
GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig())); GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig()));
return image_segmenter; return image_segmenter;
} }
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment( absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -201,11 +203,17 @@ absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))}, {{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}})); MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>(); std::vector<Image> confidence_masks =
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
std::optional<Image> category_mask;
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
} }
absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo( absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
@ -225,11 +233,17 @@ absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>(); std::vector<Image> confidence_masks =
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
std::optional<Image> category_mask;
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
} }
absl::Status ImageSegmenter::SegmentAsync( absl::Status ImageSegmenter::SegmentAsync(
Image image, int64 timestamp_ms, Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
namespace mediapipe { namespace mediapipe {
@ -52,23 +53,14 @@ struct ImageSegmenterOptions {
// Metadata, if any. Defaults to English. // Metadata, if any. Defaults to English.
std::string display_names_locale = "en"; std::string display_names_locale = "en";
// The output type of segmentation results. // Whether to output category mask.
enum OutputType { bool output_category_mask = false;
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK = 0,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK = 1,
};
OutputType output_type = OutputType::CATEGORY_MASK;
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.
std::function<void(absl::StatusOr<std::vector<mediapipe::Image>>, std::function<void(absl::StatusOr<ImageSegmenterResult>, const Image&,
const Image&, int64)> int64_t)>
result_callback = nullptr; result_callback = nullptr;
}; };
@ -84,13 +76,9 @@ struct ImageSegmenterOptions {
// 1 or 3). // 1 or 3).
// - if type is kTfLiteFloat32, NormalizationOptions are required to be // - if type is kTfLiteFloat32, NormalizationOptions are required to be
// attached to the metadata for input normalization. // attached to the metadata for input normalization.
// Output tensors: // Output ImageSegmenterResult:
// (kTfLiteUInt8/kTfLiteFloat32) // Provides confidence masks and an optional category mask if
// - list of segmented masks. // `output_category_mask` is set true.
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
// `channels`.
// - batch is always 1
// An example of such model can be found at: // An example of such model can be found at:
// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
@ -114,12 +102,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// setting its 'rotation_degrees' field. Note that specifying a // setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported // region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned. // and will result in an invalid argument error being returned.
//
// If the output_type is CATEGORY_MASK, the returned vector of images is absl::StatusOr<ImageSegmenterResult> Segment(
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -137,13 +121,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// setting its 'rotation_degrees' field. Note that specifying a // setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported // region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned. // and will result in an invalid argument error being returned.
// absl::StatusOr<ImageSegmenterResult> SegmentForVideo(
// If the output_type is CATEGORY_MASK, the returned vector of images is mediapipe::Image image, int64_t timestamp_ms,
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> SegmentForVideo(
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -164,17 +143,13 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// and will result in an invalid argument error being returned. // and will result in an invalid argument error being returned.
// //
// The "result_callback" prvoides // The "result_callback" prvoides
// - A vector of segmented image masks. // - An ImageSegmenterResult.
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
// - The const reference to the corresponding input image that the image // - The const reference to the corresponding input image that the image
// segmentation runs on. Note that the const reference to the image will // segmentation runs on. Note that the const reference to the image will
// no longer be valid when the callback returns. To access the image data // no longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image. // outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds. // - The input timestamp in milliseconds.
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms, absl::Status SegmentAsync(mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt); image_processing_options = std::nullopt);
@ -182,9 +157,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
// Get the category label list of the ImageSegmenter can recognize. For // Get the category label list of the ImageSegmenter can recognize. For
// CATEGORY_MASK type, the index in the category mask corresponds to the // CATEGORY_MASK, the index in the category mask corresponds to the category
// category in the label list. For CONFIDENCE_MASK type, the output mask list // in the label list. For CONFIDENCE_MASK, the output mask list at index
// at index corresponds to the category in the label list. // corresponds to the category in the label list.
// //
// If there is no labelmap provided in the model file, empty label list is // If there is no labelmap provided in the model file, empty label list is
// returned. // returned.
@ -192,6 +167,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
private: private:
std::vector<std::string> labels_; std::vector<std::string> labels_;
bool output_category_mask_;
}; };
} // namespace image_segmenter } // namespace image_segmenter

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <memory> #include <memory>
#include <optional>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
@ -42,6 +43,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h" #include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "mediapipe/util/graph_builder_utils.h"
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/label_map_util.h" #include "mediapipe/util/label_map_util.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
@ -65,10 +67,13 @@ using ::mediapipe::tasks::vision::image_segmenter::proto::
ImageSegmenterGraphOptions; ImageSegmenterGraphOptions;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
using ::tflite::TensorMetadata; using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>; using LabelItems = mediapipe::proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>;
constexpr char kSegmentationTag[] = "SEGMENTATION"; constexpr char kSegmentationTag[] = "SEGMENTATION";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kConfidenceMaskTag[] = "CONFIDENCE_MASK";
constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS";
constexpr char kCategoryMaskTag[] = "CATEGORY_MASK";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kImageCpuTag[] = "IMAGE_CPU"; constexpr char kImageCpuTag[] = "IMAGE_CPU";
constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kImageGpuTag[] = "IMAGE_GPU";
@ -80,7 +85,9 @@ constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
// Struct holding the different output streams produced by the image segmenter // Struct holding the different output streams produced by the image segmenter
// subgraph. // subgraph.
struct ImageSegmenterOutputs { struct ImageSegmenterOutputs {
std::vector<Source<Image>> segmented_masks; std::optional<std::vector<Source<Image>>> segmented_masks;
std::optional<std::vector<Source<Image>>> confidence_masks;
std::optional<Source<Image>> category_mask;
// The same as the input image, mainly used for live stream mode. // The same as the input image, mainly used for live stream mode.
Source<Image> image; Source<Image> image;
}; };
@ -95,7 +102,9 @@ struct ImageAndTensorsOnDevice {
} // namespace } // namespace
absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) { absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) {
if (options.segmenter_options().output_type() == // TODO: remove deprecated output type support.
if (options.segmenter_options().has_output_type() &&
options.segmenter_options().output_type() ==
SegmenterOptions::UNSPECIFIED) { SegmenterOptions::UNSPECIFIED) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
"`output_type` must not be UNSPECIFIED", "`output_type` must not be UNSPECIFIED",
@ -133,9 +142,8 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
const core::ModelResources& model_resources, const core::ModelResources& model_resources,
TensorsToSegmentationCalculatorOptions* options) { TensorsToSegmentationCalculatorOptions* options) {
// Set default activation function NONE // Set default activation function NONE
options->mutable_segmenter_options()->set_output_type( options->mutable_segmenter_options()->CopyFrom(
segmenter_option.segmenter_options().output_type()); segmenter_option.segmenter_options());
options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE);
// Find the custom metadata of ImageSegmenterOptions type in model metadata. // Find the custom metadata of ImageSegmenterOptions type in model metadata.
const auto* metadata_extractor = model_resources.GetMetadataExtractor(); const auto* metadata_extractor = model_resources.GetMetadataExtractor();
bool found_activation_in_metadata = false; bool found_activation_in_metadata = false;
@ -317,12 +325,14 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
} }
} }
// An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic // An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs
// segmentation. // semantic segmentation. The graph always output confidence masks, and an
// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. // optional category mask if CATEGORY_MASK is connected.
// Users can retrieve segmented mask of only particular category/channel from //
// SEGMENTATION, and users can also get all segmented masks from // Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and
// GROUPED_SEGMENTATION. // CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular
// category/channel from CONFIDENCE_MASK, and users can also get all segmented
// confidence masks from CONFIDENCE_MASKS.
// - Accepts CPU input images and outputs segmented masks on CPU. // - Accepts CPU input images and outputs segmented masks on CPU.
// //
// Inputs: // Inputs:
@ -334,11 +344,13 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
// @Optional: rect covering the whole image is used if not specified. // @Optional: rect covering the whole image is used if not specified.
// //
// Outputs: // Outputs:
// SEGMENTATION - mediapipe::Image @Multiple // CONFIDENCE_MASK - mediapipe::Image @Multiple
// Segmented masks for individual category. Segmented mask of single // Confidence masks for individual category. Confidence mask of single
// category can be accessed by index based output stream. // category can be accessed by index based output stream.
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image> // CONFIDENCE_MASKS - std::vector<mediapipe::Image>
// The output segmented masks grouped in a vector. // The output confidence masks grouped in a vector.
// CATEGORY_MASK - mediapipe::Image @Optional
// Optional Category mask.
// IMAGE - mediapipe::Image // IMAGE - mediapipe::Image
// The image that image segmenter runs on. // The image that image segmenter runs on.
// //
@ -369,23 +381,39 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(const auto* model_resources, ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<ImageSegmenterGraphOptions>(sc)); CreateModelResources<ImageSegmenterGraphOptions>(sc));
Graph graph; Graph graph;
const auto& options = sc->Options<ImageSegmenterGraphOptions>();
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_streams, auto output_streams,
BuildSegmentationTask( BuildSegmentationTask(
sc->Options<ImageSegmenterGraphOptions>(), *model_resources, options, *model_resources, graph[Input<Image>(kImageTag)],
graph[Input<Image>(kImageTag)], graph[Input<NormalizedRect>::Optional(kNormRectTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); HasOutput(sc->OriginalNode(), kCategoryMaskTag), graph));
auto& merge_images_to_vector = auto& merge_images_to_vector =
graph.AddNode("MergeImagesToVectorCalculator"); graph.AddNode("MergeImagesToVectorCalculator");
for (int i = 0; i < output_streams.segmented_masks.size(); ++i) { // TODO: remove deprecated output type support.
output_streams.segmented_masks[i] >> if (options.segmenter_options().has_output_type()) {
for (int i = 0; i < output_streams.segmented_masks->size(); ++i) {
output_streams.segmented_masks->at(i) >>
merge_images_to_vector[Input<Image>::Multiple("")][i]; merge_images_to_vector[Input<Image>::Multiple("")][i];
output_streams.segmented_masks[i] >> output_streams.segmented_masks->at(i) >>
graph[Output<Image>::Multiple(kSegmentationTag)][i]; graph[Output<Image>::Multiple(kSegmentationTag)][i];
} }
merge_images_to_vector.Out("") >> merge_images_to_vector.Out("") >>
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)]; graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
} else {
for (int i = 0; i < output_streams.confidence_masks->size(); ++i) {
output_streams.confidence_masks->at(i) >>
merge_images_to_vector[Input<Image>::Multiple("")][i];
output_streams.confidence_masks->at(i) >>
graph[Output<Image>::Multiple(kConfidenceMaskTag)][i];
}
merge_images_to_vector.Out("") >>
graph[Output<std::vector<Image>>(kConfidenceMasksTag)];
if (output_streams.category_mask) {
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
}
}
output_streams.image >> graph[Output<Image>(kImageTag)]; output_streams.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -403,7 +431,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask( absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
const ImageSegmenterGraphOptions& task_options, const ImageSegmenterGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in, const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) { Source<NormalizedRect> norm_rect_in, bool output_category_mask,
Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
// Adds preprocessing calculators and connects them to the graph input image // Adds preprocessing calculators and connects them to the graph input image
@ -435,6 +464,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag); image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag);
// Exports multiple segmented masks. // Exports multiple segmented masks.
// TODO: remove deprecated output type support.
if (task_options.segmenter_options().has_output_type()) {
std::vector<Source<Image>> segmented_masks; std::vector<Source<Image>> segmented_masks;
if (task_options.segmenter_options().output_type() == if (task_options.segmenter_options().output_type() ==
SegmenterOptions::CATEGORY_MASK) { SegmenterOptions::CATEGORY_MASK) {
@ -450,7 +481,29 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
} }
} }
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
/*confidence_masks=*/std::nullopt,
/*category_mask=*/std::nullopt,
/*image=*/image_and_tensors.image}; /*image=*/image_and_tensors.image};
} else {
ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
GetOutputTensor(model_resources));
int segmentation_streams_num = *output_tensor->shape()->rbegin();
std::vector<Source<Image>> confidence_masks;
confidence_masks.reserve(segmentation_streams_num);
for (int i = 0; i < segmentation_streams_num; ++i) {
confidence_masks.push_back(Source<Image>(
tensor_to_images[Output<Image>::Multiple(kConfidenceMaskTag)][i]));
}
return ImageSegmenterOutputs{
/*segmented_masks=*/std::nullopt,
/*confidence_masks=*/confidence_masks,
/*category_mask=*/
output_category_mask
? std::make_optional(
tensor_to_images[Output<Image>(kCategoryMaskTag)])
: std::nullopt,
/*image=*/image_and_tensors.image};
}
} }
}; };

View File

@ -0,0 +1,43 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_
#include <optional>
#include "mediapipe/framework/formats/image.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_segmenter {
// The output result of ImageSegmenter
struct ImageSegmenterResult {
// Multiple masks of float image in VEC32F1 format where, for each mask, each
// pixel represents the prediction confidence, usually in the [0, 1] range.
std::vector<Image> confidence_masks;
// A category mask of uint8 image in GRAY8 format where each pixel represents
// the class which the pixel in the original image was predicted to belong to.
std::optional<Image> category_mask;
};
} // namespace image_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -256,7 +257,6 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
@ -278,15 +278,14 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->output_category_mask = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_masks, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(category_masks.size(), 1); EXPECT_TRUE(result.category_mask.has_value());
cv::Mat actual_mask = mediapipe::formats::MatView( cv::Mat actual_mask = mediapipe::formats::MatView(
category_masks[0].GetImageFrameSharedPtr().get()); result.category_mask->GetImageFrameSharedPtr().get());
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
@ -303,12 +302,11 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 21); EXPECT_EQ(result.confidence_masks.size(), 21);
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE); JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE);
@ -317,7 +315,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
// Cat category index 8. // Cat category index 8.
cv::Mat cat_mask = mediapipe::formats::MatView( cv::Mat cat_mask = mediapipe::formats::MatView(
confidence_masks[8].GetImageFrameSharedPtr().get()); result.confidence_masks[8].GetImageFrameSharedPtr().get());
EXPECT_THAT(cat_mask, EXPECT_THAT(cat_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
@ -331,15 +329,14 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
ImageProcessingOptions image_processing_options; ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90; image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, MP_ASSERT_OK_AND_ASSIGN(auto result,
segmenter->Segment(image, image_processing_options)); segmenter->Segment(image, image_processing_options));
EXPECT_EQ(confidence_masks.size(), 21); EXPECT_EQ(result.confidence_masks.size(), 21);
cv::Mat expected_mask = cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
@ -349,7 +346,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
// Cat category index 8. // Cat category index 8.
cv::Mat cat_mask = mediapipe::formats::MatView( cv::Mat cat_mask = mediapipe::formats::MatView(
confidence_masks[8].GetImageFrameSharedPtr().get()); result.confidence_masks[8].GetImageFrameSharedPtr().get());
EXPECT_THAT(cat_mask, EXPECT_THAT(cat_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
@ -361,7 +358,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
@ -384,12 +380,11 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 2); EXPECT_EQ(result.confidence_masks.size(), 2);
cv::Mat expected_mask = cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, cv::imread(JoinPath("./", kTestDataDirectory,
@ -400,7 +395,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
// Selfie category index 1. // Selfie category index 1.
cv::Mat selfie_mask = mediapipe::formats::MatView( cv::Mat selfie_mask = mediapipe::formats::MatView(
confidence_masks[1].GetImageFrameSharedPtr().get()); result.confidence_masks[1].GetImageFrameSharedPtr().get());
EXPECT_THAT(selfie_mask, EXPECT_THAT(selfie_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
@ -411,11 +406,10 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 1); EXPECT_EQ(result.confidence_masks.size(), 1);
cv::Mat expected_mask = cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, cv::imread(JoinPath("./", kTestDataDirectory,
@ -425,7 +419,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
cv::Mat selfie_mask = mediapipe::formats::MatView( cv::Mat selfie_mask = mediapipe::formats::MatView(
confidence_masks[0].GetImageFrameSharedPtr().get()); result.confidence_masks[0].GetImageFrameSharedPtr().get());
EXPECT_THAT(selfie_mask, EXPECT_THAT(selfie_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
@ -436,12 +430,11 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentation); JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 1); EXPECT_EQ(result.confidence_masks.size(), 1);
MP_ASSERT_OK(segmenter->Close()); MP_ASSERT_OK(segmenter->Close());
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
@ -452,7 +445,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
cv::Mat selfie_mask = mediapipe::formats::MatView( cv::Mat selfie_mask = mediapipe::formats::MatView(
confidence_masks[0].GetImageFrameSharedPtr().get()); result.confidence_masks[0].GetImageFrameSharedPtr().get());
EXPECT_THAT(selfie_mask, EXPECT_THAT(selfie_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
@ -463,16 +456,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentation); JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->output_category_mask = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(category_mask.size(), 1); EXPECT_TRUE(result.category_mask.has_value());
MP_ASSERT_OK(segmenter->Close()); MP_ASSERT_OK(segmenter->Close());
cv::Mat selfie_mask = mediapipe::formats::MatView( cv::Mat selfie_mask = mediapipe::formats::MatView(
category_mask[0].GetImageFrameSharedPtr().get()); result.category_mask->GetImageFrameSharedPtr().get());
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, JoinPath("./", kTestDataDirectory,
"portrait_selfie_segmentation_expected_category_mask.jpg"), "portrait_selfie_segmentation_expected_category_mask.jpg"),
@ -487,16 +479,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape); JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->output_category_mask = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(category_mask.size(), 1); EXPECT_TRUE(result.category_mask.has_value());
MP_ASSERT_OK(segmenter->Close()); MP_ASSERT_OK(segmenter->Close());
cv::Mat selfie_mask = mediapipe::formats::MatView( cv::Mat selfie_mask = mediapipe::formats::MatView(
category_mask[0].GetImageFrameSharedPtr().get()); result.category_mask->GetImageFrameSharedPtr().get());
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
JoinPath( JoinPath(
"./", kTestDataDirectory, "./", kTestDataDirectory,
@ -512,14 +503,13 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata); JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 2); EXPECT_EQ(result.confidence_masks.size(), 2);
cv::Mat hair_mask = mediapipe::formats::MatView( cv::Mat hair_mask = mediapipe::formats::MatView(
confidence_masks[1].GetImageFrameSharedPtr().get()); result.confidence_masks[1].GetImageFrameSharedPtr().get());
MP_ASSERT_OK(segmenter->Close()); MP_ASSERT_OK(segmenter->Close());
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"), JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"),
@ -540,7 +530,6 @@ TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->running_mode = core::RunningMode::VIDEO; options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
@ -572,7 +561,7 @@ TEST_F(VideoModeTest, Succeeds) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->output_category_mask = true;
options->running_mode = core::RunningMode::VIDEO; options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
@ -580,11 +569,10 @@ TEST_F(VideoModeTest, Succeeds) {
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
cv::IMREAD_GRAYSCALE); cv::IMREAD_GRAYSCALE);
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto category_masks, MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->SegmentForVideo(image, i));
segmenter->SegmentForVideo(image, i)); EXPECT_TRUE(result.category_mask.has_value());
EXPECT_EQ(category_masks.size(), 1);
cv::Mat actual_mask = mediapipe::formats::MatView( cv::Mat actual_mask = mediapipe::formats::MatView(
category_masks[0].GetImageFrameSharedPtr().get()); result.category_mask->GetImageFrameSharedPtr().get());
EXPECT_THAT(actual_mask, EXPECT_THAT(actual_mask,
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity,
kGoldenMaskMagnificationFactor)); kGoldenMaskMagnificationFactor));
@ -601,11 +589,10 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = options->result_callback =
[](absl::StatusOr<std::vector<Image>> segmented_masks, const Image& image, [](absl::StatusOr<ImageSegmenterResult> segmented_masks,
int64 timestamp_ms) {}; const Image& image, int64_t timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
@ -634,11 +621,9 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = options->result_callback = [](absl::StatusOr<ImageSegmenterResult> result,
[](absl::StatusOr<std::vector<Image>> segmented_masks, const Image& image, const Image& image, int64_t timestamp_ms) {};
int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK(segmenter->SegmentAsync(image, 1)); MP_ASSERT_OK(segmenter->SegmentAsync(image, 1));
@ -660,20 +645,20 @@ TEST_F(LiveStreamModeTest, Succeeds) {
Image image, Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"segmentation_input_rotation0.jpg"))); "segmentation_input_rotation0.jpg")));
std::vector<std::vector<Image>> segmented_masks_results; std::vector<Image> segmented_masks_results;
std::vector<std::pair<int, int>> image_sizes; std::vector<std::pair<int, int>> image_sizes;
std::vector<int64> timestamps; std::vector<int64_t> timestamps;
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->output_category_mask = true;
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = options->result_callback = [&segmented_masks_results, &image_sizes,
[&segmented_masks_results, &image_sizes, &timestamps]( &timestamps](
absl::StatusOr<std::vector<Image>> segmented_masks, absl::StatusOr<ImageSegmenterResult> result,
const Image& image, int64 timestamp_ms) { const Image& image, int64_t timestamp_ms) {
MP_ASSERT_OK(segmented_masks.status()); MP_ASSERT_OK(result.status());
segmented_masks_results.push_back(std::move(segmented_masks).value()); segmented_masks_results.push_back(std::move(*result->category_mask));
image_sizes.push_back({image.width(), image.height()}); image_sizes.push_back({image.width(), image.height()});
timestamps.push_back(timestamp_ms); timestamps.push_back(timestamp_ms);
}; };
@ -690,10 +675,9 @@ TEST_F(LiveStreamModeTest, Succeeds) {
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
cv::IMREAD_GRAYSCALE); cv::IMREAD_GRAYSCALE);
for (const auto& segmented_masks : segmented_masks_results) { for (const auto& category_mask : segmented_masks_results) {
EXPECT_EQ(segmented_masks.size(), 1);
cv::Mat actual_mask = mediapipe::formats::MatView( cv::Mat actual_mask = mediapipe::formats::MatView(
segmented_masks[0].GetImageFrameSharedPtr().get()); category_mask.GetImageFrameSharedPtr().get());
EXPECT_THAT(actual_mask, EXPECT_THAT(actual_mask,
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity,
kGoldenMaskMagnificationFactor)); kGoldenMaskMagnificationFactor));
@ -702,7 +686,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
EXPECT_EQ(image_size.first, image.width()); EXPECT_EQ(image_size.first, image.width());
EXPECT_EQ(image_size.second, image.height()); EXPECT_EQ(image_size.second, image.height());
} }
int64 timestamp_ms = -1; int64_t timestamp_ms = -1;
for (const auto& timestamp : timestamps) { for (const auto& timestamp : timestamps) {
EXPECT_GT(timestamp, timestamp_ms); EXPECT_GT(timestamp, timestamp_ms);
timestamp_ms = timestamp; timestamp_ms = timestamp;

View File

@ -33,7 +33,7 @@ message SegmenterOptions {
CONFIDENCE_MASK = 2; CONFIDENCE_MASK = 2;
} }
// Optional output mask type. // Optional output mask type.
optional OutputType output_type = 1 [default = CATEGORY_MASK]; optional OutputType output_type = 1 [deprecated = true];
// Supported activation functions for filtering. // Supported activation functions for filtering.
enum Activation { enum Activation {