From ae4b2ae577e47fe503f252cc79041819da885224 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 12 Oct 2022 11:33:11 -0700 Subject: [PATCH] Add support for input image rotation in ImageClassifier. PiperOrigin-RevId: 480676070 --- .../image_classifier/image_classifier.cc | 45 +++--- .../image_classifier/image_classifier.h | 55 ++++++-- .../image_classifier/image_classifier_test.cc | 131 +++++++++++++++--- mediapipe/tasks/testdata/vision/BUILD | 4 + third_party/external_files.bzl | 12 ++ 5 files changed, 194 insertions(+), 53 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 0338b2ee2..f3dcdd07d 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -59,14 +59,24 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; -// Builds a NormalizedRect covering the entire image. -NormalizedRect BuildFullImageNormRect() { - NormalizedRect norm_rect; - norm_rect.set_x_center(0.5); - norm_rect.set_y_center(0.5); - norm_rect.set_width(1); - norm_rect.set_height(1); - return norm_rect; +// Returns a NormalizedRect covering the full image if input is not present. +// Otherwise, makes sure the x_center, y_center, width and height are set in +// case only a rotation was provided in the input. +NormalizedRect FillNormalizedRect( + std::optional normalized_rect) { + NormalizedRect result; + if (normalized_rect.has_value()) { + result = *normalized_rect; + } + bool has_coordinates = result.has_x_center() || result.has_y_center() || + result.has_width() || result.has_height(); + if (!has_coordinates) { + result.set_x_center(0.5); + result.set_y_center(0.5); + result.set_width(1); + result.set_height(1); + } + return result; } // Creates a MediaPipe graph config that contains a subgraph node of @@ -154,15 +164,14 @@ absl::StatusOr> ImageClassifier::Create( } absl::StatusOr ImageClassifier::Classify( - Image image, std::optional roi) { + Image image, std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -173,15 +182,15 @@ absl::StatusOr ImageClassifier::Classify( } absl::StatusOr ImageClassifier::ClassifyForVideo( - Image image, int64 timestamp_ms, std::optional roi) { + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -195,16 +204,16 @@ absl::StatusOr ImageClassifier::ClassifyForVideo( .Get(); } -absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms, - std::optional roi) { +absl::Status ImageClassifier::ClassifyAsync( + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h index 24f36017a..5dff06cc7 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -105,9 +105,18 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { static absl::StatusOr> Create( std::unique_ptr options); - // Performs image classification on the provided single image. Classification - // is performed on the region of interest specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs image classification on the provided single image. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing classification, by + // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° + // anti-clockwise rotation). + // and/or + // - the region-of-interest on which to perform classification, by setting its + // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is + // set, they will automatically be set to cover the full image. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageClassifier is created with the image // running mode. @@ -117,11 +126,21 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // YUVToImageCalculator is integrated. absl::StatusOr Classify( mediapipe::Image image, - std::optional roi = std::nullopt); + std::optional image_processing_options = + std::nullopt); - // Performs image classification on the provided video frame. Classification - // is performed on the region of interested specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs image classification on the provided video frame. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing classification, by + // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° + // anti-clockwise rotation). + // and/or + // - the region-of-interest on which to perform classification, by setting its + // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is + // set, they will automatically be set to cover the full image. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageClassifier is created with the video // running mode. @@ -131,12 +150,22 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // must be monotonically increasing. absl::StatusOr ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + std::optional + image_processing_options = std::nullopt); // Sends live image data to image classification, and the results will be // available via the "result_callback" provided in the ImageClassifierOptions. - // Classification is performed on the region of interested specified by the - // `roi` argument if provided, or on the entire image otherwise. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing classification, by + // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° + // anti-clockwise rotation). + // and/or + // - the region-of-interest on which to perform classification, by setting its + // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is + // set, they will automatically be set to cover the full image. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageClassifier is created with the live // stream running mode. @@ -153,9 +182,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status ClassifyAsync( - mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + absl::Status ClassifyAsync(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // TODO: add Classify() variants taking a region of interest as // additional argument. diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index dcb2fddfc..55830e520 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h" +#include #include #include #include @@ -546,18 +547,102 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { options->classifier_options.max_results = 1; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // NormalizedRect around the soccer ball. - NormalizedRect roi; - roi.set_x_center(0.532); - roi.set_y_center(0.521); - roi.set_width(0.164); - roi.set_height(0.427); + // Crop around the soccer ball. + NormalizedRect image_processing_options; + image_processing_options.set_x_center(0.532); + image_processing_options.set_y_center(0.521); + image_processing_options.set_width(0.164); + image_processing_options.set_height(0.427); - MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image, roi)); + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( + image, image_processing_options)); ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0)); } +TEST_F(ImageModeTest, SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "burger_rotated.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.max_results = 3; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + // Specify a 90° anti-clockwise rotation. + NormalizedRect image_processing_options; + image_processing_options.set_rotation(M_PI / 2.0); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( + image, image_processing_options)); + + // Results differ slightly from the non-rotated image, but that's expected + // as models are very sensitive to the slightest numerical differences + // introduced by the rotation and JPG encoding. + ExpectApproximatelyEqual(results, ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 934 + score: 0.6371766 + category_name: "cheeseburger" + } + categories { + index: 963 + score: 0.049443405 + category_name: "meat loaf" + } + categories { + index: 925 + score: 0.047918003 + category_name: "guacamole" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + +TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "multi_objects_rotated.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.max_results = 1; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + // Crop around the chair, with 90° anti-clockwise rotation. + NormalizedRect image_processing_options; + image_processing_options.set_x_center(0.2821); + image_processing_options.set_y_center(0.2406); + image_processing_options.set_width(0.5642); + image_processing_options.set_height(0.1286); + image_processing_options.set_rotation(M_PI / 2.0); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( + image, image_processing_options)); + + ExpectApproximatelyEqual(results, + ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 560 + score: 0.6800408 + category_name: "folding chair" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + class VideoModeTest : public tflite_shims::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { @@ -646,16 +731,17 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { options->classifier_options.max_results = 1; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // NormalizedRect around the soccer ball. - NormalizedRect roi; - roi.set_x_center(0.532); - roi.set_y_center(0.521); - roi.set_width(0.164); - roi.set_height(0.427); + // Crop around the soccer ball. + NormalizedRect image_processing_options; + image_processing_options.set_x_center(0.532); + image_processing_options.set_y_center(0.521); + image_processing_options.set_width(0.164); + image_processing_options.set_height(0.427); for (int i = 0; i < iterations; ++i) { - MP_ASSERT_OK_AND_ASSIGN(auto results, - image_classifier->ClassifyForVideo(image, i, roi)); + MP_ASSERT_OK_AND_ASSIGN( + auto results, + image_classifier->ClassifyForVideo(image, i, image_processing_options)); ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i)); } MP_ASSERT_OK(image_classifier->Close()); @@ -790,15 +876,16 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { }; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // NormalizedRect around the soccer ball. - NormalizedRect roi; - roi.set_x_center(0.532); - roi.set_y_center(0.521); - roi.set_width(0.164); - roi.set_height(0.427); + // Crop around the soccer ball. + NormalizedRect image_processing_options; + image_processing_options.set_x_center(0.532); + image_processing_options.set_y_center(0.521); + image_processing_options.set_width(0.164); + image_processing_options.set_height(0.427); for (int i = 0; i < iterations; ++i) { - MP_ASSERT_OK(image_classifier->ClassifyAsync(image, i, roi)); + MP_ASSERT_OK( + image_classifier->ClassifyAsync(image, i, image_processing_options)); } MP_ASSERT_OK(image_classifier->Close()); diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 290b29016..8b205cc49 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -25,6 +25,7 @@ package( mediapipe_files(srcs = [ "burger.jpg", "burger_crop.jpg", + "burger_rotated.jpg", "cat.jpg", "cat_mask.jpg", "cats_and_dogs.jpg", @@ -46,6 +47,7 @@ mediapipe_files(srcs = [ "mobilenet_v3_small_100_224_embedder.tflite", "mozart_square.jpg", "multi_objects.jpg", + "multi_objects_rotated.jpg", "palm_detection_full.tflite", "pointing_up.jpg", "right_hands.jpg", @@ -72,6 +74,7 @@ filegroup( srcs = [ "burger.jpg", "burger_crop.jpg", + "burger_rotated.jpg", "cat.jpg", "cat_mask.jpg", "cats_and_dogs.jpg", @@ -81,6 +84,7 @@ filegroup( "left_hands.jpg", "mozart_square.jpg", "multi_objects.jpg", + "multi_objects_rotated.jpg", "pointing_up.jpg", "right_hands.jpg", "segmentation_golden_rotation0.png", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 254692856..24fb15446 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -58,6 +58,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/burger.jpg?generation=1661875667922678"], ) + http_file( + name = "com_google_mediapipe_burger_rotated_jpg", + sha256 = "b7bb5e59ef778f3ce6b3e616c511908a53d513b83a56aae58b7453e14b0a4b2a", + urls = ["https://storage.googleapis.com/mediapipe-assets/burger_rotated.jpg?generation=1665065843774448"], + ) + http_file( name = "com_google_mediapipe_cat_jpg", sha256 = "2533197401eebe9410ea4d063f86c43fbd2666f3e8165a38aca155c0d09c21be", @@ -436,6 +442,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects.jpg?generation=1663251779213308"], ) + http_file( + name = "com_google_mediapipe_multi_objects_rotated_jpg", + sha256 = "175f6c572ffbab6554e382fd5056d09720eef931ccc4ed79481bdc47a8443911", + urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects_rotated.jpg?generation=1665065847969523"], + ) + http_file( name = "com_google_mediapipe_object_detection_3d_camera_tflite", sha256 = "f66e92e81ed3f4698f74d565a7668e016e2288ea92fb42938e33b778bd1e110d",