internal change
PiperOrigin-RevId: 493742399
This commit is contained in:
parent
a59f0a9924
commit
a0efcb47f2
|
@ -18,6 +18,7 @@ licenses(["notice"])
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "rect",
|
name = "rect",
|
||||||
|
srcs = ["rect.cc"],
|
||||||
hdrs = ["rect.h"],
|
hdrs = ["rect.h"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,6 +42,18 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "detection_result",
|
||||||
|
srcs = ["detection_result.cc"],
|
||||||
|
hdrs = ["detection_result.h"],
|
||||||
|
deps = [
|
||||||
|
":category",
|
||||||
|
":rect",
|
||||||
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "embedding_result",
|
name = "embedding_result",
|
||||||
srcs = ["embedding_result.cc"],
|
srcs = ["embedding_result.cc"],
|
||||||
|
|
73
mediapipe/tasks/cc/components/containers/detection_result.cc
Normal file
73
mediapipe/tasks/cc/components/containers/detection_result.cc
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||||
|
|
||||||
|
#include <strings.h>
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
constexpr int kDefaultCategoryIndex = -1;
|
||||||
|
|
||||||
|
Detection ConvertToDetectionResult(
|
||||||
|
const mediapipe::Detection& detection_proto) {
|
||||||
|
Detection detection;
|
||||||
|
for (int idx = 0; idx < detection_proto.score_size(); ++idx) {
|
||||||
|
detection.categories.push_back(
|
||||||
|
{/* index= */ detection_proto.label_id_size() > idx
|
||||||
|
? detection_proto.label_id(idx)
|
||||||
|
: kDefaultCategoryIndex,
|
||||||
|
/* score= */ detection_proto.score(idx),
|
||||||
|
/* category_name */ detection_proto.label_size() > idx
|
||||||
|
? detection_proto.label(idx)
|
||||||
|
: "",
|
||||||
|
/* display_name */ detection_proto.display_name_size() > idx
|
||||||
|
? detection_proto.display_name(idx)
|
||||||
|
: ""});
|
||||||
|
}
|
||||||
|
Rect bounding_box;
|
||||||
|
if (detection_proto.location_data().has_bounding_box()) {
|
||||||
|
mediapipe::LocationData::BoundingBox bounding_box_proto =
|
||||||
|
detection_proto.location_data().bounding_box();
|
||||||
|
bounding_box.left = bounding_box_proto.xmin();
|
||||||
|
bounding_box.top = bounding_box_proto.ymin();
|
||||||
|
bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width();
|
||||||
|
bounding_box.bottom =
|
||||||
|
bounding_box_proto.ymin() + bounding_box_proto.height();
|
||||||
|
}
|
||||||
|
detection.bounding_box = bounding_box;
|
||||||
|
return detection;
|
||||||
|
}
|
||||||
|
|
||||||
|
DetectionResult ConvertToDetectionResult(
|
||||||
|
std::vector<mediapipe::Detection> detections_proto) {
|
||||||
|
DetectionResult detection_result;
|
||||||
|
detection_result.detections.reserve(detections_proto.size());
|
||||||
|
for (const auto& detection_proto : detections_proto) {
|
||||||
|
detection_result.detections.push_back(
|
||||||
|
ConvertToDetectionResult(detection_proto));
|
||||||
|
}
|
||||||
|
return detection_result;
|
||||||
|
}
|
||||||
|
} // namespace mediapipe::tasks::components::containers
|
52
mediapipe/tasks/cc/components/containers/detection_result.h
Normal file
52
mediapipe/tasks/cc/components/containers/detection_result.h
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
// Detection for a single bounding box.
|
||||||
|
struct Detection {
|
||||||
|
// A vector of detected categories.
|
||||||
|
std::vector<Category> categories;
|
||||||
|
// The bounding box location.
|
||||||
|
Rect bounding_box;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Detection results of a model.
|
||||||
|
struct DetectionResult {
|
||||||
|
// A vector of Detections.
|
||||||
|
std::vector<Detection> detections;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Utility function to convert from Detection proto to Detection struct.
|
||||||
|
Detection ConvertToDetection(const mediapipe::Detection& detection_proto);
|
||||||
|
|
||||||
|
// Utility function to convert from list of Detection proto to DetectionResult
|
||||||
|
// struct.
|
||||||
|
DetectionResult ConvertToDetectionResult(
|
||||||
|
std::vector<mediapipe::Detection> detections_proto);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::components::containers
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
34
mediapipe/tasks/cc/components/containers/rect.cc
Normal file
34
mediapipe/tasks/cc/components/containers/rect.cc
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
RectF ToRectF(const Rect& rect, int image_height, int image_width) {
|
||||||
|
return RectF{static_cast<float>(rect.left) / image_width,
|
||||||
|
static_cast<float>(rect.top) / image_height,
|
||||||
|
static_cast<float>(rect.right) / image_width,
|
||||||
|
static_cast<float>(rect.bottom) / image_height};
|
||||||
|
}
|
||||||
|
|
||||||
|
Rect ToRect(const RectF& rect, int image_height, int image_width) {
|
||||||
|
return Rect{static_cast<int>(rect.left * image_width),
|
||||||
|
static_cast<int>(rect.top * image_height),
|
||||||
|
static_cast<int>(rect.right * image_width),
|
||||||
|
static_cast<int>(rect.bottom * image_height)};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::components::containers
|
|
@ -16,20 +16,47 @@ limitations under the License.
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
namespace mediapipe::tasks::components::containers {
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
constexpr float kRectFTolerance = 1e-4;
|
||||||
|
|
||||||
// Defines a rectangle, used e.g. as part of detection results or as input
|
// Defines a rectangle, used e.g. as part of detection results or as input
|
||||||
// region-of-interest.
|
// region-of-interest.
|
||||||
//
|
//
|
||||||
|
struct Rect {
|
||||||
|
int left;
|
||||||
|
int top;
|
||||||
|
int right;
|
||||||
|
int bottom;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline bool operator==(const Rect& lhs, const Rect& rhs) {
|
||||||
|
return lhs.left == rhs.left && lhs.top == rhs.top && lhs.right == rhs.right &&
|
||||||
|
lhs.bottom == rhs.bottom;
|
||||||
|
}
|
||||||
|
|
||||||
// The coordinates are normalized wrt the image dimensions, i.e. generally in
|
// The coordinates are normalized wrt the image dimensions, i.e. generally in
|
||||||
// [0,1] but they may exceed these bounds if describing a region overlapping the
|
// [0,1] but they may exceed these bounds if describing a region overlapping the
|
||||||
// image. The origin is on the top-left corner of the image.
|
// image. The origin is on the top-left corner of the image.
|
||||||
struct Rect {
|
struct RectF {
|
||||||
float left;
|
float left;
|
||||||
float top;
|
float top;
|
||||||
float right;
|
float right;
|
||||||
float bottom;
|
float bottom;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool operator==(const RectF& lhs, const RectF& rhs) {
|
||||||
|
return abs(lhs.left - rhs.left) < kRectFTolerance &&
|
||||||
|
abs(lhs.top - rhs.top) < kRectFTolerance &&
|
||||||
|
abs(lhs.right - rhs.right) < kRectFTolerance &&
|
||||||
|
abs(lhs.bottom - rhs.bottom) < kRectFTolerance;
|
||||||
|
}
|
||||||
|
|
||||||
|
RectF ToRectF(const Rect& rect, int image_height, int image_width);
|
||||||
|
|
||||||
|
Rect ToRect(const RectF& rect, int image_height, int image_width);
|
||||||
|
|
||||||
} // namespace mediapipe::tasks::components::containers
|
} // namespace mediapipe::tasks::components::containers
|
||||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
||||||
|
|
|
@ -129,13 +129,13 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi {
|
||||||
if (roi.left >= roi.right || roi.top >= roi.bottom) {
|
if (roi.left >= roi.right || roi.top >= roi.bottom) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
"Expected Rect with left < right and top < bottom.",
|
"Expected RectF with left < right and top < bottom.",
|
||||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
|
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
|
||||||
}
|
}
|
||||||
if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) {
|
if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
"Expected Rect values to be in [0,1].",
|
"Expected RectF values to be in [0,1].",
|
||||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
|
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
|
||||||
}
|
}
|
||||||
normalized_rect.set_x_center((roi.left + roi.right) / 2.0);
|
normalized_rect.set_x_center((roi.left + roi.right) / 2.0);
|
||||||
|
|
|
@ -35,7 +35,8 @@ struct ImageProcessingOptions {
|
||||||
// the full image is used.
|
// the full image is used.
|
||||||
//
|
//
|
||||||
// Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom.
|
// Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom.
|
||||||
std::optional<components::containers::Rect> region_of_interest = std::nullopt;
|
std::optional<components::containers::RectF> region_of_interest =
|
||||||
|
std::nullopt;
|
||||||
|
|
||||||
// The rotation to apply to the image (or cropped region-of-interest), in
|
// The rotation to apply to the image (or cropped region-of-interest), in
|
||||||
// degrees clockwise.
|
// degrees clockwise.
|
||||||
|
|
|
@ -44,7 +44,7 @@ namespace {
|
||||||
using ::mediapipe::api2::Input;
|
using ::mediapipe::api2::Input;
|
||||||
using ::mediapipe::api2::Output;
|
using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::components::containers::Rect;
|
using ::mediapipe::tasks::components::containers::RectF;
|
||||||
using ::mediapipe::tasks::vision::utils::CalculateIOU;
|
using ::mediapipe::tasks::vision::utils::CalculateIOU;
|
||||||
using ::mediapipe::tasks::vision::utils::DuplicatesFinder;
|
using ::mediapipe::tasks::vision::utils::DuplicatesFinder;
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ absl::StatusOr<float> HandBaselineDistance(
|
||||||
return distance;
|
return distance;
|
||||||
}
|
}
|
||||||
|
|
||||||
Rect CalculateBound(const NormalizedLandmarkList& list) {
|
RectF CalculateBound(const NormalizedLandmarkList& list) {
|
||||||
constexpr float kMinInitialValue = std::numeric_limits<float>::max();
|
constexpr float kMinInitialValue = std::numeric_limits<float>::max();
|
||||||
constexpr float kMaxInitialValue = std::numeric_limits<float>::lowest();
|
constexpr float kMaxInitialValue = std::numeric_limits<float>::lowest();
|
||||||
|
|
||||||
|
@ -144,7 +144,7 @@ Rect CalculateBound(const NormalizedLandmarkList& list) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Populate normalized non rotated face bounding box
|
// Populate normalized non rotated face bounding box
|
||||||
return Rect{/*left=*/bounding_box_left,
|
return RectF{/*left=*/bounding_box_left,
|
||||||
/*top=*/bounding_box_top,
|
/*top=*/bounding_box_top,
|
||||||
/*right=*/bounding_box_right,
|
/*right=*/bounding_box_right,
|
||||||
/*bottom=*/bounding_box_bottom};
|
/*bottom=*/bounding_box_bottom};
|
||||||
|
@ -172,7 +172,7 @@ class HandDuplicatesFinder : public DuplicatesFinder {
|
||||||
const int num = multi_landmarks.size();
|
const int num = multi_landmarks.size();
|
||||||
std::vector<float> baseline_distances;
|
std::vector<float> baseline_distances;
|
||||||
baseline_distances.reserve(num);
|
baseline_distances.reserve(num);
|
||||||
std::vector<Rect> bounds;
|
std::vector<RectF> bounds;
|
||||||
bounds.reserve(num);
|
bounds.reserve(num);
|
||||||
for (const NormalizedLandmarkList& list : multi_landmarks) {
|
for (const NormalizedLandmarkList& list : multi_landmarks) {
|
||||||
ASSIGN_OR_RETURN(const float baseline_distance,
|
ASSIGN_OR_RETURN(const float baseline_distance,
|
||||||
|
|
|
@ -50,7 +50,7 @@ namespace {
|
||||||
|
|
||||||
using ::file::Defaults;
|
using ::file::Defaults;
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::containers::Rect;
|
using ::mediapipe::tasks::components::containers::RectF;
|
||||||
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
|
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
|
||||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
using ::testing::EqualsProto;
|
using ::testing::EqualsProto;
|
||||||
|
@ -188,7 +188,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
options->running_mode = core::RunningMode::IMAGE;
|
options->running_mode = core::RunningMode::IMAGE;
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||||
HandLandmarker::Create(std::move(options)));
|
HandLandmarker::Create(std::move(options)));
|
||||||
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||||
|
|
||||||
auto results = hand_landmarker->Detect(image, image_processing_options);
|
auto results = hand_landmarker->Detect(image, image_processing_options);
|
||||||
|
|
|
@ -52,7 +52,7 @@ namespace {
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::containers::Category;
|
using ::mediapipe::tasks::components::containers::Category;
|
||||||
using ::mediapipe::tasks::components::containers::Classifications;
|
using ::mediapipe::tasks::components::containers::Classifications;
|
||||||
using ::mediapipe::tasks::components::containers::Rect;
|
using ::mediapipe::tasks::components::containers::RectF;
|
||||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
@ -472,7 +472,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||||
ImageClassifier::Create(std::move(options)));
|
ImageClassifier::Create(std::move(options)));
|
||||||
// Region-of-interest around the soccer ball.
|
// Region-of-interest around the soccer ball.
|
||||||
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||||
|
@ -526,7 +526,8 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||||
ImageClassifier::Create(std::move(options)));
|
ImageClassifier::Create(std::move(options)));
|
||||||
// Region-of-interest around the chair, with 90° anti-clockwise rotation.
|
// Region-of-interest around the chair, with 90° anti-clockwise rotation.
|
||||||
Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049};
|
RectF roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702,
|
||||||
|
/*bottom=*/0.3049};
|
||||||
ImageProcessingOptions image_processing_options{roi,
|
ImageProcessingOptions image_processing_options{roi,
|
||||||
/*rotation_degrees=*/-90};
|
/*rotation_degrees=*/-90};
|
||||||
|
|
||||||
|
@ -554,13 +555,13 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
|
||||||
ImageClassifier::Create(std::move(options)));
|
ImageClassifier::Create(std::move(options)));
|
||||||
|
|
||||||
// Invalid: left > right.
|
// Invalid: left > right.
|
||||||
Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1};
|
RectF roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1};
|
||||||
ImageProcessingOptions image_processing_options{roi,
|
ImageProcessingOptions image_processing_options{roi,
|
||||||
/*rotation_degrees=*/0};
|
/*rotation_degrees=*/0};
|
||||||
auto results = image_classifier->Classify(image, image_processing_options);
|
auto results = image_classifier->Classify(image, image_processing_options);
|
||||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
EXPECT_THAT(results.status().message(),
|
EXPECT_THAT(results.status().message(),
|
||||||
HasSubstr("Expected Rect with left < right and top < bottom"));
|
HasSubstr("Expected RectF with left < right and top < bottom"));
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
Optional(absl::Cord(absl::StrCat(
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
@ -573,7 +574,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
|
||||||
results = image_classifier->Classify(image, image_processing_options);
|
results = image_classifier->Classify(image, image_processing_options);
|
||||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
EXPECT_THAT(results.status().message(),
|
EXPECT_THAT(results.status().message(),
|
||||||
HasSubstr("Expected Rect with left < right and top < bottom"));
|
HasSubstr("Expected RectF with left < right and top < bottom"));
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
Optional(absl::Cord(absl::StrCat(
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
@ -586,7 +587,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
|
||||||
results = image_classifier->Classify(image, image_processing_options);
|
results = image_classifier->Classify(image, image_processing_options);
|
||||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
EXPECT_THAT(results.status().message(),
|
EXPECT_THAT(results.status().message(),
|
||||||
HasSubstr("Expected Rect values to be in [0,1]"));
|
HasSubstr("Expected RectF values to be in [0,1]"));
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
Optional(absl::Cord(absl::StrCat(
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
@ -695,7 +696,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
|
||||||
ImageClassifier::Create(std::move(options)));
|
ImageClassifier::Create(std::move(options)));
|
||||||
// Crop around the soccer ball.
|
// Crop around the soccer ball.
|
||||||
// Region-of-interest around the soccer ball.
|
// Region-of-interest around the soccer ball.
|
||||||
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||||
|
|
||||||
for (int i = 0; i < iterations; ++i) {
|
for (int i = 0; i < iterations; ++i) {
|
||||||
|
@ -837,7 +838,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||||
ImageClassifier::Create(std::move(options)));
|
ImageClassifier::Create(std::move(options)));
|
||||||
// Crop around the soccer ball.
|
// Crop around the soccer ball.
|
||||||
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||||
|
|
||||||
for (int i = 0; i < iterations; ++i) {
|
for (int i = 0; i < iterations; ++i) {
|
||||||
|
|
|
@ -41,7 +41,7 @@ namespace image_embedder {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::containers::Rect;
|
using ::mediapipe::tasks::components::containers::RectF;
|
||||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
@ -320,7 +320,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
||||||
Image crop, DecodeImageFromFile(
|
Image crop, DecodeImageFromFile(
|
||||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||||
// Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
// Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||||
Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
|
RectF roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
|
||||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
|
@ -388,7 +388,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||||
"burger_rotated.jpg")));
|
"burger_rotated.jpg")));
|
||||||
// Region-of-interest corresponding to burger_crop.jpg.
|
// Region-of-interest corresponding to burger_crop.jpg.
|
||||||
Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
|
RectF roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
|
||||||
ImageProcessingOptions image_processing_options{roi,
|
ImageProcessingOptions image_processing_options{roi,
|
||||||
/*rotation_degrees=*/-90};
|
/*rotation_degrees=*/-90};
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,7 @@ namespace {
|
||||||
|
|
||||||
using ::mediapipe::Image;
|
using ::mediapipe::Image;
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::containers::Rect;
|
using ::mediapipe::tasks::components::containers::RectF;
|
||||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
@ -299,7 +299,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||||
|
|
||||||
auto results = segmenter->Segment(image, image_processing_options);
|
auto results = segmenter->Segment(image, image_processing_options);
|
||||||
|
|
|
@ -33,6 +33,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:detection_result",
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
"//mediapipe/tasks/cc/core:utils",
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||||
|
@ -56,6 +57,7 @@ constexpr char kSubgraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.ObjectDetectorGraph";
|
"mediapipe.tasks.vision.ObjectDetectorGraph";
|
||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::ConvertToDetectionResult;
|
||||||
using ObjectDetectorOptionsProto =
|
using ObjectDetectorOptionsProto =
|
||||||
object_detector::proto::ObjectDetectorOptions;
|
object_detector::proto::ObjectDetectorOptions;
|
||||||
|
|
||||||
|
@ -129,7 +131,8 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
|
||||||
Packet detections_packet =
|
Packet detections_packet =
|
||||||
status_or_packets.value()[kDetectionsOutStreamName];
|
status_or_packets.value()[kDetectionsOutStreamName];
|
||||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||||
result_callback(detections_packet.Get<std::vector<Detection>>(),
|
result_callback(ConvertToDetectionResult(
|
||||||
|
detections_packet.Get<std::vector<Detection>>()),
|
||||||
image_packet.Get<Image>(),
|
image_packet.Get<Image>(),
|
||||||
detections_packet.Timestamp().Value() /
|
detections_packet.Timestamp().Value() /
|
||||||
kMicroSecondsPerMilliSecond);
|
kMicroSecondsPerMilliSecond);
|
||||||
|
@ -144,7 +147,7 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
|
||||||
std::move(packets_callback));
|
std::move(packets_callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
|
absl::StatusOr<ObjectDetectorResult> ObjectDetector::Detect(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -161,10 +164,11 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
|
||||||
ProcessImageData(
|
ProcessImageData(
|
||||||
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
||||||
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
|
return ConvertToDetectionResult(
|
||||||
|
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
|
absl::StatusOr<ObjectDetectorResult> ObjectDetector::DetectForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms,
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -185,7 +189,8 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
|
||||||
{kNormRectName,
|
{kNormRectName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||||
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
|
return ConvertToDetectionResult(
|
||||||
|
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ObjectDetector::DetectAsync(
|
absl::Status ObjectDetector::DetectAsync(
|
||||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
|
@ -36,6 +37,10 @@ namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
|
||||||
|
// Alias the shared DetectionResult struct as result typo.
|
||||||
|
using ObjectDetectorResult =
|
||||||
|
::mediapipe::tasks::components::containers::DetectionResult;
|
||||||
|
|
||||||
// The options for configuring a mediapipe object detector task.
|
// The options for configuring a mediapipe object detector task.
|
||||||
struct ObjectDetectorOptions {
|
struct ObjectDetectorOptions {
|
||||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||||
|
@ -79,8 +84,7 @@ struct ObjectDetectorOptions {
|
||||||
// The user-defined result callback for processing live stream data.
|
// The user-defined result callback for processing live stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::LIVE_STREAM.
|
// to RunningMode::LIVE_STREAM.
|
||||||
std::function<void(absl::StatusOr<std::vector<mediapipe::Detection>>,
|
std::function<void(absl::StatusOr<ObjectDetectorResult>, const Image&, int64)>
|
||||||
const Image&, int64)>
|
|
||||||
result_callback = nullptr;
|
result_callback = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -165,7 +169,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// underlying image data.
|
// underlying image data.
|
||||||
// TODO: Describes the output bounding boxes for gpu input
|
// TODO: Describes the output bounding boxes for gpu input
|
||||||
// images after enabling the gpu support in MediaPipe Tasks.
|
// images after enabling the gpu support in MediaPipe Tasks.
|
||||||
absl::StatusOr<std::vector<mediapipe::Detection>> Detect(
|
absl::StatusOr<ObjectDetectorResult> Detect(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
@ -188,7 +192,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// unrotated input frame of reference coordinates system, i.e. in `[0,
|
// unrotated input frame of reference coordinates system, i.e. in `[0,
|
||||||
// image_width) x [0, image_height)`, which are the dimensions of the
|
// image_width) x [0, image_height)`, which are the dimensions of the
|
||||||
// underlying image data.
|
// underlying image data.
|
||||||
absl::StatusOr<std::vector<mediapipe::Detection>> DetectForVideo(
|
absl::StatusOr<ObjectDetectorResult> DetectForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms,
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
|
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
@ -65,10 +66,14 @@ namespace vision {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::containers::Rect;
|
using ::mediapipe::tasks::components::containers::ConvertToDetectionResult;
|
||||||
|
using ::mediapipe::tasks::components::containers::Detection;
|
||||||
|
using ::mediapipe::tasks::components::containers::DetectionResult;
|
||||||
|
using ::mediapipe::tasks::components::containers::RectF;
|
||||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
using DetectionProto = mediapipe::Detection;
|
||||||
|
|
||||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||||
constexpr char kMobileSsdWithMetadata[] =
|
constexpr char kMobileSsdWithMetadata[] =
|
||||||
|
@ -83,47 +88,45 @@ constexpr char kEfficientDetWithMetadata[] =
|
||||||
// Checks that the two provided `Detection` proto vectors are equal, with a
|
// Checks that the two provided `Detection` proto vectors are equal, with a
|
||||||
// tolerancy on floating-point scores to account for numerical instabilities.
|
// tolerancy on floating-point scores to account for numerical instabilities.
|
||||||
// If the proto definition changes, please also change this function.
|
// If the proto definition changes, please also change this function.
|
||||||
void ExpectApproximatelyEqual(const std::vector<Detection>& actual,
|
void ExpectApproximatelyEqual(const ObjectDetectorResult& actual,
|
||||||
const std::vector<Detection>& expected) {
|
const ObjectDetectorResult& expected) {
|
||||||
const float kPrecision = 1e-6;
|
const float kPrecision = 1e-6;
|
||||||
EXPECT_EQ(actual.size(), expected.size());
|
EXPECT_EQ(actual.detections.size(), expected.detections.size());
|
||||||
for (int i = 0; i < actual.size(); ++i) {
|
for (int i = 0; i < actual.detections.size(); ++i) {
|
||||||
const Detection& a = actual[i];
|
const Detection& a = actual.detections[i];
|
||||||
const Detection& b = expected[i];
|
const Detection& b = expected.detections[i];
|
||||||
EXPECT_THAT(a.location_data().bounding_box(),
|
EXPECT_EQ(a.bounding_box, b.bounding_box);
|
||||||
EqualsProto(b.location_data().bounding_box()));
|
EXPECT_EQ(a.categories.size(), 1);
|
||||||
EXPECT_EQ(a.label_size(), 1);
|
EXPECT_EQ(b.categories.size(), 1);
|
||||||
EXPECT_EQ(b.label_size(), 1);
|
EXPECT_EQ(a.categories[0].category_name, b.categories[0].category_name);
|
||||||
EXPECT_EQ(a.label(0), b.label(0));
|
EXPECT_NEAR(a.categories[0].score, b.categories[0].score, kPrecision);
|
||||||
EXPECT_EQ(a.score_size(), 1);
|
|
||||||
EXPECT_EQ(b.score_size(), 1);
|
|
||||||
EXPECT_NEAR(a.score(0), b.score(0), kPrecision);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Detection> GenerateMobileSsdNoImageResizingFullExpectedResults() {
|
std::vector<DetectionProto>
|
||||||
return {ParseTextProtoOrDie<Detection>(R"pb(
|
GenerateMobileSsdNoImageResizingFullExpectedResults() {
|
||||||
|
return {ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.6328125
|
score: 0.6328125
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
|
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
|
||||||
})pb"),
|
})pb"),
|
||||||
ParseTextProtoOrDie<Detection>(R"pb(
|
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.59765625
|
score: 0.59765625
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 }
|
bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 }
|
||||||
})pb"),
|
})pb"),
|
||||||
ParseTextProtoOrDie<Detection>(R"pb(
|
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.5
|
score: 0.5
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 }
|
bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 }
|
||||||
})pb"),
|
})pb"),
|
||||||
ParseTextProtoOrDie<Detection>(R"pb(
|
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "dog"
|
label: "dog"
|
||||||
score: 0.48828125
|
score: 0.48828125
|
||||||
location_data {
|
location_data {
|
||||||
|
@ -263,8 +266,8 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) {
|
||||||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||||
options->running_mode = running_mode;
|
options->running_mode = running_mode;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[](absl::StatusOr<std::vector<Detection>> detections,
|
[](absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
|
||||||
const Image& image, int64 timestamp_ms) {};
|
int64 timestamp_ms) {};
|
||||||
absl::StatusOr<std::unique_ptr<ObjectDetector>> object_detector =
|
absl::StatusOr<std::unique_ptr<ObjectDetector>> object_detector =
|
||||||
ObjectDetector::Create(std::move(options));
|
ObjectDetector::Create(std::move(options));
|
||||||
EXPECT_EQ(object_detector.status().code(),
|
EXPECT_EQ(object_detector.status().code(),
|
||||||
|
@ -340,34 +343,36 @@ TEST_F(ImageModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
ExpectApproximatelyEqual(
|
ExpectApproximatelyEqual(
|
||||||
results, {ParseTextProtoOrDie<Detection>(R"pb(
|
results,
|
||||||
|
ConvertToDetectionResult(
|
||||||
|
{ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.69921875
|
score: 0.69921875
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 }
|
bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 }
|
||||||
})pb"),
|
})pb"),
|
||||||
ParseTextProtoOrDie<Detection>(R"pb(
|
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.64453125
|
score: 0.64453125
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 }
|
bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 }
|
||||||
})pb"),
|
})pb"),
|
||||||
ParseTextProtoOrDie<Detection>(R"pb(
|
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.51171875
|
score: 0.51171875
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 }
|
bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 }
|
||||||
})pb"),
|
})pb"),
|
||||||
ParseTextProtoOrDie<Detection>(R"pb(
|
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.48828125
|
score: 0.48828125
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 }
|
bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 }
|
||||||
})pb")});
|
})pb")}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsEfficientDetModel) {
|
TEST_F(ImageModeTest, SucceedsEfficientDetModel) {
|
||||||
|
@ -383,34 +388,36 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
ExpectApproximatelyEqual(
|
ExpectApproximatelyEqual(
|
||||||
results, {ParseTextProtoOrDie<Detection>(R"pb(
|
results,
|
||||||
|
ConvertToDetectionResult(
|
||||||
|
{ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.7578125
|
score: 0.7578125
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 }
|
bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 }
|
||||||
})pb"),
|
})pb"),
|
||||||
ParseTextProtoOrDie<Detection>(R"pb(
|
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.72265625
|
score: 0.72265625
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 }
|
bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 }
|
||||||
})pb"),
|
})pb"),
|
||||||
ParseTextProtoOrDie<Detection>(R"pb(
|
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.6289063
|
score: 0.6289063
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 }
|
bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 }
|
||||||
})pb"),
|
})pb"),
|
||||||
ParseTextProtoOrDie<Detection>(R"pb(
|
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.5859375
|
score: 0.5859375
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 }
|
bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 }
|
||||||
})pb")});
|
})pb")}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
|
TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
|
||||||
|
@ -426,7 +433,8 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
ExpectApproximatelyEqual(
|
ExpectApproximatelyEqual(
|
||||||
results, GenerateMobileSsdNoImageResizingFullExpectedResults());
|
results, ConvertToDetectionResult(
|
||||||
|
GenerateMobileSsdNoImageResizingFullExpectedResults()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
|
TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
|
||||||
|
@ -442,13 +450,14 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
ExpectApproximatelyEqual(
|
ExpectApproximatelyEqual(
|
||||||
results, {ParseTextProtoOrDie<Detection>(R"pb(
|
results,
|
||||||
|
ConvertToDetectionResult({ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.6531269142
|
score: 0.6531269142
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
|
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
|
||||||
})pb")});
|
})pb")}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
|
TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
|
||||||
|
@ -463,11 +472,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
|
||||||
ObjectDetector::Create(std::move(options)));
|
ObjectDetector::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
std::vector<Detection> full_expected_results =
|
std::vector<DetectionProto> full_expected_results =
|
||||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||||
ExpectApproximatelyEqual(results,
|
|
||||||
{full_expected_results[0], full_expected_results[1],
|
ExpectApproximatelyEqual(
|
||||||
full_expected_results[2]});
|
results, ConvertToDetectionResult({full_expected_results[0],
|
||||||
|
full_expected_results[1],
|
||||||
|
full_expected_results[2]}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
|
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
|
||||||
|
@ -482,10 +493,11 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
|
||||||
ObjectDetector::Create(std::move(options)));
|
ObjectDetector::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
std::vector<Detection> full_expected_results =
|
std::vector<DetectionProto> full_expected_results =
|
||||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||||
ExpectApproximatelyEqual(
|
ExpectApproximatelyEqual(
|
||||||
results, {full_expected_results[0], full_expected_results[1]});
|
results, ConvertToDetectionResult(
|
||||||
|
{full_expected_results[0], full_expected_results[1]}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
|
TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
|
||||||
|
@ -501,9 +513,10 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
|
||||||
ObjectDetector::Create(std::move(options)));
|
ObjectDetector::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
std::vector<Detection> full_expected_results =
|
std::vector<DetectionProto> full_expected_results =
|
||||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||||
ExpectApproximatelyEqual(results, {full_expected_results[3]});
|
ExpectApproximatelyEqual(
|
||||||
|
results, ConvertToDetectionResult({full_expected_results[3]}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
|
TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
|
||||||
|
@ -519,9 +532,10 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
|
||||||
ObjectDetector::Create(std::move(options)));
|
ObjectDetector::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
std::vector<Detection> full_expected_results =
|
std::vector<DetectionProto> full_expected_results =
|
||||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||||
ExpectApproximatelyEqual(results, {full_expected_results[3]});
|
ExpectApproximatelyEqual(
|
||||||
|
results, ConvertToDetectionResult({full_expected_results[3]}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||||
|
@ -541,13 +555,14 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||||
auto results, object_detector->Detect(image, image_processing_options));
|
auto results, object_detector->Detect(image, image_processing_options));
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
ExpectApproximatelyEqual(
|
ExpectApproximatelyEqual(
|
||||||
results, {ParseTextProtoOrDie<Detection>(R"pb(
|
results,
|
||||||
|
ConvertToDetectionResult({ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||||
label: "cat"
|
label: "cat"
|
||||||
score: 0.7109375
|
score: 0.7109375
|
||||||
location_data {
|
location_data {
|
||||||
format: BOUNDING_BOX
|
format: BOUNDING_BOX
|
||||||
bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 }
|
bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 }
|
||||||
})pb")});
|
})pb")}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
|
@ -560,7 +575,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||||
ObjectDetector::Create(std::move(options)));
|
ObjectDetector::Create(std::move(options)));
|
||||||
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||||
|
|
||||||
auto results = object_detector->Detect(image, image_processing_options);
|
auto results = object_detector->Detect(image, image_processing_options);
|
||||||
|
@ -619,10 +634,11 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
for (int i = 0; i < iterations; ++i) {
|
for (int i = 0; i < iterations; ++i) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
object_detector->DetectForVideo(image, i));
|
object_detector->DetectForVideo(image, i));
|
||||||
std::vector<Detection> full_expected_results =
|
std::vector<DetectionProto> full_expected_results =
|
||||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||||
ExpectApproximatelyEqual(
|
ExpectApproximatelyEqual(
|
||||||
results, {full_expected_results[0], full_expected_results[1]});
|
results, ConvertToDetectionResult(
|
||||||
|
{full_expected_results[0], full_expected_results[1]}));
|
||||||
}
|
}
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
}
|
}
|
||||||
|
@ -637,9 +653,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback =
|
options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
|
||||||
[](absl::StatusOr<std::vector<Detection>> detections, const Image& image,
|
const Image& image, int64 timestamp_ms) {};
|
||||||
int64 timestamp_ms) {};
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||||
ObjectDetector::Create(std::move(options)));
|
ObjectDetector::Create(std::move(options)));
|
||||||
|
@ -669,9 +684,8 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||||
options->result_callback =
|
options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
|
||||||
[](absl::StatusOr<std::vector<Detection>> detections, const Image& image,
|
const Image& image, int64 timestamp_ms) {};
|
||||||
int64 timestamp_ms) {};
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||||
ObjectDetector::Create(std::move(options)));
|
ObjectDetector::Create(std::move(options)));
|
||||||
MP_ASSERT_OK(object_detector->DetectAsync(image, 1));
|
MP_ASSERT_OK(object_detector->DetectAsync(image, 1));
|
||||||
|
@ -695,14 +709,14 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
auto options = std::make_unique<ObjectDetectorOptions>();
|
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||||
options->max_results = 2;
|
options->max_results = 2;
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
std::vector<std::vector<Detection>> detection_results;
|
std::vector<ObjectDetectorResult> detection_results;
|
||||||
std::vector<std::pair<int, int>> image_sizes;
|
std::vector<std::pair<int, int>> image_sizes;
|
||||||
std::vector<int64> timestamps;
|
std::vector<int64> timestamps;
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[&detection_results, &image_sizes, ×tamps](
|
[&detection_results, &image_sizes, ×tamps](
|
||||||
absl::StatusOr<std::vector<Detection>> detections, const Image& image,
|
absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
|
||||||
int64 timestamp_ms) {
|
int64 timestamp_ms) {
|
||||||
MP_ASSERT_OK(detections.status());
|
MP_ASSERT_OK(detections.status());
|
||||||
detection_results.push_back(std::move(detections).value());
|
detection_results.push_back(std::move(detections).value());
|
||||||
|
@ -719,11 +733,12 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
// number of iterations.
|
// number of iterations.
|
||||||
ASSERT_LE(detection_results.size(), iterations);
|
ASSERT_LE(detection_results.size(), iterations);
|
||||||
ASSERT_GT(detection_results.size(), 0);
|
ASSERT_GT(detection_results.size(), 0);
|
||||||
std::vector<Detection> full_expected_results =
|
std::vector<DetectionProto> full_expected_results =
|
||||||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||||
for (const auto& detection_result : detection_results) {
|
for (const auto& detection_result : detection_results) {
|
||||||
ExpectApproximatelyEqual(
|
ExpectApproximatelyEqual(
|
||||||
detection_result, {full_expected_results[0], full_expected_results[1]});
|
detection_result, ConvertToDetectionResult({full_expected_results[0],
|
||||||
|
full_expected_results[1]}));
|
||||||
}
|
}
|
||||||
for (const auto& image_size : image_sizes) {
|
for (const auto& image_size : image_sizes) {
|
||||||
EXPECT_EQ(image_size.first, image.width());
|
EXPECT_EQ(image_size.first, image.width());
|
||||||
|
|
|
@ -22,13 +22,13 @@ limitations under the License.
|
||||||
|
|
||||||
namespace mediapipe::tasks::vision::utils {
|
namespace mediapipe::tasks::vision::utils {
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::Rect;
|
using ::mediapipe::tasks::components::containers::RectF;
|
||||||
|
|
||||||
float CalculateArea(const Rect& rect) {
|
float CalculateArea(const RectF& rect) {
|
||||||
return (rect.right - rect.left) * (rect.bottom - rect.top);
|
return (rect.right - rect.left) * (rect.bottom - rect.top);
|
||||||
}
|
}
|
||||||
|
|
||||||
float CalculateIntersectionArea(const Rect& a, const Rect& b) {
|
float CalculateIntersectionArea(const RectF& a, const RectF& b) {
|
||||||
const float intersection_left = std::max<float>(a.left, b.left);
|
const float intersection_left = std::max<float>(a.left, b.left);
|
||||||
const float intersection_top = std::max<float>(a.top, b.top);
|
const float intersection_top = std::max<float>(a.top, b.top);
|
||||||
const float intersection_right = std::min<float>(a.right, b.right);
|
const float intersection_right = std::min<float>(a.right, b.right);
|
||||||
|
@ -38,7 +38,7 @@ float CalculateIntersectionArea(const Rect& a, const Rect& b) {
|
||||||
std::max<float>(intersection_right - intersection_left, 0.0);
|
std::max<float>(intersection_right - intersection_left, 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
float CalculateIOU(const Rect& a, const Rect& b) {
|
float CalculateIOU(const RectF& a, const RectF& b) {
|
||||||
const float area_a = CalculateArea(a);
|
const float area_a = CalculateArea(a);
|
||||||
const float area_b = CalculateArea(b);
|
const float area_b = CalculateArea(b);
|
||||||
if (area_a <= 0 || area_b <= 0) return 0.0;
|
if (area_a <= 0 || area_b <= 0) return 0.0;
|
||||||
|
|
|
@ -27,15 +27,15 @@ limitations under the License.
|
||||||
namespace mediapipe::tasks::vision::utils {
|
namespace mediapipe::tasks::vision::utils {
|
||||||
|
|
||||||
// Calculates intersection over union for two bounds.
|
// Calculates intersection over union for two bounds.
|
||||||
float CalculateIOU(const components::containers::Rect& a,
|
float CalculateIOU(const components::containers::RectF& a,
|
||||||
const components::containers::Rect& b);
|
const components::containers::RectF& b);
|
||||||
|
|
||||||
// Calculates area for face bound
|
// Calculates area for face bound
|
||||||
float CalculateArea(const components::containers::Rect& rect);
|
float CalculateArea(const components::containers::RectF& rect);
|
||||||
|
|
||||||
// Calucates intersection area of two face bounds
|
// Calucates intersection area of two face bounds
|
||||||
float CalculateIntersectionArea(const components::containers::Rect& a,
|
float CalculateIntersectionArea(const components::containers::RectF& a,
|
||||||
const components::containers::Rect& b);
|
const components::containers::RectF& b);
|
||||||
} // namespace mediapipe::tasks::vision::utils
|
} // namespace mediapipe::tasks::vision::utils
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_
|
#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_
|
||||||
|
|
Loading…
Reference in New Issue
Block a user