internal update
PiperOrigin-RevId: 533197055
This commit is contained in:
		
							parent
							
								
									a1755044ea
								
							
						
					
					
						commit
						c248525eeb
					
				| 
						 | 
					@ -291,8 +291,11 @@ class TensorsToSegmentationCalculator : public Node {
 | 
				
			||||||
  static constexpr Output<Image>::Multiple kConfidenceMaskOut{
 | 
					  static constexpr Output<Image>::Multiple kConfidenceMaskOut{
 | 
				
			||||||
      "CONFIDENCE_MASK"};
 | 
					      "CONFIDENCE_MASK"};
 | 
				
			||||||
  static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
 | 
					  static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
 | 
				
			||||||
 | 
					  static constexpr Output<std::vector<float>>::Optional kQualityScoresOut{
 | 
				
			||||||
 | 
					      "QUALITY_SCORES"};
 | 
				
			||||||
  MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
 | 
					  MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
 | 
				
			||||||
                          kConfidenceMaskOut, kCategoryMaskOut);
 | 
					                          kConfidenceMaskOut, kCategoryMaskOut,
 | 
				
			||||||
 | 
					                          kQualityScoresOut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  static absl::Status UpdateContract(CalculatorContract* cc);
 | 
					  static absl::Status UpdateContract(CalculatorContract* cc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -345,12 +348,33 @@ absl::Status TensorsToSegmentationCalculator::Open(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::Status TensorsToSegmentationCalculator::Process(
 | 
					absl::Status TensorsToSegmentationCalculator::Process(
 | 
				
			||||||
    mediapipe::CalculatorContext* cc) {
 | 
					    mediapipe::CalculatorContext* cc) {
 | 
				
			||||||
  RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1)
 | 
					  const auto& input_tensors = kTensorsIn(cc).Get();
 | 
				
			||||||
      << "Expect a vector of single Tensor.";
 | 
					  if (input_tensors.size() != 1 && input_tensors.size() != 2) {
 | 
				
			||||||
  const auto& input_tensor = kTensorsIn(cc).Get()[0];
 | 
					    return absl::InvalidArgumentError(
 | 
				
			||||||
 | 
					        "Expect input tensor vector of size 1 or 2.");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  const auto& input_tensor = *input_tensors.rbegin();
 | 
				
			||||||
  ASSIGN_OR_RETURN(const Shape input_shape,
 | 
					  ASSIGN_OR_RETURN(const Shape input_shape,
 | 
				
			||||||
                   GetImageLikeTensorShape(input_tensor));
 | 
					                   GetImageLikeTensorShape(input_tensor));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // TODO: should use tensor signature to get the correct output
 | 
				
			||||||
 | 
					  // tensor.
 | 
				
			||||||
 | 
					  if (input_tensors.size() == 2) {
 | 
				
			||||||
 | 
					    const auto& quality_tensor = input_tensors[0];
 | 
				
			||||||
 | 
					    const float* quality_score_buffer =
 | 
				
			||||||
 | 
					        quality_tensor.GetCpuReadView().buffer<float>();
 | 
				
			||||||
 | 
					    const std::vector<float> quality_scores(
 | 
				
			||||||
 | 
					        quality_score_buffer,
 | 
				
			||||||
 | 
					        quality_score_buffer +
 | 
				
			||||||
 | 
					            (quality_tensor.bytes() / quality_tensor.element_size()));
 | 
				
			||||||
 | 
					    kQualityScoresOut(cc).Send(quality_scores);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    // If the input_tensors don't contain quality scores, send the default
 | 
				
			||||||
 | 
					    // quality scores as 1.
 | 
				
			||||||
 | 
					    const std::vector<float> quality_scores(input_shape.channels, 1.0f);
 | 
				
			||||||
 | 
					    kQualityScoresOut(cc).Send(quality_scores);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Category mask does not require activation function.
 | 
					  // Category mask does not require activation function.
 | 
				
			||||||
  if (options_.segmenter_options().output_type() ==
 | 
					  if (options_.segmenter_options().output_type() ==
 | 
				
			||||||
          SegmenterOptions::CONFIDENCE_MASK &&
 | 
					          SegmenterOptions::CONFIDENCE_MASK &&
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -46,6 +46,8 @@ constexpr char kImageOutStreamName[] = "image_out";
 | 
				
			||||||
constexpr char kImageTag[] = "IMAGE";
 | 
					constexpr char kImageTag[] = "IMAGE";
 | 
				
			||||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
 | 
					constexpr char kNormRectStreamName[] = "norm_rect_in";
 | 
				
			||||||
constexpr char kNormRectTag[] = "NORM_RECT";
 | 
					constexpr char kNormRectTag[] = "NORM_RECT";
 | 
				
			||||||
 | 
					constexpr char kQualityScoresStreamName[] = "quality_scores";
 | 
				
			||||||
 | 
					constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
 | 
				
			||||||
constexpr char kSubgraphTypeName[] =
 | 
					constexpr char kSubgraphTypeName[] =
 | 
				
			||||||
    "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
 | 
					    "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
 | 
				
			||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
 | 
					constexpr int kMicroSecondsPerMilliSecond = 1000;
 | 
				
			||||||
| 
						 | 
					@ -77,6 +79,8 @@ CalculatorGraphConfig CreateGraphConfig(
 | 
				
			||||||
    task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
 | 
					    task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
 | 
				
			||||||
        graph.Out(kCategoryMaskTag);
 | 
					        graph.Out(kCategoryMaskTag);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
 | 
				
			||||||
 | 
					      graph.Out(kQualityScoresTag);
 | 
				
			||||||
  task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
 | 
					  task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
 | 
				
			||||||
      graph.Out(kImageTag);
 | 
					      graph.Out(kImageTag);
 | 
				
			||||||
  if (enable_flow_limiting) {
 | 
					  if (enable_flow_limiting) {
 | 
				
			||||||
| 
						 | 
					@ -172,9 +176,13 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
 | 
				
			||||||
            category_mask =
 | 
					            category_mask =
 | 
				
			||||||
                status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
 | 
					                status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
 | 
					          const std::vector<float>& quality_scores =
 | 
				
			||||||
 | 
					              status_or_packets.value()[kQualityScoresStreamName]
 | 
				
			||||||
 | 
					                  .Get<std::vector<float>>();
 | 
				
			||||||
          Packet image_packet = status_or_packets.value()[kImageOutStreamName];
 | 
					          Packet image_packet = status_or_packets.value()[kImageOutStreamName];
 | 
				
			||||||
          result_callback(
 | 
					          result_callback(
 | 
				
			||||||
              {{confidence_masks, category_mask}}, image_packet.Get<Image>(),
 | 
					              {{confidence_masks, category_mask, quality_scores}},
 | 
				
			||||||
 | 
					              image_packet.Get<Image>(),
 | 
				
			||||||
              image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
 | 
					              image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -227,7 +235,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
 | 
				
			||||||
  if (output_category_mask_) {
 | 
					  if (output_category_mask_) {
 | 
				
			||||||
    category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
 | 
					    category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return {{confidence_masks, category_mask}};
 | 
					  const std::vector<float>& quality_scores =
 | 
				
			||||||
 | 
					      output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
 | 
				
			||||||
 | 
					  return {{confidence_masks, category_mask, quality_scores}};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
 | 
					absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
 | 
				
			||||||
| 
						 | 
					@ -260,7 +270,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
 | 
				
			||||||
  if (output_category_mask_) {
 | 
					  if (output_category_mask_) {
 | 
				
			||||||
    category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
 | 
					    category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return {{confidence_masks, category_mask}};
 | 
					  const std::vector<float>& quality_scores =
 | 
				
			||||||
 | 
					      output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
 | 
				
			||||||
 | 
					  return {{confidence_masks, category_mask, quality_scores}};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::Status ImageSegmenter::SegmentAsync(
 | 
					absl::Status ImageSegmenter::SegmentAsync(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 | 
				
			||||||
limitations under the License.
 | 
					limitations under the License.
 | 
				
			||||||
==============================================================================*/
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <cstdint>
 | 
				
			||||||
#include <memory>
 | 
					#include <memory>
 | 
				
			||||||
#include <optional>
 | 
					#include <optional>
 | 
				
			||||||
#include <type_traits>
 | 
					#include <type_traits>
 | 
				
			||||||
| 
						 | 
					@ -81,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
 | 
				
			||||||
constexpr char kNormRectTag[] = "NORM_RECT";
 | 
					constexpr char kNormRectTag[] = "NORM_RECT";
 | 
				
			||||||
constexpr char kTensorsTag[] = "TENSORS";
 | 
					constexpr char kTensorsTag[] = "TENSORS";
 | 
				
			||||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
 | 
					constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
 | 
				
			||||||
 | 
					constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
 | 
				
			||||||
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
 | 
					constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Struct holding the different output streams produced by the image segmenter
 | 
					// Struct holding the different output streams produced by the image segmenter
 | 
				
			||||||
| 
						 | 
					@ -90,6 +92,7 @@ struct ImageSegmenterOutputs {
 | 
				
			||||||
  std::optional<std::vector<Source<Image>>> confidence_masks;
 | 
					  std::optional<std::vector<Source<Image>>> confidence_masks;
 | 
				
			||||||
  std::optional<Source<Image>> category_mask;
 | 
					  std::optional<Source<Image>> category_mask;
 | 
				
			||||||
  // The same as the input image, mainly used for live stream mode.
 | 
					  // The same as the input image, mainly used for live stream mode.
 | 
				
			||||||
 | 
					  std::optional<Source<std::vector<float>>> quality_scores;
 | 
				
			||||||
  Source<Image> image;
 | 
					  Source<Image> image;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -191,19 +194,12 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
 | 
				
			||||||
        "Segmentation tflite models are assumed to have a single subgraph.",
 | 
					        "Segmentation tflite models are assumed to have a single subgraph.",
 | 
				
			||||||
        MediaPipeTasksStatus::kInvalidArgumentError);
 | 
					        MediaPipeTasksStatus::kInvalidArgumentError);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  const auto* primary_subgraph = (*model.subgraphs())[0];
 | 
					 | 
				
			||||||
  if (primary_subgraph->outputs()->size() != 1) {
 | 
					 | 
				
			||||||
    return CreateStatusWithPayload(
 | 
					 | 
				
			||||||
        absl::StatusCode::kInvalidArgument,
 | 
					 | 
				
			||||||
        "Segmentation tflite models are assumed to have a single output.",
 | 
					 | 
				
			||||||
        MediaPipeTasksStatus::kInvalidArgumentError);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  ASSIGN_OR_RETURN(
 | 
					  ASSIGN_OR_RETURN(
 | 
				
			||||||
      *options->mutable_label_items(),
 | 
					      *options->mutable_label_items(),
 | 
				
			||||||
      GetLabelItemsIfAny(*metadata_extractor,
 | 
					      GetLabelItemsIfAny(
 | 
				
			||||||
                         *metadata_extractor->GetOutputTensorMetadata()->Get(0),
 | 
					          *metadata_extractor,
 | 
				
			||||||
                         segmenter_option.display_names_locale()));
 | 
					          **metadata_extractor->GetOutputTensorMetadata()->crbegin(),
 | 
				
			||||||
 | 
					          segmenter_option.display_names_locale()));
 | 
				
			||||||
  return absl::OkStatus();
 | 
					  return absl::OkStatus();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -213,10 +209,16 @@ absl::StatusOr<const tflite::Tensor*> GetOutputTensor(
 | 
				
			||||||
  const tflite::Model& model = *model_resources.GetTfLiteModel();
 | 
					  const tflite::Model& model = *model_resources.GetTfLiteModel();
 | 
				
			||||||
  const auto* primary_subgraph = (*model.subgraphs())[0];
 | 
					  const auto* primary_subgraph = (*model.subgraphs())[0];
 | 
				
			||||||
  const auto* output_tensor =
 | 
					  const auto* output_tensor =
 | 
				
			||||||
      (*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]];
 | 
					      (*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()];
 | 
				
			||||||
  return output_tensor;
 | 
					  return output_tensor;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) {
 | 
				
			||||||
 | 
					  const tflite::Model& model = *model_resources.GetTfLiteModel();
 | 
				
			||||||
 | 
					  const auto* primary_subgraph = (*model.subgraphs())[0];
 | 
				
			||||||
 | 
					  return primary_subgraph->outputs()->size();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Get the input tensor from the tflite model of given model resources.
 | 
					// Get the input tensor from the tflite model of given model resources.
 | 
				
			||||||
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
 | 
					absl::StatusOr<const tflite::Tensor*> GetInputTensor(
 | 
				
			||||||
    const core::ModelResources& model_resources) {
 | 
					    const core::ModelResources& model_resources) {
 | 
				
			||||||
| 
						 | 
					@ -433,6 +435,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
 | 
				
			||||||
        *output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
 | 
					        *output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    if (output_streams.quality_scores) {
 | 
				
			||||||
 | 
					      *output_streams.quality_scores >>
 | 
				
			||||||
 | 
					          graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    output_streams.image >> graph[Output<Image>(kImageTag)];
 | 
					    output_streams.image >> graph[Output<Image>(kImageTag)];
 | 
				
			||||||
    return graph.GetConfig();
 | 
					    return graph.GetConfig();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -530,9 +536,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
 | 
				
			||||||
              tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
 | 
					              tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					      auto quality_scores =
 | 
				
			||||||
 | 
					          tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
 | 
				
			||||||
      return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
 | 
					      return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
 | 
				
			||||||
                                   /*confidence_masks=*/std::nullopt,
 | 
					                                   /*confidence_masks=*/std::nullopt,
 | 
				
			||||||
                                   /*category_mask=*/std::nullopt,
 | 
					                                   /*category_mask=*/std::nullopt,
 | 
				
			||||||
 | 
					                                   /*quality_scores=*/quality_scores,
 | 
				
			||||||
                                   /*image=*/image_and_tensors.image};
 | 
					                                   /*image=*/image_and_tensors.image};
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      std::optional<std::vector<Source<Image>>> confidence_masks;
 | 
					      std::optional<std::vector<Source<Image>>> confidence_masks;
 | 
				
			||||||
| 
						 | 
					@ -552,9 +561,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
 | 
				
			||||||
      if (output_category_mask_) {
 | 
					      if (output_category_mask_) {
 | 
				
			||||||
        category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
 | 
					        category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					      auto quality_scores =
 | 
				
			||||||
 | 
					          tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
 | 
				
			||||||
      return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
 | 
					      return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
 | 
				
			||||||
                                   /*confidence_masks=*/confidence_masks,
 | 
					                                   /*confidence_masks=*/confidence_masks,
 | 
				
			||||||
                                   /*category_mask=*/category_mask,
 | 
					                                   /*category_mask=*/category_mask,
 | 
				
			||||||
 | 
					                                   /*quality_scores=*/quality_scores,
 | 
				
			||||||
                                   /*image=*/image_and_tensors.image};
 | 
					                                   /*image=*/image_and_tensors.image};
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -33,6 +33,10 @@ struct ImageSegmenterResult {
 | 
				
			||||||
  // A category mask of uint8 image in GRAY8 format where each pixel represents
 | 
					  // A category mask of uint8 image in GRAY8 format where each pixel represents
 | 
				
			||||||
  // the class which the pixel in the original image was predicted to belong to.
 | 
					  // the class which the pixel in the original image was predicted to belong to.
 | 
				
			||||||
  std::optional<Image> category_mask;
 | 
					  std::optional<Image> category_mask;
 | 
				
			||||||
 | 
					  // The quality scores of the result masks, in the range of [0, 1]. Default to
 | 
				
			||||||
 | 
					  // `1` if the model doesn't output quality scores. Each element corresponds to
 | 
				
			||||||
 | 
					  // the score of the category in the model outputs.
 | 
				
			||||||
 | 
					  std::vector<float> quality_scores;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace image_segmenter
 | 
					}  // namespace image_segmenter
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,12 +51,14 @@ constexpr char kImageInStreamName[] = "image_in";
 | 
				
			||||||
constexpr char kImageOutStreamName[] = "image_out";
 | 
					constexpr char kImageOutStreamName[] = "image_out";
 | 
				
			||||||
constexpr char kRoiStreamName[] = "roi_in";
 | 
					constexpr char kRoiStreamName[] = "roi_in";
 | 
				
			||||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
 | 
					constexpr char kNormRectStreamName[] = "norm_rect_in";
 | 
				
			||||||
 | 
					constexpr char kQualityScoresStreamName[] = "quality_scores";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
 | 
					constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
 | 
				
			||||||
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
 | 
					constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
 | 
				
			||||||
constexpr absl::string_view kImageTag{"IMAGE"};
 | 
					constexpr absl::string_view kImageTag{"IMAGE"};
 | 
				
			||||||
constexpr absl::string_view kRoiTag{"ROI"};
 | 
					constexpr absl::string_view kRoiTag{"ROI"};
 | 
				
			||||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
 | 
					constexpr absl::string_view kNormRectTag{"NORM_RECT"};
 | 
				
			||||||
 | 
					constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
constexpr absl::string_view kSubgraphTypeName{
 | 
					constexpr absl::string_view kSubgraphTypeName{
 | 
				
			||||||
    "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
 | 
					    "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
 | 
				
			||||||
| 
						 | 
					@ -91,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig(
 | 
				
			||||||
    task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
 | 
					    task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
 | 
				
			||||||
        graph.Out(kCategoryMaskTag);
 | 
					        graph.Out(kCategoryMaskTag);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
 | 
				
			||||||
 | 
					      graph.Out(kQualityScoresTag);
 | 
				
			||||||
  task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
 | 
					  task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
 | 
				
			||||||
      graph.Out(kImageTag);
 | 
					      graph.Out(kImageTag);
 | 
				
			||||||
  graph.In(kImageTag) >> task_subgraph.In(kImageTag);
 | 
					  graph.In(kImageTag) >> task_subgraph.In(kImageTag);
 | 
				
			||||||
| 
						 | 
					@ -201,7 +205,9 @@ absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
 | 
				
			||||||
  if (output_category_mask_) {
 | 
					  if (output_category_mask_) {
 | 
				
			||||||
    category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
 | 
					    category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return {{confidence_masks, category_mask}};
 | 
					  const std::vector<float>& quality_scores =
 | 
				
			||||||
 | 
					      output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
 | 
				
			||||||
 | 
					  return {{confidence_masks, category_mask, quality_scores}};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace interactive_segmenter
 | 
					}  // namespace interactive_segmenter
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -58,6 +58,7 @@ constexpr absl::string_view kAlphaTag{"ALPHA"};
 | 
				
			||||||
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
 | 
					constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
 | 
				
			||||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
 | 
					constexpr absl::string_view kNormRectTag{"NORM_RECT"};
 | 
				
			||||||
constexpr absl::string_view kRoiTag{"ROI"};
 | 
					constexpr absl::string_view kRoiTag{"ROI"};
 | 
				
			||||||
 | 
					constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Updates the graph to return `roi` stream which has same dimension as
 | 
					// Updates the graph to return `roi` stream which has same dimension as
 | 
				
			||||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
 | 
					// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
 | 
				
			||||||
| 
						 | 
					@ -200,6 +201,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
 | 
				
			||||||
            graph[Output<Image>(kCategoryMaskTag)];
 | 
					            graph[Output<Image>(kCategoryMaskTag)];
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    image_segmenter.Out(kQualityScoresTag) >>
 | 
				
			||||||
 | 
					        graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
 | 
				
			||||||
    image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
 | 
					    image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return graph.GetConfig();
 | 
					    return graph.GetConfig();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -115,6 +115,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
        segmenterOptions.outputCategoryMask()
 | 
					        segmenterOptions.outputCategoryMask()
 | 
				
			||||||
            ? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
 | 
					            ? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
 | 
				
			||||||
            : -1;
 | 
					            : -1;
 | 
				
			||||||
 | 
					    final int qualityScoresOutStreamIndex =
 | 
				
			||||||
 | 
					        getStreamIndex.apply(outputStreams, "QUALITY_SCORES:quality_scores");
 | 
				
			||||||
    final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
 | 
					    final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO: Consolidate OutputHandler and TaskRunner.
 | 
					    // TODO: Consolidate OutputHandler and TaskRunner.
 | 
				
			||||||
| 
						 | 
					@ -128,6 +130,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
              return ImageSegmenterResult.create(
 | 
					              return ImageSegmenterResult.create(
 | 
				
			||||||
                  Optional.empty(),
 | 
					                  Optional.empty(),
 | 
				
			||||||
                  Optional.empty(),
 | 
					                  Optional.empty(),
 | 
				
			||||||
 | 
					                  new ArrayList<>(),
 | 
				
			||||||
                  packets.get(imageOutStreamIndex).getTimestamp());
 | 
					                  packets.get(imageOutStreamIndex).getTimestamp());
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            boolean copyImage = !segmenterOptions.resultListener().isPresent();
 | 
					            boolean copyImage = !segmenterOptions.resultListener().isPresent();
 | 
				
			||||||
| 
						 | 
					@ -182,9 +185,16 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
                  new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
 | 
					                  new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
 | 
				
			||||||
              categoryMask = Optional.of(builder.build());
 | 
					              categoryMask = Optional.of(builder.build());
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					            float[] qualityScores =
 | 
				
			||||||
 | 
					                PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
 | 
				
			||||||
 | 
					            List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
 | 
				
			||||||
 | 
					            for (float score : qualityScores) {
 | 
				
			||||||
 | 
					              qualityScoresList.add(score);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
            return ImageSegmenterResult.create(
 | 
					            return ImageSegmenterResult.create(
 | 
				
			||||||
                confidenceMasks,
 | 
					                confidenceMasks,
 | 
				
			||||||
                categoryMask,
 | 
					                categoryMask,
 | 
				
			||||||
 | 
					                qualityScoresList,
 | 
				
			||||||
                BaseVisionTaskApi.generateResultTimestampMs(
 | 
					                BaseVisionTaskApi.generateResultTimestampMs(
 | 
				
			||||||
                    segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
 | 
					                    segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
| 
						 | 
					@ -592,8 +602,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
      public abstract Builder setOutputCategoryMask(boolean value);
 | 
					      public abstract Builder setOutputCategoryMask(boolean value);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      /**
 | 
					      /**
 | 
				
			||||||
       * Sets an optional {@link ResultListener} to receive the segmentation results when the graph
 | 
					       * /** Sets an optional {@link ResultListener} to receive the segmentation results when the
 | 
				
			||||||
       * pipeline is done processing an image.
 | 
					       * graph pipeline is done processing an image.
 | 
				
			||||||
       */
 | 
					       */
 | 
				
			||||||
      public abstract Builder setResultListener(
 | 
					      public abstract Builder setResultListener(
 | 
				
			||||||
          ResultListener<ImageSegmenterResult, MPImage> value);
 | 
					          ResultListener<ImageSegmenterResult, MPImage> value);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,19 +34,30 @@ public abstract class ImageSegmenterResult implements TaskResult {
 | 
				
			||||||
   * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
 | 
					   * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
 | 
				
			||||||
   *     category mask, where each pixel represents the class which the pixel in the original image
 | 
					   *     category mask, where each pixel represents the class which the pixel in the original image
 | 
				
			||||||
   *     was predicted to belong to.
 | 
					   *     was predicted to belong to.
 | 
				
			||||||
 | 
					   * @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to
 | 
				
			||||||
 | 
					   *     `1` if the model doesn't output quality scores. Each element corresponds to the score of
 | 
				
			||||||
 | 
					   *     the category in the model outputs.
 | 
				
			||||||
   * @param timestampMs a timestamp for this result.
 | 
					   * @param timestampMs a timestamp for this result.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  // TODO: consolidate output formats across platforms.
 | 
					  // TODO: consolidate output formats across platforms.
 | 
				
			||||||
  public static ImageSegmenterResult create(
 | 
					  public static ImageSegmenterResult create(
 | 
				
			||||||
      Optional<List<MPImage>> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) {
 | 
					      Optional<List<MPImage>> confidenceMasks,
 | 
				
			||||||
 | 
					      Optional<MPImage> categoryMask,
 | 
				
			||||||
 | 
					      List<Float> qualityScores,
 | 
				
			||||||
 | 
					      long timestampMs) {
 | 
				
			||||||
    return new AutoValue_ImageSegmenterResult(
 | 
					    return new AutoValue_ImageSegmenterResult(
 | 
				
			||||||
        confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs);
 | 
					        confidenceMasks.map(Collections::unmodifiableList),
 | 
				
			||||||
 | 
					        categoryMask,
 | 
				
			||||||
 | 
					        Collections.unmodifiableList(qualityScores),
 | 
				
			||||||
 | 
					        timestampMs);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public abstract Optional<List<MPImage>> confidenceMasks();
 | 
					  public abstract Optional<List<MPImage>> confidenceMasks();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public abstract Optional<MPImage> categoryMask();
 | 
					  public abstract Optional<MPImage> categoryMask();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  public abstract List<Float> qualityScores();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @Override
 | 
					  @Override
 | 
				
			||||||
  public abstract long timestampMs();
 | 
					  public abstract long timestampMs();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -127,6 +127,10 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
      outputStreams.add("CATEGORY_MASK:category_mask");
 | 
					      outputStreams.add("CATEGORY_MASK:category_mask");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
 | 
					    final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    outputStreams.add("QUALITY_SCORES:quality_scores");
 | 
				
			||||||
 | 
					    final int qualityScoresOutStreamIndex = outputStreams.size() - 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    outputStreams.add("IMAGE:image_out");
 | 
					    outputStreams.add("IMAGE:image_out");
 | 
				
			||||||
    // TODO: add test for stream indices.
 | 
					    // TODO: add test for stream indices.
 | 
				
			||||||
    final int imageOutStreamIndex = outputStreams.size() - 1;
 | 
					    final int imageOutStreamIndex = outputStreams.size() - 1;
 | 
				
			||||||
| 
						 | 
					@ -142,6 +146,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
              return ImageSegmenterResult.create(
 | 
					              return ImageSegmenterResult.create(
 | 
				
			||||||
                  Optional.empty(),
 | 
					                  Optional.empty(),
 | 
				
			||||||
                  Optional.empty(),
 | 
					                  Optional.empty(),
 | 
				
			||||||
 | 
					                  new ArrayList<>(),
 | 
				
			||||||
                  packets.get(imageOutStreamIndex).getTimestamp());
 | 
					                  packets.get(imageOutStreamIndex).getTimestamp());
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            // If resultListener is not provided, the resulted MPImage is deep copied from
 | 
					            // If resultListener is not provided, the resulted MPImage is deep copied from
 | 
				
			||||||
| 
						 | 
					@ -199,9 +204,17 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
              categoryMask = Optional.of(builder.build());
 | 
					              categoryMask = Optional.of(builder.build());
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            float[] qualityScores =
 | 
				
			||||||
 | 
					                PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
 | 
				
			||||||
 | 
					            List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
 | 
				
			||||||
 | 
					            for (float score : qualityScores) {
 | 
				
			||||||
 | 
					              qualityScoresList.add(score);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            return ImageSegmenterResult.create(
 | 
					            return ImageSegmenterResult.create(
 | 
				
			||||||
                confidenceMasks,
 | 
					                confidenceMasks,
 | 
				
			||||||
                categoryMask,
 | 
					                categoryMask,
 | 
				
			||||||
 | 
					                qualityScoresList,
 | 
				
			||||||
                BaseVisionTaskApi.generateResultTimestampMs(
 | 
					                BaseVisionTaskApi.generateResultTimestampMs(
 | 
				
			||||||
                    RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
 | 
					                    RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user