Image segmenter output both confidence masks and category mask optionally.
PiperOrigin-RevId: 522227345
This commit is contained in:
		
							parent
							
								
									7fe87936e5
								
							
						
					
					
						commit
						d5def9e24d
					
				|  | @ -32,6 +32,7 @@ limitations under the License. | ||||||
| #include "mediapipe/framework/formats/image.h" | #include "mediapipe/framework/formats/image.h" | ||||||
| #include "mediapipe/framework/formats/image_frame_opencv.h" | #include "mediapipe/framework/formats/image_frame_opencv.h" | ||||||
| #include "mediapipe/framework/formats/tensor.h" | #include "mediapipe/framework/formats/tensor.h" | ||||||
|  | #include "mediapipe/framework/port/canonical_errors.h" | ||||||
| #include "mediapipe/framework/port/opencv_core_inc.h" | #include "mediapipe/framework/port/opencv_core_inc.h" | ||||||
| #include "mediapipe/framework/port/opencv_imgproc_inc.h" | #include "mediapipe/framework/port/opencv_imgproc_inc.h" | ||||||
| #include "mediapipe/framework/port/status_macros.h" | #include "mediapipe/framework/port/status_macros.h" | ||||||
|  | @ -210,8 +211,9 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape, | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| // Converts Tensors from a vector of Tensor to Segmentation masks. The
 | // Converts Tensors from a vector of Tensor to Segmentation masks. The
 | ||||||
| // calculator always output confidence masks, and an optional category mask if
 | // calculator can output optional confidence masks if CONFIDENCE_MASK is
 | ||||||
| // CATEGORY_MASK is connected.
 | // connected, and an optional category mask if CATEGORY_MASK is connected. At
 | ||||||
|  | // least one of CONFIDENCE_MASK and CATEGORY_MASK must be connected.
 | ||||||
| //
 | //
 | ||||||
| // Performs optional resizing to OUTPUT_SIZE dimension if provided,
 | // Performs optional resizing to OUTPUT_SIZE dimension if provided,
 | ||||||
| // otherwise the segmented masks is the same size as input tensor.
 | // otherwise the segmented masks is the same size as input tensor.
 | ||||||
|  | @ -296,6 +298,13 @@ absl::Status TensorsToSegmentationCalculator::Open( | ||||||
|                  SegmenterOptions::UNSPECIFIED) |                  SegmenterOptions::UNSPECIFIED) | ||||||
|         << "Must specify output_type as one of " |         << "Must specify output_type as one of " | ||||||
|            "[CONFIDENCE_MASK|CATEGORY_MASK]."; |            "[CONFIDENCE_MASK|CATEGORY_MASK]."; | ||||||
|  |   } else { | ||||||
|  |     if (!cc->Outputs().HasTag("CONFIDENCE_MASK") && | ||||||
|  |         !cc->Outputs().HasTag("CATEGORY_MASK")) { | ||||||
|  |       return absl::InvalidArgumentError( | ||||||
|  |           "At least one of CONFIDENCE_MASK and CATEGORY_MASK must be " | ||||||
|  |           "connected."); | ||||||
|  |     } | ||||||
|   } |   } | ||||||
| #ifdef __EMSCRIPTEN__ | #ifdef __EMSCRIPTEN__ | ||||||
|   MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_)); |   MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_)); | ||||||
|  | @ -366,8 +375,9 @@ absl::Status TensorsToSegmentationCalculator::Process( | ||||||
|     return absl::OkStatus(); |     return absl::OkStatus(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   std::vector<Image> confidence_masks = |   if (cc->Outputs().HasTag("CONFIDENCE_MASK")) { | ||||||
|       ProcessForConfidenceMaskCpu(input_shape, |     std::vector<Image> confidence_masks = ProcessForConfidenceMaskCpu( | ||||||
|  |         input_shape, | ||||||
|         {/* height= */ output_height, |         {/* height= */ output_height, | ||||||
|          /* width= */ output_width, |          /* width= */ output_width, | ||||||
|          /* channels= */ input_shape.channels}, |          /* channels= */ input_shape.channels}, | ||||||
|  | @ -375,6 +385,7 @@ absl::Status TensorsToSegmentationCalculator::Process( | ||||||
|     for (int i = 0; i < confidence_masks.size(); ++i) { |     for (int i = 0; i < confidence_masks.size(); ++i) { | ||||||
|       kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i])); |       kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i])); | ||||||
|     } |     } | ||||||
|  |   } | ||||||
|   if (cc->Outputs().HasTag("CATEGORY_MASK")) { |   if (cc->Outputs().HasTag("CATEGORY_MASK")) { | ||||||
|     kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu( |     kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu( | ||||||
|         input_shape, |         input_shape, | ||||||
|  |  | ||||||
|  | @ -60,15 +60,19 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: | ||||||
| // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
 | // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
 | ||||||
| CalculatorGraphConfig CreateGraphConfig( | CalculatorGraphConfig CreateGraphConfig( | ||||||
|     std::unique_ptr<ImageSegmenterGraphOptionsProto> options, |     std::unique_ptr<ImageSegmenterGraphOptionsProto> options, | ||||||
|     bool output_category_mask, bool enable_flow_limiting) { |     bool output_confidence_masks, bool output_category_mask, | ||||||
|  |     bool enable_flow_limiting) { | ||||||
|   api2::builder::Graph graph; |   api2::builder::Graph graph; | ||||||
|   auto& task_subgraph = graph.AddNode(kSubgraphTypeName); |   auto& task_subgraph = graph.AddNode(kSubgraphTypeName); | ||||||
|   task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap( |   task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap( | ||||||
|       options.get()); |       options.get()); | ||||||
|   graph.In(kImageTag).SetName(kImageInStreamName); |   graph.In(kImageTag).SetName(kImageInStreamName); | ||||||
|   graph.In(kNormRectTag).SetName(kNormRectStreamName); |   graph.In(kNormRectTag).SetName(kNormRectStreamName); | ||||||
|   task_subgraph.Out(kConfidenceMasksTag).SetName(kConfidenceMasksStreamName) >> |   if (output_confidence_masks) { | ||||||
|  |     task_subgraph.Out(kConfidenceMasksTag) | ||||||
|  |             .SetName(kConfidenceMasksStreamName) >> | ||||||
|         graph.Out(kConfidenceMasksTag); |         graph.Out(kConfidenceMasksTag); | ||||||
|  |   } | ||||||
|   if (output_category_mask) { |   if (output_category_mask) { | ||||||
|     task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> |     task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> | ||||||
|         graph.Out(kCategoryMaskTag); |         graph.Out(kCategoryMaskTag); | ||||||
|  | @ -135,11 +139,17 @@ absl::StatusOr<std::vector<std::string>> GetLabelsFromGraphConfig( | ||||||
| 
 | 
 | ||||||
| absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create( | absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create( | ||||||
|     std::unique_ptr<ImageSegmenterOptions> options) { |     std::unique_ptr<ImageSegmenterOptions> options) { | ||||||
|  |   if (!options->output_confidence_masks && !options->output_category_mask) { | ||||||
|  |     return absl::InvalidArgumentError( | ||||||
|  |         "At least one of `output_confidence_masks` and `output_category_mask` " | ||||||
|  |         "must be set."); | ||||||
|  |   } | ||||||
|   auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); |   auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); | ||||||
|   tasks::core::PacketsCallback packets_callback = nullptr; |   tasks::core::PacketsCallback packets_callback = nullptr; | ||||||
|   if (options->result_callback) { |   if (options->result_callback) { | ||||||
|     auto result_callback = options->result_callback; |     auto result_callback = options->result_callback; | ||||||
|     bool output_category_mask = options->output_category_mask; |     bool output_category_mask = options->output_category_mask; | ||||||
|  |     bool output_confidence_masks = options->output_confidence_masks; | ||||||
|     packets_callback = |     packets_callback = | ||||||
|         [=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) { |         [=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) { | ||||||
|           if (!status_or_packets.ok()) { |           if (!status_or_packets.ok()) { | ||||||
|  | @ -151,8 +161,12 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create( | ||||||
|           if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { |           if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { | ||||||
|             return; |             return; | ||||||
|           } |           } | ||||||
|           Packet confidence_masks = |           std::optional<std::vector<Image>> confidence_masks; | ||||||
|               status_or_packets.value()[kConfidenceMasksStreamName]; |           if (output_confidence_masks) { | ||||||
|  |             confidence_masks = | ||||||
|  |                 status_or_packets.value()[kConfidenceMasksStreamName] | ||||||
|  |                     .Get<std::vector<Image>>(); | ||||||
|  |           } | ||||||
|           std::optional<Image> category_mask; |           std::optional<Image> category_mask; | ||||||
|           if (output_category_mask) { |           if (output_category_mask) { | ||||||
|             category_mask = |             category_mask = | ||||||
|  | @ -160,23 +174,24 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create( | ||||||
|           } |           } | ||||||
|           Packet image_packet = status_or_packets.value()[kImageOutStreamName]; |           Packet image_packet = status_or_packets.value()[kImageOutStreamName]; | ||||||
|           result_callback( |           result_callback( | ||||||
|               {{confidence_masks.Get<std::vector<Image>>(), category_mask}}, |               {{confidence_masks, category_mask}}, image_packet.Get<Image>(), | ||||||
|               image_packet.Get<Image>(), |               image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); | ||||||
|               confidence_masks.Timestamp().Value() / |  | ||||||
|                   kMicroSecondsPerMilliSecond); |  | ||||||
|         }; |         }; | ||||||
|   } |   } | ||||||
|   auto image_segmenter = |   auto image_segmenter = | ||||||
|       core::VisionTaskApiFactory::Create<ImageSegmenter, |       core::VisionTaskApiFactory::Create<ImageSegmenter, | ||||||
|                                          ImageSegmenterGraphOptionsProto>( |                                          ImageSegmenterGraphOptionsProto>( | ||||||
|           CreateGraphConfig( |           CreateGraphConfig( | ||||||
|               std::move(options_proto), options->output_category_mask, |               std::move(options_proto), options->output_confidence_masks, | ||||||
|  |               options->output_category_mask, | ||||||
|               options->running_mode == core::RunningMode::LIVE_STREAM), |               options->running_mode == core::RunningMode::LIVE_STREAM), | ||||||
|           std::move(options->base_options.op_resolver), options->running_mode, |           std::move(options->base_options.op_resolver), options->running_mode, | ||||||
|           std::move(packets_callback)); |           std::move(packets_callback)); | ||||||
|   if (!image_segmenter.ok()) { |   if (!image_segmenter.ok()) { | ||||||
|     return image_segmenter.status(); |     return image_segmenter.status(); | ||||||
|   } |   } | ||||||
|  |   image_segmenter.value()->output_confidence_masks_ = | ||||||
|  |       options->output_confidence_masks; | ||||||
|   image_segmenter.value()->output_category_mask_ = |   image_segmenter.value()->output_category_mask_ = | ||||||
|       options->output_category_mask; |       options->output_category_mask; | ||||||
|   ASSIGN_OR_RETURN( |   ASSIGN_OR_RETURN( | ||||||
|  | @ -203,8 +218,11 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment( | ||||||
|           {{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))}, |           {{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))}, | ||||||
|            {kNormRectStreamName, |            {kNormRectStreamName, | ||||||
|             MakePacket<NormalizedRect>(std::move(norm_rect))}})); |             MakePacket<NormalizedRect>(std::move(norm_rect))}})); | ||||||
|   std::vector<Image> confidence_masks = |   std::optional<std::vector<Image>> confidence_masks; | ||||||
|  |   if (output_confidence_masks_) { | ||||||
|  |     confidence_masks = | ||||||
|         output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>(); |         output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>(); | ||||||
|  |   } | ||||||
|   std::optional<Image> category_mask; |   std::optional<Image> category_mask; | ||||||
|   if (output_category_mask_) { |   if (output_category_mask_) { | ||||||
|     category_mask = output_packets[kCategoryMaskStreamName].Get<Image>(); |     category_mask = output_packets[kCategoryMaskStreamName].Get<Image>(); | ||||||
|  | @ -233,8 +251,11 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo( | ||||||
|            {kNormRectStreamName, |            {kNormRectStreamName, | ||||||
|             MakePacket<NormalizedRect>(std::move(norm_rect)) |             MakePacket<NormalizedRect>(std::move(norm_rect)) | ||||||
|                 .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); |                 .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); | ||||||
|   std::vector<Image> confidence_masks = |   std::optional<std::vector<Image>> confidence_masks; | ||||||
|  |   if (output_confidence_masks_) { | ||||||
|  |     confidence_masks = | ||||||
|         output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>(); |         output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>(); | ||||||
|  |   } | ||||||
|   std::optional<Image> category_mask; |   std::optional<Image> category_mask; | ||||||
|   if (output_category_mask_) { |   if (output_category_mask_) { | ||||||
|     category_mask = output_packets[kCategoryMaskStreamName].Get<Image>(); |     category_mask = output_packets[kCategoryMaskStreamName].Get<Image>(); | ||||||
|  |  | ||||||
|  | @ -53,6 +53,9 @@ struct ImageSegmenterOptions { | ||||||
|   // Metadata, if any. Defaults to English.
 |   // Metadata, if any. Defaults to English.
 | ||||||
|   std::string display_names_locale = "en"; |   std::string display_names_locale = "en"; | ||||||
| 
 | 
 | ||||||
|  |   // Whether to output confidence masks.
 | ||||||
|  |   bool output_confidence_masks = true; | ||||||
|  | 
 | ||||||
|   // Whether to output category mask.
 |   // Whether to output category mask.
 | ||||||
|   bool output_category_mask = false; |   bool output_category_mask = false; | ||||||
| 
 | 
 | ||||||
|  | @ -77,8 +80,10 @@ struct ImageSegmenterOptions { | ||||||
| //    - if type is kTfLiteFloat32, NormalizationOptions are required to be
 | //    - if type is kTfLiteFloat32, NormalizationOptions are required to be
 | ||||||
| //      attached to the metadata for input normalization.
 | //      attached to the metadata for input normalization.
 | ||||||
| // Output ImageSegmenterResult:
 | // Output ImageSegmenterResult:
 | ||||||
| //    Provides confidence masks and an optional category mask if
 | //    Provides optional confidence masks if `output_confidence_masks` is set
 | ||||||
| //    `output_category_mask` is set true.
 | //    true,  and an optional category mask if `output_category_mask` is set
 | ||||||
|  | //    true. At least one of `output_confidence_masks` and `output_category_mask`
 | ||||||
|  | //    must be set to true.
 | ||||||
| // An example of such model can be found at:
 | // An example of such model can be found at:
 | ||||||
| // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
 | // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
 | ||||||
| class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { | class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { | ||||||
|  | @ -167,6 +172,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|   std::vector<std::string> labels_; |   std::vector<std::string> labels_; | ||||||
|  |   bool output_confidence_masks_; | ||||||
|   bool output_category_mask_; |   bool output_category_mask_; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -326,8 +326,10 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors( | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs
 | // An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs
 | ||||||
| // semantic segmentation. The graph always output confidence masks, and an
 | // semantic segmentation. The graph can output optional confidence masks if
 | ||||||
| // optional category mask if CATEGORY_MASK is connected.
 | // CONFIDENCE_MASKS is connected, and an optional category mask if CATEGORY_MASK
 | ||||||
|  | // is connected. At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and
 | ||||||
|  | // CATEGORY_MASK must be connected.
 | ||||||
| //
 | //
 | ||||||
| //  Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and
 | //  Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and
 | ||||||
| //  CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular
 | //  CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular
 | ||||||
|  | @ -347,7 +349,7 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors( | ||||||
| //   CONFIDENCE_MASK - mediapipe::Image @Multiple
 | //   CONFIDENCE_MASK - mediapipe::Image @Multiple
 | ||||||
| //     Confidence masks for individual category. Confidence mask of single
 | //     Confidence masks for individual category. Confidence mask of single
 | ||||||
| //     category can be accessed by index based output stream.
 | //     category can be accessed by index based output stream.
 | ||||||
| //   CONFIDENCE_MASKS - std::vector<mediapipe::Image>
 | //   CONFIDENCE_MASKS - std::vector<mediapipe::Image> @Optional
 | ||||||
| //     The output confidence masks grouped in a vector.
 | //     The output confidence masks grouped in a vector.
 | ||||||
| //   CATEGORY_MASK - mediapipe::Image @Optional
 | //   CATEGORY_MASK - mediapipe::Image @Optional
 | ||||||
| //     Optional Category mask.
 | //     Optional Category mask.
 | ||||||
|  | @ -356,7 +358,7 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors( | ||||||
| //
 | //
 | ||||||
| // Example:
 | // Example:
 | ||||||
| // node {
 | // node {
 | ||||||
| //   calculator: "mediapipe.tasks.vision.ImageSegmenterGraph"
 | //   calculator: "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"
 | ||||||
| //   input_stream: "IMAGE:image"
 | //   input_stream: "IMAGE:image"
 | ||||||
| //   output_stream: "SEGMENTATION:segmented_masks"
 | //   output_stream: "SEGMENTATION:segmented_masks"
 | ||||||
| //   options {
 | //   options {
 | ||||||
|  | @ -382,17 +384,20 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { | ||||||
|                      CreateModelResources<ImageSegmenterGraphOptions>(sc)); |                      CreateModelResources<ImageSegmenterGraphOptions>(sc)); | ||||||
|     Graph graph; |     Graph graph; | ||||||
|     const auto& options = sc->Options<ImageSegmenterGraphOptions>(); |     const auto& options = sc->Options<ImageSegmenterGraphOptions>(); | ||||||
|  |     // TODO: remove deprecated output type support.
 | ||||||
|  |     if (!options.segmenter_options().has_output_type()) { | ||||||
|  |       MP_RETURN_IF_ERROR(SanityCheck(sc)); | ||||||
|  |     } | ||||||
|     ASSIGN_OR_RETURN( |     ASSIGN_OR_RETURN( | ||||||
|         auto output_streams, |         auto output_streams, | ||||||
|         BuildSegmentationTask( |         BuildSegmentationTask( | ||||||
|             options, *model_resources, graph[Input<Image>(kImageTag)], |             options, *model_resources, graph[Input<Image>(kImageTag)], | ||||||
|             graph[Input<NormalizedRect>::Optional(kNormRectTag)], |             graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); | ||||||
|             HasOutput(sc->OriginalNode(), kCategoryMaskTag), graph)); |  | ||||||
| 
 | 
 | ||||||
|     auto& merge_images_to_vector = |  | ||||||
|         graph.AddNode("MergeImagesToVectorCalculator"); |  | ||||||
|     // TODO: remove deprecated output type support.
 |     // TODO: remove deprecated output type support.
 | ||||||
|     if (options.segmenter_options().has_output_type()) { |     if (options.segmenter_options().has_output_type()) { | ||||||
|  |       auto& merge_images_to_vector = | ||||||
|  |           graph.AddNode("MergeImagesToVectorCalculator"); | ||||||
|       for (int i = 0; i < output_streams.segmented_masks->size(); ++i) { |       for (int i = 0; i < output_streams.segmented_masks->size(); ++i) { | ||||||
|         output_streams.segmented_masks->at(i) >> |         output_streams.segmented_masks->at(i) >> | ||||||
|             merge_images_to_vector[Input<Image>::Multiple("")][i]; |             merge_images_to_vector[Input<Image>::Multiple("")][i]; | ||||||
|  | @ -402,6 +407,9 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { | ||||||
|       merge_images_to_vector.Out("") >> |       merge_images_to_vector.Out("") >> | ||||||
|           graph[Output<std::vector<Image>>(kGroupedSegmentationTag)]; |           graph[Output<std::vector<Image>>(kGroupedSegmentationTag)]; | ||||||
|     } else { |     } else { | ||||||
|  |       if (output_streams.confidence_masks) { | ||||||
|  |         auto& merge_images_to_vector = | ||||||
|  |             graph.AddNode("MergeImagesToVectorCalculator"); | ||||||
|         for (int i = 0; i < output_streams.confidence_masks->size(); ++i) { |         for (int i = 0; i < output_streams.confidence_masks->size(); ++i) { | ||||||
|           output_streams.confidence_masks->at(i) >> |           output_streams.confidence_masks->at(i) >> | ||||||
|               merge_images_to_vector[Input<Image>::Multiple("")][i]; |               merge_images_to_vector[Input<Image>::Multiple("")][i]; | ||||||
|  | @ -409,7 +417,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { | ||||||
|               graph[Output<Image>::Multiple(kConfidenceMaskTag)][i]; |               graph[Output<Image>::Multiple(kConfidenceMaskTag)][i]; | ||||||
|         } |         } | ||||||
|         merge_images_to_vector.Out("") >> |         merge_images_to_vector.Out("") >> | ||||||
|           graph[Output<std::vector<Image>>(kConfidenceMasksTag)]; |             graph[Output<std::vector<Image>>::Optional(kConfidenceMasksTag)]; | ||||||
|  |       } | ||||||
|       if (output_streams.category_mask) { |       if (output_streams.category_mask) { | ||||||
|         *output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)]; |         *output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)]; | ||||||
|       } |       } | ||||||
|  | @ -419,6 +428,19 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|  |   absl::Status SanityCheck(mediapipe::SubgraphContext* sc) { | ||||||
|  |     const auto& node = sc->OriginalNode(); | ||||||
|  |     output_confidence_masks_ = HasOutput(node, kConfidenceMaskTag) || | ||||||
|  |                                HasOutput(node, kConfidenceMasksTag); | ||||||
|  |     output_category_mask_ = HasOutput(node, kCategoryMaskTag); | ||||||
|  |     if (!output_confidence_masks_ && !output_category_mask_) { | ||||||
|  |       return absl::InvalidArgumentError( | ||||||
|  |           "At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and CATEGORY_MASK " | ||||||
|  |           "must be connected."); | ||||||
|  |     } | ||||||
|  |     return absl::OkStatus(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   // Adds a mediapipe image segmentation task pipeline graph into the provided
 |   // Adds a mediapipe image segmentation task pipeline graph into the provided
 | ||||||
|   // builder::Graph instance. The segmentation pipeline takes images
 |   // builder::Graph instance. The segmentation pipeline takes images
 | ||||||
|   // (mediapipe::Image) as the input and returns segmented image mask as output.
 |   // (mediapipe::Image) as the input and returns segmented image mask as output.
 | ||||||
|  | @ -431,8 +453,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { | ||||||
|   absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask( |   absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask( | ||||||
|       const ImageSegmenterGraphOptions& task_options, |       const ImageSegmenterGraphOptions& task_options, | ||||||
|       const core::ModelResources& model_resources, Source<Image> image_in, |       const core::ModelResources& model_resources, Source<Image> image_in, | ||||||
|       Source<NormalizedRect> norm_rect_in, bool output_category_mask, |       Source<NormalizedRect> norm_rect_in, Graph& graph) { | ||||||
|       Graph& graph) { |  | ||||||
|     MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); |     MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); | ||||||
| 
 | 
 | ||||||
|     // Adds preprocessing calculators and connects them to the graph input image
 |     // Adds preprocessing calculators and connects them to the graph input image
 | ||||||
|  | @ -485,26 +506,32 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { | ||||||
|                                    /*category_mask=*/std::nullopt, |                                    /*category_mask=*/std::nullopt, | ||||||
|                                    /*image=*/image_and_tensors.image}; |                                    /*image=*/image_and_tensors.image}; | ||||||
|     } else { |     } else { | ||||||
|  |       std::optional<std::vector<Source<Image>>> confidence_masks; | ||||||
|  |       if (output_confidence_masks_) { | ||||||
|         ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, |         ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, | ||||||
|                          GetOutputTensor(model_resources)); |                          GetOutputTensor(model_resources)); | ||||||
|         int segmentation_streams_num = *output_tensor->shape()->rbegin(); |         int segmentation_streams_num = *output_tensor->shape()->rbegin(); | ||||||
|       std::vector<Source<Image>> confidence_masks; |         confidence_masks = std::vector<Source<Image>>(); | ||||||
|       confidence_masks.reserve(segmentation_streams_num); |         confidence_masks->reserve(segmentation_streams_num); | ||||||
|         for (int i = 0; i < segmentation_streams_num; ++i) { |         for (int i = 0; i < segmentation_streams_num; ++i) { | ||||||
|         confidence_masks.push_back(Source<Image>( |           confidence_masks->push_back(Source<Image>( | ||||||
|             tensor_to_images[Output<Image>::Multiple(kConfidenceMaskTag)][i])); |               tensor_to_images[Output<Image>::Multiple(kConfidenceMaskTag)] | ||||||
|  |                               [i])); | ||||||
|         } |         } | ||||||
|       return ImageSegmenterOutputs{ |       } | ||||||
|           /*segmented_masks=*/std::nullopt, |       std::optional<Source<Image>> category_mask; | ||||||
|  |       if (output_category_mask_) { | ||||||
|  |         category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)]; | ||||||
|  |       } | ||||||
|  |       return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt, | ||||||
|                                    /*confidence_masks=*/confidence_masks, |                                    /*confidence_masks=*/confidence_masks, | ||||||
|           /*category_mask=*/ |                                    /*category_mask=*/category_mask, | ||||||
|           output_category_mask |  | ||||||
|               ? std::make_optional( |  | ||||||
|                     tensor_to_images[Output<Image>(kCategoryMaskTag)]) |  | ||||||
|               : std::nullopt, |  | ||||||
|                                    /*image=*/image_and_tensors.image}; |                                    /*image=*/image_and_tensors.image}; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  | 
 | ||||||
|  |   bool output_confidence_masks_ = false; | ||||||
|  |   bool output_category_mask_ = false; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| REGISTER_MEDIAPIPE_GRAPH( | REGISTER_MEDIAPIPE_GRAPH( | ||||||
|  |  | ||||||
|  | @ -29,7 +29,7 @@ namespace image_segmenter { | ||||||
| struct ImageSegmenterResult { | struct ImageSegmenterResult { | ||||||
|   // Multiple masks of float image in VEC32F1 format where, for each mask, each
 |   // Multiple masks of float image in VEC32F1 format where, for each mask, each
 | ||||||
|   // pixel represents the prediction confidence, usually in the [0, 1] range.
 |   // pixel represents the prediction confidence, usually in the [0, 1] range.
 | ||||||
|   std::vector<Image> confidence_masks; |   std::optional<std::vector<Image>> confidence_masks; | ||||||
|   // A category mask of uint8 image in GRAY8 format where each pixel represents
 |   // A category mask of uint8 image in GRAY8 format where each pixel represents
 | ||||||
|   // the class which the pixel in the original image was predicted to belong to.
 |   // the class which the pixel in the original image was predicted to belong to.
 | ||||||
|   std::optional<Image> category_mask; |   std::optional<Image> category_mask; | ||||||
|  |  | ||||||
|  | @ -278,6 +278,7 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) { | ||||||
|   auto options = std::make_unique<ImageSegmenterOptions>(); |   auto options = std::make_unique<ImageSegmenterOptions>(); | ||||||
|   options->base_options.model_asset_path = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); |       JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); | ||||||
|  |   options->output_confidence_masks = false; | ||||||
|   options->output_category_mask = true; |   options->output_category_mask = true; | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, | ||||||
|                           ImageSegmenter::Create(std::move(options))); |                           ImageSegmenter::Create(std::move(options))); | ||||||
|  | @ -306,7 +307,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, | ||||||
|                           ImageSegmenter::Create(std::move(options))); |                           ImageSegmenter::Create(std::move(options))); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); | ||||||
|   EXPECT_EQ(result.confidence_masks.size(), 21); |   EXPECT_EQ(result.confidence_masks->size(), 21); | ||||||
| 
 | 
 | ||||||
|   cv::Mat expected_mask = cv::imread( |   cv::Mat expected_mask = cv::imread( | ||||||
|       JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE); |       JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE); | ||||||
|  | @ -315,7 +316,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { | ||||||
| 
 | 
 | ||||||
|   // Cat category index 8.
 |   // Cat category index 8.
 | ||||||
|   cv::Mat cat_mask = mediapipe::formats::MatView( |   cv::Mat cat_mask = mediapipe::formats::MatView( | ||||||
|       result.confidence_masks[8].GetImageFrameSharedPtr().get()); |       result.confidence_masks->at(8).GetImageFrameSharedPtr().get()); | ||||||
|   EXPECT_THAT(cat_mask, |   EXPECT_THAT(cat_mask, | ||||||
|               SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); |               SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); | ||||||
| } | } | ||||||
|  | @ -336,7 +337,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { | ||||||
|   image_processing_options.rotation_degrees = -90; |   image_processing_options.rotation_degrees = -90; | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto result, |   MP_ASSERT_OK_AND_ASSIGN(auto result, | ||||||
|                           segmenter->Segment(image, image_processing_options)); |                           segmenter->Segment(image, image_processing_options)); | ||||||
|   EXPECT_EQ(result.confidence_masks.size(), 21); |   EXPECT_EQ(result.confidence_masks->size(), 21); | ||||||
| 
 | 
 | ||||||
|   cv::Mat expected_mask = |   cv::Mat expected_mask = | ||||||
|       cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), |       cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), | ||||||
|  | @ -346,7 +347,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { | ||||||
| 
 | 
 | ||||||
|   // Cat category index 8.
 |   // Cat category index 8.
 | ||||||
|   cv::Mat cat_mask = mediapipe::formats::MatView( |   cv::Mat cat_mask = mediapipe::formats::MatView( | ||||||
|       result.confidence_masks[8].GetImageFrameSharedPtr().get()); |       result.confidence_masks->at(8).GetImageFrameSharedPtr().get()); | ||||||
|   EXPECT_THAT(cat_mask, |   EXPECT_THAT(cat_mask, | ||||||
|               SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); |               SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); | ||||||
| } | } | ||||||
|  | @ -384,7 +385,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, | ||||||
|                           ImageSegmenter::Create(std::move(options))); |                           ImageSegmenter::Create(std::move(options))); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); | ||||||
|   EXPECT_EQ(result.confidence_masks.size(), 2); |   EXPECT_EQ(result.confidence_masks->size(), 2); | ||||||
| 
 | 
 | ||||||
|   cv::Mat expected_mask = |   cv::Mat expected_mask = | ||||||
|       cv::imread(JoinPath("./", kTestDataDirectory, |       cv::imread(JoinPath("./", kTestDataDirectory, | ||||||
|  | @ -395,7 +396,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { | ||||||
| 
 | 
 | ||||||
|   // Selfie category index 1.
 |   // Selfie category index 1.
 | ||||||
|   cv::Mat selfie_mask = mediapipe::formats::MatView( |   cv::Mat selfie_mask = mediapipe::formats::MatView( | ||||||
|       result.confidence_masks[1].GetImageFrameSharedPtr().get()); |       result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); | ||||||
|   EXPECT_THAT(selfie_mask, |   EXPECT_THAT(selfie_mask, | ||||||
|               SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); |               SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); | ||||||
| } | } | ||||||
|  | @ -409,7 +410,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, | ||||||
|                           ImageSegmenter::Create(std::move(options))); |                           ImageSegmenter::Create(std::move(options))); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); | ||||||
|   EXPECT_EQ(result.confidence_masks.size(), 1); |   EXPECT_EQ(result.confidence_masks->size(), 1); | ||||||
| 
 | 
 | ||||||
|   cv::Mat expected_mask = |   cv::Mat expected_mask = | ||||||
|       cv::imread(JoinPath("./", kTestDataDirectory, |       cv::imread(JoinPath("./", kTestDataDirectory, | ||||||
|  | @ -419,7 +420,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { | ||||||
|   expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); |   expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); | ||||||
| 
 | 
 | ||||||
|   cv::Mat selfie_mask = mediapipe::formats::MatView( |   cv::Mat selfie_mask = mediapipe::formats::MatView( | ||||||
|       result.confidence_masks[0].GetImageFrameSharedPtr().get()); |       result.confidence_masks->at(0).GetImageFrameSharedPtr().get()); | ||||||
|   EXPECT_THAT(selfie_mask, |   EXPECT_THAT(selfie_mask, | ||||||
|               SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); |               SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); | ||||||
| } | } | ||||||
|  | @ -434,7 +435,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, | ||||||
|                           ImageSegmenter::Create(std::move(options))); |                           ImageSegmenter::Create(std::move(options))); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); | ||||||
|   EXPECT_EQ(result.confidence_masks.size(), 1); |   EXPECT_EQ(result.confidence_masks->size(), 1); | ||||||
|   MP_ASSERT_OK(segmenter->Close()); |   MP_ASSERT_OK(segmenter->Close()); | ||||||
| 
 | 
 | ||||||
|   cv::Mat expected_mask = cv::imread( |   cv::Mat expected_mask = cv::imread( | ||||||
|  | @ -445,7 +446,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { | ||||||
|   expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); |   expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); | ||||||
| 
 | 
 | ||||||
|   cv::Mat selfie_mask = mediapipe::formats::MatView( |   cv::Mat selfie_mask = mediapipe::formats::MatView( | ||||||
|       result.confidence_masks[0].GetImageFrameSharedPtr().get()); |       result.confidence_masks->at(0).GetImageFrameSharedPtr().get()); | ||||||
|   EXPECT_THAT(selfie_mask, |   EXPECT_THAT(selfie_mask, | ||||||
|               SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); |               SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); | ||||||
| } | } | ||||||
|  | @ -506,10 +507,10 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) { | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, | ||||||
|                           ImageSegmenter::Create(std::move(options))); |                           ImageSegmenter::Create(std::move(options))); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); | ||||||
|   EXPECT_EQ(result.confidence_masks.size(), 2); |   EXPECT_EQ(result.confidence_masks->size(), 2); | ||||||
| 
 | 
 | ||||||
|   cv::Mat hair_mask = mediapipe::formats::MatView( |   cv::Mat hair_mask = mediapipe::formats::MatView( | ||||||
|       result.confidence_masks[1].GetImageFrameSharedPtr().get()); |       result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); | ||||||
|   MP_ASSERT_OK(segmenter->Close()); |   MP_ASSERT_OK(segmenter->Close()); | ||||||
|   cv::Mat expected_mask = cv::imread( |   cv::Mat expected_mask = cv::imread( | ||||||
|       JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"), |       JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"), | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user