diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 5004383d2..b5c9427a3 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -30,6 +30,15 @@ cc_library( ], ) +cc_library( + name = "hand_landmarks_detection_result", + hdrs = ["hand_landmarks_detection_result.h"], + deps = [ + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + ], +) + cc_library( name = "category", srcs = ["category.cc"], diff --git a/mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h b/mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h new file mode 100644 index 000000000..b341cd8d3 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h @@ -0,0 +1,43 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_ + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace containers { + +// The hand landmarks detection result from HandLandmarker, where each vector +// element represents a single hand detected in the image. +struct HandLandmarksDetectionResult { + // Classification of handedness. + std::vector handedness; + // Detected hand landmarks in normalized image coordinates. + std::vector hand_landmarks; + // Detected hand landmarks in world coordinates. + std::vector hand_world_landmarks; +}; + +} // namespace containers +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_ diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 9090fc7b3..084934286 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -110,4 +110,38 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "hand_landmarker", + srcs = ["hand_landmarker.cc"], + hdrs = ["hand_landmarker.h"], + deps = [ + ":hand_landmarker_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/containers:hand_landmarks_detection_result", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + # TODO: Enable this test diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc new file mode 100644 index 000000000..ec9790c30 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -0,0 +1,269 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h" +#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/core/base_task_api.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +namespace { + +using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: + hand_landmarker::proto::HandLandmarkerGraphOptions; + +using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult; + +constexpr char kHandLandmarkerGraphTypeName[] = + "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageInStreamName[] = "image_in"; +constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectStreamName[] = "norm_rect_in"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kHandednessStreamName[] = "handedness"; +constexpr char kHandLandmarksTag[] = "LANDMARKS"; +constexpr char kHandLandmarksStreamName[] = "landmarks"; +constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks"; +constexpr int kMicroSecondsPerMilliSecond = 1000; + +// Creates a MediaPipe graph config that contains a subgraph node of +// "mediapipe.tasks.vision.hand_ladnamrker.HandLandmarkerGraph". If the task is +// running in the live stream mode, a "FlowLimiterCalculator" will be added to +// limit the number of frames in flight. +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options, + bool enable_flow_limiting) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kHandLandmarkerGraphTypeName); + subgraph.GetOptions().Swap(options.get()); + graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectStreamName); + subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >> + graph.Out(kHandednessTag); + subgraph.Out(kHandLandmarksTag).SetName(kHandLandmarksStreamName) >> + graph.Out(kHandLandmarksTag); + subgraph.Out(kHandWorldLandmarksTag).SetName(kHandWorldLandmarksStreamName) >> + graph.Out(kHandWorldLandmarksTag); + subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); + if (enable_flow_limiting) { + return tasks::core::AddFlowLimiterCalculator( + graph, subgraph, {kImageTag, kNormRectTag}, kHandLandmarksTag); + } + graph.In(kImageTag) >> subgraph.In(kImageTag); + graph.In(kNormRectTag) >> subgraph.In(kNormRectTag); + return graph.GetConfig(); +} + +// Converts the user-facing HandLandmarkerOptions struct to the internal +// HandLandmarkerGraphOptions proto. +std::unique_ptr +ConvertHandLandmarkerGraphOptionsProto(HandLandmarkerOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + options_proto->mutable_base_options()->set_use_stream_mode( + options->running_mode != core::RunningMode::IMAGE); + + // Configure hand detector options. + auto* hand_detector_graph_options = + options_proto->mutable_hand_detector_graph_options(); + hand_detector_graph_options->set_num_hands(options->num_hands); + hand_detector_graph_options->set_min_detection_confidence( + options->min_hand_detection_confidence); + + // Configure hand landmark detector options. + options_proto->set_min_tracking_confidence(options->min_tracking_confidence); + auto* hand_landmarks_detector_graph_options = + options_proto->mutable_hand_landmarks_detector_graph_options(); + hand_landmarks_detector_graph_options->set_min_detection_confidence( + options->min_hand_presence_confidence); + + return options_proto; +} + +} // namespace + +absl::StatusOr> HandLandmarker::Create( + std::unique_ptr options) { + auto options_proto = ConvertHandLandmarkerGraphOptionsProto(options.get()); + tasks::core::PacketsCallback packets_callback = nullptr; + if (options->result_callback) { + auto result_callback = options->result_callback; + packets_callback = [=](absl::StatusOr + status_or_packets) { + if (!status_or_packets.ok()) { + Image image; + result_callback(status_or_packets.status(), image, + Timestamp::Unset().Value()); + return; + } + if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { + return; + } + Packet image_packet = status_or_packets.value()[kImageOutStreamName]; + if (status_or_packets.value()[kHandLandmarksStreamName].IsEmpty()) { + Packet empty_packet = + status_or_packets.value()[kHandLandmarksStreamName]; + result_callback( + {HandLandmarksDetectionResult()}, image_packet.Get(), + empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); + return; + } + Packet handedness_packet = + status_or_packets.value()[kHandednessStreamName]; + Packet hand_landmarks_packet = + status_or_packets.value()[kHandLandmarksStreamName]; + Packet hand_world_landmarks_packet = + status_or_packets.value()[kHandWorldLandmarksStreamName]; + result_callback( + {{handedness_packet.Get>(), + hand_landmarks_packet.Get>(), + hand_world_landmarks_packet.Get>()}}, + image_packet.Get(), + hand_landmarks_packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond); + }; + } + return core::VisionTaskApiFactory::Create( + CreateGraphConfig( + std::move(options_proto), + options->running_mode == core::RunningMode::LIVE_STREAM), + std::move(options->base_options.op_resolver), options->running_mode, + std::move(packets_callback)); +} + +absl::StatusOr HandLandmarker::Detect( + mediapipe::Image image, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "GPU input images are currently not supported.", + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData( + {{kImageInStreamName, MakePacket(std::move(image))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect))}})); + if (output_packets[kHandLandmarksStreamName].IsEmpty()) { + return {HandLandmarksDetectionResult()}; + } + return {{/* handedness= */ + {output_packets[kHandednessStreamName] + .Get>()}, + /* hand_landmarks= */ + {output_packets[kHandLandmarksStreamName] + .Get>()}, + /* hand_world_landmarks */ + {output_packets[kHandWorldLandmarksStreamName] + .Get>()}}}; +} + +absl::StatusOr HandLandmarker::DetectForVideo( + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessVideoData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + if (output_packets[kHandLandmarksStreamName].IsEmpty()) { + return {HandLandmarksDetectionResult()}; + } + return { + {/* handedness= */ + {output_packets[kHandednessStreamName] + .Get>()}, + /* hand_landmarks= */ + {output_packets[kHandLandmarksStreamName] + .Get>()}, + /* hand_world_landmarks */ + {output_packets[kHandWorldLandmarksStreamName] + .Get>()}}, + }; +} + +absl::Status HandLandmarker::DetectAsync( + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); + return SendLiveStreamData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); +} + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h new file mode 100644 index 000000000..3538ab3f5 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h @@ -0,0 +1,192 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +struct HandLandmarkerOptions { + // Base options for configuring MediaPipe Tasks library, such as specifying + // the TfLite model bundle file with metadata, accelerator options, op + // resolver, etc. + tasks::core::BaseOptions base_options; + + // The running mode of the task. Default to the image mode. + // HandLandmarker has three running modes: + // 1) The image mode for detecting hand landmarks on single image inputs. + // 2) The video mode for detecting hand landmarks on the decoded frames of a + // video. + // 3) The live stream mode for detecting hand landmarks on the live stream of + // input data, such as from camera. In this mode, the "result_callback" + // below must be specified to receive the detection results asynchronously. + core::RunningMode running_mode = core::RunningMode::IMAGE; + + // The maximum number of hands can be detected by the HandLandmarker. + int num_hands = 1; + + // The minimum confidence score for the hand detection to be considered + // successful. + float min_hand_detection_confidence = 0.5; + + // The minimum confidence score of hand presence score in the hand landmark + // detection. + float min_hand_presence_confidence = 0.5; + + // The minimum confidence score for the hand tracking to be considered + // successful. + float min_tracking_confidence = 0.5; + + // The user-defined result callback for processing live stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::LIVE_STREAM. + std::function, + const Image&, int64)> + result_callback = nullptr; +}; + +// Performs hand landmarks detection on the given image. +// +// TODO add the link to DevSite. +// This API expects a pre-trained hand landmarker model asset bundle. +// +// Inputs: +// Image +// - The image that hand landmarks detection runs on. +// std::optional +// - If provided, can be used to specify the rotation to apply to the image +// before performing hand landmarks detection, by setting its 'rotation' +// field in radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). +// Note that specifying a region-of-interest using the 'x_center', +// 'y_center', 'width' and 'height' fields is NOT supported and will +// result in an invalid argument error being returned. +// Outputs: +// HandLandmarksDetectionResult +// - The hand landmarks detection results. +class HandLandmarker : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates a HandLandmarker from a HandLandmarkerOptions to process image data + // or streaming data. Hand landmarker can be created with one of the following + // three running modes: + // 1) Image mode for detecting hand landmarks on single image inputs. Users + // provide mediapipe::Image to the `Detect` method, and will receive the + // deteced hand landmarks results as the return value. + // 2) Video mode for detecting hand landmarks on the decoded frames of a + // video. Users call `DetectForVideo` method, and will receive the detected + // hand landmarks results as the return value. + // 3) Live stream mode for detecting hand landmarks on the live stream of the + // input data, such as from camera. Users call `DetectAsync` to push the + // image data into the HandLandmarker, the detected results along with the + // input timestamp and the image that hand landmarker runs on will be + // available in the result callback when the hand landmarker finishes the + // work. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs hand landmarks detection on the given image. + // Only use this method when the HandLandmarker is created with the image + // running mode. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by setting + // its 'rotation_degrees' field. Note that specifying a region-of-interest + // using the 'region_of_interest' field is NOT supported and will result in an + // invalid argument error being returned. + // + // The image can be of any size with format RGB or RGBA. + // TODO: Describes how the input image will be preprocessed + // after the yuv support is implemented. + absl::StatusOr Detect( + Image image, + std::optional image_processing_options = + std::nullopt); + + // Performs hand landmarks detection on the provided video frame. + // Only use this method when the HandLandmarker is created with the video + // running mode. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by setting + // its 'rotation_degrees' field. Note that specifying a region-of-interest + // using the 'region_of_interest' field is NOT supported and will result in an + // invalid argument error being returned. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide the video frame's timestamp (in milliseconds). The input timestamps + // must be monotonically increasing. + absl::StatusOr + DetectForVideo(Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); + + // Sends live image data to perform hand landmarks detection, and the results + // will be available via the "result_callback" provided in the + // HandLandmarkerOptions. Only use this method when the HandLandmarker + // is created with the live stream running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide a timestamp (in milliseconds) to indicate when the input image is + // sent to the hand landmarker. The input timestamps must be monotonically + // increasing. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by setting + // its 'rotation_degrees' field. Note that specifying a region-of-interest + // using the 'region_of_interest' field is NOT supported and will result in an + // invalid argument error being returned. + // + // The "result_callback" provides + // - A vector of HandLandmarksDetectionResult, each is the detected results + // for a input frame. + // - The const reference to the corresponding input image that the hand + // landmarker runs on. Note that the const reference to the image will no + // longer be valid when the callback returns. To access the image data + // outside of the callback, callers need to make a copy of the image. + // - The input timestamp in milliseconds. + absl::Status DetectAsync(Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); + + // Shuts down the HandLandmarker when all works are done. + absl::Status Close() { return runner_->Close(); } +}; + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_ diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc new file mode 100644 index 000000000..ee8c3c10d --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc @@ -0,0 +1,511 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h" + +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h" +#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +namespace { + +using ::file::Defaults; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult; +using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; +using ::testing::EqualsProto; +using ::testing::HasSubstr; +using ::testing::Optional; +using ::testing::Pointwise; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kHandLandmarkerBundleAsset[] = "hand_landmarker.task"; +constexpr char kThumbUpLandmarksFilename[] = "thumb_up_landmarks.pbtxt"; +constexpr char kPointingUpLandmarksFilename[] = "pointing_up_landmarks.pbtxt"; +constexpr char kPointingUpRotatedLandmarksFilename[] = + "pointing_up_rotated_landmarks.pbtxt"; +constexpr char kThumbUpImage[] = "thumb_up.jpg"; +constexpr char kPointingUpImage[] = "pointing_up.jpg"; +constexpr char kPointingUpRotatedImage[] = "pointing_up_rotated.jpg"; +constexpr char kNoHandsImage[] = "cats_and_dogs.jpg"; + +constexpr float kLandmarksFractionDiff = 0.03; // percentage +constexpr float kLandmarksAbsMargin = 0.03; +constexpr float kHandednessMargin = 0.05; + +LandmarksDetectionResult GetLandmarksDetectionResult( + absl::string_view landmarks_file_name) { + LandmarksDetectionResult result; + MP_EXPECT_OK(GetTextProto( + file::JoinPath("./", kTestDataDirectory, landmarks_file_name), &result, + Defaults())); + // Remove z position of landmarks, because they are not used in correctness + // testing. For video or live stream mode, the z positions varies a lot during + // tracking from frame to frame. + for (int i = 0; i < result.landmarks().landmark().size(); i++) { + auto& landmark = *result.mutable_landmarks()->mutable_landmark(i); + landmark.clear_z(); + } + return result; +} + +HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult( + const std::vector& landmarks_file_names) { + HandLandmarksDetectionResult expected_results; + for (const auto& file_name : landmarks_file_names) { + const auto landmarks_detection_result = + GetLandmarksDetectionResult(file_name); + expected_results.hand_landmarks.push_back( + landmarks_detection_result.landmarks()); + expected_results.handedness.push_back( + landmarks_detection_result.classifications()); + } + return expected_results; +} + +void ExpectHandLandmarksDetectionResultsCorrect( + const HandLandmarksDetectionResult& actual_results, + const HandLandmarksDetectionResult& expected_results) { + const auto& actual_landmarks = actual_results.hand_landmarks; + const auto& actual_handedness = actual_results.handedness; + + const auto& expected_landmarks = expected_results.hand_landmarks; + const auto& expected_handedness = expected_results.handedness; + + ASSERT_EQ(actual_landmarks.size(), expected_landmarks.size()); + ASSERT_EQ(actual_handedness.size(), expected_handedness.size()); + + EXPECT_THAT( + actual_handedness, + Pointwise(Approximately(Partially(EqualsProto()), kHandednessMargin), + expected_handedness)); + EXPECT_THAT(actual_landmarks, + Pointwise(Approximately(Partially(EqualsProto()), + /*margin=*/kLandmarksAbsMargin, + /*fraction=*/kLandmarksFractionDiff), + expected_landmarks)); +} + +} // namespace + +struct TestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of test image. + std::string test_image_name; + // The filename of test model. + std::string test_model_file; + // The rotation to apply to the test image before processing, in degrees + // clockwise. + int rotation; + // Expected results from the hand landmarker model output. + HandLandmarksDetectionResult expected_results; +}; + +class ImageModeTest : public testing::TestWithParam {}; + +TEST_F(ImageModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset); + options->running_mode = core::RunningMode::IMAGE; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, + HandLandmarker::Create(std::move(options))); + auto results = hand_landmarker->DetectForVideo(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the video mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = hand_landmarker->DetectAsync(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the live stream mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(hand_landmarker->Close()); +} + +TEST_F(ImageModeTest, FailsWithRegionOfInterest) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset); + options->running_mode = core::RunningMode::IMAGE; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, + HandLandmarker::Create(std::move(options))); + Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; + + auto results = hand_landmarker->Detect(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("This task doesn't support region-of-interest")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); +} + +TEST_P(ImageModeTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, GetParam().test_model_file); + options->running_mode = core::RunningMode::IMAGE; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, + HandLandmarker::Create(std::move(options))); + HandLandmarksDetectionResult hand_landmarker_results; + if (GetParam().rotation != 0) { + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = GetParam().rotation; + MP_ASSERT_OK_AND_ASSIGN( + hand_landmarker_results, + hand_landmarker->Detect(image, image_processing_options)); + } else { + MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results, + hand_landmarker->Detect(image)); + } + ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results, + GetParam().expected_results); + MP_ASSERT_OK(hand_landmarker->Close()); +} + +INSTANTIATE_TEST_SUITE_P( + HandGestureTest, ImageModeTest, + Values(TestParams{ + /* test_name= */ "LandmarksThumbUp", + /* test_image_name= */ kThumbUpImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + GetExpectedHandLandmarksDetectionResult( + {kThumbUpLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "LandmarksPointingUp", + /* test_image_name= */ kPointingUpImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + GetExpectedHandLandmarksDetectionResult( + {kPointingUpLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "LandmarksPointingUpRotated", + /* test_image_name= */ kPointingUpRotatedImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ -90, + /* expected_results = */ + GetExpectedHandLandmarksDetectionResult( + {kPointingUpRotatedLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "NoHands", + /* test_image_name= */ kNoHandsImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + {{}, {}, {}}, + }), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +class VideoModeTest : public testing::TestWithParam {}; + +TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset); + options->running_mode = core::RunningMode::VIDEO; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, + HandLandmarker::Create(std::move(options))); + auto results = hand_landmarker->Detect(image); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the image mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = hand_landmarker->DetectAsync(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the live stream mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(hand_landmarker->Close()); +} + +TEST_P(VideoModeTest, Succeeds) { + const int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, GetParam().test_model_file); + options->running_mode = core::RunningMode::VIDEO; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, + HandLandmarker::Create(std::move(options))); + const auto expected_results = GetParam().expected_results; + for (int i = 0; i < iterations; ++i) { + HandLandmarksDetectionResult hand_landmarker_results; + if (GetParam().rotation != 0) { + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = GetParam().rotation; + MP_ASSERT_OK_AND_ASSIGN( + hand_landmarker_results, + hand_landmarker->DetectForVideo(image, i, image_processing_options)); + } else { + MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results, + hand_landmarker->DetectForVideo(image, i)); + } + ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results, + expected_results); + } + MP_ASSERT_OK(hand_landmarker->Close()); +} + +INSTANTIATE_TEST_SUITE_P( + HandGestureTest, VideoModeTest, + Values(TestParams{ + /* test_name= */ "LandmarksThumbUp", + /* test_image_name= */ kThumbUpImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + GetExpectedHandLandmarksDetectionResult( + {kThumbUpLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "LandmarksPointingUp", + /* test_image_name= */ kPointingUpImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + GetExpectedHandLandmarksDetectionResult( + {kPointingUpLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "LandmarksPointingUpRotated", + /* test_image_name= */ kPointingUpRotatedImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ -90, + /* expected_results = */ + GetExpectedHandLandmarksDetectionResult( + {kPointingUpRotatedLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "NoHands", + /* test_image_name= */ kNoHandsImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + {{}, {}, {}}, + }), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +class LiveStreamModeTest : public testing::TestWithParam {}; + +TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset); + options->running_mode = core::RunningMode::LIVE_STREAM; + options->result_callback = + [](absl::StatusOr results, + const Image& image, int64 timestamp_ms) {}; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, + HandLandmarker::Create(std::move(options))); + auto results = hand_landmarker->Detect(image); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the image mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = hand_landmarker->DetectForVideo(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the video mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(hand_landmarker->Close()); +} + +TEST_P(LiveStreamModeTest, Succeeds) { + const int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, GetParam().test_model_file); + options->running_mode = core::RunningMode::LIVE_STREAM; + std::vector hand_landmarker_results; + std::vector> image_sizes; + std::vector timestamps; + options->result_callback = + [&hand_landmarker_results, &image_sizes, ×tamps]( + absl::StatusOr results, + const Image& image, int64 timestamp_ms) { + MP_ASSERT_OK(results.status()); + hand_landmarker_results.push_back(std::move(results.value())); + image_sizes.push_back({image.width(), image.height()}); + timestamps.push_back(timestamp_ms); + }; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, + HandLandmarker::Create(std::move(options))); + for (int i = 0; i < iterations; ++i) { + HandLandmarksDetectionResult hand_landmarker_results; + if (GetParam().rotation != 0) { + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = GetParam().rotation; + MP_ASSERT_OK( + hand_landmarker->DetectAsync(image, i, image_processing_options)); + } else { + MP_ASSERT_OK(hand_landmarker->DetectAsync(image, i)); + } + } + MP_ASSERT_OK(hand_landmarker->Close()); + // Due to the flow limiter, the total of outputs will be smaller than the + // number of iterations. + ASSERT_LE(hand_landmarker_results.size(), iterations); + ASSERT_GT(hand_landmarker_results.size(), 0); + + const auto expected_results = GetParam().expected_results; + for (int i = 0; i < hand_landmarker_results.size(); ++i) { + ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results[i], + expected_results); + } + for (const auto& image_size : image_sizes) { + EXPECT_EQ(image_size.first, image.width()); + EXPECT_EQ(image_size.second, image.height()); + } + int64 timestamp_ms = -1; + for (const auto& timestamp : timestamps) { + EXPECT_GT(timestamp, timestamp_ms); + timestamp_ms = timestamp; + } +} + +INSTANTIATE_TEST_SUITE_P( + HandGestureTest, LiveStreamModeTest, + Values(TestParams{ + /* test_name= */ "LandmarksThumbUp", + /* test_image_name= */ kThumbUpImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + GetExpectedHandLandmarksDetectionResult( + {kThumbUpLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "LandmarksPointingUp", + /* test_image_name= */ kPointingUpImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + GetExpectedHandLandmarksDetectionResult( + {kPointingUpLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "LandmarksPointingUpRotated", + /* test_image_name= */ kPointingUpRotatedImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ -90, + /* expected_results = */ + GetExpectedHandLandmarksDetectionResult( + {kPointingUpRotatedLandmarksFilename}), + }, + TestParams{ + /* test_name= */ "NoHands", + /* test_image_name= */ kNoHandsImage, + /* test_model_file= */ kHandLandmarkerBundleAsset, + /* rotation= */ 0, + /* expected_results = */ + {{}, {}, {}}, + }), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe