From c248525eeb17da110346ec76a9865de5a20d4c4d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 18 May 2023 11:37:13 -0700 Subject: [PATCH] internal update PiperOrigin-RevId: 533197055 --- .../tensors_to_segmentation_calculator.cc | 32 ++++++++++++++--- .../vision/image_segmenter/image_segmenter.cc | 18 ++++++++-- .../image_segmenter/image_segmenter_graph.cc | 36 ++++++++++++------- .../image_segmenter/image_segmenter_result.h | 4 +++ .../interactive_segmenter.cc | 8 ++++- .../interactive_segmenter_graph.cc | 3 ++ .../vision/imagesegmenter/ImageSegmenter.java | 14 ++++++-- .../imagesegmenter/ImageSegmenterResult.java | 15 ++++++-- .../InteractiveSegmenter.java | 13 +++++++ 9 files changed, 119 insertions(+), 24 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 660dc59b7..f77855587 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -291,8 +291,11 @@ class TensorsToSegmentationCalculator : public Node { static constexpr Output::Multiple kConfidenceMaskOut{ "CONFIDENCE_MASK"}; static constexpr Output::Optional kCategoryMaskOut{"CATEGORY_MASK"}; + static constexpr Output>::Optional kQualityScoresOut{ + "QUALITY_SCORES"}; MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut, - kConfidenceMaskOut, kCategoryMaskOut); + kConfidenceMaskOut, kCategoryMaskOut, + kQualityScoresOut); static absl::Status UpdateContract(CalculatorContract* cc); @@ -345,12 +348,33 @@ absl::Status TensorsToSegmentationCalculator::Open( absl::Status TensorsToSegmentationCalculator::Process( mediapipe::CalculatorContext* cc) { - RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1) - << "Expect a vector of single Tensor."; - const auto& input_tensor = kTensorsIn(cc).Get()[0]; + const auto& input_tensors = kTensorsIn(cc).Get(); + if (input_tensors.size() != 1 && input_tensors.size() != 2) { + return absl::InvalidArgumentError( + "Expect input tensor vector of size 1 or 2."); + } + const auto& input_tensor = *input_tensors.rbegin(); ASSIGN_OR_RETURN(const Shape input_shape, GetImageLikeTensorShape(input_tensor)); + // TODO: should use tensor signature to get the correct output + // tensor. + if (input_tensors.size() == 2) { + const auto& quality_tensor = input_tensors[0]; + const float* quality_score_buffer = + quality_tensor.GetCpuReadView().buffer(); + const std::vector quality_scores( + quality_score_buffer, + quality_score_buffer + + (quality_tensor.bytes() / quality_tensor.element_size())); + kQualityScoresOut(cc).Send(quality_scores); + } else { + // If the input_tensors don't contain quality scores, send the default + // quality scores as 1. + const std::vector quality_scores(input_shape.channels, 1.0f); + kQualityScoresOut(cc).Send(quality_scores); + } + // Category mask does not require activation function. if (options_.segmenter_options().output_type() == SegmenterOptions::CONFIDENCE_MASK && diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index a67843258..99faa1064 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -46,6 +46,8 @@ constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectStreamName[] = "norm_rect_in"; constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kQualityScoresStreamName[] = "quality_scores"; +constexpr char kQualityScoresTag[] = "QUALITY_SCORES"; constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; @@ -77,6 +79,8 @@ CalculatorGraphConfig CreateGraphConfig( task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> graph.Out(kCategoryMaskTag); } + task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >> + graph.Out(kQualityScoresTag); task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { @@ -172,9 +176,13 @@ absl::StatusOr> ImageSegmenter::Create( category_mask = status_or_packets.value()[kCategoryMaskStreamName].Get(); } + const std::vector& quality_scores = + status_or_packets.value()[kQualityScoresStreamName] + .Get>(); Packet image_packet = status_or_packets.value()[kImageOutStreamName]; result_callback( - {{confidence_masks, category_mask}}, image_packet.Get(), + {{confidence_masks, category_mask, quality_scores}}, + image_packet.Get(), image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); }; } @@ -227,7 +235,9 @@ absl::StatusOr ImageSegmenter::Segment( if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); } - return {{confidence_masks, category_mask}}; + const std::vector& quality_scores = + output_packets[kQualityScoresStreamName].Get>(); + return {{confidence_masks, category_mask, quality_scores}}; } absl::StatusOr ImageSegmenter::SegmentForVideo( @@ -260,7 +270,9 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); } - return {{confidence_masks, category_mask}}; + const std::vector& quality_scores = + output_packets[kQualityScoresStreamName].Get>(); + return {{confidence_masks, category_mask, quality_scores}}; } absl::Status ImageSegmenter::SegmentAsync( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 6ecfa3685..0ae47ffd1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -81,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kQualityScoresTag[] = "QUALITY_SCORES"; constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; // Struct holding the different output streams produced by the image segmenter @@ -90,6 +92,7 @@ struct ImageSegmenterOutputs { std::optional>> confidence_masks; std::optional> category_mask; // The same as the input image, mainly used for live stream mode. + std::optional>> quality_scores; Source image; }; @@ -191,19 +194,12 @@ absl::Status ConfigureTensorsToSegmentationCalculator( "Segmentation tflite models are assumed to have a single subgraph.", MediaPipeTasksStatus::kInvalidArgumentError); } - const auto* primary_subgraph = (*model.subgraphs())[0]; - if (primary_subgraph->outputs()->size() != 1) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Segmentation tflite models are assumed to have a single output.", - MediaPipeTasksStatus::kInvalidArgumentError); - } - ASSIGN_OR_RETURN( *options->mutable_label_items(), - GetLabelItemsIfAny(*metadata_extractor, - *metadata_extractor->GetOutputTensorMetadata()->Get(0), - segmenter_option.display_names_locale())); + GetLabelItemsIfAny( + *metadata_extractor, + **metadata_extractor->GetOutputTensorMetadata()->crbegin(), + segmenter_option.display_names_locale())); return absl::OkStatus(); } @@ -213,10 +209,16 @@ absl::StatusOr GetOutputTensor( const tflite::Model& model = *model_resources.GetTfLiteModel(); const auto* primary_subgraph = (*model.subgraphs())[0]; const auto* output_tensor = - (*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]]; + (*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()]; return output_tensor; } +uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) { + const tflite::Model& model = *model_resources.GetTfLiteModel(); + const auto* primary_subgraph = (*model.subgraphs())[0]; + return primary_subgraph->outputs()->size(); +} + // Get the input tensor from the tflite model of given model resources. absl::StatusOr GetInputTensor( const core::ModelResources& model_resources) { @@ -433,6 +435,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { *output_streams.category_mask >> graph[Output(kCategoryMaskTag)]; } } + if (output_streams.quality_scores) { + *output_streams.quality_scores >> + graph[Output>::Optional(kQualityScoresTag)]; + } output_streams.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -530,9 +536,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { tensor_to_images[Output::Multiple(kSegmentationTag)][i])); } } + auto quality_scores = + tensor_to_images[Output>(kQualityScoresTag)]; return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, /*confidence_masks=*/std::nullopt, /*category_mask=*/std::nullopt, + /*quality_scores=*/quality_scores, /*image=*/image_and_tensors.image}; } else { std::optional>> confidence_masks; @@ -552,9 +561,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { if (output_category_mask_) { category_mask = tensor_to_images[Output(kCategoryMaskTag)]; } + auto quality_scores = + tensor_to_images[Output>(kQualityScoresTag)]; return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt, /*confidence_masks=*/confidence_masks, /*category_mask=*/category_mask, + /*quality_scores=*/quality_scores, /*image=*/image_and_tensors.image}; } } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h index 1e7968ebd..a203718f4 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h @@ -33,6 +33,10 @@ struct ImageSegmenterResult { // A category mask of uint8 image in GRAY8 format where each pixel represents // the class which the pixel in the original image was predicted to belong to. std::optional category_mask; + // The quality scores of the result masks, in the range of [0, 1]. Default to + // `1` if the model doesn't output quality scores. Each element corresponds to + // the score of the category in the model outputs. + std::vector quality_scores; }; } // namespace image_segmenter diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc index c0d89c87d..38bbf3baf 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc @@ -51,12 +51,14 @@ constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kRoiStreamName[] = "roi_in"; constexpr char kNormRectStreamName[] = "norm_rect_in"; +constexpr char kQualityScoresStreamName[] = "quality_scores"; constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"}; constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"}; constexpr absl::string_view kImageTag{"IMAGE"}; constexpr absl::string_view kRoiTag{"ROI"}; constexpr absl::string_view kNormRectTag{"NORM_RECT"}; +constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"}; constexpr absl::string_view kSubgraphTypeName{ "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"}; @@ -91,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig( task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> graph.Out(kCategoryMaskTag); } + task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >> + graph.Out(kQualityScoresTag); task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); graph.In(kImageTag) >> task_subgraph.In(kImageTag); @@ -201,7 +205,9 @@ absl::StatusOr InteractiveSegmenter::Segment( if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); } - return {{confidence_masks, category_mask}}; + const std::vector& quality_scores = + output_packets[kQualityScoresStreamName].Get>(); + return {{confidence_masks, category_mask, quality_scores}}; } } // namespace interactive_segmenter diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc index a765997d8..5bb3e8ece 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc @@ -58,6 +58,7 @@ constexpr absl::string_view kAlphaTag{"ALPHA"}; constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"}; constexpr absl::string_view kNormRectTag{"NORM_RECT"}; constexpr absl::string_view kRoiTag{"ROI"}; +constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"}; // Updates the graph to return `roi` stream which has same dimension as // `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is @@ -200,6 +201,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph { graph[Output(kCategoryMaskTag)]; } } + image_segmenter.Out(kQualityScoresTag) >> + graph[Output>::Optional(kQualityScoresTag)]; image_segmenter.Out(kImageTag) >> graph[Output(kImageTag)]; return graph.GetConfig(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index 3d6df3022..f977c0159 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -115,6 +115,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi { segmenterOptions.outputCategoryMask() ? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask") : -1; + final int qualityScoresOutStreamIndex = + getStreamIndex.apply(outputStreams, "QUALITY_SCORES:quality_scores"); final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out"); // TODO: Consolidate OutputHandler and TaskRunner. @@ -128,6 +130,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { return ImageSegmenterResult.create( Optional.empty(), Optional.empty(), + new ArrayList<>(), packets.get(imageOutStreamIndex).getTimestamp()); } boolean copyImage = !segmenterOptions.resultListener().isPresent(); @@ -182,9 +185,16 @@ public final class ImageSegmenter extends BaseVisionTaskApi { new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA); categoryMask = Optional.of(builder.build()); } + float[] qualityScores = + PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex)); + List qualityScoresList = new ArrayList<>(qualityScores.length); + for (float score : qualityScores) { + qualityScoresList.add(score); + } return ImageSegmenterResult.create( confidenceMasks, categoryMask, + qualityScoresList, BaseVisionTaskApi.generateResultTimestampMs( segmenterOptions.runningMode(), packets.get(imageOutStreamIndex))); } @@ -592,8 +602,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi { public abstract Builder setOutputCategoryMask(boolean value); /** - * Sets an optional {@link ResultListener} to receive the segmentation results when the graph - * pipeline is done processing an image. + * /** Sets an optional {@link ResultListener} to receive the segmentation results when the + * graph pipeline is done processing an image. */ public abstract Builder setResultListener( ResultListener value); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java index cbc5211cc..e4ac85c2f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java @@ -34,19 +34,30 @@ public abstract class ImageSegmenterResult implements TaskResult { * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a * category mask, where each pixel represents the class which the pixel in the original image * was predicted to belong to. + * @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to + * `1` if the model doesn't output quality scores. Each element corresponds to the score of + * the category in the model outputs. * @param timestampMs a timestamp for this result. */ // TODO: consolidate output formats across platforms. public static ImageSegmenterResult create( - Optional> confidenceMasks, Optional categoryMask, long timestampMs) { + Optional> confidenceMasks, + Optional categoryMask, + List qualityScores, + long timestampMs) { return new AutoValue_ImageSegmenterResult( - confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs); + confidenceMasks.map(Collections::unmodifiableList), + categoryMask, + Collections.unmodifiableList(qualityScores), + timestampMs); } public abstract Optional> confidenceMasks(); public abstract Optional categoryMask(); + public abstract List qualityScores(); + @Override public abstract long timestampMs(); } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index e9ff1f2b5..fe0ce0c3f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -127,6 +127,10 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { outputStreams.add("CATEGORY_MASK:category_mask"); } final int categoryMaskOutStreamIndex = outputStreams.size() - 1; + + outputStreams.add("QUALITY_SCORES:quality_scores"); + final int qualityScoresOutStreamIndex = outputStreams.size() - 1; + outputStreams.add("IMAGE:image_out"); // TODO: add test for stream indices. final int imageOutStreamIndex = outputStreams.size() - 1; @@ -142,6 +146,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { return ImageSegmenterResult.create( Optional.empty(), Optional.empty(), + new ArrayList<>(), packets.get(imageOutStreamIndex).getTimestamp()); } // If resultListener is not provided, the resulted MPImage is deep copied from @@ -199,9 +204,17 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { categoryMask = Optional.of(builder.build()); } + float[] qualityScores = + PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex)); + List qualityScoresList = new ArrayList<>(qualityScores.length); + for (float score : qualityScores) { + qualityScoresList.add(score); + } + return ImageSegmenterResult.create( confidenceMasks, categoryMask, + qualityScoresList, BaseVisionTaskApi.generateResultTimestampMs( RunningMode.IMAGE, packets.get(imageOutStreamIndex))); }