internal change
PiperOrigin-RevId: 493742399
This commit is contained in:
parent
a59f0a9924
commit
a0efcb47f2
|
@ -18,6 +18,7 @@ licenses(["notice"])
|
|||
|
||||
cc_library(
|
||||
name = "rect",
|
||||
srcs = ["rect.cc"],
|
||||
hdrs = ["rect.h"],
|
||||
)
|
||||
|
||||
|
@ -41,6 +42,18 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "detection_result",
|
||||
srcs = ["detection_result.cc"],
|
||||
hdrs = ["detection_result.h"],
|
||||
deps = [
|
||||
":category",
|
||||
":rect",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedding_result",
|
||||
srcs = ["embedding_result.cc"],
|
||||
|
|
73
mediapipe/tasks/cc/components/containers/detection_result.cc
Normal file
73
mediapipe/tasks/cc/components/containers/detection_result.cc
Normal file
|
@ -0,0 +1,73 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||
|
||||
#include <strings.h>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
constexpr int kDefaultCategoryIndex = -1;
|
||||
|
||||
Detection ConvertToDetectionResult(
|
||||
const mediapipe::Detection& detection_proto) {
|
||||
Detection detection;
|
||||
for (int idx = 0; idx < detection_proto.score_size(); ++idx) {
|
||||
detection.categories.push_back(
|
||||
{/* index= */ detection_proto.label_id_size() > idx
|
||||
? detection_proto.label_id(idx)
|
||||
: kDefaultCategoryIndex,
|
||||
/* score= */ detection_proto.score(idx),
|
||||
/* category_name */ detection_proto.label_size() > idx
|
||||
? detection_proto.label(idx)
|
||||
: "",
|
||||
/* display_name */ detection_proto.display_name_size() > idx
|
||||
? detection_proto.display_name(idx)
|
||||
: ""});
|
||||
}
|
||||
Rect bounding_box;
|
||||
if (detection_proto.location_data().has_bounding_box()) {
|
||||
mediapipe::LocationData::BoundingBox bounding_box_proto =
|
||||
detection_proto.location_data().bounding_box();
|
||||
bounding_box.left = bounding_box_proto.xmin();
|
||||
bounding_box.top = bounding_box_proto.ymin();
|
||||
bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width();
|
||||
bounding_box.bottom =
|
||||
bounding_box_proto.ymin() + bounding_box_proto.height();
|
||||
}
|
||||
detection.bounding_box = bounding_box;
|
||||
return detection;
|
||||
}
|
||||
|
||||
DetectionResult ConvertToDetectionResult(
|
||||
std::vector<mediapipe::Detection> detections_proto) {
|
||||
DetectionResult detection_result;
|
||||
detection_result.detections.reserve(detections_proto.size());
|
||||
for (const auto& detection_proto : detections_proto) {
|
||||
detection_result.detections.push_back(
|
||||
ConvertToDetectionResult(detection_proto));
|
||||
}
|
||||
return detection_result;
|
||||
}
|
||||
} // namespace mediapipe::tasks::components::containers
|
52
mediapipe/tasks/cc/components/containers/detection_result.h
Normal file
52
mediapipe/tasks/cc/components/containers/detection_result.h
Normal file
|
@ -0,0 +1,52 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
// Detection for a single bounding box.
|
||||
struct Detection {
|
||||
// A vector of detected categories.
|
||||
std::vector<Category> categories;
|
||||
// The bounding box location.
|
||||
Rect bounding_box;
|
||||
};
|
||||
|
||||
// Detection results of a model.
|
||||
struct DetectionResult {
|
||||
// A vector of Detections.
|
||||
std::vector<Detection> detections;
|
||||
};
|
||||
|
||||
// Utility function to convert from Detection proto to Detection struct.
|
||||
Detection ConvertToDetection(const mediapipe::Detection& detection_proto);
|
||||
|
||||
// Utility function to convert from list of Detection proto to DetectionResult
|
||||
// struct.
|
||||
DetectionResult ConvertToDetectionResult(
|
||||
std::vector<mediapipe::Detection> detections_proto);
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
34
mediapipe/tasks/cc/components/containers/rect.cc
Normal file
34
mediapipe/tasks/cc/components/containers/rect.cc
Normal file
|
@ -0,0 +1,34 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
RectF ToRectF(const Rect& rect, int image_height, int image_width) {
|
||||
return RectF{static_cast<float>(rect.left) / image_width,
|
||||
static_cast<float>(rect.top) / image_height,
|
||||
static_cast<float>(rect.right) / image_width,
|
||||
static_cast<float>(rect.bottom) / image_height};
|
||||
}
|
||||
|
||||
Rect ToRect(const RectF& rect, int image_height, int image_width) {
|
||||
return Rect{static_cast<int>(rect.left * image_width),
|
||||
static_cast<int>(rect.top * image_height),
|
||||
static_cast<int>(rect.right * image_width),
|
||||
static_cast<int>(rect.bottom * image_height)};
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
|
@ -16,20 +16,47 @@ limitations under the License.
|
|||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
constexpr float kRectFTolerance = 1e-4;
|
||||
|
||||
// Defines a rectangle, used e.g. as part of detection results or as input
|
||||
// region-of-interest.
|
||||
//
|
||||
struct Rect {
|
||||
int left;
|
||||
int top;
|
||||
int right;
|
||||
int bottom;
|
||||
};
|
||||
|
||||
inline bool operator==(const Rect& lhs, const Rect& rhs) {
|
||||
return lhs.left == rhs.left && lhs.top == rhs.top && lhs.right == rhs.right &&
|
||||
lhs.bottom == rhs.bottom;
|
||||
}
|
||||
|
||||
// The coordinates are normalized wrt the image dimensions, i.e. generally in
|
||||
// [0,1] but they may exceed these bounds if describing a region overlapping the
|
||||
// image. The origin is on the top-left corner of the image.
|
||||
struct Rect {
|
||||
struct RectF {
|
||||
float left;
|
||||
float top;
|
||||
float right;
|
||||
float bottom;
|
||||
};
|
||||
|
||||
inline bool operator==(const RectF& lhs, const RectF& rhs) {
|
||||
return abs(lhs.left - rhs.left) < kRectFTolerance &&
|
||||
abs(lhs.top - rhs.top) < kRectFTolerance &&
|
||||
abs(lhs.right - rhs.right) < kRectFTolerance &&
|
||||
abs(lhs.bottom - rhs.bottom) < kRectFTolerance;
|
||||
}
|
||||
|
||||
RectF ToRectF(const Rect& rect, int image_height, int image_width);
|
||||
|
||||
Rect ToRect(const RectF& rect, int image_height, int image_width);
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
||||
|
|
|
@ -129,13 +129,13 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi {
|
|||
if (roi.left >= roi.right || roi.top >= roi.bottom) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Expected Rect with left < right and top < bottom.",
|
||||
"Expected RectF with left < right and top < bottom.",
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
|
||||
}
|
||||
if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Expected Rect values to be in [0,1].",
|
||||
"Expected RectF values to be in [0,1].",
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
|
||||
}
|
||||
normalized_rect.set_x_center((roi.left + roi.right) / 2.0);
|
||||
|
|
|
@ -35,7 +35,8 @@ struct ImageProcessingOptions {
|
|||
// the full image is used.
|
||||
//
|
||||
// Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom.
|
||||
std::optional<components::containers::Rect> region_of_interest = std::nullopt;
|
||||
std::optional<components::containers::RectF> region_of_interest =
|
||||
std::nullopt;
|
||||
|
||||
// The rotation to apply to the image (or cropped region-of-interest), in
|
||||
// degrees clockwise.
|
||||
|
|
|
@ -44,7 +44,7 @@ namespace {
|
|||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::RectF;
|
||||
using ::mediapipe::tasks::vision::utils::CalculateIOU;
|
||||
using ::mediapipe::tasks::vision::utils::DuplicatesFinder;
|
||||
|
||||
|
@ -126,7 +126,7 @@ absl::StatusOr<float> HandBaselineDistance(
|
|||
return distance;
|
||||
}
|
||||
|
||||
Rect CalculateBound(const NormalizedLandmarkList& list) {
|
||||
RectF CalculateBound(const NormalizedLandmarkList& list) {
|
||||
constexpr float kMinInitialValue = std::numeric_limits<float>::max();
|
||||
constexpr float kMaxInitialValue = std::numeric_limits<float>::lowest();
|
||||
|
||||
|
@ -144,7 +144,7 @@ Rect CalculateBound(const NormalizedLandmarkList& list) {
|
|||
}
|
||||
|
||||
// Populate normalized non rotated face bounding box
|
||||
return Rect{/*left=*/bounding_box_left,
|
||||
return RectF{/*left=*/bounding_box_left,
|
||||
/*top=*/bounding_box_top,
|
||||
/*right=*/bounding_box_right,
|
||||
/*bottom=*/bounding_box_bottom};
|
||||
|
@ -172,7 +172,7 @@ class HandDuplicatesFinder : public DuplicatesFinder {
|
|||
const int num = multi_landmarks.size();
|
||||
std::vector<float> baseline_distances;
|
||||
baseline_distances.reserve(num);
|
||||
std::vector<Rect> bounds;
|
||||
std::vector<RectF> bounds;
|
||||
bounds.reserve(num);
|
||||
for (const NormalizedLandmarkList& list : multi_landmarks) {
|
||||
ASSIGN_OR_RETURN(const float baseline_distance,
|
||||
|
|
|
@ -50,7 +50,7 @@ namespace {
|
|||
|
||||
using ::file::Defaults;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::RectF;
|
||||
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::EqualsProto;
|
||||
|
@ -188,7 +188,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
|||
options->running_mode = core::RunningMode::IMAGE;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||
HandLandmarker::Create(std::move(options)));
|
||||
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
auto results = hand_landmarker->Detect(image, image_processing_options);
|
||||
|
|
|
@ -52,7 +52,7 @@ namespace {
|
|||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Category;
|
||||
using ::mediapipe::tasks::components::containers::Classifications;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::RectF;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
@ -472,7 +472,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
// Region-of-interest around the soccer ball.
|
||||
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||
RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||
|
@ -526,7 +526,8 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
// Region-of-interest around the chair, with 90° anti-clockwise rotation.
|
||||
Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049};
|
||||
RectF roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702,
|
||||
/*bottom=*/0.3049};
|
||||
ImageProcessingOptions image_processing_options{roi,
|
||||
/*rotation_degrees=*/-90};
|
||||
|
||||
|
@ -554,13 +555,13 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
|
|||
ImageClassifier::Create(std::move(options)));
|
||||
|
||||
// Invalid: left > right.
|
||||
Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1};
|
||||
RectF roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi,
|
||||
/*rotation_degrees=*/0};
|
||||
auto results = image_classifier->Classify(image, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("Expected Rect with left < right and top < bottom"));
|
||||
HasSubstr("Expected RectF with left < right and top < bottom"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
|
@ -573,7 +574,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
|
|||
results = image_classifier->Classify(image, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("Expected Rect with left < right and top < bottom"));
|
||||
HasSubstr("Expected RectF with left < right and top < bottom"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
|
@ -586,7 +587,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
|
|||
results = image_classifier->Classify(image, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("Expected Rect values to be in [0,1]"));
|
||||
HasSubstr("Expected RectF values to be in [0,1]"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
|
@ -695,7 +696,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
|
|||
ImageClassifier::Create(std::move(options)));
|
||||
// Crop around the soccer ball.
|
||||
// Region-of-interest around the soccer ball.
|
||||
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||
RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
for (int i = 0; i < iterations; ++i) {
|
||||
|
@ -837,7 +838,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
// Crop around the soccer ball.
|
||||
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||
RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
for (int i = 0; i < iterations; ++i) {
|
||||
|
|
|
@ -41,7 +41,7 @@ namespace image_embedder {
|
|||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::RectF;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
@ -320,7 +320,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
|||
Image crop, DecodeImageFromFile(
|
||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||
// Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
|
||||
RectF roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
// Extract both embeddings.
|
||||
|
@ -388,7 +388,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
|||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
"burger_rotated.jpg")));
|
||||
// Region-of-interest corresponding to burger_crop.jpg.
|
||||
Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
|
||||
RectF roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
|
||||
ImageProcessingOptions image_processing_options{roi,
|
||||
/*rotation_degrees=*/-90};
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ namespace {
|
|||
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::RectF;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
@ -299,7 +299,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
|||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
auto results = segmenter->Segment(image, image_processing_options);
|
||||
|
|
|
@ -33,6 +33,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
||||
"//mediapipe/tasks/cc/components/containers:detection_result",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
|
@ -56,6 +57,7 @@ constexpr char kSubgraphTypeName[] =
|
|||
"mediapipe.tasks.vision.ObjectDetectorGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
using ::mediapipe::tasks::components::containers::ConvertToDetectionResult;
|
||||
using ObjectDetectorOptionsProto =
|
||||
object_detector::proto::ObjectDetectorOptions;
|
||||
|
||||
|
@ -129,7 +131,8 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
|
|||
Packet detections_packet =
|
||||
status_or_packets.value()[kDetectionsOutStreamName];
|
||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||
result_callback(detections_packet.Get<std::vector<Detection>>(),
|
||||
result_callback(ConvertToDetectionResult(
|
||||
detections_packet.Get<std::vector<Detection>>()),
|
||||
image_packet.Get<Image>(),
|
||||
detections_packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond);
|
||||
|
@ -144,7 +147,7 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
|
|||
std::move(packets_callback));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
|
||||
absl::StatusOr<ObjectDetectorResult> ObjectDetector::Detect(
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
|
@ -161,10 +164,11 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
|
|||
ProcessImageData(
|
||||
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
||||
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
|
||||
return ConvertToDetectionResult(
|
||||
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
|
||||
absl::StatusOr<ObjectDetectorResult> ObjectDetector::DetectForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
|
@ -185,7 +189,8 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
|
|||
{kNormRectName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
|
||||
return ConvertToDetectionResult(
|
||||
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
|
||||
}
|
||||
|
||||
absl::Status ObjectDetector::DetectAsync(
|
||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
|
@ -36,6 +37,10 @@ namespace mediapipe {
|
|||
namespace tasks {
|
||||
namespace vision {
|
||||
|
||||
// Alias the shared DetectionResult struct as result typo.
|
||||
using ObjectDetectorResult =
|
||||
::mediapipe::tasks::components::containers::DetectionResult;
|
||||
|
||||
// The options for configuring a mediapipe object detector task.
|
||||
struct ObjectDetectorOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
|
@ -79,8 +84,7 @@ struct ObjectDetectorOptions {
|
|||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM.
|
||||
std::function<void(absl::StatusOr<std::vector<mediapipe::Detection>>,
|
||||
const Image&, int64)>
|
||||
std::function<void(absl::StatusOr<ObjectDetectorResult>, const Image&, int64)>
|
||||
result_callback = nullptr;
|
||||
};
|
||||
|
||||
|
@ -165,7 +169,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
|||
// underlying image data.
|
||||
// TODO: Describes the output bounding boxes for gpu input
|
||||
// images after enabling the gpu support in MediaPipe Tasks.
|
||||
absl::StatusOr<std::vector<mediapipe::Detection>> Detect(
|
||||
absl::StatusOr<ObjectDetectorResult> Detect(
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
@ -188,7 +192,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
|||
// unrotated input frame of reference coordinates system, i.e. in `[0,
|
||||
// image_width) x [0, image_height)`, which are the dimensions of the
|
||||
// underlying image data.
|
||||
absl::StatusOr<std::vector<mediapipe::Detection>> DetectForVideo(
|
||||
absl::StatusOr<ObjectDetectorResult> DetectForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
|
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
@ -65,10 +66,14 @@ namespace vision {
|
|||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::ConvertToDetectionResult;
|
||||
using ::mediapipe::tasks::components::containers::Detection;
|
||||
using ::mediapipe::tasks::components::containers::DetectionResult;
|
||||
using ::mediapipe::tasks::components::containers::RectF;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
using DetectionProto = mediapipe::Detection;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kMobileSsdWithMetadata[] =
|
||||
|
@ -83,47 +88,45 @@ constexpr char kEfficientDetWithMetadata[] =
|
|||
// Checks that the two provided `Detection` proto vectors are equal, with a
|
||||
// tolerancy on floating-point scores to account for numerical instabilities.
|
||||
// If the proto definition changes, please also change this function.
|
||||
void ExpectApproximatelyEqual(const std::vector<Detection>& actual,
|
||||
const std::vector<Detection>& expected) {
|
||||
void ExpectApproximatelyEqual(const ObjectDetectorResult& actual,
|
||||
const ObjectDetectorResult& expected) {
|
||||
const float kPrecision = 1e-6;
|
||||
EXPECT_EQ(actual.size(), expected.size());
|
||||
for (int i = 0; i < actual.size(); ++i) {
|
||||
const Detection& a = actual[i];
|
||||
const Detection& b = expected[i];
|
||||
EXPECT_THAT(a.location_data().bounding_box(),
|
||||
EqualsProto(b.location_data().bounding_box()));
|
||||
EXPECT_EQ(a.label_size(), 1);
|
||||
EXPECT_EQ(b.label_size(), 1);
|
||||
EXPECT_EQ(a.label(0), b.label(0));
|
||||
EXPECT_EQ(a.score_size(), 1);
|
||||
EXPECT_EQ(b.score_size(), 1);
|
||||
EXPECT_NEAR(a.score(0), b.score(0), kPrecision);
|
||||
EXPECT_EQ(actual.detections.size(), expected.detections.size());
|
||||
for (int i = 0; i < actual.detections.size(); ++i) {
|
||||
const Detection& a = actual.detections[i];
|
||||
const Detection& b = expected.detections[i];
|
||||
EXPECT_EQ(a.bounding_box, b.bounding_box);
|
||||
EXPECT_EQ(a.categories.size(), 1);
|
||||
EXPECT_EQ(b.categories.size(), 1);
|
||||
EXPECT_EQ(a.categories[0].category_name, b.categories[0].category_name);
|
||||
EXPECT_NEAR(a.categories[0].score, b.categories[0].score, kPrecision);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Detection> GenerateMobileSsdNoImageResizingFullExpectedResults() {
|
||||
return {ParseTextProtoOrDie<Detection>(R"pb(
|
||||
std::vector<DetectionProto>
|
||||
GenerateMobileSsdNoImageResizingFullExpectedResults() {
|
||||
return {ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.6328125
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<Detection>(R"pb(
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.59765625
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<Detection>(R"pb(
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.5
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<Detection>(R"pb(
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "dog"
|
||||
score: 0.48828125
|
||||
location_data {
|
||||
|
@ -263,8 +266,8 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) {
|
|||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||
options->running_mode = running_mode;
|
||||
options->result_callback =
|
||||
[](absl::StatusOr<std::vector<Detection>> detections,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
[](absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
|
||||
int64 timestamp_ms) {};
|
||||
absl::StatusOr<std::unique_ptr<ObjectDetector>> object_detector =
|
||||
ObjectDetector::Create(std::move(options));
|
||||
EXPECT_EQ(object_detector.status().code(),
|
||||
|
@ -340,34 +343,36 @@ TEST_F(ImageModeTest, Succeeds) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
ExpectApproximatelyEqual(
|
||||
results, {ParseTextProtoOrDie<Detection>(R"pb(
|
||||
results,
|
||||
ConvertToDetectionResult(
|
||||
{ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.69921875
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<Detection>(R"pb(
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.64453125
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<Detection>(R"pb(
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.51171875
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<Detection>(R"pb(
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.48828125
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 }
|
||||
})pb")});
|
||||
})pb")}));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsEfficientDetModel) {
|
||||
|
@ -383,34 +388,36 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
ExpectApproximatelyEqual(
|
||||
results, {ParseTextProtoOrDie<Detection>(R"pb(
|
||||
results,
|
||||
ConvertToDetectionResult(
|
||||
{ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.7578125
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<Detection>(R"pb(
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.72265625
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<Detection>(R"pb(
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.6289063
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<Detection>(R"pb(
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.5859375
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 }
|
||||
})pb")});
|
||||
})pb")}));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
|
||||
|
@ -426,7 +433,8 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
ExpectApproximatelyEqual(
|
||||
results, GenerateMobileSsdNoImageResizingFullExpectedResults());
|
||||
results, ConvertToDetectionResult(
|
||||
GenerateMobileSsdNoImageResizingFullExpectedResults()));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
|
||||
|
@ -442,13 +450,14 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
ExpectApproximatelyEqual(
|
||||
results, {ParseTextProtoOrDie<Detection>(R"pb(
|
||||
results,
|
||||
ConvertToDetectionResult({ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.6531269142
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
|
||||
})pb")});
|
||||
})pb")}));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
|
||||
|
@ -463,11 +472,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
|
|||
ObjectDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
std::vector<Detection> full_expected_results =
|
||||
std::vector<DetectionProto> full_expected_results =
|
||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||
ExpectApproximatelyEqual(results,
|
||||
{full_expected_results[0], full_expected_results[1],
|
||||
full_expected_results[2]});
|
||||
|
||||
ExpectApproximatelyEqual(
|
||||
results, ConvertToDetectionResult({full_expected_results[0],
|
||||
full_expected_results[1],
|
||||
full_expected_results[2]}));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
|
||||
|
@ -482,10 +493,11 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
|
|||
ObjectDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
std::vector<Detection> full_expected_results =
|
||||
std::vector<DetectionProto> full_expected_results =
|
||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||
ExpectApproximatelyEqual(
|
||||
results, {full_expected_results[0], full_expected_results[1]});
|
||||
results, ConvertToDetectionResult(
|
||||
{full_expected_results[0], full_expected_results[1]}));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
|
||||
|
@ -501,9 +513,10 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
|
|||
ObjectDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
std::vector<Detection> full_expected_results =
|
||||
std::vector<DetectionProto> full_expected_results =
|
||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||
ExpectApproximatelyEqual(results, {full_expected_results[3]});
|
||||
ExpectApproximatelyEqual(
|
||||
results, ConvertToDetectionResult({full_expected_results[3]}));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
|
||||
|
@ -519,9 +532,10 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
|
|||
ObjectDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
std::vector<Detection> full_expected_results =
|
||||
std::vector<DetectionProto> full_expected_results =
|
||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||
ExpectApproximatelyEqual(results, {full_expected_results[3]});
|
||||
ExpectApproximatelyEqual(
|
||||
results, ConvertToDetectionResult({full_expected_results[3]}));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||
|
@ -541,13 +555,14 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
|
|||
auto results, object_detector->Detect(image, image_processing_options));
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
ExpectApproximatelyEqual(
|
||||
results, {ParseTextProtoOrDie<Detection>(R"pb(
|
||||
results,
|
||||
ConvertToDetectionResult({ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.7109375
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 }
|
||||
})pb")});
|
||||
})pb")}));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||
|
@ -560,7 +575,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
|||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||
ObjectDetector::Create(std::move(options)));
|
||||
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
auto results = object_detector->Detect(image, image_processing_options);
|
||||
|
@ -619,10 +634,11 @@ TEST_F(VideoModeTest, Succeeds) {
|
|||
for (int i = 0; i < iterations; ++i) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
object_detector->DetectForVideo(image, i));
|
||||
std::vector<Detection> full_expected_results =
|
||||
std::vector<DetectionProto> full_expected_results =
|
||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||
ExpectApproximatelyEqual(
|
||||
results, {full_expected_results[0], full_expected_results[1]});
|
||||
results, ConvertToDetectionResult(
|
||||
{full_expected_results[0], full_expected_results[1]}));
|
||||
}
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
}
|
||||
|
@ -637,9 +653,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->result_callback =
|
||||
[](absl::StatusOr<std::vector<Detection>> detections, const Image& image,
|
||||
int64 timestamp_ms) {};
|
||||
options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||
ObjectDetector::Create(std::move(options)));
|
||||
|
@ -669,9 +684,8 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
|||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||
options->result_callback =
|
||||
[](absl::StatusOr<std::vector<Detection>> detections, const Image& image,
|
||||
int64 timestamp_ms) {};
|
||||
options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||
ObjectDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK(object_detector->DetectAsync(image, 1));
|
||||
|
@ -695,14 +709,14 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
|||
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||
options->max_results = 2;
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
std::vector<std::vector<Detection>> detection_results;
|
||||
std::vector<ObjectDetectorResult> detection_results;
|
||||
std::vector<std::pair<int, int>> image_sizes;
|
||||
std::vector<int64> timestamps;
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||
options->result_callback =
|
||||
[&detection_results, &image_sizes, ×tamps](
|
||||
absl::StatusOr<std::vector<Detection>> detections, const Image& image,
|
||||
absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
|
||||
int64 timestamp_ms) {
|
||||
MP_ASSERT_OK(detections.status());
|
||||
detection_results.push_back(std::move(detections).value());
|
||||
|
@ -719,11 +733,12 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
|||
// number of iterations.
|
||||
ASSERT_LE(detection_results.size(), iterations);
|
||||
ASSERT_GT(detection_results.size(), 0);
|
||||
std::vector<Detection> full_expected_results =
|
||||
std::vector<DetectionProto> full_expected_results =
|
||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||
for (const auto& detection_result : detection_results) {
|
||||
ExpectApproximatelyEqual(
|
||||
detection_result, {full_expected_results[0], full_expected_results[1]});
|
||||
detection_result, ConvertToDetectionResult({full_expected_results[0],
|
||||
full_expected_results[1]}));
|
||||
}
|
||||
for (const auto& image_size : image_sizes) {
|
||||
EXPECT_EQ(image_size.first, image.width());
|
||||
|
|
|
@ -22,13 +22,13 @@ limitations under the License.
|
|||
|
||||
namespace mediapipe::tasks::vision::utils {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::RectF;
|
||||
|
||||
float CalculateArea(const Rect& rect) {
|
||||
float CalculateArea(const RectF& rect) {
|
||||
return (rect.right - rect.left) * (rect.bottom - rect.top);
|
||||
}
|
||||
|
||||
float CalculateIntersectionArea(const Rect& a, const Rect& b) {
|
||||
float CalculateIntersectionArea(const RectF& a, const RectF& b) {
|
||||
const float intersection_left = std::max<float>(a.left, b.left);
|
||||
const float intersection_top = std::max<float>(a.top, b.top);
|
||||
const float intersection_right = std::min<float>(a.right, b.right);
|
||||
|
@ -38,7 +38,7 @@ float CalculateIntersectionArea(const Rect& a, const Rect& b) {
|
|||
std::max<float>(intersection_right - intersection_left, 0.0);
|
||||
}
|
||||
|
||||
float CalculateIOU(const Rect& a, const Rect& b) {
|
||||
float CalculateIOU(const RectF& a, const RectF& b) {
|
||||
const float area_a = CalculateArea(a);
|
||||
const float area_b = CalculateArea(b);
|
||||
if (area_a <= 0 || area_b <= 0) return 0.0;
|
||||
|
|
|
@ -27,15 +27,15 @@ limitations under the License.
|
|||
namespace mediapipe::tasks::vision::utils {
|
||||
|
||||
// Calculates intersection over union for two bounds.
|
||||
float CalculateIOU(const components::containers::Rect& a,
|
||||
const components::containers::Rect& b);
|
||||
float CalculateIOU(const components::containers::RectF& a,
|
||||
const components::containers::RectF& b);
|
||||
|
||||
// Calculates area for face bound
|
||||
float CalculateArea(const components::containers::Rect& rect);
|
||||
float CalculateArea(const components::containers::RectF& rect);
|
||||
|
||||
// Calucates intersection area of two face bounds
|
||||
float CalculateIntersectionArea(const components::containers::Rect& a,
|
||||
const components::containers::Rect& b);
|
||||
float CalculateIntersectionArea(const components::containers::RectF& a,
|
||||
const components::containers::RectF& b);
|
||||
} // namespace mediapipe::tasks::vision::utils
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_
|
||||
|
|
Loading…
Reference in New Issue
Block a user