ImageSegmenterGraph set activation type from metadata, and remove the activation config in C++ ImageSegmenterOptions.

PiperOrigin-RevId: 516893115
This commit is contained in:
MediaPipe Team 2023-03-15 12:08:28 -07:00 committed by Copybara-Service
parent a323825134
commit 59962bed27
6 changed files with 46 additions and 38 deletions

View File

@ -80,6 +80,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util",

View File

@ -101,20 +101,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
SegmenterOptions::CONFIDENCE_MASK);
break;
}
switch (options->activation) {
case ImageSegmenterOptions::Activation::NONE:
options_proto->mutable_segmenter_options()->set_activation(
SegmenterOptions::NONE);
break;
case ImageSegmenterOptions::Activation::SIGMOID:
options_proto->mutable_segmenter_options()->set_activation(
SegmenterOptions::SIGMOID);
break;
case ImageSegmenterOptions::Activation::SOFTMAX:
options_proto->mutable_segmenter_options()->set_activation(
SegmenterOptions::SOFTMAX);
break;
}
return options_proto;
}

View File

@ -64,15 +64,6 @@ struct ImageSegmenterOptions {
OutputType output_type = OutputType::CATEGORY_MASK;
// The activation function used on the raw segmentation model output.
enum Activation {
NONE = 0, // No activation function is used.
SIGMOID = 1, // Assumes 1-channel input tensor.
SOFTMAX = 2, // Assumes multi-channel input tensor.
};
Activation activation = Activation::NONE;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/label_map_util.h"
@ -74,6 +75,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
// Struct holding the different output streams produced by the image segmenter
// subgraph.
@ -130,7 +132,49 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
const ImageSegmenterGraphOptions& segmenter_option,
const core::ModelResources& model_resources,
TensorsToSegmentationCalculatorOptions* options) {
*options->mutable_segmenter_options() = segmenter_option.segmenter_options();
// Set default activation function NONE
options->mutable_segmenter_options()->set_output_type(
segmenter_option.segmenter_options().output_type());
options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE);
// Find the custom metadata of ImageSegmenterOptions type in model metadata.
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
bool found_activation_in_metadata = false;
if (metadata_extractor->GetCustomMetadataList() != nullptr &&
metadata_extractor->GetCustomMetadataList()->size() > 0) {
for (const auto& custom_metadata :
*metadata_extractor->GetCustomMetadataList()) {
if (custom_metadata->name()->str() == kSegmentationMetadataName) {
found_activation_in_metadata = true;
auto activation_fb =
GetImageSegmenterOptions(custom_metadata->data()->data())
->activation();
switch (activation_fb) {
case Activation_NONE:
options->mutable_segmenter_options()->set_activation(
SegmenterOptions::NONE);
break;
case Activation_SIGMOID:
options->mutable_segmenter_options()->set_activation(
SegmenterOptions::SIGMOID);
break;
case Activation_SOFTMAX:
options->mutable_segmenter_options()->set_activation(
SegmenterOptions::SOFTMAX);
break;
default:
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Invalid activation type found in CustomMetadata of "
"ImageSegmenterOptions type.");
}
}
}
}
if (!found_activation_in_metadata) {
LOG(WARNING)
<< "No activation type is found in model metadata. Use NONE for "
"ImageSegmenterGraph.";
}
const tflite::Model& model = *model_resources.GetTfLiteModel();
if (model.subgraphs()->size() != 1) {
return CreateStatusWithPayload(
@ -146,8 +190,6 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
MediaPipeTasksStatus::kInvalidArgumentError);
}
const ModelMetadataExtractor* metadata_extractor =
model_resources.GetMetadataExtractor();
ASSIGN_OR_RETURN(
*options->mutable_label_items(),
GetLabelItemsIfAny(*metadata_extractor,

View File

@ -304,7 +304,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
@ -333,7 +332,6 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
@ -364,7 +362,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
@ -388,7 +385,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
@ -416,7 +412,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::NONE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
@ -442,7 +437,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::NONE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
@ -470,7 +464,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->activation = ImageSegmenterOptions::Activation::NONE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
@ -495,7 +488,6 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
options->activation = ImageSegmenterOptions::Activation::NONE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
@ -521,7 +513,6 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));

View File

@ -641,9 +641,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
}
// TODO: remove this once activation is handled in metadata and grpah level.
segmenterOptionsBuilder.setActivation(
SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX);
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
return CalculatorOptions.newBuilder()
.setExtension(