Internal change

PiperOrigin-RevId: 522631851
This commit is contained in:
MediaPipe Team 2023-04-07 10:43:23 -07:00 committed by Copybara-Service
parent e3185e3df0
commit a1ce19f68e
5 changed files with 184 additions and 116 deletions

View File

@ -24,21 +24,26 @@ cc_library(
hdrs = ["interactive_segmenter.h"], hdrs = ["interactive_segmenter.h"],
deps = [ deps = [
":interactive_segmenter_graph", ":interactive_segmenter_graph",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers:keypoint", "//mediapipe/tasks/cc/components/containers:keypoint",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_result",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/util:color_cc_proto", "//mediapipe/util:color_cc_proto",
"//mediapipe/util:render_data_cc_proto", "//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
], ],
) )
@ -61,9 +66,12 @@ cc_library(
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/util:color_cc_proto", "//mediapipe/util:color_cc_proto",
"//mediapipe/util:graph_builder_utils",
"//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:render_data_cc_proto", "//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],

View File

@ -15,16 +15,24 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h" #include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
#include <memory>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/util/color.pb.h" #include "mediapipe/util/color.pb.h"
@ -36,23 +44,26 @@ namespace vision {
namespace interactive_segmenter { namespace interactive_segmenter {
namespace { namespace {
constexpr char kSegmentationStreamName[] = "segmented_mask_out"; constexpr char kConfidenceMasksStreamName[] = "confidence_masks";
constexpr char kCategoryMaskStreamName[] = "category_mask";
constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageOutStreamName[] = "image_out";
constexpr char kRoiStreamName[] = "roi_in"; constexpr char kRoiStreamName[] = "roi_in";
constexpr char kNormRectStreamName[] = "norm_rect_in"; constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
constexpr char kImageTag[] = "IMAGE"; constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
constexpr char kRoiTag[] = "ROI"; constexpr absl::string_view kImageTag{"IMAGE"};
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr absl::string_view kRoiTag{"ROI"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr char kSubgraphTypeName[] = constexpr absl::string_view kSubgraphTypeName{
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"; "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image; using ::mediapipe::Image;
using ::mediapipe::NormalizedRect; using ::mediapipe::NormalizedRect;
using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
image_segmenter::proto::ImageSegmenterGraphOptions; image_segmenter::proto::ImageSegmenterGraphOptions;
@ -60,7 +71,8 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
// Creates a MediaPipe graph config that only contains a single subgraph node of // Creates a MediaPipe graph config that only contains a single subgraph node of
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
CalculatorGraphConfig CreateGraphConfig( CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageSegmenterGraphOptionsProto> options) { std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
bool output_confidence_masks, bool output_category_mask) {
api2::builder::Graph graph; api2::builder::Graph graph;
auto& task_subgraph = graph.AddNode(kSubgraphTypeName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap( task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
@ -68,8 +80,15 @@ CalculatorGraphConfig CreateGraphConfig(
graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kRoiTag).SetName(kRoiStreamName); graph.In(kRoiTag).SetName(kRoiStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> if (output_confidence_masks) {
graph.Out(kGroupedSegmentationTag); task_subgraph.Out(kConfidenceMasksTag)
.SetName(kConfidenceMasksStreamName) >>
graph.Out(kConfidenceMasksTag);
}
if (output_category_mask) {
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag);
}
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag); graph.Out(kImageTag);
graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kImageTag) >> task_subgraph.In(kImageTag);
@ -86,16 +105,6 @@ ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>( auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get()); options_proto->mutable_base_options()->Swap(base_options_proto.get());
switch (options->output_type) {
case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CATEGORY_MASK);
break;
case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CONFIDENCE_MASK);
break;
}
return options_proto; return options_proto;
} }
@ -104,10 +113,10 @@ ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) { absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
RenderData result; RenderData result;
switch (roi.format) { switch (roi.format) {
case RegionOfInterest::UNSPECIFIED: case RegionOfInterest::Format::kUnspecified:
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
"RegionOfInterest format not specified"); "RegionOfInterest format not specified");
case RegionOfInterest::KEYPOINT: case RegionOfInterest::Format::kKeyPoint:
RET_CHECK(roi.keypoint.has_value()); RET_CHECK(roi.keypoint.has_value());
auto* annotation = result.add_render_annotations(); auto* annotation = result.add_render_annotations();
annotation->mutable_color()->set_r(255); annotation->mutable_color()->set_r(255);
@ -125,15 +134,29 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> absl::StatusOr<std::unique_ptr<InteractiveSegmenter>>
InteractiveSegmenter::Create( InteractiveSegmenter::Create(
std::unique_ptr<InteractiveSegmenterOptions> options) { std::unique_ptr<InteractiveSegmenterOptions> options) {
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); if (!options->output_confidence_masks && !options->output_category_mask) {
return core::VisionTaskApiFactory::Create<InteractiveSegmenter, return absl::InvalidArgumentError(
"At least one of `output_confidence_masks` and `output_category_mask` "
"must be set.");
}
std::unique_ptr<ImageSegmenterGraphOptionsProto> options_proto =
ConvertImageSegmenterOptionsToProto(options.get());
ASSIGN_OR_RETURN(
std::unique_ptr<InteractiveSegmenter> segmenter,
(core::VisionTaskApiFactory::Create<InteractiveSegmenter,
ImageSegmenterGraphOptionsProto>( ImageSegmenterGraphOptionsProto>(
CreateGraphConfig(std::move(options_proto)), CreateGraphConfig(std::move(options_proto),
std::move(options->base_options.op_resolver), core::RunningMode::IMAGE, options->output_confidence_masks,
/*packets_callback=*/nullptr); options->output_category_mask),
std::move(options->base_options.op_resolver),
core::RunningMode::IMAGE,
/*packets_callback=*/nullptr)));
segmenter->output_category_mask_ = options->output_category_mask;
segmenter->output_confidence_masks_ = options->output_confidence_masks;
return segmenter;
} }
absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment( absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
mediapipe::Image image, const RegionOfInterest& roi, mediapipe::Image image, const RegionOfInterest& roi,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -154,7 +177,16 @@ absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))}, mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))},
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}})); MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>(); std::optional<std::vector<Image>> confidence_masks;
if (output_confidence_masks_) {
confidence_masks =
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
}
std::optional<Image> category_mask;
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
} }
} // namespace interactive_segmenter } // namespace interactive_segmenter

View File

@ -21,12 +21,14 @@ limitations under the License.
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h" #include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -39,30 +41,24 @@ struct InteractiveSegmenterOptions {
// file with metadata, accelerator options, op resolver, etc. // file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options; tasks::core::BaseOptions base_options;
// The output type of segmentation results. // Whether to output confidence masks.
enum OutputType { bool output_confidence_masks = true;
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK = 0,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK = 1,
};
OutputType output_type = OutputType::CATEGORY_MASK; // Whether to output category mask.
bool output_category_mask = false;
}; };
// The Region-Of-Interest (ROI) to interact with. // The Region-Of-Interest (ROI) to interact with.
struct RegionOfInterest { struct RegionOfInterest {
enum Format { enum class Format {
UNSPECIFIED = 0, // Format not specified. kUnspecified = 0, // Format not specified.
KEYPOINT = 1, // Using keypoint to represent ROI. kKeyPoint = 1, // Using keypoint to represent ROI.
}; };
// Specifies the format used to specify the region-of-interest. Note that // Specifies the format used to specify the region-of-interest. Note that
// using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status // using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status
// being returned. // being returned.
Format format = Format::UNSPECIFIED; Format format = Format::kUnspecified;
// Represents the ROI in keypoint format, this should be non-nullopt if // Represents the ROI in keypoint format, this should be non-nullopt if
// `format` is `KEYPOINT`. // `format` is `KEYPOINT`.
@ -84,13 +80,11 @@ struct RegionOfInterest {
// - RGB inputs is supported (`channels` is required to be 3). // - RGB inputs is supported (`channels` is required to be 3).
// - if type is kTfLiteFloat32, NormalizationOptions are required to be // - if type is kTfLiteFloat32, NormalizationOptions are required to be
// attached to the metadata for input normalization. // attached to the metadata for input normalization.
// Output tensors: // Output ImageSegmenterResult:
// (kTfLiteUInt8/kTfLiteFloat32) // Provides optional confidence masks if `output_confidence_masks` is set
// - list of segmented masks. // true, and an optional category mask if `output_category_mask` is set
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. // true. At least one of `output_confidence_masks` and `output_category_mask`
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size // must be set to true.
// `channels`.
// - batch is always 1
class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi { class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
public: public:
using BaseVisionTaskApi::BaseVisionTaskApi; using BaseVisionTaskApi::BaseVisionTaskApi;
@ -114,18 +108,17 @@ class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
// setting its 'rotation_degrees' field. Note that specifying a // setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported // region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned. // and will result in an invalid argument error being returned.
// absl::StatusOr<image_segmenter::ImageSegmenterResult> Segment(
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
mediapipe::Image image, const RegionOfInterest& roi, mediapipe::Image image, const RegionOfInterest& roi,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
// Shuts down the InteractiveSegmenter when all works are done. // Shuts down the InteractiveSegmenter when all works are done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
private:
bool output_confidence_masks_;
bool output_category_mask_;
}; };
} // namespace interactive_segmenter } // namespace interactive_segmenter

View File

@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <vector>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" #include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
@ -23,8 +26,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/util/color.pb.h" #include "mediapipe/util/color.pb.h"
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/graph_builder_utils.h"
#include "mediapipe/util/render_data.pb.h" #include "mediapipe/util/render_data.pb.h"
namespace mediapipe { namespace mediapipe {
@ -42,16 +46,18 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
constexpr char kSegmentationTag[] = "SEGMENTATION"; constexpr absl::string_view kSegmentationTag{"SEGMENTATION"};
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr absl::string_view kGroupedSegmentationTag{"GROUPED_SEGMENTATION"};
constexpr char kImageTag[] = "IMAGE"; constexpr absl::string_view kConfidenceMaskTag{"CONFIDENCE_MASK"};
constexpr char kImageCpuTag[] = "IMAGE_CPU"; constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
constexpr char kAlphaTag[] = "ALPHA"; constexpr absl::string_view kImageTag{"IMAGE"};
constexpr char kAlphaGpuTag[] = "ALPHA_GPU"; constexpr absl::string_view kImageCpuTag{"IMAGE_CPU"};
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr absl::string_view kImageGpuTag{"IMAGE_GPU"};
constexpr char kRoiTag[] = "ROI"; constexpr absl::string_view kAlphaTag{"ALPHA"};
constexpr char kVideoTag[] = "VIDEO"; constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kRoiTag{"ROI"};
// Updates the graph to return `roi` stream which has same dimension as // Updates the graph to return `roi` stream which has same dimension as
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is // `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
@ -87,11 +93,10 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
} // namespace } // namespace
// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph" // An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
// performs semantic segmentation given user's region-of-interest. Two kinds of // performs semantic segmentation given the user's region-of-interest. The graph
// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can // can output optional confidence masks if CONFIDENCE_MASKS is connected, and an
// retrieve segmented mask of only particular category/channel from // optional category mask if CATEGORY_MASK is connected. At least one of
// SEGMENTATION, and users can also get all segmented masks from // CONFIDENCE_MASK, CONFIDENCE_MASKS and CATEGORY_MASK must be connected.
// GROUPED_SEGMENTATION.
// - Accepts CPU input images and outputs segmented masks on CPU. // - Accepts CPU input images and outputs segmented masks on CPU.
// //
// Inputs: // Inputs:
@ -106,11 +111,13 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
// @Optional: rect covering the whole image is used if not specified. // @Optional: rect covering the whole image is used if not specified.
// //
// Outputs: // Outputs:
// SEGMENTATION - mediapipe::Image @Multiple // CONFIDENCE_MASK - mediapipe::Image @Multiple
// Segmented masks for individual category. Segmented mask of single // Confidence masks for individual category. Confidence mask of single
// category can be accessed by index based output stream. // category can be accessed by index based output stream.
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image> // CONFIDENCE_MASKS - std::vector<mediapipe::Image> @Optional
// The output segmented masks grouped in a vector. // The output confidence masks grouped in a vector.
// CATEGORY_MASK - mediapipe::Image @Optional
// Optional Category mask.
// IMAGE - mediapipe::Image // IMAGE - mediapipe::Image
// The image that image segmenter runs on. // The image that image segmenter runs on.
// //
@ -129,9 +136,6 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
// file_name: "/path/to/model.tflite" // file_name: "/path/to/model.tflite"
// } // }
// } // }
// segmenter_options {
// output_type: CONFIDENCE_MASK
// }
// } // }
// } // }
// } // }
@ -176,10 +180,26 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
image_with_set_alpha >> image_segmenter.In(kImageTag); image_with_set_alpha >> image_segmenter.In(kImageTag);
norm_rect >> image_segmenter.In(kNormRectTag); norm_rect >> image_segmenter.In(kNormRectTag);
// TODO: remove deprecated output type support.
if (task_options.segmenter_options().has_output_type()) {
image_segmenter.Out(kSegmentationTag) >> image_segmenter.Out(kSegmentationTag) >>
graph[Output<Image>(kSegmentationTag)]; graph[Output<Image>(kSegmentationTag)];
image_segmenter.Out(kGroupedSegmentationTag) >> image_segmenter.Out(kGroupedSegmentationTag) >>
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)]; graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
} else {
if (HasOutput(sc->OriginalNode(), kConfidenceMaskTag)) {
image_segmenter.Out(kConfidenceMaskTag) >>
graph[Output<Image>(kConfidenceMaskTag)];
}
if (HasOutput(sc->OriginalNode(), kConfidenceMasksTag)) {
image_segmenter.Out(kConfidenceMasksTag) >>
graph[Output<Image>(kConfidenceMasksTag)];
}
if (HasOutput(sc->OriginalNode(), kCategoryMaskTag)) {
image_segmenter.Out(kCategoryMaskTag) >>
graph[Output<Image>(kCategoryMaskTag)];
}
}
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)]; image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
return graph.GetConfig(); return graph.GetConfig();

View File

@ -17,8 +17,11 @@ limitations under the License.
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include "absl/flags/flag.h" #include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
@ -39,6 +42,7 @@ limitations under the License.
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h" #include "tensorflow/lite/mutable_op_resolver.h"
#include "testing/base/public/gmock.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -53,13 +57,16 @@ using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
using ::testing::SizeIs;
using ::testing::status::StatusIs;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr absl::string_view kTestDataDirectory{
constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite"; "/mediapipe/tasks/testdata/vision/"};
constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg"; constexpr absl::string_view kPtmModel{"ptm_512_hdt_ptm_woid.tflite"};
constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"};
// Golden mask for the dogs in cats_and_dogs.jpg. // Golden mask for the dogs in cats_and_dogs.jpg.
constexpr char kCatsAndDogsMaskDog1[] = "cats_and_dogs_mask_dog1.png"; constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"};
constexpr char kCatsAndDogsMaskDog2[] = "cats_and_dogs_mask_dog2.png"; constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"};
constexpr float kGoldenMaskSimilarity = 0.97; constexpr float kGoldenMaskSimilarity = 0.97;
@ -135,35 +142,45 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
options->base_options.op_resolver = options->base_options.op_resolver =
absl::make_unique<DeepLabOpResolverMissingOps>(); absl::make_unique<DeepLabOpResolverMissingOps>();
auto segmenter_or = InteractiveSegmenter::Create(std::move(options)); auto segmenter = InteractiveSegmenter::Create(std::move(options));
// TODO: Make MediaPipe InferenceCalculator report the detailed // TODO: Make MediaPipe InferenceCalculator report the detailed
// interpreter errors (e.g., "Encountered unresolved custom op"). // interpreter errors (e.g., "Encountered unresolved custom op").
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal); EXPECT_EQ(segmenter.status().code(), absl::StatusCode::kInternal);
EXPECT_THAT( EXPECT_THAT(
segmenter_or.status().message(), segmenter.status().message(),
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
} }
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter_or = absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter =
InteractiveSegmenter::Create( InteractiveSegmenter::Create(
std::make_unique<InteractiveSegmenterOptions>()); std::make_unique<InteractiveSegmenterOptions>());
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(segmenter.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT( EXPECT_THAT(
segmenter_or.status().message(), segmenter.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', " HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload), EXPECT_THAT(segmenter.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat( Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError)))); MediaPipeTasksStatus::kRunnerInitializationError))));
} }
TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) {
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->output_category_mask = false;
options->output_confidence_masks = false;
EXPECT_THAT(InteractiveSegmenter::Create(std::move(options)),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("At least one of")));
}
struct InteractiveSegmenterTestParams { struct InteractiveSegmenterTestParams {
std::string test_name; std::string test_name;
RegionOfInterest::Format format; RegionOfInterest::Format format;
NormalizedKeypoint roi; NormalizedKeypoint roi;
std::string golden_mask_file; absl::string_view golden_mask_file;
float similarity_threshold; float similarity_threshold;
}; };
@ -181,16 +198,18 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK; options->output_confidence_masks = false;
options->output_category_mask = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options))); InteractiveSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_masks, MP_ASSERT_OK_AND_ASSIGN(auto result,
segmenter->Segment(image, interaction_roi)); segmenter->Segment(image, interaction_roi));
EXPECT_EQ(category_masks.size(), 1); EXPECT_TRUE(result.category_mask.has_value());
EXPECT_FALSE(result.confidence_masks.has_value());
cv::Mat actual_mask = mediapipe::formats::MatView( cv::Mat actual_mask = mediapipe::formats::MatView(
category_masks[0].GetImageFrameSharedPtr().get()); result.category_mask->GetImageFrameSharedPtr().get());
cv::Mat expected_mask = cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file), cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
@ -211,14 +230,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options))); InteractiveSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, MP_ASSERT_OK_AND_ASSIGN(auto result,
segmenter->Segment(image, interaction_roi)); segmenter->Segment(image, interaction_roi));
EXPECT_EQ(confidence_masks.size(), 2); EXPECT_FALSE(result.category_mask.has_value());
EXPECT_THAT(result.confidence_masks, Optional(SizeIs(2)));
cv::Mat expected_mask = cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file), cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
@ -227,7 +245,7 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
cv::Mat actual_mask = mediapipe::formats::MatView( cv::Mat actual_mask = mediapipe::formats::MatView(
confidence_masks[1].GetImageFrameSharedPtr().get()); result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float, EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
params.similarity_threshold)); params.similarity_threshold));
} }
@ -235,9 +253,9 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi, SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
::testing::ValuesIn<InteractiveSegmenterTestParams>( ::testing::ValuesIn<InteractiveSegmenterTestParams>(
{{"PointToDog1", RegionOfInterest::KEYPOINT, {{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f}, NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
{"PointToDog2", RegionOfInterest::KEYPOINT, {"PointToDog2", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2, NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
kGoldenMaskSimilarity}}), kGoldenMaskSimilarity}}),
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>& [](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
@ -252,22 +270,21 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
Image image, Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi; RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT; interaction_roi.format = RegionOfInterest::Format::kKeyPoint;
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66}; interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options))); InteractiveSegmenter::Create(std::move(options)));
ImageProcessingOptions image_processing_options; ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90; image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto confidence_masks, auto result,
segmenter->Segment(image, interaction_roi, image_processing_options)); segmenter->Segment(image, interaction_roi, image_processing_options));
EXPECT_EQ(confidence_masks.size(), 2); EXPECT_FALSE(result.category_mask.has_value());
EXPECT_EQ(result.confidence_masks->size(), 2);
} }
TEST_F(ImageModeTest, FailsWithRegionOfInterest) { TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
@ -275,13 +292,11 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
Image image, Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi; RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT; interaction_roi.format = RegionOfInterest::Format::kKeyPoint;
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66}; interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options))); InteractiveSegmenter::Create(std::move(options)));