Add support for input image rotation in ImageClassifier.
PiperOrigin-RevId: 480676070
This commit is contained in:
parent
51a7606083
commit
ae4b2ae577
|
@ -59,14 +59,24 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
|
|||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
|
||||
// Builds a NormalizedRect covering the entire image.
|
||||
NormalizedRect BuildFullImageNormRect() {
|
||||
NormalizedRect norm_rect;
|
||||
norm_rect.set_x_center(0.5);
|
||||
norm_rect.set_y_center(0.5);
|
||||
norm_rect.set_width(1);
|
||||
norm_rect.set_height(1);
|
||||
return norm_rect;
|
||||
// Returns a NormalizedRect covering the full image if input is not present.
|
||||
// Otherwise, makes sure the x_center, y_center, width and height are set in
|
||||
// case only a rotation was provided in the input.
|
||||
NormalizedRect FillNormalizedRect(
|
||||
std::optional<NormalizedRect> normalized_rect) {
|
||||
NormalizedRect result;
|
||||
if (normalized_rect.has_value()) {
|
||||
result = *normalized_rect;
|
||||
}
|
||||
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
|
||||
result.has_width() || result.has_height();
|
||||
if (!has_coordinates) {
|
||||
result.set_x_center(0.5);
|
||||
result.set_y_center(0.5);
|
||||
result.set_width(1);
|
||||
result.set_height(1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
|
@ -154,15 +164,14 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
|
|||
}
|
||||
|
||||
absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
|
||||
Image image, std::optional<NormalizedRect> roi) {
|
||||
Image image, std::optional<NormalizedRect> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect =
|
||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
||||
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData(
|
||||
|
@ -173,15 +182,15 @@ absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
|
|||
}
|
||||
|
||||
absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
|
||||
Image image, int64 timestamp_ms, std::optional<NormalizedRect> roi) {
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<NormalizedRect> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect =
|
||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
||||
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessVideoData(
|
||||
|
@ -195,16 +204,16 @@ absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
|
|||
.Get<ClassificationResult>();
|
||||
}
|
||||
|
||||
absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms,
|
||||
std::optional<NormalizedRect> roi) {
|
||||
absl::Status ImageClassifier::ClassifyAsync(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<NormalizedRect> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect =
|
||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
||||
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
|
||||
return SendLiveStreamData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
|
|
|
@ -105,9 +105,18 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
static absl::StatusOr<std::unique_ptr<ImageClassifier>> Create(
|
||||
std::unique_ptr<ImageClassifierOptions> options);
|
||||
|
||||
// Performs image classification on the provided single image. Classification
|
||||
// is performed on the region of interest specified by the `roi` argument if
|
||||
// provided, or on the entire image otherwise.
|
||||
// Performs image classification on the provided single image.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing classification, by
|
||||
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||
// anti-clockwise rotation).
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform classification, by setting its
|
||||
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
|
||||
// set, they will automatically be set to cover the full image.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
// Only use this method when the ImageClassifier is created with the image
|
||||
// running mode.
|
||||
|
@ -117,11 +126,21 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// YUVToImageCalculator is integrated.
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||
mediapipe::Image image,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Performs image classification on the provided video frame. Classification
|
||||
// is performed on the region of interested specified by the `roi` argument if
|
||||
// provided, or on the entire image otherwise.
|
||||
// Performs image classification on the provided video frame.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing classification, by
|
||||
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||
// anti-clockwise rotation).
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform classification, by setting its
|
||||
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
|
||||
// set, they will automatically be set to cover the full image.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
// Only use this method when the ImageClassifier is created with the video
|
||||
// running mode.
|
||||
|
@ -131,12 +150,22 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult>
|
||||
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
std::optional<mediapipe::NormalizedRect>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Sends live image data to image classification, and the results will be
|
||||
// available via the "result_callback" provided in the ImageClassifierOptions.
|
||||
// Classification is performed on the region of interested specified by the
|
||||
// `roi` argument if provided, or on the entire image otherwise.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing classification, by
|
||||
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||
// anti-clockwise rotation).
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform classification, by setting its
|
||||
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
|
||||
// set, they will automatically be set to cover the full image.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
// Only use this method when the ImageClassifier is created with the live
|
||||
// stream running mode.
|
||||
|
@ -153,9 +182,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// longer be valid when the callback returns. To access the image data
|
||||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status ClassifyAsync(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
absl::Status ClassifyAsync(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// TODO: add Classify() variants taking a region of interest as
|
||||
// additional argument.
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -546,18 +547,102 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
|||
options->classifier_options.max_results = 1;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
// NormalizedRect around the soccer ball.
|
||||
NormalizedRect roi;
|
||||
roi.set_x_center(0.532);
|
||||
roi.set_y_center(0.521);
|
||||
roi.set_width(0.164);
|
||||
roi.set_height(0.427);
|
||||
// Crop around the soccer ball.
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_x_center(0.532);
|
||||
image_processing_options.set_y_center(0.521);
|
||||
image_processing_options.set_width(0.164);
|
||||
image_processing_options.set_height(0.427);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image, roi));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||
image, image_processing_options));
|
||||
|
||||
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
"burger_rotated.jpg")));
|
||||
auto options = std::make_unique<ImageClassifierOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
|
||||
options->classifier_options.max_results = 3;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
|
||||
// Specify a 90° anti-clockwise rotation.
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_rotation(M_PI / 2.0);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||
image, image_processing_options));
|
||||
|
||||
// Results differ slightly from the non-rotated image, but that's expected
|
||||
// as models are very sensitive to the slightest numerical differences
|
||||
// introduced by the rotation and JPG encoding.
|
||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 934
|
||||
score: 0.6371766
|
||||
category_name: "cheeseburger"
|
||||
}
|
||||
categories {
|
||||
index: 963
|
||||
score: 0.049443405
|
||||
category_name: "meat loaf"
|
||||
}
|
||||
categories {
|
||||
index: 925
|
||||
score: 0.047918003
|
||||
category_name: "guacamole"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
"multi_objects_rotated.jpg")));
|
||||
auto options = std::make_unique<ImageClassifierOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
|
||||
options->classifier_options.max_results = 1;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
// Crop around the chair, with 90° anti-clockwise rotation.
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_x_center(0.2821);
|
||||
image_processing_options.set_y_center(0.2406);
|
||||
image_processing_options.set_width(0.5642);
|
||||
image_processing_options.set_height(0.1286);
|
||||
image_processing_options.set_rotation(M_PI / 2.0);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||
image, image_processing_options));
|
||||
|
||||
ExpectApproximatelyEqual(results,
|
||||
ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 560
|
||||
score: 0.6800408
|
||||
category_name: "folding chair"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
}
|
||||
|
||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||
|
@ -646,16 +731,17 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
|
|||
options->classifier_options.max_results = 1;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
// NormalizedRect around the soccer ball.
|
||||
NormalizedRect roi;
|
||||
roi.set_x_center(0.532);
|
||||
roi.set_y_center(0.521);
|
||||
roi.set_width(0.164);
|
||||
roi.set_height(0.427);
|
||||
// Crop around the soccer ball.
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_x_center(0.532);
|
||||
image_processing_options.set_y_center(0.521);
|
||||
image_processing_options.set_width(0.164);
|
||||
image_processing_options.set_height(0.427);
|
||||
|
||||
for (int i = 0; i < iterations; ++i) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
image_classifier->ClassifyForVideo(image, i, roi));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto results,
|
||||
image_classifier->ClassifyForVideo(image, i, image_processing_options));
|
||||
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i));
|
||||
}
|
||||
MP_ASSERT_OK(image_classifier->Close());
|
||||
|
@ -790,15 +876,16 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
|
|||
};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
// NormalizedRect around the soccer ball.
|
||||
NormalizedRect roi;
|
||||
roi.set_x_center(0.532);
|
||||
roi.set_y_center(0.521);
|
||||
roi.set_width(0.164);
|
||||
roi.set_height(0.427);
|
||||
// Crop around the soccer ball.
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_x_center(0.532);
|
||||
image_processing_options.set_y_center(0.521);
|
||||
image_processing_options.set_width(0.164);
|
||||
image_processing_options.set_height(0.427);
|
||||
|
||||
for (int i = 0; i < iterations; ++i) {
|
||||
MP_ASSERT_OK(image_classifier->ClassifyAsync(image, i, roi));
|
||||
MP_ASSERT_OK(
|
||||
image_classifier->ClassifyAsync(image, i, image_processing_options));
|
||||
}
|
||||
MP_ASSERT_OK(image_classifier->Close());
|
||||
|
||||
|
|
4
mediapipe/tasks/testdata/vision/BUILD
vendored
4
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -25,6 +25,7 @@ package(
|
|||
mediapipe_files(srcs = [
|
||||
"burger.jpg",
|
||||
"burger_crop.jpg",
|
||||
"burger_rotated.jpg",
|
||||
"cat.jpg",
|
||||
"cat_mask.jpg",
|
||||
"cats_and_dogs.jpg",
|
||||
|
@ -46,6 +47,7 @@ mediapipe_files(srcs = [
|
|||
"mobilenet_v3_small_100_224_embedder.tflite",
|
||||
"mozart_square.jpg",
|
||||
"multi_objects.jpg",
|
||||
"multi_objects_rotated.jpg",
|
||||
"palm_detection_full.tflite",
|
||||
"pointing_up.jpg",
|
||||
"right_hands.jpg",
|
||||
|
@ -72,6 +74,7 @@ filegroup(
|
|||
srcs = [
|
||||
"burger.jpg",
|
||||
"burger_crop.jpg",
|
||||
"burger_rotated.jpg",
|
||||
"cat.jpg",
|
||||
"cat_mask.jpg",
|
||||
"cats_and_dogs.jpg",
|
||||
|
@ -81,6 +84,7 @@ filegroup(
|
|||
"left_hands.jpg",
|
||||
"mozart_square.jpg",
|
||||
"multi_objects.jpg",
|
||||
"multi_objects_rotated.jpg",
|
||||
"pointing_up.jpg",
|
||||
"right_hands.jpg",
|
||||
"segmentation_golden_rotation0.png",
|
||||
|
|
12
third_party/external_files.bzl
vendored
12
third_party/external_files.bzl
vendored
|
@ -58,6 +58,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/burger.jpg?generation=1661875667922678"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_burger_rotated_jpg",
|
||||
sha256 = "b7bb5e59ef778f3ce6b3e616c511908a53d513b83a56aae58b7453e14b0a4b2a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/burger_rotated.jpg?generation=1665065843774448"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_cat_jpg",
|
||||
sha256 = "2533197401eebe9410ea4d063f86c43fbd2666f3e8165a38aca155c0d09c21be",
|
||||
|
@ -436,6 +442,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects.jpg?generation=1663251779213308"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_multi_objects_rotated_jpg",
|
||||
sha256 = "175f6c572ffbab6554e382fd5056d09720eef931ccc4ed79481bdc47a8443911",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects_rotated.jpg?generation=1665065847969523"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_object_detection_3d_camera_tflite",
|
||||
sha256 = "f66e92e81ed3f4698f74d565a7668e016e2288ea92fb42938e33b778bd1e110d",
|
||||
|
|
Loading…
Reference in New Issue
Block a user