Internal change.
PiperOrigin-RevId: 516594221
This commit is contained in:
parent
8a41a5e44d
commit
854ab25ee9
53
mediapipe/tasks/cc/vision/pose_detector/BUILD
Normal file
53
mediapipe/tasks/cc/vision/pose_detector/BUILD
Normal file
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "pose_detector_graph",
|
||||
srcs = ["pose_detector_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tflite:ssd_anchors_calculator",
|
||||
"//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:detection_projection_calculator",
|
||||
"//mediapipe/calculators/util:detection_transformation_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:subgraph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
354
mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc
Normal file
354
mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc
Normal file
|
@ -0,0 +1,354 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/subgraph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace pose_detector {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::Tensor;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::vision::pose_detector::proto::
|
||||
PoseDetectorGraphOptions;
|
||||
|
||||
namespace {
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kAnchorsTag[] = "ANCHORS";
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kNormRectsTag[] = "NORM_RECTS";
|
||||
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
|
||||
constexpr char kPoseRectsTag[] = "POSE_RECTS";
|
||||
constexpr char kExpandedPoseRectsTag[] = "EXPANDED_POSE_RECTS";
|
||||
constexpr char kMatrixTag[] = "MATRIX";
|
||||
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
|
||||
|
||||
struct PoseDetectionOuts {
|
||||
Source<std::vector<Detection>> pose_detections;
|
||||
Source<std::vector<NormalizedRect>> pose_rects;
|
||||
Source<std::vector<NormalizedRect>> expanded_pose_rects;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
||||
// TODO: Configuration detection related calculators in pose
|
||||
// detector with model metadata.
|
||||
void ConfigureSsdAnchorsCalculator(
|
||||
mediapipe::SsdAnchorsCalculatorOptions* options) {
|
||||
// Dervied from
|
||||
// mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt
|
||||
options->set_num_layers(5);
|
||||
options->set_min_scale(0.1484375);
|
||||
options->set_max_scale(0.75);
|
||||
options->set_input_size_height(224);
|
||||
options->set_input_size_width(224);
|
||||
options->set_anchor_offset_x(0.5);
|
||||
options->set_anchor_offset_y(0.5);
|
||||
options->add_strides(8);
|
||||
options->add_strides(16);
|
||||
options->add_strides(32);
|
||||
options->add_strides(32);
|
||||
options->add_strides(32);
|
||||
options->add_aspect_ratios(1.0);
|
||||
options->set_fixed_anchor_size(true);
|
||||
}
|
||||
|
||||
// TODO: Configuration detection related calculators in pose
|
||||
// detector with model metadata.
|
||||
void ConfigureTensorsToDetectionsCalculator(
|
||||
const PoseDetectorGraphOptions& tasks_options,
|
||||
mediapipe::TensorsToDetectionsCalculatorOptions* options) {
|
||||
// Dervied from
|
||||
// mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt
|
||||
options->set_num_classes(1);
|
||||
options->set_num_boxes(2254);
|
||||
options->set_num_coords(12);
|
||||
options->set_box_coord_offset(0);
|
||||
options->set_keypoint_coord_offset(4);
|
||||
options->set_num_keypoints(4);
|
||||
options->set_num_values_per_keypoint(2);
|
||||
options->set_sigmoid_score(true);
|
||||
options->set_score_clipping_thresh(100.0);
|
||||
options->set_reverse_output_order(true);
|
||||
options->set_min_score_thresh(tasks_options.min_detection_confidence());
|
||||
options->set_x_scale(224.0);
|
||||
options->set_y_scale(224.0);
|
||||
options->set_w_scale(224.0);
|
||||
options->set_h_scale(224.0);
|
||||
}
|
||||
|
||||
void ConfigureNonMaxSuppressionCalculator(
|
||||
const PoseDetectorGraphOptions& tasks_options,
|
||||
mediapipe::NonMaxSuppressionCalculatorOptions* options) {
|
||||
options->set_min_suppression_threshold(
|
||||
tasks_options.min_suppression_threshold());
|
||||
options->set_overlap_type(
|
||||
mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION);
|
||||
options->set_algorithm(
|
||||
mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED);
|
||||
}
|
||||
|
||||
// TODO: Configuration detection related calculators in pose
|
||||
// detector with model metadata.
|
||||
void ConfigureDetectionsToRectsCalculator(
|
||||
mediapipe::DetectionsToRectsCalculatorOptions* options) {
|
||||
options->set_rotation_vector_start_keypoint_index(0);
|
||||
options->set_rotation_vector_end_keypoint_index(2);
|
||||
options->set_rotation_vector_target_angle(90);
|
||||
options->set_output_zero_rect_for_empty_detections(true);
|
||||
}
|
||||
|
||||
// TODO: Configuration detection related calculators in pose
|
||||
// detector with model metadata.
|
||||
void ConfigureRectTransformationCalculator(
|
||||
mediapipe::RectTransformationCalculatorOptions* options) {
|
||||
options->set_scale_x(2.6);
|
||||
options->set_scale_y(2.6);
|
||||
options->set_shift_y(-0.5);
|
||||
options->set_square_long(true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.pose_detector.PoseDetectorGraph" performs pose
|
||||
// detection.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform detection on.
|
||||
// NORM_RECT - NormalizedRect @Optional
|
||||
// Describes image rotation and region of image to perform detection on. If
|
||||
// not provided, whole image is used for pose detection.
|
||||
//
|
||||
// Outputs:
|
||||
// DETECTIONS - std::vector<Detection>
|
||||
// Detected pose with maximum `num_poses` specified in options.
|
||||
// POSE_RECTS - std::vector<NormalizedRect>
|
||||
// Detected pose bounding boxes in normalized coordinates.
|
||||
// EXPANDED_POSE_RECTS - std::vector<NormalizedRect>
|
||||
// Expanded pose bounding boxes in normalized coordinates so that bounding
|
||||
// boxes likely contain the whole pose. This is usually used as RoI for pose
|
||||
// landmarks detection to run on.
|
||||
// IMAGE - Image
|
||||
// The input image that the pose detector runs on and has the pixel data
|
||||
// stored on the target storage (CPU vs GPU).
|
||||
// All returned coordinates are in the unrotated and uncropped input image
|
||||
// coordinates system.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.pose_detector.PoseDetectorGraph"
|
||||
// input_stream: "IMAGE:image"
|
||||
// input_stream: "NORM_RECT:norm_rect"
|
||||
// output_stream: "DETECTIONS:palm_detections"
|
||||
// output_stream: "POSE_RECTS:pose_rects"
|
||||
// output_stream: "EXPANDED_POSE_RECTS:expanded_pose_rects"
|
||||
// output_stream: "IMAGE:image_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.pose_detector.proto.PoseDetectorGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "pose_detection.tflite"
|
||||
// }
|
||||
// }
|
||||
// min_detection_confidence: 0.5
|
||||
// num_poses: 2
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class PoseDetectorGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<PoseDetectorGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(auto outs,
|
||||
BuildPoseDetectionSubgraph(
|
||||
sc->Options<PoseDetectorGraphOptions>(),
|
||||
*model_resources, graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>(kNormRectTag)], graph));
|
||||
|
||||
outs.pose_detections >>
|
||||
graph.Out(kDetectionsTag).Cast<std::vector<Detection>>();
|
||||
outs.pose_rects >>
|
||||
graph.Out(kPoseRectsTag).Cast<std::vector<NormalizedRect>>();
|
||||
outs.expanded_pose_rects >>
|
||||
graph.Out(kExpandedPoseRectsTag).Cast<std::vector<NormalizedRect>>();
|
||||
outs.image >> graph.Out(kImageTag).Cast<Image>();
|
||||
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<PoseDetectionOuts> BuildPoseDetectionSubgraph(
|
||||
const PoseDetectorGraphOptions& subgraph_options,
|
||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
// Image preprocessing subgraph to convert image to tensor for the tflite
|
||||
// model.
|
||||
auto& preprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
|
||||
bool use_gpu =
|
||||
components::processors::DetermineImagePreprocessingGpuBackend(
|
||||
subgraph_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
|
||||
model_resources, use_gpu,
|
||||
&preprocessing.GetOptions<
|
||||
components::processors::proto::ImagePreprocessingGraphOptions>()));
|
||||
auto& image_to_tensor_options =
|
||||
*preprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ImagePreprocessingGraphOptions>()
|
||||
.mutable_image_to_tensor_options();
|
||||
image_to_tensor_options.set_keep_aspect_ratio(true);
|
||||
image_to_tensor_options.set_border_mode(
|
||||
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
norm_rect_in >> preprocessing.In(kNormRectTag);
|
||||
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
|
||||
auto matrix = preprocessing.Out(kMatrixTag);
|
||||
auto image_size = preprocessing.Out(kImageSizeTag);
|
||||
|
||||
// Pose detection model inferece.
|
||||
auto& inference = AddInference(
|
||||
model_resources, subgraph_options.base_options().acceleration(), graph);
|
||||
preprocessed_tensors >> inference.In(kTensorsTag);
|
||||
auto model_output_tensors =
|
||||
inference.Out(kTensorsTag).Cast<std::vector<Tensor>>();
|
||||
|
||||
// Generates a single side packet containing a vector of SSD anchors.
|
||||
auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator");
|
||||
ConfigureSsdAnchorsCalculator(
|
||||
&ssd_anchor.GetOptions<mediapipe::SsdAnchorsCalculatorOptions>());
|
||||
auto anchors = ssd_anchor.SideOut("");
|
||||
|
||||
// Converts output tensors to Detections.
|
||||
auto& tensors_to_detections =
|
||||
graph.AddNode("TensorsToDetectionsCalculator");
|
||||
ConfigureTensorsToDetectionsCalculator(
|
||||
subgraph_options,
|
||||
&tensors_to_detections
|
||||
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>());
|
||||
model_output_tensors >> tensors_to_detections.In(kTensorsTag);
|
||||
anchors >> tensors_to_detections.SideIn(kAnchorsTag);
|
||||
auto detections = tensors_to_detections.Out(kDetectionsTag);
|
||||
|
||||
// Non maximum suppression removes redundant face detections.
|
||||
auto& non_maximum_suppression =
|
||||
graph.AddNode("NonMaxSuppressionCalculator");
|
||||
ConfigureNonMaxSuppressionCalculator(
|
||||
subgraph_options,
|
||||
&non_maximum_suppression
|
||||
.GetOptions<mediapipe::NonMaxSuppressionCalculatorOptions>());
|
||||
detections >> non_maximum_suppression.In("");
|
||||
auto nms_detections = non_maximum_suppression.Out("");
|
||||
|
||||
// Projects detections back into the input image coordinates system.
|
||||
auto& detection_projection = graph.AddNode("DetectionProjectionCalculator");
|
||||
nms_detections >> detection_projection.In(kDetectionsTag);
|
||||
matrix >> detection_projection.In(kProjectionMatrixTag);
|
||||
Source<std::vector<Detection>> pose_detections =
|
||||
detection_projection.Out(kDetectionsTag).Cast<std::vector<Detection>>();
|
||||
|
||||
if (subgraph_options.has_num_poses()) {
|
||||
// Clip face detections to maximum number of poses.
|
||||
auto& clip_detection_vector_size =
|
||||
graph.AddNode("ClipDetectionVectorSizeCalculator");
|
||||
clip_detection_vector_size
|
||||
.GetOptions<mediapipe::ClipVectorSizeCalculatorOptions>()
|
||||
.set_max_vec_size(subgraph_options.num_poses());
|
||||
pose_detections >> clip_detection_vector_size.In("");
|
||||
pose_detections =
|
||||
clip_detection_vector_size.Out("").Cast<std::vector<Detection>>();
|
||||
}
|
||||
|
||||
// Converts results of pose detection into a rectangle (normalized by image
|
||||
// size) that encloses the face and is rotated such that the line connecting
|
||||
// left eye and right eye is aligned with the X-axis of the rectangle.
|
||||
auto& detections_to_rects = graph.AddNode("DetectionsToRectsCalculator");
|
||||
ConfigureDetectionsToRectsCalculator(
|
||||
&detections_to_rects
|
||||
.GetOptions<mediapipe::DetectionsToRectsCalculatorOptions>());
|
||||
image_size >> detections_to_rects.In(kImageSizeTag);
|
||||
pose_detections >> detections_to_rects.In(kDetectionsTag);
|
||||
auto pose_rects = detections_to_rects.Out(kNormRectsTag)
|
||||
.Cast<std::vector<NormalizedRect>>();
|
||||
|
||||
// Expands and shifts the rectangle that contains the pose so that it's
|
||||
// likely to cover the entire pose.
|
||||
auto& rect_transformation = graph.AddNode("RectTransformationCalculator");
|
||||
ConfigureRectTransformationCalculator(
|
||||
&rect_transformation
|
||||
.GetOptions<mediapipe::RectTransformationCalculatorOptions>());
|
||||
pose_rects >> rect_transformation.In(kNormRectsTag);
|
||||
image_size >> rect_transformation.In(kImageSizeTag);
|
||||
auto expanded_pose_rects =
|
||||
rect_transformation.Out("").Cast<std::vector<NormalizedRect>>();
|
||||
|
||||
// Calculator to convert relative detection bounding boxes to pixel
|
||||
// detection bounding boxes.
|
||||
auto& detection_transformation =
|
||||
graph.AddNode("DetectionTransformationCalculator");
|
||||
detection_projection.Out(kDetectionsTag) >>
|
||||
detection_transformation.In(kDetectionsTag);
|
||||
preprocessing.Out(kImageSizeTag) >>
|
||||
detection_transformation.In(kImageSizeTag);
|
||||
auto pose_pixel_detections =
|
||||
detection_transformation.Out(kPixelDetectionsTag)
|
||||
.Cast<std::vector<Detection>>();
|
||||
|
||||
return PoseDetectionOuts{
|
||||
/* pose_detections= */ pose_pixel_detections,
|
||||
/* pose_rects= */ pose_rects,
|
||||
/* expanded_pose_rects= */ expanded_pose_rects,
|
||||
/* image= */ preprocessing.Out(kImageTag).Cast<Image>()};
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::pose_detector::PoseDetectorGraph);
|
||||
|
||||
} // namespace pose_detector
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,165 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace pose_detector {
|
||||
namespace {
|
||||
|
||||
using ::file::Defaults;
|
||||
using ::file::GetTextProto;
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::core::TaskRunner;
|
||||
using ::mediapipe::tasks::vision::pose_detector::proto::
|
||||
PoseDetectorGraphOptions;
|
||||
using ::testing::EqualsProto;
|
||||
using ::testing::Pointwise;
|
||||
using ::testing::TestParamInfo;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kPoseDetectionModel[] = "pose_detection.tflite";
|
||||
constexpr char kPortraitImage[] = "pose.jpg";
|
||||
constexpr char kPoseExpectedDetection[] = "pose_expected_detection.pbtxt";
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageName[] = "image";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kNormRectName[] = "norm_rect";
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kDetectionsName[] = "detections";
|
||||
|
||||
constexpr float kPoseDetectionMaxDiff = 0.01;
|
||||
|
||||
// Helper function to create a TaskRunner.
|
||||
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
||||
absl::string_view model_name) {
|
||||
Graph graph;
|
||||
|
||||
auto& pose_detector_graph =
|
||||
graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph");
|
||||
|
||||
auto options = std::make_unique<PoseDetectorGraphOptions>();
|
||||
options->mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||
JoinPath("./", kTestDataDirectory, model_name));
|
||||
options->set_min_detection_confidence(0.6);
|
||||
options->set_min_suppression_threshold(0.3);
|
||||
pose_detector_graph.GetOptions<PoseDetectorGraphOptions>().Swap(
|
||||
options.get());
|
||||
|
||||
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
||||
pose_detector_graph.In(kImageTag);
|
||||
graph[Input<NormalizedRect>(kNormRectTag)].SetName(kNormRectName) >>
|
||||
pose_detector_graph.In(kNormRectTag);
|
||||
|
||||
pose_detector_graph.Out(kDetectionsTag).SetName(kDetectionsName) >>
|
||||
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
||||
|
||||
return TaskRunner::Create(
|
||||
graph.GetConfig(), std::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
Detection GetExpectedPoseDetectionResult(absl::string_view file_name) {
|
||||
Detection detection;
|
||||
CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name),
|
||||
&detection, Defaults()))
|
||||
<< "Expected pose detection result does not exist.";
|
||||
return detection;
|
||||
}
|
||||
|
||||
struct TestParams {
|
||||
// The name of this test, for convenience when displaying test results.
|
||||
std::string test_name;
|
||||
// The filename of pose landmark detection model.
|
||||
std::string pose_detection_model_name;
|
||||
// The filename of test image.
|
||||
std::string test_image_name;
|
||||
// Expected pose detection results.
|
||||
std::vector<Detection> expected_result;
|
||||
};
|
||||
|
||||
class PoseDetectorGraphTest : public testing::TestWithParam<TestParams> {};
|
||||
|
||||
TEST_P(PoseDetectorGraphTest, Succeed) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
GetParam().test_image_name)));
|
||||
NormalizedRect input_norm_rect;
|
||||
input_norm_rect.set_x_center(0.5);
|
||||
input_norm_rect.set_y_center(0.5);
|
||||
input_norm_rect.set_width(1.0);
|
||||
input_norm_rect.set_height(1.0);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto task_runner, CreateTaskRunner(GetParam().pose_detection_model_name));
|
||||
auto output_packets = task_runner->Process(
|
||||
{{kImageName, MakePacket<Image>(std::move(image))},
|
||||
{kNormRectName,
|
||||
MakePacket<NormalizedRect>(std::move(input_norm_rect))}});
|
||||
MP_ASSERT_OK(output_packets);
|
||||
const std::vector<Detection>& pose_detections =
|
||||
(*output_packets)[kDetectionsName].Get<std::vector<Detection>>();
|
||||
EXPECT_THAT(pose_detections, Pointwise(Approximately(Partially(EqualsProto()),
|
||||
kPoseDetectionMaxDiff),
|
||||
GetParam().expected_result));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
PoseDetectorGraphTest, PoseDetectorGraphTest,
|
||||
Values(TestParams{.test_name = "DetectPose",
|
||||
.pose_detection_model_name = kPoseDetectionModel,
|
||||
.test_image_name = kPortraitImage,
|
||||
.expected_result = {GetExpectedPoseDetectionResult(
|
||||
kPoseExpectedDetection)}}),
|
||||
[](const TestParamInfo<PoseDetectorGraphTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace pose_detector
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
31
mediapipe/tasks/cc/vision/pose_detector/proto/BUILD
Normal file
31
mediapipe/tasks/cc/vision/pose_detector/proto/BUILD
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "pose_detector_graph_options_proto",
|
||||
srcs = ["pose_detector_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,45 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.vision.pose_detector.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/framework/calculator_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.posedetector.proto";
|
||||
option java_outer_classname = "PoseDetectorGraphOptionsProto";
|
||||
|
||||
message PoseDetectorGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional PoseDetectorGraphOptions ext = 514774813;
|
||||
}
|
||||
// Base options for configuring Task library, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Minimum confidence value ([0.0, 1.0]) for confidence score to be considered
|
||||
// successfully detecting a pose in the image.
|
||||
optional float min_detection_confidence = 2 [default = 0.5];
|
||||
|
||||
// IoU threshold ([0,0, 1.0]) for non-maximu-suppression to be considered
|
||||
// duplicate detections.
|
||||
optional float min_suppression_threshold = 3 [default = 0.5];
|
||||
|
||||
// Maximum number of poses to detect in the image.
|
||||
optional int32 num_poses = 4;
|
||||
}
|
5
mediapipe/tasks/testdata/vision/BUILD
vendored
5
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -70,6 +70,8 @@ mediapipe_files(srcs = [
|
|||
"portrait.jpg",
|
||||
"portrait_hair_expected_mask.jpg",
|
||||
"portrait_rotated.jpg",
|
||||
"pose.jpg",
|
||||
"pose_detection.tflite",
|
||||
"right_hands.jpg",
|
||||
"right_hands_rotated.jpg",
|
||||
"segmentation_golden_rotation0.png",
|
||||
|
@ -127,6 +129,7 @@ filegroup(
|
|||
"portrait.jpg",
|
||||
"portrait_hair_expected_mask.jpg",
|
||||
"portrait_rotated.jpg",
|
||||
"pose.jpg",
|
||||
"right_hands.jpg",
|
||||
"right_hands_rotated.jpg",
|
||||
"segmentation_golden_rotation0.png",
|
||||
|
@ -171,6 +174,7 @@ filegroup(
|
|||
"mobilenet_v2_1.0_224.tflite",
|
||||
"mobilenet_v3_small_100_224_embedder.tflite",
|
||||
"palm_detection_full.tflite",
|
||||
"pose_detection.tflite",
|
||||
"selfie_segm_128_128_3.tflite",
|
||||
"selfie_segm_144_256_3.tflite",
|
||||
"selfie_segmentation.tflite",
|
||||
|
@ -199,6 +203,7 @@ filegroup(
|
|||
"portrait_expected_face_landmarks.pbtxt",
|
||||
"portrait_expected_face_landmarks_with_attention.pbtxt",
|
||||
"portrait_rotated_expected_detection.pbtxt",
|
||||
"pose_expected_detection.pbtxt",
|
||||
"thumb_up_landmarks.pbtxt",
|
||||
"thumb_up_rotated_landmarks.pbtxt",
|
||||
"victory_landmarks.pbtxt",
|
||||
|
|
27
mediapipe/tasks/testdata/vision/pose_expected_detection.pbtxt
vendored
Normal file
27
mediapipe/tasks/testdata/vision/pose_expected_detection.pbtxt
vendored
Normal file
|
@ -0,0 +1,27 @@
|
|||
# proto-file: mediapipe/framework/formats/detection.proto
|
||||
# proto-message: Detection
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box {
|
||||
xmin: 397
|
||||
ymin: 198
|
||||
width: 199
|
||||
height: 199
|
||||
}
|
||||
relative_keypoints {
|
||||
x: 0.4879558
|
||||
y: 0.7013345
|
||||
}
|
||||
relative_keypoints {
|
||||
x: 0.48453212
|
||||
y: 0.32265592
|
||||
}
|
||||
relative_keypoints {
|
||||
x: 0.4992165
|
||||
y: 0.4854874
|
||||
}
|
||||
relative_keypoints {
|
||||
x: 0.50227845
|
||||
y: 0.159788
|
||||
}
|
||||
}
|
26
third_party/external_files.bzl
vendored
26
third_party/external_files.bzl
vendored
|
@ -72,8 +72,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_BUILD_orig",
|
||||
sha256 = "64d5343a6a5f9be06db0a5074a2260f9ae63a989fe01702832cd215680dc19c1",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678323576393653"],
|
||||
sha256 = "d86b98b82e00dd87cd46bd1429bf5eaa007b500c1a24d9316b73309f2e6c8df8",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678737479599640"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -823,7 +823,7 @@ def external_files():
|
|||
http_file(
|
||||
name = "com_google_mediapipe_portrait_expected_face_geometry_with_attention_pbtxt",
|
||||
sha256 = "7ed1eed98e61e0a10811bb611c895d87c8023f398a36db01b6d9ba2e1ab09e16",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678505004840652"],
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678737486927530"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -864,8 +864,20 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pose_detection_tflite",
|
||||
sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_detection.tflite?generation=1661875889147923"],
|
||||
sha256 = "9ba9dd3d42efaaba86b4ff0122b06f29c4122e756b329d89dca1e297fd8f866c",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_detection.tflite?generation=1678737489600422"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pose_expected_detection_pbtxt",
|
||||
sha256 = "e0d40e98dd5320a780a642c336d0c8720243ac5bcc0e39c4061ad970a503ae24",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_detection.pbtxt?generation=1678737492211540"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pose_jpg",
|
||||
sha256 = "c8a830ed683c0276d713dd5aeda28f415f10cd6291972084a40d0d8b934ed62b",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose.jpg?generation=1678737494661975"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -1224,8 +1236,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_object_detection_saved_model_README_md",
|
||||
sha256 = "fe163cf12fbd017738a2fd360c03d223e964ba6404ac75c635f5918784e9c34d",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md?generation=1661875995856372"],
|
||||
sha256 = "acc23dee09f69210717ac060035c844ba902e8271486f1086f29fb156c236690",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md?generation=1678737498915254"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
|
Loading…
Reference in New Issue
Block a user