From 854ab25ee9590fcf8e4df07964e7cc725b970ca0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 14 Mar 2023 12:04:51 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 516594221 --- mediapipe/tasks/cc/vision/pose_detector/BUILD | 53 +++ .../pose_detector/pose_detector_graph.cc | 354 ++++++++++++++++++ .../pose_detector/pose_detector_graph_test.cc | 165 ++++++++ .../tasks/cc/vision/pose_detector/proto/BUILD | 31 ++ .../proto/pose_detector_graph_options.proto | 45 +++ mediapipe/tasks/testdata/vision/BUILD | 5 + .../vision/pose_expected_detection.pbtxt | 27 ++ third_party/external_files.bzl | 26 +- 8 files changed, 699 insertions(+), 7 deletions(-) create mode 100644 mediapipe/tasks/cc/vision/pose_detector/BUILD create mode 100644 mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc create mode 100644 mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph_test.cc create mode 100644 mediapipe/tasks/cc/vision/pose_detector/proto/BUILD create mode 100644 mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto create mode 100644 mediapipe/tasks/testdata/vision/pose_expected_detection.pbtxt diff --git a/mediapipe/tasks/cc/vision/pose_detector/BUILD b/mediapipe/tasks/cc/vision/pose_detector/BUILD new file mode 100644 index 000000000..1e361afbe --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_detector/BUILD @@ -0,0 +1,53 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +cc_library( + name = "pose_detector_graph", + srcs = ["pose_detector_graph.cc"], + deps = [ + "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", + "//mediapipe/calculators/util:detection_projection_calculator", + "//mediapipe/calculators/util:detection_transformation_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", + "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto", + "//mediapipe/calculators/util:rect_transformation_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:subgraph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc new file mode 100644 index 000000000..7c8958b3c --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc @@ -0,0 +1,354 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" +#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" +#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" +#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/subgraph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_detector { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::Tensor; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::vision::pose_detector::proto:: + PoseDetectorGraphOptions; + +namespace { +constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kAnchorsTag[] = "ANCHORS"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kNormRectsTag[] = "NORM_RECTS"; +constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS"; +constexpr char kPoseRectsTag[] = "POSE_RECTS"; +constexpr char kExpandedPoseRectsTag[] = "EXPANDED_POSE_RECTS"; +constexpr char kMatrixTag[] = "MATRIX"; +constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX"; + +struct PoseDetectionOuts { + Source> pose_detections; + Source> pose_rects; + Source> expanded_pose_rects; + Source image; +}; + +// TODO: Configuration detection related calculators in pose +// detector with model metadata. +void ConfigureSsdAnchorsCalculator( + mediapipe::SsdAnchorsCalculatorOptions* options) { + // Dervied from + // mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt + options->set_num_layers(5); + options->set_min_scale(0.1484375); + options->set_max_scale(0.75); + options->set_input_size_height(224); + options->set_input_size_width(224); + options->set_anchor_offset_x(0.5); + options->set_anchor_offset_y(0.5); + options->add_strides(8); + options->add_strides(16); + options->add_strides(32); + options->add_strides(32); + options->add_strides(32); + options->add_aspect_ratios(1.0); + options->set_fixed_anchor_size(true); +} + +// TODO: Configuration detection related calculators in pose +// detector with model metadata. +void ConfigureTensorsToDetectionsCalculator( + const PoseDetectorGraphOptions& tasks_options, + mediapipe::TensorsToDetectionsCalculatorOptions* options) { + // Dervied from + // mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt + options->set_num_classes(1); + options->set_num_boxes(2254); + options->set_num_coords(12); + options->set_box_coord_offset(0); + options->set_keypoint_coord_offset(4); + options->set_num_keypoints(4); + options->set_num_values_per_keypoint(2); + options->set_sigmoid_score(true); + options->set_score_clipping_thresh(100.0); + options->set_reverse_output_order(true); + options->set_min_score_thresh(tasks_options.min_detection_confidence()); + options->set_x_scale(224.0); + options->set_y_scale(224.0); + options->set_w_scale(224.0); + options->set_h_scale(224.0); +} + +void ConfigureNonMaxSuppressionCalculator( + const PoseDetectorGraphOptions& tasks_options, + mediapipe::NonMaxSuppressionCalculatorOptions* options) { + options->set_min_suppression_threshold( + tasks_options.min_suppression_threshold()); + options->set_overlap_type( + mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION); + options->set_algorithm( + mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED); +} + +// TODO: Configuration detection related calculators in pose +// detector with model metadata. +void ConfigureDetectionsToRectsCalculator( + mediapipe::DetectionsToRectsCalculatorOptions* options) { + options->set_rotation_vector_start_keypoint_index(0); + options->set_rotation_vector_end_keypoint_index(2); + options->set_rotation_vector_target_angle(90); + options->set_output_zero_rect_for_empty_detections(true); +} + +// TODO: Configuration detection related calculators in pose +// detector with model metadata. +void ConfigureRectTransformationCalculator( + mediapipe::RectTransformationCalculatorOptions* options) { + options->set_scale_x(2.6); + options->set_scale_y(2.6); + options->set_shift_y(-0.5); + options->set_square_long(true); +} + +} // namespace + +// A "mediapipe.tasks.vision.pose_detector.PoseDetectorGraph" performs pose +// detection. +// +// Inputs: +// IMAGE - Image +// Image to perform detection on. +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform detection on. If +// not provided, whole image is used for pose detection. +// +// Outputs: +// DETECTIONS - std::vector +// Detected pose with maximum `num_poses` specified in options. +// POSE_RECTS - std::vector +// Detected pose bounding boxes in normalized coordinates. +// EXPANDED_POSE_RECTS - std::vector +// Expanded pose bounding boxes in normalized coordinates so that bounding +// boxes likely contain the whole pose. This is usually used as RoI for pose +// landmarks detection to run on. +// IMAGE - Image +// The input image that the pose detector runs on and has the pixel data +// stored on the target storage (CPU vs GPU). +// All returned coordinates are in the unrotated and uncropped input image +// coordinates system. +// +// Example: +// node { +// calculator: "mediapipe.tasks.vision.pose_detector.PoseDetectorGraph" +// input_stream: "IMAGE:image" +// input_stream: "NORM_RECT:norm_rect" +// output_stream: "DETECTIONS:palm_detections" +// output_stream: "POSE_RECTS:pose_rects" +// output_stream: "EXPANDED_POSE_RECTS:expanded_pose_rects" +// output_stream: "IMAGE:image_out" +// options { +// [mediapipe.tasks.vision.pose_detector.proto.PoseDetectorGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "pose_detection.tflite" +// } +// } +// min_detection_confidence: 0.5 +// num_poses: 2 +// } +// } +// } +class PoseDetectorGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + ASSIGN_OR_RETURN(const auto* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN(auto outs, + BuildPoseDetectionSubgraph( + sc->Options(), + *model_resources, graph[Input(kImageTag)], + graph[Input(kNormRectTag)], graph)); + + outs.pose_detections >> + graph.Out(kDetectionsTag).Cast>(); + outs.pose_rects >> + graph.Out(kPoseRectsTag).Cast>(); + outs.expanded_pose_rects >> + graph.Out(kExpandedPoseRectsTag).Cast>(); + outs.image >> graph.Out(kImageTag).Cast(); + + return graph.GetConfig(); + } + + private: + absl::StatusOr BuildPoseDetectionSubgraph( + const PoseDetectorGraphOptions& subgraph_options, + const core::ModelResources& model_resources, Source image_in, + Source norm_rect_in, Graph& graph) { + // Image preprocessing subgraph to convert image to tensor for the tflite + // model. + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( + model_resources, use_gpu, + &preprocessing.GetOptions< + components::processors::proto::ImagePreprocessingGraphOptions>())); + auto& image_to_tensor_options = + *preprocessing + .GetOptions() + .mutable_image_to_tensor_options(); + image_to_tensor_options.set_keep_aspect_ratio(true); + image_to_tensor_options.set_border_mode( + mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); + image_in >> preprocessing.In(kImageTag); + norm_rect_in >> preprocessing.In(kNormRectTag); + auto preprocessed_tensors = preprocessing.Out(kTensorsTag); + auto matrix = preprocessing.Out(kMatrixTag); + auto image_size = preprocessing.Out(kImageSizeTag); + + // Pose detection model inferece. + auto& inference = AddInference( + model_resources, subgraph_options.base_options().acceleration(), graph); + preprocessed_tensors >> inference.In(kTensorsTag); + auto model_output_tensors = + inference.Out(kTensorsTag).Cast>(); + + // Generates a single side packet containing a vector of SSD anchors. + auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator"); + ConfigureSsdAnchorsCalculator( + &ssd_anchor.GetOptions()); + auto anchors = ssd_anchor.SideOut(""); + + // Converts output tensors to Detections. + auto& tensors_to_detections = + graph.AddNode("TensorsToDetectionsCalculator"); + ConfigureTensorsToDetectionsCalculator( + subgraph_options, + &tensors_to_detections + .GetOptions()); + model_output_tensors >> tensors_to_detections.In(kTensorsTag); + anchors >> tensors_to_detections.SideIn(kAnchorsTag); + auto detections = tensors_to_detections.Out(kDetectionsTag); + + // Non maximum suppression removes redundant face detections. + auto& non_maximum_suppression = + graph.AddNode("NonMaxSuppressionCalculator"); + ConfigureNonMaxSuppressionCalculator( + subgraph_options, + &non_maximum_suppression + .GetOptions()); + detections >> non_maximum_suppression.In(""); + auto nms_detections = non_maximum_suppression.Out(""); + + // Projects detections back into the input image coordinates system. + auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); + nms_detections >> detection_projection.In(kDetectionsTag); + matrix >> detection_projection.In(kProjectionMatrixTag); + Source> pose_detections = + detection_projection.Out(kDetectionsTag).Cast>(); + + if (subgraph_options.has_num_poses()) { + // Clip face detections to maximum number of poses. + auto& clip_detection_vector_size = + graph.AddNode("ClipDetectionVectorSizeCalculator"); + clip_detection_vector_size + .GetOptions() + .set_max_vec_size(subgraph_options.num_poses()); + pose_detections >> clip_detection_vector_size.In(""); + pose_detections = + clip_detection_vector_size.Out("").Cast>(); + } + + // Converts results of pose detection into a rectangle (normalized by image + // size) that encloses the face and is rotated such that the line connecting + // left eye and right eye is aligned with the X-axis of the rectangle. + auto& detections_to_rects = graph.AddNode("DetectionsToRectsCalculator"); + ConfigureDetectionsToRectsCalculator( + &detections_to_rects + .GetOptions()); + image_size >> detections_to_rects.In(kImageSizeTag); + pose_detections >> detections_to_rects.In(kDetectionsTag); + auto pose_rects = detections_to_rects.Out(kNormRectsTag) + .Cast>(); + + // Expands and shifts the rectangle that contains the pose so that it's + // likely to cover the entire pose. + auto& rect_transformation = graph.AddNode("RectTransformationCalculator"); + ConfigureRectTransformationCalculator( + &rect_transformation + .GetOptions()); + pose_rects >> rect_transformation.In(kNormRectsTag); + image_size >> rect_transformation.In(kImageSizeTag); + auto expanded_pose_rects = + rect_transformation.Out("").Cast>(); + + // Calculator to convert relative detection bounding boxes to pixel + // detection bounding boxes. + auto& detection_transformation = + graph.AddNode("DetectionTransformationCalculator"); + detection_projection.Out(kDetectionsTag) >> + detection_transformation.In(kDetectionsTag); + preprocessing.Out(kImageSizeTag) >> + detection_transformation.In(kImageSizeTag); + auto pose_pixel_detections = + detection_transformation.Out(kPixelDetectionsTag) + .Cast>(); + + return PoseDetectionOuts{ + /* pose_detections= */ pose_pixel_detections, + /* pose_rects= */ pose_rects, + /* expanded_pose_rects= */ expanded_pose_rects, + /* image= */ preprocessing.Out(kImageTag).Cast()}; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::pose_detector::PoseDetectorGraph); + +} // namespace pose_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph_test.cc b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph_test.cc new file mode 100644 index 000000000..5706bd9d7 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph_test.cc @@ -0,0 +1,165 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_detector { +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::pose_detector::proto:: + PoseDetectorGraphOptions; +using ::testing::EqualsProto; +using ::testing::Pointwise; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kPoseDetectionModel[] = "pose_detection.tflite"; +constexpr char kPortraitImage[] = "pose.jpg"; +constexpr char kPoseExpectedDetection[] = "pose_expected_detection.pbtxt"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectName[] = "norm_rect"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kDetectionsName[] = "detections"; + +constexpr float kPoseDetectionMaxDiff = 0.01; + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner( + absl::string_view model_name) { + Graph graph; + + auto& pose_detector_graph = + graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph"); + + auto options = std::make_unique(); + options->mutable_base_options()->mutable_model_asset()->set_file_name( + JoinPath("./", kTestDataDirectory, model_name)); + options->set_min_detection_confidence(0.6); + options->set_min_suppression_threshold(0.3); + pose_detector_graph.GetOptions().Swap( + options.get()); + + graph[Input(kImageTag)].SetName(kImageName) >> + pose_detector_graph.In(kImageTag); + graph[Input(kNormRectTag)].SetName(kNormRectName) >> + pose_detector_graph.In(kNormRectTag); + + pose_detector_graph.Out(kDetectionsTag).SetName(kDetectionsName) >> + graph[Output>(kDetectionsTag)]; + + return TaskRunner::Create( + graph.GetConfig(), std::make_unique()); +} + +Detection GetExpectedPoseDetectionResult(absl::string_view file_name) { + Detection detection; + CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), + &detection, Defaults())) + << "Expected pose detection result does not exist."; + return detection; +} + +struct TestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of pose landmark detection model. + std::string pose_detection_model_name; + // The filename of test image. + std::string test_image_name; + // Expected pose detection results. + std::vector expected_result; +}; + +class PoseDetectorGraphTest : public testing::TestWithParam {}; + +TEST_P(PoseDetectorGraphTest, Succeed) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + NormalizedRect input_norm_rect; + input_norm_rect.set_x_center(0.5); + input_norm_rect.set_y_center(0.5); + input_norm_rect.set_width(1.0); + input_norm_rect.set_height(1.0); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, CreateTaskRunner(GetParam().pose_detection_model_name)); + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(std::move(input_norm_rect))}}); + MP_ASSERT_OK(output_packets); + const std::vector& pose_detections = + (*output_packets)[kDetectionsName].Get>(); + EXPECT_THAT(pose_detections, Pointwise(Approximately(Partially(EqualsProto()), + kPoseDetectionMaxDiff), + GetParam().expected_result)); +} + +INSTANTIATE_TEST_SUITE_P( + PoseDetectorGraphTest, PoseDetectorGraphTest, + Values(TestParams{.test_name = "DetectPose", + .pose_detection_model_name = kPoseDetectionModel, + .test_image_name = kPortraitImage, + .expected_result = {GetExpectedPoseDetectionResult( + kPoseExpectedDetection)}}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace pose_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/pose_detector/proto/BUILD b/mediapipe/tasks/cc/vision/pose_detector/proto/BUILD new file mode 100644 index 000000000..287ed0183 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_detector/proto/BUILD @@ -0,0 +1,31 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "pose_detector_graph_options_proto", + srcs = ["pose_detector_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto b/mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto new file mode 100644 index 000000000..693f95262 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto @@ -0,0 +1,45 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.pose_detector.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.posedetector.proto"; +option java_outer_classname = "PoseDetectorGraphOptionsProto"; + +message PoseDetectorGraphOptions { + extend mediapipe.CalculatorOptions { + optional PoseDetectorGraphOptions ext = 514774813; + } + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Minimum confidence value ([0.0, 1.0]) for confidence score to be considered + // successfully detecting a pose in the image. + optional float min_detection_confidence = 2 [default = 0.5]; + + // IoU threshold ([0,0, 1.0]) for non-maximu-suppression to be considered + // duplicate detections. + optional float min_suppression_threshold = 3 [default = 0.5]; + + // Maximum number of poses to detect in the image. + optional int32 num_poses = 4; +} diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 888da6ea1..087d0ea75 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -70,6 +70,8 @@ mediapipe_files(srcs = [ "portrait.jpg", "portrait_hair_expected_mask.jpg", "portrait_rotated.jpg", + "pose.jpg", + "pose_detection.tflite", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -127,6 +129,7 @@ filegroup( "portrait.jpg", "portrait_hair_expected_mask.jpg", "portrait_rotated.jpg", + "pose.jpg", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -171,6 +174,7 @@ filegroup( "mobilenet_v2_1.0_224.tflite", "mobilenet_v3_small_100_224_embedder.tflite", "palm_detection_full.tflite", + "pose_detection.tflite", "selfie_segm_128_128_3.tflite", "selfie_segm_144_256_3.tflite", "selfie_segmentation.tflite", @@ -199,6 +203,7 @@ filegroup( "portrait_expected_face_landmarks.pbtxt", "portrait_expected_face_landmarks_with_attention.pbtxt", "portrait_rotated_expected_detection.pbtxt", + "pose_expected_detection.pbtxt", "thumb_up_landmarks.pbtxt", "thumb_up_rotated_landmarks.pbtxt", "victory_landmarks.pbtxt", diff --git a/mediapipe/tasks/testdata/vision/pose_expected_detection.pbtxt b/mediapipe/tasks/testdata/vision/pose_expected_detection.pbtxt new file mode 100644 index 000000000..411a374d4 --- /dev/null +++ b/mediapipe/tasks/testdata/vision/pose_expected_detection.pbtxt @@ -0,0 +1,27 @@ +# proto-file: mediapipe/framework/formats/detection.proto +# proto-message: Detection +location_data { + format: BOUNDING_BOX + bounding_box { + xmin: 397 + ymin: 198 + width: 199 + height: 199 + } + relative_keypoints { + x: 0.4879558 + y: 0.7013345 + } + relative_keypoints { + x: 0.48453212 + y: 0.32265592 + } + relative_keypoints { + x: 0.4992165 + y: 0.4854874 + } + relative_keypoints { + x: 0.50227845 + y: 0.159788 + } +} diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 91bd7c756..f6c2e8ce1 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -72,8 +72,8 @@ def external_files(): http_file( name = "com_google_mediapipe_BUILD_orig", - sha256 = "64d5343a6a5f9be06db0a5074a2260f9ae63a989fe01702832cd215680dc19c1", - urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678323576393653"], + sha256 = "d86b98b82e00dd87cd46bd1429bf5eaa007b500c1a24d9316b73309f2e6c8df8", + urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678737479599640"], ) http_file( @@ -823,7 +823,7 @@ def external_files(): http_file( name = "com_google_mediapipe_portrait_expected_face_geometry_with_attention_pbtxt", sha256 = "7ed1eed98e61e0a10811bb611c895d87c8023f398a36db01b6d9ba2e1ab09e16", - urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678505004840652"], + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678737486927530"], ) http_file( @@ -864,8 +864,20 @@ def external_files(): http_file( name = "com_google_mediapipe_pose_detection_tflite", - sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85", - urls = ["https://storage.googleapis.com/mediapipe-assets/pose_detection.tflite?generation=1661875889147923"], + sha256 = "9ba9dd3d42efaaba86b4ff0122b06f29c4122e756b329d89dca1e297fd8f866c", + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_detection.tflite?generation=1678737489600422"], + ) + + http_file( + name = "com_google_mediapipe_pose_expected_detection_pbtxt", + sha256 = "e0d40e98dd5320a780a642c336d0c8720243ac5bcc0e39c4061ad970a503ae24", + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_detection.pbtxt?generation=1678737492211540"], + ) + + http_file( + name = "com_google_mediapipe_pose_jpg", + sha256 = "c8a830ed683c0276d713dd5aeda28f415f10cd6291972084a40d0d8b934ed62b", + urls = ["https://storage.googleapis.com/mediapipe-assets/pose.jpg?generation=1678737494661975"], ) http_file( @@ -1224,8 +1236,8 @@ def external_files(): http_file( name = "com_google_mediapipe_object_detection_saved_model_README_md", - sha256 = "fe163cf12fbd017738a2fd360c03d223e964ba6404ac75c635f5918784e9c34d", - urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md?generation=1661875995856372"], + sha256 = "acc23dee09f69210717ac060035c844ba902e8271486f1086f29fb156c236690", + urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md?generation=1678737498915254"], ) http_file(