Internal change

PiperOrigin-RevId: 522631851
This commit is contained in:
MediaPipe Team 2023-04-07 10:43:23 -07:00 committed by Copybara-Service
parent e3185e3df0
commit a1ce19f68e
5 changed files with 184 additions and 116 deletions

View File

@ -24,21 +24,26 @@ cc_library(
hdrs = ["interactive_segmenter.h"],
deps = [
":interactive_segmenter_graph",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers:keypoint",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_result",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/util:color_cc_proto",
"//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@ -61,9 +66,12 @@ cc_library(
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/util:color_cc_proto",
"//mediapipe/util:graph_builder_utils",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
] + select({
"//mediapipe/gpu:disable_gpu": [],

View File

@ -15,16 +15,24 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
#include <memory>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/util/color.pb.h"
@ -36,23 +44,26 @@ namespace vision {
namespace interactive_segmenter {
namespace {
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
constexpr char kConfidenceMasksStreamName[] = "confidence_masks";
constexpr char kCategoryMaskStreamName[] = "category_mask";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kRoiStreamName[] = "roi_in";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageTag[] = "IMAGE";
constexpr char kRoiTag[] = "ROI";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
constexpr absl::string_view kImageTag{"IMAGE"};
constexpr absl::string_view kRoiTag{"ROI"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
constexpr absl::string_view kSubgraphTypeName{
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image;
using ::mediapipe::NormalizedRect;
using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
image_segmenter::proto::ImageSegmenterGraphOptions;
@ -60,7 +71,8 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
// Creates a MediaPipe graph config that only contains a single subgraph node of
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageSegmenterGraphOptionsProto> options) {
std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
bool output_confidence_masks, bool output_category_mask) {
api2::builder::Graph graph;
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
@ -68,8 +80,15 @@ CalculatorGraphConfig CreateGraphConfig(
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kRoiTag).SetName(kRoiStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
graph.Out(kGroupedSegmentationTag);
if (output_confidence_masks) {
task_subgraph.Out(kConfidenceMasksTag)
.SetName(kConfidenceMasksStreamName) >>
graph.Out(kConfidenceMasksTag);
}
if (output_category_mask) {
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag);
}
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
@ -86,16 +105,6 @@ ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
switch (options->output_type) {
case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CATEGORY_MASK);
break;
case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CONFIDENCE_MASK);
break;
}
return options_proto;
}
@ -104,10 +113,10 @@ ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
RenderData result;
switch (roi.format) {
case RegionOfInterest::UNSPECIFIED:
case RegionOfInterest::Format::kUnspecified:
return absl::InvalidArgumentError(
"RegionOfInterest format not specified");
case RegionOfInterest::KEYPOINT:
case RegionOfInterest::Format::kKeyPoint:
RET_CHECK(roi.keypoint.has_value());
auto* annotation = result.add_render_annotations();
annotation->mutable_color()->set_r(255);
@ -125,15 +134,29 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>>
InteractiveSegmenter::Create(
std::unique_ptr<InteractiveSegmenterOptions> options) {
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
return core::VisionTaskApiFactory::Create<InteractiveSegmenter,
ImageSegmenterGraphOptionsProto>(
CreateGraphConfig(std::move(options_proto)),
std::move(options->base_options.op_resolver), core::RunningMode::IMAGE,
/*packets_callback=*/nullptr);
if (!options->output_confidence_masks && !options->output_category_mask) {
return absl::InvalidArgumentError(
"At least one of `output_confidence_masks` and `output_category_mask` "
"must be set.");
}
std::unique_ptr<ImageSegmenterGraphOptionsProto> options_proto =
ConvertImageSegmenterOptionsToProto(options.get());
ASSIGN_OR_RETURN(
std::unique_ptr<InteractiveSegmenter> segmenter,
(core::VisionTaskApiFactory::Create<InteractiveSegmenter,
ImageSegmenterGraphOptionsProto>(
CreateGraphConfig(std::move(options_proto),
options->output_confidence_masks,
options->output_category_mask),
std::move(options->base_options.op_resolver),
core::RunningMode::IMAGE,
/*packets_callback=*/nullptr)));
segmenter->output_category_mask_ = options->output_category_mask;
segmenter->output_confidence_masks_ = options->output_confidence_masks;
return segmenter;
}
absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
mediapipe::Image image, const RegionOfInterest& roi,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
@ -154,7 +177,16 @@ absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
std::optional<std::vector<Image>> confidence_masks;
if (output_confidence_masks_) {
confidence_masks =
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
}
std::optional<Image> category_mask;
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
}
} // namespace interactive_segmenter

View File

@ -21,12 +21,14 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
namespace mediapipe {
namespace tasks {
@ -39,30 +41,24 @@ struct InteractiveSegmenterOptions {
// file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options;
// The output type of segmentation results.
enum OutputType {
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK = 0,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK = 1,
};
// Whether to output confidence masks.
bool output_confidence_masks = true;
OutputType output_type = OutputType::CATEGORY_MASK;
// Whether to output category mask.
bool output_category_mask = false;
};
// The Region-Of-Interest (ROI) to interact with.
struct RegionOfInterest {
enum Format {
UNSPECIFIED = 0, // Format not specified.
KEYPOINT = 1, // Using keypoint to represent ROI.
enum class Format {
kUnspecified = 0, // Format not specified.
kKeyPoint = 1, // Using keypoint to represent ROI.
};
// Specifies the format used to specify the region-of-interest. Note that
// using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status
// being returned.
Format format = Format::UNSPECIFIED;
Format format = Format::kUnspecified;
// Represents the ROI in keypoint format, this should be non-nullopt if
// `format` is `KEYPOINT`.
@ -84,13 +80,11 @@ struct RegionOfInterest {
// - RGB inputs is supported (`channels` is required to be 3).
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
// attached to the metadata for input normalization.
// Output tensors:
// (kTfLiteUInt8/kTfLiteFloat32)
// - list of segmented masks.
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
// `channels`.
// - batch is always 1
// Output ImageSegmenterResult:
// Provides optional confidence masks if `output_confidence_masks` is set
// true, and an optional category mask if `output_category_mask` is set
// true. At least one of `output_confidence_masks` and `output_category_mask`
// must be set to true.
class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
public:
using BaseVisionTaskApi::BaseVisionTaskApi;
@ -114,18 +108,17 @@ class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
absl::StatusOr<image_segmenter::ImageSegmenterResult> Segment(
mediapipe::Image image, const RegionOfInterest& roi,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Shuts down the InteractiveSegmenter when all works are done.
absl::Status Close() { return runner_->Close(); }
private:
bool output_confidence_masks_;
bool output_category_mask_;
};
} // namespace interactive_segmenter

View File

@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
@ -23,8 +26,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/graph_builder_utils.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
@ -42,16 +46,18 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
constexpr char kSegmentationTag[] = "SEGMENTATION";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageCpuTag[] = "IMAGE_CPU";
constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kAlphaTag[] = "ALPHA";
constexpr char kAlphaGpuTag[] = "ALPHA_GPU";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kRoiTag[] = "ROI";
constexpr char kVideoTag[] = "VIDEO";
constexpr absl::string_view kSegmentationTag{"SEGMENTATION"};
constexpr absl::string_view kGroupedSegmentationTag{"GROUPED_SEGMENTATION"};
constexpr absl::string_view kConfidenceMaskTag{"CONFIDENCE_MASK"};
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
constexpr absl::string_view kImageTag{"IMAGE"};
constexpr absl::string_view kImageCpuTag{"IMAGE_CPU"};
constexpr absl::string_view kImageGpuTag{"IMAGE_GPU"};
constexpr absl::string_view kAlphaTag{"ALPHA"};
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kRoiTag{"ROI"};
// Updates the graph to return `roi` stream which has same dimension as
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
@ -87,11 +93,10 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
} // namespace
// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
// performs semantic segmentation given user's region-of-interest. Two kinds of
// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can
// retrieve segmented mask of only particular category/channel from
// SEGMENTATION, and users can also get all segmented masks from
// GROUPED_SEGMENTATION.
// performs semantic segmentation given the user's region-of-interest. The graph
// can output optional confidence masks if CONFIDENCE_MASKS is connected, and an
// optional category mask if CATEGORY_MASK is connected. At least one of
// CONFIDENCE_MASK, CONFIDENCE_MASKS and CATEGORY_MASK must be connected.
// - Accepts CPU input images and outputs segmented masks on CPU.
//
// Inputs:
@ -106,11 +111,13 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
// @Optional: rect covering the whole image is used if not specified.
//
// Outputs:
// SEGMENTATION - mediapipe::Image @Multiple
// Segmented masks for individual category. Segmented mask of single
// CONFIDENCE_MASK - mediapipe::Image @Multiple
// Confidence masks for individual category. Confidence mask of single
// category can be accessed by index based output stream.
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image>
// The output segmented masks grouped in a vector.
// CONFIDENCE_MASKS - std::vector<mediapipe::Image> @Optional
// The output confidence masks grouped in a vector.
// CATEGORY_MASK - mediapipe::Image @Optional
// Optional Category mask.
// IMAGE - mediapipe::Image
// The image that image segmenter runs on.
//
@ -129,9 +136,6 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
// file_name: "/path/to/model.tflite"
// }
// }
// segmenter_options {
// output_type: CONFIDENCE_MASK
// }
// }
// }
// }
@ -176,10 +180,26 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
image_with_set_alpha >> image_segmenter.In(kImageTag);
norm_rect >> image_segmenter.In(kNormRectTag);
image_segmenter.Out(kSegmentationTag) >>
graph[Output<Image>(kSegmentationTag)];
image_segmenter.Out(kGroupedSegmentationTag) >>
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
// TODO: remove deprecated output type support.
if (task_options.segmenter_options().has_output_type()) {
image_segmenter.Out(kSegmentationTag) >>
graph[Output<Image>(kSegmentationTag)];
image_segmenter.Out(kGroupedSegmentationTag) >>
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
} else {
if (HasOutput(sc->OriginalNode(), kConfidenceMaskTag)) {
image_segmenter.Out(kConfidenceMaskTag) >>
graph[Output<Image>(kConfidenceMaskTag)];
}
if (HasOutput(sc->OriginalNode(), kConfidenceMasksTag)) {
image_segmenter.Out(kConfidenceMasksTag) >>
graph[Output<Image>(kConfidenceMasksTag)];
}
if (HasOutput(sc->OriginalNode(), kCategoryMaskTag)) {
image_segmenter.Out(kCategoryMaskTag) >>
graph[Output<Image>(kCategoryMaskTag)];
}
}
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();

View File

@ -17,8 +17,11 @@ limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
@ -39,6 +42,7 @@ limitations under the License.
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "testing/base/public/gmock.h"
namespace mediapipe {
namespace tasks {
@ -53,13 +57,16 @@ using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
using ::testing::SizeIs;
using ::testing::status::StatusIs;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite";
constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg";
constexpr absl::string_view kTestDataDirectory{
"/mediapipe/tasks/testdata/vision/"};
constexpr absl::string_view kPtmModel{"ptm_512_hdt_ptm_woid.tflite"};
constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"};
// Golden mask for the dogs in cats_and_dogs.jpg.
constexpr char kCatsAndDogsMaskDog1[] = "cats_and_dogs_mask_dog1.png";
constexpr char kCatsAndDogsMaskDog2[] = "cats_and_dogs_mask_dog2.png";
constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"};
constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"};
constexpr float kGoldenMaskSimilarity = 0.97;
@ -135,35 +142,45 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
JoinPath("./", kTestDataDirectory, kPtmModel);
options->base_options.op_resolver =
absl::make_unique<DeepLabOpResolverMissingOps>();
auto segmenter_or = InteractiveSegmenter::Create(std::move(options));
auto segmenter = InteractiveSegmenter::Create(std::move(options));
// TODO: Make MediaPipe InferenceCalculator report the detailed
// interpreter errors (e.g., "Encountered unresolved custom op").
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
EXPECT_EQ(segmenter.status().code(), absl::StatusCode::kInternal);
EXPECT_THAT(
segmenter_or.status().message(),
segmenter.status().message(),
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
}
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter_or =
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter =
InteractiveSegmenter::Create(
std::make_unique<InteractiveSegmenterOptions>());
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_EQ(segmenter.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
segmenter_or.status().message(),
segmenter.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload),
EXPECT_THAT(segmenter.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) {
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->output_category_mask = false;
options->output_confidence_masks = false;
EXPECT_THAT(InteractiveSegmenter::Create(std::move(options)),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("At least one of")));
}
struct InteractiveSegmenterTestParams {
std::string test_name;
RegionOfInterest::Format format;
NormalizedKeypoint roi;
std::string golden_mask_file;
absl::string_view golden_mask_file;
float similarity_threshold;
};
@ -181,16 +198,18 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK;
options->output_confidence_masks = false;
options->output_category_mask = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_masks,
MP_ASSERT_OK_AND_ASSIGN(auto result,
segmenter->Segment(image, interaction_roi));
EXPECT_EQ(category_masks.size(), 1);
EXPECT_TRUE(result.category_mask.has_value());
EXPECT_FALSE(result.confidence_masks.has_value());
cv::Mat actual_mask = mediapipe::formats::MatView(
category_masks[0].GetImageFrameSharedPtr().get());
result.category_mask->GetImageFrameSharedPtr().get());
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
@ -211,14 +230,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
MP_ASSERT_OK_AND_ASSIGN(auto result,
segmenter->Segment(image, interaction_roi));
EXPECT_EQ(confidence_masks.size(), 2);
EXPECT_FALSE(result.category_mask.has_value());
EXPECT_THAT(result.confidence_masks, Optional(SizeIs(2)));
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
@ -227,7 +245,7 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
cv::Mat actual_mask = mediapipe::formats::MatView(
confidence_masks[1].GetImageFrameSharedPtr().get());
result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
params.similarity_threshold));
}
@ -235,9 +253,9 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
INSTANTIATE_TEST_SUITE_P(
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
::testing::ValuesIn<InteractiveSegmenterTestParams>(
{{"PointToDog1", RegionOfInterest::KEYPOINT,
{{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
{"PointToDog2", RegionOfInterest::KEYPOINT,
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
kGoldenMaskSimilarity}}),
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
@ -252,22 +270,21 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT;
interaction_roi.format = RegionOfInterest::Format::kKeyPoint;
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN(
auto confidence_masks,
auto result,
segmenter->Segment(image, interaction_roi, image_processing_options));
EXPECT_EQ(confidence_masks.size(), 2);
EXPECT_FALSE(result.category_mask.has_value());
EXPECT_EQ(result.confidence_masks->size(), 2);
}
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
@ -275,13 +292,11 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT;
interaction_roi.format = RegionOfInterest::Format::kKeyPoint;
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));