Add support for rotation in ObjectDetector C++ API

PiperOrigin-RevId: 481167472
This commit is contained in:
MediaPipe Team 2022-10-14 09:46:16 -07:00 committed by Copybara-Service
parent 0ebe6ccf59
commit 6f3e8381ed
7 changed files with 165 additions and 14 deletions

View File

@ -33,6 +33,7 @@ cc_library(
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
@ -66,6 +67,7 @@ cc_library(
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
"//mediapipe/tasks/cc/core:base_options",

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
@ -48,6 +49,8 @@ constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.ObjectDetectorGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
@ -55,6 +58,31 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
using ObjectDetectorOptionsProto =
object_detector::proto::ObjectDetectorOptions;
// Returns a NormalizedRect filling the whole image. If input is present, its
// rotation is set in the returned NormalizedRect and a check is performed to
// make sure no region-of-interest was provided. Otherwise, rotation is set to
// 0.
absl::StatusOr<NormalizedRect> FillNormalizedRect(
std::optional<NormalizedRect> normalized_rect) {
NormalizedRect result;
if (normalized_rect.has_value()) {
result = *normalized_rect;
}
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
result.has_width() || result.has_height();
if (has_coordinates) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"ObjectDetector does not support region-of-interest.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
result.set_x_center(0.5);
result.set_y_center(0.5);
result.set_width(1);
result.set_height(1);
return result;
}
// Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.ObjectDetectorGraph". If the task is running in the
// live stream mode, a "FlowLimiterCalculator" will be added to limit the
@ -64,6 +92,7 @@ CalculatorGraphConfig CreateGraphConfig(
bool enable_flow_limiting) {
api2::builder::Graph graph;
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectName);
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ObjectDetectorOptionsProto>().Swap(
options_proto.get());
@ -76,6 +105,7 @@ CalculatorGraphConfig CreateGraphConfig(
{kImageTag}, kDetectionsTag);
}
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
return graph.GetConfig();
}
@ -139,46 +169,64 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
}
absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
mediapipe::Image image) {
mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(auto output_packets,
ProcessImageData({{kImageInStreamName,
MakePacket<Image>(std::move(image))}}));
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
FillNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
}
absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
mediapipe::Image image, int64 timestamp_ms) {
mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
FillNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
}
absl::Status ObjectDetector::DetectAsync(Image image, int64 timestamp_ms) {
absl::Status ObjectDetector::DetectAsync(
Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
FillNormalizedRect(image_processing_options));
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
@ -26,6 +27,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
@ -151,6 +153,13 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
// anti-clockwise rotation). Note that specifying a region-of-interest using
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
// and will result in an invalid argument error being returned.
//
// For CPU images, the returned bounding boxes are expressed in the
// unrotated input frame of reference coordinates system, i.e. in `[0,
// image_width) x [0, image_height)`, which are the dimensions of the
@ -158,7 +167,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// TODO: Describes the output bounding boxes for gpu input
// images after enabling the gpu support in MediaPipe Tasks.
absl::StatusOr<std::vector<mediapipe::Detection>> Detect(
mediapipe::Image image);
mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> image_processing_options =
std::nullopt);
// Performs object detection on the provided video frame.
// Only use this method when the ObjectDetector is created with the video
@ -168,12 +179,21 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
// anti-clockwise rotation). Note that specifying a region-of-interest using
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
// and will result in an invalid argument error being returned.
//
// For CPU images, the returned bounding boxes are expressed in the
// unrotated input frame of reference coordinates system, i.e. in `[0,
// image_width) x [0, image_height)`, which are the dimensions of the
// underlying image data.
absl::StatusOr<std::vector<mediapipe::Detection>> DetectForVideo(
mediapipe::Image image, int64 timestamp_ms);
mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> image_processing_options =
std::nullopt);
// Sends live image data to perform object detection, and the results will be
// available via the "result_callback" provided in the ObjectDetectorOptions.
@ -185,7 +205,14 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// sent to the object detector. The input timestamps must be monotonically
// increasing.
//
// The "result_callback" prvoides
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
// anti-clockwise rotation). Note that specifying a region-of-interest using
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
// and will result in an invalid argument error being returned.
//
// The "result_callback" provides
// - A vector of detections, each has a bounding box that is expressed in
// the unrotated input frame of reference coordinates system, i.e. in `[0,
// image_width) x [0, image_height)`, which are the dimensions of the
@ -195,7 +222,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms);
absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect>
image_processing_options = std::nullopt);
// Shuts down the ObjectDetector when all works are done.
absl::Status Close() { return runner_->Close(); }

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
@ -87,6 +88,7 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kImageTag[] = "IMAGE";
constexpr char kIndicesTag[] = "INDICES";
constexpr char kMatrixTag[] = "MATRIX";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
constexpr char kScoresTag[] = "SCORES";
@ -457,6 +459,10 @@ void ConfigureTensorsToDetectionsCalculator(
// Inputs:
// IMAGE - Image
// Image to perform detection on.
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform classification
// on.
// @Optional: rect covering the whole image is used if not specified.
//
// Outputs:
// DETECTIONS - std::vector<Detection>
@ -494,9 +500,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
Graph graph;
ASSIGN_OR_RETURN(
auto output_streams,
BuildObjectDetectionTask(sc->Options<ObjectDetectorOptionsProto>(),
*model_resources,
graph[Input<Image>(kImageTag)], graph));
BuildObjectDetectionTask(
sc->Options<ObjectDetectorOptionsProto>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.detections >>
graph[Output<std::vector<Detection>>(kDetectionsTag)];
output_streams.image >> graph[Output<Image>(kImageTag)];
@ -519,7 +526,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
absl::StatusOr<ObjectDetectionOutputStreams> BuildObjectDetectionTask(
const ObjectDetectorOptionsProto& task_options,
const core::ModelResources& model_resources, Source<Image> image_in,
Graph& graph) {
Source<NormalizedRect> norm_rect_in, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
// Checks that the model has 4 outputs.
auto& model = *model_resources.GetTfLiteModel();
@ -559,6 +566,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
&preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);
// Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator.

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/object_detector/object_detector.h"
#include <cmath>
#include <functional>
#include <memory>
#include <string>
@ -30,6 +31,7 @@ limitations under the License.
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
@ -519,6 +521,54 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
ExpectApproximatelyEqual(results, {full_expected_results[3]});
}
TEST_F(ImageModeTest, SucceedsWithRotation) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"cats_and_dogs_rotated.jpg")));
auto options = std::make_unique<ObjectDetectorOptions>();
options->max_results = 1;
options->category_allowlist.push_back("cat");
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options)));
NormalizedRect image_processing_options;
image_processing_options.set_rotation(M_PI / 2.0);
MP_ASSERT_OK_AND_ASSIGN(
auto results, object_detector->Detect(image, image_processing_options));
MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual(
results, {ParseTextProtoOrDie<Detection>(R"pb(
label: "cat"
score: 0.7109375
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 }
})pb")});
}
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"cats_and_dogs.jpg")));
auto options = std::make_unique<ObjectDetectorOptions>();
options->max_results = 1;
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options)));
NormalizedRect image_processing_options;
image_processing_options.set_x_center(0.5);
image_processing_options.set_y_center(0.5);
image_processing_options.set_width(1.0);
image_processing_options.set_height(1.0);
auto results = object_detector->Detect(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("ObjectDetector does not support region-of-interest"));
}
class VideoModeTest : public tflite_shims::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {

View File

@ -30,6 +30,7 @@ mediapipe_files(srcs = [
"cat_mask.jpg",
"cats_and_dogs.jpg",
"cats_and_dogs_no_resizing.jpg",
"cats_and_dogs_rotated.jpg",
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
@ -79,6 +80,7 @@ filegroup(
"cat_mask.jpg",
"cats_and_dogs.jpg",
"cats_and_dogs_no_resizing.jpg",
"cats_and_dogs_rotated.jpg",
"hand_landmark_full.tflite",
"hand_landmark_lite.tflite",
"left_hands.jpg",

View File

@ -46,6 +46,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"],
)
http_file(
name = "com_google_mediapipe_BUILD_orig",
sha256 = "650df617b3e125e0890f1b8c936cc64c9d975707f57e616b6430fc667ce315d4",
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1665609930388174"],
)
http_file(
name = "com_google_mediapipe_burger_crop_jpg",
sha256 = "8f58de573f0bf59a49c3d86cfabb9ad4061481f574aa049177e8da3963dddc50",
@ -88,6 +94,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_no_resizing.jpg?generation=1661875687251296"],
)
http_file(
name = "com_google_mediapipe_cats_and_dogs_rotated_jpg",
sha256 = "5384926d16ddd8802555ae3108deedefb42a2ea78d99e5ad0933c5e11f43244a",
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_rotated.jpg?generation=1665609933260747"],
)
http_file(
name = "com_google_mediapipe_classification_tensor_float_meta_json",
sha256 = "1d10b1c9c87eabac330651136804074ddc134779e94a73cf783207c3aa2a5619",