update ImageSegmenterGraph to always output confidence mask and optionally output category mask

PiperOrigin-RevId: 521679910
This commit is contained in:
MediaPipe Team 2023-04-04 00:23:03 -07:00 committed by Copybara-Service
parent c31a4681e5
commit 367ccbfdf3
9 changed files with 370 additions and 257 deletions

View File

@ -16,6 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "image_segmenter_result",
hdrs = ["image_segmenter_result.h"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/formats:image"],
)
# Docs for Mediapipe Tasks Image Segmenter
# https://developers.google.com/mediapipe/solutions/vision/image_segmenter
cc_library(
@ -25,6 +32,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":image_segmenter_graph",
":image_segmenter_result",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
@ -82,6 +90,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:graph_builder_utils",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util",
"@com_google_absl//absl/status",

View File

@ -80,7 +80,7 @@ void Sigmoid(absl::Span<const float> values,
[](float value) { return 1. / (1 + std::exp(-value)); });
}
std::vector<Image> ProcessForCategoryMaskCpu(const Shape& input_shape,
Image ProcessForCategoryMaskCpu(const Shape& input_shape,
const Shape& output_shape,
const SegmenterOptions& options,
const float* tensors_buffer) {
@ -135,7 +135,7 @@ std::vector<Image> ProcessForCategoryMaskCpu(const Shape& input_shape,
pixel = maximum_category_idx;
}
});
return {category_mask};
return category_mask;
}
std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
@ -209,7 +209,9 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
} // namespace
// Converts Tensors from a vector of Tensor to Segmentation.
// Converts Tensors from a vector of Tensor to Segmentation masks. The
// calculator always output confidence masks, and an optional category mask if
// CATEGORY_MASK is connected.
//
// Performs optional resizing to OUTPUT_SIZE dimension if provided,
// otherwise the segmented masks is the same size as input tensor.
@ -221,7 +223,12 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
// the size to resize masks to.
//
// Output:
// Segmentation: Segmentation proto.
// CONFIDENCE_MASK @Multiple: Multiple masks of float image where, for each
// mask, each pixel represents the prediction confidence, usually in the [0,
// 1] range.
// CATEGORY_MASK @Optional: A category mask of uint8 image where each pixel
// represents the class which the pixel in the original image was predicted to
// belong to.
//
// Options:
// See tensors_to_segmentation_calculator.proto
@ -231,13 +238,13 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
// calculator: "TensorsToSegmentationCalculator"
// input_stream: "TENSORS:tensors"
// input_stream: "OUTPUT_SIZE:size"
// output_stream: "SEGMENTATION:0:segmentation"
// output_stream: "SEGMENTATION:1:segmentation"
// output_stream: "CONFIDENCE_MASK:0:confidence_mask"
// output_stream: "CONFIDENCE_MASK:1:confidence_mask"
// output_stream: "CATEGORY_MASK:category_mask"
// options {
// [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
// segmenter_options {
// activation: SOFTMAX
// output_type: CONFIDENCE_MASK
// }
// }
// }
@ -248,7 +255,11 @@ class TensorsToSegmentationCalculator : public Node {
static constexpr Input<std::pair<int, int>>::Optional kOutputSizeIn{
"OUTPUT_SIZE"};
static constexpr Output<Image>::Multiple kSegmentationOut{"SEGMENTATION"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut);
static constexpr Output<Image>::Multiple kConfidenceMaskOut{
"CONFIDENCE_MASK"};
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
kConfidenceMaskOut, kCategoryMaskOut);
static absl::Status UpdateContract(CalculatorContract* cc);
@ -279,9 +290,13 @@ absl::Status TensorsToSegmentationCalculator::UpdateContract(
absl::Status TensorsToSegmentationCalculator::Open(
mediapipe::CalculatorContext* cc) {
options_ = cc->Options<TensorsToSegmentationCalculatorOptions>();
// TODO: remove deprecated output type support.
if (options_.segmenter_options().has_output_type()) {
RET_CHECK_NE(options_.segmenter_options().output_type(),
SegmenterOptions::UNSPECIFIED)
<< "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK].";
<< "Must specify output_type as one of "
"[CONFIDENCE_MASK|CATEGORY_MASK].";
}
#ifdef __EMSCRIPTEN__
MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_));
#endif // __EMSCRIPTEN__
@ -309,6 +324,10 @@ absl::Status TensorsToSegmentationCalculator::Process(
if (cc->Inputs().HasTag("OUTPUT_SIZE")) {
std::tie(output_width, output_height) = kOutputSizeIn(cc).Get();
}
// Use GPU postprocessing on web when Tensor is there already and has <= 12
// categories.
#ifdef __EMSCRIPTEN__
Shape output_shape = {
/* height= */ output_height,
/* width= */ output_width,
@ -316,10 +335,6 @@ absl::Status TensorsToSegmentationCalculator::Process(
SegmenterOptions::CATEGORY_MASK
? 1
: input_shape.channels};
// Use GPU postprocessing on web when Tensor is there already and has <= 12
// categories.
#ifdef __EMSCRIPTEN__
if (input_tensor.ready_as_opengl_texture_2d() && input_shape.channels <= 12) {
std::vector<std::unique_ptr<Image>> segmented_masks =
postprocessor_.GetSegmentationResultGpu(input_shape, output_shape,
@ -332,12 +347,43 @@ absl::Status TensorsToSegmentationCalculator::Process(
#endif // __EMSCRIPTEN__
// Otherwise, use CPU postprocessing.
const float* tensors_buffer = input_tensor.GetCpuReadView().buffer<float>();
// TODO: remove deprecated output type support.
if (options_.segmenter_options().has_output_type()) {
std::vector<Image> segmented_masks = GetSegmentationResultCpu(
input_shape, output_shape, input_tensor.GetCpuReadView().buffer<float>());
input_shape,
{/* height= */ output_height,
/* width= */ output_width,
/* channels= */ options_.segmenter_options().output_type() ==
SegmenterOptions::CATEGORY_MASK
? 1
: input_shape.channels},
input_tensor.GetCpuReadView().buffer<float>());
for (int i = 0; i < segmented_masks.size(); ++i) {
kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i]));
}
return absl::OkStatus();
}
std::vector<Image> confidence_masks =
ProcessForConfidenceMaskCpu(input_shape,
{/* height= */ output_height,
/* width= */ output_width,
/* channels= */ input_shape.channels},
options_.segmenter_options(), tensors_buffer);
for (int i = 0; i < confidence_masks.size(); ++i) {
kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i]));
}
if (cc->Outputs().HasTag("CATEGORY_MASK")) {
kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu(
input_shape,
{/* height= */ output_height,
/* width= */ output_width,
/* channels= */ 1},
options_.segmenter_options(), tensors_buffer));
}
return absl::OkStatus();
}
std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResultCpu(
@ -345,9 +391,9 @@ std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResultCpu(
const float* tensors_buffer) {
if (options_.segmenter_options().output_type() ==
SegmenterOptions::CATEGORY_MASK) {
return ProcessForCategoryMaskCpu(input_shape, output_shape,
return {ProcessForCategoryMaskCpu(input_shape, output_shape,
options_.segmenter_options(),
tensors_buffer);
tensors_buffer)};
} else {
return ProcessForConfidenceMaskCpu(input_shape, output_shape,
options_.segmenter_options(),

View File

@ -79,8 +79,9 @@ void PushTensorsToRunner(int tensor_height, int tensor_width,
std::vector<Packet> GetPackets(const CalculatorRunner& runner) {
std::vector<Packet> mask_packets;
for (int i = 0; i < runner.Outputs().NumEntries(); ++i) {
EXPECT_EQ(runner.Outputs().Get("SEGMENTATION", i).packets.size(), 1);
mask_packets.push_back(runner.Outputs().Get("SEGMENTATION", i).packets[0]);
EXPECT_EQ(runner.Outputs().Get("CONFIDENCE_MASK", i).packets.size(), 1);
mask_packets.push_back(
runner.Outputs().Get("CONFIDENCE_MASK", i).packets[0]);
}
return mask_packets;
}
@ -118,13 +119,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionOne) {
R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:segmentation"
output_stream: "CONFIDENCE_MASK:segmentation"
options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options {
activation: SOFTMAX
output_type: CONFIDENCE_MASK
}
segmenter_options { activation: SOFTMAX }
}
}
)pb"));
@ -145,13 +143,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionFive) {
R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:segmentation"
output_stream: "CONFIDENCE_MASK:segmentation"
options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options {
activation: SOFTMAX
output_type: CONFIDENCE_MASK
}
segmenter_options { activation: SOFTMAX }
}
}
)pb"));
@ -173,16 +168,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSoftmax) {
R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:0:segmented_mask_0"
output_stream: "SEGMENTATION:1:segmented_mask_1"
output_stream: "SEGMENTATION:2:segmented_mask_2"
output_stream: "SEGMENTATION:3:segmented_mask_3"
output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options {
activation: SOFTMAX
output_type: CONFIDENCE_MASK
}
segmenter_options { activation: SOFTMAX }
}
}
)pb"));
@ -218,16 +210,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithNone) {
R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:0:segmented_mask_0"
output_stream: "SEGMENTATION:1:segmented_mask_1"
output_stream: "SEGMENTATION:2:segmented_mask_2"
output_stream: "SEGMENTATION:3:segmented_mask_3"
output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options {
activation: NONE
output_type: CONFIDENCE_MASK
}
segmenter_options { activation: NONE }
}
}
)pb"));
@ -259,16 +248,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSigmoid) {
R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:0:segmented_mask_0"
output_stream: "SEGMENTATION:1:segmented_mask_1"
output_stream: "SEGMENTATION:2:segmented_mask_2"
output_stream: "SEGMENTATION:3:segmented_mask_3"
output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options {
activation: SIGMOID
output_type: CONFIDENCE_MASK
}
segmenter_options { activation: SIGMOID }
}
}
)pb"));
@ -301,13 +287,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) {
R"pb(
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "SEGMENTATION:segmentation"
output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
output_stream: "CATEGORY_MASK:segmentation"
options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options {
activation: NONE
output_type: CATEGORY_MASK
}
segmenter_options { activation: NONE }
}
}
)pb"));
@ -318,11 +305,11 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) {
tensor_height, tensor_width,
std::vector<float>(kTestValues.begin(), kTestValues.end()), &runner);
MP_ASSERT_OK(runner.Run());
ASSERT_EQ(runner.Outputs().NumEntries(), 1);
ASSERT_EQ(runner.Outputs().NumEntries(), 5);
// Largest element index is 3.
const int expected_index = 3;
const std::vector<int> buffer_indices = {0};
std::vector<Packet> packets = GetPackets(runner);
std::vector<Packet> packets = runner.Outputs().Tag("CATEGORY_MASK").packets;
EXPECT_THAT(packets, testing::ElementsAre(
Uint8ImagePacket(tensor_height, tensor_width,
expected_index, buffer_indices)));
@ -335,13 +322,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) {
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
input_stream: "TENSORS:tensors"
input_stream: "OUTPUT_SIZE:size"
output_stream: "SEGMENTATION:segmentation"
output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
output_stream: "CATEGORY_MASK:segmentation"
options {
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
segmenter_options {
activation: NONE
output_type: CATEGORY_MASK
}
segmenter_options { activation: NONE }
}
}
)pb"));
@ -367,7 +355,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) {
const std::vector<int> buffer_indices = {
0 * output_width + 0, 0 * output_width + 1, 1 * output_width + 0,
1 * output_width + 1};
std::vector<Packet> packets = GetPackets(runner);
std::vector<Packet> packets = runner.Outputs().Tag("CATEGORY_MASK").packets;
EXPECT_THAT(packets, testing::ElementsAre(
Uint8ImagePacket(output_height, output_width,
expected_index, buffer_indices)));

View File

@ -37,8 +37,10 @@ namespace vision {
namespace image_segmenter {
namespace {
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS";
constexpr char kConfidenceMasksStreamName[] = "confidence_masks";
constexpr char kCategoryMaskTag[] = "CATEGORY_MASK";
constexpr char kCategoryMaskStreamName[] = "category_mask";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE";
@ -51,7 +53,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image;
using ::mediapipe::NormalizedRect;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
image_segmenter::proto::ImageSegmenterGraphOptions;
@ -59,21 +60,24 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
bool enable_flow_limiting) {
bool output_category_mask, bool enable_flow_limiting) {
api2::builder::Graph graph;
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
graph.Out(kGroupedSegmentationTag);
task_subgraph.Out(kConfidenceMasksTag).SetName(kConfidenceMasksStreamName) >>
graph.Out(kConfidenceMasksTag);
if (output_category_mask) {
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag);
}
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
{kImageTag, kNormRectTag},
kGroupedSegmentationTag);
return tasks::core::AddFlowLimiterCalculator(
graph, task_subgraph, {kImageTag, kNormRectTag}, kConfidenceMasksTag);
}
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
@ -91,16 +95,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE);
options_proto->set_display_names_locale(options->display_names_locale);
switch (options->output_type) {
case ImageSegmenterOptions::OutputType::CATEGORY_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CATEGORY_MASK);
break;
case ImageSegmenterOptions::OutputType::CONFIDENCE_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CONFIDENCE_MASK);
break;
}
return options_proto;
}
@ -145,6 +139,7 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
tasks::core::PacketsCallback packets_callback = nullptr;
if (options->result_callback) {
auto result_callback = options->result_callback;
bool output_category_mask = options->output_category_mask;
packets_callback =
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) {
@ -156,34 +151,41 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return;
}
Packet segmented_masks =
status_or_packets.value()[kSegmentationStreamName];
Packet confidence_masks =
status_or_packets.value()[kConfidenceMasksStreamName];
std::optional<Image> category_mask;
if (output_category_mask) {
category_mask =
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
}
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(segmented_masks.Get<std::vector<Image>>(),
result_callback(
{{confidence_masks.Get<std::vector<Image>>(), category_mask}},
image_packet.Get<Image>(),
segmented_masks.Timestamp().Value() /
confidence_masks.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
};
}
auto image_segmenter =
core::VisionTaskApiFactory::Create<ImageSegmenter,
ImageSegmenterGraphOptionsProto>(
CreateGraphConfig(
std::move(options_proto),
std::move(options_proto), options->output_category_mask,
options->running_mode == core::RunningMode::LIVE_STREAM),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback));
if (!image_segmenter.ok()) {
return image_segmenter.status();
}
image_segmenter.value()->output_category_mask_ =
options->output_category_mask;
ASSIGN_OR_RETURN(
(*image_segmenter)->labels_,
GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig()));
return image_segmenter;
}
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
@ -201,11 +203,17 @@ absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
std::vector<Image> confidence_masks =
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
std::optional<Image> category_mask;
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
}
absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
mediapipe::Image image, int64 timestamp_ms,
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
@ -225,11 +233,17 @@ absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
std::vector<Image> confidence_masks =
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
std::optional<Image> category_mask;
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
}
absl::Status ImageSegmenter::SegmentAsync(
Image image, int64 timestamp_ms,
Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe {
@ -52,23 +53,14 @@ struct ImageSegmenterOptions {
// Metadata, if any. Defaults to English.
std::string display_names_locale = "en";
// The output type of segmentation results.
enum OutputType {
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK = 0,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK = 1,
};
OutputType output_type = OutputType::CATEGORY_MASK;
// Whether to output category mask.
bool output_category_mask = false;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(absl::StatusOr<std::vector<mediapipe::Image>>,
const Image&, int64)>
std::function<void(absl::StatusOr<ImageSegmenterResult>, const Image&,
int64_t)>
result_callback = nullptr;
};
@ -84,13 +76,9 @@ struct ImageSegmenterOptions {
// 1 or 3).
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
// attached to the metadata for input normalization.
// Output tensors:
// (kTfLiteUInt8/kTfLiteFloat32)
// - list of segmented masks.
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
// `channels`.
// - batch is always 1
// Output ImageSegmenterResult:
// Provides confidence masks and an optional category mask if
// `output_category_mask` is set true.
// An example of such model can be found at:
// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
@ -114,12 +102,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
absl::StatusOr<ImageSegmenterResult> Segment(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
@ -137,13 +121,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> SegmentForVideo(
mediapipe::Image image, int64 timestamp_ms,
absl::StatusOr<ImageSegmenterResult> SegmentForVideo(
mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
@ -164,17 +143,13 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// and will result in an invalid argument error being returned.
//
// The "result_callback" prvoides
// - A vector of segmented image masks.
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
// - An ImageSegmenterResult.
// - The const reference to the corresponding input image that the image
// segmentation runs on. Note that the const reference to the image will
// no longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms,
absl::Status SegmentAsync(mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
@ -182,9 +157,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
absl::Status Close() { return runner_->Close(); }
// Get the category label list of the ImageSegmenter can recognize. For
// CATEGORY_MASK type, the index in the category mask corresponds to the
// category in the label list. For CONFIDENCE_MASK type, the output mask list
// at index corresponds to the category in the label list.
// CATEGORY_MASK, the index in the category mask corresponds to the category
// in the label list. For CONFIDENCE_MASK, the output mask list at index
// corresponds to the category in the label list.
//
// If there is no labelmap provided in the model file, empty label list is
// returned.
@ -192,6 +167,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
private:
std::vector<std::string> labels_;
bool output_category_mask_;
};
} // namespace image_segmenter

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <memory>
#include <optional>
#include <type_traits>
#include <vector>
@ -42,6 +43,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "mediapipe/util/graph_builder_utils.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/label_map_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -65,10 +67,13 @@ using ::mediapipe::tasks::vision::image_segmenter::proto::
ImageSegmenterGraphOptions;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
using LabelItems = mediapipe::proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>;
constexpr char kSegmentationTag[] = "SEGMENTATION";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kConfidenceMaskTag[] = "CONFIDENCE_MASK";
constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS";
constexpr char kCategoryMaskTag[] = "CATEGORY_MASK";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageCpuTag[] = "IMAGE_CPU";
constexpr char kImageGpuTag[] = "IMAGE_GPU";
@ -80,7 +85,9 @@ constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
// Struct holding the different output streams produced by the image segmenter
// subgraph.
struct ImageSegmenterOutputs {
std::vector<Source<Image>> segmented_masks;
std::optional<std::vector<Source<Image>>> segmented_masks;
std::optional<std::vector<Source<Image>>> confidence_masks;
std::optional<Source<Image>> category_mask;
// The same as the input image, mainly used for live stream mode.
Source<Image> image;
};
@ -95,7 +102,9 @@ struct ImageAndTensorsOnDevice {
} // namespace
absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) {
if (options.segmenter_options().output_type() ==
// TODO: remove deprecated output type support.
if (options.segmenter_options().has_output_type() &&
options.segmenter_options().output_type() ==
SegmenterOptions::UNSPECIFIED) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
"`output_type` must not be UNSPECIFIED",
@ -133,9 +142,8 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
const core::ModelResources& model_resources,
TensorsToSegmentationCalculatorOptions* options) {
// Set default activation function NONE
options->mutable_segmenter_options()->set_output_type(
segmenter_option.segmenter_options().output_type());
options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE);
options->mutable_segmenter_options()->CopyFrom(
segmenter_option.segmenter_options());
// Find the custom metadata of ImageSegmenterOptions type in model metadata.
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
bool found_activation_in_metadata = false;
@ -317,12 +325,14 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
}
}
// An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic
// segmentation.
// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION.
// Users can retrieve segmented mask of only particular category/channel from
// SEGMENTATION, and users can also get all segmented masks from
// GROUPED_SEGMENTATION.
// An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs
// semantic segmentation. The graph always output confidence masks, and an
// optional category mask if CATEGORY_MASK is connected.
//
// Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and
// CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular
// category/channel from CONFIDENCE_MASK, and users can also get all segmented
// confidence masks from CONFIDENCE_MASKS.
// - Accepts CPU input images and outputs segmented masks on CPU.
//
// Inputs:
@ -334,11 +344,13 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
// @Optional: rect covering the whole image is used if not specified.
//
// Outputs:
// SEGMENTATION - mediapipe::Image @Multiple
// Segmented masks for individual category. Segmented mask of single
// CONFIDENCE_MASK - mediapipe::Image @Multiple
// Confidence masks for individual category. Confidence mask of single
// category can be accessed by index based output stream.
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image>
// The output segmented masks grouped in a vector.
// CONFIDENCE_MASKS - std::vector<mediapipe::Image>
// The output confidence masks grouped in a vector.
// CATEGORY_MASK - mediapipe::Image @Optional
// Optional Category mask.
// IMAGE - mediapipe::Image
// The image that image segmenter runs on.
//
@ -369,23 +381,39 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<ImageSegmenterGraphOptions>(sc));
Graph graph;
const auto& options = sc->Options<ImageSegmenterGraphOptions>();
ASSIGN_OR_RETURN(
auto output_streams,
BuildSegmentationTask(
sc->Options<ImageSegmenterGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
options, *model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
HasOutput(sc->OriginalNode(), kCategoryMaskTag), graph));
auto& merge_images_to_vector =
graph.AddNode("MergeImagesToVectorCalculator");
for (int i = 0; i < output_streams.segmented_masks.size(); ++i) {
output_streams.segmented_masks[i] >>
// TODO: remove deprecated output type support.
if (options.segmenter_options().has_output_type()) {
for (int i = 0; i < output_streams.segmented_masks->size(); ++i) {
output_streams.segmented_masks->at(i) >>
merge_images_to_vector[Input<Image>::Multiple("")][i];
output_streams.segmented_masks[i] >>
output_streams.segmented_masks->at(i) >>
graph[Output<Image>::Multiple(kSegmentationTag)][i];
}
merge_images_to_vector.Out("") >>
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
} else {
for (int i = 0; i < output_streams.confidence_masks->size(); ++i) {
output_streams.confidence_masks->at(i) >>
merge_images_to_vector[Input<Image>::Multiple("")][i];
output_streams.confidence_masks->at(i) >>
graph[Output<Image>::Multiple(kConfidenceMaskTag)][i];
}
merge_images_to_vector.Out("") >>
graph[Output<std::vector<Image>>(kConfidenceMasksTag)];
if (output_streams.category_mask) {
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
}
}
output_streams.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
@ -403,7 +431,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
const ImageSegmenterGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) {
Source<NormalizedRect> norm_rect_in, bool output_category_mask,
Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
// Adds preprocessing calculators and connects them to the graph input image
@ -435,6 +464,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag);
// Exports multiple segmented masks.
// TODO: remove deprecated output type support.
if (task_options.segmenter_options().has_output_type()) {
std::vector<Source<Image>> segmented_masks;
if (task_options.segmenter_options().output_type() ==
SegmenterOptions::CATEGORY_MASK) {
@ -450,7 +481,29 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
}
}
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
/*confidence_masks=*/std::nullopt,
/*category_mask=*/std::nullopt,
/*image=*/image_and_tensors.image};
} else {
ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
GetOutputTensor(model_resources));
int segmentation_streams_num = *output_tensor->shape()->rbegin();
std::vector<Source<Image>> confidence_masks;
confidence_masks.reserve(segmentation_streams_num);
for (int i = 0; i < segmentation_streams_num; ++i) {
confidence_masks.push_back(Source<Image>(
tensor_to_images[Output<Image>::Multiple(kConfidenceMaskTag)][i]));
}
return ImageSegmenterOutputs{
/*segmented_masks=*/std::nullopt,
/*confidence_masks=*/confidence_masks,
/*category_mask=*/
output_category_mask
? std::make_optional(
tensor_to_images[Output<Image>(kCategoryMaskTag)])
: std::nullopt,
/*image=*/image_and_tensors.image};
}
}
};

View File

@ -0,0 +1,43 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_
#include <optional>
#include "mediapipe/framework/formats/image.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_segmenter {
// The output result of ImageSegmenter
struct ImageSegmenterResult {
// Multiple masks of float image in VEC32F1 format where, for each mask, each
// pixel represents the prediction confidence, usually in the [0, 1] range.
std::vector<Image> confidence_masks;
// A category mask of uint8 image in GRAY8 format where each pixel represents
// the class which the pixel in the original image was predicted to belong to.
std::optional<Image> category_mask;
};
} // namespace image_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -256,7 +257,6 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
@ -278,15 +278,14 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->output_category_mask = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_masks, segmenter->Segment(image));
EXPECT_EQ(category_masks.size(), 1);
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_TRUE(result.category_mask.has_value());
cv::Mat actual_mask = mediapipe::formats::MatView(
category_masks[0].GetImageFrameSharedPtr().get());
result.category_mask->GetImageFrameSharedPtr().get());
cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
@ -303,12 +302,11 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 21);
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(result.confidence_masks.size(), 21);
cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE);
@ -317,7 +315,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
// Cat category index 8.
cv::Mat cat_mask = mediapipe::formats::MatView(
confidence_masks[8].GetImageFrameSharedPtr().get());
result.confidence_masks[8].GetImageFrameSharedPtr().get());
EXPECT_THAT(cat_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
@ -331,15 +329,14 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
MP_ASSERT_OK_AND_ASSIGN(auto result,
segmenter->Segment(image, image_processing_options));
EXPECT_EQ(confidence_masks.size(), 21);
EXPECT_EQ(result.confidence_masks.size(), 21);
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
@ -349,7 +346,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
// Cat category index 8.
cv::Mat cat_mask = mediapipe::formats::MatView(
confidence_masks[8].GetImageFrameSharedPtr().get());
result.confidence_masks[8].GetImageFrameSharedPtr().get());
EXPECT_THAT(cat_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
@ -361,7 +358,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
@ -384,12 +380,11 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 2);
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(result.confidence_masks.size(), 2);
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory,
@ -400,7 +395,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
// Selfie category index 1.
cv::Mat selfie_mask = mediapipe::formats::MatView(
confidence_masks[1].GetImageFrameSharedPtr().get());
result.confidence_masks[1].GetImageFrameSharedPtr().get());
EXPECT_THAT(selfie_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
@ -411,11 +406,10 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 1);
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(result.confidence_masks.size(), 1);
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory,
@ -425,7 +419,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
cv::Mat selfie_mask = mediapipe::formats::MatView(
confidence_masks[0].GetImageFrameSharedPtr().get());
result.confidence_masks[0].GetImageFrameSharedPtr().get());
EXPECT_THAT(selfie_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
@ -436,12 +430,11 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 1);
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(result.confidence_masks.size(), 1);
MP_ASSERT_OK(segmenter->Close());
cv::Mat expected_mask = cv::imread(
@ -452,7 +445,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
cv::Mat selfie_mask = mediapipe::formats::MatView(
confidence_masks[0].GetImageFrameSharedPtr().get());
result.confidence_masks[0].GetImageFrameSharedPtr().get());
EXPECT_THAT(selfie_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
@ -463,16 +456,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->output_category_mask = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image));
EXPECT_EQ(category_mask.size(), 1);
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_TRUE(result.category_mask.has_value());
MP_ASSERT_OK(segmenter->Close());
cv::Mat selfie_mask = mediapipe::formats::MatView(
category_mask[0].GetImageFrameSharedPtr().get());
result.category_mask->GetImageFrameSharedPtr().get());
cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory,
"portrait_selfie_segmentation_expected_category_mask.jpg"),
@ -487,16 +479,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->output_category_mask = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image));
EXPECT_EQ(category_mask.size(), 1);
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_TRUE(result.category_mask.has_value());
MP_ASSERT_OK(segmenter->Close());
cv::Mat selfie_mask = mediapipe::formats::MatView(
category_mask[0].GetImageFrameSharedPtr().get());
result.category_mask->GetImageFrameSharedPtr().get());
cv::Mat expected_mask = cv::imread(
JoinPath(
"./", kTestDataDirectory,
@ -512,14 +503,13 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 2);
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(result.confidence_masks.size(), 2);
cv::Mat hair_mask = mediapipe::formats::MatView(
confidence_masks[1].GetImageFrameSharedPtr().get());
result.confidence_masks[1].GetImageFrameSharedPtr().get());
MP_ASSERT_OK(segmenter->Close());
cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"),
@ -540,7 +530,6 @@ TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
@ -572,7 +561,7 @@ TEST_F(VideoModeTest, Succeeds) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->output_category_mask = true;
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
@ -580,11 +569,10 @@ TEST_F(VideoModeTest, Succeeds) {
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
cv::IMREAD_GRAYSCALE);
for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto category_masks,
segmenter->SegmentForVideo(image, i));
EXPECT_EQ(category_masks.size(), 1);
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->SegmentForVideo(image, i));
EXPECT_TRUE(result.category_mask.has_value());
cv::Mat actual_mask = mediapipe::formats::MatView(
category_masks[0].GetImageFrameSharedPtr().get());
result.category_mask->GetImageFrameSharedPtr().get());
EXPECT_THAT(actual_mask,
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity,
kGoldenMaskMagnificationFactor));
@ -601,11 +589,10 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback =
[](absl::StatusOr<std::vector<Image>> segmented_masks, const Image& image,
int64 timestamp_ms) {};
[](absl::StatusOr<ImageSegmenterResult> segmented_masks,
const Image& image, int64_t timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
@ -634,11 +621,9 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback =
[](absl::StatusOr<std::vector<Image>> segmented_masks, const Image& image,
int64 timestamp_ms) {};
options->result_callback = [](absl::StatusOr<ImageSegmenterResult> result,
const Image& image, int64_t timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK(segmenter->SegmentAsync(image, 1));
@ -660,20 +645,20 @@ TEST_F(LiveStreamModeTest, Succeeds) {
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"segmentation_input_rotation0.jpg")));
std::vector<std::vector<Image>> segmented_masks_results;
std::vector<Image> segmented_masks_results;
std::vector<std::pair<int, int>> image_sizes;
std::vector<int64> timestamps;
std::vector<int64_t> timestamps;
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->output_category_mask = true;
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback =
[&segmented_masks_results, &image_sizes, &timestamps](
absl::StatusOr<std::vector<Image>> segmented_masks,
const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(segmented_masks.status());
segmented_masks_results.push_back(std::move(segmented_masks).value());
options->result_callback = [&segmented_masks_results, &image_sizes,
&timestamps](
absl::StatusOr<ImageSegmenterResult> result,
const Image& image, int64_t timestamp_ms) {
MP_ASSERT_OK(result.status());
segmented_masks_results.push_back(std::move(*result->category_mask));
image_sizes.push_back({image.width(), image.height()});
timestamps.push_back(timestamp_ms);
};
@ -690,10 +675,9 @@ TEST_F(LiveStreamModeTest, Succeeds) {
cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
cv::IMREAD_GRAYSCALE);
for (const auto& segmented_masks : segmented_masks_results) {
EXPECT_EQ(segmented_masks.size(), 1);
for (const auto& category_mask : segmented_masks_results) {
cv::Mat actual_mask = mediapipe::formats::MatView(
segmented_masks[0].GetImageFrameSharedPtr().get());
category_mask.GetImageFrameSharedPtr().get());
EXPECT_THAT(actual_mask,
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity,
kGoldenMaskMagnificationFactor));
@ -702,7 +686,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
EXPECT_EQ(image_size.first, image.width());
EXPECT_EQ(image_size.second, image.height());
}
int64 timestamp_ms = -1;
int64_t timestamp_ms = -1;
for (const auto& timestamp : timestamps) {
EXPECT_GT(timestamp, timestamp_ms);
timestamp_ms = timestamp;

View File

@ -33,7 +33,7 @@ message SegmenterOptions {
CONFIDENCE_MASK = 2;
}
// Optional output mask type.
optional OutputType output_type = 1 [default = CATEGORY_MASK];
optional OutputType output_type = 1 [deprecated = true];
// Supported activation functions for filtering.
enum Activation {