No public description

PiperOrigin-RevId: 559275983
This commit is contained in:
MediaPipe Team 2023-08-22 18:01:46 -07:00 committed by Copybara-Service
parent 1dfdeb6ebb
commit 90781669cb
3 changed files with 73 additions and 59 deletions

View File

@ -37,6 +37,7 @@ cc_library(
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:base_vision_task_api",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
@ -217,13 +218,16 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment( absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
return Segment(image, image.width(), image.height(), return Segment(image, {
std::move(image_processing_options)); /*output_width=*/image.width(),
/*output_height=*/image.height(),
std::move(image_processing_options),
});
} }
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment( absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
mediapipe::Image image, int output_width, int output_height, mediapipe::Image image, SegmentationOptions segmentation_options) {
std::optional<core::ImageProcessingOptions> image_processing_options) { MP_RETURN_IF_ERROR(ValidateSegmentationOptions(segmentation_options));
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
@ -231,8 +235,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(NormalizedRect norm_rect, ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, image, ConvertToNormalizedRect(
/*roi_allowed=*/false)); segmentation_options.image_processing_options, image,
/*roi_allowed=*/false));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessImageData( ProcessImageData(
@ -241,7 +246,8 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
MakePacket<NormalizedRect>(std::move(norm_rect))}, MakePacket<NormalizedRect>(std::move(norm_rect))},
{kOutputSizeStreamName, {kOutputSizeStreamName,
MakePacket<std::pair<int, int>>( MakePacket<std::pair<int, int>>(
std::make_pair(output_width, output_height))}})); std::make_pair(segmentation_options.output_width,
segmentation_options.output_height))}}));
std::optional<std::vector<Image>> confidence_masks; std::optional<std::vector<Image>> confidence_masks;
if (output_confidence_masks_) { if (output_confidence_masks_) {
confidence_masks = confidence_masks =
@ -259,14 +265,18 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo( absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
mediapipe::Image image, int64_t timestamp_ms, mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
return SegmentForVideo(image, image.width(), image.height(), timestamp_ms, return SegmentForVideo(image, timestamp_ms,
image_processing_options); {
/*output_width=*/image.width(),
/*output_height=*/image.height(),
std::move(image_processing_options),
});
} }
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo( absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
mediapipe::Image image, int output_width, int output_height, mediapipe::Image image, int64_t timestamp_ms,
int64_t timestamp_ms, SegmentationOptions segmentation_options) {
std::optional<core::ImageProcessingOptions> image_processing_options) { MP_RETURN_IF_ERROR(ValidateSegmentationOptions(segmentation_options));
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
@ -274,8 +284,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(NormalizedRect norm_rect, ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, image, ConvertToNormalizedRect(
/*roi_allowed=*/false)); segmentation_options.image_processing_options, image,
/*roi_allowed=*/false));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessVideoData( ProcessVideoData(
@ -287,7 +298,8 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kOutputSizeStreamName, {kOutputSizeStreamName,
MakePacket<std::pair<int, int>>( MakePacket<std::pair<int, int>>(
std::make_pair(output_width, output_height)) std::make_pair(segmentation_options.output_width,
segmentation_options.output_height))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
std::optional<std::vector<Image>> confidence_masks; std::optional<std::vector<Image>> confidence_masks;
if (output_confidence_masks_) { if (output_confidence_masks_) {
@ -306,13 +318,18 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
absl::Status ImageSegmenter::SegmentAsync( absl::Status ImageSegmenter::SegmentAsync(
Image image, int64_t timestamp_ms, Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
return SegmentAsync(image, image.width(), image.height(), timestamp_ms, return SegmentAsync(image, timestamp_ms,
image_processing_options); {
/*output_width=*/image.width(),
/*output_height=*/image.height(),
std::move(image_processing_options),
});
} }
absl::Status ImageSegmenter::SegmentAsync( absl::Status ImageSegmenter::SegmentAsync(
Image image, int output_width, int output_height, int64_t timestamp_ms, Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { SegmentationOptions segmentation_options) {
MP_RETURN_IF_ERROR(ValidateSegmentationOptions(segmentation_options));
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
@ -320,8 +337,9 @@ absl::Status ImageSegmenter::SegmentAsync(
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(NormalizedRect norm_rect, ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, image, ConvertToNormalizedRect(
/*roi_allowed=*/false)); segmentation_options.image_processing_options, image,
/*roi_allowed=*/false));
return SendLiveStreamData( return SendLiveStreamData(
{{kImageInStreamName, {{kImageInStreamName,
MakePacket<Image>(std::move(image)) MakePacket<Image>(std::move(image))
@ -331,7 +349,8 @@ absl::Status ImageSegmenter::SegmentAsync(
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kOutputSizeStreamName, {kOutputSizeStreamName,
MakePacket<std::pair<int, int>>( MakePacket<std::pair<int, int>>(
std::make_pair(output_width, output_height)) std::make_pair(segmentation_options.output_width,
segmentation_options.output_height))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
} }

View File

@ -67,6 +67,22 @@ struct ImageSegmenterOptions {
result_callback = nullptr; result_callback = nullptr;
}; };
// Options for configuring runtime behavior of ImageSegmenter.
struct SegmentationOptions {
// The width of the output segmentation masks.
int output_width;
// The height of the output segmentation masks.
int output_height;
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
std::optional<core::ImageProcessingOptions> image_processing_options;
};
// Performs segmentation on images. // Performs segmentation on images.
// //
// The API expects a TFLite model with mandatory TFLite Model Metadata. // The API expects a TFLite model with mandatory TFLite Model Metadata.
@ -119,18 +135,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// running mode. // running mode.
// //
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
//
// The output width and height specify the size of the resulted mask.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
absl::StatusOr<ImageSegmenterResult> Segment( absl::StatusOr<ImageSegmenterResult> Segment(
mediapipe::Image image, int output_width, int output_height, mediapipe::Image image, SegmentationOptions segmentation_options);
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs image segmentation on the provided video frame. // Performs image segmentation on the provided video frame.
// Only use this method when the ImageSegmenter is created with the video // Only use this method when the ImageSegmenter is created with the video
@ -159,19 +165,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to // The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
//
// The output width and height specify the size of the resulted mask.
//
// The optional 'image_processing_options' parameter can be used
// to specify the rotation to apply to the image before performing
// segmentation, by setting its 'rotation_degrees' field. Note that specifying
// a region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
absl::StatusOr<ImageSegmenterResult> SegmentForVideo( absl::StatusOr<ImageSegmenterResult> SegmentForVideo(
mediapipe::Image image, int output_width, int output_height, mediapipe::Image image, int64_t timestamp_ms,
int64_t timestamp_ms, SegmentationOptions segmentation_options);
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to perform image segmentation, and the results will // Sends live image data to perform image segmentation, and the results will
// be available via the "result_callback" provided in the // be available via the "result_callback" provided in the
@ -191,7 +187,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// region-of-interest using the 'region_of_interest' field is NOT supported // region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned. // and will result in an invalid argument error being returned.
// //
// The "result_callback" prvoides // The "result_callback" provides
// - An ImageSegmenterResult. // - An ImageSegmenterResult.
// - The const reference to the corresponding input image that the image // - The const reference to the corresponding input image that the image
// segmentation runs on. Note that the const reference to the image will // segmentation runs on. Note that the const reference to the image will
@ -212,25 +208,15 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// sent to the image segmenter. The input timestamps must be monotonically // sent to the image segmenter. The input timestamps must be monotonically
// increasing. // increasing.
// //
// The output width and height specify the size of the resulted mask. // The "result_callback" provides
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// The "result_callback" prvoides
// - An ImageSegmenterResult. // - An ImageSegmenterResult.
// - The const reference to the corresponding input image that the image // - The const reference to the corresponding input image that the image
// segmentation runs on. Note that the const reference to the image will // segmentation runs on. Note that the const reference to the image will
// no longer be valid when the callback returns. To access the image data // no longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image. // outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds. // - The input timestamp in milliseconds.
absl::Status SegmentAsync(mediapipe::Image image, int output_width, absl::Status SegmentAsync(mediapipe::Image image, int64_t timestamp_ms,
int output_height, int64_t timestamp_ms, SegmentationOptions segmentation_options);
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the ImageSegmenter when all works are done. // Shuts down the ImageSegmenter when all works are done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
@ -248,6 +234,14 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
std::vector<std::string> labels_; std::vector<std::string> labels_;
bool output_confidence_masks_; bool output_confidence_masks_;
bool output_category_mask_; bool output_category_mask_;
absl::Status ValidateSegmentationOptions(const SegmentationOptions& options) {
if (options.output_width <= 0 || options.output_height <= 0) {
return absl::InvalidArgumentError(
"Both output_width and output_height must be larger than 0.");
}
return absl::OkStatus();
}
}; };
} // namespace image_segmenter } // namespace image_segmenter