Add support for input image rotation in ImageClassifier.

PiperOrigin-RevId: 480676070
This commit is contained in:
MediaPipe Team 2022-10-12 11:33:11 -07:00 committed by Copybara-Service
parent 51a7606083
commit ae4b2ae577
5 changed files with 194 additions and 53 deletions

View File

@ -59,14 +59,24 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
// Builds a NormalizedRect covering the entire image. // Returns a NormalizedRect covering the full image if input is not present.
NormalizedRect BuildFullImageNormRect() { // Otherwise, makes sure the x_center, y_center, width and height are set in
NormalizedRect norm_rect; // case only a rotation was provided in the input.
norm_rect.set_x_center(0.5); NormalizedRect FillNormalizedRect(
norm_rect.set_y_center(0.5); std::optional<NormalizedRect> normalized_rect) {
norm_rect.set_width(1); NormalizedRect result;
norm_rect.set_height(1); if (normalized_rect.has_value()) {
return norm_rect; result = *normalized_rect;
}
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
result.has_width() || result.has_height();
if (!has_coordinates) {
result.set_x_center(0.5);
result.set_y_center(0.5);
result.set_width(1);
result.set_height(1);
}
return result;
} }
// Creates a MediaPipe graph config that contains a subgraph node of // Creates a MediaPipe graph config that contains a subgraph node of
@ -154,15 +164,14 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
} }
absl::StatusOr<ClassificationResult> ImageClassifier::Classify( absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
Image image, std::optional<NormalizedRect> roi) { Image image, std::optional<NormalizedRect> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.", "GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
NormalizedRect norm_rect = NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
roi.has_value() ? roi.value() : BuildFullImageNormRect();
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessImageData( ProcessImageData(
@ -173,15 +182,15 @@ absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
} }
absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo( absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
Image image, int64 timestamp_ms, std::optional<NormalizedRect> roi) { Image image, int64 timestamp_ms,
std::optional<NormalizedRect> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.", "GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
NormalizedRect norm_rect = NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
roi.has_value() ? roi.value() : BuildFullImageNormRect();
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessVideoData( ProcessVideoData(
@ -195,16 +204,16 @@ absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
.Get<ClassificationResult>(); .Get<ClassificationResult>();
} }
absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms, absl::Status ImageClassifier::ClassifyAsync(
std::optional<NormalizedRect> roi) { Image image, int64 timestamp_ms,
std::optional<NormalizedRect> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.", "GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
NormalizedRect norm_rect = NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
roi.has_value() ? roi.value() : BuildFullImageNormRect();
return SendLiveStreamData( return SendLiveStreamData(
{{kImageInStreamName, {{kImageInStreamName,
MakePacket<Image>(std::move(image)) MakePacket<Image>(std::move(image))

View File

@ -105,9 +105,18 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
static absl::StatusOr<std::unique_ptr<ImageClassifier>> Create( static absl::StatusOr<std::unique_ptr<ImageClassifier>> Create(
std::unique_ptr<ImageClassifierOptions> options); std::unique_ptr<ImageClassifierOptions> options);
// Performs image classification on the provided single image. Classification // Performs image classification on the provided single image.
// is performed on the region of interest specified by the `roi` argument if //
// provided, or on the entire image otherwise. // The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
// anti-clockwise rotation).
// and/or
// - the region-of-interest on which to perform classification, by setting its
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
// set, they will automatically be set to cover the full image.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
// //
// Only use this method when the ImageClassifier is created with the image // Only use this method when the ImageClassifier is created with the image
// running mode. // running mode.
@ -117,11 +126,21 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// YUVToImageCalculator is integrated. // YUVToImageCalculator is integrated.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify( absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
mediapipe::Image image, mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt); std::optional<mediapipe::NormalizedRect> image_processing_options =
std::nullopt);
// Performs image classification on the provided video frame. Classification // Performs image classification on the provided video frame.
// is performed on the region of interested specified by the `roi` argument if //
// provided, or on the entire image otherwise. // The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
// anti-clockwise rotation).
// and/or
// - the region-of-interest on which to perform classification, by setting its
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
// set, they will automatically be set to cover the full image.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
// //
// Only use this method when the ImageClassifier is created with the video // Only use this method when the ImageClassifier is created with the video
// running mode. // running mode.
@ -131,12 +150,22 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<components::containers::proto::ClassificationResult> absl::StatusOr<components::containers::proto::ClassificationResult>
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt); std::optional<mediapipe::NormalizedRect>
image_processing_options = std::nullopt);
// Sends live image data to image classification, and the results will be // Sends live image data to image classification, and the results will be
// available via the "result_callback" provided in the ImageClassifierOptions. // available via the "result_callback" provided in the ImageClassifierOptions.
// Classification is performed on the region of interested specified by the //
// `roi` argument if provided, or on the entire image otherwise. // The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
// anti-clockwise rotation).
// and/or
// - the region-of-interest on which to perform classification, by setting its
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
// set, they will automatically be set to cover the full image.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
// //
// Only use this method when the ImageClassifier is created with the live // Only use this method when the ImageClassifier is created with the live
// stream running mode. // stream running mode.
@ -153,9 +182,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// longer be valid when the callback returns. To access the image data // longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image. // outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds. // - The input timestamp in milliseconds.
absl::Status ClassifyAsync( absl::Status ClassifyAsync(mediapipe::Image image, int64 timestamp_ms,
mediapipe::Image image, int64 timestamp_ms, std::optional<mediapipe::NormalizedRect>
std::optional<mediapipe::NormalizedRect> roi = std::nullopt); image_processing_options = std::nullopt);
// TODO: add Classify() variants taking a region of interest as // TODO: add Classify() variants taking a region of interest as
// additional argument. // additional argument.

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h" #include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h"
#include <cmath>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
@ -546,18 +547,102 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
options->classifier_options.max_results = 1; options->classifier_options.max_results = 1;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// NormalizedRect around the soccer ball. // Crop around the soccer ball.
NormalizedRect roi; NormalizedRect image_processing_options;
roi.set_x_center(0.532); image_processing_options.set_x_center(0.532);
roi.set_y_center(0.521); image_processing_options.set_y_center(0.521);
roi.set_width(0.164); image_processing_options.set_width(0.164);
roi.set_height(0.427); image_processing_options.set_height(0.427);
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image, roi)); MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options));
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0)); ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0));
} }
TEST_F(ImageModeTest, SucceedsWithRotation) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg")));
auto options = std::make_unique<ImageClassifierOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
options->classifier_options.max_results = 3;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
// Specify a 90° anti-clockwise rotation.
NormalizedRect image_processing_options;
image_processing_options.set_rotation(M_PI / 2.0);
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options));
// Results differ slightly from the non-rotated image, but that's expected
// as models are very sensitive to the slightest numerical differences
// introduced by the rotation and JPG encoding.
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 934
score: 0.6371766
category_name: "cheeseburger"
}
categories {
index: 963
score: 0.049443405
category_name: "meat loaf"
}
categories {
index: 925
score: 0.047918003
category_name: "guacamole"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
}
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"multi_objects_rotated.jpg")));
auto options = std::make_unique<ImageClassifierOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
options->classifier_options.max_results = 1;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
// Crop around the chair, with 90° anti-clockwise rotation.
NormalizedRect image_processing_options;
image_processing_options.set_x_center(0.2821);
image_processing_options.set_y_center(0.2406);
image_processing_options.set_width(0.5642);
image_processing_options.set_height(0.1286);
image_processing_options.set_rotation(M_PI / 2.0);
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options));
ExpectApproximatelyEqual(results,
ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 560
score: 0.6800408
category_name: "folding chair"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
}
class VideoModeTest : public tflite_shims::testing::Test {}; class VideoModeTest : public tflite_shims::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
@ -646,16 +731,17 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
options->classifier_options.max_results = 1; options->classifier_options.max_results = 1;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// NormalizedRect around the soccer ball. // Crop around the soccer ball.
NormalizedRect roi; NormalizedRect image_processing_options;
roi.set_x_center(0.532); image_processing_options.set_x_center(0.532);
roi.set_y_center(0.521); image_processing_options.set_y_center(0.521);
roi.set_width(0.164); image_processing_options.set_width(0.164);
roi.set_height(0.427); image_processing_options.set_height(0.427);
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto results, MP_ASSERT_OK_AND_ASSIGN(
image_classifier->ClassifyForVideo(image, i, roi)); auto results,
image_classifier->ClassifyForVideo(image, i, image_processing_options));
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i)); ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i));
} }
MP_ASSERT_OK(image_classifier->Close()); MP_ASSERT_OK(image_classifier->Close());
@ -790,15 +876,16 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
}; };
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// NormalizedRect around the soccer ball. // Crop around the soccer ball.
NormalizedRect roi; NormalizedRect image_processing_options;
roi.set_x_center(0.532); image_processing_options.set_x_center(0.532);
roi.set_y_center(0.521); image_processing_options.set_y_center(0.521);
roi.set_width(0.164); image_processing_options.set_width(0.164);
roi.set_height(0.427); image_processing_options.set_height(0.427);
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK(image_classifier->ClassifyAsync(image, i, roi)); MP_ASSERT_OK(
image_classifier->ClassifyAsync(image, i, image_processing_options));
} }
MP_ASSERT_OK(image_classifier->Close()); MP_ASSERT_OK(image_classifier->Close());

View File

@ -25,6 +25,7 @@ package(
mediapipe_files(srcs = [ mediapipe_files(srcs = [
"burger.jpg", "burger.jpg",
"burger_crop.jpg", "burger_crop.jpg",
"burger_rotated.jpg",
"cat.jpg", "cat.jpg",
"cat_mask.jpg", "cat_mask.jpg",
"cats_and_dogs.jpg", "cats_and_dogs.jpg",
@ -46,6 +47,7 @@ mediapipe_files(srcs = [
"mobilenet_v3_small_100_224_embedder.tflite", "mobilenet_v3_small_100_224_embedder.tflite",
"mozart_square.jpg", "mozart_square.jpg",
"multi_objects.jpg", "multi_objects.jpg",
"multi_objects_rotated.jpg",
"palm_detection_full.tflite", "palm_detection_full.tflite",
"pointing_up.jpg", "pointing_up.jpg",
"right_hands.jpg", "right_hands.jpg",
@ -72,6 +74,7 @@ filegroup(
srcs = [ srcs = [
"burger.jpg", "burger.jpg",
"burger_crop.jpg", "burger_crop.jpg",
"burger_rotated.jpg",
"cat.jpg", "cat.jpg",
"cat_mask.jpg", "cat_mask.jpg",
"cats_and_dogs.jpg", "cats_and_dogs.jpg",
@ -81,6 +84,7 @@ filegroup(
"left_hands.jpg", "left_hands.jpg",
"mozart_square.jpg", "mozart_square.jpg",
"multi_objects.jpg", "multi_objects.jpg",
"multi_objects_rotated.jpg",
"pointing_up.jpg", "pointing_up.jpg",
"right_hands.jpg", "right_hands.jpg",
"segmentation_golden_rotation0.png", "segmentation_golden_rotation0.png",

View File

@ -58,6 +58,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/burger.jpg?generation=1661875667922678"], urls = ["https://storage.googleapis.com/mediapipe-assets/burger.jpg?generation=1661875667922678"],
) )
http_file(
name = "com_google_mediapipe_burger_rotated_jpg",
sha256 = "b7bb5e59ef778f3ce6b3e616c511908a53d513b83a56aae58b7453e14b0a4b2a",
urls = ["https://storage.googleapis.com/mediapipe-assets/burger_rotated.jpg?generation=1665065843774448"],
)
http_file( http_file(
name = "com_google_mediapipe_cat_jpg", name = "com_google_mediapipe_cat_jpg",
sha256 = "2533197401eebe9410ea4d063f86c43fbd2666f3e8165a38aca155c0d09c21be", sha256 = "2533197401eebe9410ea4d063f86c43fbd2666f3e8165a38aca155c0d09c21be",
@ -436,6 +442,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects.jpg?generation=1663251779213308"], urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects.jpg?generation=1663251779213308"],
) )
http_file(
name = "com_google_mediapipe_multi_objects_rotated_jpg",
sha256 = "175f6c572ffbab6554e382fd5056d09720eef931ccc4ed79481bdc47a8443911",
urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects_rotated.jpg?generation=1665065847969523"],
)
http_file( http_file(
name = "com_google_mediapipe_object_detection_3d_camera_tflite", name = "com_google_mediapipe_object_detection_3d_camera_tflite",
sha256 = "f66e92e81ed3f4698f74d565a7668e016e2288ea92fb42938e33b778bd1e110d", sha256 = "f66e92e81ed3f4698f74d565a7668e016e2288ea92fb42938e33b778bd1e110d",