internal change

PiperOrigin-RevId: 493742399
This commit is contained in:
MediaPipe Team 2022-12-07 16:37:08 -08:00 committed by Copybara-Service
parent a59f0a9924
commit a0efcb47f2
18 changed files with 377 additions and 151 deletions

View File

@ -18,6 +18,7 @@ licenses(["notice"])
cc_library( cc_library(
name = "rect", name = "rect",
srcs = ["rect.cc"],
hdrs = ["rect.h"], hdrs = ["rect.h"],
) )
@ -41,6 +42,18 @@ cc_library(
], ],
) )
cc_library(
name = "detection_result",
srcs = ["detection_result.cc"],
hdrs = ["detection_result.h"],
deps = [
":category",
":rect",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
],
)
cc_library( cc_library(
name = "embedding_result", name = "embedding_result",
srcs = ["embedding_result.cc"], srcs = ["embedding_result.cc"],

View File

@ -0,0 +1,73 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include <strings.h>
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::components::containers {
constexpr int kDefaultCategoryIndex = -1;
Detection ConvertToDetectionResult(
const mediapipe::Detection& detection_proto) {
Detection detection;
for (int idx = 0; idx < detection_proto.score_size(); ++idx) {
detection.categories.push_back(
{/* index= */ detection_proto.label_id_size() > idx
? detection_proto.label_id(idx)
: kDefaultCategoryIndex,
/* score= */ detection_proto.score(idx),
/* category_name */ detection_proto.label_size() > idx
? detection_proto.label(idx)
: "",
/* display_name */ detection_proto.display_name_size() > idx
? detection_proto.display_name(idx)
: ""});
}
Rect bounding_box;
if (detection_proto.location_data().has_bounding_box()) {
mediapipe::LocationData::BoundingBox bounding_box_proto =
detection_proto.location_data().bounding_box();
bounding_box.left = bounding_box_proto.xmin();
bounding_box.top = bounding_box_proto.ymin();
bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width();
bounding_box.bottom =
bounding_box_proto.ymin() + bounding_box_proto.height();
}
detection.bounding_box = bounding_box;
return detection;
}
DetectionResult ConvertToDetectionResult(
std::vector<mediapipe::Detection> detections_proto) {
DetectionResult detection_result;
detection_result.detections.reserve(detections_proto.size());
for (const auto& detection_proto : detections_proto) {
detection_result.detections.push_back(
ConvertToDetectionResult(detection_proto));
}
return detection_result;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,52 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::components::containers {
// Detection for a single bounding box.
struct Detection {
// A vector of detected categories.
std::vector<Category> categories;
// The bounding box location.
Rect bounding_box;
};
// Detection results of a model.
struct DetectionResult {
// A vector of Detections.
std::vector<Detection> detections;
};
// Utility function to convert from Detection proto to Detection struct.
Detection ConvertToDetection(const mediapipe::Detection& detection_proto);
// Utility function to convert from list of Detection proto to DetectionResult
// struct.
DetectionResult ConvertToDetectionResult(
std::vector<mediapipe::Detection> detections_proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_

View File

@ -0,0 +1,34 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::components::containers {
RectF ToRectF(const Rect& rect, int image_height, int image_width) {
return RectF{static_cast<float>(rect.left) / image_width,
static_cast<float>(rect.top) / image_height,
static_cast<float>(rect.right) / image_width,
static_cast<float>(rect.bottom) / image_height};
}
Rect ToRect(const RectF& rect, int image_height, int image_width) {
return Rect{static_cast<int>(rect.left * image_width),
static_cast<int>(rect.top * image_height),
static_cast<int>(rect.right * image_width),
static_cast<int>(rect.bottom * image_height)};
}
} // namespace mediapipe::tasks::components::containers

View File

@ -16,20 +16,47 @@ limitations under the License.
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
#include <cstdlib>
namespace mediapipe::tasks::components::containers { namespace mediapipe::tasks::components::containers {
constexpr float kRectFTolerance = 1e-4;
// Defines a rectangle, used e.g. as part of detection results or as input // Defines a rectangle, used e.g. as part of detection results or as input
// region-of-interest. // region-of-interest.
// //
struct Rect {
int left;
int top;
int right;
int bottom;
};
inline bool operator==(const Rect& lhs, const Rect& rhs) {
return lhs.left == rhs.left && lhs.top == rhs.top && lhs.right == rhs.right &&
lhs.bottom == rhs.bottom;
}
// The coordinates are normalized wrt the image dimensions, i.e. generally in // The coordinates are normalized wrt the image dimensions, i.e. generally in
// [0,1] but they may exceed these bounds if describing a region overlapping the // [0,1] but they may exceed these bounds if describing a region overlapping the
// image. The origin is on the top-left corner of the image. // image. The origin is on the top-left corner of the image.
struct Rect { struct RectF {
float left; float left;
float top; float top;
float right; float right;
float bottom; float bottom;
}; };
inline bool operator==(const RectF& lhs, const RectF& rhs) {
return abs(lhs.left - rhs.left) < kRectFTolerance &&
abs(lhs.top - rhs.top) < kRectFTolerance &&
abs(lhs.right - rhs.right) < kRectFTolerance &&
abs(lhs.bottom - rhs.bottom) < kRectFTolerance;
}
RectF ToRectF(const Rect& rect, int image_height, int image_width);
Rect ToRect(const RectF& rect, int image_height, int image_width);
} // namespace mediapipe::tasks::components::containers } // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_

View File

@ -129,13 +129,13 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi {
if (roi.left >= roi.right || roi.top >= roi.bottom) { if (roi.left >= roi.right || roi.top >= roi.bottom) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"Expected Rect with left < right and top < bottom.", "Expected RectF with left < right and top < bottom.",
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
} }
if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"Expected Rect values to be in [0,1].", "Expected RectF values to be in [0,1].",
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
} }
normalized_rect.set_x_center((roi.left + roi.right) / 2.0); normalized_rect.set_x_center((roi.left + roi.right) / 2.0);

View File

@ -35,7 +35,8 @@ struct ImageProcessingOptions {
// the full image is used. // the full image is used.
// //
// Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. // Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom.
std::optional<components::containers::Rect> region_of_interest = std::nullopt; std::optional<components::containers::RectF> region_of_interest =
std::nullopt;
// The rotation to apply to the image (or cropped region-of-interest), in // The rotation to apply to the image (or cropped region-of-interest), in
// degrees clockwise. // degrees clockwise.

View File

@ -44,7 +44,7 @@ namespace {
using ::mediapipe::api2::Input; using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output; using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::utils::CalculateIOU; using ::mediapipe::tasks::vision::utils::CalculateIOU;
using ::mediapipe::tasks::vision::utils::DuplicatesFinder; using ::mediapipe::tasks::vision::utils::DuplicatesFinder;
@ -126,7 +126,7 @@ absl::StatusOr<float> HandBaselineDistance(
return distance; return distance;
} }
Rect CalculateBound(const NormalizedLandmarkList& list) { RectF CalculateBound(const NormalizedLandmarkList& list) {
constexpr float kMinInitialValue = std::numeric_limits<float>::max(); constexpr float kMinInitialValue = std::numeric_limits<float>::max();
constexpr float kMaxInitialValue = std::numeric_limits<float>::lowest(); constexpr float kMaxInitialValue = std::numeric_limits<float>::lowest();
@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) {
} }
// Populate normalized non rotated face bounding box // Populate normalized non rotated face bounding box
return Rect{/*left=*/bounding_box_left, return RectF{/*left=*/bounding_box_left,
/*top=*/bounding_box_top, /*top=*/bounding_box_top,
/*right=*/bounding_box_right, /*right=*/bounding_box_right,
/*bottom=*/bounding_box_bottom}; /*bottom=*/bounding_box_bottom};
} }
// Uses IoU and distance of some corresponding hand landmarks to detect // Uses IoU and distance of some corresponding hand landmarks to detect
@ -172,7 +172,7 @@ class HandDuplicatesFinder : public DuplicatesFinder {
const int num = multi_landmarks.size(); const int num = multi_landmarks.size();
std::vector<float> baseline_distances; std::vector<float> baseline_distances;
baseline_distances.reserve(num); baseline_distances.reserve(num);
std::vector<Rect> bounds; std::vector<RectF> bounds;
bounds.reserve(num); bounds.reserve(num);
for (const NormalizedLandmarkList& list : multi_landmarks) { for (const NormalizedLandmarkList& list : multi_landmarks) {
ASSIGN_OR_RETURN(const float baseline_distance, ASSIGN_OR_RETURN(const float baseline_distance,

View File

@ -50,7 +50,7 @@ namespace {
using ::file::Defaults; using ::file::Defaults;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::EqualsProto; using ::testing::EqualsProto;
@ -188,7 +188,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
options->running_mode = core::RunningMode::IMAGE; options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options))); HandLandmarker::Create(std::move(options)));
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = hand_landmarker->Detect(image, image_processing_options); auto results = hand_landmarker->Detect(image, image_processing_options);

View File

@ -52,7 +52,7 @@ namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Category; using ::mediapipe::tasks::components::containers::Category;
using ::mediapipe::tasks::components::containers::Classifications; using ::mediapipe::tasks::components::containers::Classifications;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -472,7 +472,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// Region-of-interest around the soccer ball. // Region-of-interest around the soccer ball.
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
@ -526,7 +526,8 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// Region-of-interest around the chair, with 90° anti-clockwise rotation. // Region-of-interest around the chair, with 90° anti-clockwise rotation.
Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049}; RectF roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702,
/*bottom=*/0.3049};
ImageProcessingOptions image_processing_options{roi, ImageProcessingOptions image_processing_options{roi,
/*rotation_degrees=*/-90}; /*rotation_degrees=*/-90};
@ -554,13 +555,13 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// Invalid: left > right. // Invalid: left > right.
Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; RectF roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, ImageProcessingOptions image_processing_options{roi,
/*rotation_degrees=*/0}; /*rotation_degrees=*/0};
auto results = image_classifier->Classify(image, image_processing_options); auto results = image_classifier->Classify(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(), EXPECT_THAT(results.status().message(),
HasSubstr("Expected Rect with left < right and top < bottom")); HasSubstr("Expected RectF with left < right and top < bottom"));
EXPECT_THAT( EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload), results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat( Optional(absl::Cord(absl::StrCat(
@ -573,7 +574,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
results = image_classifier->Classify(image, image_processing_options); results = image_classifier->Classify(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(), EXPECT_THAT(results.status().message(),
HasSubstr("Expected Rect with left < right and top < bottom")); HasSubstr("Expected RectF with left < right and top < bottom"));
EXPECT_THAT( EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload), results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat( Optional(absl::Cord(absl::StrCat(
@ -586,7 +587,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
results = image_classifier->Classify(image, image_processing_options); results = image_classifier->Classify(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(), EXPECT_THAT(results.status().message(),
HasSubstr("Expected Rect values to be in [0,1]")); HasSubstr("Expected RectF values to be in [0,1]"));
EXPECT_THAT( EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload), results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat( Optional(absl::Cord(absl::StrCat(
@ -695,7 +696,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// Crop around the soccer ball. // Crop around the soccer ball.
// Region-of-interest around the soccer ball. // Region-of-interest around the soccer ball.
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
@ -837,7 +838,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// Crop around the soccer ball. // Crop around the soccer ball.
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {

View File

@ -41,7 +41,7 @@ namespace image_embedder {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -320,7 +320,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
Image crop, DecodeImageFromFile( Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". // Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; RectF roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
// Extract both embeddings. // Extract both embeddings.
@ -388,7 +388,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg"))); "burger_rotated.jpg")));
// Region-of-interest corresponding to burger_crop.jpg. // Region-of-interest corresponding to burger_crop.jpg.
Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; RectF roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
ImageProcessingOptions image_processing_options{roi, ImageProcessingOptions image_processing_options{roi,
/*rotation_degrees=*/-90}; /*rotation_degrees=*/-90};

View File

@ -47,7 +47,7 @@ namespace {
using ::mediapipe::Image; using ::mediapipe::Image;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -299,7 +299,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = segmenter->Segment(image, image_processing_options); auto results = segmenter->Segment(image, image_processing_options);

View File

@ -33,6 +33,7 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
"//mediapipe/tasks/cc/components/containers:detection_result",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -56,6 +57,7 @@ constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.ObjectDetectorGraph"; "mediapipe.tasks.vision.ObjectDetectorGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::ConvertToDetectionResult;
using ObjectDetectorOptionsProto = using ObjectDetectorOptionsProto =
object_detector::proto::ObjectDetectorOptions; object_detector::proto::ObjectDetectorOptions;
@ -129,7 +131,8 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
Packet detections_packet = Packet detections_packet =
status_or_packets.value()[kDetectionsOutStreamName]; status_or_packets.value()[kDetectionsOutStreamName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(detections_packet.Get<std::vector<Detection>>(), result_callback(ConvertToDetectionResult(
detections_packet.Get<std::vector<Detection>>()),
image_packet.Get<Image>(), image_packet.Get<Image>(),
detections_packet.Timestamp().Value() / detections_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond); kMicroSecondsPerMilliSecond);
@ -144,7 +147,7 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
std::move(packets_callback)); std::move(packets_callback));
} }
absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect( absl::StatusOr<ObjectDetectorResult> ObjectDetector::Detect(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -161,10 +164,11 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
ProcessImageData( ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))}, {{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}})); {kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>(); return ConvertToDetectionResult(
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
} }
absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo( absl::StatusOr<ObjectDetectorResult> ObjectDetector::DetectForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -185,7 +189,8 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
{kNormRectName, {kNormRectName,
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>(); return ConvertToDetectionResult(
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
} }
absl::Status ObjectDetector::DetectAsync( absl::Status ObjectDetector::DetectAsync(

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
@ -36,6 +37,10 @@ namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
// Alias the shared DetectionResult struct as result typo.
using ObjectDetectorResult =
::mediapipe::tasks::components::containers::DetectionResult;
// The options for configuring a mediapipe object detector task. // The options for configuring a mediapipe object detector task.
struct ObjectDetectorOptions { struct ObjectDetectorOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite // Base options for configuring MediaPipe Tasks, such as specifying the TfLite
@ -79,8 +84,7 @@ struct ObjectDetectorOptions {
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.
std::function<void(absl::StatusOr<std::vector<mediapipe::Detection>>, std::function<void(absl::StatusOr<ObjectDetectorResult>, const Image&, int64)>
const Image&, int64)>
result_callback = nullptr; result_callback = nullptr;
}; };
@ -165,7 +169,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// underlying image data. // underlying image data.
// TODO: Describes the output bounding boxes for gpu input // TODO: Describes the output bounding boxes for gpu input
// images after enabling the gpu support in MediaPipe Tasks. // images after enabling the gpu support in MediaPipe Tasks.
absl::StatusOr<std::vector<mediapipe::Detection>> Detect( absl::StatusOr<ObjectDetectorResult> Detect(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -188,7 +192,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// unrotated input frame of reference coordinates system, i.e. in `[0, // unrotated input frame of reference coordinates system, i.e. in `[0,
// image_width) x [0, image_height)`, which are the dimensions of the // image_width) x [0, image_height)`, which are the dimensions of the
// underlying image data. // underlying image data.
absl::StatusOr<std::vector<mediapipe::Detection>> DetectForVideo( absl::StatusOr<ObjectDetectorResult> DetectForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
@ -65,10 +66,14 @@ namespace vision {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::ConvertToDetectionResult;
using ::mediapipe::tasks::components::containers::Detection;
using ::mediapipe::tasks::components::containers::DetectionResult;
using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
using DetectionProto = mediapipe::Detection;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kMobileSsdWithMetadata[] = constexpr char kMobileSsdWithMetadata[] =
@ -83,47 +88,45 @@ constexpr char kEfficientDetWithMetadata[] =
// Checks that the two provided `Detection` proto vectors are equal, with a // Checks that the two provided `Detection` proto vectors are equal, with a
// tolerancy on floating-point scores to account for numerical instabilities. // tolerancy on floating-point scores to account for numerical instabilities.
// If the proto definition changes, please also change this function. // If the proto definition changes, please also change this function.
void ExpectApproximatelyEqual(const std::vector<Detection>& actual, void ExpectApproximatelyEqual(const ObjectDetectorResult& actual,
const std::vector<Detection>& expected) { const ObjectDetectorResult& expected) {
const float kPrecision = 1e-6; const float kPrecision = 1e-6;
EXPECT_EQ(actual.size(), expected.size()); EXPECT_EQ(actual.detections.size(), expected.detections.size());
for (int i = 0; i < actual.size(); ++i) { for (int i = 0; i < actual.detections.size(); ++i) {
const Detection& a = actual[i]; const Detection& a = actual.detections[i];
const Detection& b = expected[i]; const Detection& b = expected.detections[i];
EXPECT_THAT(a.location_data().bounding_box(), EXPECT_EQ(a.bounding_box, b.bounding_box);
EqualsProto(b.location_data().bounding_box())); EXPECT_EQ(a.categories.size(), 1);
EXPECT_EQ(a.label_size(), 1); EXPECT_EQ(b.categories.size(), 1);
EXPECT_EQ(b.label_size(), 1); EXPECT_EQ(a.categories[0].category_name, b.categories[0].category_name);
EXPECT_EQ(a.label(0), b.label(0)); EXPECT_NEAR(a.categories[0].score, b.categories[0].score, kPrecision);
EXPECT_EQ(a.score_size(), 1);
EXPECT_EQ(b.score_size(), 1);
EXPECT_NEAR(a.score(0), b.score(0), kPrecision);
} }
} }
std::vector<Detection> GenerateMobileSsdNoImageResizingFullExpectedResults() { std::vector<DetectionProto>
return {ParseTextProtoOrDie<Detection>(R"pb( GenerateMobileSsdNoImageResizingFullExpectedResults() {
return {ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.6328125 score: 0.6328125
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
})pb"), })pb"),
ParseTextProtoOrDie<Detection>(R"pb( ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.59765625 score: 0.59765625
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 } bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 }
})pb"), })pb"),
ParseTextProtoOrDie<Detection>(R"pb( ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.5 score: 0.5
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 } bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 }
})pb"), })pb"),
ParseTextProtoOrDie<Detection>(R"pb( ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "dog" label: "dog"
score: 0.48828125 score: 0.48828125
location_data { location_data {
@ -263,8 +266,8 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) {
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
options->running_mode = running_mode; options->running_mode = running_mode;
options->result_callback = options->result_callback =
[](absl::StatusOr<std::vector<Detection>> detections, [](absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
const Image& image, int64 timestamp_ms) {}; int64 timestamp_ms) {};
absl::StatusOr<std::unique_ptr<ObjectDetector>> object_detector = absl::StatusOr<std::unique_ptr<ObjectDetector>> object_detector =
ObjectDetector::Create(std::move(options)); ObjectDetector::Create(std::move(options));
EXPECT_EQ(object_detector.status().code(), EXPECT_EQ(object_detector.status().code(),
@ -340,34 +343,36 @@ TEST_F(ImageModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close()); MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual( ExpectApproximatelyEqual(
results, {ParseTextProtoOrDie<Detection>(R"pb( results,
label: "cat" ConvertToDetectionResult(
score: 0.69921875 {ParseTextProtoOrDie<DetectionProto>(R"pb(
location_data { label: "cat"
format: BOUNDING_BOX score: 0.69921875
bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } location_data {
})pb"), format: BOUNDING_BOX
ParseTextProtoOrDie<Detection>(R"pb( bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 }
label: "cat" })pb"),
score: 0.64453125 ParseTextProtoOrDie<DetectionProto>(R"pb(
location_data { label: "cat"
format: BOUNDING_BOX score: 0.64453125
bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } location_data {
})pb"), format: BOUNDING_BOX
ParseTextProtoOrDie<Detection>(R"pb( bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 }
label: "cat" })pb"),
score: 0.51171875 ParseTextProtoOrDie<DetectionProto>(R"pb(
location_data { label: "cat"
format: BOUNDING_BOX score: 0.51171875
bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } location_data {
})pb"), format: BOUNDING_BOX
ParseTextProtoOrDie<Detection>(R"pb( bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 }
label: "cat" })pb"),
score: 0.48828125 ParseTextProtoOrDie<DetectionProto>(R"pb(
location_data { label: "cat"
format: BOUNDING_BOX score: 0.48828125
bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } location_data {
})pb")}); format: BOUNDING_BOX
bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 }
})pb")}));
} }
TEST_F(ImageModeTest, SucceedsEfficientDetModel) { TEST_F(ImageModeTest, SucceedsEfficientDetModel) {
@ -383,34 +388,36 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) {
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close()); MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual( ExpectApproximatelyEqual(
results, {ParseTextProtoOrDie<Detection>(R"pb( results,
label: "cat" ConvertToDetectionResult(
score: 0.7578125 {ParseTextProtoOrDie<DetectionProto>(R"pb(
location_data { label: "cat"
format: BOUNDING_BOX score: 0.7578125
bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } location_data {
})pb"), format: BOUNDING_BOX
ParseTextProtoOrDie<Detection>(R"pb( bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 }
label: "cat" })pb"),
score: 0.72265625 ParseTextProtoOrDie<DetectionProto>(R"pb(
location_data { label: "cat"
format: BOUNDING_BOX score: 0.72265625
bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } location_data {
})pb"), format: BOUNDING_BOX
ParseTextProtoOrDie<Detection>(R"pb( bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 }
label: "cat" })pb"),
score: 0.6289063 ParseTextProtoOrDie<DetectionProto>(R"pb(
location_data { label: "cat"
format: BOUNDING_BOX score: 0.6289063
bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } location_data {
})pb"), format: BOUNDING_BOX
ParseTextProtoOrDie<Detection>(R"pb( bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 }
label: "cat" })pb"),
score: 0.5859375 ParseTextProtoOrDie<DetectionProto>(R"pb(
location_data { label: "cat"
format: BOUNDING_BOX score: 0.5859375
bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } location_data {
})pb")}); format: BOUNDING_BOX
bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 }
})pb")}));
} }
TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
@ -426,7 +433,8 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close()); MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual( ExpectApproximatelyEqual(
results, GenerateMobileSsdNoImageResizingFullExpectedResults()); results, ConvertToDetectionResult(
GenerateMobileSsdNoImageResizingFullExpectedResults()));
} }
TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
@ -442,13 +450,14 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close()); MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual( ExpectApproximatelyEqual(
results, {ParseTextProtoOrDie<Detection>(R"pb( results,
ConvertToDetectionResult({ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.6531269142 score: 0.6531269142
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
})pb")}); })pb")}));
} }
TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
@ -463,11 +472,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
ObjectDetector::Create(std::move(options))); ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close()); MP_ASSERT_OK(object_detector->Close());
std::vector<Detection> full_expected_results = std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults(); GenerateMobileSsdNoImageResizingFullExpectedResults();
ExpectApproximatelyEqual(results,
{full_expected_results[0], full_expected_results[1], ExpectApproximatelyEqual(
full_expected_results[2]}); results, ConvertToDetectionResult({full_expected_results[0],
full_expected_results[1],
full_expected_results[2]}));
} }
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
@ -482,10 +493,11 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
ObjectDetector::Create(std::move(options))); ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close()); MP_ASSERT_OK(object_detector->Close());
std::vector<Detection> full_expected_results = std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults(); GenerateMobileSsdNoImageResizingFullExpectedResults();
ExpectApproximatelyEqual( ExpectApproximatelyEqual(
results, {full_expected_results[0], full_expected_results[1]}); results, ConvertToDetectionResult(
{full_expected_results[0], full_expected_results[1]}));
} }
TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
@ -501,9 +513,10 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
ObjectDetector::Create(std::move(options))); ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close()); MP_ASSERT_OK(object_detector->Close());
std::vector<Detection> full_expected_results = std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults(); GenerateMobileSsdNoImageResizingFullExpectedResults();
ExpectApproximatelyEqual(results, {full_expected_results[3]}); ExpectApproximatelyEqual(
results, ConvertToDetectionResult({full_expected_results[3]}));
} }
TEST_F(ImageModeTest, SucceedsWithDenylistOption) { TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
@ -519,9 +532,10 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
ObjectDetector::Create(std::move(options))); ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close()); MP_ASSERT_OK(object_detector->Close());
std::vector<Detection> full_expected_results = std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults(); GenerateMobileSsdNoImageResizingFullExpectedResults();
ExpectApproximatelyEqual(results, {full_expected_results[3]}); ExpectApproximatelyEqual(
results, ConvertToDetectionResult({full_expected_results[3]}));
} }
TEST_F(ImageModeTest, SucceedsWithRotation) { TEST_F(ImageModeTest, SucceedsWithRotation) {
@ -541,13 +555,14 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
auto results, object_detector->Detect(image, image_processing_options)); auto results, object_detector->Detect(image, image_processing_options));
MP_ASSERT_OK(object_detector->Close()); MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual( ExpectApproximatelyEqual(
results, {ParseTextProtoOrDie<Detection>(R"pb( results,
ConvertToDetectionResult({ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.7109375 score: 0.7109375
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 } bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 }
})pb")}); })pb")}));
} }
TEST_F(ImageModeTest, FailsWithRegionOfInterest) { TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
@ -560,7 +575,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options))); ObjectDetector::Create(std::move(options)));
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = object_detector->Detect(image, image_processing_options); auto results = object_detector->Detect(image, image_processing_options);
@ -619,10 +634,11 @@ TEST_F(VideoModeTest, Succeeds) {
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto results, MP_ASSERT_OK_AND_ASSIGN(auto results,
object_detector->DetectForVideo(image, i)); object_detector->DetectForVideo(image, i));
std::vector<Detection> full_expected_results = std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults(); GenerateMobileSsdNoImageResizingFullExpectedResults();
ExpectApproximatelyEqual( ExpectApproximatelyEqual(
results, {full_expected_results[0], full_expected_results[1]}); results, ConvertToDetectionResult(
{full_expected_results[0], full_expected_results[1]}));
} }
MP_ASSERT_OK(object_detector->Close()); MP_ASSERT_OK(object_detector->Close());
} }
@ -637,9 +653,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
[](absl::StatusOr<std::vector<Detection>> detections, const Image& image, const Image& image, int64 timestamp_ms) {};
int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options))); ObjectDetector::Create(std::move(options)));
@ -669,9 +684,8 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
options->result_callback = options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
[](absl::StatusOr<std::vector<Detection>> detections, const Image& image, const Image& image, int64 timestamp_ms) {};
int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options))); ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK(object_detector->DetectAsync(image, 1)); MP_ASSERT_OK(object_detector->DetectAsync(image, 1));
@ -695,14 +709,14 @@ TEST_F(LiveStreamModeTest, Succeeds) {
auto options = std::make_unique<ObjectDetectorOptions>(); auto options = std::make_unique<ObjectDetectorOptions>();
options->max_results = 2; options->max_results = 2;
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
std::vector<std::vector<Detection>> detection_results; std::vector<ObjectDetectorResult> detection_results;
std::vector<std::pair<int, int>> image_sizes; std::vector<std::pair<int, int>> image_sizes;
std::vector<int64> timestamps; std::vector<int64> timestamps;
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
options->result_callback = options->result_callback =
[&detection_results, &image_sizes, &timestamps]( [&detection_results, &image_sizes, &timestamps](
absl::StatusOr<std::vector<Detection>> detections, const Image& image, absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
int64 timestamp_ms) { int64 timestamp_ms) {
MP_ASSERT_OK(detections.status()); MP_ASSERT_OK(detections.status());
detection_results.push_back(std::move(detections).value()); detection_results.push_back(std::move(detections).value());
@ -719,11 +733,12 @@ TEST_F(LiveStreamModeTest, Succeeds) {
// number of iterations. // number of iterations.
ASSERT_LE(detection_results.size(), iterations); ASSERT_LE(detection_results.size(), iterations);
ASSERT_GT(detection_results.size(), 0); ASSERT_GT(detection_results.size(), 0);
std::vector<Detection> full_expected_results = std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults(); GenerateMobileSsdNoImageResizingFullExpectedResults();
for (const auto& detection_result : detection_results) { for (const auto& detection_result : detection_results) {
ExpectApproximatelyEqual( ExpectApproximatelyEqual(
detection_result, {full_expected_results[0], full_expected_results[1]}); detection_result, ConvertToDetectionResult({full_expected_results[0],
full_expected_results[1]}));
} }
for (const auto& image_size : image_sizes) { for (const auto& image_size : image_sizes) {
EXPECT_EQ(image_size.first, image.width()); EXPECT_EQ(image_size.first, image.width());

View File

@ -22,13 +22,13 @@ limitations under the License.
namespace mediapipe::tasks::vision::utils { namespace mediapipe::tasks::vision::utils {
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::RectF;
float CalculateArea(const Rect& rect) { float CalculateArea(const RectF& rect) {
return (rect.right - rect.left) * (rect.bottom - rect.top); return (rect.right - rect.left) * (rect.bottom - rect.top);
} }
float CalculateIntersectionArea(const Rect& a, const Rect& b) { float CalculateIntersectionArea(const RectF& a, const RectF& b) {
const float intersection_left = std::max<float>(a.left, b.left); const float intersection_left = std::max<float>(a.left, b.left);
const float intersection_top = std::max<float>(a.top, b.top); const float intersection_top = std::max<float>(a.top, b.top);
const float intersection_right = std::min<float>(a.right, b.right); const float intersection_right = std::min<float>(a.right, b.right);
@ -38,7 +38,7 @@ float CalculateIntersectionArea(const Rect& a, const Rect& b) {
std::max<float>(intersection_right - intersection_left, 0.0); std::max<float>(intersection_right - intersection_left, 0.0);
} }
float CalculateIOU(const Rect& a, const Rect& b) { float CalculateIOU(const RectF& a, const RectF& b) {
const float area_a = CalculateArea(a); const float area_a = CalculateArea(a);
const float area_b = CalculateArea(b); const float area_b = CalculateArea(b);
if (area_a <= 0 || area_b <= 0) return 0.0; if (area_a <= 0 || area_b <= 0) return 0.0;

View File

@ -27,15 +27,15 @@ limitations under the License.
namespace mediapipe::tasks::vision::utils { namespace mediapipe::tasks::vision::utils {
// Calculates intersection over union for two bounds. // Calculates intersection over union for two bounds.
float CalculateIOU(const components::containers::Rect& a, float CalculateIOU(const components::containers::RectF& a,
const components::containers::Rect& b); const components::containers::RectF& b);
// Calculates area for face bound // Calculates area for face bound
float CalculateArea(const components::containers::Rect& rect); float CalculateArea(const components::containers::RectF& rect);
// Calucates intersection area of two face bounds // Calucates intersection area of two face bounds
float CalculateIntersectionArea(const components::containers::Rect& a, float CalculateIntersectionArea(const components::containers::RectF& a,
const components::containers::Rect& b); const components::containers::RectF& b);
} // namespace mediapipe::tasks::vision::utils } // namespace mediapipe::tasks::vision::utils
#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ #endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_