Internal change
PiperOrigin-RevId: 520717805
This commit is contained in:
parent
99ba7dd787
commit
d43579fe3e
61
mediapipe/tasks/cc/vision/pose_landmarker/BUILD
Normal file
61
mediapipe/tasks/cc/vision/pose_landmarker/BUILD
Normal file
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "pose_landmarks_detector_graph",
|
||||
srcs = ["pose_landmarks_detector_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:begin_loop_calculator",
|
||||
"//mediapipe/calculators/core:end_loop_calculator",
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:split_proto_list_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_floats_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:tensors_to_segmentation_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_segmentation_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:landmarks_to_detection_calculator",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator",
|
||||
"//mediapipe/calculators/util:refine_landmarks_from_heatmap_calculator",
|
||||
"//mediapipe/calculators/util:refine_landmarks_from_heatmap_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:thresholding_calculator",
|
||||
"//mediapipe/calculators/util:thresholding_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:visibility_copy_calculator",
|
||||
"//mediapipe/calculators/util:visibility_copy_calculator_cc_proto",
|
||||
"//mediapipe/framework:subgraph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,641 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/thresholding_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/visibility_copy_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/subgraph.h"
|
||||
#include "mediapipe/gpu/gpu_origin.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace pose_landmarker {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::mediapipe::tasks::vision::pose_landmarker::proto::
|
||||
PoseLandmarksDetectorGraphOptions;
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
|
||||
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kAuxLandmarksTag[] = "AUXILIARY_LANDMARKS";
|
||||
constexpr char kPoseRectNextFrameTag[] = "POSE_RECT_NEXT_FRAME";
|
||||
constexpr char kPoseRectsNextFrameTag[] = "POSE_RECTS_NEXT_FRAME";
|
||||
constexpr char kPresenceTag[] = "PRESENCE";
|
||||
constexpr char kPresenceScoreTag[] = "PRESENCE_SCORE";
|
||||
constexpr char kSegmentationMaskTag[] = "SEGMENTATION_MASK";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kLandmarksToTag[] = "LANDMARKS_TO";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kFloatTag[] = "FLOAT";
|
||||
constexpr char kFlagTag[] = "FLAG";
|
||||
constexpr char kMaskTag[] = "MASK";
|
||||
constexpr char kDetectionTag[] = "DETECTION";
|
||||
constexpr char kNormLandmarksFromTag[] = "NORM_LANDMARKS_FROM";
|
||||
constexpr char kBatchEndTag[] = "BATCH_END";
|
||||
constexpr char kItemTag[] = "ITEM";
|
||||
constexpr char kIterableTag[] = "ITERABLE";
|
||||
|
||||
constexpr int kModelOutputTensorSplitNum = 5;
|
||||
constexpr int kLandmarksNum = 39;
|
||||
constexpr float kLandmarksNormalizeZ = 0.4;
|
||||
|
||||
struct SinglePoseLandmarkerOutputs {
|
||||
Source<NormalizedLandmarkList> pose_landmarks;
|
||||
Source<LandmarkList> world_pose_landmarks;
|
||||
Source<NormalizedLandmarkList> auxiliary_pose_landmarks;
|
||||
Source<NormalizedRect> pose_rect_next_frame;
|
||||
Source<bool> pose_presence;
|
||||
Source<float> pose_presence_score;
|
||||
Source<Image> segmentation_mask;
|
||||
};
|
||||
|
||||
struct PoseLandmarkerOutputs {
|
||||
Source<std::vector<NormalizedLandmarkList>> landmark_lists;
|
||||
Source<std::vector<LandmarkList>> world_landmark_lists;
|
||||
Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
|
||||
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
|
||||
Source<std::vector<bool>> presences;
|
||||
Source<std::vector<float>> presence_scores;
|
||||
Source<std::vector<Image>> segmentation_masks;
|
||||
};
|
||||
|
||||
absl::Status SanityCheckOptions(
|
||||
const PoseLandmarksDetectorGraphOptions& options) {
|
||||
if (options.min_detection_confidence() < 0 ||
|
||||
options.min_detection_confidence() > 1) {
|
||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||
"Invalid `min_detection_confidence` option: "
|
||||
"value must be in the range [0.0, 1.0]",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Split pose landmark detection model output tensor into five parts,
|
||||
// representing landmarks, presence scores, segmentation, heatmap, and world
|
||||
// landmarks respectively.
|
||||
void ConfigureSplitTensorVectorCalculator(
|
||||
mediapipe::SplitVectorCalculatorOptions* options) {
|
||||
for (int i = 0; i < kModelOutputTensorSplitNum; ++i) {
|
||||
auto* range = options->add_ranges();
|
||||
range->set_begin(i);
|
||||
range->set_end(i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
void ConfigureTensorsToLandmarksCalculator(
|
||||
const ImageTensorSpecs& input_image_tensor_spec, bool normalize,
|
||||
bool sigmoid_activation,
|
||||
mediapipe::TensorsToLandmarksCalculatorOptions* options) {
|
||||
options->set_num_landmarks(kLandmarksNum);
|
||||
options->set_input_image_height(input_image_tensor_spec.image_height);
|
||||
options->set_input_image_width(input_image_tensor_spec.image_width);
|
||||
|
||||
if (normalize) {
|
||||
options->set_normalize_z(kLandmarksNormalizeZ);
|
||||
}
|
||||
|
||||
if (sigmoid_activation) {
|
||||
options->set_visibility_activation(
|
||||
mediapipe::TensorsToLandmarksCalculatorOptions_Activation_SIGMOID);
|
||||
options->set_presence_activation(
|
||||
mediapipe::TensorsToLandmarksCalculatorOptions_Activation_SIGMOID);
|
||||
}
|
||||
}
|
||||
|
||||
void ConfigureTensorsToSegmentationCalculator(
|
||||
mediapipe::TensorsToSegmentationCalculatorOptions* options) {
|
||||
options->set_activation(
|
||||
mediapipe::TensorsToSegmentationCalculatorOptions_Activation_SIGMOID);
|
||||
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
|
||||
}
|
||||
|
||||
void ConfigureRefineLandmarksFromHeatmapCalculator(
|
||||
mediapipe::RefineLandmarksFromHeatmapCalculatorOptions* options) {
|
||||
// Derived from
|
||||
// mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt.
|
||||
options->set_kernel_size(7);
|
||||
}
|
||||
|
||||
void ConfigureSplitNormalizedLandmarkListCalculator(
|
||||
mediapipe::SplitVectorCalculatorOptions* options) {
|
||||
// Derived from
|
||||
// mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt
|
||||
auto* range = options->add_ranges();
|
||||
range->set_begin(0);
|
||||
range->set_end(33);
|
||||
auto* range_2 = options->add_ranges();
|
||||
range_2->set_begin(33);
|
||||
range_2->set_end(35);
|
||||
}
|
||||
|
||||
void ConfigureSplitLandmarkListCalculator(
|
||||
mediapipe::SplitVectorCalculatorOptions* options) {
|
||||
// Derived from
|
||||
// mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt
|
||||
auto* range = options->add_ranges();
|
||||
range->set_begin(0);
|
||||
range->set_end(33);
|
||||
}
|
||||
|
||||
void ConfigureVisibilityCopyCalculator(
|
||||
mediapipe::VisibilityCopyCalculatorOptions* options) {
|
||||
// Derived from
|
||||
// mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt
|
||||
options->set_copy_visibility(true);
|
||||
options->set_copy_presence(true);
|
||||
}
|
||||
|
||||
// A "mediapipe.tasks.vision.pose_landmarker.SinglePoseLandmarksDetectorGraph"
|
||||
// performs pose landmarks detection.
|
||||
// - Accepts CPU input images and outputs Landmark on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform detection on.
|
||||
// NORM_RECT - NormalizedRect @Optional
|
||||
// Rect enclosing the RoI to perform detection on. If not set, the detection
|
||||
// RoI is the whole image.
|
||||
//
|
||||
//
|
||||
// Outputs:
|
||||
// LANDMARKS: - NormalizedLandmarkList
|
||||
// Detected pose landmarks.
|
||||
// WORLD_LANDMARKS - LandmarkList
|
||||
// Detected pose landmarks in world coordinates.
|
||||
// AUXILIARY_LANDMARKS - NormalizedLandmarkList
|
||||
// Detected pose auxiliary landmarks.
|
||||
// POSE_RECT_NEXT_FRAME - NormalizedRect
|
||||
// The predicted Rect enclosing the pose RoI for landmark detection on the
|
||||
// next frame.
|
||||
// PRESENCE - bool
|
||||
// Boolean value indicates whether the pose is present.
|
||||
// PRESENCE_SCORE - float
|
||||
// Float value indicates the probability that the pose is present.
|
||||
// SEGMENTATION_MASK - Image
|
||||
// Segmentation mask for pose.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.pose_landmarker.SingleposeLandmarksDetectorGraph"
|
||||
// input_stream: "IMAGE:input_image"
|
||||
// input_stream: "POSE_RECT:pose_rect"
|
||||
// output_stream: "LANDMARKS:pose_landmarks"
|
||||
// output_stream: "WORLD_LANDMARKS:world_pose_landmarks"
|
||||
// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
|
||||
// output_stream: "POSE_RECT_NEXT_FRAME:pose_rect_next_frame"
|
||||
// output_stream: "PRESENCE:pose_presence"
|
||||
// output_stream: "PRESENCE_SCORE:pose_presence_score"
|
||||
// output_stream: "SEGMENTATION_MASK:segmentation_mask"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.pose_landmarker.proto.poseLandmarksDetectorGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "pose_landmark_lite.tflite"
|
||||
// }
|
||||
// }
|
||||
// min_detection_confidence: 0.5
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
CreateModelResources<PoseLandmarksDetectorGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto pose_landmark_detection_outs,
|
||||
BuildSinglePoseLandmarksDetectorGraph(
|
||||
sc->Options<PoseLandmarksDetectorGraphOptions>(), *model_resources,
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
pose_landmark_detection_outs.pose_landmarks >>
|
||||
graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
|
||||
pose_landmark_detection_outs.world_pose_landmarks >>
|
||||
graph[Output<LandmarkList>(kWorldLandmarksTag)];
|
||||
pose_landmark_detection_outs.auxiliary_pose_landmarks >>
|
||||
graph[Output<NormalizedLandmarkList>(kAuxLandmarksTag)];
|
||||
pose_landmark_detection_outs.pose_rect_next_frame >>
|
||||
graph[Output<NormalizedRect>(kPoseRectNextFrameTag)];
|
||||
pose_landmark_detection_outs.pose_presence >>
|
||||
graph[Output<bool>(kPresenceTag)];
|
||||
pose_landmark_detection_outs.pose_presence_score >>
|
||||
graph[Output<float>(kPresenceScoreTag)];
|
||||
pose_landmark_detection_outs.segmentation_mask >>
|
||||
graph[Output<Image>(kSegmentationMaskTag)];
|
||||
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<SinglePoseLandmarkerOutputs>
|
||||
BuildSinglePoseLandmarksDetectorGraph(
|
||||
const PoseLandmarksDetectorGraphOptions& subgraph_options,
|
||||
const ModelResources& model_resources, Source<Image> image_in,
|
||||
Source<NormalizedRect> pose_rect, Graph& graph) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
|
||||
|
||||
auto& preprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
|
||||
bool use_gpu =
|
||||
components::processors::DetermineImagePreprocessingGpuBackend(
|
||||
subgraph_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
|
||||
model_resources, use_gpu,
|
||||
&preprocessing.GetOptions<tasks::components::processors::proto::
|
||||
ImagePreprocessingGraphOptions>()));
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
pose_rect >> preprocessing.In(kNormRectTag);
|
||||
auto image_size = preprocessing[Output<std::pair<int, int>>(kImageSizeTag)];
|
||||
|
||||
ASSIGN_OR_RETURN(auto image_tensor_specs,
|
||||
BuildInputImageTensorSpecs(model_resources));
|
||||
|
||||
auto& inference = AddInference(
|
||||
model_resources, subgraph_options.base_options().acceleration(), graph);
|
||||
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
|
||||
|
||||
// Split model output tensors to multiple streams.
|
||||
auto& split_tensors_vector = graph.AddNode("SplitTensorVectorCalculator");
|
||||
ConfigureSplitTensorVectorCalculator(
|
||||
&split_tensors_vector
|
||||
.GetOptions<mediapipe::SplitVectorCalculatorOptions>());
|
||||
inference.Out(kTensorsTag) >> split_tensors_vector.In("");
|
||||
auto landmark_tensors = split_tensors_vector.Out(0);
|
||||
auto pose_flag_tensors = split_tensors_vector.Out(1);
|
||||
auto segmentation_tensors = split_tensors_vector.Out(2);
|
||||
auto heatmap_tensors = split_tensors_vector.Out(3);
|
||||
auto world_landmark_tensors = split_tensors_vector.Out(4);
|
||||
|
||||
// Converts the pose-flag tensor into a float that represents the confidence
|
||||
// score of pose presence.
|
||||
auto& tensors_to_pose_presence = graph.AddNode("TensorsToFloatsCalculator");
|
||||
pose_flag_tensors >> tensors_to_pose_presence.In(kTensorsTag);
|
||||
auto pose_presence_score =
|
||||
tensors_to_pose_presence[Output<float>(kFloatTag)];
|
||||
|
||||
// Applies a threshold to the confidence score to determine whether a
|
||||
// pose is present.
|
||||
auto& pose_presence_thresholding = graph.AddNode("ThresholdingCalculator");
|
||||
pose_presence_thresholding
|
||||
.GetOptions<mediapipe::ThresholdingCalculatorOptions>()
|
||||
.set_threshold(subgraph_options.min_detection_confidence());
|
||||
pose_presence_score >> pose_presence_thresholding.In(kFloatTag);
|
||||
auto pose_presence = pose_presence_thresholding[Output<bool>(kFlagTag)];
|
||||
|
||||
// GateCalculator for tensors.
|
||||
auto& tensors_gate = graph.AddNode("GateCalculator");
|
||||
landmark_tensors >> tensors_gate.In("")[0];
|
||||
segmentation_tensors >> tensors_gate.In("")[1];
|
||||
heatmap_tensors >> tensors_gate.In("")[2];
|
||||
world_landmark_tensors >> tensors_gate.In("")[3];
|
||||
pose_presence >> tensors_gate.In("ALLOW");
|
||||
auto ensured_landmarks_tensors = tensors_gate.Out(0);
|
||||
auto ensured_segmentation_tensors = tensors_gate.Out(1);
|
||||
auto ensured_heatmap_tensors = tensors_gate.Out(2);
|
||||
auto ensured_world_landmark_tensors = tensors_gate.Out(3);
|
||||
|
||||
// Decodes the landmark tensors into a list of landmarks, where the landmark
|
||||
// coordinates are normalized by the size of the input image to the model.
|
||||
auto& tensors_to_landmarks = graph.AddNode("TensorsToLandmarksCalculator");
|
||||
ConfigureTensorsToLandmarksCalculator(
|
||||
image_tensor_specs, /* normalize = */ false,
|
||||
/*sigmoid_activation= */ true,
|
||||
&tensors_to_landmarks
|
||||
.GetOptions<mediapipe::TensorsToLandmarksCalculatorOptions>());
|
||||
ensured_landmarks_tensors >> tensors_to_landmarks.In(kTensorsTag);
|
||||
|
||||
auto landmarks =
|
||||
tensors_to_landmarks[Output<NormalizedLandmarkList>(kNormLandmarksTag)];
|
||||
|
||||
// Decodes the segmentation tensor into a mask image with pixel values in
|
||||
// [0, 1] (1 for person and 0 for background).
|
||||
auto& tensors_to_segmentation =
|
||||
graph.AddNode("TensorsToSegmentationCalculator");
|
||||
ConfigureTensorsToSegmentationCalculator(
|
||||
&tensors_to_segmentation
|
||||
.GetOptions<mediapipe::TensorsToSegmentationCalculatorOptions>());
|
||||
ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
|
||||
auto segmentation_mask = tensors_to_segmentation[Output<Image>(kMaskTag)];
|
||||
|
||||
// Refines landmarks with the heatmap tensor.
|
||||
auto& refine_landmarks_from_heatmap =
|
||||
graph.AddNode("RefineLandmarksFromHeatmapCalculator");
|
||||
ConfigureRefineLandmarksFromHeatmapCalculator(
|
||||
&refine_landmarks_from_heatmap.GetOptions<
|
||||
mediapipe::RefineLandmarksFromHeatmapCalculatorOptions>());
|
||||
ensured_heatmap_tensors >> refine_landmarks_from_heatmap.In(kTensorsTag);
|
||||
landmarks >> refine_landmarks_from_heatmap.In(kNormLandmarksTag);
|
||||
auto landmarks_from_heatmap =
|
||||
refine_landmarks_from_heatmap[Output<NormalizedLandmarkList>(
|
||||
kNormLandmarksTag)];
|
||||
|
||||
// Splits the landmarks into two sets: the actual pose landmarks and the
|
||||
// auxiliary landmarks.
|
||||
auto& split_normalized_landmark_list =
|
||||
graph.AddNode("SplitNormalizedLandmarkListCalculator");
|
||||
ConfigureSplitNormalizedLandmarkListCalculator(
|
||||
&split_normalized_landmark_list
|
||||
.GetOptions<mediapipe::SplitVectorCalculatorOptions>());
|
||||
landmarks_from_heatmap >> split_normalized_landmark_list.In("");
|
||||
auto normalized_landmarks = split_normalized_landmark_list.Out("")[0]
|
||||
.Cast<NormalizedLandmarkList>();
|
||||
auto normalized_auxiliary_landmarks =
|
||||
split_normalized_landmark_list.Out("")[1]
|
||||
.Cast<NormalizedLandmarkList>();
|
||||
|
||||
// Decodes the world-landmark tensors into a vector of world landmarks.
|
||||
auto& tensors_to_world_landmarks =
|
||||
graph.AddNode("TensorsToLandmarksCalculator");
|
||||
ConfigureTensorsToLandmarksCalculator(
|
||||
image_tensor_specs, /* normalize = */ false,
|
||||
/* sigmoid_activation= */ false,
|
||||
&tensors_to_world_landmarks
|
||||
.GetOptions<mediapipe::TensorsToLandmarksCalculatorOptions>());
|
||||
ensured_world_landmark_tensors >>
|
||||
tensors_to_world_landmarks.In(kTensorsTag);
|
||||
auto world_landmarks =
|
||||
tensors_to_world_landmarks[Output<LandmarkList>(kLandmarksTag)];
|
||||
|
||||
// Keeps only the actual world landmarks.
|
||||
auto& split_landmark_list = graph.AddNode("SplitLandmarkListCalculator");
|
||||
ConfigureSplitLandmarkListCalculator(
|
||||
&split_landmark_list
|
||||
.GetOptions<mediapipe::SplitVectorCalculatorOptions>());
|
||||
world_landmarks >> split_landmark_list.In("");
|
||||
auto split_landmarks = split_landmark_list.Out(0);
|
||||
|
||||
// Reuses the visibility and presence field in pose landmarks for the world
|
||||
// landmarks.
|
||||
auto& visibility_copy = graph.AddNode("VisibilityCopyCalculator");
|
||||
ConfigureVisibilityCopyCalculator(
|
||||
&visibility_copy
|
||||
.GetOptions<mediapipe::VisibilityCopyCalculatorOptions>());
|
||||
split_landmarks >> visibility_copy.In(kLandmarksToTag);
|
||||
normalized_landmarks >> visibility_copy.In(kNormLandmarksFromTag);
|
||||
auto world_landmarks_with_visibility =
|
||||
visibility_copy[Output<LandmarkList>(kLandmarksToTag)];
|
||||
|
||||
// Landmarks to Detections.
|
||||
auto& landmarks_to_detection =
|
||||
graph.AddNode("LandmarksToDetectionCalculator");
|
||||
landmarks >> landmarks_to_detection.In(kNormLandmarksTag);
|
||||
auto detection = landmarks_to_detection.Out(kDetectionTag);
|
||||
|
||||
// Detections to Rects.
|
||||
auto& detection_to_rects = graph.AddNode("DetectionsToRectsCalculator");
|
||||
image_size >> detection_to_rects.In(kImageSizeTag);
|
||||
detection >> detection_to_rects.In(kDetectionTag);
|
||||
auto norm_rect = detection_to_rects.Out(kNormRectTag);
|
||||
|
||||
// Expands the pose rectangle so that in the next video frame it's likely to
|
||||
// still contain the pose even with some motion.
|
||||
auto& pose_rect_transformation =
|
||||
graph.AddNode("RectTransformationCalculator");
|
||||
image_size >> pose_rect_transformation.In(kImageSizeTag);
|
||||
norm_rect >> pose_rect_transformation.In(kNormRectTag);
|
||||
auto pose_rect_next_frame =
|
||||
pose_rect_transformation[Output<NormalizedRect>("")];
|
||||
|
||||
return {{
|
||||
/* pose_landmarks= */ normalized_landmarks,
|
||||
/* world_pose_landmarks= */ world_landmarks_with_visibility,
|
||||
/* auxiliary_pose_landmarks= */ normalized_auxiliary_landmarks,
|
||||
/* pose_rect_next_frame= */ pose_rect_next_frame,
|
||||
/* pose_presence= */ pose_presence,
|
||||
/* pose_presence_score= */ pose_presence_score,
|
||||
/* segmentation_mask= */ segmentation_mask,
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::pose_landmarker::SinglePoseLandmarksDetectorGraph); // NOLINT
|
||||
// clang-format on
|
||||
|
||||
// A "mediapipe.tasks.vision.pose_landmarker.MultiplePoseLandmarksDetectorGraph"
|
||||
// performs multi pose landmark detection.
|
||||
// - Accepts CPU input image and a vector of pose rect RoIs to detect the
|
||||
// multiple poses landmarks enclosed by the RoIs. Output vectors of
|
||||
// pose landmarks related results, where each element in the vectors
|
||||
// corresponds to the result of the same pose.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform detection on.
|
||||
// NORM_RECT - std::vector<NormalizedRect>
|
||||
// A vector of multiple pose rects enclosing the pose RoI to perform
|
||||
// landmarks detection on.
|
||||
//
|
||||
//
|
||||
// Outputs:
|
||||
// LANDMARKS: - std::vector<NormalizedLandmarkList>
|
||||
// Vector of detected pose landmarks.
|
||||
// WORLD_LANDMARKS - std::vector<LandmarkList>
|
||||
// Vector of detected pose landmarks in world coordinates.
|
||||
// AUXILIARY_LANDMARKS - std::vector<NormalizedLandmarkList>
|
||||
// Vector of detected pose auxiliary landmarks.
|
||||
// POSE_RECT_NEXT_FRAME - std::vector<NormalizedRect>
|
||||
// Vector of the predicted rects enclosing the same pose RoI for landmark
|
||||
// detection on the next frame.
|
||||
// PRESENCE - std::vector<bool>
|
||||
// Vector of boolean value indicates whether the pose is present.
|
||||
// PRESENCE_SCORE - std::vector<float>
|
||||
// Vector of float value indicates the probability that the pose is present.
|
||||
// SEGMENTATION_MASK - std::vector<Image>
|
||||
// Vector of segmentation masks.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.pose_landmarker.MultiplePoseLandmarksDetectorGraph"
|
||||
// input_stream: "IMAGE:input_image"
|
||||
// input_stream: "POSE_RECT:pose_rect"
|
||||
// output_stream: "LANDMARKS:pose_landmarks"
|
||||
// output_stream: "WORLD_LANDMARKS:world_pose_landmarks"
|
||||
// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
|
||||
// output_stream: "POSE_RECT_NEXT_FRAME:pose_rect_next_frame"
|
||||
// output_stream: "PRESENCE:pose_presence"
|
||||
// output_stream: "PRESENCE_SCORE:pose_presence_score"
|
||||
// output_stream: "SEGMENTATION_MASK:segmentation_mask"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.pose_landmarker.proto.PoseLandmarksDetectorGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "pose_landmark_lite.tflite"
|
||||
// }
|
||||
// }
|
||||
// min_detection_confidence: 0.5
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto pose_landmark_detection_outputs,
|
||||
BuildPoseLandmarksDetectorGraph(
|
||||
sc->Options<PoseLandmarksDetectorGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph));
|
||||
pose_landmark_detection_outputs.landmark_lists >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
|
||||
pose_landmark_detection_outputs.world_landmark_lists >>
|
||||
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
|
||||
pose_landmark_detection_outputs.auxiliary_landmark_lists >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kAuxLandmarksTag)];
|
||||
pose_landmark_detection_outputs.pose_rects_next_frame >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
|
||||
pose_landmark_detection_outputs.presences >>
|
||||
graph[Output<std::vector<bool>>(kPresenceTag)];
|
||||
pose_landmark_detection_outputs.presence_scores >>
|
||||
graph[Output<std::vector<float>>(kPresenceScoreTag)];
|
||||
pose_landmark_detection_outputs.segmentation_masks >>
|
||||
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarksDetectorGraph(
|
||||
const PoseLandmarksDetectorGraphOptions& subgraph_options,
|
||||
Source<Image> image_in,
|
||||
Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph) {
|
||||
auto& begin_loop_multi_pose_rects =
|
||||
graph.AddNode("BeginLoopNormalizedRectCalculator");
|
||||
image_in >> begin_loop_multi_pose_rects.In("CLONE");
|
||||
multi_pose_rects >> begin_loop_multi_pose_rects.In("ITERABLE");
|
||||
auto batch_end = begin_loop_multi_pose_rects.Out("BATCH_END");
|
||||
auto image = begin_loop_multi_pose_rects.Out("CLONE");
|
||||
auto pose_rect = begin_loop_multi_pose_rects.Out("ITEM");
|
||||
|
||||
auto& pose_landmark_subgraph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.pose_landmarker."
|
||||
"SinglePoseLandmarksDetectorGraph");
|
||||
pose_landmark_subgraph.GetOptions<PoseLandmarksDetectorGraphOptions>()
|
||||
.CopyFrom(subgraph_options);
|
||||
image >> pose_landmark_subgraph.In(kImageTag);
|
||||
pose_rect >> pose_landmark_subgraph.In(kNormRectTag);
|
||||
auto landmarks = pose_landmark_subgraph.Out(kLandmarksTag);
|
||||
auto world_landmarks = pose_landmark_subgraph.Out(kWorldLandmarksTag);
|
||||
auto auxiliary_landmarks = pose_landmark_subgraph.Out(kAuxLandmarksTag);
|
||||
auto pose_rect_next_frame =
|
||||
pose_landmark_subgraph.Out(kPoseRectNextFrameTag);
|
||||
auto presence = pose_landmark_subgraph.Out(kPresenceTag);
|
||||
auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag);
|
||||
auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
|
||||
|
||||
auto& end_loop_landmarks =
|
||||
graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
|
||||
batch_end >> end_loop_landmarks.In(kBatchEndTag);
|
||||
landmarks >> end_loop_landmarks.In(kItemTag);
|
||||
auto landmark_lists =
|
||||
end_loop_landmarks[Output<std::vector<NormalizedLandmarkList>>(
|
||||
kIterableTag)];
|
||||
|
||||
auto& end_loop_world_landmarks =
|
||||
graph.AddNode("EndLoopLandmarkListVectorCalculator");
|
||||
batch_end >> end_loop_world_landmarks.In(kBatchEndTag);
|
||||
world_landmarks >> end_loop_world_landmarks.In(kItemTag);
|
||||
auto world_landmark_lists =
|
||||
end_loop_world_landmarks[Output<std::vector<LandmarkList>>(
|
||||
kIterableTag)];
|
||||
|
||||
auto& end_loop_auxiliary_landmarks =
|
||||
graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
|
||||
batch_end >> end_loop_auxiliary_landmarks.In(kBatchEndTag);
|
||||
auxiliary_landmarks >> end_loop_auxiliary_landmarks.In(kItemTag);
|
||||
auto auxiliary_landmark_lists = end_loop_auxiliary_landmarks
|
||||
[Output<std::vector<NormalizedLandmarkList>>(kIterableTag)];
|
||||
|
||||
auto& end_loop_rects_next_frame =
|
||||
graph.AddNode("EndLoopNormalizedRectCalculator");
|
||||
batch_end >> end_loop_rects_next_frame.In(kBatchEndTag);
|
||||
pose_rect_next_frame >> end_loop_rects_next_frame.In(kItemTag);
|
||||
auto pose_rects_next_frame =
|
||||
end_loop_rects_next_frame[Output<std::vector<NormalizedRect>>(
|
||||
kIterableTag)];
|
||||
|
||||
auto& end_loop_presence = graph.AddNode("EndLoopBooleanCalculator");
|
||||
batch_end >> end_loop_presence.In(kBatchEndTag);
|
||||
presence >> end_loop_presence.In(kItemTag);
|
||||
auto presences = end_loop_presence[Output<std::vector<bool>>(kIterableTag)];
|
||||
|
||||
auto& end_loop_presence_score = graph.AddNode("EndLoopFloatCalculator");
|
||||
batch_end >> end_loop_presence_score.In(kBatchEndTag);
|
||||
presence_score >> end_loop_presence_score.In(kItemTag);
|
||||
auto presence_scores =
|
||||
end_loop_presence_score[Output<std::vector<float>>(kIterableTag)];
|
||||
|
||||
auto& end_loop_segmentation_mask = graph.AddNode("EndLoopImageCalculator");
|
||||
batch_end >> end_loop_segmentation_mask.In(kBatchEndTag);
|
||||
segmentation_mask >> end_loop_segmentation_mask.In(kItemTag);
|
||||
auto segmentation_masks =
|
||||
end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
|
||||
|
||||
return {{
|
||||
/* landmark_lists= */ landmark_lists,
|
||||
/* world_landmark_lists= */ world_landmark_lists,
|
||||
/* auxiliary_landmark_lists= */ auxiliary_landmark_lists,
|
||||
/* pose_rects_next_frame= */ pose_rects_next_frame,
|
||||
/* presences= */ presences,
|
||||
/* presence_scores= */ presence_scores,
|
||||
/* segmentation_masks= */ segmentation_masks,
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::pose_landmarker::MultiplePoseLandmarksDetectorGraph); // NOLINT
|
||||
// clang-format on
|
||||
|
||||
} // namespace pose_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,366 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace pose_landmarker {
|
||||
namespace {
|
||||
|
||||
using ::file::Defaults;
|
||||
using ::file::GetTextProto;
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::core::TaskRunner;
|
||||
using ::mediapipe::tasks::vision::DecodeImageFromFile;
|
||||
using ::mediapipe::tasks::vision::pose_landmarker::proto::
|
||||
PoseLandmarksDetectorGraphOptions;
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::EqualsProto;
|
||||
using ::testing::Pointwise;
|
||||
using ::testing::TestParamInfo;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kPoseLandmarkerLiteModel[] = "pose_landmark_lite.tflite";
|
||||
constexpr char kPoseImage[] = "pose.jpg";
|
||||
constexpr char kBurgerImage[] = "burger.jpg";
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageName[] = "image_in";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
|
||||
constexpr char kPoseRectName[] = "pose_rect_in";
|
||||
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kLandmarksName[] = "landmarks";
|
||||
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kWorldLandmarksName[] = "world_landmarks";
|
||||
constexpr char kAuxLandmarksTag[] = "AUXILIARY_LANDMARKS";
|
||||
constexpr char kAuxLandmarksName[] = "auxiliary_landmarks";
|
||||
constexpr char kPoseRectNextFrameTag[] = "POSE_RECT_NEXT_FRAME";
|
||||
constexpr char kPoseRectNextFrameName[] = "pose_rect_next_frame";
|
||||
constexpr char kPoseRectsNextFrameTag[] = "POSE_RECTS_NEXT_FRAME";
|
||||
constexpr char kPoseRectsNextFrameName[] = "pose_rects_next_frame";
|
||||
constexpr char kPresenceTag[] = "PRESENCE";
|
||||
constexpr char kPresenceName[] = "presence";
|
||||
constexpr char kPresenceScoreTag[] = "PRESENCE_SCORE";
|
||||
constexpr char kPresenceScoreName[] = "presence_score";
|
||||
constexpr char kSegmentationMaskTag[] = "SEGMENTATION_MASK";
|
||||
constexpr char kSegmentationMaskName[] = "segmentation_mask";
|
||||
|
||||
// Expected pose landmarks positions, in text proto format.
|
||||
constexpr char kExpectedPoseLandmarksFilename[] =
|
||||
"expected_pose_landmarks.prototxt";
|
||||
|
||||
constexpr float kLiteModelFractionDiff = 0.05; // percentage
|
||||
constexpr float kAbsMargin = 0.03;
|
||||
|
||||
// Helper function to create a Single Pose Landmark TaskRunner.
|
||||
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSinglePoseTaskRunner(
|
||||
absl::string_view model_name) {
|
||||
Graph graph;
|
||||
|
||||
auto& pose_landmark_detection = graph.AddNode(
|
||||
"mediapipe.tasks.vision.pose_landmarker."
|
||||
"SinglePoseLandmarksDetectorGraph");
|
||||
|
||||
auto options = std::make_unique<PoseLandmarksDetectorGraphOptions>();
|
||||
options->mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||
JoinPath("./", kTestDataDirectory, model_name));
|
||||
pose_landmark_detection.GetOptions<PoseLandmarksDetectorGraphOptions>().Swap(
|
||||
options.get());
|
||||
|
||||
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
||||
pose_landmark_detection.In(kImageTag);
|
||||
graph[Input<NormalizedRect>(kNormRectTag)].SetName(kPoseRectName) >>
|
||||
pose_landmark_detection.In(kNormRectTag);
|
||||
|
||||
pose_landmark_detection.Out(kLandmarksTag).SetName(kLandmarksName) >>
|
||||
graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
|
||||
pose_landmark_detection.Out(kWorldLandmarksTag)
|
||||
.SetName(kWorldLandmarksName) >>
|
||||
graph[Output<LandmarkList>(kWorldLandmarksTag)];
|
||||
pose_landmark_detection.Out(kAuxLandmarksTag).SetName(kAuxLandmarksName) >>
|
||||
graph[Output<LandmarkList>(kAuxLandmarksTag)];
|
||||
pose_landmark_detection.Out(kPresenceTag).SetName(kPresenceName) >>
|
||||
graph[Output<bool>(kPresenceTag)];
|
||||
pose_landmark_detection.Out(kPresenceScoreTag).SetName(kPresenceScoreName) >>
|
||||
graph[Output<float>(kPresenceScoreTag)];
|
||||
pose_landmark_detection.Out(kSegmentationMaskTag)
|
||||
.SetName(kSegmentationMaskName) >>
|
||||
graph[Output<Image>(kSegmentationMaskTag)];
|
||||
pose_landmark_detection.Out(kPoseRectNextFrameTag)
|
||||
.SetName(kPoseRectNextFrameName) >>
|
||||
graph[Output<NormalizedRect>(kPoseRectNextFrameTag)];
|
||||
|
||||
return TaskRunner::Create(
|
||||
graph.GetConfig(),
|
||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
||||
}
|
||||
|
||||
// Helper function to create a Multi Pose Landmark TaskRunner.
|
||||
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiPoseTaskRunner(
|
||||
absl::string_view model_name) {
|
||||
Graph graph;
|
||||
|
||||
auto& multi_pose_landmark_detection = graph.AddNode(
|
||||
"mediapipe.tasks.vision.pose_landmarker."
|
||||
"MultiplePoseLandmarksDetectorGraph");
|
||||
|
||||
auto options = std::make_unique<PoseLandmarksDetectorGraphOptions>();
|
||||
options->mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||
JoinPath("./", kTestDataDirectory, model_name));
|
||||
multi_pose_landmark_detection.GetOptions<PoseLandmarksDetectorGraphOptions>()
|
||||
.Swap(options.get());
|
||||
|
||||
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
||||
multi_pose_landmark_detection.In(kImageTag);
|
||||
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)].SetName(
|
||||
kPoseRectName) >>
|
||||
multi_pose_landmark_detection.In(kNormRectTag);
|
||||
|
||||
multi_pose_landmark_detection.Out(kLandmarksTag).SetName(kLandmarksName) >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
|
||||
multi_pose_landmark_detection.Out(kWorldLandmarksTag)
|
||||
.SetName(kWorldLandmarksName) >>
|
||||
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
|
||||
multi_pose_landmark_detection.Out(kAuxLandmarksTag)
|
||||
.SetName(kAuxLandmarksName) >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kAuxLandmarksTag)];
|
||||
multi_pose_landmark_detection.Out(kPresenceTag).SetName(kPresenceName) >>
|
||||
graph[Output<std::vector<bool>>(kPresenceTag)];
|
||||
multi_pose_landmark_detection.Out(kPresenceScoreTag)
|
||||
.SetName(kPresenceScoreName) >>
|
||||
graph[Output<std::vector<float>>(kPresenceScoreTag)];
|
||||
multi_pose_landmark_detection.Out(kSegmentationMaskTag)
|
||||
.SetName(kSegmentationMaskName) >>
|
||||
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||
multi_pose_landmark_detection.Out(kPoseRectsNextFrameTag)
|
||||
.SetName(kPoseRectsNextFrameName) >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
|
||||
|
||||
return TaskRunner::Create(
|
||||
graph.GetConfig(),
|
||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
||||
}
|
||||
|
||||
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
|
||||
NormalizedLandmarkList expected_landmark_list;
|
||||
MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename),
|
||||
&expected_landmark_list, Defaults()));
|
||||
return expected_landmark_list;
|
||||
}
|
||||
|
||||
// Struct holding the parameters for parameterized PoseLandmarkerTest
|
||||
// class.
|
||||
struct SinglePoseTestParams {
|
||||
// The name of this test, for convenience when displaying test results.
|
||||
std::string test_name;
|
||||
// The filename of the model to test.
|
||||
std::string input_model_name;
|
||||
// The filename of the test image.
|
||||
std::string test_image_name;
|
||||
// RoI on image to detect pose.
|
||||
NormalizedRect pose_rect;
|
||||
// Expected pose presence value.
|
||||
bool expected_presence;
|
||||
// The expected output landmarks positions in pixels coornidates.
|
||||
std::optional<NormalizedLandmarkList> expected_landmarks;
|
||||
// The expected segmentation mask.
|
||||
Image expected_segmentation_mask;
|
||||
// The max value difference between expected_positions and detected positions.
|
||||
float landmarks_diff_threshold;
|
||||
};
|
||||
|
||||
struct MultiPoseTestParams {
|
||||
// The name of this test, for convenience when displaying test results.
|
||||
std::string test_name;
|
||||
// The filename of the model to test.
|
||||
std::string input_model_name;
|
||||
// The filename of the test image.
|
||||
std::string test_image_name;
|
||||
// RoIs on image to detect poses.
|
||||
std::vector<NormalizedRect> pose_rects;
|
||||
// Expected pose presence values.
|
||||
std::vector<bool> expected_presences;
|
||||
// The expected output landmarks positions in pixels coornidates.
|
||||
std::vector<NormalizedLandmarkList> expected_landmark_lists;
|
||||
// The expected segmentation_mask Image.
|
||||
std::vector<Image> expected_segmentation_masks;
|
||||
// The max value difference between expected_positions and detected positions.
|
||||
float landmarks_diff_threshold;
|
||||
};
|
||||
|
||||
// Helper function to construct NormalizeRect proto.
|
||||
NormalizedRect MakePoseRect(float x_center, float y_center, float width,
|
||||
float height, float rotation) {
|
||||
NormalizedRect pose_rect;
|
||||
pose_rect.set_x_center(x_center);
|
||||
pose_rect.set_y_center(y_center);
|
||||
pose_rect.set_width(width);
|
||||
pose_rect.set_height(height);
|
||||
pose_rect.set_rotation(rotation);
|
||||
return pose_rect;
|
||||
}
|
||||
|
||||
class PoseLandmarkerTest : public testing::TestWithParam<SinglePoseTestParams> {
|
||||
};
|
||||
|
||||
TEST_P(PoseLandmarkerTest, Succeeds) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
GetParam().test_image_name)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateSinglePoseTaskRunner(
|
||||
GetParam().input_model_name));
|
||||
|
||||
auto output_packets = task_runner->Process(
|
||||
{{kImageName, MakePacket<Image>(std::move(image))},
|
||||
{kPoseRectName,
|
||||
MakePacket<NormalizedRect>(std::move(GetParam().pose_rect))}});
|
||||
MP_ASSERT_OK(output_packets);
|
||||
|
||||
const bool presence = (*output_packets)[kPresenceName].Get<bool>();
|
||||
ASSERT_EQ(presence, GetParam().expected_presence);
|
||||
|
||||
if (presence) {
|
||||
const NormalizedLandmarkList landmarks =
|
||||
(*output_packets)[kLandmarksName].Get<NormalizedLandmarkList>();
|
||||
|
||||
if (GetParam().expected_landmarks.has_value()) {
|
||||
const NormalizedLandmarkList& expected_landmarks =
|
||||
GetParam().expected_landmarks.value();
|
||||
|
||||
EXPECT_THAT(
|
||||
landmarks,
|
||||
Approximately(Partially(EqualsProto(expected_landmarks)),
|
||||
/*margin=*/kAbsMargin,
|
||||
/*fraction=*/GetParam().landmarks_diff_threshold));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class MultiPoseLandmarkerTest
|
||||
: public testing::TestWithParam<MultiPoseTestParams> {};
|
||||
|
||||
TEST_P(MultiPoseLandmarkerTest, Succeeds) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
GetParam().test_image_name)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto task_runner, CreateMultiPoseTaskRunner(GetParam().input_model_name));
|
||||
|
||||
auto output_packets = task_runner->Process(
|
||||
{{kImageName, MakePacket<Image>(std::move(image))},
|
||||
{kPoseRectName, MakePacket<std::vector<NormalizedRect>>(
|
||||
std::move(GetParam().pose_rects))}});
|
||||
MP_ASSERT_OK(output_packets);
|
||||
|
||||
const std::vector<bool>& presences =
|
||||
(*output_packets)[kPresenceName].Get<std::vector<bool>>();
|
||||
const std::vector<NormalizedLandmarkList>& landmark_lists =
|
||||
(*output_packets)[kLandmarksName]
|
||||
.Get<std::vector<NormalizedLandmarkList>>();
|
||||
|
||||
EXPECT_THAT(presences, ElementsAreArray(GetParam().expected_presences));
|
||||
|
||||
EXPECT_THAT(
|
||||
landmark_lists,
|
||||
Pointwise(Approximately(Partially(EqualsProto()),
|
||||
/*margin=*/kAbsMargin,
|
||||
/*fraction=*/GetParam().landmarks_diff_threshold),
|
||||
GetParam().expected_landmark_lists));
|
||||
}
|
||||
// TODO: Add additional tests for MP Tasks Pose Graphs.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
PoseLandmarkerTest, PoseLandmarkerTest,
|
||||
Values(
|
||||
SinglePoseTestParams{
|
||||
.test_name = "PoseLandmarkerLiteModel",
|
||||
.input_model_name = kPoseLandmarkerLiteModel,
|
||||
.test_image_name = kPoseImage,
|
||||
.pose_rect = MakePoseRect(0.25, 0.5, 0.5, 1.0, 0),
|
||||
.expected_presence = true,
|
||||
.expected_landmarks =
|
||||
GetExpectedLandmarkList(kExpectedPoseLandmarksFilename),
|
||||
.landmarks_diff_threshold = kLiteModelFractionDiff},
|
||||
SinglePoseTestParams{
|
||||
.test_name = "PoseLandmarkerLiteModelNoPose",
|
||||
.input_model_name = kPoseLandmarkerLiteModel,
|
||||
.test_image_name = kBurgerImage,
|
||||
.pose_rect = MakePoseRect(0.25, 0.5, 0.5, 1.0, 0),
|
||||
.expected_presence = false,
|
||||
.expected_landmarks = std::nullopt,
|
||||
.landmarks_diff_threshold = kLiteModelFractionDiff}),
|
||||
[](const TestParamInfo<PoseLandmarkerTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
MultiPoseLandmarkerTest, MultiPoseLandmarkerTest,
|
||||
Values(MultiPoseTestParams{
|
||||
.test_name = "MultiPoseLandmarkerLiteModel",
|
||||
.input_model_name = kPoseLandmarkerLiteModel,
|
||||
.test_image_name = kPoseImage,
|
||||
.pose_rects = {MakePoseRect(0.25, 0.5, 0.5, 1.0, 0)},
|
||||
.expected_presences = {true},
|
||||
.expected_landmark_lists = {GetExpectedLandmarkList(
|
||||
kExpectedPoseLandmarksFilename)},
|
||||
.landmarks_diff_threshold = kLiteModelFractionDiff,
|
||||
}),
|
||||
[](const TestParamInfo<MultiPoseLandmarkerTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace pose_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
31
mediapipe/tasks/cc/vision/pose_landmarker/proto/BUILD
Normal file
31
mediapipe/tasks/cc/vision/pose_landmarker/proto/BUILD
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "pose_landmarks_detector_graph_options_proto",
|
||||
srcs = ["pose_landmarks_detector_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.vision.pose_landmarker.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/framework/calculator_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.poselandmarker.proto";
|
||||
option java_outer_classname = "PoseLandmarksDetectorGraphOptionsProto";
|
||||
|
||||
message PoseLandmarksDetectorGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional PoseLandmarksDetectorGraphOptions ext = 518928384;
|
||||
}
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Minimum confidence value ([0.0, 1.0]) for pose presence score to be
|
||||
// considered successfully detecting a pose in the image.
|
||||
optional float min_detection_confidence = 2 [default = 0.5];
|
||||
}
|
3
mediapipe/tasks/testdata/vision/BUILD
vendored
3
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -77,6 +77,7 @@ mediapipe_files(srcs = [
|
|||
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg",
|
||||
"pose.jpg",
|
||||
"pose_detection.tflite",
|
||||
"pose_landmark_lite.tflite",
|
||||
"right_hands.jpg",
|
||||
"right_hands_rotated.jpg",
|
||||
"segmentation_golden_rotation0.png",
|
||||
|
@ -185,6 +186,7 @@ filegroup(
|
|||
"mobilenet_v3_small_100_224_embedder.tflite",
|
||||
"palm_detection_full.tflite",
|
||||
"pose_detection.tflite",
|
||||
"pose_landmark_lite.tflite",
|
||||
"selfie_segm_128_128_3.tflite",
|
||||
"selfie_segm_144_256_3.tflite",
|
||||
"selfie_segmentation.tflite",
|
||||
|
@ -199,6 +201,7 @@ filegroup(
|
|||
"expected_left_down_hand_rotated_landmarks.prototxt",
|
||||
"expected_left_up_hand_landmarks.prototxt",
|
||||
"expected_left_up_hand_rotated_landmarks.prototxt",
|
||||
"expected_pose_landmarks.prototxt",
|
||||
"expected_right_down_hand_landmarks.prototxt",
|
||||
"expected_right_up_hand_landmarks.prototxt",
|
||||
"face_geometry_expected_out.pbtxt",
|
||||
|
|
231
mediapipe/tasks/testdata/vision/expected_pose_landmarks.prototxt
vendored
Normal file
231
mediapipe/tasks/testdata/vision/expected_pose_landmarks.prototxt
vendored
Normal file
|
@ -0,0 +1,231 @@
|
|||
landmark {
|
||||
x: 0.5033445
|
||||
y: 0.43453366
|
||||
z: 0.14119335
|
||||
visibility: 0.9981027
|
||||
presence: 0.9992255
|
||||
}
|
||||
landmark {
|
||||
x: 0.49945074
|
||||
y: 0.42642897
|
||||
z: 0.1261509
|
||||
visibility: 0.99845314
|
||||
presence: 0.9988668
|
||||
}
|
||||
landmark {
|
||||
x: 0.49838358
|
||||
y: 0.42594656
|
||||
z: 0.12615599
|
||||
visibility: 0.9983102
|
||||
presence: 0.9988244
|
||||
}
|
||||
landmark {
|
||||
x: 0.49738216
|
||||
y: 0.42554975
|
||||
z: 0.12622544
|
||||
visibility: 0.99852186
|
||||
presence: 0.99873894
|
||||
}
|
||||
landmark {
|
||||
x: 0.5031697
|
||||
y: 0.42662272
|
||||
z: 0.13347684
|
||||
visibility: 0.9978422
|
||||
presence: 0.9989411
|
||||
}
|
||||
landmark {
|
||||
x: 0.5052138
|
||||
y: 0.42640454
|
||||
z: 0.13340297
|
||||
visibility: 0.9973635
|
||||
presence: 0.99888366
|
||||
}
|
||||
landmark {
|
||||
x: 0.50699204
|
||||
y: 0.42618936
|
||||
z: 0.13343208
|
||||
visibility: 0.9977912
|
||||
presence: 0.9987791
|
||||
}
|
||||
landmark {
|
||||
x: 0.50158876
|
||||
y: 0.42372653
|
||||
z: 0.075136274
|
||||
visibility: 0.9976786
|
||||
presence: 0.99851876
|
||||
}
|
||||
landmark {
|
||||
x: 0.51374334
|
||||
y: 0.42491674
|
||||
z: 0.10752227
|
||||
visibility: 0.99754214
|
||||
presence: 0.99863297
|
||||
}
|
||||
landmark {
|
||||
x: 0.5059119
|
||||
y: 0.43917143
|
||||
z: 0.12961307
|
||||
visibility: 0.9958307
|
||||
presence: 0.9977532
|
||||
}
|
||||
landmark {
|
||||
x: 0.5109223
|
||||
y: 0.43983036
|
||||
z: 0.13892516
|
||||
visibility: 0.99557614
|
||||
presence: 0.9976047
|
||||
}
|
||||
landmark {
|
||||
x: 0.5023406
|
||||
y: 0.45419085
|
||||
z: 0.045139182
|
||||
visibility: 0.99948055
|
||||
presence: 0.99782556
|
||||
}
|
||||
landmark {
|
||||
x: 0.5593891
|
||||
y: 0.44672203
|
||||
z: 0.09991482
|
||||
visibility: 0.99833965
|
||||
presence: 0.9977375
|
||||
}
|
||||
landmark {
|
||||
x: 0.49746004
|
||||
y: 0.4821021
|
||||
z: 0.037247833
|
||||
visibility: 0.80266625
|
||||
presence: 0.9972971
|
||||
}
|
||||
landmark {
|
||||
x: 0.57370883
|
||||
y: 0.47189793
|
||||
z: 0.13776684
|
||||
visibility: 0.5926873
|
||||
presence: 0.9988753
|
||||
}
|
||||
landmark {
|
||||
x: 0.5110358
|
||||
y: 0.44240227
|
||||
z: 0.040228914
|
||||
visibility: 0.73293996
|
||||
presence: 0.9983026
|
||||
}
|
||||
landmark {
|
||||
x: 0.55913407
|
||||
y: 0.45325255
|
||||
z: 0.15570506
|
||||
visibility: 0.6228974
|
||||
presence: 0.998798
|
||||
}
|
||||
landmark {
|
||||
x: 0.5122394
|
||||
y: 0.42851248
|
||||
z: 0.022895059
|
||||
visibility: 0.67662305
|
||||
presence: 0.9976555
|
||||
}
|
||||
landmark {
|
||||
x: 0.5534328
|
||||
y: 0.45476273
|
||||
z: 0.13998204
|
||||
visibility: 0.5879434
|
||||
presence: 0.99810314
|
||||
}
|
||||
landmark {
|
||||
x: 0.5166826
|
||||
y: 0.42873073
|
||||
z: 0.013023325
|
||||
visibility: 0.66849846
|
||||
presence: 0.9973978
|
||||
}
|
||||
landmark {
|
||||
x: 0.54080456
|
||||
y: 0.45333704
|
||||
z: 0.13224734
|
||||
visibility: 0.58104885
|
||||
presence: 0.99810755
|
||||
}
|
||||
landmark {
|
||||
x: 0.5170599
|
||||
y: 0.43338987
|
||||
z: 0.03576146
|
||||
visibility: 0.65676475
|
||||
presence: 0.99781895
|
||||
}
|
||||
landmark {
|
||||
x: 0.550342
|
||||
y: 0.4529801
|
||||
z: 0.15014008
|
||||
visibility: 0.57411015
|
||||
presence: 0.9985139
|
||||
}
|
||||
landmark {
|
||||
x: 0.5236847
|
||||
y: 0.5765062
|
||||
z: -0.03329975
|
||||
visibility: 0.9998833
|
||||
presence: 0.99935263
|
||||
}
|
||||
landmark {
|
||||
x: 0.55076087
|
||||
y: 0.57722
|
||||
z: 0.033408616
|
||||
visibility: 0.99978465
|
||||
presence: 0.9994066
|
||||
}
|
||||
landmark {
|
||||
x: 0.56554604
|
||||
y: 0.68362844
|
||||
z: -0.1572319
|
||||
visibility: 0.8825664
|
||||
presence: 0.99819404
|
||||
}
|
||||
landmark {
|
||||
x: 0.6127089
|
||||
y: 0.6976384
|
||||
z: 0.034268316
|
||||
visibility: 0.6231058
|
||||
presence: 0.9986853
|
||||
}
|
||||
landmark {
|
||||
x: 0.63440466
|
||||
y: 0.85055786
|
||||
z: -0.21429145
|
||||
visibility: 0.93542594
|
||||
presence: 0.99297804
|
||||
}
|
||||
landmark {
|
||||
x: 0.68133575
|
||||
y: 0.8420663
|
||||
z: 0.024820065
|
||||
visibility: 0.8150173
|
||||
presence: 0.994477
|
||||
}
|
||||
landmark {
|
||||
x: 0.65322053
|
||||
y: 0.8746961
|
||||
z: -0.22213714
|
||||
visibility: 0.89239687
|
||||
presence: 0.986766
|
||||
}
|
||||
landmark {
|
||||
x: 0.69881815
|
||||
y: 0.862424
|
||||
z: 0.018177645
|
||||
visibility: 0.7516772
|
||||
presence: 0.98905355
|
||||
}
|
||||
landmark {
|
||||
x: 0.62477547
|
||||
y: 0.9208759
|
||||
z: -0.31597853
|
||||
visibility: 0.92574716
|
||||
presence: 0.9792296
|
||||
}
|
||||
landmark {
|
||||
x: 0.6807446
|
||||
y: 0.8746925
|
||||
z: -0.059982482
|
||||
visibility: 0.8126682
|
||||
presence: 0.98336893
|
||||
}
|
154
third_party/external_files.bzl
vendored
154
third_party/external_files.bzl
vendored
|
@ -70,6 +70,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=166187566369397616783235763936531678737479599640"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_BUILD_orig",
|
||||
sha256 = "d86b98b82e00dd87cd46bd1429bf5eaa007b500c1a24d9316b73309f2e6c8df8",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1679955080207504"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_burger_crop_jpg",
|
||||
sha256 = "8f58de573f0bf59a49c3d86cfabb9ad4061481f574aa049177e8da3963dddc50",
|
||||
|
@ -298,6 +304,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666629476401757"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_expected_pose_landmarks_prototxt",
|
||||
sha256 = "75dfd2825fc23f51e3906f3a0a050caa8ae9f502cc358af1e7c9fda7ea89c9a5",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=1679955083449778"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_expected_right_down_hand_landmarks_prototxt",
|
||||
sha256 = "f281b745175aaa7f458def6cf4c89521fb56302dd61a05642b3b4a4f237ffaa3",
|
||||
|
@ -942,8 +954,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pose_landmark_lite_tflite",
|
||||
sha256 = "f17bfbecadb61c3be1baa8b8d851cc6619c870a87167b32848ad20db306b9d61",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmark_lite.tflite?generation=1661875901231143"],
|
||||
sha256 = "13628a7d1c1a0f601ae7202c71ec8edc3ac42db9d15f116c494ff24d1afabdd7",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmark_lite.tflite?generation=1679955090327685"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -1216,6 +1228,18 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/victory_landmarks.pbtxt?generation=1666999366036622"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_vit_multiclass_256x256-2022_10_14-xenoformer_f32_tflite",
|
||||
sha256 = "768b7dd613c5b9f263b289cdbe1b9bc716f65e92c51e7ae57fae01f97f9658cd",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/vit_multiclass_256x256-2022_10_14-xenoformer.f32.tflite?generation=1679955093986149"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_vit_multiclass_512x512-2022_12_02_f32_tflite",
|
||||
sha256 = "7ac7f0a037cd451b9be8eb25da86339aaba54fa821a0bd44e18768866ed0205a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/vit_multiclass_512x512-2022_12_02.f32.tflite?generation=1679955096856280"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_vocab_for_regex_tokenizer_txt",
|
||||
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",
|
||||
|
@ -1234,6 +1258,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/vocab_with_index.txt?generation=1661875977280658"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_w_avg_npy",
|
||||
sha256 = "a044e35609986d18a972532f2980e939832b5b7d559659959d11ecc752a58bbe",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/w_avg.npy?generation=1679955100435717"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_yamnet_audio_classifier_with_metadata_tflite",
|
||||
sha256 = "10c95ea3eb9a7bb4cb8bddf6feb023250381008177ac162ce169694d05c317de",
|
||||
|
@ -1246,6 +1276,60 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_embedding_metadata.tflite?generation=1668295071595506"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_decoder_fingerprint_pb",
|
||||
sha256 = "0bf6239c4855d78edb60f3349b46cdb2c6f83def64f1b31589b6e298e5cbec3c",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/fingerprint.pb?generation=1679955102906559"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_decoder_keras_metadata_pb",
|
||||
sha256 = "1631ee698455aea52d4467fe6118800718a86ec49c29f4f3c904785b72f425ff",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/keras_metadata.pb?generation=1679955105294959"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_decoder_saved_model_pb",
|
||||
sha256 = "b424d30c63548e93390b2944b9bd9dc29773a56197bb462d3bd9e7a0bd1270ff",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/saved_model.pb?generation=1679955107808916"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_discriminator_fingerprint_pb",
|
||||
sha256 = "1fa6201d253c9218f7054138b9ce273266ce431e00cbce2d74d557f6b97223fd",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/fingerprint.pb?generation=1679955110094297"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_discriminator_keras_metadata_pb",
|
||||
sha256 = "59a8601790d615dd37ec24e788743ce737e9999ce6ea6593fcf1ee43f674987f",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/keras_metadata.pb?generation=1679955112486389"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_discriminator_saved_model_pb",
|
||||
sha256 = "280d6097a9b3d4c3756e028c597fe3d3c76eb14f76f24d49a22ed7b6df1e3878",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/saved_model.pb?generation=1679955114873053"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_encoder_fingerprint_pb",
|
||||
sha256 = "06cb4319f8178edf447a7a2442e89303a14a48cc4fc5ae27354eac2ba11ae120",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/fingerprint.pb?generation=1679955117213209"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_encoder_keras_metadata_pb",
|
||||
sha256 = "8b1429ee95c130fad0c78077b2b544cd03c9e288658aae93e81df4959b84009e",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/keras_metadata.pb?generation=1679955119546778"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_encoder_saved_model_pb",
|
||||
sha256 = "ce48392c71485ecd9b142b46e54442581a299df5560102337038b76a62e02a09",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/saved_model.pb?generation=1679955122069362"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb",
|
||||
sha256 = "c76b856101e2284293a5e5963b7c445e407a0b3e56ec63eb78f64d883e51e3aa",
|
||||
|
@ -1258,6 +1342,24 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mapping_fingerprint_pb",
|
||||
sha256 = "6320890f1b9a57e5f4e50e3b56d96fd39d815aa2de51dd1c9b635aa6107d982b",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/fingerprint.pb?generation=1679955124430234"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mapping_keras_metadata_pb",
|
||||
sha256 = "22582a2ec1d4883b52f50e628c1a2d69a2610b38d72a48a0bd9939c26be304f6",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/keras_metadata.pb?generation=1679955126858694"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mapping_saved_model_pb",
|
||||
sha256 = "6a79de45d00f49110304bf0a6746bc717c45f77824cad22690f700f2fbdc1470",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/saved_model.pb?generation=1679955129259768"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilebert_tiny_keras_metadata_pb",
|
||||
sha256 = "cef8131a414c602b9d4742ac57f4f90bc5d8a42baec36b65deece884e2d0cf0f",
|
||||
|
@ -1306,6 +1408,42 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/saved_model.pb?generation=1661875999264354"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_decoder_variables_variables_data-00000-of-00001",
|
||||
sha256 = "d720ddf354036f17fa210951f9ebfb009453b244913a493f494f1441cfc2eca3",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/variables/variables.data-00000-of-00001?generation=1679955132326947"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_decoder_variables_variables_index",
|
||||
sha256 = "245f69af6e53fb8b163059fe9936f57b68a7844e15d696393fcddf94c771dfcc",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/variables/variables.index?generation=1679955134518344"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_discriminator_variables_variables_data-00000-of-00001",
|
||||
sha256 = "50b00e1898a573588fb0d5d24d74346d99b7153b5d79441d0350c2c6ca89fb02",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/variables/variables.data-00000-of-00001?generation=1679955138489595"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_discriminator_variables_variables_index",
|
||||
sha256 = "e5cb4be6442a5741504ce7da9487445637ad89b1f4b6a993bb9e762c7bd5621d",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/variables/variables.index?generation=1679955140891136"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_encoder_variables_variables_data-00000-of-00001",
|
||||
sha256 = "09bcd1e2f1c6261bd1842af2da95651d54c9b4b9343eb9b8f0004a97f9bc84bf",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/variables/variables.data-00000-of-00001?generation=1679955144875765"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_encoder_variables_variables_index",
|
||||
sha256 = "964f5ac6ced7b19f76b7856d9dad47594a5b2fa89c52840f82996b809372aec9",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/variables/variables.index?generation=1679955147123313"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_gesture_embedder_variables_variables_data-00000-of-00001",
|
||||
sha256 = "c156c9654c9ffb1091bb9f06c71080bd1e428586276d3f39c33fbab27fe0522d",
|
||||
|
@ -1318,6 +1456,18 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mapping_variables_variables_data-00000-of-00001",
|
||||
sha256 = "4187055e7f69fcc913ee2b11151a56149dda3017c75621d1e160596bde874c07",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/variables/variables.data-00000-of-00001?generation=1679955149680908"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mapping_variables_variables_index",
|
||||
sha256 = "a04fcae7083715613f93ac89943f5fe1f5ba2e6efb9efd14eee7314f25502e4a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/variables/variables.index?generation=1679955152034297"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilebert_tiny_assets_vocab_txt",
|
||||
sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3",
|
||||
|
|
Loading…
Reference in New Issue
Block a user