Add pose landmarker C++ API.

PiperOrigin-RevId: 523795237
This commit is contained in:
MediaPipe Team 2023-04-12 13:48:02 -07:00 committed by Copybara-Service
parent c7aecb42ff
commit 27c38f00ec
10 changed files with 1473 additions and 7 deletions

View File

@ -18,6 +18,37 @@ package(default_visibility = [
licenses(["notice"])
cc_library(
name = "pose_landmarker",
srcs = ["pose_landmarker.cc"],
hdrs = ["pose_landmarker.h"],
visibility = ["//visibility:public"],
deps = [
":pose_landmarker_graph",
":pose_landmarker_result",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "pose_landmarks_detector_graph",
srcs = ["pose_landmarks_detector_graph.cc"],
@ -110,3 +141,15 @@ cc_library(
],
alwayslink = 1,
)
cc_library(
name = "pose_landmarker_result",
srcs = ["pose_landmarker_result.cc"],
hdrs = ["pose_landmarker_result.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/tasks/cc/components/containers:landmark",
],
)

View File

@ -0,0 +1,307 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.h"
#include <memory>
#include <vector>
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace pose_landmarker {
namespace {
using PoseLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision::
pose_landmarker::proto::PoseLandmarkerGraphOptions;
using ::mediapipe::NormalizedRect;
constexpr char kPoseLandmarkerGraphTypeName[] =
"mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kSegmentationMaskTag[] = "SEGMENTATION_MASK";
constexpr char kSegmentationMaskStreamName[] = "segmentation_mask";
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kNormLandmarksStreamName[] = "norm_landmarks";
constexpr char kPoseWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kPoseWorldLandmarksStreamName[] = "world_landmarks";
constexpr char kPoseAuxiliaryLandmarksTag[] = "AUXILIARY_LANDMARKS";
constexpr char kPoseAuxiliaryLandmarksStreamName[] = "auxiliary_landmarks";
constexpr int kMicroSecondsPerMilliSecond = 1000;
// Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.pose_ladnamrker.PoseLandmarkerGraph". If the task is
// running in the live stream mode, a "FlowLimiterCalculator" will be added to
// limit the number of frames in flight.
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
bool enable_flow_limiting) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
graph.Out(kSegmentationMaskTag);
subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
graph.Out(kNormLandmarksTag);
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
graph.Out(kPoseWorldLandmarksTag);
subgraph.Out(kPoseAuxiliaryLandmarksTag)
.SetName(kPoseAuxiliaryLandmarksStreamName) >>
graph.Out(kPoseAuxiliaryLandmarksTag);
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(
graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
}
graph.In(kImageTag) >> subgraph.In(kImageTag);
graph.In(kNormRectTag) >> subgraph.In(kNormRectTag);
return graph.GetConfig();
}
// Converts the user-facing PoseLandmarkerOptions struct to the internal
// PoseLandmarkerGraphOptions proto.
std::unique_ptr<PoseLandmarkerGraphOptionsProto>
ConvertPoseLandmarkerGraphOptionsProto(PoseLandmarkerOptions* options) {
auto options_proto = std::make_unique<PoseLandmarkerGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE);
// Configure pose detector options.
auto* pose_detector_graph_options =
options_proto->mutable_pose_detector_graph_options();
pose_detector_graph_options->set_num_poses(options->num_poses);
pose_detector_graph_options->set_min_detection_confidence(
options->min_pose_detection_confidence);
// Configure pose landmark detector options.
options_proto->set_min_tracking_confidence(options->min_tracking_confidence);
auto* pose_landmarks_detector_graph_options =
options_proto->mutable_pose_landmarks_detector_graph_options();
pose_landmarks_detector_graph_options->set_min_detection_confidence(
options->min_pose_presence_confidence);
return options_proto;
}
} // namespace
absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
std::unique_ptr<PoseLandmarkerOptions> options) {
auto options_proto = ConvertPoseLandmarkerGraphOptionsProto(options.get());
tasks::core::PacketsCallback packets_callback = nullptr;
if (options->result_callback) {
auto result_callback = options->result_callback;
bool output_segmentation_masks = options->output_segmentation_masks;
packets_callback = [=](absl::StatusOr<tasks::core::PacketMap>
status_or_packets) {
if (!status_or_packets.ok()) {
Image image;
result_callback(status_or_packets.status(), image,
Timestamp::Unset().Value());
return;
}
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return;
}
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
if (status_or_packets.value()[kNormLandmarksStreamName].IsEmpty()) {
Packet empty_packet =
status_or_packets.value()[kNormLandmarksStreamName];
result_callback(
{PoseLandmarkerResult()}, image_packet.Get<Image>(),
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
return;
}
Packet segmentation_mask_packet =
status_or_packets.value()[kSegmentationMaskStreamName];
Packet pose_landmarks_packet =
status_or_packets.value()[kNormLandmarksStreamName];
Packet pose_world_landmarks_packet =
status_or_packets.value()[kPoseWorldLandmarksStreamName];
Packet pose_auxiliary_landmarks_packet =
status_or_packets.value()[kPoseAuxiliaryLandmarksStreamName];
std::optional<std::vector<Image>> segmentation_mask = std::nullopt;
if (output_segmentation_masks) {
segmentation_mask = segmentation_mask_packet.Get<std::vector<Image>>();
}
result_callback(
ConvertToPoseLandmarkerResult(
/* segmentation_mask= */ segmentation_mask,
/* pose_landmarks= */
pose_landmarks_packet.Get<std::vector<NormalizedLandmarkList>>(),
/* pose_world_landmarks= */
pose_world_landmarks_packet.Get<std::vector<LandmarkList>>(),
pose_auxiliary_landmarks_packet
.Get<std::vector<NormalizedLandmarkList>>()),
image_packet.Get<Image>(),
pose_landmarks_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
};
}
ASSIGN_OR_RETURN(
std::unique_ptr<PoseLandmarker> pose_landmarker,
(core::VisionTaskApiFactory::Create<PoseLandmarker,
PoseLandmarkerGraphOptionsProto>(
CreateGraphConfig(
std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback))));
pose_landmarker->output_segmentation_masks_ =
options->output_segmentation_masks;
return pose_landmarker;
}
absl::StatusOr<PoseLandmarkerResult> PoseLandmarker::Detect(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, image,
/*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
if (output_packets[kNormLandmarksStreamName].IsEmpty()) {
return {PoseLandmarkerResult()};
}
std::optional<std::vector<Image>> segmentation_mask = std::nullopt;
if (output_segmentation_masks_) {
segmentation_mask =
output_packets[kSegmentationMaskStreamName].Get<std::vector<Image>>();
}
return ConvertToPoseLandmarkerResult(
/* segmentation_mask= */
segmentation_mask,
/* pose_landmarks= */
output_packets[kNormLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>(),
/* pose_world_landmarks */
output_packets[kPoseWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>(),
/*pose_auxiliary_landmarks= */
output_packets[kPoseAuxiliaryLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>());
}
absl::StatusOr<PoseLandmarkerResult> PoseLandmarker::DetectForVideo(
mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, image,
/*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
if (output_packets[kNormLandmarksStreamName].IsEmpty()) {
return {PoseLandmarkerResult()};
}
std::optional<std::vector<Image>> segmentation_mask = std::nullopt;
if (output_segmentation_masks_) {
segmentation_mask =
output_packets[kSegmentationMaskStreamName].Get<std::vector<Image>>();
}
return ConvertToPoseLandmarkerResult(
/* segmentation_mask= */
segmentation_mask,
/* pose_landmarks= */
output_packets[kNormLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>(),
/* pose_world_landmarks */
output_packets[kPoseWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>(),
/* pose_auxiliary_landmarks= */
output_packets[kPoseAuxiliaryLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>());
}
absl::Status PoseLandmarker::DetectAsync(
mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, image,
/*roi_allowed=*/false));
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}
} // namespace pose_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,193 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKER_H_
#define MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKER_H_
#include <memory>
#include <optional>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace pose_landmarker {
struct PoseLandmarkerOptions {
// Base options for configuring MediaPipe Tasks library, such as specifying
// the TfLite model bundle file with metadata, accelerator options, op
// resolver, etc.
tasks::core::BaseOptions base_options;
// The running mode of the task. Default to the image mode.
// PoseLandmarker has three running modes:
// 1) The image mode for detecting pose landmarks on single image inputs.
// 2) The video mode for detecting pose landmarks on the decoded frames of a
// video.
// 3) The live stream mode for detecting pose landmarks on the live stream of
// input data, such as from camera. In this mode, the "result_callback"
// below must be specified to receive the detection results asynchronously.
core::RunningMode running_mode = core::RunningMode::IMAGE;
// The maximum number of poses can be detected by the PoseLandmarker.
int num_poses = 1;
// The minimum confidence score for the pose detection to be considered
// successful.
float min_pose_detection_confidence = 0.5;
// The minimum confidence score of pose presence score in the pose landmark
// detection.
float min_pose_presence_confidence = 0.5;
// The minimum confidence score for the pose tracking to be considered
// successful.
float min_tracking_confidence = 0.5;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(absl::StatusOr<PoseLandmarkerResult>, const Image&, int64)>
result_callback = nullptr;
// Whether to output segmentation masks.
bool output_segmentation_masks = false;
};
// Performs pose landmarks detection on the given image.
//
// This API expects a pre-trained pose landmarker model asset bundle.
//
// Inputs:
// Image
// - The image that pose landmarks detection runs on.
// std::optional<ImageProcessingOptions>
// - If provided, can be used to specify the rotation to apply to the image
// before performing pose landmarks detection, by setting its 'rotation'
// field in radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation).
// Note that specifying a region-of-interest using the 'x_center',
// 'y_center', 'width' and 'height' fields is NOT supported and will
// result in an invalid argument error being returned.
// Outputs:
// PoseLandmarkerResult
// - The pose landmarks detection results.
class PoseLandmarker : tasks::vision::core::BaseVisionTaskApi {
public:
using BaseVisionTaskApi::BaseVisionTaskApi;
// Creates a PoseLandmarker from a PoseLandmarkerOptions to process image data
// or streaming data. Pose landmarker can be created with one of the following
// three running modes:
// 1) Image mode for detecting pose landmarks on single image inputs. Users
// provide mediapipe::Image to the `Detect` method, and will receive the
// detected pose landmarks results as the return value.
// 2) Video mode for detecting pose landmarks on the decoded frames of a
// video. Users call `DetectForVideo` method, and will receive the detected
// pose landmarks results as the return value.
// 3) Live stream mode for detecting pose landmarks on the live stream of the
// input data, such as from camera. Users call `DetectAsync` to push the
// image data into the PoseLandmarker, the detected results along with the
// input timestamp and the image that pose landmarker runs on will be
// available in the result callback when the pose landmarker finishes the
// work.
static absl::StatusOr<std::unique_ptr<PoseLandmarker>> Create(
std::unique_ptr<PoseLandmarkerOptions> options);
// Performs pose landmarks detection on the given image.
// Only use this method when the PoseLandmarker is created with the image
// running mode.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a region-of-interest
// using the 'region_of_interest' field is NOT supported and will result in an
// invalid argument error being returned.
//
// The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented.
absl::StatusOr<PoseLandmarkerResult> Detect(
Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs pose landmarks detection on the provided video frame.
// Only use this method when the PoseLandmarker is created with the video
// running mode.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a region-of-interest
// using the 'region_of_interest' field is NOT supported and will result in an
// invalid argument error being returned.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
absl::StatusOr<PoseLandmarkerResult> DetectForVideo(
Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to perform pose landmarks detection, and the results
// will be available via the "result_callback" provided in the
// PoseLandmarkerOptions. Only use this method when the PoseLandmarker
// is created with the live stream running mode.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the pose landmarker. The input timestamps must be monotonically
// increasing.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a region-of-interest
// using the 'region_of_interest' field is NOT supported and will result in an
// invalid argument error being returned.
//
// The "result_callback" provides
// - A vector of PoseLandmarkerResult, each is the detected results
// for a input frame.
// - The const reference to the corresponding input image that the pose
// landmarker runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status DetectAsync(Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the PoseLandmarker when all works are done.
absl::Status Close() { return runner_->Close(); }
private:
bool output_segmentation_masks_;
};
} // namespace pose_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKER_H_

View File

@ -0,0 +1,56 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h"
#include <algorithm>
#include "mediapipe/tasks/cc/components/containers/landmark.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace pose_landmarker {
PoseLandmarkerResult ConvertToPoseLandmarkerResult(
std::optional<std::vector<mediapipe::Image>> segmentation_masks,
const std::vector<mediapipe::NormalizedLandmarkList>& pose_landmarks_proto,
const std::vector<mediapipe::LandmarkList>& pose_world_landmarks_proto,
const std::vector<mediapipe::NormalizedLandmarkList>&
pose_auxiliary_landmarks_proto) {
PoseLandmarkerResult result;
result.segmentation_masks = segmentation_masks;
result.pose_landmarks.resize(pose_landmarks_proto.size());
result.pose_world_landmarks.resize(pose_world_landmarks_proto.size());
result.pose_auxiliary_landmarks.resize(pose_auxiliary_landmarks_proto.size());
std::transform(pose_landmarks_proto.begin(), pose_landmarks_proto.end(),
result.pose_landmarks.begin(),
components::containers::ConvertToNormalizedLandmarks);
std::transform(pose_world_landmarks_proto.begin(),
pose_world_landmarks_proto.end(),
result.pose_world_landmarks.begin(),
components::containers::ConvertToLandmarks);
std::transform(pose_auxiliary_landmarks_proto.begin(),
pose_auxiliary_landmarks_proto.end(),
result.pose_auxiliary_landmarks.begin(),
components::containers::ConvertToNormalizedLandmarks);
return result;
}
} // namespace pose_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,57 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
#include <vector>
// #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/components/containers/landmark.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace pose_landmarker {
// The pose landmarks detection result from PoseLandmarker, where each vector
// element represents a single pose detected in the image.
struct PoseLandmarkerResult {
// Segmentation masks for pose.
std::optional<std::vector<Image>> segmentation_masks;
// Detected pose landmarks in normalized image coordinates.
std::vector<components::containers::NormalizedLandmarks> pose_landmarks;
// Detected pose landmarks in world coordinates.
std::vector<components::containers::Landmarks> pose_world_landmarks;
// Detected auxiliary landmarks, used for deriving ROI for next frame.
std::vector<components::containers::NormalizedLandmarks>
pose_auxiliary_landmarks;
};
PoseLandmarkerResult ConvertToPoseLandmarkerResult(
std::optional<std::vector<mediapipe::Image>> segmentation_mask,
const std::vector<mediapipe::NormalizedLandmarkList>& pose_landmarks_proto,
const std::vector<mediapipe::LandmarkList>& pose_world_landmarks_proto,
const std::vector<mediapipe::NormalizedLandmarkList>&
pose_auxiliary_landmarks_proto);
} // namespace pose_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_

View File

@ -0,0 +1,98 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h"
#include <optional>
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/landmark.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace pose_landmarker {
TEST(ConvertFromProto, Succeeds) {
Image segmentation_mask;
mediapipe::NormalizedLandmarkList normalized_landmark_list_proto;
mediapipe::NormalizedLandmark& normalized_landmark_proto =
*normalized_landmark_list_proto.add_landmark();
normalized_landmark_proto.set_x(0.1);
normalized_landmark_proto.set_y(0.2);
normalized_landmark_proto.set_z(0.3);
mediapipe::LandmarkList world_landmark_list_proto;
mediapipe::Landmark& landmark_proto =
*world_landmark_list_proto.add_landmark();
landmark_proto.set_x(3.1);
landmark_proto.set_y(5.2);
landmark_proto.set_z(4.3);
mediapipe::NormalizedLandmarkList auxiliary_landmark_list_proto;
mediapipe::NormalizedLandmark& auxiliary_landmark_proto =
*auxiliary_landmark_list_proto.add_landmark();
auxiliary_landmark_proto.set_x(0.5);
auxiliary_landmark_proto.set_y(0.5);
auxiliary_landmark_proto.set_z(0.5);
std::vector<Image> segmentation_masks_lists = {segmentation_mask};
std::vector<mediapipe::NormalizedLandmarkList> normalized_landmarks_lists = {
normalized_landmark_list_proto};
std::vector<mediapipe::LandmarkList> world_landmarks_lists = {
world_landmark_list_proto};
std::vector<mediapipe::NormalizedLandmarkList> auxiliary_landmarks_lists = {
auxiliary_landmark_list_proto};
PoseLandmarkerResult pose_landmarker_result = ConvertToPoseLandmarkerResult(
segmentation_masks_lists, normalized_landmarks_lists,
world_landmarks_lists, auxiliary_landmarks_lists);
EXPECT_EQ(pose_landmarker_result.pose_landmarks.size(), 1);
EXPECT_EQ(pose_landmarker_result.pose_landmarks[0].landmarks.size(), 1);
EXPECT_THAT(pose_landmarker_result.pose_landmarks[0].landmarks[0],
testing::FieldsAre(testing::FloatEq(0.1), testing::FloatEq(0.2),
testing::FloatEq(0.3), std::nullopt,
std::nullopt, std::nullopt));
EXPECT_EQ(pose_landmarker_result.pose_world_landmarks.size(), 1);
EXPECT_EQ(pose_landmarker_result.pose_world_landmarks[0].landmarks.size(), 1);
EXPECT_THAT(pose_landmarker_result.pose_world_landmarks[0].landmarks[0],
testing::FieldsAre(testing::FloatEq(3.1), testing::FloatEq(5.2),
testing::FloatEq(4.3), std::nullopt,
std::nullopt, std::nullopt));
EXPECT_EQ(pose_landmarker_result.pose_auxiliary_landmarks.size(), 1);
EXPECT_EQ(pose_landmarker_result.pose_auxiliary_landmarks[0].landmarks.size(),
1);
EXPECT_THAT(pose_landmarker_result.pose_auxiliary_landmarks[0].landmarks[0],
testing::FieldsAre(testing::FloatEq(0.5), testing::FloatEq(0.5),
testing::FloatEq(0.5), std::nullopt,
std::nullopt, std::nullopt));
}
} // namespace pose_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,470 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.h"
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/landmark.h"
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "util/tuple/dump_vars.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace pose_landmarker {
namespace {
using ::file::Defaults;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::ConvertToLandmarks;
using ::mediapipe::tasks::components::containers::ConvertToNormalizedLandmarks;
using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kPoseLandmarkerBundleAsset[] = "pose_landmarker.task";
constexpr char kPoseLandmarksFilename[] = "pose_landmarks.pbtxt";
constexpr char kPoseImage[] = "pose.jpg";
constexpr char kBurgerImage[] = "burger.jpg";
constexpr float kLandmarksAbsMargin = 0.03;
constexpr float kLandmarksOnVideoAbsMargin = 0.03;
LandmarksDetectionResult GetLandmarksDetectionResult(
absl::string_view landmarks_file_name) {
LandmarksDetectionResult result;
MP_EXPECT_OK(GetTextProto(
file::JoinPath("./", kTestDataDirectory, landmarks_file_name), &result,
Defaults()));
// Remove z position of landmarks, because they are not used in correctness
// testing. For video or live stream mode, the z positions varies a lot during
// tracking from frame to frame.
for (int i = 0; i < result.landmarks().landmark().size(); i++) {
auto& landmark = *result.mutable_landmarks()->mutable_landmark(i);
landmark.clear_z();
}
return result;
}
PoseLandmarkerResult GetExpectedPoseLandmarkerResult(
const std::vector<absl::string_view>& landmarks_file_names) {
PoseLandmarkerResult expected_results;
for (const auto& file_name : landmarks_file_names) {
const auto landmarks_detection_result =
GetLandmarksDetectionResult(file_name);
expected_results.pose_landmarks.push_back(
ConvertToNormalizedLandmarks(landmarks_detection_result.landmarks()));
expected_results.pose_world_landmarks.push_back(
ConvertToLandmarks(landmarks_detection_result.world_landmarks()));
}
return expected_results;
}
MATCHER_P2(LandmarksMatches, expected_landmarks, toleration, "") {
for (int i = 0; i < arg.size(); i++) {
for (int j = 0; j < arg[i].landmarks.size(); j++) {
if (arg[i].landmarks.size() != expected_landmarks[i].landmarks.size()) {
LOG(INFO) << "sizes not equal";
return false;
}
if (std::abs(arg[i].landmarks[j].x -
expected_landmarks[i].landmarks[j].x) > toleration ||
std::abs(arg[i].landmarks[j].y -
expected_landmarks[i].landmarks[j].y) > toleration) {
LOG(INFO) << DUMP_VARS(arg[i].landmarks[j].x,
expected_landmarks[i].landmarks[j].x);
LOG(INFO) << DUMP_VARS(arg[i].landmarks[j].y,
expected_landmarks[i].landmarks[j].y);
return false;
}
}
}
return true;
}
void ExpectPoseLandmarkerResultsCorrect(
const PoseLandmarkerResult& actual_results,
const PoseLandmarkerResult& expected_results, float margin) {
const auto& actual_landmarks = actual_results.pose_landmarks;
const auto& expected_landmarks = expected_results.pose_landmarks;
ASSERT_EQ(actual_landmarks.size(), expected_landmarks.size());
if (actual_landmarks.empty()) {
return;
}
ASSERT_GE(actual_landmarks.size(), 1);
EXPECT_THAT(actual_landmarks, LandmarksMatches(expected_landmarks, margin));
}
} // namespace
struct TestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of test image.
std::string test_image_name;
// The filename of test model.
std::string test_model_file;
// The rotation to apply to the test image before processing, in degrees
// clockwise.
int rotation;
// Expected results from the pose landmarker model output.
PoseLandmarkerResult expected_results;
};
class ImageModeTest : public testing::TestWithParam<TestParams> {};
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kPoseImage)));
auto options = std::make_unique<PoseLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPoseLandmarkerBundleAsset);
options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PoseLandmarker> pose_landmarker,
PoseLandmarker::Create(std::move(options)));
auto results = pose_landmarker->DetectForVideo(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the video mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = pose_landmarker->DetectAsync(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the live stream mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(pose_landmarker->Close());
}
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kPoseImage)));
auto options = std::make_unique<PoseLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPoseLandmarkerBundleAsset);
options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PoseLandmarker> pose_landmarker,
PoseLandmarker::Create(std::move(options)));
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = pose_landmarker->Detect(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("This task doesn't support region-of-interest"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
}
TEST_P(ImageModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<PoseLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PoseLandmarker> pose_landmarker,
PoseLandmarker::Create(std::move(options)));
PoseLandmarkerResult pose_landmarker_results;
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK_AND_ASSIGN(
pose_landmarker_results,
pose_landmarker->Detect(image, image_processing_options));
} else {
MP_ASSERT_OK_AND_ASSIGN(pose_landmarker_results,
pose_landmarker->Detect(image));
}
ExpectPoseLandmarkerResultsCorrect(pose_landmarker_results,
GetParam().expected_results,
kLandmarksAbsMargin);
MP_ASSERT_OK(pose_landmarker->Close());
}
INSTANTIATE_TEST_SUITE_P(
PoseGestureTest, ImageModeTest,
Values(TestParams{
/* test_name= */ "Pose",
/* test_image_name= */ kPoseImage,
/* test_model_file= */ kPoseLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedPoseLandmarkerResult({kPoseLandmarksFilename}),
},
TestParams{
/* test_name= */ "NoPoses",
/* test_image_name= */ kBurgerImage,
/* test_model_file= */ kPoseLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
{{}, {}, {}},
}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
class VideoModeTest : public testing::TestWithParam<TestParams> {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kPoseImage)));
auto options = std::make_unique<PoseLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPoseLandmarkerBundleAsset);
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PoseLandmarker> pose_landmarker,
PoseLandmarker::Create(std::move(options)));
auto results = pose_landmarker->Detect(image);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the image mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = pose_landmarker->DetectAsync(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the live stream mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(pose_landmarker->Close());
}
TEST_P(VideoModeTest, Succeeds) {
const int iterations = 100;
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<PoseLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PoseLandmarker> pose_landmarker,
PoseLandmarker::Create(std::move(options)));
const auto expected_results = GetParam().expected_results;
for (int i = 0; i < iterations; ++i) {
PoseLandmarkerResult pose_landmarker_results;
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK_AND_ASSIGN(
pose_landmarker_results,
pose_landmarker->DetectForVideo(image, i, image_processing_options));
} else {
MP_ASSERT_OK_AND_ASSIGN(pose_landmarker_results,
pose_landmarker->DetectForVideo(image, i));
}
LOG(INFO) << i;
ExpectPoseLandmarkerResultsCorrect(
pose_landmarker_results, expected_results, kLandmarksOnVideoAbsMargin);
}
MP_ASSERT_OK(pose_landmarker->Close());
}
// TODO Add additional tests for MP Tasks Pose Graphs
// TODO Investigate PoseLandmarker performance in VideoMode.
INSTANTIATE_TEST_SUITE_P(
PoseGestureTest, VideoModeTest,
Values(TestParams{
/* test_name= */ "Pose",
/* test_image_name= */ kPoseImage,
/* test_model_file= */ kPoseLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedPoseLandmarkerResult({kPoseLandmarksFilename}),
},
TestParams{
/* test_name= */ "NoPoses",
/* test_image_name= */ kBurgerImage,
/* test_model_file= */ kPoseLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
{{}, {}, {}},
}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
class LiveStreamModeTest : public testing::TestWithParam<TestParams> {};
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kPoseImage)));
auto options = std::make_unique<PoseLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPoseLandmarkerBundleAsset);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<PoseLandmarkerResult> results,
const Image& image, int64_t timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PoseLandmarker> pose_landmarker,
PoseLandmarker::Create(std::move(options)));
auto results = pose_landmarker->Detect(image);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the image mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = pose_landmarker->DetectForVideo(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the video mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(pose_landmarker->Close());
}
TEST_P(LiveStreamModeTest, Succeeds) {
const int iterations = 100;
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<PoseLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
options->running_mode = core::RunningMode::LIVE_STREAM;
std::vector<PoseLandmarkerResult> pose_landmarker_results;
std::vector<std::pair<int, int>> image_sizes;
std::vector<int64_t> timestamps;
options->result_callback = [&pose_landmarker_results, &image_sizes,
&timestamps](
absl::StatusOr<PoseLandmarkerResult> results,
const Image& image, int64_t timestamp_ms) {
MP_ASSERT_OK(results.status());
pose_landmarker_results.push_back(std::move(results.value()));
image_sizes.push_back({image.width(), image.height()});
timestamps.push_back(timestamp_ms);
};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PoseLandmarker> pose_landmarker,
PoseLandmarker::Create(std::move(options)));
for (int i = 0; i < iterations; ++i) {
PoseLandmarkerResult pose_landmarker_results;
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK(
pose_landmarker->DetectAsync(image, i, image_processing_options));
} else {
MP_ASSERT_OK(pose_landmarker->DetectAsync(image, i));
}
}
MP_ASSERT_OK(pose_landmarker->Close());
// Due to the flow limiter, the total of outputs will be smaller than the
// number of iterations.
ASSERT_LE(pose_landmarker_results.size(), iterations);
ASSERT_GT(pose_landmarker_results.size(), 0);
const auto expected_results = GetParam().expected_results;
for (int i = 0; i < pose_landmarker_results.size(); ++i) {
ExpectPoseLandmarkerResultsCorrect(pose_landmarker_results[i],
expected_results,
kLandmarksOnVideoAbsMargin);
}
for (const auto& image_size : image_sizes) {
EXPECT_EQ(image_size.first, image.width());
EXPECT_EQ(image_size.second, image.height());
}
int64_t timestamp_ms = -1;
for (const auto& timestamp : timestamps) {
EXPECT_GT(timestamp, timestamp_ms);
timestamp_ms = timestamp;
}
}
// TODO Add additional tests for MP Tasks Pose Graphs
// Investigate PoseLandmarker performance in LiveStreamMode.
INSTANTIATE_TEST_SUITE_P(
PoseGestureTest, LiveStreamModeTest,
Values(TestParams{
/* test_name= */ "Pose",
/* test_image_name= */ kPoseImage,
/* test_model_file= */ kPoseLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedPoseLandmarkerResult({kPoseLandmarksFilename}),
},
TestParams{
/* test_name= */ "NoPoses",
/* test_image_name= */ kBurgerImage,
/* test_model_file= */ kPoseLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
{{}, {}, {}},
}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace pose_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -220,6 +220,7 @@ filegroup(
"portrait_rotated_expected_detection.pbtxt",
"pose_expected_detection.pbtxt",
"pose_expected_expanded_rect.pbtxt",
"pose_landmarks.pbtxt",
"thumb_up_landmarks.pbtxt",
"thumb_up_rotated_landmarks.pbtxt",
"victory_landmarks.pbtxt",

View File

@ -0,0 +1,235 @@
# proto-file: mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto
# proto-message: LandmarksDetectionResult
landmarks {
landmark {
x: 0.46507242
y: 0.41610304
z: -0.14999297
visibility: 0.99998844
presence: 0.99998903
}
landmark {
x: 0.47030905
y: 0.40309227
z: -0.14747797
visibility: 0.9999664
presence: 0.99995553
}
landmark {
x: 0.47342995
y: 0.40156996
z: -0.14745711
visibility: 0.9999546
presence: 0.9999534
}
landmark {
x: 0.47663194
y: 0.39997125
z: -0.14746517
visibility: 0.99994385
presence: 0.99994576
}
landmark {
x: 0.46315864
y: 0.4074311
z: -0.12709364
visibility: 0.99994826
presence: 0.99996126
}
landmark {
x: 0.46167478
y: 0.4089059
z: -0.12713416
visibility: 0.9999267
presence: 0.99996483
}
landmark {
x: 0.45983645
y: 0.41042596
z: -0.12708154
visibility: 0.9999249
presence: 0.9999635
}
landmark {
x: 0.48985234
y: 0.4053322
z: -0.10685073
visibility: 0.99990225
presence: 0.9999591
}
landmark {
x: 0.46542352
y: 0.41843837
z: -0.011494645
visibility: 0.999894
presence: 0.9999794
}
landmark {
x: 0.4757115
y: 0.42860925
z: -0.13509856
visibility: 0.9999609
presence: 0.9999964
}
landmark {
x: 0.46715724
y: 0.4335927
z: -0.10763592
visibility: 0.9999547
presence: 0.99999654
}
landmark {
x: 0.5383015
y: 0.47349828
z: -0.09904624
visibility: 0.99999046
presence: 0.99997306
}
landmark {
x: 0.45313644
y: 0.48936796
z: 0.06387678
visibility: 0.9999919
presence: 0.99999344
}
landmark {
x: 0.6247076
y: 0.47583634
z: -0.16247825
visibility: 0.9971179
presence: 0.9997496
}
landmark {
x: 0.36646777
y: 0.48342204
z: 0.08566214
visibility: 0.98212147
presence: 0.9999796
}
landmark {
x: 0.6950672
y: 0.47276276
z: -0.23432226
visibility: 0.99190086
presence: 0.99751616
}
landmark {
x: 0.29679844
y: 0.46896482
z: -0.020334292
visibility: 0.9826943
presence: 0.99966264
}
landmark {
x: 0.71771
y: 0.466713
z: -0.25376678
visibility: 0.97681385
presence: 0.9940614
}
landmark {
x: 0.2762509
y: 0.46507546
z: -0.030049918
visibility: 0.9577505
presence: 0.9986953
}
landmark {
x: 0.71468514
y: 0.46545807
z: -0.27873537
visibility: 0.9773756
presence: 0.99447954
}
landmark {
x: 0.27898294
y: 0.45952708
z: -0.061881505
visibility: 0.9598176
presence: 0.9987778
}
landmark {
x: 0.70616376
y: 0.47313935
z: -0.2473996
visibility: 0.97479546
presence: 0.9961171
}
landmark {
x: 0.2866556
y: 0.4661593
z: -0.039750997
visibility: 0.96074265
presence: 0.9992084
}
landmark {
x: 0.5199628
y: 0.71132404
z: -0.05568369
visibility: 0.9997075
presence: 0.9997811
}
landmark {
x: 0.46290728
y: 0.7031081
z: 0.05579845
visibility: 0.99977857
presence: 0.99986553
}
landmark {
x: 0.60217524
y: 0.81984437
z: -0.14303333
visibility: 0.99511355
presence: 0.9993449
}
landmark {
x: 0.36314535
y: 0.7332305
z: -0.063036785
visibility: 0.9963553
presence: 0.9997341
}
landmark {
x: 0.70064104
y: 0.93365943
z: -0.031219501
visibility: 0.9909759
presence: 0.9935846
}
landmark {
x: 0.35170174
y: 0.9031892
z: 0.0022715325
visibility: 0.9929086
presence: 0.99870515
}
landmark {
x: 0.70923454
y: 0.954628
z: -0.024844522
visibility: 0.93651295
presence: 0.99121124
}
landmark {
x: 0.36733973
y: 0.93519753
z: 0.003870454
visibility: 0.96783006
presence: 0.99829334
}
landmark {
x: 0.72855467
y: 0.96812284
z: -0.12364431
visibility: 0.9639738
presence: 0.98293763
}
landmark {
x: 0.30121577
y: 0.94473803
z: -0.08448716
visibility: 0.97545445
presence: 0.99386966
}
}

View File

@ -307,13 +307,13 @@ def external_files():
http_file(
name = "com_google_mediapipe_expected_pose_landmarks_prototxt",
sha256 = "eed8dfa169b0abee60cde01496599b0bc75d91a82594a1bdf59be2f76f45d7f5",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=1681243892338529"],
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=1681244232522990"],
)
http_file(
name = "com_google_mediapipe_expected_pose_landmarks_prototxt_orig",
sha256 = "c230e0933e6cb4af69ec21314f3f9930fe13e7bb4bf1dbdb74427e4138c24c1e",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt.orig?generation=1681243894793518"],
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt.orig?generation=1681244235071100"],
)
http_file(
@ -823,7 +823,7 @@ def external_files():
http_file(
name = "com_google_mediapipe_ocr_text_jpg",
sha256 = "88052e93aa910330433741f5cef140f8f9ec463230a332aef7038b5457b06482",
urls = ["https://storage.googleapis.com/mediapipe-assets/ocr_text.jpg?generation=1681243901191762"],
urls = ["https://storage.googleapis.com/mediapipe-assets/ocr_text.jpg?generation=1681244241009078"],
)
http_file(
@ -943,13 +943,13 @@ def external_files():
http_file(
name = "com_google_mediapipe_pose_expected_detection_pbtxt",
sha256 = "16866c8dd4fbee60f6972630d73baed219b45824c055c7fbc7dc9a91c4b182cc",
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_detection.pbtxt?generation=1681243904247231"],
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_detection.pbtxt?generation=1681244244235448"],
)
http_file(
name = "com_google_mediapipe_pose_expected_expanded_rect_pbtxt",
sha256 = "b0a41d25ed115757606dfc034e9d320a93a52616d92d745150b6a886ddc5a88a",
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_expanded_rect.pbtxt?generation=1681243906782078"],
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_expanded_rect.pbtxt?generation=1681244246786802"],
)
http_file(
@ -961,7 +961,7 @@ def external_files():
http_file(
name = "com_google_mediapipe_pose_landmarker_task",
sha256 = "fb9cc326c88fc2a4d9a6d355c28520d5deacfbaa375b56243b0141b546080596",
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarker.task?generation=1681243909774681"],
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarker.task?generation=1681244249587900"],
)
http_file(
@ -979,7 +979,13 @@ def external_files():
http_file(
name = "com_google_mediapipe_pose_landmark_lite_tflite",
sha256 = "1150dc68a713b80660b90ef46ce4e85c1c781bb88b6e3512cc64e6a685ba5588",
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmark_lite.tflite?generation=1681243912778143"],
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmark_lite.tflite?generation=1681244252454799"],
)
http_file(
name = "com_google_mediapipe_pose_landmarks_pbtxt",
sha256 = "305a71fbff83e270a5dbd81fb7cf65203f56e0b1caba8ea42edc16c6e8a2ba18",
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681244254964356"],
)
http_file(