Internal change
PiperOrigin-RevId: 522631851
This commit is contained in:
parent
e3185e3df0
commit
a1ce19f68e
|
@ -24,21 +24,26 @@ cc_library(
|
||||||
hdrs = ["interactive_segmenter.h"],
|
hdrs = ["interactive_segmenter.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":interactive_segmenter_graph",
|
":interactive_segmenter_graph",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components/containers:keypoint",
|
"//mediapipe/tasks/cc/components/containers:keypoint",
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_result",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||||
"//mediapipe/util:color_cc_proto",
|
"//mediapipe/util:color_cc_proto",
|
||||||
"//mediapipe/util:render_data_cc_proto",
|
"//mediapipe/util:render_data_cc_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -61,9 +66,12 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||||
"//mediapipe/util:color_cc_proto",
|
"//mediapipe/util:color_cc_proto",
|
||||||
|
"//mediapipe/util:graph_builder_utils",
|
||||||
"//mediapipe/util:label_map_cc_proto",
|
"//mediapipe/util:label_map_cc_proto",
|
||||||
"//mediapipe/util:render_data_cc_proto",
|
"//mediapipe/util:render_data_cc_proto",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
] + select({
|
] + select({
|
||||||
"//mediapipe/gpu:disable_gpu": [],
|
"//mediapipe/gpu:disable_gpu": [],
|
||||||
|
|
|
@ -15,16 +15,24 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
|
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||||
#include "mediapipe/util/color.pb.h"
|
#include "mediapipe/util/color.pb.h"
|
||||||
|
@ -36,23 +44,26 @@ namespace vision {
|
||||||
namespace interactive_segmenter {
|
namespace interactive_segmenter {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
|
constexpr char kConfidenceMasksStreamName[] = "confidence_masks";
|
||||||
|
constexpr char kCategoryMaskStreamName[] = "category_mask";
|
||||||
constexpr char kImageInStreamName[] = "image_in";
|
constexpr char kImageInStreamName[] = "image_in";
|
||||||
constexpr char kImageOutStreamName[] = "image_out";
|
constexpr char kImageOutStreamName[] = "image_out";
|
||||||
constexpr char kRoiStreamName[] = "roi_in";
|
constexpr char kRoiStreamName[] = "roi_in";
|
||||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||||
|
|
||||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
|
||||||
constexpr char kRoiTag[] = "ROI";
|
constexpr absl::string_view kImageTag{"IMAGE"};
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr absl::string_view kRoiTag{"ROI"};
|
||||||
|
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||||
|
|
||||||
constexpr char kSubgraphTypeName[] =
|
constexpr absl::string_view kSubgraphTypeName{
|
||||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
||||||
|
|
||||||
using ::mediapipe::CalculatorGraphConfig;
|
using ::mediapipe::CalculatorGraphConfig;
|
||||||
using ::mediapipe::Image;
|
using ::mediapipe::Image;
|
||||||
using ::mediapipe::NormalizedRect;
|
using ::mediapipe::NormalizedRect;
|
||||||
|
using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult;
|
||||||
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
||||||
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||||
image_segmenter::proto::ImageSegmenterGraphOptions;
|
image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||||
|
@ -60,7 +71,8 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||||
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
|
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
|
||||||
CalculatorGraphConfig CreateGraphConfig(
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
std::unique_ptr<ImageSegmenterGraphOptionsProto> options) {
|
std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
|
||||||
|
bool output_confidence_masks, bool output_category_mask) {
|
||||||
api2::builder::Graph graph;
|
api2::builder::Graph graph;
|
||||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
|
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
|
||||||
|
@ -68,8 +80,15 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||||
graph.In(kRoiTag).SetName(kRoiStreamName);
|
graph.In(kRoiTag).SetName(kRoiStreamName);
|
||||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||||
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
if (output_confidence_masks) {
|
||||||
graph.Out(kGroupedSegmentationTag);
|
task_subgraph.Out(kConfidenceMasksTag)
|
||||||
|
.SetName(kConfidenceMasksStreamName) >>
|
||||||
|
graph.Out(kConfidenceMasksTag);
|
||||||
|
}
|
||||||
|
if (output_category_mask) {
|
||||||
|
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||||
|
graph.Out(kCategoryMaskTag);
|
||||||
|
}
|
||||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
graph.Out(kImageTag);
|
graph.Out(kImageTag);
|
||||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||||
|
@ -86,16 +105,6 @@ ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
|
||||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||||
switch (options->output_type) {
|
|
||||||
case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK:
|
|
||||||
options_proto->mutable_segmenter_options()->set_output_type(
|
|
||||||
SegmenterOptions::CATEGORY_MASK);
|
|
||||||
break;
|
|
||||||
case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK:
|
|
||||||
options_proto->mutable_segmenter_options()->set_output_type(
|
|
||||||
SegmenterOptions::CONFIDENCE_MASK);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return options_proto;
|
return options_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,10 +113,10 @@ ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
|
||||||
absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
||||||
RenderData result;
|
RenderData result;
|
||||||
switch (roi.format) {
|
switch (roi.format) {
|
||||||
case RegionOfInterest::UNSPECIFIED:
|
case RegionOfInterest::Format::kUnspecified:
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"RegionOfInterest format not specified");
|
"RegionOfInterest format not specified");
|
||||||
case RegionOfInterest::KEYPOINT:
|
case RegionOfInterest::Format::kKeyPoint:
|
||||||
RET_CHECK(roi.keypoint.has_value());
|
RET_CHECK(roi.keypoint.has_value());
|
||||||
auto* annotation = result.add_render_annotations();
|
auto* annotation = result.add_render_annotations();
|
||||||
annotation->mutable_color()->set_r(255);
|
annotation->mutable_color()->set_r(255);
|
||||||
|
@ -125,15 +134,29 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
||||||
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>>
|
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>>
|
||||||
InteractiveSegmenter::Create(
|
InteractiveSegmenter::Create(
|
||||||
std::unique_ptr<InteractiveSegmenterOptions> options) {
|
std::unique_ptr<InteractiveSegmenterOptions> options) {
|
||||||
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
|
if (!options->output_confidence_masks && !options->output_category_mask) {
|
||||||
return core::VisionTaskApiFactory::Create<InteractiveSegmenter,
|
return absl::InvalidArgumentError(
|
||||||
|
"At least one of `output_confidence_masks` and `output_category_mask` "
|
||||||
|
"must be set.");
|
||||||
|
}
|
||||||
|
std::unique_ptr<ImageSegmenterGraphOptionsProto> options_proto =
|
||||||
|
ConvertImageSegmenterOptionsToProto(options.get());
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||||
|
(core::VisionTaskApiFactory::Create<InteractiveSegmenter,
|
||||||
ImageSegmenterGraphOptionsProto>(
|
ImageSegmenterGraphOptionsProto>(
|
||||||
CreateGraphConfig(std::move(options_proto)),
|
CreateGraphConfig(std::move(options_proto),
|
||||||
std::move(options->base_options.op_resolver), core::RunningMode::IMAGE,
|
options->output_confidence_masks,
|
||||||
/*packets_callback=*/nullptr);
|
options->output_category_mask),
|
||||||
|
std::move(options->base_options.op_resolver),
|
||||||
|
core::RunningMode::IMAGE,
|
||||||
|
/*packets_callback=*/nullptr)));
|
||||||
|
segmenter->output_category_mask_ = options->output_category_mask;
|
||||||
|
segmenter->output_confidence_masks_ = options->output_confidence_masks;
|
||||||
|
return segmenter;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
|
absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
|
||||||
mediapipe::Image image, const RegionOfInterest& roi,
|
mediapipe::Image image, const RegionOfInterest& roi,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -154,7 +177,16 @@ absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
|
||||||
mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))},
|
mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))},
|
||||||
{kNormRectStreamName,
|
{kNormRectStreamName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
std::optional<std::vector<Image>> confidence_masks;
|
||||||
|
if (output_confidence_masks_) {
|
||||||
|
confidence_masks =
|
||||||
|
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
|
||||||
|
}
|
||||||
|
std::optional<Image> category_mask;
|
||||||
|
if (output_category_mask_) {
|
||||||
|
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||||
|
}
|
||||||
|
return {{confidence_masks, category_mask}};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace interactive_segmenter
|
} // namespace interactive_segmenter
|
||||||
|
|
|
@ -21,12 +21,14 @@ limitations under the License.
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -39,30 +41,24 @@ struct InteractiveSegmenterOptions {
|
||||||
// file with metadata, accelerator options, op resolver, etc.
|
// file with metadata, accelerator options, op resolver, etc.
|
||||||
tasks::core::BaseOptions base_options;
|
tasks::core::BaseOptions base_options;
|
||||||
|
|
||||||
// The output type of segmentation results.
|
// Whether to output confidence masks.
|
||||||
enum OutputType {
|
bool output_confidence_masks = true;
|
||||||
// Gives a single output mask where each pixel represents the class which
|
|
||||||
// the pixel in the original image was predicted to belong to.
|
|
||||||
CATEGORY_MASK = 0,
|
|
||||||
// Gives a list of output masks where, for each mask, each pixel represents
|
|
||||||
// the prediction confidence, usually in the [0, 1] range.
|
|
||||||
CONFIDENCE_MASK = 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
OutputType output_type = OutputType::CATEGORY_MASK;
|
// Whether to output category mask.
|
||||||
|
bool output_category_mask = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// The Region-Of-Interest (ROI) to interact with.
|
// The Region-Of-Interest (ROI) to interact with.
|
||||||
struct RegionOfInterest {
|
struct RegionOfInterest {
|
||||||
enum Format {
|
enum class Format {
|
||||||
UNSPECIFIED = 0, // Format not specified.
|
kUnspecified = 0, // Format not specified.
|
||||||
KEYPOINT = 1, // Using keypoint to represent ROI.
|
kKeyPoint = 1, // Using keypoint to represent ROI.
|
||||||
};
|
};
|
||||||
|
|
||||||
// Specifies the format used to specify the region-of-interest. Note that
|
// Specifies the format used to specify the region-of-interest. Note that
|
||||||
// using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status
|
// using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status
|
||||||
// being returned.
|
// being returned.
|
||||||
Format format = Format::UNSPECIFIED;
|
Format format = Format::kUnspecified;
|
||||||
|
|
||||||
// Represents the ROI in keypoint format, this should be non-nullopt if
|
// Represents the ROI in keypoint format, this should be non-nullopt if
|
||||||
// `format` is `KEYPOINT`.
|
// `format` is `KEYPOINT`.
|
||||||
|
@ -84,13 +80,11 @@ struct RegionOfInterest {
|
||||||
// - RGB inputs is supported (`channels` is required to be 3).
|
// - RGB inputs is supported (`channels` is required to be 3).
|
||||||
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
||||||
// attached to the metadata for input normalization.
|
// attached to the metadata for input normalization.
|
||||||
// Output tensors:
|
// Output ImageSegmenterResult:
|
||||||
// (kTfLiteUInt8/kTfLiteFloat32)
|
// Provides optional confidence masks if `output_confidence_masks` is set
|
||||||
// - list of segmented masks.
|
// true, and an optional category mask if `output_category_mask` is set
|
||||||
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
|
// true. At least one of `output_confidence_masks` and `output_category_mask`
|
||||||
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
|
// must be set to true.
|
||||||
// `channels`.
|
|
||||||
// - batch is always 1
|
|
||||||
class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
public:
|
public:
|
||||||
using BaseVisionTaskApi::BaseVisionTaskApi;
|
using BaseVisionTaskApi::BaseVisionTaskApi;
|
||||||
|
@ -114,18 +108,17 @@ class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// setting its 'rotation_degrees' field. Note that specifying a
|
// setting its 'rotation_degrees' field. Note that specifying a
|
||||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||||
// and will result in an invalid argument error being returned.
|
// and will result in an invalid argument error being returned.
|
||||||
//
|
absl::StatusOr<image_segmenter::ImageSegmenterResult> Segment(
|
||||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
|
||||||
// per-category segmented image mask.
|
|
||||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
|
||||||
// contains only one confidence image mask.
|
|
||||||
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
|
|
||||||
mediapipe::Image image, const RegionOfInterest& roi,
|
mediapipe::Image image, const RegionOfInterest& roi,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
|
||||||
// Shuts down the InteractiveSegmenter when all works are done.
|
// Shuts down the InteractiveSegmenter when all works are done.
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool output_confidence_masks_;
|
||||||
|
bool output_category_mask_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace interactive_segmenter
|
} // namespace interactive_segmenter
|
||||||
|
|
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
@ -23,8 +26,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||||
#include "mediapipe/util/color.pb.h"
|
#include "mediapipe/util/color.pb.h"
|
||||||
#include "mediapipe/util/label_map.pb.h"
|
#include "mediapipe/util/graph_builder_utils.h"
|
||||||
#include "mediapipe/util/render_data.pb.h"
|
#include "mediapipe/util/render_data.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -42,16 +46,18 @@ using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
|
||||||
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
constexpr absl::string_view kSegmentationTag{"SEGMENTATION"};
|
||||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
constexpr absl::string_view kGroupedSegmentationTag{"GROUPED_SEGMENTATION"};
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr absl::string_view kConfidenceMaskTag{"CONFIDENCE_MASK"};
|
||||||
constexpr char kImageCpuTag[] = "IMAGE_CPU";
|
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
|
||||||
constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
|
||||||
constexpr char kAlphaTag[] = "ALPHA";
|
constexpr absl::string_view kImageTag{"IMAGE"};
|
||||||
constexpr char kAlphaGpuTag[] = "ALPHA_GPU";
|
constexpr absl::string_view kImageCpuTag{"IMAGE_CPU"};
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr absl::string_view kImageGpuTag{"IMAGE_GPU"};
|
||||||
constexpr char kRoiTag[] = "ROI";
|
constexpr absl::string_view kAlphaTag{"ALPHA"};
|
||||||
constexpr char kVideoTag[] = "VIDEO";
|
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
|
||||||
|
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||||
|
constexpr absl::string_view kRoiTag{"ROI"};
|
||||||
|
|
||||||
// Updates the graph to return `roi` stream which has same dimension as
|
// Updates the graph to return `roi` stream which has same dimension as
|
||||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
||||||
|
@ -87,11 +93,10 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
|
// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
|
||||||
// performs semantic segmentation given user's region-of-interest. Two kinds of
|
// performs semantic segmentation given the user's region-of-interest. The graph
|
||||||
// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can
|
// can output optional confidence masks if CONFIDENCE_MASKS is connected, and an
|
||||||
// retrieve segmented mask of only particular category/channel from
|
// optional category mask if CATEGORY_MASK is connected. At least one of
|
||||||
// SEGMENTATION, and users can also get all segmented masks from
|
// CONFIDENCE_MASK, CONFIDENCE_MASKS and CATEGORY_MASK must be connected.
|
||||||
// GROUPED_SEGMENTATION.
|
|
||||||
// - Accepts CPU input images and outputs segmented masks on CPU.
|
// - Accepts CPU input images and outputs segmented masks on CPU.
|
||||||
//
|
//
|
||||||
// Inputs:
|
// Inputs:
|
||||||
|
@ -106,11 +111,13 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
|
||||||
// @Optional: rect covering the whole image is used if not specified.
|
// @Optional: rect covering the whole image is used if not specified.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// SEGMENTATION - mediapipe::Image @Multiple
|
// CONFIDENCE_MASK - mediapipe::Image @Multiple
|
||||||
// Segmented masks for individual category. Segmented mask of single
|
// Confidence masks for individual category. Confidence mask of single
|
||||||
// category can be accessed by index based output stream.
|
// category can be accessed by index based output stream.
|
||||||
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image>
|
// CONFIDENCE_MASKS - std::vector<mediapipe::Image> @Optional
|
||||||
// The output segmented masks grouped in a vector.
|
// The output confidence masks grouped in a vector.
|
||||||
|
// CATEGORY_MASK - mediapipe::Image @Optional
|
||||||
|
// Optional Category mask.
|
||||||
// IMAGE - mediapipe::Image
|
// IMAGE - mediapipe::Image
|
||||||
// The image that image segmenter runs on.
|
// The image that image segmenter runs on.
|
||||||
//
|
//
|
||||||
|
@ -129,9 +136,6 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
|
||||||
// file_name: "/path/to/model.tflite"
|
// file_name: "/path/to/model.tflite"
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
// segmenter_options {
|
|
||||||
// output_type: CONFIDENCE_MASK
|
|
||||||
// }
|
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
@ -176,10 +180,26 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
|
||||||
image_with_set_alpha >> image_segmenter.In(kImageTag);
|
image_with_set_alpha >> image_segmenter.In(kImageTag);
|
||||||
norm_rect >> image_segmenter.In(kNormRectTag);
|
norm_rect >> image_segmenter.In(kNormRectTag);
|
||||||
|
|
||||||
|
// TODO: remove deprecated output type support.
|
||||||
|
if (task_options.segmenter_options().has_output_type()) {
|
||||||
image_segmenter.Out(kSegmentationTag) >>
|
image_segmenter.Out(kSegmentationTag) >>
|
||||||
graph[Output<Image>(kSegmentationTag)];
|
graph[Output<Image>(kSegmentationTag)];
|
||||||
image_segmenter.Out(kGroupedSegmentationTag) >>
|
image_segmenter.Out(kGroupedSegmentationTag) >>
|
||||||
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
||||||
|
} else {
|
||||||
|
if (HasOutput(sc->OriginalNode(), kConfidenceMaskTag)) {
|
||||||
|
image_segmenter.Out(kConfidenceMaskTag) >>
|
||||||
|
graph[Output<Image>(kConfidenceMaskTag)];
|
||||||
|
}
|
||||||
|
if (HasOutput(sc->OriginalNode(), kConfidenceMasksTag)) {
|
||||||
|
image_segmenter.Out(kConfidenceMasksTag) >>
|
||||||
|
graph[Output<Image>(kConfidenceMasksTag)];
|
||||||
|
}
|
||||||
|
if (HasOutput(sc->OriginalNode(), kCategoryMaskTag)) {
|
||||||
|
image_segmenter.Out(kCategoryMaskTag) >>
|
||||||
|
graph[Output<Image>(kCategoryMaskTag)];
|
||||||
|
}
|
||||||
|
}
|
||||||
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
||||||
|
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
|
|
|
@ -17,8 +17,11 @@ limitations under the License.
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/flags/flag.h"
|
#include "absl/flags/flag.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "mediapipe/framework/deps/file_path.h"
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
@ -39,6 +42,7 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "testing/base/public/gmock.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -53,13 +57,16 @@ using ::mediapipe::tasks::components::containers::RectF;
|
||||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
using ::testing::SizeIs;
|
||||||
|
using ::testing::status::StatusIs;
|
||||||
|
|
||||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
constexpr absl::string_view kTestDataDirectory{
|
||||||
constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite";
|
"/mediapipe/tasks/testdata/vision/"};
|
||||||
constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg";
|
constexpr absl::string_view kPtmModel{"ptm_512_hdt_ptm_woid.tflite"};
|
||||||
|
constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"};
|
||||||
// Golden mask for the dogs in cats_and_dogs.jpg.
|
// Golden mask for the dogs in cats_and_dogs.jpg.
|
||||||
constexpr char kCatsAndDogsMaskDog1[] = "cats_and_dogs_mask_dog1.png";
|
constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"};
|
||||||
constexpr char kCatsAndDogsMaskDog2[] = "cats_and_dogs_mask_dog2.png";
|
constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"};
|
||||||
|
|
||||||
constexpr float kGoldenMaskSimilarity = 0.97;
|
constexpr float kGoldenMaskSimilarity = 0.97;
|
||||||
|
|
||||||
|
@ -135,35 +142,45 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
|
||||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
options->base_options.op_resolver =
|
options->base_options.op_resolver =
|
||||||
absl::make_unique<DeepLabOpResolverMissingOps>();
|
absl::make_unique<DeepLabOpResolverMissingOps>();
|
||||||
auto segmenter_or = InteractiveSegmenter::Create(std::move(options));
|
auto segmenter = InteractiveSegmenter::Create(std::move(options));
|
||||||
// TODO: Make MediaPipe InferenceCalculator report the detailed
|
// TODO: Make MediaPipe InferenceCalculator report the detailed
|
||||||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
|
EXPECT_EQ(segmenter.status().code(), absl::StatusCode::kInternal);
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
segmenter_or.status().message(),
|
segmenter.status().message(),
|
||||||
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
|
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||||
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter_or =
|
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter =
|
||||||
InteractiveSegmenter::Create(
|
InteractiveSegmenter::Create(
|
||||||
std::make_unique<InteractiveSegmenterOptions>());
|
std::make_unique<InteractiveSegmenterOptions>());
|
||||||
|
|
||||||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument);
|
EXPECT_EQ(segmenter.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
segmenter_or.status().message(),
|
segmenter.status().message(),
|
||||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||||
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload),
|
EXPECT_THAT(segmenter.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
Optional(absl::Cord(absl::StrCat(
|
Optional(absl::Cord(absl::StrCat(
|
||||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) {
|
||||||
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
|
options->output_category_mask = false;
|
||||||
|
options->output_confidence_masks = false;
|
||||||
|
|
||||||
|
EXPECT_THAT(InteractiveSegmenter::Create(std::move(options)),
|
||||||
|
StatusIs(absl::StatusCode::kInvalidArgument,
|
||||||
|
HasSubstr("At least one of")));
|
||||||
|
}
|
||||||
|
|
||||||
struct InteractiveSegmenterTestParams {
|
struct InteractiveSegmenterTestParams {
|
||||||
std::string test_name;
|
std::string test_name;
|
||||||
RegionOfInterest::Format format;
|
RegionOfInterest::Format format;
|
||||||
NormalizedKeypoint roi;
|
NormalizedKeypoint roi;
|
||||||
std::string golden_mask_file;
|
absl::string_view golden_mask_file;
|
||||||
float similarity_threshold;
|
float similarity_threshold;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -181,16 +198,18 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
||||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK;
|
options->output_confidence_masks = false;
|
||||||
|
options->output_category_mask = true;
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||||
InteractiveSegmenter::Create(std::move(options)));
|
InteractiveSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto category_masks,
|
MP_ASSERT_OK_AND_ASSIGN(auto result,
|
||||||
segmenter->Segment(image, interaction_roi));
|
segmenter->Segment(image, interaction_roi));
|
||||||
EXPECT_EQ(category_masks.size(), 1);
|
EXPECT_TRUE(result.category_mask.has_value());
|
||||||
|
EXPECT_FALSE(result.confidence_masks.has_value());
|
||||||
|
|
||||||
cv::Mat actual_mask = mediapipe::formats::MatView(
|
cv::Mat actual_mask = mediapipe::formats::MatView(
|
||||||
category_masks[0].GetImageFrameSharedPtr().get());
|
result.category_mask->GetImageFrameSharedPtr().get());
|
||||||
|
|
||||||
cv::Mat expected_mask =
|
cv::Mat expected_mask =
|
||||||
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
|
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
|
||||||
|
@ -211,14 +230,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
options->output_type =
|
|
||||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||||
InteractiveSegmenter::Create(std::move(options)));
|
InteractiveSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
|
MP_ASSERT_OK_AND_ASSIGN(auto result,
|
||||||
segmenter->Segment(image, interaction_roi));
|
segmenter->Segment(image, interaction_roi));
|
||||||
EXPECT_EQ(confidence_masks.size(), 2);
|
EXPECT_FALSE(result.category_mask.has_value());
|
||||||
|
EXPECT_THAT(result.confidence_masks, Optional(SizeIs(2)));
|
||||||
|
|
||||||
cv::Mat expected_mask =
|
cv::Mat expected_mask =
|
||||||
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
|
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
|
||||||
|
@ -227,7 +245,7 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||||
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||||
|
|
||||||
cv::Mat actual_mask = mediapipe::formats::MatView(
|
cv::Mat actual_mask = mediapipe::formats::MatView(
|
||||||
confidence_masks[1].GetImageFrameSharedPtr().get());
|
result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
|
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
|
||||||
params.similarity_threshold));
|
params.similarity_threshold));
|
||||||
}
|
}
|
||||||
|
@ -235,9 +253,9 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
|
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
|
||||||
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
||||||
{{"PointToDog1", RegionOfInterest::KEYPOINT,
|
{{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
||||||
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
|
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
|
||||||
{"PointToDog2", RegionOfInterest::KEYPOINT,
|
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
|
||||||
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
|
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
|
||||||
kGoldenMaskSimilarity}}),
|
kGoldenMaskSimilarity}}),
|
||||||
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
||||||
|
@ -252,22 +270,21 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
||||||
Image image,
|
Image image,
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||||
RegionOfInterest interaction_roi;
|
RegionOfInterest interaction_roi;
|
||||||
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
interaction_roi.format = RegionOfInterest::Format::kKeyPoint;
|
||||||
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
|
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
|
||||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
options->output_type =
|
|
||||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||||
InteractiveSegmenter::Create(std::move(options)));
|
InteractiveSegmenter::Create(std::move(options)));
|
||||||
ImageProcessingOptions image_processing_options;
|
ImageProcessingOptions image_processing_options;
|
||||||
image_processing_options.rotation_degrees = -90;
|
image_processing_options.rotation_degrees = -90;
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto confidence_masks,
|
auto result,
|
||||||
segmenter->Segment(image, interaction_roi, image_processing_options));
|
segmenter->Segment(image, interaction_roi, image_processing_options));
|
||||||
EXPECT_EQ(confidence_masks.size(), 2);
|
EXPECT_FALSE(result.category_mask.has_value());
|
||||||
|
EXPECT_EQ(result.confidence_masks->size(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
|
@ -275,13 +292,11 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
Image image,
|
Image image,
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||||
RegionOfInterest interaction_roi;
|
RegionOfInterest interaction_roi;
|
||||||
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
interaction_roi.format = RegionOfInterest::Format::kKeyPoint;
|
||||||
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
|
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
|
||||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
options->output_type =
|
|
||||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||||
InteractiveSegmenter::Create(std::move(options)));
|
InteractiveSegmenter::Create(std::move(options)));
|
||||||
|
|
Loading…
Reference in New Issue
Block a user