Image segmenter output both confidence masks and category mask optionally.

PiperOrigin-RevId: 522227345
This commit is contained in:
MediaPipe Team 2023-04-05 20:30:05 -07:00 committed by Copybara-Service
parent 7fe87936e5
commit d5def9e24d
6 changed files with 139 additions and 73 deletions

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
@ -210,8 +211,9 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
} // namespace } // namespace
// Converts Tensors from a vector of Tensor to Segmentation masks. The // Converts Tensors from a vector of Tensor to Segmentation masks. The
// calculator always output confidence masks, and an optional category mask if // calculator can output optional confidence masks if CONFIDENCE_MASK is
// CATEGORY_MASK is connected. // connected, and an optional category mask if CATEGORY_MASK is connected. At
// least one of CONFIDENCE_MASK and CATEGORY_MASK must be connected.
// //
// Performs optional resizing to OUTPUT_SIZE dimension if provided, // Performs optional resizing to OUTPUT_SIZE dimension if provided,
// otherwise the segmented masks is the same size as input tensor. // otherwise the segmented masks is the same size as input tensor.
@ -296,6 +298,13 @@ absl::Status TensorsToSegmentationCalculator::Open(
SegmenterOptions::UNSPECIFIED) SegmenterOptions::UNSPECIFIED)
<< "Must specify output_type as one of " << "Must specify output_type as one of "
"[CONFIDENCE_MASK|CATEGORY_MASK]."; "[CONFIDENCE_MASK|CATEGORY_MASK].";
} else {
if (!cc->Outputs().HasTag("CONFIDENCE_MASK") &&
!cc->Outputs().HasTag("CATEGORY_MASK")) {
return absl::InvalidArgumentError(
"At least one of CONFIDENCE_MASK and CATEGORY_MASK must be "
"connected.");
}
} }
#ifdef __EMSCRIPTEN__ #ifdef __EMSCRIPTEN__
MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_)); MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_));
@ -366,14 +375,16 @@ absl::Status TensorsToSegmentationCalculator::Process(
return absl::OkStatus(); return absl::OkStatus();
} }
std::vector<Image> confidence_masks = if (cc->Outputs().HasTag("CONFIDENCE_MASK")) {
ProcessForConfidenceMaskCpu(input_shape, std::vector<Image> confidence_masks = ProcessForConfidenceMaskCpu(
{/* height= */ output_height, input_shape,
/* width= */ output_width, {/* height= */ output_height,
/* channels= */ input_shape.channels}, /* width= */ output_width,
options_.segmenter_options(), tensors_buffer); /* channels= */ input_shape.channels},
for (int i = 0; i < confidence_masks.size(); ++i) { options_.segmenter_options(), tensors_buffer);
kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i])); for (int i = 0; i < confidence_masks.size(); ++i) {
kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i]));
}
} }
if (cc->Outputs().HasTag("CATEGORY_MASK")) { if (cc->Outputs().HasTag("CATEGORY_MASK")) {
kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu( kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu(

View File

@ -60,15 +60,19 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
CalculatorGraphConfig CreateGraphConfig( CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageSegmenterGraphOptionsProto> options, std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
bool output_category_mask, bool enable_flow_limiting) { bool output_confidence_masks, bool output_category_mask,
bool enable_flow_limiting) {
api2::builder::Graph graph; api2::builder::Graph graph;
auto& task_subgraph = graph.AddNode(kSubgraphTypeName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap( task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
options.get()); options.get());
graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kConfidenceMasksTag).SetName(kConfidenceMasksStreamName) >> if (output_confidence_masks) {
graph.Out(kConfidenceMasksTag); task_subgraph.Out(kConfidenceMasksTag)
.SetName(kConfidenceMasksStreamName) >>
graph.Out(kConfidenceMasksTag);
}
if (output_category_mask) { if (output_category_mask) {
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag); graph.Out(kCategoryMaskTag);
@ -135,11 +139,17 @@ absl::StatusOr<std::vector<std::string>> GetLabelsFromGraphConfig(
absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create( absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
std::unique_ptr<ImageSegmenterOptions> options) { std::unique_ptr<ImageSegmenterOptions> options) {
if (!options->output_confidence_masks && !options->output_category_mask) {
return absl::InvalidArgumentError(
"At least one of `output_confidence_masks` and `output_category_mask` "
"must be set.");
}
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
tasks::core::PacketsCallback packets_callback = nullptr; tasks::core::PacketsCallback packets_callback = nullptr;
if (options->result_callback) { if (options->result_callback) {
auto result_callback = options->result_callback; auto result_callback = options->result_callback;
bool output_category_mask = options->output_category_mask; bool output_category_mask = options->output_category_mask;
bool output_confidence_masks = options->output_confidence_masks;
packets_callback = packets_callback =
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) { [=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) { if (!status_or_packets.ok()) {
@ -151,8 +161,12 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return; return;
} }
Packet confidence_masks = std::optional<std::vector<Image>> confidence_masks;
status_or_packets.value()[kConfidenceMasksStreamName]; if (output_confidence_masks) {
confidence_masks =
status_or_packets.value()[kConfidenceMasksStreamName]
.Get<std::vector<Image>>();
}
std::optional<Image> category_mask; std::optional<Image> category_mask;
if (output_category_mask) { if (output_category_mask) {
category_mask = category_mask =
@ -160,23 +174,24 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
} }
Packet image_packet = status_or_packets.value()[kImageOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback( result_callback(
{{confidence_masks.Get<std::vector<Image>>(), category_mask}}, {{confidence_masks, category_mask}}, image_packet.Get<Image>(),
image_packet.Get<Image>(), image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
confidence_masks.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
}; };
} }
auto image_segmenter = auto image_segmenter =
core::VisionTaskApiFactory::Create<ImageSegmenter, core::VisionTaskApiFactory::Create<ImageSegmenter,
ImageSegmenterGraphOptionsProto>( ImageSegmenterGraphOptionsProto>(
CreateGraphConfig( CreateGraphConfig(
std::move(options_proto), options->output_category_mask, std::move(options_proto), options->output_confidence_masks,
options->output_category_mask,
options->running_mode == core::RunningMode::LIVE_STREAM), options->running_mode == core::RunningMode::LIVE_STREAM),
std::move(options->base_options.op_resolver), options->running_mode, std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback)); std::move(packets_callback));
if (!image_segmenter.ok()) { if (!image_segmenter.ok()) {
return image_segmenter.status(); return image_segmenter.status();
} }
image_segmenter.value()->output_confidence_masks_ =
options->output_confidence_masks;
image_segmenter.value()->output_category_mask_ = image_segmenter.value()->output_category_mask_ =
options->output_category_mask; options->output_category_mask;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
@ -203,8 +218,11 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))}, {{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}})); MakePacket<NormalizedRect>(std::move(norm_rect))}}));
std::vector<Image> confidence_masks = std::optional<std::vector<Image>> confidence_masks;
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>(); if (output_confidence_masks_) {
confidence_masks =
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
}
std::optional<Image> category_mask; std::optional<Image> category_mask;
if (output_category_mask_) { if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>(); category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
@ -233,8 +251,11 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
std::vector<Image> confidence_masks = std::optional<std::vector<Image>> confidence_masks;
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>(); if (output_confidence_masks_) {
confidence_masks =
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
}
std::optional<Image> category_mask; std::optional<Image> category_mask;
if (output_category_mask_) { if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>(); category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();

View File

@ -53,6 +53,9 @@ struct ImageSegmenterOptions {
// Metadata, if any. Defaults to English. // Metadata, if any. Defaults to English.
std::string display_names_locale = "en"; std::string display_names_locale = "en";
// Whether to output confidence masks.
bool output_confidence_masks = true;
// Whether to output category mask. // Whether to output category mask.
bool output_category_mask = false; bool output_category_mask = false;
@ -77,8 +80,10 @@ struct ImageSegmenterOptions {
// - if type is kTfLiteFloat32, NormalizationOptions are required to be // - if type is kTfLiteFloat32, NormalizationOptions are required to be
// attached to the metadata for input normalization. // attached to the metadata for input normalization.
// Output ImageSegmenterResult: // Output ImageSegmenterResult:
// Provides confidence masks and an optional category mask if // Provides optional confidence masks if `output_confidence_masks` is set
// `output_category_mask` is set true. // true, and an optional category mask if `output_category_mask` is set
// true. At least one of `output_confidence_masks` and `output_category_mask`
// must be set to true.
// An example of such model can be found at: // An example of such model can be found at:
// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
@ -167,6 +172,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
private: private:
std::vector<std::string> labels_; std::vector<std::string> labels_;
bool output_confidence_masks_;
bool output_category_mask_; bool output_category_mask_;
}; };

View File

@ -326,8 +326,10 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
} }
// An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs // An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs
// semantic segmentation. The graph always output confidence masks, and an // semantic segmentation. The graph can output optional confidence masks if
// optional category mask if CATEGORY_MASK is connected. // CONFIDENCE_MASKS is connected, and an optional category mask if CATEGORY_MASK
// is connected. At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and
// CATEGORY_MASK must be connected.
// //
// Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and // Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and
// CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular // CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular
@ -347,7 +349,7 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
// CONFIDENCE_MASK - mediapipe::Image @Multiple // CONFIDENCE_MASK - mediapipe::Image @Multiple
// Confidence masks for individual category. Confidence mask of single // Confidence masks for individual category. Confidence mask of single
// category can be accessed by index based output stream. // category can be accessed by index based output stream.
// CONFIDENCE_MASKS - std::vector<mediapipe::Image> // CONFIDENCE_MASKS - std::vector<mediapipe::Image> @Optional
// The output confidence masks grouped in a vector. // The output confidence masks grouped in a vector.
// CATEGORY_MASK - mediapipe::Image @Optional // CATEGORY_MASK - mediapipe::Image @Optional
// Optional Category mask. // Optional Category mask.
@ -356,7 +358,7 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
// //
// Example: // Example:
// node { // node {
// calculator: "mediapipe.tasks.vision.ImageSegmenterGraph" // calculator: "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"
// input_stream: "IMAGE:image" // input_stream: "IMAGE:image"
// output_stream: "SEGMENTATION:segmented_masks" // output_stream: "SEGMENTATION:segmented_masks"
// options { // options {
@ -382,17 +384,20 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
CreateModelResources<ImageSegmenterGraphOptions>(sc)); CreateModelResources<ImageSegmenterGraphOptions>(sc));
Graph graph; Graph graph;
const auto& options = sc->Options<ImageSegmenterGraphOptions>(); const auto& options = sc->Options<ImageSegmenterGraphOptions>();
// TODO: remove deprecated output type support.
if (!options.segmenter_options().has_output_type()) {
MP_RETURN_IF_ERROR(SanityCheck(sc));
}
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_streams, auto output_streams,
BuildSegmentationTask( BuildSegmentationTask(
options, *model_resources, graph[Input<Image>(kImageTag)], options, *model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
HasOutput(sc->OriginalNode(), kCategoryMaskTag), graph));
auto& merge_images_to_vector =
graph.AddNode("MergeImagesToVectorCalculator");
// TODO: remove deprecated output type support. // TODO: remove deprecated output type support.
if (options.segmenter_options().has_output_type()) { if (options.segmenter_options().has_output_type()) {
auto& merge_images_to_vector =
graph.AddNode("MergeImagesToVectorCalculator");
for (int i = 0; i < output_streams.segmented_masks->size(); ++i) { for (int i = 0; i < output_streams.segmented_masks->size(); ++i) {
output_streams.segmented_masks->at(i) >> output_streams.segmented_masks->at(i) >>
merge_images_to_vector[Input<Image>::Multiple("")][i]; merge_images_to_vector[Input<Image>::Multiple("")][i];
@ -402,14 +407,18 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
merge_images_to_vector.Out("") >> merge_images_to_vector.Out("") >>
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)]; graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
} else { } else {
for (int i = 0; i < output_streams.confidence_masks->size(); ++i) { if (output_streams.confidence_masks) {
output_streams.confidence_masks->at(i) >> auto& merge_images_to_vector =
merge_images_to_vector[Input<Image>::Multiple("")][i]; graph.AddNode("MergeImagesToVectorCalculator");
output_streams.confidence_masks->at(i) >> for (int i = 0; i < output_streams.confidence_masks->size(); ++i) {
graph[Output<Image>::Multiple(kConfidenceMaskTag)][i]; output_streams.confidence_masks->at(i) >>
merge_images_to_vector[Input<Image>::Multiple("")][i];
output_streams.confidence_masks->at(i) >>
graph[Output<Image>::Multiple(kConfidenceMaskTag)][i];
}
merge_images_to_vector.Out("") >>
graph[Output<std::vector<Image>>::Optional(kConfidenceMasksTag)];
} }
merge_images_to_vector.Out("") >>
graph[Output<std::vector<Image>>(kConfidenceMasksTag)];
if (output_streams.category_mask) { if (output_streams.category_mask) {
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)]; *output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
} }
@ -419,6 +428,19 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
} }
private: private:
absl::Status SanityCheck(mediapipe::SubgraphContext* sc) {
const auto& node = sc->OriginalNode();
output_confidence_masks_ = HasOutput(node, kConfidenceMaskTag) ||
HasOutput(node, kConfidenceMasksTag);
output_category_mask_ = HasOutput(node, kCategoryMaskTag);
if (!output_confidence_masks_ && !output_category_mask_) {
return absl::InvalidArgumentError(
"At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and CATEGORY_MASK "
"must be connected.");
}
return absl::OkStatus();
}
// Adds a mediapipe image segmentation task pipeline graph into the provided // Adds a mediapipe image segmentation task pipeline graph into the provided
// builder::Graph instance. The segmentation pipeline takes images // builder::Graph instance. The segmentation pipeline takes images
// (mediapipe::Image) as the input and returns segmented image mask as output. // (mediapipe::Image) as the input and returns segmented image mask as output.
@ -431,8 +453,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask( absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
const ImageSegmenterGraphOptions& task_options, const ImageSegmenterGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in, const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, bool output_category_mask, Source<NormalizedRect> norm_rect_in, Graph& graph) {
Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
// Adds preprocessing calculators and connects them to the graph input image // Adds preprocessing calculators and connects them to the graph input image
@ -485,26 +506,32 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
/*category_mask=*/std::nullopt, /*category_mask=*/std::nullopt,
/*image=*/image_and_tensors.image}; /*image=*/image_and_tensors.image};
} else { } else {
ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, std::optional<std::vector<Source<Image>>> confidence_masks;
GetOutputTensor(model_resources)); if (output_confidence_masks_) {
int segmentation_streams_num = *output_tensor->shape()->rbegin(); ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
std::vector<Source<Image>> confidence_masks; GetOutputTensor(model_resources));
confidence_masks.reserve(segmentation_streams_num); int segmentation_streams_num = *output_tensor->shape()->rbegin();
for (int i = 0; i < segmentation_streams_num; ++i) { confidence_masks = std::vector<Source<Image>>();
confidence_masks.push_back(Source<Image>( confidence_masks->reserve(segmentation_streams_num);
tensor_to_images[Output<Image>::Multiple(kConfidenceMaskTag)][i])); for (int i = 0; i < segmentation_streams_num; ++i) {
confidence_masks->push_back(Source<Image>(
tensor_to_images[Output<Image>::Multiple(kConfidenceMaskTag)]
[i]));
}
} }
return ImageSegmenterOutputs{ std::optional<Source<Image>> category_mask;
/*segmented_masks=*/std::nullopt, if (output_category_mask_) {
/*confidence_masks=*/confidence_masks, category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
/*category_mask=*/ }
output_category_mask return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
? std::make_optional( /*confidence_masks=*/confidence_masks,
tensor_to_images[Output<Image>(kCategoryMaskTag)]) /*category_mask=*/category_mask,
: std::nullopt, /*image=*/image_and_tensors.image};
/*image=*/image_and_tensors.image};
} }
} }
bool output_confidence_masks_ = false;
bool output_category_mask_ = false;
}; };
REGISTER_MEDIAPIPE_GRAPH( REGISTER_MEDIAPIPE_GRAPH(

View File

@ -29,7 +29,7 @@ namespace image_segmenter {
struct ImageSegmenterResult { struct ImageSegmenterResult {
// Multiple masks of float image in VEC32F1 format where, for each mask, each // Multiple masks of float image in VEC32F1 format where, for each mask, each
// pixel represents the prediction confidence, usually in the [0, 1] range. // pixel represents the prediction confidence, usually in the [0, 1] range.
std::vector<Image> confidence_masks; std::optional<std::vector<Image>> confidence_masks;
// A category mask of uint8 image in GRAY8 format where each pixel represents // A category mask of uint8 image in GRAY8 format where each pixel represents
// the class which the pixel in the original image was predicted to belong to. // the class which the pixel in the original image was predicted to belong to.
std::optional<Image> category_mask; std::optional<Image> category_mask;

View File

@ -278,6 +278,7 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_confidence_masks = false;
options->output_category_mask = true; options->output_category_mask = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
@ -306,7 +307,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(result.confidence_masks.size(), 21); EXPECT_EQ(result.confidence_masks->size(), 21);
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE); JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE);
@ -315,7 +316,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
// Cat category index 8. // Cat category index 8.
cv::Mat cat_mask = mediapipe::formats::MatView( cv::Mat cat_mask = mediapipe::formats::MatView(
result.confidence_masks[8].GetImageFrameSharedPtr().get()); result.confidence_masks->at(8).GetImageFrameSharedPtr().get());
EXPECT_THAT(cat_mask, EXPECT_THAT(cat_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
@ -336,7 +337,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
image_processing_options.rotation_degrees = -90; image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN(auto result, MP_ASSERT_OK_AND_ASSIGN(auto result,
segmenter->Segment(image, image_processing_options)); segmenter->Segment(image, image_processing_options));
EXPECT_EQ(result.confidence_masks.size(), 21); EXPECT_EQ(result.confidence_masks->size(), 21);
cv::Mat expected_mask = cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
@ -346,7 +347,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
// Cat category index 8. // Cat category index 8.
cv::Mat cat_mask = mediapipe::formats::MatView( cv::Mat cat_mask = mediapipe::formats::MatView(
result.confidence_masks[8].GetImageFrameSharedPtr().get()); result.confidence_masks->at(8).GetImageFrameSharedPtr().get());
EXPECT_THAT(cat_mask, EXPECT_THAT(cat_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
@ -384,7 +385,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(result.confidence_masks.size(), 2); EXPECT_EQ(result.confidence_masks->size(), 2);
cv::Mat expected_mask = cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, cv::imread(JoinPath("./", kTestDataDirectory,
@ -395,7 +396,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
// Selfie category index 1. // Selfie category index 1.
cv::Mat selfie_mask = mediapipe::formats::MatView( cv::Mat selfie_mask = mediapipe::formats::MatView(
result.confidence_masks[1].GetImageFrameSharedPtr().get()); result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
EXPECT_THAT(selfie_mask, EXPECT_THAT(selfie_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
@ -409,7 +410,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(result.confidence_masks.size(), 1); EXPECT_EQ(result.confidence_masks->size(), 1);
cv::Mat expected_mask = cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, cv::imread(JoinPath("./", kTestDataDirectory,
@ -419,7 +420,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
cv::Mat selfie_mask = mediapipe::formats::MatView( cv::Mat selfie_mask = mediapipe::formats::MatView(
result.confidence_masks[0].GetImageFrameSharedPtr().get()); result.confidence_masks->at(0).GetImageFrameSharedPtr().get());
EXPECT_THAT(selfie_mask, EXPECT_THAT(selfie_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
@ -434,7 +435,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(result.confidence_masks.size(), 1); EXPECT_EQ(result.confidence_masks->size(), 1);
MP_ASSERT_OK(segmenter->Close()); MP_ASSERT_OK(segmenter->Close());
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
@ -445,7 +446,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
cv::Mat selfie_mask = mediapipe::formats::MatView( cv::Mat selfie_mask = mediapipe::formats::MatView(
result.confidence_masks[0].GetImageFrameSharedPtr().get()); result.confidence_masks->at(0).GetImageFrameSharedPtr().get());
EXPECT_THAT(selfie_mask, EXPECT_THAT(selfie_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
@ -506,10 +507,10 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
EXPECT_EQ(result.confidence_masks.size(), 2); EXPECT_EQ(result.confidence_masks->size(), 2);
cv::Mat hair_mask = mediapipe::formats::MatView( cv::Mat hair_mask = mediapipe::formats::MatView(
result.confidence_masks[1].GetImageFrameSharedPtr().get()); result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
MP_ASSERT_OK(segmenter->Close()); MP_ASSERT_OK(segmenter->Close());
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"), JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"),