internal change

PiperOrigin-RevId: 493742399
This commit is contained in:
MediaPipe Team 2022-12-07 16:37:08 -08:00 committed by Copybara-Service
parent a59f0a9924
commit a0efcb47f2
18 changed files with 377 additions and 151 deletions

View File

@ -18,6 +18,7 @@ licenses(["notice"])
cc_library(
name = "rect",
srcs = ["rect.cc"],
hdrs = ["rect.h"],
)
@ -41,6 +42,18 @@ cc_library(
],
)
cc_library(
name = "detection_result",
srcs = ["detection_result.cc"],
hdrs = ["detection_result.h"],
deps = [
":category",
":rect",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
],
)
cc_library(
name = "embedding_result",
srcs = ["embedding_result.cc"],

View File

@ -0,0 +1,73 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include <strings.h>
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::components::containers {
constexpr int kDefaultCategoryIndex = -1;
Detection ConvertToDetectionResult(
const mediapipe::Detection& detection_proto) {
Detection detection;
for (int idx = 0; idx < detection_proto.score_size(); ++idx) {
detection.categories.push_back(
{/* index= */ detection_proto.label_id_size() > idx
? detection_proto.label_id(idx)
: kDefaultCategoryIndex,
/* score= */ detection_proto.score(idx),
/* category_name */ detection_proto.label_size() > idx
? detection_proto.label(idx)
: "",
/* display_name */ detection_proto.display_name_size() > idx
? detection_proto.display_name(idx)
: ""});
}
Rect bounding_box;
if (detection_proto.location_data().has_bounding_box()) {
mediapipe::LocationData::BoundingBox bounding_box_proto =
detection_proto.location_data().bounding_box();
bounding_box.left = bounding_box_proto.xmin();
bounding_box.top = bounding_box_proto.ymin();
bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width();
bounding_box.bottom =
bounding_box_proto.ymin() + bounding_box_proto.height();
}
detection.bounding_box = bounding_box;
return detection;
}
DetectionResult ConvertToDetectionResult(
std::vector<mediapipe::Detection> detections_proto) {
DetectionResult detection_result;
detection_result.detections.reserve(detections_proto.size());
for (const auto& detection_proto : detections_proto) {
detection_result.detections.push_back(
ConvertToDetectionResult(detection_proto));
}
return detection_result;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,52 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::components::containers {
// Detection for a single bounding box.
struct Detection {
// A vector of detected categories.
std::vector<Category> categories;
// The bounding box location.
Rect bounding_box;
};
// Detection results of a model.
struct DetectionResult {
// A vector of Detections.
std::vector<Detection> detections;
};
// Utility function to convert from Detection proto to Detection struct.
Detection ConvertToDetection(const mediapipe::Detection& detection_proto);
// Utility function to convert from list of Detection proto to DetectionResult
// struct.
DetectionResult ConvertToDetectionResult(
std::vector<mediapipe::Detection> detections_proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_

View File

@ -0,0 +1,34 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::components::containers {
RectF ToRectF(const Rect& rect, int image_height, int image_width) {
return RectF{static_cast<float>(rect.left) / image_width,
static_cast<float>(rect.top) / image_height,
static_cast<float>(rect.right) / image_width,
static_cast<float>(rect.bottom) / image_height};
}
Rect ToRect(const RectF& rect, int image_height, int image_width) {
return Rect{static_cast<int>(rect.left * image_width),
static_cast<int>(rect.top * image_height),
static_cast<int>(rect.right * image_width),
static_cast<int>(rect.bottom * image_height)};
}
} // namespace mediapipe::tasks::components::containers

View File

@ -16,20 +16,47 @@ limitations under the License.
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
#include <cstdlib>
namespace mediapipe::tasks::components::containers {
constexpr float kRectFTolerance = 1e-4;
// Defines a rectangle, used e.g. as part of detection results or as input
// region-of-interest.
//
struct Rect {
int left;
int top;
int right;
int bottom;
};
inline bool operator==(const Rect& lhs, const Rect& rhs) {
return lhs.left == rhs.left && lhs.top == rhs.top && lhs.right == rhs.right &&
lhs.bottom == rhs.bottom;
}
// The coordinates are normalized wrt the image dimensions, i.e. generally in
// [0,1] but they may exceed these bounds if describing a region overlapping the
// image. The origin is on the top-left corner of the image.
struct Rect {
struct RectF {
float left;
float top;
float right;
float bottom;
};
inline bool operator==(const RectF& lhs, const RectF& rhs) {
return abs(lhs.left - rhs.left) < kRectFTolerance &&
abs(lhs.top - rhs.top) < kRectFTolerance &&
abs(lhs.right - rhs.right) < kRectFTolerance &&
abs(lhs.bottom - rhs.bottom) < kRectFTolerance;
}
RectF ToRectF(const Rect& rect, int image_height, int image_width);
Rect ToRect(const RectF& rect, int image_height, int image_width);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_

View File

@ -129,13 +129,13 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi {
if (roi.left >= roi.right || roi.top >= roi.bottom) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Expected Rect with left < right and top < bottom.",
"Expected RectF with left < right and top < bottom.",
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
}
if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Expected Rect values to be in [0,1].",
"Expected RectF values to be in [0,1].",
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
}
normalized_rect.set_x_center((roi.left + roi.right) / 2.0);

View File

@ -35,7 +35,8 @@ struct ImageProcessingOptions {
// the full image is used.
//
// Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom.
std::optional<components::containers::Rect> region_of_interest = std::nullopt;
std::optional<components::containers::RectF> region_of_interest =
std::nullopt;
// The rotation to apply to the image (or cropped region-of-interest), in
// degrees clockwise.

View File

@ -44,7 +44,7 @@ namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::utils::CalculateIOU;
using ::mediapipe::tasks::vision::utils::DuplicatesFinder;
@ -126,7 +126,7 @@ absl::StatusOr<float> HandBaselineDistance(
return distance;
}
Rect CalculateBound(const NormalizedLandmarkList& list) {
RectF CalculateBound(const NormalizedLandmarkList& list) {
constexpr float kMinInitialValue = std::numeric_limits<float>::max();
constexpr float kMaxInitialValue = std::numeric_limits<float>::lowest();
@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) {
}
// Populate normalized non rotated face bounding box
return Rect{/*left=*/bounding_box_left,
/*top=*/bounding_box_top,
/*right=*/bounding_box_right,
/*bottom=*/bounding_box_bottom};
return RectF{/*left=*/bounding_box_left,
/*top=*/bounding_box_top,
/*right=*/bounding_box_right,
/*bottom=*/bounding_box_bottom};
}
// Uses IoU and distance of some corresponding hand landmarks to detect
@ -172,7 +172,7 @@ class HandDuplicatesFinder : public DuplicatesFinder {
const int num = multi_landmarks.size();
std::vector<float> baseline_distances;
baseline_distances.reserve(num);
std::vector<Rect> bounds;
std::vector<RectF> bounds;
bounds.reserve(num);
for (const NormalizedLandmarkList& list : multi_landmarks) {
ASSIGN_OR_RETURN(const float baseline_distance,

View File

@ -50,7 +50,7 @@ namespace {
using ::file::Defaults;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::EqualsProto;
@ -188,7 +188,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = hand_landmarker->Detect(image, image_processing_options);

View File

@ -52,7 +52,7 @@ namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Category;
using ::mediapipe::tasks::components::containers::Classifications;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
@ -472,7 +472,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
// Region-of-interest around the soccer ball.
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
@ -526,7 +526,8 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
// Region-of-interest around the chair, with 90° anti-clockwise rotation.
Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049};
RectF roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702,
/*bottom=*/0.3049};
ImageProcessingOptions image_processing_options{roi,
/*rotation_degrees=*/-90};
@ -554,13 +555,13 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
ImageClassifier::Create(std::move(options)));
// Invalid: left > right.
Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1};
RectF roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi,
/*rotation_degrees=*/0};
auto results = image_classifier->Classify(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("Expected Rect with left < right and top < bottom"));
HasSubstr("Expected RectF with left < right and top < bottom"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
@ -573,7 +574,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
results = image_classifier->Classify(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("Expected Rect with left < right and top < bottom"));
HasSubstr("Expected RectF with left < right and top < bottom"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
@ -586,7 +587,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
results = image_classifier->Classify(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("Expected Rect values to be in [0,1]"));
HasSubstr("Expected RectF values to be in [0,1]"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
@ -695,7 +696,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
ImageClassifier::Create(std::move(options)));
// Crop around the soccer ball.
// Region-of-interest around the soccer ball.
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
for (int i = 0; i < iterations; ++i) {
@ -837,7 +838,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
// Crop around the soccer ball.
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
for (int i = 0; i < iterations; ++i) {

View File

@ -41,7 +41,7 @@ namespace image_embedder {
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
@ -320,7 +320,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
RectF roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
// Extract both embeddings.
@ -388,7 +388,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg")));
// Region-of-interest corresponding to burger_crop.jpg.
Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
RectF roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
ImageProcessingOptions image_processing_options{roi,
/*rotation_degrees=*/-90};

View File

@ -47,7 +47,7 @@ namespace {
using ::mediapipe::Image;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
@ -299,7 +299,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = segmenter->Segment(image, image_processing_options);

View File

@ -33,6 +33,7 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
"//mediapipe/tasks/cc/components/containers:detection_result",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -56,6 +57,7 @@ constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.ObjectDetectorGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::ConvertToDetectionResult;
using ObjectDetectorOptionsProto =
object_detector::proto::ObjectDetectorOptions;
@ -129,7 +131,8 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
Packet detections_packet =
status_or_packets.value()[kDetectionsOutStreamName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(detections_packet.Get<std::vector<Detection>>(),
result_callback(ConvertToDetectionResult(
detections_packet.Get<std::vector<Detection>>()),
image_packet.Get<Image>(),
detections_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
@ -144,7 +147,7 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
std::move(packets_callback));
}
absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
absl::StatusOr<ObjectDetectorResult> ObjectDetector::Detect(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
@ -161,10 +164,11 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
return ConvertToDetectionResult(
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
}
absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
absl::StatusOr<ObjectDetectorResult> ObjectDetector::DetectForVideo(
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
@ -185,7 +189,8 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
{kNormRectName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
return ConvertToDetectionResult(
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
}
absl::Status ObjectDetector::DetectAsync(

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
@ -36,6 +37,10 @@ namespace mediapipe {
namespace tasks {
namespace vision {
// Alias the shared DetectionResult struct as result typo.
using ObjectDetectorResult =
::mediapipe::tasks::components::containers::DetectionResult;
// The options for configuring a mediapipe object detector task.
struct ObjectDetectorOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
@ -79,8 +84,7 @@ struct ObjectDetectorOptions {
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(absl::StatusOr<std::vector<mediapipe::Detection>>,
const Image&, int64)>
std::function<void(absl::StatusOr<ObjectDetectorResult>, const Image&, int64)>
result_callback = nullptr;
};
@ -165,7 +169,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// underlying image data.
// TODO: Describes the output bounding boxes for gpu input
// images after enabling the gpu support in MediaPipe Tasks.
absl::StatusOr<std::vector<mediapipe::Detection>> Detect(
absl::StatusOr<ObjectDetectorResult> Detect(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
@ -188,7 +192,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// unrotated input frame of reference coordinates system, i.e. in `[0,
// image_width) x [0, image_height)`, which are the dimensions of the
// underlying image data.
absl::StatusOr<std::vector<mediapipe::Detection>> DetectForVideo(
absl::StatusOr<ObjectDetectorResult> DetectForVideo(
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
@ -65,10 +66,14 @@ namespace vision {
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::ConvertToDetectionResult;
using ::mediapipe::tasks::components::containers::Detection;
using ::mediapipe::tasks::components::containers::DetectionResult;
using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
using DetectionProto = mediapipe::Detection;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kMobileSsdWithMetadata[] =
@ -83,47 +88,45 @@ constexpr char kEfficientDetWithMetadata[] =
// Checks that the two provided `Detection` proto vectors are equal, with a
// tolerancy on floating-point scores to account for numerical instabilities.
// If the proto definition changes, please also change this function.
void ExpectApproximatelyEqual(const std::vector<Detection>& actual,
const std::vector<Detection>& expected) {
void ExpectApproximatelyEqual(const ObjectDetectorResult& actual,
const ObjectDetectorResult& expected) {
const float kPrecision = 1e-6;
EXPECT_EQ(actual.size(), expected.size());
for (int i = 0; i < actual.size(); ++i) {
const Detection& a = actual[i];
const Detection& b = expected[i];
EXPECT_THAT(a.location_data().bounding_box(),
EqualsProto(b.location_data().bounding_box()));
EXPECT_EQ(a.label_size(), 1);
EXPECT_EQ(b.label_size(), 1);
EXPECT_EQ(a.label(0), b.label(0));
EXPECT_EQ(a.score_size(), 1);
EXPECT_EQ(b.score_size(), 1);
EXPECT_NEAR(a.score(0), b.score(0), kPrecision);
EXPECT_EQ(actual.detections.size(), expected.detections.size());
for (int i = 0; i < actual.detections.size(); ++i) {
const Detection& a = actual.detections[i];
const Detection& b = expected.detections[i];
EXPECT_EQ(a.bounding_box, b.bounding_box);
EXPECT_EQ(a.categories.size(), 1);
EXPECT_EQ(b.categories.size(), 1);
EXPECT_EQ(a.categories[0].category_name, b.categories[0].category_name);
EXPECT_NEAR(a.categories[0].score, b.categories[0].score, kPrecision);
}
}
std::vector<Detection> GenerateMobileSsdNoImageResizingFullExpectedResults() {
return {ParseTextProtoOrDie<Detection>(R"pb(
std::vector<DetectionProto>
GenerateMobileSsdNoImageResizingFullExpectedResults() {
return {ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.6328125
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
})pb"),
ParseTextProtoOrDie<Detection>(R"pb(
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.59765625
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 }
})pb"),
ParseTextProtoOrDie<Detection>(R"pb(
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.5
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 }
})pb"),
ParseTextProtoOrDie<Detection>(R"pb(
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "dog"
score: 0.48828125
location_data {
@ -263,8 +266,8 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) {
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
options->running_mode = running_mode;
options->result_callback =
[](absl::StatusOr<std::vector<Detection>> detections,
const Image& image, int64 timestamp_ms) {};
[](absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
int64 timestamp_ms) {};
absl::StatusOr<std::unique_ptr<ObjectDetector>> object_detector =
ObjectDetector::Create(std::move(options));
EXPECT_EQ(object_detector.status().code(),
@ -340,34 +343,36 @@ TEST_F(ImageModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual(
results, {ParseTextProtoOrDie<Detection>(R"pb(
label: "cat"
score: 0.69921875
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 }
})pb"),
ParseTextProtoOrDie<Detection>(R"pb(
label: "cat"
score: 0.64453125
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 }
})pb"),
ParseTextProtoOrDie<Detection>(R"pb(
label: "cat"
score: 0.51171875
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 }
})pb"),
ParseTextProtoOrDie<Detection>(R"pb(
label: "cat"
score: 0.48828125
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 }
})pb")});
results,
ConvertToDetectionResult(
{ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.69921875
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 }
})pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.64453125
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 }
})pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.51171875
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 }
})pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.48828125
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 }
})pb")}));
}
TEST_F(ImageModeTest, SucceedsEfficientDetModel) {
@ -383,34 +388,36 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) {
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual(
results, {ParseTextProtoOrDie<Detection>(R"pb(
label: "cat"
score: 0.7578125
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 }
})pb"),
ParseTextProtoOrDie<Detection>(R"pb(
label: "cat"
score: 0.72265625
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 }
})pb"),
ParseTextProtoOrDie<Detection>(R"pb(
label: "cat"
score: 0.6289063
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 }
})pb"),
ParseTextProtoOrDie<Detection>(R"pb(
label: "cat"
score: 0.5859375
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 }
})pb")});
results,
ConvertToDetectionResult(
{ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.7578125
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 }
})pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.72265625
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 }
})pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.6289063
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 }
})pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.5859375
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 }
})pb")}));
}
TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
@ -426,7 +433,8 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual(
results, GenerateMobileSsdNoImageResizingFullExpectedResults());
results, ConvertToDetectionResult(
GenerateMobileSsdNoImageResizingFullExpectedResults()));
}
TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
@ -442,13 +450,14 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual(
results, {ParseTextProtoOrDie<Detection>(R"pb(
results,
ConvertToDetectionResult({ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.6531269142
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
})pb")});
})pb")}));
}
TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
@ -463,11 +472,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close());
std::vector<Detection> full_expected_results =
std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults();
ExpectApproximatelyEqual(results,
{full_expected_results[0], full_expected_results[1],
full_expected_results[2]});
ExpectApproximatelyEqual(
results, ConvertToDetectionResult({full_expected_results[0],
full_expected_results[1],
full_expected_results[2]}));
}
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
@ -482,10 +493,11 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close());
std::vector<Detection> full_expected_results =
std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults();
ExpectApproximatelyEqual(
results, {full_expected_results[0], full_expected_results[1]});
results, ConvertToDetectionResult(
{full_expected_results[0], full_expected_results[1]}));
}
TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
@ -501,9 +513,10 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close());
std::vector<Detection> full_expected_results =
std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults();
ExpectApproximatelyEqual(results, {full_expected_results[3]});
ExpectApproximatelyEqual(
results, ConvertToDetectionResult({full_expected_results[3]}));
}
TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
@ -519,9 +532,10 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close());
std::vector<Detection> full_expected_results =
std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults();
ExpectApproximatelyEqual(results, {full_expected_results[3]});
ExpectApproximatelyEqual(
results, ConvertToDetectionResult({full_expected_results[3]}));
}
TEST_F(ImageModeTest, SucceedsWithRotation) {
@ -541,13 +555,14 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
auto results, object_detector->Detect(image, image_processing_options));
MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual(
results, {ParseTextProtoOrDie<Detection>(R"pb(
results,
ConvertToDetectionResult({ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.7109375
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 }
})pb")});
})pb")}));
}
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
@ -560,7 +575,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options)));
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = object_detector->Detect(image, image_processing_options);
@ -619,10 +634,11 @@ TEST_F(VideoModeTest, Succeeds) {
for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto results,
object_detector->DetectForVideo(image, i));
std::vector<Detection> full_expected_results =
std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults();
ExpectApproximatelyEqual(
results, {full_expected_results[0], full_expected_results[1]});
results, ConvertToDetectionResult(
{full_expected_results[0], full_expected_results[1]}));
}
MP_ASSERT_OK(object_detector->Close());
}
@ -637,9 +653,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback =
[](absl::StatusOr<std::vector<Detection>> detections, const Image& image,
int64 timestamp_ms) {};
options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options)));
@ -669,9 +684,8 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
options->running_mode = core::RunningMode::LIVE_STREAM;
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
options->result_callback =
[](absl::StatusOr<std::vector<Detection>> detections, const Image& image,
int64 timestamp_ms) {};
options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK(object_detector->DetectAsync(image, 1));
@ -695,14 +709,14 @@ TEST_F(LiveStreamModeTest, Succeeds) {
auto options = std::make_unique<ObjectDetectorOptions>();
options->max_results = 2;
options->running_mode = core::RunningMode::LIVE_STREAM;
std::vector<std::vector<Detection>> detection_results;
std::vector<ObjectDetectorResult> detection_results;
std::vector<std::pair<int, int>> image_sizes;
std::vector<int64> timestamps;
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
options->result_callback =
[&detection_results, &image_sizes, &timestamps](
absl::StatusOr<std::vector<Detection>> detections, const Image& image,
absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
int64 timestamp_ms) {
MP_ASSERT_OK(detections.status());
detection_results.push_back(std::move(detections).value());
@ -719,11 +733,12 @@ TEST_F(LiveStreamModeTest, Succeeds) {
// number of iterations.
ASSERT_LE(detection_results.size(), iterations);
ASSERT_GT(detection_results.size(), 0);
std::vector<Detection> full_expected_results =
std::vector<DetectionProto> full_expected_results =
GenerateMobileSsdNoImageResizingFullExpectedResults();
for (const auto& detection_result : detection_results) {
ExpectApproximatelyEqual(
detection_result, {full_expected_results[0], full_expected_results[1]});
detection_result, ConvertToDetectionResult({full_expected_results[0],
full_expected_results[1]}));
}
for (const auto& image_size : image_sizes) {
EXPECT_EQ(image_size.first, image.width());

View File

@ -22,13 +22,13 @@ limitations under the License.
namespace mediapipe::tasks::vision::utils {
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::RectF;
float CalculateArea(const Rect& rect) {
float CalculateArea(const RectF& rect) {
return (rect.right - rect.left) * (rect.bottom - rect.top);
}
float CalculateIntersectionArea(const Rect& a, const Rect& b) {
float CalculateIntersectionArea(const RectF& a, const RectF& b) {
const float intersection_left = std::max<float>(a.left, b.left);
const float intersection_top = std::max<float>(a.top, b.top);
const float intersection_right = std::min<float>(a.right, b.right);
@ -38,7 +38,7 @@ float CalculateIntersectionArea(const Rect& a, const Rect& b) {
std::max<float>(intersection_right - intersection_left, 0.0);
}
float CalculateIOU(const Rect& a, const Rect& b) {
float CalculateIOU(const RectF& a, const RectF& b) {
const float area_a = CalculateArea(a);
const float area_b = CalculateArea(b);
if (area_a <= 0 || area_b <= 0) return 0.0;

View File

@ -27,15 +27,15 @@ limitations under the License.
namespace mediapipe::tasks::vision::utils {
// Calculates intersection over union for two bounds.
float CalculateIOU(const components::containers::Rect& a,
const components::containers::Rect& b);
float CalculateIOU(const components::containers::RectF& a,
const components::containers::RectF& b);
// Calculates area for face bound
float CalculateArea(const components::containers::Rect& rect);
float CalculateArea(const components::containers::RectF& rect);
// Calucates intersection area of two face bounds
float CalculateIntersectionArea(const components::containers::Rect& a,
const components::containers::Rect& b);
float CalculateIntersectionArea(const components::containers::RectF& a,
const components::containers::RectF& b);
} // namespace mediapipe::tasks::vision::utils
#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_