Add Interactive Segmenter MediaPipe Task
PiperOrigin-RevId: 516954589
This commit is contained in:
parent
43082482f8
commit
61bcddc671
76
mediapipe/tasks/cc/vision/interactive_segmenter/BUILD
Normal file
76
mediapipe/tasks/cc/vision/interactive_segmenter/BUILD
Normal file
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# Docs for Mediapipe Tasks Interactive Segmenter
|
||||
# TODO: add doc link.
|
||||
cc_library(
|
||||
name = "interactive_segmenter",
|
||||
srcs = ["interactive_segmenter.cc"],
|
||||
hdrs = ["interactive_segmenter.h"],
|
||||
deps = [
|
||||
":interactive_segmenter_graph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:keypoint",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "interactive_segmenter_graph",
|
||||
srcs = ["interactive_segmenter_graph.cc"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
"//mediapipe/calculators/image:set_alpha_calculator",
|
||||
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
||||
"//mediapipe/calculators/util:flat_color_image_calculator",
|
||||
"//mediapipe/calculators/util:flat_color_image_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:from_image_calculator",
|
||||
"//mediapipe/calculators/util:to_image_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/gpu:gpu_buffer_to_image_frame_calculator",
|
||||
"//mediapipe/gpu:image_frame_to_gpu_buffer_calculator",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
|
@ -0,0 +1,163 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace interactive_segmenter {
|
||||
namespace {
|
||||
|
||||
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
|
||||
constexpr char kImageInStreamName[] = "image_in";
|
||||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kRoiStreamName[] = "roi_in";
|
||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
|
||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kRoiTag[] = "ROI";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
||||
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
||||
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||
|
||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<ImageSegmenterGraphOptionsProto> options) {
|
||||
api2::builder::Graph graph;
|
||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
|
||||
options.get());
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kRoiTag).SetName(kRoiStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
||||
graph.Out(kGroupedSegmentationTag);
|
||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||
graph.Out(kImageTag);
|
||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||
graph.In(kRoiTag) >> task_subgraph.In(kRoiTag);
|
||||
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Converts the user-facing InteractiveSegmenterOptions struct to the internal
|
||||
// ImageSegmenterOptions proto.
|
||||
std::unique_ptr<ImageSegmenterGraphOptionsProto>
|
||||
ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
|
||||
auto options_proto = std::make_unique<ImageSegmenterGraphOptionsProto>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
switch (options->output_type) {
|
||||
case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK:
|
||||
options_proto->mutable_segmenter_options()->set_output_type(
|
||||
SegmenterOptions::CATEGORY_MASK);
|
||||
break;
|
||||
case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK:
|
||||
options_proto->mutable_segmenter_options()->set_output_type(
|
||||
SegmenterOptions::CONFIDENCE_MASK);
|
||||
break;
|
||||
}
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
// Converts the user-facing RegionOfInterest struct to the RenderData proto that
|
||||
// is used in subgraph.
|
||||
absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
||||
RenderData result;
|
||||
switch (roi.format) {
|
||||
case RegionOfInterest::UNSPECIFIED:
|
||||
return absl::InvalidArgumentError(
|
||||
"RegionOfInterest format not specified");
|
||||
case RegionOfInterest::KEYPOINT:
|
||||
RET_CHECK(roi.keypoint.has_value());
|
||||
auto* annotation = result.add_render_annotations();
|
||||
annotation->mutable_color()->set_r(255);
|
||||
auto* point = annotation->mutable_point();
|
||||
point->set_normalized(true);
|
||||
point->set_x(roi.keypoint->x);
|
||||
point->set_y(roi.keypoint->y);
|
||||
return result;
|
||||
}
|
||||
return absl::UnimplementedError("Unrecognized format");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>>
|
||||
InteractiveSegmenter::Create(
|
||||
std::unique_ptr<InteractiveSegmenterOptions> options) {
|
||||
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
|
||||
return core::VisionTaskApiFactory::Create<InteractiveSegmenter,
|
||||
ImageSegmenterGraphOptionsProto>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver), core::RunningMode::IMAGE,
|
||||
/*packets_callback=*/nullptr);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
|
||||
mediapipe::Image image, const RegionOfInterest& roi,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(RenderData roi_as_render_data, ConvertRoiToRenderData(roi));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData(
|
||||
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
|
||||
{kRoiStreamName,
|
||||
mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
||||
}
|
||||
|
||||
} // namespace interactive_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,136 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace interactive_segmenter {
|
||||
|
||||
// The options for configuring a mediapipe interactive segmenter task.
|
||||
struct InteractiveSegmenterOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// The output type of segmentation results.
|
||||
enum OutputType {
|
||||
// Gives a single output mask where each pixel represents the class which
|
||||
// the pixel in the original image was predicted to belong to.
|
||||
CATEGORY_MASK = 0,
|
||||
// Gives a list of output masks where, for each mask, each pixel represents
|
||||
// the prediction confidence, usually in the [0, 1] range.
|
||||
CONFIDENCE_MASK = 1,
|
||||
};
|
||||
|
||||
OutputType output_type = OutputType::CATEGORY_MASK;
|
||||
};
|
||||
|
||||
// The Region-Of-Interest (ROI) to interact with.
|
||||
struct RegionOfInterest {
|
||||
enum Format {
|
||||
UNSPECIFIED = 0, // Format not specified.
|
||||
KEYPOINT = 1, // Using keypoint to represent ROI.
|
||||
};
|
||||
|
||||
// Specifies the format used to specify the region-of-interest. Note that
|
||||
// using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status
|
||||
// being returned.
|
||||
Format format = Format::UNSPECIFIED;
|
||||
|
||||
// Represents the ROI in keypoint format, this should be non-nullopt if
|
||||
// `format` is `KEYPOINT`.
|
||||
std::optional<components::containers::NormalizedKeypoint> keypoint;
|
||||
};
|
||||
|
||||
// Performs interactive segmentation on images.
|
||||
//
|
||||
// Users can represent user interaction through `RegionOfInterest`, which gives
|
||||
// a hint to InteractiveSegmenter to perform segmentation focusing on the given
|
||||
// region of interest.
|
||||
//
|
||||
// The API expects a TFLite model with mandatory TFLite Model Metadata.
|
||||
//
|
||||
// Input tensor:
|
||||
// (kTfLiteUInt8/kTfLiteFloat32)
|
||||
// - image input of size `[batch x height x width x channels]`.
|
||||
// - batch inference is not supported (`batch` is required to be 1).
|
||||
// - RGB inputs is supported (`channels` is required to be 3).
|
||||
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
||||
// attached to the metadata for input normalization.
|
||||
// Output tensors:
|
||||
// (kTfLiteUInt8/kTfLiteFloat32)
|
||||
// - list of segmented masks.
|
||||
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
|
||||
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
|
||||
// `channels`.
|
||||
// - batch is always 1
|
||||
class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||
public:
|
||||
using BaseVisionTaskApi::BaseVisionTaskApi;
|
||||
|
||||
// Creates an InteractiveSegmenter from the provided options. A non-default
|
||||
// OpResolver can be specified in the BaseOptions of
|
||||
// InteractiveSegmenterOptions, to support custom Ops of the segmentation
|
||||
// model.
|
||||
static absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> Create(
|
||||
std::unique_ptr<InteractiveSegmenterOptions> options);
|
||||
|
||||
// Performs image segmentation on the provided single image.
|
||||
//
|
||||
// The image can be of any size with format RGB.
|
||||
//
|
||||
// The `roi` parameter is used to represent user's region of interest for
|
||||
// segmentation.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing segmentation, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||
// per-category segmented image mask.
|
||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||
// contains only one confidence image mask.
|
||||
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
|
||||
mediapipe::Image image, const RegionOfInterest& roi,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Shuts down the InteractiveSegmenter when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace interactive_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
|
|
@ -0,0 +1,198 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace interactive_segmenter {
|
||||
|
||||
namespace {
|
||||
|
||||
using image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
|
||||
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageCpuTag[] = "IMAGE_CPU";
|
||||
constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
||||
constexpr char kAlphaTag[] = "ALPHA";
|
||||
constexpr char kAlphaGpuTag[] = "ALPHA_GPU";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kRoiTag[] = "ROI";
|
||||
constexpr char kVideoTag[] = "VIDEO";
|
||||
|
||||
// Updates the graph to return `roi` stream which has same dimension as
|
||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
||||
// in GpuBuffer format, otherwise using ImageFrame.
|
||||
Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
|
||||
Graph& graph) {
|
||||
// TODO: Replace with efficient implementation.
|
||||
const absl::string_view image_tag_with_suffix =
|
||||
use_gpu ? kImageGpuTag : kImageCpuTag;
|
||||
|
||||
// Generates a blank canvas with same size as input image.
|
||||
auto& flat_color = graph.AddNode("FlatColorImageCalculator");
|
||||
auto& flat_color_options =
|
||||
flat_color.GetOptions<FlatColorImageCalculatorOptions>();
|
||||
// SetAlphaCalculator only takes 1st channel.
|
||||
flat_color_options.mutable_color()->set_r(0);
|
||||
image >> flat_color.In(kImageTag)[0];
|
||||
auto blank_canvas = flat_color.Out(kImageTag)[0];
|
||||
|
||||
auto& from_mp_image = graph.AddNode("FromImageCalculator");
|
||||
blank_canvas >> from_mp_image.In(kImageTag);
|
||||
auto blank_canvas_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
|
||||
|
||||
auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
|
||||
blank_canvas_in_cpu_or_gpu >>
|
||||
roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
|
||||
roi >> roi_to_alpha.In(0);
|
||||
auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
|
||||
|
||||
return alpha;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
|
||||
// performs semantic segmentation given user's region-of-interest. Two kinds of
|
||||
// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can
|
||||
// retrieve segmented mask of only particular category/channel from
|
||||
// SEGMENTATION, and users can also get all segmented masks from
|
||||
// GROUPED_SEGMENTATION.
|
||||
// - Accepts CPU input images and outputs segmented masks on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform segmentation on.
|
||||
// ROI - RenderData proto
|
||||
// Region of interest based on user interaction. Currently only support
|
||||
// Point format, and Color has to be (255, 255, 255).
|
||||
// NORM_RECT - NormalizedRect @Optional
|
||||
// Describes image rotation and region of image to perform detection
|
||||
// on.
|
||||
// @Optional: rect covering the whole image is used if not specified.
|
||||
//
|
||||
// Outputs:
|
||||
// SEGMENTATION - mediapipe::Image @Multiple
|
||||
// Segmented masks for individual category. Segmented mask of single
|
||||
// category can be accessed by index based output stream.
|
||||
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image>
|
||||
// The output segmented masks grouped in a vector.
|
||||
// IMAGE - mediapipe::Image
|
||||
// The image that image segmenter runs on.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
|
||||
// input_stream: "IMAGE:image"
|
||||
// input_stream: "ROI:region_of_interest"
|
||||
// output_stream: "SEGMENTATION:segmented_masks"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "/path/to/model.tflite"
|
||||
// }
|
||||
// }
|
||||
// segmenter_options {
|
||||
// output_type: CONFIDENCE_MASK
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class InteractiveSegmenterGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||
mediapipe::SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
const auto& task_options = sc->Options<ImageSegmenterGraphOptions>();
|
||||
bool use_gpu =
|
||||
components::processors::DetermineImagePreprocessingGpuBackend(
|
||||
task_options.base_options().acceleration());
|
||||
|
||||
Source<Image> image = graph[Input<Image>(kImageTag)];
|
||||
Source<RenderData> roi = graph[Input<RenderData>(kRoiTag)];
|
||||
Source<NormalizedRect> norm_rect =
|
||||
graph[Input<NormalizedRect>(kNormRectTag)];
|
||||
const absl::string_view image_tag_with_suffix =
|
||||
use_gpu ? kImageGpuTag : kImageCpuTag;
|
||||
const absl::string_view alpha_tag_with_suffix =
|
||||
use_gpu ? kAlphaGpuTag : kAlphaTag;
|
||||
|
||||
auto& from_mp_image = graph.AddNode("FromImageCalculator");
|
||||
image >> from_mp_image.In(kImageTag);
|
||||
auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
|
||||
|
||||
auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
|
||||
|
||||
auto& set_alpha = graph.AddNode("SetAlphaCalculator");
|
||||
image_in_cpu_or_gpu >> set_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
|
||||
alpha_in_cpu_or_gpu >> set_alpha.In(alpha_tag_with_suffix);
|
||||
auto image_in_cpu_or_gpu_with_set_alpha =
|
||||
set_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
|
||||
|
||||
auto& to_mp_image = graph.AddNode("ToImageCalculator");
|
||||
image_in_cpu_or_gpu_with_set_alpha >> to_mp_image.In(image_tag_with_suffix);
|
||||
auto image_with_set_alpha = to_mp_image.Out(kImageTag);
|
||||
|
||||
auto& image_segmenter = graph.AddNode(
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph");
|
||||
image_segmenter.GetOptions<ImageSegmenterGraphOptions>() = task_options;
|
||||
image_with_set_alpha >> image_segmenter.In(kImageTag);
|
||||
norm_rect >> image_segmenter.In(kNormRectTag);
|
||||
|
||||
image_segmenter.Out(kSegmentationTag) >>
|
||||
graph[Output<Image>(kSegmentationTag)];
|
||||
image_segmenter.Out(kGroupedSegmentationTag) >>
|
||||
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
||||
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
||||
|
||||
return graph.GetConfig();
|
||||
}
|
||||
};
|
||||
|
||||
// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly.
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::interactive_segmenter::InteractiveSegmenterGraph);
|
||||
// clang-format on
|
||||
|
||||
} // namespace interactive_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,261 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace interactive_segmenter {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::RectF;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite";
|
||||
constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg";
|
||||
|
||||
constexpr float kGoldenMaskSimilarity = 0.98;
|
||||
|
||||
// Magnification factor used when creating the golden category masks to make
|
||||
// them more human-friendly. Each pixel in the golden masks has its value
|
||||
// multiplied by this factor, i.e. a value of 10 means class index 1, a value of
|
||||
// 20 means class index 2, etc.
|
||||
constexpr int kGoldenMaskMagnificationFactor = 10;
|
||||
|
||||
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1
|
||||
// as expected outputs are stored in CV_8UC1, so this conversion allows to do
|
||||
// fair comparison.
|
||||
cv::Mat PostProcessResultMask(const cv::Mat& mask) {
|
||||
cv::Mat mask_float;
|
||||
mask.convertTo(mask_float, CV_8UC1, 255);
|
||||
mask_float.convertTo(mask_float, CV_32FC1, 1 / 255.f);
|
||||
return mask_float;
|
||||
}
|
||||
|
||||
double CalculateSum(const cv::Mat& m) {
|
||||
double sum = 0.0;
|
||||
cv::Scalar s = cv::sum(m);
|
||||
for (int i = 0; i < m.channels(); ++i) {
|
||||
sum += s.val[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
double CalculateSoftIOU(const cv::Mat& m1, const cv::Mat& m2) {
|
||||
cv::Mat intersection;
|
||||
cv::multiply(m1, m2, intersection);
|
||||
double intersection_value = CalculateSum(intersection);
|
||||
double union_value =
|
||||
CalculateSum(m1.mul(m1)) + CalculateSum(m2.mul(m2)) - intersection_value;
|
||||
return union_value > 0.0 ? intersection_value / union_value : 0.0;
|
||||
}
|
||||
|
||||
MATCHER_P2(SimilarToFloatMask, expected_mask, similarity_threshold, "") {
|
||||
cv::Mat actual_mask = PostProcessResultMask(arg);
|
||||
return arg.rows == expected_mask.rows && arg.cols == expected_mask.cols &&
|
||||
CalculateSoftIOU(arg, expected_mask) > similarity_threshold;
|
||||
}
|
||||
|
||||
MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold,
|
||||
magnification_factor, "") {
|
||||
if (arg.rows != expected_mask.rows || arg.cols != expected_mask.cols) {
|
||||
return false;
|
||||
}
|
||||
int consistent_pixels = 0;
|
||||
const int num_pixels = expected_mask.rows * expected_mask.cols;
|
||||
for (int i = 0; i < num_pixels; ++i) {
|
||||
consistent_pixels +=
|
||||
(arg.data[i] * magnification_factor == expected_mask.data[i]);
|
||||
}
|
||||
return static_cast<float>(consistent_pixels) / num_pixels >=
|
||||
similarity_threshold;
|
||||
}
|
||||
|
||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||
|
||||
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||
public:
|
||||
DeepLabOpResolverMissingOps() {
|
||||
AddBuiltin(::tflite::BuiltinOperator_ADD,
|
||||
::tflite::ops::builtin::Register_ADD());
|
||||
}
|
||||
|
||||
DeepLabOpResolverMissingOps(const DeepLabOpResolverMissingOps& r) = delete;
|
||||
};
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->base_options.op_resolver =
|
||||
absl::make_unique<DeepLabOpResolverMissingOps>();
|
||||
auto segmenter_or = InteractiveSegmenter::Create(std::move(options));
|
||||
// TODO: Make MediaPipe InferenceCalculator report the detailed
|
||||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
|
||||
EXPECT_THAT(
|
||||
segmenter_or.status().message(),
|
||||
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter_or =
|
||||
InteractiveSegmenter::Create(
|
||||
std::make_unique<InteractiveSegmenterOptions>());
|
||||
|
||||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(
|
||||
segmenter_or.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||
interaction_roi.keypoint =
|
||||
components::containers::NormalizedKeypoint{0.25, 0.9};
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto category_masks,
|
||||
segmenter->Segment(image, interaction_roi));
|
||||
EXPECT_EQ(category_masks.size(), 1);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||
interaction_roi.keypoint =
|
||||
components::containers::NormalizedKeypoint{0.25, 0.9};
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type =
|
||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
|
||||
segmenter->Segment(image, interaction_roi));
|
||||
EXPECT_EQ(confidence_masks.size(), 2);
|
||||
}
|
||||
|
||||
// TODO: fix this unit test after image segmenter handled post
|
||||
// processing correctly with rotated image.
|
||||
TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||
interaction_roi.keypoint =
|
||||
components::containers::NormalizedKeypoint{0.25, 0.9};
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type =
|
||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = -90;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto confidence_masks,
|
||||
segmenter->Segment(image, interaction_roi, image_processing_options));
|
||||
EXPECT_EQ(confidence_masks.size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||
interaction_roi.keypoint =
|
||||
components::containers::NormalizedKeypoint{0.25, 0.9};
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type =
|
||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
auto results =
|
||||
segmenter->Segment(image, interaction_roi, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("This task doesn't support region-of-interest"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace interactive_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user