mediapipe/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc
MediaPipe Team 09662749ea Support scribble input for Interactive Segmenter
PiperOrigin-RevId: 529156049
2023-05-03 12:01:10 -07:00

211 lines
8.6 KiB
C++

/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
#include <memory>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace interactive_segmenter {
namespace {
constexpr char kConfidenceMasksStreamName[] = "confidence_masks";
constexpr char kCategoryMaskStreamName[] = "category_mask";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kRoiStreamName[] = "roi_in";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
constexpr absl::string_view kImageTag{"IMAGE"};
constexpr absl::string_view kRoiTag{"ROI"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kSubgraphTypeName{
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
using components::containers::NormalizedKeypoint;
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image;
using ::mediapipe::NormalizedRect;
using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult;
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
image_segmenter::proto::ImageSegmenterGraphOptions;
// Creates a MediaPipe graph config that only contains a single subgraph node of
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
bool output_confidence_masks, bool output_category_mask) {
api2::builder::Graph graph;
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kRoiTag).SetName(kRoiStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
if (output_confidence_masks) {
task_subgraph.Out(kConfidenceMasksTag)
.SetName(kConfidenceMasksStreamName) >>
graph.Out(kConfidenceMasksTag);
}
if (output_category_mask) {
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag);
}
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kRoiTag) >> task_subgraph.In(kRoiTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
return graph.GetConfig();
}
// Converts the user-facing InteractiveSegmenterOptions struct to the internal
// ImageSegmenterOptions proto.
std::unique_ptr<ImageSegmenterGraphOptionsProto>
ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
auto options_proto = std::make_unique<ImageSegmenterGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
return options_proto;
}
// Converts the user-facing RegionOfInterest struct to the RenderData proto that
// is used in subgraph.
absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
RenderData result;
switch (roi.format) {
case RegionOfInterest::Format::kUnspecified:
return absl::InvalidArgumentError(
"RegionOfInterest format not specified");
case RegionOfInterest::Format::kKeyPoint: {
RET_CHECK(roi.keypoint.has_value());
auto* annotation = result.add_render_annotations();
annotation->mutable_color()->set_r(255);
auto* point = annotation->mutable_point();
point->set_normalized(true);
point->set_x(roi.keypoint->x);
point->set_y(roi.keypoint->y);
return result;
}
case RegionOfInterest::Format::kScribble: {
RET_CHECK(roi.scribble.has_value());
auto* annotation = result.add_render_annotations();
annotation->mutable_color()->set_r(255);
for (const NormalizedKeypoint& keypoint : *(roi.scribble)) {
auto* point = annotation->mutable_scribble()->add_point();
point->set_normalized(true);
point->set_x(keypoint.x);
point->set_y(keypoint.y);
}
return result;
}
}
return absl::UnimplementedError("Unrecognized format");
}
} // namespace
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>>
InteractiveSegmenter::Create(
std::unique_ptr<InteractiveSegmenterOptions> options) {
if (!options->output_confidence_masks && !options->output_category_mask) {
return absl::InvalidArgumentError(
"At least one of `output_confidence_masks` and `output_category_mask` "
"must be set.");
}
std::unique_ptr<ImageSegmenterGraphOptionsProto> options_proto =
ConvertImageSegmenterOptionsToProto(options.get());
ASSIGN_OR_RETURN(
std::unique_ptr<InteractiveSegmenter> segmenter,
(core::VisionTaskApiFactory::Create<InteractiveSegmenter,
ImageSegmenterGraphOptionsProto>(
CreateGraphConfig(std::move(options_proto),
options->output_confidence_masks,
options->output_category_mask),
std::move(options->base_options.op_resolver),
core::RunningMode::IMAGE,
/*packets_callback=*/nullptr)));
segmenter->output_category_mask_ = options->output_category_mask;
segmenter->output_confidence_masks_ = options->output_confidence_masks;
return segmenter;
}
absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
mediapipe::Image image, const RegionOfInterest& roi,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, image,
/*roi_allowed=*/false));
ASSIGN_OR_RETURN(RenderData roi_as_render_data, ConvertRoiToRenderData(roi));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
{kRoiStreamName,
mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
std::optional<std::vector<Image>> confidence_masks;
if (output_confidence_masks_) {
confidence_masks =
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
}
std::optional<Image> category_mask;
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
}
} // namespace interactive_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe