C++ Image segmenter add output size parameters.

PiperOrigin-RevId: 550995124
This commit is contained in:
MediaPipe Team 2023-07-25 14:20:15 -07:00 committed by Copybara-Service
parent bd7888cc0c
commit 1f6851c577
5 changed files with 148 additions and 13 deletions

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" #include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
#include <optional> #include <optional>
#include <utility>
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
@ -41,6 +42,8 @@ constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS";
constexpr char kConfidenceMasksStreamName[] = "confidence_masks"; constexpr char kConfidenceMasksStreamName[] = "confidence_masks";
constexpr char kCategoryMaskTag[] = "CATEGORY_MASK"; constexpr char kCategoryMaskTag[] = "CATEGORY_MASK";
constexpr char kCategoryMaskStreamName[] = "category_mask"; constexpr char kCategoryMaskStreamName[] = "category_mask";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
constexpr char kOutputSizeStreamName[] = "output_size";
constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
@ -70,6 +73,7 @@ CalculatorGraphConfig CreateGraphConfig(
options.get()); options.get());
graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName);
graph.In(kOutputSizeTag).SetName(kOutputSizeStreamName);
if (output_confidence_masks) { if (output_confidence_masks) {
task_subgraph.Out(kConfidenceMasksTag) task_subgraph.Out(kConfidenceMasksTag)
.SetName(kConfidenceMasksStreamName) >> .SetName(kConfidenceMasksStreamName) >>
@ -85,10 +89,12 @@ CalculatorGraphConfig CreateGraphConfig(
graph.Out(kImageTag); graph.Out(kImageTag);
if (enable_flow_limiting) { if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator( return tasks::core::AddFlowLimiterCalculator(
graph, task_subgraph, {kImageTag, kNormRectTag}, kConfidenceMasksTag); graph, task_subgraph, {kImageTag, kNormRectTag, kOutputSizeTag},
kConfidenceMasksTag);
} }
graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
graph.In(kOutputSizeTag) >> task_subgraph.In(kOutputSizeTag);
return graph.GetConfig(); return graph.GetConfig();
} }
@ -211,6 +217,13 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment( absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
return Segment(image, image.width(), image.height(),
std::move(image_processing_options));
}
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
mediapipe::Image image, int output_width, int output_height,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
@ -225,7 +238,10 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
ProcessImageData( ProcessImageData(
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))}, {{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}})); MakePacket<NormalizedRect>(std::move(norm_rect))},
{kOutputSizeStreamName,
MakePacket<std::pair<int, int>>(
std::make_pair(output_width, output_height))}}));
std::optional<std::vector<Image>> confidence_masks; std::optional<std::vector<Image>> confidence_masks;
if (output_confidence_masks_) { if (output_confidence_masks_) {
confidence_masks = confidence_masks =
@ -243,6 +259,14 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo( absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
mediapipe::Image image, int64_t timestamp_ms, mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
return SegmentForVideo(image, image.width(), image.height(), timestamp_ms,
image_processing_options);
}
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
mediapipe::Image image, int output_width, int output_height,
int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
@ -260,6 +284,10 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kOutputSizeStreamName,
MakePacket<std::pair<int, int>>(
std::make_pair(output_width, output_height))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
std::optional<std::vector<Image>> confidence_masks; std::optional<std::vector<Image>> confidence_masks;
if (output_confidence_masks_) { if (output_confidence_masks_) {
@ -278,6 +306,13 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
absl::Status ImageSegmenter::SegmentAsync( absl::Status ImageSegmenter::SegmentAsync(
Image image, int64_t timestamp_ms, Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
return SegmentAsync(image, image.width(), image.height(), timestamp_ms,
image_processing_options);
}
absl::Status ImageSegmenter::SegmentAsync(
Image image, int output_width, int output_height, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
@ -293,6 +328,10 @@ absl::Status ImageSegmenter::SegmentAsync(
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kOutputSizeStreamName,
MakePacket<std::pair<int, int>>(
std::make_pair(output_width, output_height))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
} }

View File

@ -102,17 +102,36 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// //
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
// //
// The output size is the same as the input image size.
//
// The optional 'image_processing_options' parameter can be used to specify // The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by // the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a // setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported // region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned. // and will result in an invalid argument error being returned.
absl::StatusOr<ImageSegmenterResult> Segment( absl::StatusOr<ImageSegmenterResult> Segment(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
// Performs image segmentation on the provided single image.
// Only use this method when the ImageSegmenter is created with the image
// running mode.
//
// The image can be of any size with format RGB or RGBA.
//
// The output width and height specify the size of the resulted mask.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
absl::StatusOr<ImageSegmenterResult> Segment(
mediapipe::Image image, int output_width, int output_height,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs image segmentation on the provided video frame. // Performs image segmentation on the provided video frame.
// Only use this method when the ImageSegmenter is created with the video // Only use this method when the ImageSegmenter is created with the video
// running mode. // running mode.
@ -121,16 +140,39 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
// //
// The optional 'image_processing_options' parameter can be used to specify // The output size is the same as the input image size.
// the rotation to apply to the image before performing segmentation, by //
// setting its 'rotation_degrees' field. Note that specifying a // The optional 'image_processing_options' parameter can be used
// region-of-interest using the 'region_of_interest' field is NOT supported // to specify the rotation to apply to the image before performing
// segmentation, by setting its 'rotation_degrees' field. Note that specifying
// a region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned. // and will result in an invalid argument error being returned.
absl::StatusOr<ImageSegmenterResult> SegmentForVideo( absl::StatusOr<ImageSegmenterResult> SegmentForVideo(
mediapipe::Image image, int64_t timestamp_ms, mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
// Performs image segmentation on the provided video frame.
// Only use this method when the ImageSegmenter is created with the video
// running mode.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
//
// The output width and height specify the size of the resulted mask.
//
// The optional 'image_processing_options' parameter can be used
// to specify the rotation to apply to the image before performing
// segmentation, by setting its 'rotation_degrees' field. Note that specifying
// a region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
absl::StatusOr<ImageSegmenterResult> SegmentForVideo(
mediapipe::Image image, int output_width, int output_height,
int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to perform image segmentation, and the results will // Sends live image data to perform image segmentation, and the results will
// be available via the "result_callback" provided in the // be available via the "result_callback" provided in the
// ImageSegmenterOptions. Only use this method when the ImageSegmenter is // ImageSegmenterOptions. Only use this method when the ImageSegmenter is
@ -141,6 +183,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// sent to the image segmenter. The input timestamps must be monotonically // sent to the image segmenter. The input timestamps must be monotonically
// increasing. // increasing.
// //
// The output size is the same as the input image size.
//
// The optional 'image_processing_options' parameter can be used to specify // The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by // the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a // setting its 'rotation_degrees' field. Note that specifying a
@ -158,6 +202,36 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
std::optional<core::ImageProcessingOptions> std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt); image_processing_options = std::nullopt);
// Sends live image data to perform image segmentation, and the results will
// be available via the "result_callback" provided in the
// ImageSegmenterOptions. Only use this method when the ImageSegmenter is
// created with the live stream running mode.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the image segmenter. The input timestamps must be monotonically
// increasing.
//
// The output width and height specify the size of the resulted mask.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// The "result_callback" prvoides
// - An ImageSegmenterResult.
// - The const reference to the corresponding input image that the image
// segmentation runs on. Note that the const reference to the image will
// no longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status SegmentAsync(mediapipe::Image image, int output_width,
int output_height, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the ImageSegmenter when all works are done. // Shuts down the ImageSegmenter when all works are done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }

View File

@ -82,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
constexpr char kSizeTag[] = "SIZE";
constexpr char kQualityScoresTag[] = "QUALITY_SCORES"; constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
@ -356,6 +357,9 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
// Describes image rotation and region of image to perform detection // Describes image rotation and region of image to perform detection
// on. // on.
// @Optional: rect covering the whole image is used if not specified. // @Optional: rect covering the whole image is used if not specified.
// OUTPUT_SIZE - std::pair<int, int> @Optional
// The output size of the mask, in width and height. If not specified, the
// output size of the input image is used.
// //
// Outputs: // Outputs:
// CONFIDENCE_MASK - mediapipe::Image @Multiple // CONFIDENCE_MASK - mediapipe::Image @Multiple
@ -400,11 +404,16 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
if (!options.segmenter_options().has_output_type()) { if (!options.segmenter_options().has_output_type()) {
MP_RETURN_IF_ERROR(SanityCheck(sc)); MP_RETURN_IF_ERROR(SanityCheck(sc));
} }
std::optional<Source<std::pair<int, int>>> output_size;
if (HasInput(sc->OriginalNode(), kOutputSizeTag)) {
output_size = graph.In(kOutputSizeTag).Cast<std::pair<int, int>>();
}
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_streams, auto output_streams,
BuildSegmentationTask( BuildSegmentationTask(
options, *model_resources, graph[Input<Image>(kImageTag)], options, *model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); graph[Input<NormalizedRect>::Optional(kNormRectTag)], output_size,
graph));
// TODO: remove deprecated output type support. // TODO: remove deprecated output type support.
if (options.segmenter_options().has_output_type()) { if (options.segmenter_options().has_output_type()) {
@ -469,7 +478,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask( absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
const ImageSegmenterGraphOptions& task_options, const ImageSegmenterGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in, const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) { Source<NormalizedRect> norm_rect_in,
std::optional<Source<std::pair<int, int>>> output_size, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
// Adds preprocessing calculators and connects them to the graph input image // Adds preprocessing calculators and connects them to the graph input image
@ -514,10 +524,14 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
image_and_tensors.tensors >> inference.In(kTensorsTag); image_and_tensors.tensors >> inference.In(kTensorsTag);
inference.Out(kTensorsTag) >> tensor_to_images.In(kTensorsTag); inference.Out(kTensorsTag) >> tensor_to_images.In(kTensorsTag);
if (output_size.has_value()) {
*output_size >> tensor_to_images.In(kOutputSizeTag);
} else {
// Adds image property calculator for output size. // Adds image property calculator for output size.
auto& image_properties = graph.AddNode("ImagePropertiesCalculator"); auto& image_properties = graph.AddNode("ImagePropertiesCalculator");
image_in >> image_properties.In("IMAGE"); image_in >> image_properties.In(kImageTag);
image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag); image_properties.Out(kSizeTag) >> tensor_to_images.In(kOutputSizeTag);
}
// Exports multiple segmented masks. // Exports multiple segmented masks.
// TODO: remove deprecated output type support. // TODO: remove deprecated output type support.

View File

@ -57,6 +57,7 @@ mediapipe_files(srcs = [
"hand_landmarker.task", "hand_landmarker.task",
"left_hands.jpg", "left_hands.jpg",
"left_hands_rotated.jpg", "left_hands_rotated.jpg",
"leopard_bg_removal_result_512x512.png",
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_metadata_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite",
@ -136,6 +137,7 @@ filegroup(
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",
"left_hands.jpg", "left_hands.jpg",
"left_hands_rotated.jpg", "left_hands_rotated.jpg",
"leopard_bg_removal_result_512x512.png",
"mozart_square.jpg", "mozart_square.jpg",
"multi_objects.jpg", "multi_objects.jpg",
"multi_objects_rotated.jpg", "multi_objects_rotated.jpg",

View File

@ -646,6 +646,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands_rotated.jpg?generation=1666037068103465"], urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands_rotated.jpg?generation=1666037068103465"],
) )
http_file(
name = "com_google_mediapipe_leopard_bg_removal_result_512x512_png",
sha256 = "30be22e89fdd1d7b985294498ec67509b0caa1ca941fe291fa25f43a3873e4dd",
urls = ["https://storage.googleapis.com/mediapipe-assets/leopard_bg_removal_result_512x512.png?generation=1690239134617707"],
)
http_file( http_file(
name = "com_google_mediapipe_leopard_bg_removal_result_png", name = "com_google_mediapipe_leopard_bg_removal_result_png",
sha256 = "afd33f2058fd58d189cda86ec931647741a6139970c9bcbc637cdd151ec657c5", sha256 = "afd33f2058fd58d189cda86ec931647741a6139970c9bcbc637cdd151ec657c5",