diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD index f763d09eb..1b1b818ce 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD @@ -18,6 +18,37 @@ package(default_visibility = [ licenses(["notice"]) +cc_library( + name = "pose_landmarker", + srcs = ["pose_landmarker.cc"], + hdrs = ["pose_landmarker.h"], + visibility = ["//visibility:public"], + deps = [ + ":pose_landmarker_graph", + ":pose_landmarker_result", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + cc_library( name = "pose_landmarks_detector_graph", srcs = ["pose_landmarks_detector_graph.cc"], @@ -110,3 +141,15 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "pose_landmarker_result", + srcs = ["pose_landmarker_result.cc"], + hdrs = ["pose_landmarker_result.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/tasks/cc/components/containers:landmark", + ], +) diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc new file mode 100644 index 000000000..4c734c423 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc @@ -0,0 +1,307 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.h" + +#include +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/base_task_api.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_landmarker { + +namespace { + +using PoseLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: + pose_landmarker::proto::PoseLandmarkerGraphOptions; + +using ::mediapipe::NormalizedRect; + +constexpr char kPoseLandmarkerGraphTypeName[] = + "mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageInStreamName[] = "image_in"; +constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectStreamName[] = "norm_rect_in"; +constexpr char kSegmentationMaskTag[] = "SEGMENTATION_MASK"; +constexpr char kSegmentationMaskStreamName[] = "segmentation_mask"; +constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; +constexpr char kNormLandmarksStreamName[] = "norm_landmarks"; +constexpr char kPoseWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kPoseWorldLandmarksStreamName[] = "world_landmarks"; +constexpr char kPoseAuxiliaryLandmarksTag[] = "AUXILIARY_LANDMARKS"; +constexpr char kPoseAuxiliaryLandmarksStreamName[] = "auxiliary_landmarks"; +constexpr int kMicroSecondsPerMilliSecond = 1000; + +// Creates a MediaPipe graph config that contains a subgraph node of +// "mediapipe.tasks.vision.pose_ladnamrker.PoseLandmarkerGraph". If the task is +// running in the live stream mode, a "FlowLimiterCalculator" will be added to +// limit the number of frames in flight. +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options, + bool enable_flow_limiting) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName); + subgraph.GetOptions().Swap(options.get()); + graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectStreamName); + subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >> + graph.Out(kSegmentationMaskTag); + subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >> + graph.Out(kNormLandmarksTag); + subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >> + graph.Out(kPoseWorldLandmarksTag); + subgraph.Out(kPoseAuxiliaryLandmarksTag) + .SetName(kPoseAuxiliaryLandmarksStreamName) >> + graph.Out(kPoseAuxiliaryLandmarksTag); + subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); + if (enable_flow_limiting) { + return tasks::core::AddFlowLimiterCalculator( + graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag); + } + graph.In(kImageTag) >> subgraph.In(kImageTag); + graph.In(kNormRectTag) >> subgraph.In(kNormRectTag); + return graph.GetConfig(); +} + +// Converts the user-facing PoseLandmarkerOptions struct to the internal +// PoseLandmarkerGraphOptions proto. +std::unique_ptr +ConvertPoseLandmarkerGraphOptionsProto(PoseLandmarkerOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + options_proto->mutable_base_options()->set_use_stream_mode( + options->running_mode != core::RunningMode::IMAGE); + + // Configure pose detector options. + auto* pose_detector_graph_options = + options_proto->mutable_pose_detector_graph_options(); + pose_detector_graph_options->set_num_poses(options->num_poses); + pose_detector_graph_options->set_min_detection_confidence( + options->min_pose_detection_confidence); + + // Configure pose landmark detector options. + options_proto->set_min_tracking_confidence(options->min_tracking_confidence); + auto* pose_landmarks_detector_graph_options = + options_proto->mutable_pose_landmarks_detector_graph_options(); + pose_landmarks_detector_graph_options->set_min_detection_confidence( + options->min_pose_presence_confidence); + + return options_proto; +} + +} // namespace + +absl::StatusOr> PoseLandmarker::Create( + std::unique_ptr options) { + auto options_proto = ConvertPoseLandmarkerGraphOptionsProto(options.get()); + tasks::core::PacketsCallback packets_callback = nullptr; + if (options->result_callback) { + auto result_callback = options->result_callback; + bool output_segmentation_masks = options->output_segmentation_masks; + packets_callback = [=](absl::StatusOr + status_or_packets) { + if (!status_or_packets.ok()) { + Image image; + result_callback(status_or_packets.status(), image, + Timestamp::Unset().Value()); + return; + } + if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { + return; + } + Packet image_packet = status_or_packets.value()[kImageOutStreamName]; + if (status_or_packets.value()[kNormLandmarksStreamName].IsEmpty()) { + Packet empty_packet = + status_or_packets.value()[kNormLandmarksStreamName]; + result_callback( + {PoseLandmarkerResult()}, image_packet.Get(), + empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); + return; + } + Packet segmentation_mask_packet = + status_or_packets.value()[kSegmentationMaskStreamName]; + Packet pose_landmarks_packet = + status_or_packets.value()[kNormLandmarksStreamName]; + Packet pose_world_landmarks_packet = + status_or_packets.value()[kPoseWorldLandmarksStreamName]; + Packet pose_auxiliary_landmarks_packet = + status_or_packets.value()[kPoseAuxiliaryLandmarksStreamName]; + std::optional> segmentation_mask = std::nullopt; + if (output_segmentation_masks) { + segmentation_mask = segmentation_mask_packet.Get>(); + } + result_callback( + ConvertToPoseLandmarkerResult( + /* segmentation_mask= */ segmentation_mask, + /* pose_landmarks= */ + pose_landmarks_packet.Get>(), + /* pose_world_landmarks= */ + pose_world_landmarks_packet.Get>(), + pose_auxiliary_landmarks_packet + .Get>()), + image_packet.Get(), + pose_landmarks_packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond); + }; + } + ASSIGN_OR_RETURN( + std::unique_ptr pose_landmarker, + (core::VisionTaskApiFactory::Create( + CreateGraphConfig( + std::move(options_proto), + options->running_mode == core::RunningMode::LIVE_STREAM), + std::move(options->base_options.op_resolver), options->running_mode, + std::move(packets_callback)))); + + pose_landmarker->output_segmentation_masks_ = + options->output_segmentation_masks; + + return pose_landmarker; +} + +absl::StatusOr PoseLandmarker::Detect( + mediapipe::Image image, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "GPU input images are currently not supported.", + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, image, + /*roi_allowed=*/false)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData( + {{kImageInStreamName, MakePacket(std::move(image))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect))}})); + if (output_packets[kNormLandmarksStreamName].IsEmpty()) { + return {PoseLandmarkerResult()}; + } + std::optional> segmentation_mask = std::nullopt; + if (output_segmentation_masks_) { + segmentation_mask = + output_packets[kSegmentationMaskStreamName].Get>(); + } + return ConvertToPoseLandmarkerResult( + /* segmentation_mask= */ + segmentation_mask, + /* pose_landmarks= */ + output_packets[kNormLandmarksStreamName] + .Get>(), + /* pose_world_landmarks */ + output_packets[kPoseWorldLandmarksStreamName] + .Get>(), + /*pose_auxiliary_landmarks= */ + output_packets[kPoseAuxiliaryLandmarksStreamName] + .Get>()); +} + +absl::StatusOr PoseLandmarker::DetectForVideo( + mediapipe::Image image, int64_t timestamp_ms, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, image, + /*roi_allowed=*/false)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessVideoData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + if (output_packets[kNormLandmarksStreamName].IsEmpty()) { + return {PoseLandmarkerResult()}; + } + std::optional> segmentation_mask = std::nullopt; + if (output_segmentation_masks_) { + segmentation_mask = + output_packets[kSegmentationMaskStreamName].Get>(); + } + return ConvertToPoseLandmarkerResult( + /* segmentation_mask= */ + segmentation_mask, + /* pose_landmarks= */ + output_packets[kNormLandmarksStreamName] + .Get>(), + /* pose_world_landmarks */ + output_packets[kPoseWorldLandmarksStreamName] + .Get>(), + /* pose_auxiliary_landmarks= */ + output_packets[kPoseAuxiliaryLandmarksStreamName] + .Get>()); +} + +absl::Status PoseLandmarker::DetectAsync( + mediapipe::Image image, int64_t timestamp_ms, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, image, + /*roi_allowed=*/false)); + return SendLiveStreamData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); +} + +} // namespace pose_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.h b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.h new file mode 100644 index 000000000..058ab0b1e --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.h @@ -0,0 +1,193 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKER_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_landmarker { + +struct PoseLandmarkerOptions { + // Base options for configuring MediaPipe Tasks library, such as specifying + // the TfLite model bundle file with metadata, accelerator options, op + // resolver, etc. + tasks::core::BaseOptions base_options; + + // The running mode of the task. Default to the image mode. + // PoseLandmarker has three running modes: + // 1) The image mode for detecting pose landmarks on single image inputs. + // 2) The video mode for detecting pose landmarks on the decoded frames of a + // video. + // 3) The live stream mode for detecting pose landmarks on the live stream of + // input data, such as from camera. In this mode, the "result_callback" + // below must be specified to receive the detection results asynchronously. + core::RunningMode running_mode = core::RunningMode::IMAGE; + + // The maximum number of poses can be detected by the PoseLandmarker. + int num_poses = 1; + + // The minimum confidence score for the pose detection to be considered + // successful. + float min_pose_detection_confidence = 0.5; + + // The minimum confidence score of pose presence score in the pose landmark + // detection. + float min_pose_presence_confidence = 0.5; + + // The minimum confidence score for the pose tracking to be considered + // successful. + float min_tracking_confidence = 0.5; + + // The user-defined result callback for processing live stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::LIVE_STREAM. + std::function, const Image&, int64)> + result_callback = nullptr; + + // Whether to output segmentation masks. + bool output_segmentation_masks = false; +}; + +// Performs pose landmarks detection on the given image. +// +// This API expects a pre-trained pose landmarker model asset bundle. +// +// Inputs: +// Image +// - The image that pose landmarks detection runs on. +// std::optional +// - If provided, can be used to specify the rotation to apply to the image +// before performing pose landmarks detection, by setting its 'rotation' +// field in radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). +// Note that specifying a region-of-interest using the 'x_center', +// 'y_center', 'width' and 'height' fields is NOT supported and will +// result in an invalid argument error being returned. +// Outputs: +// PoseLandmarkerResult +// - The pose landmarks detection results. +class PoseLandmarker : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates a PoseLandmarker from a PoseLandmarkerOptions to process image data + // or streaming data. Pose landmarker can be created with one of the following + // three running modes: + // 1) Image mode for detecting pose landmarks on single image inputs. Users + // provide mediapipe::Image to the `Detect` method, and will receive the + // detected pose landmarks results as the return value. + // 2) Video mode for detecting pose landmarks on the decoded frames of a + // video. Users call `DetectForVideo` method, and will receive the detected + // pose landmarks results as the return value. + // 3) Live stream mode for detecting pose landmarks on the live stream of the + // input data, such as from camera. Users call `DetectAsync` to push the + // image data into the PoseLandmarker, the detected results along with the + // input timestamp and the image that pose landmarker runs on will be + // available in the result callback when the pose landmarker finishes the + // work. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs pose landmarks detection on the given image. + // Only use this method when the PoseLandmarker is created with the image + // running mode. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by setting + // its 'rotation_degrees' field. Note that specifying a region-of-interest + // using the 'region_of_interest' field is NOT supported and will result in an + // invalid argument error being returned. + // + // The image can be of any size with format RGB or RGBA. + // TODO: Describes how the input image will be preprocessed + // after the yuv support is implemented. + absl::StatusOr Detect( + Image image, + std::optional image_processing_options = + std::nullopt); + + // Performs pose landmarks detection on the provided video frame. + // Only use this method when the PoseLandmarker is created with the video + // running mode. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by setting + // its 'rotation_degrees' field. Note that specifying a region-of-interest + // using the 'region_of_interest' field is NOT supported and will result in an + // invalid argument error being returned. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide the video frame's timestamp (in milliseconds). The input timestamps + // must be monotonically increasing. + absl::StatusOr DetectForVideo( + Image image, int64 timestamp_ms, + std::optional image_processing_options = + std::nullopt); + + // Sends live image data to perform pose landmarks detection, and the results + // will be available via the "result_callback" provided in the + // PoseLandmarkerOptions. Only use this method when the PoseLandmarker + // is created with the live stream running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide a timestamp (in milliseconds) to indicate when the input image is + // sent to the pose landmarker. The input timestamps must be monotonically + // increasing. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by setting + // its 'rotation_degrees' field. Note that specifying a region-of-interest + // using the 'region_of_interest' field is NOT supported and will result in an + // invalid argument error being returned. + // + // The "result_callback" provides + // - A vector of PoseLandmarkerResult, each is the detected results + // for a input frame. + // - The const reference to the corresponding input image that the pose + // landmarker runs on. Note that the const reference to the image will no + // longer be valid when the callback returns. To access the image data + // outside of the callback, callers need to make a copy of the image. + // - The input timestamp in milliseconds. + absl::Status DetectAsync(Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); + + // Shuts down the PoseLandmarker when all works are done. + absl::Status Close() { return runner_->Close(); } + + private: + bool output_segmentation_masks_; +}; + +} // namespace pose_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKER_H_ diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.cc new file mode 100644 index 000000000..6222bbd68 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.cc @@ -0,0 +1,56 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h" + +#include + +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_landmarker { + +PoseLandmarkerResult ConvertToPoseLandmarkerResult( + std::optional> segmentation_masks, + const std::vector& pose_landmarks_proto, + const std::vector& pose_world_landmarks_proto, + const std::vector& + pose_auxiliary_landmarks_proto) { + PoseLandmarkerResult result; + result.segmentation_masks = segmentation_masks; + + result.pose_landmarks.resize(pose_landmarks_proto.size()); + result.pose_world_landmarks.resize(pose_world_landmarks_proto.size()); + result.pose_auxiliary_landmarks.resize(pose_auxiliary_landmarks_proto.size()); + std::transform(pose_landmarks_proto.begin(), pose_landmarks_proto.end(), + result.pose_landmarks.begin(), + components::containers::ConvertToNormalizedLandmarks); + std::transform(pose_world_landmarks_proto.begin(), + pose_world_landmarks_proto.end(), + result.pose_world_landmarks.begin(), + components::containers::ConvertToLandmarks); + std::transform(pose_auxiliary_landmarks_proto.begin(), + pose_auxiliary_landmarks_proto.end(), + result.pose_auxiliary_landmarks.begin(), + components::containers::ConvertToNormalizedLandmarks); + return result; +} + +} // namespace pose_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h new file mode 100644 index 000000000..07adb87f5 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h @@ -0,0 +1,57 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_ + +#include + +// #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_landmarker { + +// The pose landmarks detection result from PoseLandmarker, where each vector +// element represents a single pose detected in the image. +struct PoseLandmarkerResult { + // Segmentation masks for pose. + std::optional> segmentation_masks; + // Detected pose landmarks in normalized image coordinates. + std::vector pose_landmarks; + // Detected pose landmarks in world coordinates. + std::vector pose_world_landmarks; + // Detected auxiliary landmarks, used for deriving ROI for next frame. + std::vector + pose_auxiliary_landmarks; +}; + +PoseLandmarkerResult ConvertToPoseLandmarkerResult( + std::optional> segmentation_mask, + const std::vector& pose_landmarks_proto, + const std::vector& pose_world_landmarks_proto, + const std::vector& + pose_auxiliary_landmarks_proto); + +} // namespace pose_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_ diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result_test.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result_test.cc new file mode 100644 index 000000000..10e0d61a3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result_test.cc @@ -0,0 +1,98 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h" + +#include +#include + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_landmarker { + +TEST(ConvertFromProto, Succeeds) { + Image segmentation_mask; + + mediapipe::NormalizedLandmarkList normalized_landmark_list_proto; + mediapipe::NormalizedLandmark& normalized_landmark_proto = + *normalized_landmark_list_proto.add_landmark(); + normalized_landmark_proto.set_x(0.1); + normalized_landmark_proto.set_y(0.2); + normalized_landmark_proto.set_z(0.3); + + mediapipe::LandmarkList world_landmark_list_proto; + mediapipe::Landmark& landmark_proto = + *world_landmark_list_proto.add_landmark(); + landmark_proto.set_x(3.1); + landmark_proto.set_y(5.2); + landmark_proto.set_z(4.3); + + mediapipe::NormalizedLandmarkList auxiliary_landmark_list_proto; + mediapipe::NormalizedLandmark& auxiliary_landmark_proto = + *auxiliary_landmark_list_proto.add_landmark(); + auxiliary_landmark_proto.set_x(0.5); + auxiliary_landmark_proto.set_y(0.5); + auxiliary_landmark_proto.set_z(0.5); + + std::vector segmentation_masks_lists = {segmentation_mask}; + + std::vector normalized_landmarks_lists = { + normalized_landmark_list_proto}; + + std::vector world_landmarks_lists = { + world_landmark_list_proto}; + + std::vector auxiliary_landmarks_lists = { + auxiliary_landmark_list_proto}; + + PoseLandmarkerResult pose_landmarker_result = ConvertToPoseLandmarkerResult( + segmentation_masks_lists, normalized_landmarks_lists, + world_landmarks_lists, auxiliary_landmarks_lists); + + EXPECT_EQ(pose_landmarker_result.pose_landmarks.size(), 1); + EXPECT_EQ(pose_landmarker_result.pose_landmarks[0].landmarks.size(), 1); + EXPECT_THAT(pose_landmarker_result.pose_landmarks[0].landmarks[0], + testing::FieldsAre(testing::FloatEq(0.1), testing::FloatEq(0.2), + testing::FloatEq(0.3), std::nullopt, + std::nullopt, std::nullopt)); + + EXPECT_EQ(pose_landmarker_result.pose_world_landmarks.size(), 1); + EXPECT_EQ(pose_landmarker_result.pose_world_landmarks[0].landmarks.size(), 1); + EXPECT_THAT(pose_landmarker_result.pose_world_landmarks[0].landmarks[0], + testing::FieldsAre(testing::FloatEq(3.1), testing::FloatEq(5.2), + testing::FloatEq(4.3), std::nullopt, + std::nullopt, std::nullopt)); + + EXPECT_EQ(pose_landmarker_result.pose_auxiliary_landmarks.size(), 1); + EXPECT_EQ(pose_landmarker_result.pose_auxiliary_landmarks[0].landmarks.size(), + 1); + EXPECT_THAT(pose_landmarker_result.pose_auxiliary_landmarks[0].landmarks[0], + testing::FieldsAre(testing::FloatEq(0.5), testing::FloatEq(0.5), + testing::FloatEq(0.5), std::nullopt, + std::nullopt, std::nullopt)); +} + +} // namespace pose_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc new file mode 100644 index 000000000..062d0746d --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc @@ -0,0 +1,470 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.h" + +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" +#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "util/tuple/dump_vars.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_landmarker { + +namespace { + +using ::file::Defaults; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::ConvertToLandmarks; +using ::mediapipe::tasks::components::containers::ConvertToNormalizedLandmarks; +using ::mediapipe::tasks::components::containers::RectF; +using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; +using ::testing::HasSubstr; +using ::testing::Optional; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kPoseLandmarkerBundleAsset[] = "pose_landmarker.task"; +constexpr char kPoseLandmarksFilename[] = "pose_landmarks.pbtxt"; + +constexpr char kPoseImage[] = "pose.jpg"; +constexpr char kBurgerImage[] = "burger.jpg"; + +constexpr float kLandmarksAbsMargin = 0.03; +constexpr float kLandmarksOnVideoAbsMargin = 0.03; + +LandmarksDetectionResult GetLandmarksDetectionResult( + absl::string_view landmarks_file_name) { + LandmarksDetectionResult result; + MP_EXPECT_OK(GetTextProto( + file::JoinPath("./", kTestDataDirectory, landmarks_file_name), &result, + Defaults())); + // Remove z position of landmarks, because they are not used in correctness + // testing. For video or live stream mode, the z positions varies a lot during + // tracking from frame to frame. + for (int i = 0; i < result.landmarks().landmark().size(); i++) { + auto& landmark = *result.mutable_landmarks()->mutable_landmark(i); + landmark.clear_z(); + } + return result; +} + +PoseLandmarkerResult GetExpectedPoseLandmarkerResult( + const std::vector& landmarks_file_names) { + PoseLandmarkerResult expected_results; + for (const auto& file_name : landmarks_file_names) { + const auto landmarks_detection_result = + GetLandmarksDetectionResult(file_name); + expected_results.pose_landmarks.push_back( + ConvertToNormalizedLandmarks(landmarks_detection_result.landmarks())); + expected_results.pose_world_landmarks.push_back( + ConvertToLandmarks(landmarks_detection_result.world_landmarks())); + } + return expected_results; +} + +MATCHER_P2(LandmarksMatches, expected_landmarks, toleration, "") { + for (int i = 0; i < arg.size(); i++) { + for (int j = 0; j < arg[i].landmarks.size(); j++) { + if (arg[i].landmarks.size() != expected_landmarks[i].landmarks.size()) { + LOG(INFO) << "sizes not equal"; + return false; + } + if (std::abs(arg[i].landmarks[j].x - + expected_landmarks[i].landmarks[j].x) > toleration || + std::abs(arg[i].landmarks[j].y - + expected_landmarks[i].landmarks[j].y) > toleration) { + LOG(INFO) << DUMP_VARS(arg[i].landmarks[j].x, + expected_landmarks[i].landmarks[j].x); + LOG(INFO) << DUMP_VARS(arg[i].landmarks[j].y, + expected_landmarks[i].landmarks[j].y); + return false; + } + } + } + return true; +} + +void ExpectPoseLandmarkerResultsCorrect( + const PoseLandmarkerResult& actual_results, + const PoseLandmarkerResult& expected_results, float margin) { + const auto& actual_landmarks = actual_results.pose_landmarks; + + const auto& expected_landmarks = expected_results.pose_landmarks; + + ASSERT_EQ(actual_landmarks.size(), expected_landmarks.size()); + + if (actual_landmarks.empty()) { + return; + } + + ASSERT_GE(actual_landmarks.size(), 1); + EXPECT_THAT(actual_landmarks, LandmarksMatches(expected_landmarks, margin)); +} + +} // namespace + +struct TestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of test image. + std::string test_image_name; + // The filename of test model. + std::string test_model_file; + // The rotation to apply to the test image before processing, in degrees + // clockwise. + int rotation; + // Expected results from the pose landmarker model output. + PoseLandmarkerResult expected_results; +}; + +class ImageModeTest : public testing::TestWithParam {}; + +TEST_F(ImageModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kPoseImage))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPoseLandmarkerBundleAsset); + options->running_mode = core::RunningMode::IMAGE; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr pose_landmarker, + PoseLandmarker::Create(std::move(options))); + auto results = pose_landmarker->DetectForVideo(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the video mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = pose_landmarker->DetectAsync(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the live stream mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(pose_landmarker->Close()); +} + +TEST_F(ImageModeTest, FailsWithRegionOfInterest) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kPoseImage))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPoseLandmarkerBundleAsset); + options->running_mode = core::RunningMode::IMAGE; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr pose_landmarker, + PoseLandmarker::Create(std::move(options))); + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; + + auto results = pose_landmarker->Detect(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("This task doesn't support region-of-interest")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); +} + +TEST_P(ImageModeTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, GetParam().test_model_file); + options->running_mode = core::RunningMode::IMAGE; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr pose_landmarker, + PoseLandmarker::Create(std::move(options))); + PoseLandmarkerResult pose_landmarker_results; + if (GetParam().rotation != 0) { + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = GetParam().rotation; + MP_ASSERT_OK_AND_ASSIGN( + pose_landmarker_results, + pose_landmarker->Detect(image, image_processing_options)); + } else { + MP_ASSERT_OK_AND_ASSIGN(pose_landmarker_results, + pose_landmarker->Detect(image)); + } + ExpectPoseLandmarkerResultsCorrect(pose_landmarker_results, + GetParam().expected_results, + kLandmarksAbsMargin); + MP_ASSERT_OK(pose_landmarker->Close()); +} + +INSTANTIATE_TEST_SUITE_P( + PoseGestureTest, ImageModeTest, + Values(TestParams{ + /* test_name= */ "Pose", + /* test_image_name= */ kPoseImage, + /* test_model_file= */ kPoseLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + GetExpectedPoseLandmarkerResult({kPoseLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "NoPoses", + /* test_image_name= */ kBurgerImage, + /* test_model_file= */ kPoseLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + {{}, {}, {}}, + }), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +class VideoModeTest : public testing::TestWithParam {}; + +TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kPoseImage))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPoseLandmarkerBundleAsset); + options->running_mode = core::RunningMode::VIDEO; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr pose_landmarker, + PoseLandmarker::Create(std::move(options))); + auto results = pose_landmarker->Detect(image); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the image mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = pose_landmarker->DetectAsync(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the live stream mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(pose_landmarker->Close()); +} + +TEST_P(VideoModeTest, Succeeds) { + const int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, GetParam().test_model_file); + options->running_mode = core::RunningMode::VIDEO; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr pose_landmarker, + PoseLandmarker::Create(std::move(options))); + const auto expected_results = GetParam().expected_results; + for (int i = 0; i < iterations; ++i) { + PoseLandmarkerResult pose_landmarker_results; + if (GetParam().rotation != 0) { + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = GetParam().rotation; + MP_ASSERT_OK_AND_ASSIGN( + pose_landmarker_results, + pose_landmarker->DetectForVideo(image, i, image_processing_options)); + } else { + MP_ASSERT_OK_AND_ASSIGN(pose_landmarker_results, + pose_landmarker->DetectForVideo(image, i)); + } + LOG(INFO) << i; + ExpectPoseLandmarkerResultsCorrect( + pose_landmarker_results, expected_results, kLandmarksOnVideoAbsMargin); + } + MP_ASSERT_OK(pose_landmarker->Close()); +} + +// TODO Add additional tests for MP Tasks Pose Graphs +// TODO Investigate PoseLandmarker performance in VideoMode. + +INSTANTIATE_TEST_SUITE_P( + PoseGestureTest, VideoModeTest, + Values(TestParams{ + /* test_name= */ "Pose", + /* test_image_name= */ kPoseImage, + /* test_model_file= */ kPoseLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + GetExpectedPoseLandmarkerResult({kPoseLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "NoPoses", + /* test_image_name= */ kBurgerImage, + /* test_model_file= */ kPoseLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + {{}, {}, {}}, + }), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +class LiveStreamModeTest : public testing::TestWithParam {}; + +TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kPoseImage))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPoseLandmarkerBundleAsset); + options->running_mode = core::RunningMode::LIVE_STREAM; + options->result_callback = [](absl::StatusOr results, + const Image& image, int64_t timestamp_ms) {}; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr pose_landmarker, + PoseLandmarker::Create(std::move(options))); + auto results = pose_landmarker->Detect(image); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the image mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = pose_landmarker->DetectForVideo(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the video mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(pose_landmarker->Close()); +} + +TEST_P(LiveStreamModeTest, Succeeds) { + const int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, GetParam().test_model_file); + options->running_mode = core::RunningMode::LIVE_STREAM; + std::vector pose_landmarker_results; + std::vector> image_sizes; + std::vector timestamps; + options->result_callback = [&pose_landmarker_results, &image_sizes, + ×tamps]( + absl::StatusOr results, + const Image& image, int64_t timestamp_ms) { + MP_ASSERT_OK(results.status()); + pose_landmarker_results.push_back(std::move(results.value())); + image_sizes.push_back({image.width(), image.height()}); + timestamps.push_back(timestamp_ms); + }; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr pose_landmarker, + PoseLandmarker::Create(std::move(options))); + for (int i = 0; i < iterations; ++i) { + PoseLandmarkerResult pose_landmarker_results; + if (GetParam().rotation != 0) { + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = GetParam().rotation; + MP_ASSERT_OK( + pose_landmarker->DetectAsync(image, i, image_processing_options)); + } else { + MP_ASSERT_OK(pose_landmarker->DetectAsync(image, i)); + } + } + MP_ASSERT_OK(pose_landmarker->Close()); + // Due to the flow limiter, the total of outputs will be smaller than the + // number of iterations. + ASSERT_LE(pose_landmarker_results.size(), iterations); + ASSERT_GT(pose_landmarker_results.size(), 0); + + const auto expected_results = GetParam().expected_results; + for (int i = 0; i < pose_landmarker_results.size(); ++i) { + ExpectPoseLandmarkerResultsCorrect(pose_landmarker_results[i], + expected_results, + kLandmarksOnVideoAbsMargin); + } + for (const auto& image_size : image_sizes) { + EXPECT_EQ(image_size.first, image.width()); + EXPECT_EQ(image_size.second, image.height()); + } + int64_t timestamp_ms = -1; + for (const auto& timestamp : timestamps) { + EXPECT_GT(timestamp, timestamp_ms); + timestamp_ms = timestamp; + } +} + +// TODO Add additional tests for MP Tasks Pose Graphs +// Investigate PoseLandmarker performance in LiveStreamMode. + +INSTANTIATE_TEST_SUITE_P( + PoseGestureTest, LiveStreamModeTest, + Values(TestParams{ + /* test_name= */ "Pose", + /* test_image_name= */ kPoseImage, + /* test_model_file= */ kPoseLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + GetExpectedPoseLandmarkerResult({kPoseLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "NoPoses", + /* test_image_name= */ kBurgerImage, + /* test_model_file= */ kPoseLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + {{}, {}, {}}, + }), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace pose_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 0de0c255c..d89a60625 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -220,6 +220,7 @@ filegroup( "portrait_rotated_expected_detection.pbtxt", "pose_expected_detection.pbtxt", "pose_expected_expanded_rect.pbtxt", + "pose_landmarks.pbtxt", "thumb_up_landmarks.pbtxt", "thumb_up_rotated_landmarks.pbtxt", "victory_landmarks.pbtxt", diff --git a/mediapipe/tasks/testdata/vision/pose_landmarks.pbtxt b/mediapipe/tasks/testdata/vision/pose_landmarks.pbtxt new file mode 100644 index 000000000..d6137c6ed --- /dev/null +++ b/mediapipe/tasks/testdata/vision/pose_landmarks.pbtxt @@ -0,0 +1,235 @@ +# proto-file: mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto +# proto-message: LandmarksDetectionResult +landmarks { + landmark { + x: 0.46507242 + y: 0.41610304 + z: -0.14999297 + visibility: 0.99998844 + presence: 0.99998903 + } + landmark { + x: 0.47030905 + y: 0.40309227 + z: -0.14747797 + visibility: 0.9999664 + presence: 0.99995553 + } + landmark { + x: 0.47342995 + y: 0.40156996 + z: -0.14745711 + visibility: 0.9999546 + presence: 0.9999534 + } + landmark { + x: 0.47663194 + y: 0.39997125 + z: -0.14746517 + visibility: 0.99994385 + presence: 0.99994576 + } + landmark { + x: 0.46315864 + y: 0.4074311 + z: -0.12709364 + visibility: 0.99994826 + presence: 0.99996126 + } + landmark { + x: 0.46167478 + y: 0.4089059 + z: -0.12713416 + visibility: 0.9999267 + presence: 0.99996483 + } + landmark { + x: 0.45983645 + y: 0.41042596 + z: -0.12708154 + visibility: 0.9999249 + presence: 0.9999635 + } + landmark { + x: 0.48985234 + y: 0.4053322 + z: -0.10685073 + visibility: 0.99990225 + presence: 0.9999591 + } + landmark { + x: 0.46542352 + y: 0.41843837 + z: -0.011494645 + visibility: 0.999894 + presence: 0.9999794 + } + landmark { + x: 0.4757115 + y: 0.42860925 + z: -0.13509856 + visibility: 0.9999609 + presence: 0.9999964 + } + landmark { + x: 0.46715724 + y: 0.4335927 + z: -0.10763592 + visibility: 0.9999547 + presence: 0.99999654 + } + landmark { + x: 0.5383015 + y: 0.47349828 + z: -0.09904624 + visibility: 0.99999046 + presence: 0.99997306 + } + landmark { + x: 0.45313644 + y: 0.48936796 + z: 0.06387678 + visibility: 0.9999919 + presence: 0.99999344 + } + landmark { + x: 0.6247076 + y: 0.47583634 + z: -0.16247825 + visibility: 0.9971179 + presence: 0.9997496 + } + landmark { + x: 0.36646777 + y: 0.48342204 + z: 0.08566214 + visibility: 0.98212147 + presence: 0.9999796 + } + landmark { + x: 0.6950672 + y: 0.47276276 + z: -0.23432226 + visibility: 0.99190086 + presence: 0.99751616 + } + landmark { + x: 0.29679844 + y: 0.46896482 + z: -0.020334292 + visibility: 0.9826943 + presence: 0.99966264 + } + landmark { + x: 0.71771 + y: 0.466713 + z: -0.25376678 + visibility: 0.97681385 + presence: 0.9940614 + } + landmark { + x: 0.2762509 + y: 0.46507546 + z: -0.030049918 + visibility: 0.9577505 + presence: 0.9986953 + } + landmark { + x: 0.71468514 + y: 0.46545807 + z: -0.27873537 + visibility: 0.9773756 + presence: 0.99447954 + } + landmark { + x: 0.27898294 + y: 0.45952708 + z: -0.061881505 + visibility: 0.9598176 + presence: 0.9987778 + } + landmark { + x: 0.70616376 + y: 0.47313935 + z: -0.2473996 + visibility: 0.97479546 + presence: 0.9961171 + } + landmark { + x: 0.2866556 + y: 0.4661593 + z: -0.039750997 + visibility: 0.96074265 + presence: 0.9992084 + } + landmark { + x: 0.5199628 + y: 0.71132404 + z: -0.05568369 + visibility: 0.9997075 + presence: 0.9997811 + } + landmark { + x: 0.46290728 + y: 0.7031081 + z: 0.05579845 + visibility: 0.99977857 + presence: 0.99986553 + } + landmark { + x: 0.60217524 + y: 0.81984437 + z: -0.14303333 + visibility: 0.99511355 + presence: 0.9993449 + } + landmark { + x: 0.36314535 + y: 0.7332305 + z: -0.063036785 + visibility: 0.9963553 + presence: 0.9997341 + } + landmark { + x: 0.70064104 + y: 0.93365943 + z: -0.031219501 + visibility: 0.9909759 + presence: 0.9935846 + } + landmark { + x: 0.35170174 + y: 0.9031892 + z: 0.0022715325 + visibility: 0.9929086 + presence: 0.99870515 + } + landmark { + x: 0.70923454 + y: 0.954628 + z: -0.024844522 + visibility: 0.93651295 + presence: 0.99121124 + } + landmark { + x: 0.36733973 + y: 0.93519753 + z: 0.003870454 + visibility: 0.96783006 + presence: 0.99829334 + } + landmark { + x: 0.72855467 + y: 0.96812284 + z: -0.12364431 + visibility: 0.9639738 + presence: 0.98293763 + } + landmark { + x: 0.30121577 + y: 0.94473803 + z: -0.08448716 + visibility: 0.97545445 + presence: 0.99386966 + } +} diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 0d6abdf1e..8da42125c 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -307,13 +307,13 @@ def external_files(): http_file( name = "com_google_mediapipe_expected_pose_landmarks_prototxt", sha256 = "eed8dfa169b0abee60cde01496599b0bc75d91a82594a1bdf59be2f76f45d7f5", - urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=1681243892338529"], + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=1681244232522990"], ) http_file( name = "com_google_mediapipe_expected_pose_landmarks_prototxt_orig", sha256 = "c230e0933e6cb4af69ec21314f3f9930fe13e7bb4bf1dbdb74427e4138c24c1e", - urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt.orig?generation=1681243894793518"], + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt.orig?generation=1681244235071100"], ) http_file( @@ -823,7 +823,7 @@ def external_files(): http_file( name = "com_google_mediapipe_ocr_text_jpg", sha256 = "88052e93aa910330433741f5cef140f8f9ec463230a332aef7038b5457b06482", - urls = ["https://storage.googleapis.com/mediapipe-assets/ocr_text.jpg?generation=1681243901191762"], + urls = ["https://storage.googleapis.com/mediapipe-assets/ocr_text.jpg?generation=1681244241009078"], ) http_file( @@ -943,13 +943,13 @@ def external_files(): http_file( name = "com_google_mediapipe_pose_expected_detection_pbtxt", sha256 = "16866c8dd4fbee60f6972630d73baed219b45824c055c7fbc7dc9a91c4b182cc", - urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_detection.pbtxt?generation=1681243904247231"], + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_detection.pbtxt?generation=1681244244235448"], ) http_file( name = "com_google_mediapipe_pose_expected_expanded_rect_pbtxt", sha256 = "b0a41d25ed115757606dfc034e9d320a93a52616d92d745150b6a886ddc5a88a", - urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_expanded_rect.pbtxt?generation=1681243906782078"], + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_expanded_rect.pbtxt?generation=1681244246786802"], ) http_file( @@ -961,7 +961,7 @@ def external_files(): http_file( name = "com_google_mediapipe_pose_landmarker_task", sha256 = "fb9cc326c88fc2a4d9a6d355c28520d5deacfbaa375b56243b0141b546080596", - urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarker.task?generation=1681243909774681"], + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarker.task?generation=1681244249587900"], ) http_file( @@ -979,7 +979,13 @@ def external_files(): http_file( name = "com_google_mediapipe_pose_landmark_lite_tflite", sha256 = "1150dc68a713b80660b90ef46ce4e85c1c781bb88b6e3512cc64e6a685ba5588", - urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmark_lite.tflite?generation=1681243912778143"], + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmark_lite.tflite?generation=1681244252454799"], + ) + + http_file( + name = "com_google_mediapipe_pose_landmarks_pbtxt", + sha256 = "305a71fbff83e270a5dbd81fb7cf65203f56e0b1caba8ea42edc16c6e8a2ba18", + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681244254964356"], ) http_file(