Add Interactive Segmenter MediaPipe Task

PiperOrigin-RevId: 516954589
This commit is contained in:
MediaPipe Team 2023-03-15 16:02:48 -07:00 committed by Copybara-Service
parent 43082482f8
commit 61bcddc671
5 changed files with 834 additions and 0 deletions

View File

@ -0,0 +1,76 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# Docs for Mediapipe Tasks Interactive Segmenter
# TODO: add doc link.
cc_library(
name = "interactive_segmenter",
srcs = ["interactive_segmenter.cc"],
hdrs = ["interactive_segmenter.h"],
deps = [
":interactive_segmenter_graph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/containers:keypoint",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/util:color_cc_proto",
"//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "interactive_segmenter_graph",
srcs = ["interactive_segmenter_graph.cc"],
deps = [
"@com_google_absl//absl/strings",
"//mediapipe/calculators/image:set_alpha_calculator",
"//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:flat_color_image_calculator",
"//mediapipe/calculators/util:flat_color_image_calculator_cc_proto",
"//mediapipe/calculators/util:from_image_calculator",
"//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/util:color_cc_proto",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:render_data_cc_proto",
] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
"//mediapipe/gpu:gpu_buffer_to_image_frame_calculator",
"//mediapipe/gpu:image_frame_to_gpu_buffer_calculator",
],
}),
alwayslink = 1,
)

View File

@ -0,0 +1,163 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
#include <utility>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace interactive_segmenter {
namespace {
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kRoiStreamName[] = "roi_in";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageTag[] = "IMAGE";
constexpr char kRoiTag[] = "ROI";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image;
using ::mediapipe::NormalizedRect;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
image_segmenter::proto::ImageSegmenterGraphOptions;
// Creates a MediaPipe graph config that only contains a single subgraph node of
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageSegmenterGraphOptionsProto> options) {
api2::builder::Graph graph;
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kRoiTag).SetName(kRoiStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
graph.Out(kGroupedSegmentationTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kRoiTag) >> task_subgraph.In(kRoiTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
return graph.GetConfig();
}
// Converts the user-facing InteractiveSegmenterOptions struct to the internal
// ImageSegmenterOptions proto.
std::unique_ptr<ImageSegmenterGraphOptionsProto>
ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
auto options_proto = std::make_unique<ImageSegmenterGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
switch (options->output_type) {
case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CATEGORY_MASK);
break;
case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CONFIDENCE_MASK);
break;
}
return options_proto;
}
// Converts the user-facing RegionOfInterest struct to the RenderData proto that
// is used in subgraph.
absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
RenderData result;
switch (roi.format) {
case RegionOfInterest::UNSPECIFIED:
return absl::InvalidArgumentError(
"RegionOfInterest format not specified");
case RegionOfInterest::KEYPOINT:
RET_CHECK(roi.keypoint.has_value());
auto* annotation = result.add_render_annotations();
annotation->mutable_color()->set_r(255);
auto* point = annotation->mutable_point();
point->set_normalized(true);
point->set_x(roi.keypoint->x);
point->set_y(roi.keypoint->y);
return result;
}
return absl::UnimplementedError("Unrecognized format");
}
} // namespace
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>>
InteractiveSegmenter::Create(
std::unique_ptr<InteractiveSegmenterOptions> options) {
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
return core::VisionTaskApiFactory::Create<InteractiveSegmenter,
ImageSegmenterGraphOptionsProto>(
CreateGraphConfig(std::move(options_proto)),
std::move(options->base_options.op_resolver), core::RunningMode::IMAGE,
/*packets_callback=*/nullptr);
}
absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
mediapipe::Image image, const RegionOfInterest& roi,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(RenderData roi_as_render_data, ConvertRoiToRenderData(roi));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
{kRoiStreamName,
mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
}
} // namespace interactive_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,136 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
#define MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace interactive_segmenter {
// The options for configuring a mediapipe interactive segmenter task.
struct InteractiveSegmenterOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options;
// The output type of segmentation results.
enum OutputType {
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK = 0,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK = 1,
};
OutputType output_type = OutputType::CATEGORY_MASK;
};
// The Region-Of-Interest (ROI) to interact with.
struct RegionOfInterest {
enum Format {
UNSPECIFIED = 0, // Format not specified.
KEYPOINT = 1, // Using keypoint to represent ROI.
};
// Specifies the format used to specify the region-of-interest. Note that
// using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status
// being returned.
Format format = Format::UNSPECIFIED;
// Represents the ROI in keypoint format, this should be non-nullopt if
// `format` is `KEYPOINT`.
std::optional<components::containers::NormalizedKeypoint> keypoint;
};
// Performs interactive segmentation on images.
//
// Users can represent user interaction through `RegionOfInterest`, which gives
// a hint to InteractiveSegmenter to perform segmentation focusing on the given
// region of interest.
//
// The API expects a TFLite model with mandatory TFLite Model Metadata.
//
// Input tensor:
// (kTfLiteUInt8/kTfLiteFloat32)
// - image input of size `[batch x height x width x channels]`.
// - batch inference is not supported (`batch` is required to be 1).
// - RGB inputs is supported (`channels` is required to be 3).
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
// attached to the metadata for input normalization.
// Output tensors:
// (kTfLiteUInt8/kTfLiteFloat32)
// - list of segmented masks.
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
// `channels`.
// - batch is always 1
class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
public:
using BaseVisionTaskApi::BaseVisionTaskApi;
// Creates an InteractiveSegmenter from the provided options. A non-default
// OpResolver can be specified in the BaseOptions of
// InteractiveSegmenterOptions, to support custom Ops of the segmentation
// model.
static absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> Create(
std::unique_ptr<InteractiveSegmenterOptions> options);
// Performs image segmentation on the provided single image.
//
// The image can be of any size with format RGB.
//
// The `roi` parameter is used to represent user's region of interest for
// segmentation.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
mediapipe::Image image, const RegionOfInterest& roi,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Shuts down the InteractiveSegmenter when all works are done.
absl::Status Close() { return runner_->Close(); }
};
} // namespace interactive_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_

View File

@ -0,0 +1,198 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace interactive_segmenter {
namespace {
using image_segmenter::proto::ImageSegmenterGraphOptions;
using ::mediapipe::Image;
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
constexpr char kSegmentationTag[] = "SEGMENTATION";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageCpuTag[] = "IMAGE_CPU";
constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kAlphaTag[] = "ALPHA";
constexpr char kAlphaGpuTag[] = "ALPHA_GPU";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kRoiTag[] = "ROI";
constexpr char kVideoTag[] = "VIDEO";
// Updates the graph to return `roi` stream which has same dimension as
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
// in GpuBuffer format, otherwise using ImageFrame.
Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
Graph& graph) {
// TODO: Replace with efficient implementation.
const absl::string_view image_tag_with_suffix =
use_gpu ? kImageGpuTag : kImageCpuTag;
// Generates a blank canvas with same size as input image.
auto& flat_color = graph.AddNode("FlatColorImageCalculator");
auto& flat_color_options =
flat_color.GetOptions<FlatColorImageCalculatorOptions>();
// SetAlphaCalculator only takes 1st channel.
flat_color_options.mutable_color()->set_r(0);
image >> flat_color.In(kImageTag)[0];
auto blank_canvas = flat_color.Out(kImageTag)[0];
auto& from_mp_image = graph.AddNode("FromImageCalculator");
blank_canvas >> from_mp_image.In(kImageTag);
auto blank_canvas_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
blank_canvas_in_cpu_or_gpu >>
roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
roi >> roi_to_alpha.In(0);
auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
return alpha;
}
} // namespace
// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
// performs semantic segmentation given user's region-of-interest. Two kinds of
// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can
// retrieve segmented mask of only particular category/channel from
// SEGMENTATION, and users can also get all segmented masks from
// GROUPED_SEGMENTATION.
// - Accepts CPU input images and outputs segmented masks on CPU.
//
// Inputs:
// IMAGE - Image
// Image to perform segmentation on.
// ROI - RenderData proto
// Region of interest based on user interaction. Currently only support
// Point format, and Color has to be (255, 255, 255).
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform detection
// on.
// @Optional: rect covering the whole image is used if not specified.
//
// Outputs:
// SEGMENTATION - mediapipe::Image @Multiple
// Segmented masks for individual category. Segmented mask of single
// category can be accessed by index based output stream.
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image>
// The output segmented masks grouped in a vector.
// IMAGE - mediapipe::Image
// The image that image segmenter runs on.
//
// Example:
// node {
// calculator:
// "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
// input_stream: "IMAGE:image"
// input_stream: "ROI:region_of_interest"
// output_stream: "SEGMENTATION:segmented_masks"
// options {
// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "/path/to/model.tflite"
// }
// }
// segmenter_options {
// output_type: CONFIDENCE_MASK
// }
// }
// }
// }
class InteractiveSegmenterGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override {
Graph graph;
const auto& task_options = sc->Options<ImageSegmenterGraphOptions>();
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
Source<Image> image = graph[Input<Image>(kImageTag)];
Source<RenderData> roi = graph[Input<RenderData>(kRoiTag)];
Source<NormalizedRect> norm_rect =
graph[Input<NormalizedRect>(kNormRectTag)];
const absl::string_view image_tag_with_suffix =
use_gpu ? kImageGpuTag : kImageCpuTag;
const absl::string_view alpha_tag_with_suffix =
use_gpu ? kAlphaGpuTag : kAlphaTag;
auto& from_mp_image = graph.AddNode("FromImageCalculator");
image >> from_mp_image.In(kImageTag);
auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
auto& set_alpha = graph.AddNode("SetAlphaCalculator");
image_in_cpu_or_gpu >> set_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
alpha_in_cpu_or_gpu >> set_alpha.In(alpha_tag_with_suffix);
auto image_in_cpu_or_gpu_with_set_alpha =
set_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
auto& to_mp_image = graph.AddNode("ToImageCalculator");
image_in_cpu_or_gpu_with_set_alpha >> to_mp_image.In(image_tag_with_suffix);
auto image_with_set_alpha = to_mp_image.Out(kImageTag);
auto& image_segmenter = graph.AddNode(
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph");
image_segmenter.GetOptions<ImageSegmenterGraphOptions>() = task_options;
image_with_set_alpha >> image_segmenter.In(kImageTag);
norm_rect >> image_segmenter.In(kNormRectTag);
image_segmenter.Out(kSegmentationTag) >>
graph[Output<Image>(kSegmentationTag)];
image_segmenter.Out(kGroupedSegmentationTag) >>
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
};
// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly.
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::interactive_segmenter::InteractiveSegmenterGraph);
// clang-format on
} // namespace interactive_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,261 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
#include <cstdint>
#include <memory>
#include "absl/flags/flag.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace interactive_segmenter {
namespace {
using ::mediapipe::Image;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite";
constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg";
constexpr float kGoldenMaskSimilarity = 0.98;
// Magnification factor used when creating the golden category masks to make
// them more human-friendly. Each pixel in the golden masks has its value
// multiplied by this factor, i.e. a value of 10 means class index 1, a value of
// 20 means class index 2, etc.
constexpr int kGoldenMaskMagnificationFactor = 10;
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1
// as expected outputs are stored in CV_8UC1, so this conversion allows to do
// fair comparison.
cv::Mat PostProcessResultMask(const cv::Mat& mask) {
cv::Mat mask_float;
mask.convertTo(mask_float, CV_8UC1, 255);
mask_float.convertTo(mask_float, CV_32FC1, 1 / 255.f);
return mask_float;
}
double CalculateSum(const cv::Mat& m) {
double sum = 0.0;
cv::Scalar s = cv::sum(m);
for (int i = 0; i < m.channels(); ++i) {
sum += s.val[i];
}
return sum;
}
double CalculateSoftIOU(const cv::Mat& m1, const cv::Mat& m2) {
cv::Mat intersection;
cv::multiply(m1, m2, intersection);
double intersection_value = CalculateSum(intersection);
double union_value =
CalculateSum(m1.mul(m1)) + CalculateSum(m2.mul(m2)) - intersection_value;
return union_value > 0.0 ? intersection_value / union_value : 0.0;
}
MATCHER_P2(SimilarToFloatMask, expected_mask, similarity_threshold, "") {
cv::Mat actual_mask = PostProcessResultMask(arg);
return arg.rows == expected_mask.rows && arg.cols == expected_mask.cols &&
CalculateSoftIOU(arg, expected_mask) > similarity_threshold;
}
MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold,
magnification_factor, "") {
if (arg.rows != expected_mask.rows || arg.cols != expected_mask.cols) {
return false;
}
int consistent_pixels = 0;
const int num_pixels = expected_mask.rows * expected_mask.cols;
for (int i = 0; i < num_pixels; ++i) {
consistent_pixels +=
(arg.data[i] * magnification_factor == expected_mask.data[i]);
}
return static_cast<float>(consistent_pixels) / num_pixels >=
similarity_threshold;
}
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
public:
DeepLabOpResolverMissingOps() {
AddBuiltin(::tflite::BuiltinOperator_ADD,
::tflite::ops::builtin::Register_ADD());
}
DeepLabOpResolverMissingOps(const DeepLabOpResolverMissingOps& r) = delete;
};
TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->base_options.op_resolver =
absl::make_unique<DeepLabOpResolverMissingOps>();
auto segmenter_or = InteractiveSegmenter::Create(std::move(options));
// TODO: Make MediaPipe InferenceCalculator report the detailed
// interpreter errors (e.g., "Encountered unresolved custom op").
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
EXPECT_THAT(
segmenter_or.status().message(),
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
}
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter_or =
InteractiveSegmenter::Create(
std::make_unique<InteractiveSegmenterOptions>());
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
segmenter_or.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
class ImageModeTest : public tflite_shims::testing::Test {};
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT;
interaction_roi.keypoint =
components::containers::NormalizedKeypoint{0.25, 0.9};
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_masks,
segmenter->Segment(image, interaction_roi));
EXPECT_EQ(category_masks.size(), 1);
}
TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT;
interaction_roi.keypoint =
components::containers::NormalizedKeypoint{0.25, 0.9};
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
segmenter->Segment(image, interaction_roi));
EXPECT_EQ(confidence_masks.size(), 2);
}
// TODO: fix this unit test after image segmenter handled post
// processing correctly with rotated image.
TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT;
interaction_roi.keypoint =
components::containers::NormalizedKeypoint{0.25, 0.9};
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN(
auto confidence_masks,
segmenter->Segment(image, interaction_roi, image_processing_options));
EXPECT_EQ(confidence_masks.size(), 2);
}
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT;
interaction_roi.keypoint =
components::containers::NormalizedKeypoint{0.25, 0.9};
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results =
segmenter->Segment(image, interaction_roi, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("This task doesn't support region-of-interest"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
}
} // namespace
} // namespace interactive_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe