Internal change

PiperOrigin-RevId: 520717805
This commit is contained in:
MediaPipe Team 2023-03-30 12:50:55 -07:00 committed by Copybara-Service
parent 99ba7dd787
commit d43579fe3e
8 changed files with 1523 additions and 2 deletions

View File

@ -0,0 +1,61 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
cc_library(
name = "pose_landmarks_detector_graph",
srcs = ["pose_landmarks_detector_graph.cc"],
deps = [
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:split_proto_list_calculator",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:tensors_to_floats_calculator",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_cc_proto",
"//mediapipe/calculators/tensor:tensors_to_segmentation_calculator",
"//mediapipe/calculators/tensor:tensors_to_segmentation_calculator_cc_proto",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:landmarks_to_detection_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/calculators/util:refine_landmarks_from_heatmap_calculator",
"//mediapipe/calculators/util:refine_landmarks_from_heatmap_calculator_cc_proto",
"//mediapipe/calculators/util:thresholding_calculator",
"//mediapipe/calculators/util:thresholding_calculator_cc_proto",
"//mediapipe/calculators/util:visibility_copy_calculator",
"//mediapipe/calculators/util:visibility_copy_calculator_cc_proto",
"//mediapipe/framework:subgraph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status:statusor",
],
)

View File

@ -0,0 +1,641 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/status/statusor.h"
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.pb.h"
#include "mediapipe/calculators/util/thresholding_calculator.pb.h"
#include "mediapipe/calculators/util/visibility_copy_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/subgraph.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace pose_landmarker {
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::vision::pose_landmarker::proto::
PoseLandmarksDetectorGraphOptions;
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kAuxLandmarksTag[] = "AUXILIARY_LANDMARKS";
constexpr char kPoseRectNextFrameTag[] = "POSE_RECT_NEXT_FRAME";
constexpr char kPoseRectsNextFrameTag[] = "POSE_RECTS_NEXT_FRAME";
constexpr char kPresenceTag[] = "PRESENCE";
constexpr char kPresenceScoreTag[] = "PRESENCE_SCORE";
constexpr char kSegmentationMaskTag[] = "SEGMENTATION_MASK";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kLandmarksToTag[] = "LANDMARKS_TO";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kFloatTag[] = "FLOAT";
constexpr char kFlagTag[] = "FLAG";
constexpr char kMaskTag[] = "MASK";
constexpr char kDetectionTag[] = "DETECTION";
constexpr char kNormLandmarksFromTag[] = "NORM_LANDMARKS_FROM";
constexpr char kBatchEndTag[] = "BATCH_END";
constexpr char kItemTag[] = "ITEM";
constexpr char kIterableTag[] = "ITERABLE";
constexpr int kModelOutputTensorSplitNum = 5;
constexpr int kLandmarksNum = 39;
constexpr float kLandmarksNormalizeZ = 0.4;
struct SinglePoseLandmarkerOutputs {
Source<NormalizedLandmarkList> pose_landmarks;
Source<LandmarkList> world_pose_landmarks;
Source<NormalizedLandmarkList> auxiliary_pose_landmarks;
Source<NormalizedRect> pose_rect_next_frame;
Source<bool> pose_presence;
Source<float> pose_presence_score;
Source<Image> segmentation_mask;
};
struct PoseLandmarkerOutputs {
Source<std::vector<NormalizedLandmarkList>> landmark_lists;
Source<std::vector<LandmarkList>> world_landmark_lists;
Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
Source<std::vector<bool>> presences;
Source<std::vector<float>> presence_scores;
Source<std::vector<Image>> segmentation_masks;
};
absl::Status SanityCheckOptions(
const PoseLandmarksDetectorGraphOptions& options) {
if (options.min_detection_confidence() < 0 ||
options.min_detection_confidence() > 1) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
"Invalid `min_detection_confidence` option: "
"value must be in the range [0.0, 1.0]",
MediaPipeTasksStatus::kInvalidArgumentError);
}
return absl::OkStatus();
}
// Split pose landmark detection model output tensor into five parts,
// representing landmarks, presence scores, segmentation, heatmap, and world
// landmarks respectively.
void ConfigureSplitTensorVectorCalculator(
mediapipe::SplitVectorCalculatorOptions* options) {
for (int i = 0; i < kModelOutputTensorSplitNum; ++i) {
auto* range = options->add_ranges();
range->set_begin(i);
range->set_end(i + 1);
}
}
void ConfigureTensorsToLandmarksCalculator(
const ImageTensorSpecs& input_image_tensor_spec, bool normalize,
bool sigmoid_activation,
mediapipe::TensorsToLandmarksCalculatorOptions* options) {
options->set_num_landmarks(kLandmarksNum);
options->set_input_image_height(input_image_tensor_spec.image_height);
options->set_input_image_width(input_image_tensor_spec.image_width);
if (normalize) {
options->set_normalize_z(kLandmarksNormalizeZ);
}
if (sigmoid_activation) {
options->set_visibility_activation(
mediapipe::TensorsToLandmarksCalculatorOptions_Activation_SIGMOID);
options->set_presence_activation(
mediapipe::TensorsToLandmarksCalculatorOptions_Activation_SIGMOID);
}
}
void ConfigureTensorsToSegmentationCalculator(
mediapipe::TensorsToSegmentationCalculatorOptions* options) {
options->set_activation(
mediapipe::TensorsToSegmentationCalculatorOptions_Activation_SIGMOID);
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
}
void ConfigureRefineLandmarksFromHeatmapCalculator(
mediapipe::RefineLandmarksFromHeatmapCalculatorOptions* options) {
// Derived from
// mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt.
options->set_kernel_size(7);
}
void ConfigureSplitNormalizedLandmarkListCalculator(
mediapipe::SplitVectorCalculatorOptions* options) {
// Derived from
// mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt
auto* range = options->add_ranges();
range->set_begin(0);
range->set_end(33);
auto* range_2 = options->add_ranges();
range_2->set_begin(33);
range_2->set_end(35);
}
void ConfigureSplitLandmarkListCalculator(
mediapipe::SplitVectorCalculatorOptions* options) {
// Derived from
// mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt
auto* range = options->add_ranges();
range->set_begin(0);
range->set_end(33);
}
void ConfigureVisibilityCopyCalculator(
mediapipe::VisibilityCopyCalculatorOptions* options) {
// Derived from
// mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt
options->set_copy_visibility(true);
options->set_copy_presence(true);
}
// A "mediapipe.tasks.vision.pose_landmarker.SinglePoseLandmarksDetectorGraph"
// performs pose landmarks detection.
// - Accepts CPU input images and outputs Landmark on CPU.
//
// Inputs:
// IMAGE - Image
// Image to perform detection on.
// NORM_RECT - NormalizedRect @Optional
// Rect enclosing the RoI to perform detection on. If not set, the detection
// RoI is the whole image.
//
//
// Outputs:
// LANDMARKS: - NormalizedLandmarkList
// Detected pose landmarks.
// WORLD_LANDMARKS - LandmarkList
// Detected pose landmarks in world coordinates.
// AUXILIARY_LANDMARKS - NormalizedLandmarkList
// Detected pose auxiliary landmarks.
// POSE_RECT_NEXT_FRAME - NormalizedRect
// The predicted Rect enclosing the pose RoI for landmark detection on the
// next frame.
// PRESENCE - bool
// Boolean value indicates whether the pose is present.
// PRESENCE_SCORE - float
// Float value indicates the probability that the pose is present.
// SEGMENTATION_MASK - Image
// Segmentation mask for pose.
//
// Example:
// node {
// calculator:
// "mediapipe.tasks.vision.pose_landmarker.SingleposeLandmarksDetectorGraph"
// input_stream: "IMAGE:input_image"
// input_stream: "POSE_RECT:pose_rect"
// output_stream: "LANDMARKS:pose_landmarks"
// output_stream: "WORLD_LANDMARKS:world_pose_landmarks"
// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
// output_stream: "POSE_RECT_NEXT_FRAME:pose_rect_next_frame"
// output_stream: "PRESENCE:pose_presence"
// output_stream: "PRESENCE_SCORE:pose_presence_score"
// output_stream: "SEGMENTATION_MASK:segmentation_mask"
// options {
// [mediapipe.tasks.vision.pose_landmarker.proto.poseLandmarksDetectorGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "pose_landmark_lite.tflite"
// }
// }
// min_detection_confidence: 0.5
// }
// }
// }
class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(
const auto* model_resources,
CreateModelResources<PoseLandmarksDetectorGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
auto pose_landmark_detection_outs,
BuildSinglePoseLandmarksDetectorGraph(
sc->Options<PoseLandmarksDetectorGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
pose_landmark_detection_outs.pose_landmarks >>
graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
pose_landmark_detection_outs.world_pose_landmarks >>
graph[Output<LandmarkList>(kWorldLandmarksTag)];
pose_landmark_detection_outs.auxiliary_pose_landmarks >>
graph[Output<NormalizedLandmarkList>(kAuxLandmarksTag)];
pose_landmark_detection_outs.pose_rect_next_frame >>
graph[Output<NormalizedRect>(kPoseRectNextFrameTag)];
pose_landmark_detection_outs.pose_presence >>
graph[Output<bool>(kPresenceTag)];
pose_landmark_detection_outs.pose_presence_score >>
graph[Output<float>(kPresenceScoreTag)];
pose_landmark_detection_outs.segmentation_mask >>
graph[Output<Image>(kSegmentationMaskTag)];
return graph.GetConfig();
}
private:
absl::StatusOr<SinglePoseLandmarkerOutputs>
BuildSinglePoseLandmarksDetectorGraph(
const PoseLandmarksDetectorGraphOptions& subgraph_options,
const ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> pose_rect, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag);
pose_rect >> preprocessing.In(kNormRectTag);
auto image_size = preprocessing[Output<std::pair<int, int>>(kImageSizeTag)];
ASSIGN_OR_RETURN(auto image_tensor_specs,
BuildInputImageTensorSpecs(model_resources));
auto& inference = AddInference(
model_resources, subgraph_options.base_options().acceleration(), graph);
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
// Split model output tensors to multiple streams.
auto& split_tensors_vector = graph.AddNode("SplitTensorVectorCalculator");
ConfigureSplitTensorVectorCalculator(
&split_tensors_vector
.GetOptions<mediapipe::SplitVectorCalculatorOptions>());
inference.Out(kTensorsTag) >> split_tensors_vector.In("");
auto landmark_tensors = split_tensors_vector.Out(0);
auto pose_flag_tensors = split_tensors_vector.Out(1);
auto segmentation_tensors = split_tensors_vector.Out(2);
auto heatmap_tensors = split_tensors_vector.Out(3);
auto world_landmark_tensors = split_tensors_vector.Out(4);
// Converts the pose-flag tensor into a float that represents the confidence
// score of pose presence.
auto& tensors_to_pose_presence = graph.AddNode("TensorsToFloatsCalculator");
pose_flag_tensors >> tensors_to_pose_presence.In(kTensorsTag);
auto pose_presence_score =
tensors_to_pose_presence[Output<float>(kFloatTag)];
// Applies a threshold to the confidence score to determine whether a
// pose is present.
auto& pose_presence_thresholding = graph.AddNode("ThresholdingCalculator");
pose_presence_thresholding
.GetOptions<mediapipe::ThresholdingCalculatorOptions>()
.set_threshold(subgraph_options.min_detection_confidence());
pose_presence_score >> pose_presence_thresholding.In(kFloatTag);
auto pose_presence = pose_presence_thresholding[Output<bool>(kFlagTag)];
// GateCalculator for tensors.
auto& tensors_gate = graph.AddNode("GateCalculator");
landmark_tensors >> tensors_gate.In("")[0];
segmentation_tensors >> tensors_gate.In("")[1];
heatmap_tensors >> tensors_gate.In("")[2];
world_landmark_tensors >> tensors_gate.In("")[3];
pose_presence >> tensors_gate.In("ALLOW");
auto ensured_landmarks_tensors = tensors_gate.Out(0);
auto ensured_segmentation_tensors = tensors_gate.Out(1);
auto ensured_heatmap_tensors = tensors_gate.Out(2);
auto ensured_world_landmark_tensors = tensors_gate.Out(3);
// Decodes the landmark tensors into a list of landmarks, where the landmark
// coordinates are normalized by the size of the input image to the model.
auto& tensors_to_landmarks = graph.AddNode("TensorsToLandmarksCalculator");
ConfigureTensorsToLandmarksCalculator(
image_tensor_specs, /* normalize = */ false,
/*sigmoid_activation= */ true,
&tensors_to_landmarks
.GetOptions<mediapipe::TensorsToLandmarksCalculatorOptions>());
ensured_landmarks_tensors >> tensors_to_landmarks.In(kTensorsTag);
auto landmarks =
tensors_to_landmarks[Output<NormalizedLandmarkList>(kNormLandmarksTag)];
// Decodes the segmentation tensor into a mask image with pixel values in
// [0, 1] (1 for person and 0 for background).
auto& tensors_to_segmentation =
graph.AddNode("TensorsToSegmentationCalculator");
ConfigureTensorsToSegmentationCalculator(
&tensors_to_segmentation
.GetOptions<mediapipe::TensorsToSegmentationCalculatorOptions>());
ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
auto segmentation_mask = tensors_to_segmentation[Output<Image>(kMaskTag)];
// Refines landmarks with the heatmap tensor.
auto& refine_landmarks_from_heatmap =
graph.AddNode("RefineLandmarksFromHeatmapCalculator");
ConfigureRefineLandmarksFromHeatmapCalculator(
&refine_landmarks_from_heatmap.GetOptions<
mediapipe::RefineLandmarksFromHeatmapCalculatorOptions>());
ensured_heatmap_tensors >> refine_landmarks_from_heatmap.In(kTensorsTag);
landmarks >> refine_landmarks_from_heatmap.In(kNormLandmarksTag);
auto landmarks_from_heatmap =
refine_landmarks_from_heatmap[Output<NormalizedLandmarkList>(
kNormLandmarksTag)];
// Splits the landmarks into two sets: the actual pose landmarks and the
// auxiliary landmarks.
auto& split_normalized_landmark_list =
graph.AddNode("SplitNormalizedLandmarkListCalculator");
ConfigureSplitNormalizedLandmarkListCalculator(
&split_normalized_landmark_list
.GetOptions<mediapipe::SplitVectorCalculatorOptions>());
landmarks_from_heatmap >> split_normalized_landmark_list.In("");
auto normalized_landmarks = split_normalized_landmark_list.Out("")[0]
.Cast<NormalizedLandmarkList>();
auto normalized_auxiliary_landmarks =
split_normalized_landmark_list.Out("")[1]
.Cast<NormalizedLandmarkList>();
// Decodes the world-landmark tensors into a vector of world landmarks.
auto& tensors_to_world_landmarks =
graph.AddNode("TensorsToLandmarksCalculator");
ConfigureTensorsToLandmarksCalculator(
image_tensor_specs, /* normalize = */ false,
/* sigmoid_activation= */ false,
&tensors_to_world_landmarks
.GetOptions<mediapipe::TensorsToLandmarksCalculatorOptions>());
ensured_world_landmark_tensors >>
tensors_to_world_landmarks.In(kTensorsTag);
auto world_landmarks =
tensors_to_world_landmarks[Output<LandmarkList>(kLandmarksTag)];
// Keeps only the actual world landmarks.
auto& split_landmark_list = graph.AddNode("SplitLandmarkListCalculator");
ConfigureSplitLandmarkListCalculator(
&split_landmark_list
.GetOptions<mediapipe::SplitVectorCalculatorOptions>());
world_landmarks >> split_landmark_list.In("");
auto split_landmarks = split_landmark_list.Out(0);
// Reuses the visibility and presence field in pose landmarks for the world
// landmarks.
auto& visibility_copy = graph.AddNode("VisibilityCopyCalculator");
ConfigureVisibilityCopyCalculator(
&visibility_copy
.GetOptions<mediapipe::VisibilityCopyCalculatorOptions>());
split_landmarks >> visibility_copy.In(kLandmarksToTag);
normalized_landmarks >> visibility_copy.In(kNormLandmarksFromTag);
auto world_landmarks_with_visibility =
visibility_copy[Output<LandmarkList>(kLandmarksToTag)];
// Landmarks to Detections.
auto& landmarks_to_detection =
graph.AddNode("LandmarksToDetectionCalculator");
landmarks >> landmarks_to_detection.In(kNormLandmarksTag);
auto detection = landmarks_to_detection.Out(kDetectionTag);
// Detections to Rects.
auto& detection_to_rects = graph.AddNode("DetectionsToRectsCalculator");
image_size >> detection_to_rects.In(kImageSizeTag);
detection >> detection_to_rects.In(kDetectionTag);
auto norm_rect = detection_to_rects.Out(kNormRectTag);
// Expands the pose rectangle so that in the next video frame it's likely to
// still contain the pose even with some motion.
auto& pose_rect_transformation =
graph.AddNode("RectTransformationCalculator");
image_size >> pose_rect_transformation.In(kImageSizeTag);
norm_rect >> pose_rect_transformation.In(kNormRectTag);
auto pose_rect_next_frame =
pose_rect_transformation[Output<NormalizedRect>("")];
return {{
/* pose_landmarks= */ normalized_landmarks,
/* world_pose_landmarks= */ world_landmarks_with_visibility,
/* auxiliary_pose_landmarks= */ normalized_auxiliary_landmarks,
/* pose_rect_next_frame= */ pose_rect_next_frame,
/* pose_presence= */ pose_presence,
/* pose_presence_score= */ pose_presence_score,
/* segmentation_mask= */ segmentation_mask,
}};
}
};
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::pose_landmarker::SinglePoseLandmarksDetectorGraph); // NOLINT
// clang-format on
// A "mediapipe.tasks.vision.pose_landmarker.MultiplePoseLandmarksDetectorGraph"
// performs multi pose landmark detection.
// - Accepts CPU input image and a vector of pose rect RoIs to detect the
// multiple poses landmarks enclosed by the RoIs. Output vectors of
// pose landmarks related results, where each element in the vectors
// corresponds to the result of the same pose.
//
// Inputs:
// IMAGE - Image
// Image to perform detection on.
// NORM_RECT - std::vector<NormalizedRect>
// A vector of multiple pose rects enclosing the pose RoI to perform
// landmarks detection on.
//
//
// Outputs:
// LANDMARKS: - std::vector<NormalizedLandmarkList>
// Vector of detected pose landmarks.
// WORLD_LANDMARKS - std::vector<LandmarkList>
// Vector of detected pose landmarks in world coordinates.
// AUXILIARY_LANDMARKS - std::vector<NormalizedLandmarkList>
// Vector of detected pose auxiliary landmarks.
// POSE_RECT_NEXT_FRAME - std::vector<NormalizedRect>
// Vector of the predicted rects enclosing the same pose RoI for landmark
// detection on the next frame.
// PRESENCE - std::vector<bool>
// Vector of boolean value indicates whether the pose is present.
// PRESENCE_SCORE - std::vector<float>
// Vector of float value indicates the probability that the pose is present.
// SEGMENTATION_MASK - std::vector<Image>
// Vector of segmentation masks.
//
// Example:
// node {
// calculator:
// "mediapipe.tasks.vision.pose_landmarker.MultiplePoseLandmarksDetectorGraph"
// input_stream: "IMAGE:input_image"
// input_stream: "POSE_RECT:pose_rect"
// output_stream: "LANDMARKS:pose_landmarks"
// output_stream: "WORLD_LANDMARKS:world_pose_landmarks"
// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
// output_stream: "POSE_RECT_NEXT_FRAME:pose_rect_next_frame"
// output_stream: "PRESENCE:pose_presence"
// output_stream: "PRESENCE_SCORE:pose_presence_score"
// output_stream: "SEGMENTATION_MASK:segmentation_mask"
// options {
// [mediapipe.tasks.vision.pose_landmarker.proto.PoseLandmarksDetectorGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "pose_landmark_lite.tflite"
// }
// }
// min_detection_confidence: 0.5
// }
// }
// }
class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
ASSIGN_OR_RETURN(
auto pose_landmark_detection_outputs,
BuildPoseLandmarksDetectorGraph(
sc->Options<PoseLandmarksDetectorGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph));
pose_landmark_detection_outputs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
pose_landmark_detection_outputs.world_landmark_lists >>
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
pose_landmark_detection_outputs.auxiliary_landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kAuxLandmarksTag)];
pose_landmark_detection_outputs.pose_rects_next_frame >>
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
pose_landmark_detection_outputs.presences >>
graph[Output<std::vector<bool>>(kPresenceTag)];
pose_landmark_detection_outputs.presence_scores >>
graph[Output<std::vector<float>>(kPresenceScoreTag)];
pose_landmark_detection_outputs.segmentation_masks >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
return graph.GetConfig();
}
private:
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarksDetectorGraph(
const PoseLandmarksDetectorGraphOptions& subgraph_options,
Source<Image> image_in,
Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph) {
auto& begin_loop_multi_pose_rects =
graph.AddNode("BeginLoopNormalizedRectCalculator");
image_in >> begin_loop_multi_pose_rects.In("CLONE");
multi_pose_rects >> begin_loop_multi_pose_rects.In("ITERABLE");
auto batch_end = begin_loop_multi_pose_rects.Out("BATCH_END");
auto image = begin_loop_multi_pose_rects.Out("CLONE");
auto pose_rect = begin_loop_multi_pose_rects.Out("ITEM");
auto& pose_landmark_subgraph = graph.AddNode(
"mediapipe.tasks.vision.pose_landmarker."
"SinglePoseLandmarksDetectorGraph");
pose_landmark_subgraph.GetOptions<PoseLandmarksDetectorGraphOptions>()
.CopyFrom(subgraph_options);
image >> pose_landmark_subgraph.In(kImageTag);
pose_rect >> pose_landmark_subgraph.In(kNormRectTag);
auto landmarks = pose_landmark_subgraph.Out(kLandmarksTag);
auto world_landmarks = pose_landmark_subgraph.Out(kWorldLandmarksTag);
auto auxiliary_landmarks = pose_landmark_subgraph.Out(kAuxLandmarksTag);
auto pose_rect_next_frame =
pose_landmark_subgraph.Out(kPoseRectNextFrameTag);
auto presence = pose_landmark_subgraph.Out(kPresenceTag);
auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag);
auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
auto& end_loop_landmarks =
graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
batch_end >> end_loop_landmarks.In(kBatchEndTag);
landmarks >> end_loop_landmarks.In(kItemTag);
auto landmark_lists =
end_loop_landmarks[Output<std::vector<NormalizedLandmarkList>>(
kIterableTag)];
auto& end_loop_world_landmarks =
graph.AddNode("EndLoopLandmarkListVectorCalculator");
batch_end >> end_loop_world_landmarks.In(kBatchEndTag);
world_landmarks >> end_loop_world_landmarks.In(kItemTag);
auto world_landmark_lists =
end_loop_world_landmarks[Output<std::vector<LandmarkList>>(
kIterableTag)];
auto& end_loop_auxiliary_landmarks =
graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
batch_end >> end_loop_auxiliary_landmarks.In(kBatchEndTag);
auxiliary_landmarks >> end_loop_auxiliary_landmarks.In(kItemTag);
auto auxiliary_landmark_lists = end_loop_auxiliary_landmarks
[Output<std::vector<NormalizedLandmarkList>>(kIterableTag)];
auto& end_loop_rects_next_frame =
graph.AddNode("EndLoopNormalizedRectCalculator");
batch_end >> end_loop_rects_next_frame.In(kBatchEndTag);
pose_rect_next_frame >> end_loop_rects_next_frame.In(kItemTag);
auto pose_rects_next_frame =
end_loop_rects_next_frame[Output<std::vector<NormalizedRect>>(
kIterableTag)];
auto& end_loop_presence = graph.AddNode("EndLoopBooleanCalculator");
batch_end >> end_loop_presence.In(kBatchEndTag);
presence >> end_loop_presence.In(kItemTag);
auto presences = end_loop_presence[Output<std::vector<bool>>(kIterableTag)];
auto& end_loop_presence_score = graph.AddNode("EndLoopFloatCalculator");
batch_end >> end_loop_presence_score.In(kBatchEndTag);
presence_score >> end_loop_presence_score.In(kItemTag);
auto presence_scores =
end_loop_presence_score[Output<std::vector<float>>(kIterableTag)];
auto& end_loop_segmentation_mask = graph.AddNode("EndLoopImageCalculator");
batch_end >> end_loop_segmentation_mask.In(kBatchEndTag);
segmentation_mask >> end_loop_segmentation_mask.In(kItemTag);
auto segmentation_masks =
end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
return {{
/* landmark_lists= */ landmark_lists,
/* world_landmark_lists= */ world_landmark_lists,
/* auxiliary_landmark_lists= */ auxiliary_landmark_lists,
/* pose_rects_next_frame= */ pose_rects_next_frame,
/* presences= */ presences,
/* presence_scores= */ presence_scores,
/* segmentation_masks= */ segmentation_masks,
}};
}
};
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::pose_landmarker::MultiplePoseLandmarksDetectorGraph); // NOLINT
// clang-format on
} // namespace pose_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,366 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace pose_landmarker {
namespace {
using ::file::Defaults;
using ::file::GetTextProto;
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::core::TaskRunner;
using ::mediapipe::tasks::vision::DecodeImageFromFile;
using ::mediapipe::tasks::vision::pose_landmarker::proto::
PoseLandmarksDetectorGraphOptions;
using ::testing::ElementsAreArray;
using ::testing::EqualsProto;
using ::testing::Pointwise;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kPoseLandmarkerLiteModel[] = "pose_landmark_lite.tflite";
constexpr char kPoseImage[] = "pose.jpg";
constexpr char kBurgerImage[] = "burger.jpg";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageName[] = "image_in";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kPoseRectName[] = "pose_rect_in";
constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kLandmarksName[] = "landmarks";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kWorldLandmarksName[] = "world_landmarks";
constexpr char kAuxLandmarksTag[] = "AUXILIARY_LANDMARKS";
constexpr char kAuxLandmarksName[] = "auxiliary_landmarks";
constexpr char kPoseRectNextFrameTag[] = "POSE_RECT_NEXT_FRAME";
constexpr char kPoseRectNextFrameName[] = "pose_rect_next_frame";
constexpr char kPoseRectsNextFrameTag[] = "POSE_RECTS_NEXT_FRAME";
constexpr char kPoseRectsNextFrameName[] = "pose_rects_next_frame";
constexpr char kPresenceTag[] = "PRESENCE";
constexpr char kPresenceName[] = "presence";
constexpr char kPresenceScoreTag[] = "PRESENCE_SCORE";
constexpr char kPresenceScoreName[] = "presence_score";
constexpr char kSegmentationMaskTag[] = "SEGMENTATION_MASK";
constexpr char kSegmentationMaskName[] = "segmentation_mask";
// Expected pose landmarks positions, in text proto format.
constexpr char kExpectedPoseLandmarksFilename[] =
"expected_pose_landmarks.prototxt";
constexpr float kLiteModelFractionDiff = 0.05; // percentage
constexpr float kAbsMargin = 0.03;
// Helper function to create a Single Pose Landmark TaskRunner.
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSinglePoseTaskRunner(
absl::string_view model_name) {
Graph graph;
auto& pose_landmark_detection = graph.AddNode(
"mediapipe.tasks.vision.pose_landmarker."
"SinglePoseLandmarksDetectorGraph");
auto options = std::make_unique<PoseLandmarksDetectorGraphOptions>();
options->mutable_base_options()->mutable_model_asset()->set_file_name(
JoinPath("./", kTestDataDirectory, model_name));
pose_landmark_detection.GetOptions<PoseLandmarksDetectorGraphOptions>().Swap(
options.get());
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
pose_landmark_detection.In(kImageTag);
graph[Input<NormalizedRect>(kNormRectTag)].SetName(kPoseRectName) >>
pose_landmark_detection.In(kNormRectTag);
pose_landmark_detection.Out(kLandmarksTag).SetName(kLandmarksName) >>
graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
pose_landmark_detection.Out(kWorldLandmarksTag)
.SetName(kWorldLandmarksName) >>
graph[Output<LandmarkList>(kWorldLandmarksTag)];
pose_landmark_detection.Out(kAuxLandmarksTag).SetName(kAuxLandmarksName) >>
graph[Output<LandmarkList>(kAuxLandmarksTag)];
pose_landmark_detection.Out(kPresenceTag).SetName(kPresenceName) >>
graph[Output<bool>(kPresenceTag)];
pose_landmark_detection.Out(kPresenceScoreTag).SetName(kPresenceScoreName) >>
graph[Output<float>(kPresenceScoreTag)];
pose_landmark_detection.Out(kSegmentationMaskTag)
.SetName(kSegmentationMaskName) >>
graph[Output<Image>(kSegmentationMaskTag)];
pose_landmark_detection.Out(kPoseRectNextFrameTag)
.SetName(kPoseRectNextFrameName) >>
graph[Output<NormalizedRect>(kPoseRectNextFrameTag)];
return TaskRunner::Create(
graph.GetConfig(),
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
}
// Helper function to create a Multi Pose Landmark TaskRunner.
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiPoseTaskRunner(
absl::string_view model_name) {
Graph graph;
auto& multi_pose_landmark_detection = graph.AddNode(
"mediapipe.tasks.vision.pose_landmarker."
"MultiplePoseLandmarksDetectorGraph");
auto options = std::make_unique<PoseLandmarksDetectorGraphOptions>();
options->mutable_base_options()->mutable_model_asset()->set_file_name(
JoinPath("./", kTestDataDirectory, model_name));
multi_pose_landmark_detection.GetOptions<PoseLandmarksDetectorGraphOptions>()
.Swap(options.get());
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
multi_pose_landmark_detection.In(kImageTag);
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)].SetName(
kPoseRectName) >>
multi_pose_landmark_detection.In(kNormRectTag);
multi_pose_landmark_detection.Out(kLandmarksTag).SetName(kLandmarksName) >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
multi_pose_landmark_detection.Out(kWorldLandmarksTag)
.SetName(kWorldLandmarksName) >>
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
multi_pose_landmark_detection.Out(kAuxLandmarksTag)
.SetName(kAuxLandmarksName) >>
graph[Output<std::vector<NormalizedLandmarkList>>(kAuxLandmarksTag)];
multi_pose_landmark_detection.Out(kPresenceTag).SetName(kPresenceName) >>
graph[Output<std::vector<bool>>(kPresenceTag)];
multi_pose_landmark_detection.Out(kPresenceScoreTag)
.SetName(kPresenceScoreName) >>
graph[Output<std::vector<float>>(kPresenceScoreTag)];
multi_pose_landmark_detection.Out(kSegmentationMaskTag)
.SetName(kSegmentationMaskName) >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
multi_pose_landmark_detection.Out(kPoseRectsNextFrameTag)
.SetName(kPoseRectsNextFrameName) >>
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
return TaskRunner::Create(
graph.GetConfig(),
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
}
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
NormalizedLandmarkList expected_landmark_list;
MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename),
&expected_landmark_list, Defaults()));
return expected_landmark_list;
}
// Struct holding the parameters for parameterized PoseLandmarkerTest
// class.
struct SinglePoseTestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of the model to test.
std::string input_model_name;
// The filename of the test image.
std::string test_image_name;
// RoI on image to detect pose.
NormalizedRect pose_rect;
// Expected pose presence value.
bool expected_presence;
// The expected output landmarks positions in pixels coornidates.
std::optional<NormalizedLandmarkList> expected_landmarks;
// The expected segmentation mask.
Image expected_segmentation_mask;
// The max value difference between expected_positions and detected positions.
float landmarks_diff_threshold;
};
struct MultiPoseTestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of the model to test.
std::string input_model_name;
// The filename of the test image.
std::string test_image_name;
// RoIs on image to detect poses.
std::vector<NormalizedRect> pose_rects;
// Expected pose presence values.
std::vector<bool> expected_presences;
// The expected output landmarks positions in pixels coornidates.
std::vector<NormalizedLandmarkList> expected_landmark_lists;
// The expected segmentation_mask Image.
std::vector<Image> expected_segmentation_masks;
// The max value difference between expected_positions and detected positions.
float landmarks_diff_threshold;
};
// Helper function to construct NormalizeRect proto.
NormalizedRect MakePoseRect(float x_center, float y_center, float width,
float height, float rotation) {
NormalizedRect pose_rect;
pose_rect.set_x_center(x_center);
pose_rect.set_y_center(y_center);
pose_rect.set_width(width);
pose_rect.set_height(height);
pose_rect.set_rotation(rotation);
return pose_rect;
}
class PoseLandmarkerTest : public testing::TestWithParam<SinglePoseTestParams> {
};
TEST_P(PoseLandmarkerTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateSinglePoseTaskRunner(
GetParam().input_model_name));
auto output_packets = task_runner->Process(
{{kImageName, MakePacket<Image>(std::move(image))},
{kPoseRectName,
MakePacket<NormalizedRect>(std::move(GetParam().pose_rect))}});
MP_ASSERT_OK(output_packets);
const bool presence = (*output_packets)[kPresenceName].Get<bool>();
ASSERT_EQ(presence, GetParam().expected_presence);
if (presence) {
const NormalizedLandmarkList landmarks =
(*output_packets)[kLandmarksName].Get<NormalizedLandmarkList>();
if (GetParam().expected_landmarks.has_value()) {
const NormalizedLandmarkList& expected_landmarks =
GetParam().expected_landmarks.value();
EXPECT_THAT(
landmarks,
Approximately(Partially(EqualsProto(expected_landmarks)),
/*margin=*/kAbsMargin,
/*fraction=*/GetParam().landmarks_diff_threshold));
}
}
}
class MultiPoseLandmarkerTest
: public testing::TestWithParam<MultiPoseTestParams> {};
TEST_P(MultiPoseLandmarkerTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
MP_ASSERT_OK_AND_ASSIGN(
auto task_runner, CreateMultiPoseTaskRunner(GetParam().input_model_name));
auto output_packets = task_runner->Process(
{{kImageName, MakePacket<Image>(std::move(image))},
{kPoseRectName, MakePacket<std::vector<NormalizedRect>>(
std::move(GetParam().pose_rects))}});
MP_ASSERT_OK(output_packets);
const std::vector<bool>& presences =
(*output_packets)[kPresenceName].Get<std::vector<bool>>();
const std::vector<NormalizedLandmarkList>& landmark_lists =
(*output_packets)[kLandmarksName]
.Get<std::vector<NormalizedLandmarkList>>();
EXPECT_THAT(presences, ElementsAreArray(GetParam().expected_presences));
EXPECT_THAT(
landmark_lists,
Pointwise(Approximately(Partially(EqualsProto()),
/*margin=*/kAbsMargin,
/*fraction=*/GetParam().landmarks_diff_threshold),
GetParam().expected_landmark_lists));
}
// TODO: Add additional tests for MP Tasks Pose Graphs.
INSTANTIATE_TEST_SUITE_P(
PoseLandmarkerTest, PoseLandmarkerTest,
Values(
SinglePoseTestParams{
.test_name = "PoseLandmarkerLiteModel",
.input_model_name = kPoseLandmarkerLiteModel,
.test_image_name = kPoseImage,
.pose_rect = MakePoseRect(0.25, 0.5, 0.5, 1.0, 0),
.expected_presence = true,
.expected_landmarks =
GetExpectedLandmarkList(kExpectedPoseLandmarksFilename),
.landmarks_diff_threshold = kLiteModelFractionDiff},
SinglePoseTestParams{
.test_name = "PoseLandmarkerLiteModelNoPose",
.input_model_name = kPoseLandmarkerLiteModel,
.test_image_name = kBurgerImage,
.pose_rect = MakePoseRect(0.25, 0.5, 0.5, 1.0, 0),
.expected_presence = false,
.expected_landmarks = std::nullopt,
.landmarks_diff_threshold = kLiteModelFractionDiff}),
[](const TestParamInfo<PoseLandmarkerTest::ParamType>& info) {
return info.param.test_name;
});
INSTANTIATE_TEST_SUITE_P(
MultiPoseLandmarkerTest, MultiPoseLandmarkerTest,
Values(MultiPoseTestParams{
.test_name = "MultiPoseLandmarkerLiteModel",
.input_model_name = kPoseLandmarkerLiteModel,
.test_image_name = kPoseImage,
.pose_rects = {MakePoseRect(0.25, 0.5, 0.5, 1.0, 0)},
.expected_presences = {true},
.expected_landmark_lists = {GetExpectedLandmarkList(
kExpectedPoseLandmarksFilename)},
.landmarks_diff_threshold = kLiteModelFractionDiff,
}),
[](const TestParamInfo<MultiPoseLandmarkerTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace pose_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,31 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
mediapipe_proto_library(
name = "pose_landmarks_detector_graph_options_proto",
srcs = ["pose_landmarks_detector_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -0,0 +1,38 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.vision.pose_landmarker.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.poselandmarker.proto";
option java_outer_classname = "PoseLandmarksDetectorGraphOptionsProto";
message PoseLandmarksDetectorGraphOptions {
extend mediapipe.CalculatorOptions {
optional PoseLandmarksDetectorGraphOptions ext = 518928384;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// Minimum confidence value ([0.0, 1.0]) for pose presence score to be
// considered successfully detecting a pose in the image.
optional float min_detection_confidence = 2 [default = 0.5];
}

View File

@ -77,6 +77,7 @@ mediapipe_files(srcs = [
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg", "portrait_selfie_segmentation_landscape_expected_category_mask.jpg",
"pose.jpg", "pose.jpg",
"pose_detection.tflite", "pose_detection.tflite",
"pose_landmark_lite.tflite",
"right_hands.jpg", "right_hands.jpg",
"right_hands_rotated.jpg", "right_hands_rotated.jpg",
"segmentation_golden_rotation0.png", "segmentation_golden_rotation0.png",
@ -185,6 +186,7 @@ filegroup(
"mobilenet_v3_small_100_224_embedder.tflite", "mobilenet_v3_small_100_224_embedder.tflite",
"palm_detection_full.tflite", "palm_detection_full.tflite",
"pose_detection.tflite", "pose_detection.tflite",
"pose_landmark_lite.tflite",
"selfie_segm_128_128_3.tflite", "selfie_segm_128_128_3.tflite",
"selfie_segm_144_256_3.tflite", "selfie_segm_144_256_3.tflite",
"selfie_segmentation.tflite", "selfie_segmentation.tflite",
@ -199,6 +201,7 @@ filegroup(
"expected_left_down_hand_rotated_landmarks.prototxt", "expected_left_down_hand_rotated_landmarks.prototxt",
"expected_left_up_hand_landmarks.prototxt", "expected_left_up_hand_landmarks.prototxt",
"expected_left_up_hand_rotated_landmarks.prototxt", "expected_left_up_hand_rotated_landmarks.prototxt",
"expected_pose_landmarks.prototxt",
"expected_right_down_hand_landmarks.prototxt", "expected_right_down_hand_landmarks.prototxt",
"expected_right_up_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt",
"face_geometry_expected_out.pbtxt", "face_geometry_expected_out.pbtxt",

View File

@ -0,0 +1,231 @@
landmark {
x: 0.5033445
y: 0.43453366
z: 0.14119335
visibility: 0.9981027
presence: 0.9992255
}
landmark {
x: 0.49945074
y: 0.42642897
z: 0.1261509
visibility: 0.99845314
presence: 0.9988668
}
landmark {
x: 0.49838358
y: 0.42594656
z: 0.12615599
visibility: 0.9983102
presence: 0.9988244
}
landmark {
x: 0.49738216
y: 0.42554975
z: 0.12622544
visibility: 0.99852186
presence: 0.99873894
}
landmark {
x: 0.5031697
y: 0.42662272
z: 0.13347684
visibility: 0.9978422
presence: 0.9989411
}
landmark {
x: 0.5052138
y: 0.42640454
z: 0.13340297
visibility: 0.9973635
presence: 0.99888366
}
landmark {
x: 0.50699204
y: 0.42618936
z: 0.13343208
visibility: 0.9977912
presence: 0.9987791
}
landmark {
x: 0.50158876
y: 0.42372653
z: 0.075136274
visibility: 0.9976786
presence: 0.99851876
}
landmark {
x: 0.51374334
y: 0.42491674
z: 0.10752227
visibility: 0.99754214
presence: 0.99863297
}
landmark {
x: 0.5059119
y: 0.43917143
z: 0.12961307
visibility: 0.9958307
presence: 0.9977532
}
landmark {
x: 0.5109223
y: 0.43983036
z: 0.13892516
visibility: 0.99557614
presence: 0.9976047
}
landmark {
x: 0.5023406
y: 0.45419085
z: 0.045139182
visibility: 0.99948055
presence: 0.99782556
}
landmark {
x: 0.5593891
y: 0.44672203
z: 0.09991482
visibility: 0.99833965
presence: 0.9977375
}
landmark {
x: 0.49746004
y: 0.4821021
z: 0.037247833
visibility: 0.80266625
presence: 0.9972971
}
landmark {
x: 0.57370883
y: 0.47189793
z: 0.13776684
visibility: 0.5926873
presence: 0.9988753
}
landmark {
x: 0.5110358
y: 0.44240227
z: 0.040228914
visibility: 0.73293996
presence: 0.9983026
}
landmark {
x: 0.55913407
y: 0.45325255
z: 0.15570506
visibility: 0.6228974
presence: 0.998798
}
landmark {
x: 0.5122394
y: 0.42851248
z: 0.022895059
visibility: 0.67662305
presence: 0.9976555
}
landmark {
x: 0.5534328
y: 0.45476273
z: 0.13998204
visibility: 0.5879434
presence: 0.99810314
}
landmark {
x: 0.5166826
y: 0.42873073
z: 0.013023325
visibility: 0.66849846
presence: 0.9973978
}
landmark {
x: 0.54080456
y: 0.45333704
z: 0.13224734
visibility: 0.58104885
presence: 0.99810755
}
landmark {
x: 0.5170599
y: 0.43338987
z: 0.03576146
visibility: 0.65676475
presence: 0.99781895
}
landmark {
x: 0.550342
y: 0.4529801
z: 0.15014008
visibility: 0.57411015
presence: 0.9985139
}
landmark {
x: 0.5236847
y: 0.5765062
z: -0.03329975
visibility: 0.9998833
presence: 0.99935263
}
landmark {
x: 0.55076087
y: 0.57722
z: 0.033408616
visibility: 0.99978465
presence: 0.9994066
}
landmark {
x: 0.56554604
y: 0.68362844
z: -0.1572319
visibility: 0.8825664
presence: 0.99819404
}
landmark {
x: 0.6127089
y: 0.6976384
z: 0.034268316
visibility: 0.6231058
presence: 0.9986853
}
landmark {
x: 0.63440466
y: 0.85055786
z: -0.21429145
visibility: 0.93542594
presence: 0.99297804
}
landmark {
x: 0.68133575
y: 0.8420663
z: 0.024820065
visibility: 0.8150173
presence: 0.994477
}
landmark {
x: 0.65322053
y: 0.8746961
z: -0.22213714
visibility: 0.89239687
presence: 0.986766
}
landmark {
x: 0.69881815
y: 0.862424
z: 0.018177645
visibility: 0.7516772
presence: 0.98905355
}
landmark {
x: 0.62477547
y: 0.9208759
z: -0.31597853
visibility: 0.92574716
presence: 0.9792296
}
landmark {
x: 0.6807446
y: 0.8746925
z: -0.059982482
visibility: 0.8126682
presence: 0.98336893
}

View File

@ -70,6 +70,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=166187566369397616783235763936531678737479599640"], urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=166187566369397616783235763936531678737479599640"],
) )
http_file(
name = "com_google_mediapipe_BUILD_orig",
sha256 = "d86b98b82e00dd87cd46bd1429bf5eaa007b500c1a24d9316b73309f2e6c8df8",
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1679955080207504"],
)
http_file( http_file(
name = "com_google_mediapipe_burger_crop_jpg", name = "com_google_mediapipe_burger_crop_jpg",
sha256 = "8f58de573f0bf59a49c3d86cfabb9ad4061481f574aa049177e8da3963dddc50", sha256 = "8f58de573f0bf59a49c3d86cfabb9ad4061481f574aa049177e8da3963dddc50",
@ -298,6 +304,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666629476401757"], urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666629476401757"],
) )
http_file(
name = "com_google_mediapipe_expected_pose_landmarks_prototxt",
sha256 = "75dfd2825fc23f51e3906f3a0a050caa8ae9f502cc358af1e7c9fda7ea89c9a5",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=1679955083449778"],
)
http_file( http_file(
name = "com_google_mediapipe_expected_right_down_hand_landmarks_prototxt", name = "com_google_mediapipe_expected_right_down_hand_landmarks_prototxt",
sha256 = "f281b745175aaa7f458def6cf4c89521fb56302dd61a05642b3b4a4f237ffaa3", sha256 = "f281b745175aaa7f458def6cf4c89521fb56302dd61a05642b3b4a4f237ffaa3",
@ -942,8 +954,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_pose_landmark_lite_tflite", name = "com_google_mediapipe_pose_landmark_lite_tflite",
sha256 = "f17bfbecadb61c3be1baa8b8d851cc6619c870a87167b32848ad20db306b9d61", sha256 = "13628a7d1c1a0f601ae7202c71ec8edc3ac42db9d15f116c494ff24d1afabdd7",
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmark_lite.tflite?generation=1661875901231143"], urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmark_lite.tflite?generation=1679955090327685"],
) )
http_file( http_file(
@ -1216,6 +1228,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/victory_landmarks.pbtxt?generation=1666999366036622"], urls = ["https://storage.googleapis.com/mediapipe-assets/victory_landmarks.pbtxt?generation=1666999366036622"],
) )
http_file(
name = "com_google_mediapipe_vit_multiclass_256x256-2022_10_14-xenoformer_f32_tflite",
sha256 = "768b7dd613c5b9f263b289cdbe1b9bc716f65e92c51e7ae57fae01f97f9658cd",
urls = ["https://storage.googleapis.com/mediapipe-assets/vit_multiclass_256x256-2022_10_14-xenoformer.f32.tflite?generation=1679955093986149"],
)
http_file(
name = "com_google_mediapipe_vit_multiclass_512x512-2022_12_02_f32_tflite",
sha256 = "7ac7f0a037cd451b9be8eb25da86339aaba54fa821a0bd44e18768866ed0205a",
urls = ["https://storage.googleapis.com/mediapipe-assets/vit_multiclass_512x512-2022_12_02.f32.tflite?generation=1679955096856280"],
)
http_file( http_file(
name = "com_google_mediapipe_vocab_for_regex_tokenizer_txt", name = "com_google_mediapipe_vocab_for_regex_tokenizer_txt",
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923", sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",
@ -1234,6 +1258,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/vocab_with_index.txt?generation=1661875977280658"], urls = ["https://storage.googleapis.com/mediapipe-assets/vocab_with_index.txt?generation=1661875977280658"],
) )
http_file(
name = "com_google_mediapipe_w_avg_npy",
sha256 = "a044e35609986d18a972532f2980e939832b5b7d559659959d11ecc752a58bbe",
urls = ["https://storage.googleapis.com/mediapipe-assets/w_avg.npy?generation=1679955100435717"],
)
http_file( http_file(
name = "com_google_mediapipe_yamnet_audio_classifier_with_metadata_tflite", name = "com_google_mediapipe_yamnet_audio_classifier_with_metadata_tflite",
sha256 = "10c95ea3eb9a7bb4cb8bddf6feb023250381008177ac162ce169694d05c317de", sha256 = "10c95ea3eb9a7bb4cb8bddf6feb023250381008177ac162ce169694d05c317de",
@ -1246,6 +1276,60 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_embedding_metadata.tflite?generation=1668295071595506"], urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_embedding_metadata.tflite?generation=1668295071595506"],
) )
http_file(
name = "com_google_mediapipe_decoder_fingerprint_pb",
sha256 = "0bf6239c4855d78edb60f3349b46cdb2c6f83def64f1b31589b6e298e5cbec3c",
urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/fingerprint.pb?generation=1679955102906559"],
)
http_file(
name = "com_google_mediapipe_decoder_keras_metadata_pb",
sha256 = "1631ee698455aea52d4467fe6118800718a86ec49c29f4f3c904785b72f425ff",
urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/keras_metadata.pb?generation=1679955105294959"],
)
http_file(
name = "com_google_mediapipe_decoder_saved_model_pb",
sha256 = "b424d30c63548e93390b2944b9bd9dc29773a56197bb462d3bd9e7a0bd1270ff",
urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/saved_model.pb?generation=1679955107808916"],
)
http_file(
name = "com_google_mediapipe_discriminator_fingerprint_pb",
sha256 = "1fa6201d253c9218f7054138b9ce273266ce431e00cbce2d74d557f6b97223fd",
urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/fingerprint.pb?generation=1679955110094297"],
)
http_file(
name = "com_google_mediapipe_discriminator_keras_metadata_pb",
sha256 = "59a8601790d615dd37ec24e788743ce737e9999ce6ea6593fcf1ee43f674987f",
urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/keras_metadata.pb?generation=1679955112486389"],
)
http_file(
name = "com_google_mediapipe_discriminator_saved_model_pb",
sha256 = "280d6097a9b3d4c3756e028c597fe3d3c76eb14f76f24d49a22ed7b6df1e3878",
urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/saved_model.pb?generation=1679955114873053"],
)
http_file(
name = "com_google_mediapipe_encoder_fingerprint_pb",
sha256 = "06cb4319f8178edf447a7a2442e89303a14a48cc4fc5ae27354eac2ba11ae120",
urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/fingerprint.pb?generation=1679955117213209"],
)
http_file(
name = "com_google_mediapipe_encoder_keras_metadata_pb",
sha256 = "8b1429ee95c130fad0c78077b2b544cd03c9e288658aae93e81df4959b84009e",
urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/keras_metadata.pb?generation=1679955119546778"],
)
http_file(
name = "com_google_mediapipe_encoder_saved_model_pb",
sha256 = "ce48392c71485ecd9b142b46e54442581a299df5560102337038b76a62e02a09",
urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/saved_model.pb?generation=1679955122069362"],
)
http_file( http_file(
name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb", name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb",
sha256 = "c76b856101e2284293a5e5963b7c445e407a0b3e56ec63eb78f64d883e51e3aa", sha256 = "c76b856101e2284293a5e5963b7c445e407a0b3e56ec63eb78f64d883e51e3aa",
@ -1258,6 +1342,24 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"], urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"],
) )
http_file(
name = "com_google_mediapipe_mapping_fingerprint_pb",
sha256 = "6320890f1b9a57e5f4e50e3b56d96fd39d815aa2de51dd1c9b635aa6107d982b",
urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/fingerprint.pb?generation=1679955124430234"],
)
http_file(
name = "com_google_mediapipe_mapping_keras_metadata_pb",
sha256 = "22582a2ec1d4883b52f50e628c1a2d69a2610b38d72a48a0bd9939c26be304f6",
urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/keras_metadata.pb?generation=1679955126858694"],
)
http_file(
name = "com_google_mediapipe_mapping_saved_model_pb",
sha256 = "6a79de45d00f49110304bf0a6746bc717c45f77824cad22690f700f2fbdc1470",
urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/saved_model.pb?generation=1679955129259768"],
)
http_file( http_file(
name = "com_google_mediapipe_mobilebert_tiny_keras_metadata_pb", name = "com_google_mediapipe_mobilebert_tiny_keras_metadata_pb",
sha256 = "cef8131a414c602b9d4742ac57f4f90bc5d8a42baec36b65deece884e2d0cf0f", sha256 = "cef8131a414c602b9d4742ac57f4f90bc5d8a42baec36b65deece884e2d0cf0f",
@ -1306,6 +1408,42 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/saved_model.pb?generation=1661875999264354"], urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/saved_model.pb?generation=1661875999264354"],
) )
http_file(
name = "com_google_mediapipe_decoder_variables_variables_data-00000-of-00001",
sha256 = "d720ddf354036f17fa210951f9ebfb009453b244913a493f494f1441cfc2eca3",
urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/variables/variables.data-00000-of-00001?generation=1679955132326947"],
)
http_file(
name = "com_google_mediapipe_decoder_variables_variables_index",
sha256 = "245f69af6e53fb8b163059fe9936f57b68a7844e15d696393fcddf94c771dfcc",
urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/variables/variables.index?generation=1679955134518344"],
)
http_file(
name = "com_google_mediapipe_discriminator_variables_variables_data-00000-of-00001",
sha256 = "50b00e1898a573588fb0d5d24d74346d99b7153b5d79441d0350c2c6ca89fb02",
urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/variables/variables.data-00000-of-00001?generation=1679955138489595"],
)
http_file(
name = "com_google_mediapipe_discriminator_variables_variables_index",
sha256 = "e5cb4be6442a5741504ce7da9487445637ad89b1f4b6a993bb9e762c7bd5621d",
urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/variables/variables.index?generation=1679955140891136"],
)
http_file(
name = "com_google_mediapipe_encoder_variables_variables_data-00000-of-00001",
sha256 = "09bcd1e2f1c6261bd1842af2da95651d54c9b4b9343eb9b8f0004a97f9bc84bf",
urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/variables/variables.data-00000-of-00001?generation=1679955144875765"],
)
http_file(
name = "com_google_mediapipe_encoder_variables_variables_index",
sha256 = "964f5ac6ced7b19f76b7856d9dad47594a5b2fa89c52840f82996b809372aec9",
urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/variables/variables.index?generation=1679955147123313"],
)
http_file( http_file(
name = "com_google_mediapipe_gesture_embedder_variables_variables_data-00000-of-00001", name = "com_google_mediapipe_gesture_embedder_variables_variables_data-00000-of-00001",
sha256 = "c156c9654c9ffb1091bb9f06c71080bd1e428586276d3f39c33fbab27fe0522d", sha256 = "c156c9654c9ffb1091bb9f06c71080bd1e428586276d3f39c33fbab27fe0522d",
@ -1318,6 +1456,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"], urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"],
) )
http_file(
name = "com_google_mediapipe_mapping_variables_variables_data-00000-of-00001",
sha256 = "4187055e7f69fcc913ee2b11151a56149dda3017c75621d1e160596bde874c07",
urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/variables/variables.data-00000-of-00001?generation=1679955149680908"],
)
http_file(
name = "com_google_mediapipe_mapping_variables_variables_index",
sha256 = "a04fcae7083715613f93ac89943f5fe1f5ba2e6efb9efd14eee7314f25502e4a",
urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/variables/variables.index?generation=1679955152034297"],
)
http_file( http_file(
name = "com_google_mediapipe_mobilebert_tiny_assets_vocab_txt", name = "com_google_mediapipe_mobilebert_tiny_assets_vocab_txt",
sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3", sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3",