diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index fc977c0b5..a430ae7b8 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -37,6 +37,7 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:status", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index a251a0ffc..74d8047de 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -217,13 +218,16 @@ absl::StatusOr> ImageSegmenter::Create( absl::StatusOr ImageSegmenter::Segment( mediapipe::Image image, std::optional image_processing_options) { - return Segment(image, image.width(), image.height(), - std::move(image_processing_options)); + return Segment(image, { + /*output_width=*/image.width(), + /*output_height=*/image.height(), + std::move(image_processing_options), + }); } absl::StatusOr ImageSegmenter::Segment( - mediapipe::Image image, int output_width, int output_height, - std::optional image_processing_options) { + mediapipe::Image image, SegmentationOptions segmentation_options) { + MP_RETURN_IF_ERROR(ValidateSegmentationOptions(segmentation_options)); if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -231,8 +235,9 @@ absl::StatusOr ImageSegmenter::Segment( MediaPipeTasksStatus::kRunnerUnexpectedInputError); } ASSIGN_OR_RETURN(NormalizedRect norm_rect, - ConvertToNormalizedRect(image_processing_options, image, - /*roi_allowed=*/false)); + ConvertToNormalizedRect( + segmentation_options.image_processing_options, image, + /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -241,7 +246,8 @@ absl::StatusOr ImageSegmenter::Segment( MakePacket(std::move(norm_rect))}, {kOutputSizeStreamName, MakePacket>( - std::make_pair(output_width, output_height))}})); + std::make_pair(segmentation_options.output_width, + segmentation_options.output_height))}})); std::optional> confidence_masks; if (output_confidence_masks_) { confidence_masks = @@ -259,14 +265,18 @@ absl::StatusOr ImageSegmenter::Segment( absl::StatusOr ImageSegmenter::SegmentForVideo( mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options) { - return SegmentForVideo(image, image.width(), image.height(), timestamp_ms, - image_processing_options); + return SegmentForVideo(image, timestamp_ms, + { + /*output_width=*/image.width(), + /*output_height=*/image.height(), + std::move(image_processing_options), + }); } absl::StatusOr ImageSegmenter::SegmentForVideo( - mediapipe::Image image, int output_width, int output_height, - int64_t timestamp_ms, - std::optional image_processing_options) { + mediapipe::Image image, int64_t timestamp_ms, + SegmentationOptions segmentation_options) { + MP_RETURN_IF_ERROR(ValidateSegmentationOptions(segmentation_options)); if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -274,8 +284,9 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( MediaPipeTasksStatus::kRunnerUnexpectedInputError); } ASSIGN_OR_RETURN(NormalizedRect norm_rect, - ConvertToNormalizedRect(image_processing_options, image, - /*roi_allowed=*/false)); + ConvertToNormalizedRect( + segmentation_options.image_processing_options, image, + /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -287,7 +298,8 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, {kOutputSizeStreamName, MakePacket>( - std::make_pair(output_width, output_height)) + std::make_pair(segmentation_options.output_width, + segmentation_options.output_height)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); std::optional> confidence_masks; if (output_confidence_masks_) { @@ -306,13 +318,18 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( absl::Status ImageSegmenter::SegmentAsync( Image image, int64_t timestamp_ms, std::optional image_processing_options) { - return SegmentAsync(image, image.width(), image.height(), timestamp_ms, - image_processing_options); + return SegmentAsync(image, timestamp_ms, + { + /*output_width=*/image.width(), + /*output_height=*/image.height(), + std::move(image_processing_options), + }); } absl::Status ImageSegmenter::SegmentAsync( - Image image, int output_width, int output_height, int64_t timestamp_ms, - std::optional image_processing_options) { + Image image, int64_t timestamp_ms, + SegmentationOptions segmentation_options) { + MP_RETURN_IF_ERROR(ValidateSegmentationOptions(segmentation_options)); if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -320,8 +337,9 @@ absl::Status ImageSegmenter::SegmentAsync( MediaPipeTasksStatus::kRunnerUnexpectedInputError); } ASSIGN_OR_RETURN(NormalizedRect norm_rect, - ConvertToNormalizedRect(image_processing_options, image, - /*roi_allowed=*/false)); + ConvertToNormalizedRect( + segmentation_options.image_processing_options, image, + /*roi_allowed=*/false)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) @@ -331,7 +349,8 @@ absl::Status ImageSegmenter::SegmentAsync( .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, {kOutputSizeStreamName, MakePacket>( - std::make_pair(output_width, output_height)) + std::make_pair(segmentation_options.output_width, + segmentation_options.output_height)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 237603497..82bb3a3a6 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -67,6 +67,22 @@ struct ImageSegmenterOptions { result_callback = nullptr; }; +// Options for configuring runtime behavior of ImageSegmenter. +struct SegmentationOptions { + // The width of the output segmentation masks. + int output_width; + + // The height of the output segmentation masks. + int output_height; + + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + std::optional image_processing_options; +}; + // Performs segmentation on images. // // The API expects a TFLite model with mandatory TFLite Model Metadata. @@ -119,18 +135,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // running mode. // // The image can be of any size with format RGB or RGBA. - // - // The output width and height specify the size of the resulted mask. - // - // The optional 'image_processing_options' parameter can be used to specify - // the rotation to apply to the image before performing segmentation, by - // setting its 'rotation_degrees' field. Note that specifying a - // region-of-interest using the 'region_of_interest' field is NOT supported - // and will result in an invalid argument error being returned. absl::StatusOr Segment( - mediapipe::Image image, int output_width, int output_height, - std::optional image_processing_options = - std::nullopt); + mediapipe::Image image, SegmentationOptions segmentation_options); // Performs image segmentation on the provided video frame. // Only use this method when the ImageSegmenter is created with the video @@ -159,19 +165,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. It's required to // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. - // - // The output width and height specify the size of the resulted mask. - // - // The optional 'image_processing_options' parameter can be used - // to specify the rotation to apply to the image before performing - // segmentation, by setting its 'rotation_degrees' field. Note that specifying - // a region-of-interest using the 'region_of_interest' field is NOT supported - // and will result in an invalid argument error being returned. absl::StatusOr SegmentForVideo( - mediapipe::Image image, int output_width, int output_height, - int64_t timestamp_ms, - std::optional image_processing_options = - std::nullopt); + mediapipe::Image image, int64_t timestamp_ms, + SegmentationOptions segmentation_options); // Sends live image data to perform image segmentation, and the results will // be available via the "result_callback" provided in the @@ -191,7 +187,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. // - // The "result_callback" prvoides + // The "result_callback" provides // - An ImageSegmenterResult. // - The const reference to the corresponding input image that the image // segmentation runs on. Note that the const reference to the image will @@ -212,25 +208,15 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // sent to the image segmenter. The input timestamps must be monotonically // increasing. // - // The output width and height specify the size of the resulted mask. - // - // The optional 'image_processing_options' parameter can be used to specify - // the rotation to apply to the image before performing segmentation, by - // setting its 'rotation_degrees' field. Note that specifying a - // region-of-interest using the 'region_of_interest' field is NOT supported - // and will result in an invalid argument error being returned. - // - // The "result_callback" prvoides + // The "result_callback" provides // - An ImageSegmenterResult. // - The const reference to the corresponding input image that the image // segmentation runs on. Note that the const reference to the image will // no longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status SegmentAsync(mediapipe::Image image, int output_width, - int output_height, int64_t timestamp_ms, - std::optional - image_processing_options = std::nullopt); + absl::Status SegmentAsync(mediapipe::Image image, int64_t timestamp_ms, + SegmentationOptions segmentation_options); // Shuts down the ImageSegmenter when all works are done. absl::Status Close() { return runner_->Close(); } @@ -248,6 +234,14 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { std::vector labels_; bool output_confidence_masks_; bool output_category_mask_; + + absl::Status ValidateSegmentationOptions(const SegmentationOptions& options) { + if (options.output_width <= 0 || options.output_height <= 0) { + return absl::InvalidArgumentError( + "Both output_width and output_height must be larger than 0."); + } + return absl::OkStatus(); + } }; } // namespace image_segmenter