Add support for rotation in ObjectDetector C++ API
PiperOrigin-RevId: 481167472
This commit is contained in:
parent
0ebe6ccf59
commit
6f3e8381ed
|
@ -33,6 +33,7 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||||
|
@ -66,6 +67,7 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
|
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
|
@ -48,6 +49,8 @@ constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||||
constexpr char kImageInStreamName[] = "image_in";
|
constexpr char kImageInStreamName[] = "image_in";
|
||||||
constexpr char kImageOutStreamName[] = "image_out";
|
constexpr char kImageOutStreamName[] = "image_out";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
constexpr char kNormRectName[] = "norm_rect_in";
|
||||||
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kSubgraphTypeName[] =
|
constexpr char kSubgraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.ObjectDetectorGraph";
|
"mediapipe.tasks.vision.ObjectDetectorGraph";
|
||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
@ -55,6 +58,31 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
using ObjectDetectorOptionsProto =
|
using ObjectDetectorOptionsProto =
|
||||||
object_detector::proto::ObjectDetectorOptions;
|
object_detector::proto::ObjectDetectorOptions;
|
||||||
|
|
||||||
|
// Returns a NormalizedRect filling the whole image. If input is present, its
|
||||||
|
// rotation is set in the returned NormalizedRect and a check is performed to
|
||||||
|
// make sure no region-of-interest was provided. Otherwise, rotation is set to
|
||||||
|
// 0.
|
||||||
|
absl::StatusOr<NormalizedRect> FillNormalizedRect(
|
||||||
|
std::optional<NormalizedRect> normalized_rect) {
|
||||||
|
NormalizedRect result;
|
||||||
|
if (normalized_rect.has_value()) {
|
||||||
|
result = *normalized_rect;
|
||||||
|
}
|
||||||
|
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
|
||||||
|
result.has_width() || result.has_height();
|
||||||
|
if (has_coordinates) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
"ObjectDetector does not support region-of-interest.",
|
||||||
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
}
|
||||||
|
result.set_x_center(0.5);
|
||||||
|
result.set_y_center(0.5);
|
||||||
|
result.set_width(1);
|
||||||
|
result.set_height(1);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||||
// "mediapipe.tasks.vision.ObjectDetectorGraph". If the task is running in the
|
// "mediapipe.tasks.vision.ObjectDetectorGraph". If the task is running in the
|
||||||
// live stream mode, a "FlowLimiterCalculator" will be added to limit the
|
// live stream mode, a "FlowLimiterCalculator" will be added to limit the
|
||||||
|
@ -64,6 +92,7 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
bool enable_flow_limiting) {
|
bool enable_flow_limiting) {
|
||||||
api2::builder::Graph graph;
|
api2::builder::Graph graph;
|
||||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||||
|
graph.In(kNormRectTag).SetName(kNormRectName);
|
||||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
task_subgraph.GetOptions<ObjectDetectorOptionsProto>().Swap(
|
task_subgraph.GetOptions<ObjectDetectorOptionsProto>().Swap(
|
||||||
options_proto.get());
|
options_proto.get());
|
||||||
|
@ -76,6 +105,7 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
{kImageTag}, kDetectionsTag);
|
{kImageTag}, kDetectionsTag);
|
||||||
}
|
}
|
||||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||||
|
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,46 +169,64 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
|
absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
|
||||||
mediapipe::Image image) {
|
mediapipe::Image image,
|
||||||
|
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::StrCat("GPU input images are currently not supported."),
|
absl::StrCat("GPU input images are currently not supported."),
|
||||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||||
}
|
}
|
||||||
ASSIGN_OR_RETURN(auto output_packets,
|
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||||
ProcessImageData({{kImageInStreamName,
|
FillNormalizedRect(image_processing_options));
|
||||||
MakePacket<Image>(std::move(image))}}));
|
ASSIGN_OR_RETURN(
|
||||||
|
auto output_packets,
|
||||||
|
ProcessImageData(
|
||||||
|
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
||||||
|
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
|
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
|
absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms) {
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
|
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::StrCat("GPU input images are currently not supported."),
|
absl::StrCat("GPU input images are currently not supported."),
|
||||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||||
}
|
}
|
||||||
|
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||||
|
FillNormalizedRect(image_processing_options));
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_packets,
|
auto output_packets,
|
||||||
ProcessVideoData(
|
ProcessVideoData(
|
||||||
{{kImageInStreamName,
|
{{kImageInStreamName,
|
||||||
MakePacket<Image>(std::move(image))
|
MakePacket<Image>(std::move(image))
|
||||||
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||||
|
{kNormRectName,
|
||||||
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||||
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
|
return output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ObjectDetector::DetectAsync(Image image, int64 timestamp_ms) {
|
absl::Status ObjectDetector::DetectAsync(
|
||||||
|
Image image, int64 timestamp_ms,
|
||||||
|
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::StrCat("GPU input images are currently not supported."),
|
absl::StrCat("GPU input images are currently not supported."),
|
||||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||||
}
|
}
|
||||||
|
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||||
|
FillNormalizedRect(image_processing_options));
|
||||||
return SendLiveStreamData(
|
return SendLiveStreamData(
|
||||||
{{kImageInStreamName,
|
{{kImageInStreamName,
|
||||||
MakePacket<Image>(std::move(image))
|
MakePacket<Image>(std::move(image))
|
||||||
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||||
|
{kNormRectName,
|
||||||
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -26,6 +27,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
@ -151,6 +153,13 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// TODO: Describes how the input image will be preprocessed
|
// TODO: Describes how the input image will be preprocessed
|
||||||
// after the yuv support is implemented.
|
// after the yuv support is implemented.
|
||||||
//
|
//
|
||||||
|
// The optional 'image_processing_options' parameter can be used to specify
|
||||||
|
// the rotation to apply to the image before performing classification, by
|
||||||
|
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||||
|
// anti-clockwise rotation). Note that specifying a region-of-interest using
|
||||||
|
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
|
||||||
|
// and will result in an invalid argument error being returned.
|
||||||
|
//
|
||||||
// For CPU images, the returned bounding boxes are expressed in the
|
// For CPU images, the returned bounding boxes are expressed in the
|
||||||
// unrotated input frame of reference coordinates system, i.e. in `[0,
|
// unrotated input frame of reference coordinates system, i.e. in `[0,
|
||||||
// image_width) x [0, image_height)`, which are the dimensions of the
|
// image_width) x [0, image_height)`, which are the dimensions of the
|
||||||
|
@ -158,7 +167,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// TODO: Describes the output bounding boxes for gpu input
|
// TODO: Describes the output bounding boxes for gpu input
|
||||||
// images after enabling the gpu support in MediaPipe Tasks.
|
// images after enabling the gpu support in MediaPipe Tasks.
|
||||||
absl::StatusOr<std::vector<mediapipe::Detection>> Detect(
|
absl::StatusOr<std::vector<mediapipe::Detection>> Detect(
|
||||||
mediapipe::Image image);
|
mediapipe::Image image,
|
||||||
|
std::optional<mediapipe::NormalizedRect> image_processing_options =
|
||||||
|
std::nullopt);
|
||||||
|
|
||||||
// Performs object detection on the provided video frame.
|
// Performs object detection on the provided video frame.
|
||||||
// Only use this method when the ObjectDetector is created with the video
|
// Only use this method when the ObjectDetector is created with the video
|
||||||
|
@ -168,12 +179,21 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
// must be monotonically increasing.
|
// must be monotonically increasing.
|
||||||
//
|
//
|
||||||
|
// The optional 'image_processing_options' parameter can be used to specify
|
||||||
|
// the rotation to apply to the image before performing classification, by
|
||||||
|
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||||
|
// anti-clockwise rotation). Note that specifying a region-of-interest using
|
||||||
|
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
|
||||||
|
// and will result in an invalid argument error being returned.
|
||||||
|
//
|
||||||
// For CPU images, the returned bounding boxes are expressed in the
|
// For CPU images, the returned bounding boxes are expressed in the
|
||||||
// unrotated input frame of reference coordinates system, i.e. in `[0,
|
// unrotated input frame of reference coordinates system, i.e. in `[0,
|
||||||
// image_width) x [0, image_height)`, which are the dimensions of the
|
// image_width) x [0, image_height)`, which are the dimensions of the
|
||||||
// underlying image data.
|
// underlying image data.
|
||||||
absl::StatusOr<std::vector<mediapipe::Detection>> DetectForVideo(
|
absl::StatusOr<std::vector<mediapipe::Detection>> DetectForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms);
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
|
std::optional<mediapipe::NormalizedRect> image_processing_options =
|
||||||
|
std::nullopt);
|
||||||
|
|
||||||
// Sends live image data to perform object detection, and the results will be
|
// Sends live image data to perform object detection, and the results will be
|
||||||
// available via the "result_callback" provided in the ObjectDetectorOptions.
|
// available via the "result_callback" provided in the ObjectDetectorOptions.
|
||||||
|
@ -185,7 +205,14 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// sent to the object detector. The input timestamps must be monotonically
|
// sent to the object detector. The input timestamps must be monotonically
|
||||||
// increasing.
|
// increasing.
|
||||||
//
|
//
|
||||||
// The "result_callback" prvoides
|
// The optional 'image_processing_options' parameter can be used to specify
|
||||||
|
// the rotation to apply to the image before performing classification, by
|
||||||
|
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||||
|
// anti-clockwise rotation). Note that specifying a region-of-interest using
|
||||||
|
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
|
||||||
|
// and will result in an invalid argument error being returned.
|
||||||
|
//
|
||||||
|
// The "result_callback" provides
|
||||||
// - A vector of detections, each has a bounding box that is expressed in
|
// - A vector of detections, each has a bounding box that is expressed in
|
||||||
// the unrotated input frame of reference coordinates system, i.e. in `[0,
|
// the unrotated input frame of reference coordinates system, i.e. in `[0,
|
||||||
// image_width) x [0, image_height)`, which are the dimensions of the
|
// image_width) x [0, image_height)`, which are the dimensions of the
|
||||||
|
@ -195,7 +222,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// longer be valid when the callback returns. To access the image data
|
// longer be valid when the callback returns. To access the image data
|
||||||
// outside of the callback, callers need to make a copy of the image.
|
// outside of the callback, callers need to make a copy of the image.
|
||||||
// - The input timestamp in milliseconds.
|
// - The input timestamp in milliseconds.
|
||||||
absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms);
|
absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms,
|
||||||
|
std::optional<mediapipe::NormalizedRect>
|
||||||
|
image_processing_options = std::nullopt);
|
||||||
|
|
||||||
// Shuts down the ObjectDetector when all works are done.
|
// Shuts down the ObjectDetector when all works are done.
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||||
|
@ -87,6 +88,7 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
constexpr char kIndicesTag[] = "INDICES";
|
constexpr char kIndicesTag[] = "INDICES";
|
||||||
constexpr char kMatrixTag[] = "MATRIX";
|
constexpr char kMatrixTag[] = "MATRIX";
|
||||||
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
|
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
|
||||||
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
|
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
|
||||||
constexpr char kScoresTag[] = "SCORES";
|
constexpr char kScoresTag[] = "SCORES";
|
||||||
|
@ -457,6 +459,10 @@ void ConfigureTensorsToDetectionsCalculator(
|
||||||
// Inputs:
|
// Inputs:
|
||||||
// IMAGE - Image
|
// IMAGE - Image
|
||||||
// Image to perform detection on.
|
// Image to perform detection on.
|
||||||
|
// NORM_RECT - NormalizedRect @Optional
|
||||||
|
// Describes image rotation and region of image to perform classification
|
||||||
|
// on.
|
||||||
|
// @Optional: rect covering the whole image is used if not specified.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// DETECTIONS - std::vector<Detection>
|
// DETECTIONS - std::vector<Detection>
|
||||||
|
@ -494,9 +500,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_streams,
|
auto output_streams,
|
||||||
BuildObjectDetectionTask(sc->Options<ObjectDetectorOptionsProto>(),
|
BuildObjectDetectionTask(
|
||||||
*model_resources,
|
sc->Options<ObjectDetectorOptionsProto>(), *model_resources,
|
||||||
graph[Input<Image>(kImageTag)], graph));
|
graph[Input<Image>(kImageTag)],
|
||||||
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||||
output_streams.detections >>
|
output_streams.detections >>
|
||||||
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
||||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||||
|
@ -519,7 +526,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
||||||
absl::StatusOr<ObjectDetectionOutputStreams> BuildObjectDetectionTask(
|
absl::StatusOr<ObjectDetectionOutputStreams> BuildObjectDetectionTask(
|
||||||
const ObjectDetectorOptionsProto& task_options,
|
const ObjectDetectorOptionsProto& task_options,
|
||||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||||
Graph& graph) {
|
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||||
// Checks that the model has 4 outputs.
|
// Checks that the model has 4 outputs.
|
||||||
auto& model = *model_resources.GetTfLiteModel();
|
auto& model = *model_resources.GetTfLiteModel();
|
||||||
|
@ -559,6 +566,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
||||||
&preprocessing
|
&preprocessing
|
||||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||||
image_in >> preprocessing.In(kImageTag);
|
image_in >> preprocessing.In(kImageTag);
|
||||||
|
norm_rect_in >> preprocessing.In(kNormRectTag);
|
||||||
|
|
||||||
// Adds inference subgraph and connects its input stream to the output
|
// Adds inference subgraph and connects its input stream to the output
|
||||||
// tensors produced by the ImageToTensorCalculator.
|
// tensors produced by the ImageToTensorCalculator.
|
||||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/vision/object_detector/object_detector.h"
|
#include "mediapipe/tasks/cc/vision/object_detector/object_detector.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -30,6 +31,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/deps/file_path.h"
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
@ -519,6 +521,54 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
|
||||||
ExpectApproximatelyEqual(results, {full_expected_results[3]});
|
ExpectApproximatelyEqual(results, {full_expected_results[3]});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||||
|
"cats_and_dogs_rotated.jpg")));
|
||||||
|
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||||
|
options->max_results = 1;
|
||||||
|
options->category_allowlist.push_back("cat");
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||||
|
ObjectDetector::Create(std::move(options)));
|
||||||
|
NormalizedRect image_processing_options;
|
||||||
|
image_processing_options.set_rotation(M_PI / 2.0);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto results, object_detector->Detect(image, image_processing_options));
|
||||||
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
|
ExpectApproximatelyEqual(
|
||||||
|
results, {ParseTextProtoOrDie<Detection>(R"pb(
|
||||||
|
label: "cat"
|
||||||
|
score: 0.7109375
|
||||||
|
location_data {
|
||||||
|
format: BOUNDING_BOX
|
||||||
|
bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 }
|
||||||
|
})pb")});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(Image image,
|
||||||
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||||
|
"cats_and_dogs.jpg")));
|
||||||
|
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||||
|
options->max_results = 1;
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||||
|
ObjectDetector::Create(std::move(options)));
|
||||||
|
NormalizedRect image_processing_options;
|
||||||
|
image_processing_options.set_x_center(0.5);
|
||||||
|
image_processing_options.set_y_center(0.5);
|
||||||
|
image_processing_options.set_width(1.0);
|
||||||
|
image_processing_options.set_height(1.0);
|
||||||
|
|
||||||
|
auto results = object_detector->Detect(image, image_processing_options);
|
||||||
|
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(results.status().message(),
|
||||||
|
HasSubstr("ObjectDetector does not support region-of-interest"));
|
||||||
|
}
|
||||||
|
|
||||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
|
|
2
mediapipe/tasks/testdata/vision/BUILD
vendored
2
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -30,6 +30,7 @@ mediapipe_files(srcs = [
|
||||||
"cat_mask.jpg",
|
"cat_mask.jpg",
|
||||||
"cats_and_dogs.jpg",
|
"cats_and_dogs.jpg",
|
||||||
"cats_and_dogs_no_resizing.jpg",
|
"cats_and_dogs_no_resizing.jpg",
|
||||||
|
"cats_and_dogs_rotated.jpg",
|
||||||
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
|
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
|
||||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
||||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
||||||
|
@ -79,6 +80,7 @@ filegroup(
|
||||||
"cat_mask.jpg",
|
"cat_mask.jpg",
|
||||||
"cats_and_dogs.jpg",
|
"cats_and_dogs.jpg",
|
||||||
"cats_and_dogs_no_resizing.jpg",
|
"cats_and_dogs_no_resizing.jpg",
|
||||||
|
"cats_and_dogs_rotated.jpg",
|
||||||
"hand_landmark_full.tflite",
|
"hand_landmark_full.tflite",
|
||||||
"hand_landmark_lite.tflite",
|
"hand_landmark_lite.tflite",
|
||||||
"left_hands.jpg",
|
"left_hands.jpg",
|
||||||
|
|
12
third_party/external_files.bzl
vendored
12
third_party/external_files.bzl
vendored
|
@ -46,6 +46,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_BUILD_orig",
|
||||||
|
sha256 = "650df617b3e125e0890f1b8c936cc64c9d975707f57e616b6430fc667ce315d4",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1665609930388174"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_burger_crop_jpg",
|
name = "com_google_mediapipe_burger_crop_jpg",
|
||||||
sha256 = "8f58de573f0bf59a49c3d86cfabb9ad4061481f574aa049177e8da3963dddc50",
|
sha256 = "8f58de573f0bf59a49c3d86cfabb9ad4061481f574aa049177e8da3963dddc50",
|
||||||
|
@ -88,6 +94,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_no_resizing.jpg?generation=1661875687251296"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_no_resizing.jpg?generation=1661875687251296"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_cats_and_dogs_rotated_jpg",
|
||||||
|
sha256 = "5384926d16ddd8802555ae3108deedefb42a2ea78d99e5ad0933c5e11f43244a",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_rotated.jpg?generation=1665609933260747"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_classification_tensor_float_meta_json",
|
name = "com_google_mediapipe_classification_tensor_float_meta_json",
|
||||||
sha256 = "1d10b1c9c87eabac330651136804074ddc134779e94a73cf783207c3aa2a5619",
|
sha256 = "1d10b1c9c87eabac330651136804074ddc134779e94a73cf783207c3aa2a5619",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user