From a0efcb47f23666f84448d82fcede6dab9fdfbf55 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 16:37:08 -0800 Subject: [PATCH] internal change PiperOrigin-RevId: 493742399 --- .../tasks/cc/components/containers/BUILD | 13 + .../components/containers/detection_result.cc | 73 ++++++ .../components/containers/detection_result.h | 52 ++++ .../tasks/cc/components/containers/rect.cc | 34 +++ .../tasks/cc/components/containers/rect.h | 29 ++- .../cc/vision/core/base_vision_task_api.h | 4 +- .../cc/vision/core/image_processing_options.h | 3 +- ...hand_landmarks_deduplication_calculator.cc | 14 +- .../hand_landmarker/hand_landmarker_test.cc | 4 +- .../image_classifier/image_classifier_test.cc | 19 +- .../image_embedder/image_embedder_test.cc | 6 +- .../image_segmenter/image_segmenter_test.cc | 4 +- .../tasks/cc/vision/object_detector/BUILD | 1 + .../vision/object_detector/object_detector.cc | 15 +- .../vision/object_detector/object_detector.h | 12 +- .../object_detector/object_detector_test.cc | 227 ++++++++++-------- .../tasks/cc/vision/utils/landmarks_utils.cc | 8 +- .../tasks/cc/vision/utils/landmarks_utils.h | 10 +- 18 files changed, 377 insertions(+), 151 deletions(-) create mode 100644 mediapipe/tasks/cc/components/containers/detection_result.cc create mode 100644 mediapipe/tasks/cc/components/containers/detection_result.h create mode 100644 mediapipe/tasks/cc/components/containers/rect.cc diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 35d3f4785..0750a1482 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -18,6 +18,7 @@ licenses(["notice"]) cc_library( name = "rect", + srcs = ["rect.cc"], hdrs = ["rect.h"], ) @@ -41,6 +42,18 @@ cc_library( ], ) +cc_library( + name = "detection_result", + srcs = ["detection_result.cc"], + hdrs = ["detection_result.h"], + deps = [ + ":category", + ":rect", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + ], +) + cc_library( name = "embedding_result", srcs = ["embedding_result.cc"], diff --git a/mediapipe/tasks/cc/components/containers/detection_result.cc b/mediapipe/tasks/cc/components/containers/detection_result.cc new file mode 100644 index 000000000..43c8ca0f5 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.cc @@ -0,0 +1,73 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +#include + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +constexpr int kDefaultCategoryIndex = -1; + +Detection ConvertToDetectionResult( + const mediapipe::Detection& detection_proto) { + Detection detection; + for (int idx = 0; idx < detection_proto.score_size(); ++idx) { + detection.categories.push_back( + {/* index= */ detection_proto.label_id_size() > idx + ? detection_proto.label_id(idx) + : kDefaultCategoryIndex, + /* score= */ detection_proto.score(idx), + /* category_name */ detection_proto.label_size() > idx + ? detection_proto.label(idx) + : "", + /* display_name */ detection_proto.display_name_size() > idx + ? detection_proto.display_name(idx) + : ""}); + } + Rect bounding_box; + if (detection_proto.location_data().has_bounding_box()) { + mediapipe::LocationData::BoundingBox bounding_box_proto = + detection_proto.location_data().bounding_box(); + bounding_box.left = bounding_box_proto.xmin(); + bounding_box.top = bounding_box_proto.ymin(); + bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width(); + bounding_box.bottom = + bounding_box_proto.ymin() + bounding_box_proto.height(); + } + detection.bounding_box = bounding_box; + return detection; +} + +DetectionResult ConvertToDetectionResult( + std::vector detections_proto) { + DetectionResult detection_result; + detection_result.detections.reserve(detections_proto.size()); + for (const auto& detection_proto : detections_proto) { + detection_result.detections.push_back( + ConvertToDetectionResult(detection_proto)); + } + return detection_result; +} +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/detection_result.h b/mediapipe/tasks/cc/components/containers/detection_result.h new file mode 100644 index 000000000..546f324d6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +// Detection for a single bounding box. +struct Detection { + // A vector of detected categories. + std::vector categories; + // The bounding box location. + Rect bounding_box; +}; + +// Detection results of a model. +struct DetectionResult { + // A vector of Detections. + std::vector detections; +}; + +// Utility function to convert from Detection proto to Detection struct. +Detection ConvertToDetection(const mediapipe::Detection& detection_proto); + +// Utility function to convert from list of Detection proto to DetectionResult +// struct. +DetectionResult ConvertToDetectionResult( + std::vector detections_proto); + +} // namespace mediapipe::tasks::components::containers +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ diff --git a/mediapipe/tasks/cc/components/containers/rect.cc b/mediapipe/tasks/cc/components/containers/rect.cc new file mode 100644 index 000000000..4a94832a6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/rect.cc @@ -0,0 +1,34 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +RectF ToRectF(const Rect& rect, int image_height, int image_width) { + return RectF{static_cast(rect.left) / image_width, + static_cast(rect.top) / image_height, + static_cast(rect.right) / image_width, + static_cast(rect.bottom) / image_height}; +} + +Rect ToRect(const RectF& rect, int image_height, int image_width) { + return Rect{static_cast(rect.left * image_width), + static_cast(rect.top * image_height), + static_cast(rect.right * image_width), + static_cast(rect.bottom * image_height)}; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/rect.h b/mediapipe/tasks/cc/components/containers/rect.h index 3f5432cf2..551d91588 100644 --- a/mediapipe/tasks/cc/components/containers/rect.h +++ b/mediapipe/tasks/cc/components/containers/rect.h @@ -16,20 +16,47 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ +#include + namespace mediapipe::tasks::components::containers { +constexpr float kRectFTolerance = 1e-4; + // Defines a rectangle, used e.g. as part of detection results or as input // region-of-interest. // +struct Rect { + int left; + int top; + int right; + int bottom; +}; + +inline bool operator==(const Rect& lhs, const Rect& rhs) { + return lhs.left == rhs.left && lhs.top == rhs.top && lhs.right == rhs.right && + lhs.bottom == rhs.bottom; +} + // The coordinates are normalized wrt the image dimensions, i.e. generally in // [0,1] but they may exceed these bounds if describing a region overlapping the // image. The origin is on the top-left corner of the image. -struct Rect { +struct RectF { float left; float top; float right; float bottom; }; +inline bool operator==(const RectF& lhs, const RectF& rhs) { + return abs(lhs.left - rhs.left) < kRectFTolerance && + abs(lhs.top - rhs.top) < kRectFTolerance && + abs(lhs.right - rhs.right) < kRectFTolerance && + abs(lhs.bottom - rhs.bottom) < kRectFTolerance; +} + +RectF ToRectF(const Rect& rect, int image_height, int image_width); + +Rect ToRect(const RectF& rect, int image_height, int image_width); + } // namespace mediapipe::tasks::components::containers #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index c3c0a0261..a86b2cca8 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -129,13 +129,13 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi { if (roi.left >= roi.right || roi.top >= roi.bottom) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect with left < right and top < bottom.", + "Expected RectF with left < right and top < bottom.", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect values to be in [0,1].", + "Expected RectF values to be in [0,1].", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } normalized_rect.set_x_center((roi.left + roi.right) / 2.0); diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h index 7e764c1fe..1983272fc 100644 --- a/mediapipe/tasks/cc/vision/core/image_processing_options.h +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -35,7 +35,8 @@ struct ImageProcessingOptions { // the full image is used. // // Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. - std::optional region_of_interest = std::nullopt; + std::optional region_of_interest = + std::nullopt; // The rotation to apply to the image (or cropped region-of-interest), in // degrees clockwise. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 564184c64..266ce223f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -44,7 +44,7 @@ namespace { using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::utils::CalculateIOU; using ::mediapipe::tasks::vision::utils::DuplicatesFinder; @@ -126,7 +126,7 @@ absl::StatusOr HandBaselineDistance( return distance; } -Rect CalculateBound(const NormalizedLandmarkList& list) { +RectF CalculateBound(const NormalizedLandmarkList& list) { constexpr float kMinInitialValue = std::numeric_limits::max(); constexpr float kMaxInitialValue = std::numeric_limits::lowest(); @@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) { } // Populate normalized non rotated face bounding box - return Rect{/*left=*/bounding_box_left, - /*top=*/bounding_box_top, - /*right=*/bounding_box_right, - /*bottom=*/bounding_box_bottom}; + return RectF{/*left=*/bounding_box_left, + /*top=*/bounding_box_top, + /*right=*/bounding_box_right, + /*bottom=*/bounding_box_bottom}; } // Uses IoU and distance of some corresponding hand landmarks to detect @@ -172,7 +172,7 @@ class HandDuplicatesFinder : public DuplicatesFinder { const int num = multi_landmarks.size(); std::vector baseline_distances; baseline_distances.reserve(num); - std::vector bounds; + std::vector bounds; bounds.reserve(num); for (const NormalizedLandmarkList& list : multi_landmarks) { ASSIGN_OR_RETURN(const float baseline_distance, diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc index fa49a4c1f..94d1b1c12 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc @@ -50,7 +50,7 @@ namespace { using ::file::Defaults; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::EqualsProto; @@ -188,7 +188,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { options->running_mode = core::RunningMode::IMAGE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, HandLandmarker::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = hand_landmarker->Detect(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 1144e9032..7aa2a148c 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -52,7 +52,7 @@ namespace { using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::components::containers::Category; using ::mediapipe::tasks::components::containers::Classifications; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -472,7 +472,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( @@ -526,7 +526,8 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the chair, with 90° anti-clockwise rotation. - Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049}; + RectF roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, + /*bottom=*/0.3049}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; @@ -554,13 +555,13 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { ImageClassifier::Create(std::move(options))); // Invalid: left > right. - Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; + RectF roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -573,7 +574,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -586,7 +587,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect values to be in [0,1]")); + HasSubstr("Expected RectF values to be in [0,1]")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -695,7 +696,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { @@ -837,7 +838,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index 6098a9a70..dd602bef5 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -41,7 +41,7 @@ namespace image_embedder { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -320,7 +320,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { Image crop, DecodeImageFromFile( JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); // Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". - Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; // Extract both embeddings. @@ -388,7 +388,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger_rotated.jpg"))); // Region-of-interest corresponding to burger_crop.jpg. - Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index d5ea088a1..f9618c1b1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -47,7 +47,7 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -299,7 +299,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = segmenter->Segment(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 77373303a..5269796ae 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -33,6 +33,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/containers:detection_result", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index dd19237ff..e0222dd70 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -56,6 +57,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ObjectDetectorGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; @@ -129,7 +131,8 @@ absl::StatusOr> ObjectDetector::Create( Packet detections_packet = status_or_packets.value()[kDetectionsOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName]; - result_callback(detections_packet.Get>(), + result_callback(ConvertToDetectionResult( + detections_packet.Get>()), image_packet.Get(), detections_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); @@ -144,7 +147,7 @@ absl::StatusOr> ObjectDetector::Create( std::move(packets_callback)); } -absl::StatusOr> ObjectDetector::Detect( +absl::StatusOr ObjectDetector::Detect( mediapipe::Image image, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -161,10 +164,11 @@ absl::StatusOr> ObjectDetector::Detect( ProcessImageData( {{kImageInStreamName, MakePacket(std::move(image))}, {kNormRectName, MakePacket(std::move(norm_rect))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } -absl::StatusOr> ObjectDetector::DetectForVideo( +absl::StatusOr ObjectDetector::DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -185,7 +189,8 @@ absl::StatusOr> ObjectDetector::DetectForVideo( {kNormRectName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } absl::Status ObjectDetector::DetectAsync( diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 44ce68ed9..249a2ebf5 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" @@ -36,6 +37,10 @@ namespace mediapipe { namespace tasks { namespace vision { +// Alias the shared DetectionResult struct as result typo. +using ObjectDetectorResult = + ::mediapipe::tasks::components::containers::DetectionResult; + // The options for configuring a mediapipe object detector task. struct ObjectDetectorOptions { // Base options for configuring MediaPipe Tasks, such as specifying the TfLite @@ -79,8 +84,7 @@ struct ObjectDetectorOptions { // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function>, - const Image&, int64)> + std::function, const Image&, int64)> result_callback = nullptr; }; @@ -165,7 +169,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // underlying image data. // TODO: Describes the output bounding boxes for gpu input // images after enabling the gpu support in MediaPipe Tasks. - absl::StatusOr> Detect( + absl::StatusOr Detect( mediapipe::Image image, std::optional image_processing_options = std::nullopt); @@ -188,7 +192,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the // underlying image data. - absl::StatusOr> DetectForVideo( + absl::StatusOr DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options = std::nullopt); diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 1747685dd..798e3f238 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -65,10 +66,14 @@ namespace vision { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; +using ::mediapipe::tasks::components::containers::Detection; +using ::mediapipe::tasks::components::containers::DetectionResult; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; +using DetectionProto = mediapipe::Detection; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kMobileSsdWithMetadata[] = @@ -83,47 +88,45 @@ constexpr char kEfficientDetWithMetadata[] = // Checks that the two provided `Detection` proto vectors are equal, with a // tolerancy on floating-point scores to account for numerical instabilities. // If the proto definition changes, please also change this function. -void ExpectApproximatelyEqual(const std::vector& actual, - const std::vector& expected) { +void ExpectApproximatelyEqual(const ObjectDetectorResult& actual, + const ObjectDetectorResult& expected) { const float kPrecision = 1e-6; - EXPECT_EQ(actual.size(), expected.size()); - for (int i = 0; i < actual.size(); ++i) { - const Detection& a = actual[i]; - const Detection& b = expected[i]; - EXPECT_THAT(a.location_data().bounding_box(), - EqualsProto(b.location_data().bounding_box())); - EXPECT_EQ(a.label_size(), 1); - EXPECT_EQ(b.label_size(), 1); - EXPECT_EQ(a.label(0), b.label(0)); - EXPECT_EQ(a.score_size(), 1); - EXPECT_EQ(b.score_size(), 1); - EXPECT_NEAR(a.score(0), b.score(0), kPrecision); + EXPECT_EQ(actual.detections.size(), expected.detections.size()); + for (int i = 0; i < actual.detections.size(); ++i) { + const Detection& a = actual.detections[i]; + const Detection& b = expected.detections[i]; + EXPECT_EQ(a.bounding_box, b.bounding_box); + EXPECT_EQ(a.categories.size(), 1); + EXPECT_EQ(b.categories.size(), 1); + EXPECT_EQ(a.categories[0].category_name, b.categories[0].category_name); + EXPECT_NEAR(a.categories[0].score, b.categories[0].score, kPrecision); } } -std::vector GenerateMobileSsdNoImageResizingFullExpectedResults() { - return {ParseTextProtoOrDie(R"pb( +std::vector +GenerateMobileSsdNoImageResizingFullExpectedResults() { + return {ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6328125 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.59765625 location_data { format: BOUNDING_BOX bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.5 location_data { format: BOUNDING_BOX bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "dog" score: 0.48828125 location_data { @@ -263,8 +266,8 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = running_mode; options->result_callback = - [](absl::StatusOr> detections, - const Image& image, int64 timestamp_ms) {}; + [](absl::StatusOr detections, const Image& image, + int64 timestamp_ms) {}; absl::StatusOr> object_detector = ObjectDetector::Create(std::move(options)); EXPECT_EQ(object_detector.status().code(), @@ -340,34 +343,36 @@ TEST_F(ImageModeTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.69921875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.64453125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.51171875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.48828125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.69921875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.64453125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.51171875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.48828125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } + })pb")})); } TEST_F(ImageModeTest, SucceedsEfficientDetModel) { @@ -383,34 +388,36 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.7578125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.72265625 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.6289063 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.5859375 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.7578125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.72265625 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.6289063 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.5859375 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } + })pb")})); } TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { @@ -426,7 +433,8 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, GenerateMobileSsdNoImageResizingFullExpectedResults()); + results, ConvertToDetectionResult( + GenerateMobileSsdNoImageResizingFullExpectedResults())); } TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { @@ -442,13 +450,14 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6531269142 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { @@ -463,11 +472,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, - {full_expected_results[0], full_expected_results[1], - full_expected_results[2]}); + + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1], + full_expected_results[2]})); } TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { @@ -482,10 +493,11 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { @@ -501,9 +513,10 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithDenylistOption) { @@ -519,9 +532,10 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithRotation) { @@ -541,13 +555,14 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { auto results, object_detector->Detect(image, image_processing_options)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.7109375 location_data { format: BOUNDING_BOX bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, FailsWithRegionOfInterest) { @@ -560,7 +575,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = object_detector->Detect(image, image_processing_options); @@ -619,10 +634,11 @@ TEST_F(VideoModeTest, Succeeds) { for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->DetectForVideo(image, i)); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } MP_ASSERT_OK(object_detector->Close()); } @@ -637,9 +653,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -669,9 +684,8 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { options->running_mode = core::RunningMode::LIVE_STREAM; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); MP_ASSERT_OK(object_detector->DetectAsync(image, 1)); @@ -695,14 +709,14 @@ TEST_F(LiveStreamModeTest, Succeeds) { auto options = std::make_unique(); options->max_results = 2; options->running_mode = core::RunningMode::LIVE_STREAM; - std::vector> detection_results; + std::vector detection_results; std::vector> image_sizes; std::vector timestamps; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->result_callback = [&detection_results, &image_sizes, ×tamps]( - absl::StatusOr> detections, const Image& image, + absl::StatusOr detections, const Image& image, int64 timestamp_ms) { MP_ASSERT_OK(detections.status()); detection_results.push_back(std::move(detections).value()); @@ -719,11 +733,12 @@ TEST_F(LiveStreamModeTest, Succeeds) { // number of iterations. ASSERT_LE(detection_results.size(), iterations); ASSERT_GT(detection_results.size(), 0); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); for (const auto& detection_result : detection_results) { ExpectApproximatelyEqual( - detection_result, {full_expected_results[0], full_expected_results[1]}); + detection_result, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1]})); } for (const auto& image_size : image_sizes) { EXPECT_EQ(image_size.first, image.width()); diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc index 2ce9e2454..fe4e63824 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc @@ -22,13 +22,13 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; -float CalculateArea(const Rect& rect) { +float CalculateArea(const RectF& rect) { return (rect.right - rect.left) * (rect.bottom - rect.top); } -float CalculateIntersectionArea(const Rect& a, const Rect& b) { +float CalculateIntersectionArea(const RectF& a, const RectF& b) { const float intersection_left = std::max(a.left, b.left); const float intersection_top = std::max(a.top, b.top); const float intersection_right = std::min(a.right, b.right); @@ -38,7 +38,7 @@ float CalculateIntersectionArea(const Rect& a, const Rect& b) { std::max(intersection_right - intersection_left, 0.0); } -float CalculateIOU(const Rect& a, const Rect& b) { +float CalculateIOU(const RectF& a, const RectF& b) { const float area_a = CalculateArea(a); const float area_b = CalculateArea(b); if (area_a <= 0 || area_b <= 0) return 0.0; diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h index 73114d2ef..4d1fac62f 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h @@ -27,15 +27,15 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { // Calculates intersection over union for two bounds. -float CalculateIOU(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIOU(const components::containers::RectF& a, + const components::containers::RectF& b); // Calculates area for face bound -float CalculateArea(const components::containers::Rect& rect); +float CalculateArea(const components::containers::RectF& rect); // Calucates intersection area of two face bounds -float CalculateIntersectionArea(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIntersectionArea(const components::containers::RectF& a, + const components::containers::RectF& b); } // namespace mediapipe::tasks::vision::utils #endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_