diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 1123204ce..69833a5f6 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -80,6 +80,7 @@ cc_library( "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 9769b47d5..c12fe7f7e 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -101,20 +101,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) { SegmenterOptions::CONFIDENCE_MASK); break; } - switch (options->activation) { - case ImageSegmenterOptions::Activation::NONE: - options_proto->mutable_segmenter_options()->set_activation( - SegmenterOptions::NONE); - break; - case ImageSegmenterOptions::Activation::SIGMOID: - options_proto->mutable_segmenter_options()->set_activation( - SegmenterOptions::SIGMOID); - break; - case ImageSegmenterOptions::Activation::SOFTMAX: - options_proto->mutable_segmenter_options()->set_activation( - SegmenterOptions::SOFTMAX); - break; - } return options_proto; } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index c757296e4..076a5016c 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -64,15 +64,6 @@ struct ImageSegmenterOptions { OutputType output_type = OutputType::CATEGORY_MASK; - // The activation function used on the raw segmentation model output. - enum Activation { - NONE = 0, // No activation function is used. - SIGMOID = 1, // Assumes 1-channel input tensor. - SOFTMAX = 2, // Assumes multi-channel input tensor. - }; - - Activation activation = Activation::NONE; - // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 6a7e08626..fe6265b73 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" +#include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -74,6 +75,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; // Struct holding the different output streams produced by the image segmenter // subgraph. @@ -130,7 +132,49 @@ absl::Status ConfigureTensorsToSegmentationCalculator( const ImageSegmenterGraphOptions& segmenter_option, const core::ModelResources& model_resources, TensorsToSegmentationCalculatorOptions* options) { - *options->mutable_segmenter_options() = segmenter_option.segmenter_options(); + // Set default activation function NONE + options->mutable_segmenter_options()->set_output_type( + segmenter_option.segmenter_options().output_type()); + options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE); + // Find the custom metadata of ImageSegmenterOptions type in model metadata. + const auto* metadata_extractor = model_resources.GetMetadataExtractor(); + bool found_activation_in_metadata = false; + if (metadata_extractor->GetCustomMetadataList() != nullptr && + metadata_extractor->GetCustomMetadataList()->size() > 0) { + for (const auto& custom_metadata : + *metadata_extractor->GetCustomMetadataList()) { + if (custom_metadata->name()->str() == kSegmentationMetadataName) { + found_activation_in_metadata = true; + auto activation_fb = + GetImageSegmenterOptions(custom_metadata->data()->data()) + ->activation(); + switch (activation_fb) { + case Activation_NONE: + options->mutable_segmenter_options()->set_activation( + SegmenterOptions::NONE); + break; + case Activation_SIGMOID: + options->mutable_segmenter_options()->set_activation( + SegmenterOptions::SIGMOID); + break; + case Activation_SOFTMAX: + options->mutable_segmenter_options()->set_activation( + SegmenterOptions::SOFTMAX); + break; + default: + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid activation type found in CustomMetadata of " + "ImageSegmenterOptions type."); + } + } + } + } + if (!found_activation_in_metadata) { + LOG(WARNING) + << "No activation type is found in model metadata. Use NONE for " + "ImageSegmenterGraph."; + } const tflite::Model& model = *model_resources.GetTfLiteModel(); if (model.subgraphs()->size() != 1) { return CreateStatusWithPayload( @@ -146,8 +190,6 @@ absl::Status ConfigureTensorsToSegmentationCalculator( MediaPipeTasksStatus::kInvalidArgumentError); } - const ModelMetadataExtractor* metadata_extractor = - model_resources.GetMetadataExtractor(); ASSIGN_OR_RETURN( *options->mutable_label_items(), GetLabelItemsIfAny(*metadata_extractor, diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index d063ca87a..1d75a3fb7 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -304,7 +304,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -333,7 +332,6 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -364,7 +362,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -388,7 +385,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -416,7 +412,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); @@ -442,7 +437,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentation); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -470,7 +464,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentation); options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -495,7 +488,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape); options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -521,7 +513,6 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index 299423003..931740c8e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -641,9 +641,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi { SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); } - // TODO: remove this once activation is handled in metadata and grpah level. - segmenterOptionsBuilder.setActivation( - SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX); taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension(