diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 99faa1064..a251a0ffc 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" #include +#include #include "absl/strings/str_format.h" #include "mediapipe/framework/api2/builder.h" @@ -41,6 +42,8 @@ constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS"; constexpr char kConfidenceMasksStreamName[] = "confidence_masks"; constexpr char kCategoryMaskTag[] = "CATEGORY_MASK"; constexpr char kCategoryMaskStreamName[] = "category_mask"; +constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kOutputSizeStreamName[] = "output_size"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; @@ -70,6 +73,7 @@ CalculatorGraphConfig CreateGraphConfig( options.get()); graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); + graph.In(kOutputSizeTag).SetName(kOutputSizeStreamName); if (output_confidence_masks) { task_subgraph.Out(kConfidenceMasksTag) .SetName(kConfidenceMasksStreamName) >> @@ -85,10 +89,12 @@ CalculatorGraphConfig CreateGraphConfig( graph.Out(kImageTag); if (enable_flow_limiting) { return tasks::core::AddFlowLimiterCalculator( - graph, task_subgraph, {kImageTag, kNormRectTag}, kConfidenceMasksTag); + graph, task_subgraph, {kImageTag, kNormRectTag, kOutputSizeTag}, + kConfidenceMasksTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); + graph.In(kOutputSizeTag) >> task_subgraph.In(kOutputSizeTag); return graph.GetConfig(); } @@ -211,6 +217,13 @@ absl::StatusOr> ImageSegmenter::Create( absl::StatusOr ImageSegmenter::Segment( mediapipe::Image image, std::optional image_processing_options) { + return Segment(image, image.width(), image.height(), + std::move(image_processing_options)); +} + +absl::StatusOr ImageSegmenter::Segment( + mediapipe::Image image, int output_width, int output_height, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -225,7 +238,10 @@ absl::StatusOr ImageSegmenter::Segment( ProcessImageData( {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, {kNormRectStreamName, - MakePacket(std::move(norm_rect))}})); + MakePacket(std::move(norm_rect))}, + {kOutputSizeStreamName, + MakePacket>( + std::make_pair(output_width, output_height))}})); std::optional> confidence_masks; if (output_confidence_masks_) { confidence_masks = @@ -243,6 +259,14 @@ absl::StatusOr ImageSegmenter::Segment( absl::StatusOr ImageSegmenter::SegmentForVideo( mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options) { + return SegmentForVideo(image, image.width(), image.height(), timestamp_ms, + image_processing_options); +} + +absl::StatusOr ImageSegmenter::SegmentForVideo( + mediapipe::Image image, int output_width, int output_height, + int64_t timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -260,6 +284,10 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, {kNormRectStreamName, MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kOutputSizeStreamName, + MakePacket>( + std::make_pair(output_width, output_height)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); std::optional> confidence_masks; if (output_confidence_masks_) { @@ -278,6 +306,13 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( absl::Status ImageSegmenter::SegmentAsync( Image image, int64_t timestamp_ms, std::optional image_processing_options) { + return SegmentAsync(image, image.width(), image.height(), timestamp_ms, + image_processing_options); +} + +absl::Status ImageSegmenter::SegmentAsync( + Image image, int output_width, int output_height, int64_t timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -293,6 +328,10 @@ absl::Status ImageSegmenter::SegmentAsync( .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, {kNormRectStreamName, MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kOutputSizeStreamName, + MakePacket>( + std::make_pair(output_width, output_height)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 0546cef3a..237603497 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -102,17 +102,36 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // // The image can be of any size with format RGB or RGBA. // + // The output size is the same as the input image size. + // // The optional 'image_processing_options' parameter can be used to specify // the rotation to apply to the image before performing segmentation, by // setting its 'rotation_degrees' field. Note that specifying a // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. - absl::StatusOr Segment( mediapipe::Image image, std::optional image_processing_options = std::nullopt); + // Performs image segmentation on the provided single image. + // Only use this method when the ImageSegmenter is created with the image + // running mode. + // + // The image can be of any size with format RGB or RGBA. + // + // The output width and height specify the size of the resulted mask. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + absl::StatusOr Segment( + mediapipe::Image image, int output_width, int output_height, + std::optional image_processing_options = + std::nullopt); + // Performs image segmentation on the provided video frame. // Only use this method when the ImageSegmenter is created with the video // running mode. @@ -121,16 +140,39 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. // - // The optional 'image_processing_options' parameter can be used to specify - // the rotation to apply to the image before performing segmentation, by - // setting its 'rotation_degrees' field. Note that specifying a - // region-of-interest using the 'region_of_interest' field is NOT supported + // The output size is the same as the input image size. + // + // The optional 'image_processing_options' parameter can be used + // to specify the rotation to apply to the image before performing + // segmentation, by setting its 'rotation_degrees' field. Note that specifying + // a region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. absl::StatusOr SegmentForVideo( mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options = std::nullopt); + // Performs image segmentation on the provided video frame. + // Only use this method when the ImageSegmenter is created with the video + // running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide the video frame's timestamp (in milliseconds). The input timestamps + // must be monotonically increasing. + // + // The output width and height specify the size of the resulted mask. + // + // The optional 'image_processing_options' parameter can be used + // to specify the rotation to apply to the image before performing + // segmentation, by setting its 'rotation_degrees' field. Note that specifying + // a region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + absl::StatusOr SegmentForVideo( + mediapipe::Image image, int output_width, int output_height, + int64_t timestamp_ms, + std::optional image_processing_options = + std::nullopt); + // Sends live image data to perform image segmentation, and the results will // be available via the "result_callback" provided in the // ImageSegmenterOptions. Only use this method when the ImageSegmenter is @@ -141,6 +183,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // sent to the image segmenter. The input timestamps must be monotonically // increasing. // + // The output size is the same as the input image size. + // // The optional 'image_processing_options' parameter can be used to specify // the rotation to apply to the image before performing segmentation, by // setting its 'rotation_degrees' field. Note that specifying a @@ -158,6 +202,36 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { std::optional image_processing_options = std::nullopt); + // Sends live image data to perform image segmentation, and the results will + // be available via the "result_callback" provided in the + // ImageSegmenterOptions. Only use this method when the ImageSegmenter is + // created with the live stream running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide a timestamp (in milliseconds) to indicate when the input image is + // sent to the image segmenter. The input timestamps must be monotonically + // increasing. + // + // The output width and height specify the size of the resulted mask. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // + // The "result_callback" prvoides + // - An ImageSegmenterResult. + // - The const reference to the corresponding input image that the image + // segmentation runs on. Note that the const reference to the image will + // no longer be valid when the callback returns. To access the image data + // outside of the callback, callers need to make a copy of the image. + // - The input timestamp in milliseconds. + absl::Status SegmentAsync(mediapipe::Image image, int output_width, + int output_height, int64_t timestamp_ms, + std::optional + image_processing_options = std::nullopt); + // Shuts down the ImageSegmenter when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 0ae47ffd1..e80da0123 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -82,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kSizeTag[] = "SIZE"; constexpr char kQualityScoresTag[] = "QUALITY_SCORES"; constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; @@ -356,6 +357,9 @@ absl::StatusOr ConvertImageToTensors( // Describes image rotation and region of image to perform detection // on. // @Optional: rect covering the whole image is used if not specified. +// OUTPUT_SIZE - std::pair @Optional +// The output size of the mask, in width and height. If not specified, the +// output size of the input image is used. // // Outputs: // CONFIDENCE_MASK - mediapipe::Image @Multiple @@ -400,11 +404,16 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { if (!options.segmenter_options().has_output_type()) { MP_RETURN_IF_ERROR(SanityCheck(sc)); } + std::optional>> output_size; + if (HasInput(sc->OriginalNode(), kOutputSizeTag)) { + output_size = graph.In(kOutputSizeTag).Cast>(); + } ASSIGN_OR_RETURN( auto output_streams, BuildSegmentationTask( options, *model_resources, graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], graph)); + graph[Input::Optional(kNormRectTag)], output_size, + graph)); // TODO: remove deprecated output type support. if (options.segmenter_options().has_output_type()) { @@ -469,7 +478,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { absl::StatusOr BuildSegmentationTask( const ImageSegmenterGraphOptions& task_options, const core::ModelResources& model_resources, Source image_in, - Source norm_rect_in, Graph& graph) { + Source norm_rect_in, + std::optional>> output_size, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Adds preprocessing calculators and connects them to the graph input image @@ -514,10 +524,14 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { image_and_tensors.tensors >> inference.In(kTensorsTag); inference.Out(kTensorsTag) >> tensor_to_images.In(kTensorsTag); - // Adds image property calculator for output size. - auto& image_properties = graph.AddNode("ImagePropertiesCalculator"); - image_in >> image_properties.In("IMAGE"); - image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag); + if (output_size.has_value()) { + *output_size >> tensor_to_images.In(kOutputSizeTag); + } else { + // Adds image property calculator for output size. + auto& image_properties = graph.AddNode("ImagePropertiesCalculator"); + image_in >> image_properties.In(kImageTag); + image_properties.Out(kSizeTag) >> tensor_to_images.In(kOutputSizeTag); + } // Exports multiple segmented masks. // TODO: remove deprecated output type support. diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 4fde58e02..c6d81a394 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -57,6 +57,7 @@ mediapipe_files(srcs = [ "hand_landmarker.task", "left_hands.jpg", "left_hands_rotated.jpg", + "leopard_bg_removal_result_512x512.png", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite", @@ -136,6 +137,7 @@ filegroup( "hand_landmark_lite.tflite", "left_hands.jpg", "left_hands_rotated.jpg", + "leopard_bg_removal_result_512x512.png", "mozart_square.jpg", "multi_objects.jpg", "multi_objects_rotated.jpg", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 4b51d9de0..f9a29309f 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -646,6 +646,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands_rotated.jpg?generation=1666037068103465"], ) + http_file( + name = "com_google_mediapipe_leopard_bg_removal_result_512x512_png", + sha256 = "30be22e89fdd1d7b985294498ec67509b0caa1ca941fe291fa25f43a3873e4dd", + urls = ["https://storage.googleapis.com/mediapipe-assets/leopard_bg_removal_result_512x512.png?generation=1690239134617707"], + ) + http_file( name = "com_google_mediapipe_leopard_bg_removal_result_png", sha256 = "afd33f2058fd58d189cda86ec931647741a6139970c9bcbc637cdd151ec657c5",