ImageSegmenterGraph set activation type from metadata, and remove the activation config in C++ ImageSegmenterOptions.
PiperOrigin-RevId: 516893115
This commit is contained in:
		
							parent
							
								
									a323825134
								
							
						
					
					
						commit
						59962bed27
					
				| 
						 | 
				
			
			@ -80,6 +80,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
        "//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc",
 | 
			
		||||
        "//mediapipe/tasks/metadata:metadata_schema_cc",
 | 
			
		||||
        "//mediapipe/util:label_map_cc_proto",
 | 
			
		||||
        "//mediapipe/util:label_map_util",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -101,20 +101,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
 | 
			
		|||
          SegmenterOptions::CONFIDENCE_MASK);
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
  switch (options->activation) {
 | 
			
		||||
    case ImageSegmenterOptions::Activation::NONE:
 | 
			
		||||
      options_proto->mutable_segmenter_options()->set_activation(
 | 
			
		||||
          SegmenterOptions::NONE);
 | 
			
		||||
      break;
 | 
			
		||||
    case ImageSegmenterOptions::Activation::SIGMOID:
 | 
			
		||||
      options_proto->mutable_segmenter_options()->set_activation(
 | 
			
		||||
          SegmenterOptions::SIGMOID);
 | 
			
		||||
      break;
 | 
			
		||||
    case ImageSegmenterOptions::Activation::SOFTMAX:
 | 
			
		||||
      options_proto->mutable_segmenter_options()->set_activation(
 | 
			
		||||
          SegmenterOptions::SOFTMAX);
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
  return options_proto;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -64,15 +64,6 @@ struct ImageSegmenterOptions {
 | 
			
		|||
 | 
			
		||||
  OutputType output_type = OutputType::CATEGORY_MASK;
 | 
			
		||||
 | 
			
		||||
  // The activation function used on the raw segmentation model output.
 | 
			
		||||
  enum Activation {
 | 
			
		||||
    NONE = 0,     // No activation function is used.
 | 
			
		||||
    SIGMOID = 1,  // Assumes 1-channel input tensor.
 | 
			
		||||
    SOFTMAX = 2,  // Assumes multi-channel input tensor.
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  Activation activation = Activation::NONE;
 | 
			
		||||
 | 
			
		||||
  // The user-defined result callback for processing live stream data.
 | 
			
		||||
  // The result callback should only be specified when the running mode is set
 | 
			
		||||
  // to RunningMode::LIVE_STREAM.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -40,6 +40,7 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
 | 
			
		||||
#include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h"
 | 
			
		||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
 | 
			
		||||
#include "mediapipe/util/label_map.pb.h"
 | 
			
		||||
#include "mediapipe/util/label_map_util.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -74,6 +75,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
 | 
			
		|||
constexpr char kNormRectTag[] = "NORM_RECT";
 | 
			
		||||
constexpr char kTensorsTag[] = "TENSORS";
 | 
			
		||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
 | 
			
		||||
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
 | 
			
		||||
 | 
			
		||||
// Struct holding the different output streams produced by the image segmenter
 | 
			
		||||
// subgraph.
 | 
			
		||||
| 
						 | 
				
			
			@ -130,7 +132,49 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
 | 
			
		|||
    const ImageSegmenterGraphOptions& segmenter_option,
 | 
			
		||||
    const core::ModelResources& model_resources,
 | 
			
		||||
    TensorsToSegmentationCalculatorOptions* options) {
 | 
			
		||||
  *options->mutable_segmenter_options() = segmenter_option.segmenter_options();
 | 
			
		||||
  // Set default activation function NONE
 | 
			
		||||
  options->mutable_segmenter_options()->set_output_type(
 | 
			
		||||
      segmenter_option.segmenter_options().output_type());
 | 
			
		||||
  options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE);
 | 
			
		||||
  // Find the custom metadata of ImageSegmenterOptions type in model metadata.
 | 
			
		||||
  const auto* metadata_extractor = model_resources.GetMetadataExtractor();
 | 
			
		||||
  bool found_activation_in_metadata = false;
 | 
			
		||||
  if (metadata_extractor->GetCustomMetadataList() != nullptr &&
 | 
			
		||||
      metadata_extractor->GetCustomMetadataList()->size() > 0) {
 | 
			
		||||
    for (const auto& custom_metadata :
 | 
			
		||||
         *metadata_extractor->GetCustomMetadataList()) {
 | 
			
		||||
      if (custom_metadata->name()->str() == kSegmentationMetadataName) {
 | 
			
		||||
        found_activation_in_metadata = true;
 | 
			
		||||
        auto activation_fb =
 | 
			
		||||
            GetImageSegmenterOptions(custom_metadata->data()->data())
 | 
			
		||||
                ->activation();
 | 
			
		||||
        switch (activation_fb) {
 | 
			
		||||
          case Activation_NONE:
 | 
			
		||||
            options->mutable_segmenter_options()->set_activation(
 | 
			
		||||
                SegmenterOptions::NONE);
 | 
			
		||||
            break;
 | 
			
		||||
          case Activation_SIGMOID:
 | 
			
		||||
            options->mutable_segmenter_options()->set_activation(
 | 
			
		||||
                SegmenterOptions::SIGMOID);
 | 
			
		||||
            break;
 | 
			
		||||
          case Activation_SOFTMAX:
 | 
			
		||||
            options->mutable_segmenter_options()->set_activation(
 | 
			
		||||
                SegmenterOptions::SOFTMAX);
 | 
			
		||||
            break;
 | 
			
		||||
          default:
 | 
			
		||||
            return CreateStatusWithPayload(
 | 
			
		||||
                absl::StatusCode::kInvalidArgument,
 | 
			
		||||
                "Invalid activation type found in CustomMetadata of "
 | 
			
		||||
                "ImageSegmenterOptions type.");
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  if (!found_activation_in_metadata) {
 | 
			
		||||
    LOG(WARNING)
 | 
			
		||||
        << "No activation type is found in model metadata. Use NONE for "
 | 
			
		||||
           "ImageSegmenterGraph.";
 | 
			
		||||
  }
 | 
			
		||||
  const tflite::Model& model = *model_resources.GetTfLiteModel();
 | 
			
		||||
  if (model.subgraphs()->size() != 1) {
 | 
			
		||||
    return CreateStatusWithPayload(
 | 
			
		||||
| 
						 | 
				
			
			@ -146,8 +190,6 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
 | 
			
		|||
        MediaPipeTasksStatus::kInvalidArgumentError);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const ModelMetadataExtractor* metadata_extractor =
 | 
			
		||||
      model_resources.GetMetadataExtractor();
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      *options->mutable_label_items(),
 | 
			
		||||
      GetLabelItemsIfAny(*metadata_extractor,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -304,7 +304,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
 | 
			
		|||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
 | 
			
		||||
  options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
 | 
			
		||||
  options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
 | 
			
		||||
                          ImageSegmenter::Create(std::move(options)));
 | 
			
		||||
| 
						 | 
				
			
			@ -333,7 +332,6 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
 | 
			
		|||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
 | 
			
		||||
  options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
 | 
			
		||||
  options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
 | 
			
		||||
                          ImageSegmenter::Create(std::move(options)));
 | 
			
		||||
| 
						 | 
				
			
			@ -364,7 +362,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
 | 
			
		|||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
 | 
			
		||||
  options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
 | 
			
		||||
  options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
 | 
			
		||||
                          ImageSegmenter::Create(std::move(options)));
 | 
			
		||||
| 
						 | 
				
			
			@ -388,7 +385,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
 | 
			
		|||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
 | 
			
		||||
  options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
 | 
			
		||||
  options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
 | 
			
		||||
                          ImageSegmenter::Create(std::move(options)));
 | 
			
		||||
| 
						 | 
				
			
			@ -416,7 +412,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
 | 
			
		|||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
 | 
			
		||||
  options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
 | 
			
		||||
  options->activation = ImageSegmenterOptions::Activation::NONE;
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
 | 
			
		||||
                          ImageSegmenter::Create(std::move(options)));
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
 | 
			
		||||
| 
						 | 
				
			
			@ -442,7 +437,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
 | 
			
		|||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
 | 
			
		||||
  options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
 | 
			
		||||
  options->activation = ImageSegmenterOptions::Activation::NONE;
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
 | 
			
		||||
                          ImageSegmenter::Create(std::move(options)));
 | 
			
		||||
| 
						 | 
				
			
			@ -470,7 +464,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) {
 | 
			
		|||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
 | 
			
		||||
  options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
 | 
			
		||||
  options->activation = ImageSegmenterOptions::Activation::NONE;
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
 | 
			
		||||
                          ImageSegmenter::Create(std::move(options)));
 | 
			
		||||
| 
						 | 
				
			
			@ -495,7 +488,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
 | 
			
		|||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape);
 | 
			
		||||
  options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
 | 
			
		||||
  options->activation = ImageSegmenterOptions::Activation::NONE;
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
 | 
			
		||||
                          ImageSegmenter::Create(std::move(options)));
 | 
			
		||||
| 
						 | 
				
			
			@ -521,7 +513,6 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
 | 
			
		|||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata);
 | 
			
		||||
  options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
 | 
			
		||||
  options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
 | 
			
		||||
                          ImageSegmenter::Create(std::move(options)));
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -641,9 +641,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
            SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // TODO: remove this once activation is handled in metadata and grpah level.
 | 
			
		||||
      segmenterOptionsBuilder.setActivation(
 | 
			
		||||
          SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX);
 | 
			
		||||
      taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
 | 
			
		||||
      return CalculatorOptions.newBuilder()
 | 
			
		||||
          .setExtension(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user