Add Interactive Segmenter MediaPipe Task
PiperOrigin-RevId: 516954589
This commit is contained in:
parent
43082482f8
commit
61bcddc671
76
mediapipe/tasks/cc/vision/interactive_segmenter/BUILD
Normal file
76
mediapipe/tasks/cc/vision/interactive_segmenter/BUILD
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# Docs for Mediapipe Tasks Interactive Segmenter
|
||||||
|
# TODO: add doc link.
|
||||||
|
cc_library(
|
||||||
|
name = "interactive_segmenter",
|
||||||
|
srcs = ["interactive_segmenter.cc"],
|
||||||
|
hdrs = ["interactive_segmenter.h"],
|
||||||
|
deps = [
|
||||||
|
":interactive_segmenter_graph",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/formats:image",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:keypoint",
|
||||||
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
|
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||||
|
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||||
|
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||||
|
"//mediapipe/util:color_cc_proto",
|
||||||
|
"//mediapipe/util:render_data_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "interactive_segmenter_graph",
|
||||||
|
srcs = ["interactive_segmenter_graph.cc"],
|
||||||
|
deps = [
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"//mediapipe/calculators/image:set_alpha_calculator",
|
||||||
|
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
||||||
|
"//mediapipe/calculators/util:flat_color_image_calculator",
|
||||||
|
"//mediapipe/calculators/util:flat_color_image_calculator_cc_proto",
|
||||||
|
"//mediapipe/calculators/util:from_image_calculator",
|
||||||
|
"//mediapipe/calculators/util:to_image_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:image",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||||
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||||
|
"//mediapipe/util:color_cc_proto",
|
||||||
|
"//mediapipe/util:label_map_cc_proto",
|
||||||
|
"//mediapipe/util:render_data_cc_proto",
|
||||||
|
] + select({
|
||||||
|
"//mediapipe/gpu:disable_gpu": [],
|
||||||
|
"//conditions:default": [
|
||||||
|
"//mediapipe/gpu:gpu_buffer_to_image_frame_calculator",
|
||||||
|
"//mediapipe/gpu:image_frame_to_gpu_buffer_calculator",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
|
@ -0,0 +1,163 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||||
|
#include "mediapipe/util/color.pb.h"
|
||||||
|
#include "mediapipe/util/render_data.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace vision {
|
||||||
|
namespace interactive_segmenter {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
|
||||||
|
constexpr char kImageInStreamName[] = "image_in";
|
||||||
|
constexpr char kImageOutStreamName[] = "image_out";
|
||||||
|
constexpr char kRoiStreamName[] = "roi_in";
|
||||||
|
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||||
|
|
||||||
|
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||||
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
constexpr char kRoiTag[] = "ROI";
|
||||||
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
|
|
||||||
|
constexpr char kSubgraphTypeName[] =
|
||||||
|
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
||||||
|
|
||||||
|
using ::mediapipe::CalculatorGraphConfig;
|
||||||
|
using ::mediapipe::Image;
|
||||||
|
using ::mediapipe::NormalizedRect;
|
||||||
|
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
||||||
|
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||||
|
image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||||
|
|
||||||
|
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||||
|
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
|
||||||
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
|
std::unique_ptr<ImageSegmenterGraphOptionsProto> options) {
|
||||||
|
api2::builder::Graph graph;
|
||||||
|
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
|
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
|
||||||
|
options.get());
|
||||||
|
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||||
|
graph.In(kRoiTag).SetName(kRoiStreamName);
|
||||||
|
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||||
|
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
||||||
|
graph.Out(kGroupedSegmentationTag);
|
||||||
|
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
|
graph.Out(kImageTag);
|
||||||
|
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||||
|
graph.In(kRoiTag) >> task_subgraph.In(kRoiTag);
|
||||||
|
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts the user-facing InteractiveSegmenterOptions struct to the internal
|
||||||
|
// ImageSegmenterOptions proto.
|
||||||
|
std::unique_ptr<ImageSegmenterGraphOptionsProto>
|
||||||
|
ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
|
||||||
|
auto options_proto = std::make_unique<ImageSegmenterGraphOptionsProto>();
|
||||||
|
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||||
|
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||||
|
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||||
|
switch (options->output_type) {
|
||||||
|
case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK:
|
||||||
|
options_proto->mutable_segmenter_options()->set_output_type(
|
||||||
|
SegmenterOptions::CATEGORY_MASK);
|
||||||
|
break;
|
||||||
|
case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK:
|
||||||
|
options_proto->mutable_segmenter_options()->set_output_type(
|
||||||
|
SegmenterOptions::CONFIDENCE_MASK);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return options_proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts the user-facing RegionOfInterest struct to the RenderData proto that
|
||||||
|
// is used in subgraph.
|
||||||
|
absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
||||||
|
RenderData result;
|
||||||
|
switch (roi.format) {
|
||||||
|
case RegionOfInterest::UNSPECIFIED:
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"RegionOfInterest format not specified");
|
||||||
|
case RegionOfInterest::KEYPOINT:
|
||||||
|
RET_CHECK(roi.keypoint.has_value());
|
||||||
|
auto* annotation = result.add_render_annotations();
|
||||||
|
annotation->mutable_color()->set_r(255);
|
||||||
|
auto* point = annotation->mutable_point();
|
||||||
|
point->set_normalized(true);
|
||||||
|
point->set_x(roi.keypoint->x);
|
||||||
|
point->set_y(roi.keypoint->y);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
return absl::UnimplementedError("Unrecognized format");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>>
|
||||||
|
InteractiveSegmenter::Create(
|
||||||
|
std::unique_ptr<InteractiveSegmenterOptions> options) {
|
||||||
|
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
|
||||||
|
return core::VisionTaskApiFactory::Create<InteractiveSegmenter,
|
||||||
|
ImageSegmenterGraphOptionsProto>(
|
||||||
|
CreateGraphConfig(std::move(options_proto)),
|
||||||
|
std::move(options->base_options.op_resolver), core::RunningMode::IMAGE,
|
||||||
|
/*packets_callback=*/nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
|
||||||
|
mediapipe::Image image, const RegionOfInterest& roi,
|
||||||
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
|
if (image.UsesGpu()) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrCat("GPU input images are currently not supported."),
|
||||||
|
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
NormalizedRect norm_rect,
|
||||||
|
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||||
|
ASSIGN_OR_RETURN(RenderData roi_as_render_data, ConvertRoiToRenderData(roi));
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
auto output_packets,
|
||||||
|
ProcessImageData(
|
||||||
|
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
|
||||||
|
{kRoiStreamName,
|
||||||
|
mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))},
|
||||||
|
{kNormRectStreamName,
|
||||||
|
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
|
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace interactive_segmenter
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,136 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace vision {
|
||||||
|
namespace interactive_segmenter {
|
||||||
|
|
||||||
|
// The options for configuring a mediapipe interactive segmenter task.
|
||||||
|
struct InteractiveSegmenterOptions {
|
||||||
|
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||||
|
// file with metadata, accelerator options, op resolver, etc.
|
||||||
|
tasks::core::BaseOptions base_options;
|
||||||
|
|
||||||
|
// The output type of segmentation results.
|
||||||
|
enum OutputType {
|
||||||
|
// Gives a single output mask where each pixel represents the class which
|
||||||
|
// the pixel in the original image was predicted to belong to.
|
||||||
|
CATEGORY_MASK = 0,
|
||||||
|
// Gives a list of output masks where, for each mask, each pixel represents
|
||||||
|
// the prediction confidence, usually in the [0, 1] range.
|
||||||
|
CONFIDENCE_MASK = 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
OutputType output_type = OutputType::CATEGORY_MASK;
|
||||||
|
};
|
||||||
|
|
||||||
|
// The Region-Of-Interest (ROI) to interact with.
|
||||||
|
struct RegionOfInterest {
|
||||||
|
enum Format {
|
||||||
|
UNSPECIFIED = 0, // Format not specified.
|
||||||
|
KEYPOINT = 1, // Using keypoint to represent ROI.
|
||||||
|
};
|
||||||
|
|
||||||
|
// Specifies the format used to specify the region-of-interest. Note that
|
||||||
|
// using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status
|
||||||
|
// being returned.
|
||||||
|
Format format = Format::UNSPECIFIED;
|
||||||
|
|
||||||
|
// Represents the ROI in keypoint format, this should be non-nullopt if
|
||||||
|
// `format` is `KEYPOINT`.
|
||||||
|
std::optional<components::containers::NormalizedKeypoint> keypoint;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Performs interactive segmentation on images.
|
||||||
|
//
|
||||||
|
// Users can represent user interaction through `RegionOfInterest`, which gives
|
||||||
|
// a hint to InteractiveSegmenter to perform segmentation focusing on the given
|
||||||
|
// region of interest.
|
||||||
|
//
|
||||||
|
// The API expects a TFLite model with mandatory TFLite Model Metadata.
|
||||||
|
//
|
||||||
|
// Input tensor:
|
||||||
|
// (kTfLiteUInt8/kTfLiteFloat32)
|
||||||
|
// - image input of size `[batch x height x width x channels]`.
|
||||||
|
// - batch inference is not supported (`batch` is required to be 1).
|
||||||
|
// - RGB inputs is supported (`channels` is required to be 3).
|
||||||
|
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
||||||
|
// attached to the metadata for input normalization.
|
||||||
|
// Output tensors:
|
||||||
|
// (kTfLiteUInt8/kTfLiteFloat32)
|
||||||
|
// - list of segmented masks.
|
||||||
|
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
|
||||||
|
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
|
||||||
|
// `channels`.
|
||||||
|
// - batch is always 1
|
||||||
|
class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
|
public:
|
||||||
|
using BaseVisionTaskApi::BaseVisionTaskApi;
|
||||||
|
|
||||||
|
// Creates an InteractiveSegmenter from the provided options. A non-default
|
||||||
|
// OpResolver can be specified in the BaseOptions of
|
||||||
|
// InteractiveSegmenterOptions, to support custom Ops of the segmentation
|
||||||
|
// model.
|
||||||
|
static absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> Create(
|
||||||
|
std::unique_ptr<InteractiveSegmenterOptions> options);
|
||||||
|
|
||||||
|
// Performs image segmentation on the provided single image.
|
||||||
|
//
|
||||||
|
// The image can be of any size with format RGB.
|
||||||
|
//
|
||||||
|
// The `roi` parameter is used to represent user's region of interest for
|
||||||
|
// segmentation.
|
||||||
|
//
|
||||||
|
// The optional 'image_processing_options' parameter can be used to specify
|
||||||
|
// the rotation to apply to the image before performing segmentation, by
|
||||||
|
// setting its 'rotation_degrees' field. Note that specifying a
|
||||||
|
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||||
|
// and will result in an invalid argument error being returned.
|
||||||
|
//
|
||||||
|
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||||
|
// per-category segmented image mask.
|
||||||
|
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||||
|
// contains only one confidence image mask.
|
||||||
|
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
|
||||||
|
mediapipe::Image image, const RegionOfInterest& roi,
|
||||||
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
|
std::nullopt);
|
||||||
|
|
||||||
|
// Shuts down the InteractiveSegmenter when all works are done.
|
||||||
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace interactive_segmenter
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
|
|
@ -0,0 +1,198 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
|
#include "mediapipe/util/color.pb.h"
|
||||||
|
#include "mediapipe/util/label_map.pb.h"
|
||||||
|
#include "mediapipe/util/render_data.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace vision {
|
||||||
|
namespace interactive_segmenter {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||||
|
using ::mediapipe::Image;
|
||||||
|
using ::mediapipe::NormalizedRect;
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::mediapipe::api2::builder::Graph;
|
||||||
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
|
||||||
|
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
||||||
|
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||||
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
constexpr char kImageCpuTag[] = "IMAGE_CPU";
|
||||||
|
constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
||||||
|
constexpr char kAlphaTag[] = "ALPHA";
|
||||||
|
constexpr char kAlphaGpuTag[] = "ALPHA_GPU";
|
||||||
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
|
constexpr char kRoiTag[] = "ROI";
|
||||||
|
constexpr char kVideoTag[] = "VIDEO";
|
||||||
|
|
||||||
|
// Updates the graph to return `roi` stream which has same dimension as
|
||||||
|
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
||||||
|
// in GpuBuffer format, otherwise using ImageFrame.
|
||||||
|
Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
|
||||||
|
Graph& graph) {
|
||||||
|
// TODO: Replace with efficient implementation.
|
||||||
|
const absl::string_view image_tag_with_suffix =
|
||||||
|
use_gpu ? kImageGpuTag : kImageCpuTag;
|
||||||
|
|
||||||
|
// Generates a blank canvas with same size as input image.
|
||||||
|
auto& flat_color = graph.AddNode("FlatColorImageCalculator");
|
||||||
|
auto& flat_color_options =
|
||||||
|
flat_color.GetOptions<FlatColorImageCalculatorOptions>();
|
||||||
|
// SetAlphaCalculator only takes 1st channel.
|
||||||
|
flat_color_options.mutable_color()->set_r(0);
|
||||||
|
image >> flat_color.In(kImageTag)[0];
|
||||||
|
auto blank_canvas = flat_color.Out(kImageTag)[0];
|
||||||
|
|
||||||
|
auto& from_mp_image = graph.AddNode("FromImageCalculator");
|
||||||
|
blank_canvas >> from_mp_image.In(kImageTag);
|
||||||
|
auto blank_canvas_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
|
||||||
|
|
||||||
|
auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
|
||||||
|
blank_canvas_in_cpu_or_gpu >>
|
||||||
|
roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
|
||||||
|
roi >> roi_to_alpha.In(0);
|
||||||
|
auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
|
||||||
|
|
||||||
|
return alpha;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
|
||||||
|
// performs semantic segmentation given user's region-of-interest. Two kinds of
|
||||||
|
// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can
|
||||||
|
// retrieve segmented mask of only particular category/channel from
|
||||||
|
// SEGMENTATION, and users can also get all segmented masks from
|
||||||
|
// GROUPED_SEGMENTATION.
|
||||||
|
// - Accepts CPU input images and outputs segmented masks on CPU.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// IMAGE - Image
|
||||||
|
// Image to perform segmentation on.
|
||||||
|
// ROI - RenderData proto
|
||||||
|
// Region of interest based on user interaction. Currently only support
|
||||||
|
// Point format, and Color has to be (255, 255, 255).
|
||||||
|
// NORM_RECT - NormalizedRect @Optional
|
||||||
|
// Describes image rotation and region of image to perform detection
|
||||||
|
// on.
|
||||||
|
// @Optional: rect covering the whole image is used if not specified.
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// SEGMENTATION - mediapipe::Image @Multiple
|
||||||
|
// Segmented masks for individual category. Segmented mask of single
|
||||||
|
// category can be accessed by index based output stream.
|
||||||
|
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image>
|
||||||
|
// The output segmented masks grouped in a vector.
|
||||||
|
// IMAGE - mediapipe::Image
|
||||||
|
// The image that image segmenter runs on.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// node {
|
||||||
|
// calculator:
|
||||||
|
// "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
|
||||||
|
// input_stream: "IMAGE:image"
|
||||||
|
// input_stream: "ROI:region_of_interest"
|
||||||
|
// output_stream: "SEGMENTATION:segmented_masks"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterGraphOptions.ext]
|
||||||
|
// {
|
||||||
|
// base_options {
|
||||||
|
// model_asset {
|
||||||
|
// file_name: "/path/to/model.tflite"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// segmenter_options {
|
||||||
|
// output_type: CONFIDENCE_MASK
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
class InteractiveSegmenterGraph : public core::ModelTaskGraph {
|
||||||
|
public:
|
||||||
|
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||||
|
mediapipe::SubgraphContext* sc) override {
|
||||||
|
Graph graph;
|
||||||
|
const auto& task_options = sc->Options<ImageSegmenterGraphOptions>();
|
||||||
|
bool use_gpu =
|
||||||
|
components::processors::DetermineImagePreprocessingGpuBackend(
|
||||||
|
task_options.base_options().acceleration());
|
||||||
|
|
||||||
|
Source<Image> image = graph[Input<Image>(kImageTag)];
|
||||||
|
Source<RenderData> roi = graph[Input<RenderData>(kRoiTag)];
|
||||||
|
Source<NormalizedRect> norm_rect =
|
||||||
|
graph[Input<NormalizedRect>(kNormRectTag)];
|
||||||
|
const absl::string_view image_tag_with_suffix =
|
||||||
|
use_gpu ? kImageGpuTag : kImageCpuTag;
|
||||||
|
const absl::string_view alpha_tag_with_suffix =
|
||||||
|
use_gpu ? kAlphaGpuTag : kAlphaTag;
|
||||||
|
|
||||||
|
auto& from_mp_image = graph.AddNode("FromImageCalculator");
|
||||||
|
image >> from_mp_image.In(kImageTag);
|
||||||
|
auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
|
||||||
|
|
||||||
|
auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
|
||||||
|
|
||||||
|
auto& set_alpha = graph.AddNode("SetAlphaCalculator");
|
||||||
|
image_in_cpu_or_gpu >> set_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
|
||||||
|
alpha_in_cpu_or_gpu >> set_alpha.In(alpha_tag_with_suffix);
|
||||||
|
auto image_in_cpu_or_gpu_with_set_alpha =
|
||||||
|
set_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
|
||||||
|
|
||||||
|
auto& to_mp_image = graph.AddNode("ToImageCalculator");
|
||||||
|
image_in_cpu_or_gpu_with_set_alpha >> to_mp_image.In(image_tag_with_suffix);
|
||||||
|
auto image_with_set_alpha = to_mp_image.Out(kImageTag);
|
||||||
|
|
||||||
|
auto& image_segmenter = graph.AddNode(
|
||||||
|
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph");
|
||||||
|
image_segmenter.GetOptions<ImageSegmenterGraphOptions>() = task_options;
|
||||||
|
image_with_set_alpha >> image_segmenter.In(kImageTag);
|
||||||
|
norm_rect >> image_segmenter.In(kNormRectTag);
|
||||||
|
|
||||||
|
image_segmenter.Out(kSegmentationTag) >>
|
||||||
|
graph[Output<Image>(kSegmentationTag)];
|
||||||
|
image_segmenter.Out(kGroupedSegmentationTag) >>
|
||||||
|
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
||||||
|
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
||||||
|
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly.
|
||||||
|
// clang-format off
|
||||||
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
|
::mediapipe::tasks::vision::interactive_segmenter::InteractiveSegmenterGraph);
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
} // namespace interactive_segmenter
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,261 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/flags/flag.h"
|
||||||
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace vision {
|
||||||
|
namespace interactive_segmenter {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::Image;
|
||||||
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::mediapipe::tasks::components::containers::RectF;
|
||||||
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
using ::testing::Optional;
|
||||||
|
|
||||||
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||||
|
constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite";
|
||||||
|
constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg";
|
||||||
|
|
||||||
|
constexpr float kGoldenMaskSimilarity = 0.98;
|
||||||
|
|
||||||
|
// Magnification factor used when creating the golden category masks to make
|
||||||
|
// them more human-friendly. Each pixel in the golden masks has its value
|
||||||
|
// multiplied by this factor, i.e. a value of 10 means class index 1, a value of
|
||||||
|
// 20 means class index 2, etc.
|
||||||
|
constexpr int kGoldenMaskMagnificationFactor = 10;
|
||||||
|
|
||||||
|
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1
|
||||||
|
// as expected outputs are stored in CV_8UC1, so this conversion allows to do
|
||||||
|
// fair comparison.
|
||||||
|
cv::Mat PostProcessResultMask(const cv::Mat& mask) {
|
||||||
|
cv::Mat mask_float;
|
||||||
|
mask.convertTo(mask_float, CV_8UC1, 255);
|
||||||
|
mask_float.convertTo(mask_float, CV_32FC1, 1 / 255.f);
|
||||||
|
return mask_float;
|
||||||
|
}
|
||||||
|
|
||||||
|
double CalculateSum(const cv::Mat& m) {
|
||||||
|
double sum = 0.0;
|
||||||
|
cv::Scalar s = cv::sum(m);
|
||||||
|
for (int i = 0; i < m.channels(); ++i) {
|
||||||
|
sum += s.val[i];
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
double CalculateSoftIOU(const cv::Mat& m1, const cv::Mat& m2) {
|
||||||
|
cv::Mat intersection;
|
||||||
|
cv::multiply(m1, m2, intersection);
|
||||||
|
double intersection_value = CalculateSum(intersection);
|
||||||
|
double union_value =
|
||||||
|
CalculateSum(m1.mul(m1)) + CalculateSum(m2.mul(m2)) - intersection_value;
|
||||||
|
return union_value > 0.0 ? intersection_value / union_value : 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
MATCHER_P2(SimilarToFloatMask, expected_mask, similarity_threshold, "") {
|
||||||
|
cv::Mat actual_mask = PostProcessResultMask(arg);
|
||||||
|
return arg.rows == expected_mask.rows && arg.cols == expected_mask.cols &&
|
||||||
|
CalculateSoftIOU(arg, expected_mask) > similarity_threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold,
|
||||||
|
magnification_factor, "") {
|
||||||
|
if (arg.rows != expected_mask.rows || arg.cols != expected_mask.cols) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int consistent_pixels = 0;
|
||||||
|
const int num_pixels = expected_mask.rows * expected_mask.cols;
|
||||||
|
for (int i = 0; i < num_pixels; ++i) {
|
||||||
|
consistent_pixels +=
|
||||||
|
(arg.data[i] * magnification_factor == expected_mask.data[i]);
|
||||||
|
}
|
||||||
|
return static_cast<float>(consistent_pixels) / num_pixels >=
|
||||||
|
similarity_threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
|
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||||
|
public:
|
||||||
|
DeepLabOpResolverMissingOps() {
|
||||||
|
AddBuiltin(::tflite::BuiltinOperator_ADD,
|
||||||
|
::tflite::ops::builtin::Register_ADD());
|
||||||
|
}
|
||||||
|
|
||||||
|
DeepLabOpResolverMissingOps(const DeepLabOpResolverMissingOps& r) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
|
||||||
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
|
options->base_options.op_resolver =
|
||||||
|
absl::make_unique<DeepLabOpResolverMissingOps>();
|
||||||
|
auto segmenter_or = InteractiveSegmenter::Create(std::move(options));
|
||||||
|
// TODO: Make MediaPipe InferenceCalculator report the detailed
|
||||||
|
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||||
|
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
|
||||||
|
EXPECT_THAT(
|
||||||
|
segmenter_or.status().message(),
|
||||||
|
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||||
|
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter_or =
|
||||||
|
InteractiveSegmenter::Create(
|
||||||
|
std::make_unique<InteractiveSegmenterOptions>());
|
||||||
|
|
||||||
|
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(
|
||||||
|
segmenter_or.status().message(),
|
||||||
|
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||||
|
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||||
|
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||||
|
}
|
||||||
|
|
||||||
|
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Image image,
|
||||||
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||||
|
RegionOfInterest interaction_roi;
|
||||||
|
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||||
|
interaction_roi.keypoint =
|
||||||
|
components::containers::NormalizedKeypoint{0.25, 0.9};
|
||||||
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
|
options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK;
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||||
|
InteractiveSegmenter::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto category_masks,
|
||||||
|
segmenter->Segment(image, interaction_roi));
|
||||||
|
EXPECT_EQ(category_masks.size(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Image image,
|
||||||
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||||
|
RegionOfInterest interaction_roi;
|
||||||
|
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||||
|
interaction_roi.keypoint =
|
||||||
|
components::containers::NormalizedKeypoint{0.25, 0.9};
|
||||||
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
|
options->output_type =
|
||||||
|
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||||
|
InteractiveSegmenter::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
|
||||||
|
segmenter->Segment(image, interaction_roi));
|
||||||
|
EXPECT_EQ(confidence_masks.size(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: fix this unit test after image segmenter handled post
|
||||||
|
// processing correctly with rotated image.
|
||||||
|
TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Image image,
|
||||||
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||||
|
RegionOfInterest interaction_roi;
|
||||||
|
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||||
|
interaction_roi.keypoint =
|
||||||
|
components::containers::NormalizedKeypoint{0.25, 0.9};
|
||||||
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
|
options->output_type =
|
||||||
|
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||||
|
InteractiveSegmenter::Create(std::move(options)));
|
||||||
|
ImageProcessingOptions image_processing_options;
|
||||||
|
image_processing_options.rotation_degrees = -90;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto confidence_masks,
|
||||||
|
segmenter->Segment(image, interaction_roi, image_processing_options));
|
||||||
|
EXPECT_EQ(confidence_masks.size(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Image image,
|
||||||
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||||
|
RegionOfInterest interaction_roi;
|
||||||
|
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||||
|
interaction_roi.keypoint =
|
||||||
|
components::containers::NormalizedKeypoint{0.25, 0.9};
|
||||||
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
|
options->output_type =
|
||||||
|
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||||
|
InteractiveSegmenter::Create(std::move(options)));
|
||||||
|
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||||
|
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||||
|
|
||||||
|
auto results =
|
||||||
|
segmenter->Segment(image, interaction_roi, image_processing_options);
|
||||||
|
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(results.status().message(),
|
||||||
|
HasSubstr("This task doesn't support region-of-interest"));
|
||||||
|
EXPECT_THAT(
|
||||||
|
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace interactive_segmenter
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user