diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 6a9a25fc1..186909509 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -33,6 +33,7 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components:image_preprocessing", @@ -66,6 +67,7 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", "//mediapipe/tasks/cc/core:base_options", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index 8b7473d48..5aecc2d2f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" @@ -48,6 +49,8 @@ constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectName[] = "norm_rect_in"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ObjectDetectorGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; @@ -55,6 +58,31 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; +// Returns a NormalizedRect filling the whole image. If input is present, its +// rotation is set in the returned NormalizedRect and a check is performed to +// make sure no region-of-interest was provided. Otherwise, rotation is set to +// 0. +absl::StatusOr FillNormalizedRect( + std::optional normalized_rect) { + NormalizedRect result; + if (normalized_rect.has_value()) { + result = *normalized_rect; + } + bool has_coordinates = result.has_x_center() || result.has_y_center() || + result.has_width() || result.has_height(); + if (has_coordinates) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "ObjectDetector does not support region-of-interest.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + result.set_x_center(0.5); + result.set_y_center(0.5); + result.set_width(1); + result.set_height(1); + return result; +} + // Creates a MediaPipe graph config that contains a subgraph node of // "mediapipe.tasks.vision.ObjectDetectorGraph". If the task is running in the // live stream mode, a "FlowLimiterCalculator" will be added to limit the @@ -64,6 +92,7 @@ CalculatorGraphConfig CreateGraphConfig( bool enable_flow_limiting) { api2::builder::Graph graph; graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap( options_proto.get()); @@ -76,6 +105,7 @@ CalculatorGraphConfig CreateGraphConfig( {kImageTag}, kDetectionsTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); + graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); return graph.GetConfig(); } @@ -139,46 +169,64 @@ absl::StatusOr> ObjectDetector::Create( } absl::StatusOr> ObjectDetector::Detect( - mediapipe::Image image) { + mediapipe::Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(auto output_packets, - ProcessImageData({{kImageInStreamName, - MakePacket(std::move(image))}})); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData( + {{kImageInStreamName, MakePacket(std::move(image))}, + {kNormRectName, MakePacket(std::move(norm_rect))}})); return output_packets[kDetectionsOutStreamName].Get>(); } absl::StatusOr> ObjectDetector::DetectForVideo( - mediapipe::Image image, int64 timestamp_ms) { + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + FillNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( {{kImageInStreamName, MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectName, + MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); return output_packets[kDetectionsOutStreamName].Get>(); } -absl::Status ObjectDetector::DetectAsync(Image image, int64 timestamp_ms) { +absl::Status ObjectDetector::DetectAsync( + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + FillNormalizedRect(image_processing_options)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectName, + MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 0fa1b087b..2e5ed7b8d 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -151,6 +153,13 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // TODO: Describes how the input image will be preprocessed // after the yuv support is implemented. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing classification, by + // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° + // anti-clockwise rotation). Note that specifying a region-of-interest using + // the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported + // and will result in an invalid argument error being returned. + // // For CPU images, the returned bounding boxes are expressed in the // unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the @@ -158,7 +167,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // TODO: Describes the output bounding boxes for gpu input // images after enabling the gpu support in MediaPipe Tasks. absl::StatusOr> Detect( - mediapipe::Image image); + mediapipe::Image image, + std::optional image_processing_options = + std::nullopt); // Performs object detection on the provided video frame. // Only use this method when the ObjectDetector is created with the video @@ -168,12 +179,21 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing classification, by + // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° + // anti-clockwise rotation). Note that specifying a region-of-interest using + // the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported + // and will result in an invalid argument error being returned. + // // For CPU images, the returned bounding boxes are expressed in the // unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the // underlying image data. absl::StatusOr> DetectForVideo( - mediapipe::Image image, int64 timestamp_ms); + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options = + std::nullopt); // Sends live image data to perform object detection, and the results will be // available via the "result_callback" provided in the ObjectDetectorOptions. @@ -185,7 +205,14 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // sent to the object detector. The input timestamps must be monotonically // increasing. // - // The "result_callback" prvoides + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing classification, by + // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° + // anti-clockwise rotation). Note that specifying a region-of-interest using + // the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported + // and will result in an invalid argument error being returned. + // + // The "result_callback" provides // - A vector of detections, each has a bounding box that is expressed in // the unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the @@ -195,7 +222,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms); + absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Shuts down the ObjectDetector when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index b0533e469..a2fb373cb 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" @@ -87,6 +88,7 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageTag[] = "IMAGE"; constexpr char kIndicesTag[] = "INDICES"; constexpr char kMatrixTag[] = "MATRIX"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS"; constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX"; constexpr char kScoresTag[] = "SCORES"; @@ -457,6 +459,10 @@ void ConfigureTensorsToDetectionsCalculator( // Inputs: // IMAGE - Image // Image to perform detection on. +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform classification +// on. +// @Optional: rect covering the whole image is used if not specified. // // Outputs: // DETECTIONS - std::vector @@ -494,9 +500,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { Graph graph; ASSIGN_OR_RETURN( auto output_streams, - BuildObjectDetectionTask(sc->Options(), - *model_resources, - graph[Input(kImageTag)], graph)); + BuildObjectDetectionTask( + sc->Options(), *model_resources, + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); output_streams.detections >> graph[Output>(kDetectionsTag)]; output_streams.image >> graph[Output(kImageTag)]; @@ -519,7 +526,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { absl::StatusOr BuildObjectDetectionTask( const ObjectDetectorOptionsProto& task_options, const core::ModelResources& model_resources, Source image_in, - Graph& graph) { + Source norm_rect_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Checks that the model has 4 outputs. auto& model = *model_resources.GetTfLiteModel(); @@ -559,6 +566,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); + norm_rect_in >> preprocessing.In(kNormRectTag); // Adds inference subgraph and connects its input stream to the output // tensors produced by the ImageToTensorCalculator. diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index bcc4c95ee..8a9251152 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/object_detector/object_detector.h" +#include #include #include #include @@ -30,6 +31,7 @@ limitations under the License. #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" @@ -519,6 +521,54 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) { ExpectApproximatelyEqual(results, {full_expected_results[3]}); } +TEST_F(ImageModeTest, SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "cats_and_dogs_rotated.jpg"))); + auto options = std::make_unique(); + options->max_results = 1; + options->category_allowlist.push_back("cat"); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, + ObjectDetector::Create(std::move(options))); + NormalizedRect image_processing_options; + image_processing_options.set_rotation(M_PI / 2.0); + MP_ASSERT_OK_AND_ASSIGN( + auto results, object_detector->Detect(image, image_processing_options)); + MP_ASSERT_OK(object_detector->Close()); + ExpectApproximatelyEqual( + results, {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.7109375 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 } + })pb")}); +} + +TEST_F(ImageModeTest, FailsWithRegionOfInterest) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "cats_and_dogs.jpg"))); + auto options = std::make_unique(); + options->max_results = 1; + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, + ObjectDetector::Create(std::move(options))); + NormalizedRect image_processing_options; + image_processing_options.set_x_center(0.5); + image_processing_options.set_y_center(0.5); + image_processing_options.set_width(1.0); + image_processing_options.set_height(1.0); + + auto results = object_detector->Detect(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("ObjectDetector does not support region-of-interest")); +} + class VideoModeTest : public tflite_shims::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 8b205cc49..764b93c91 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -30,6 +30,7 @@ mediapipe_files(srcs = [ "cat_mask.jpg", "cats_and_dogs.jpg", "cats_and_dogs_no_resizing.jpg", + "cats_and_dogs_rotated.jpg", "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", @@ -79,6 +80,7 @@ filegroup( "cat_mask.jpg", "cats_and_dogs.jpg", "cats_and_dogs_no_resizing.jpg", + "cats_and_dogs_rotated.jpg", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "left_hands.jpg", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 24fb15446..8f4b70d38 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -46,6 +46,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"], ) + http_file( + name = "com_google_mediapipe_BUILD_orig", + sha256 = "650df617b3e125e0890f1b8c936cc64c9d975707f57e616b6430fc667ce315d4", + urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1665609930388174"], + ) + http_file( name = "com_google_mediapipe_burger_crop_jpg", sha256 = "8f58de573f0bf59a49c3d86cfabb9ad4061481f574aa049177e8da3963dddc50", @@ -88,6 +94,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_no_resizing.jpg?generation=1661875687251296"], ) + http_file( + name = "com_google_mediapipe_cats_and_dogs_rotated_jpg", + sha256 = "5384926d16ddd8802555ae3108deedefb42a2ea78d99e5ad0933c5e11f43244a", + urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_rotated.jpg?generation=1665609933260747"], + ) + http_file( name = "com_google_mediapipe_classification_tensor_float_meta_json", sha256 = "1d10b1c9c87eabac330651136804074ddc134779e94a73cf783207c3aa2a5619",