Add hand landmarker C++ api.

PiperOrigin-RevId: 486791711
This commit is contained in:
MediaPipe Team 2022-11-07 16:15:54 -08:00 committed by Copybara-Service
parent 571c0b1fef
commit 6f38a7a21f
6 changed files with 1058 additions and 0 deletions

View File

@ -30,6 +30,15 @@ cc_library(
], ],
) )
cc_library(
name = "hand_landmarks_detection_result",
hdrs = ["hand_landmarks_detection_result.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
],
)
cc_library( cc_library(
name = "category", name = "category",
srcs = ["category.cc"], srcs = ["category.cc"],

View File

@ -0,0 +1,43 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace containers {
// The hand landmarks detection result from HandLandmarker, where each vector
// element represents a single hand detected in the image.
struct HandLandmarksDetectionResult {
// Classification of handedness.
std::vector<mediapipe::ClassificationList> handedness;
// Detected hand landmarks in normalized image coordinates.
std::vector<mediapipe::NormalizedLandmarkList> hand_landmarks;
// Detected hand landmarks in world coordinates.
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
};
} // namespace containers
} // namespace components
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_

View File

@ -110,4 +110,38 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "hand_landmarker",
srcs = ["hand_landmarker.cc"],
hdrs = ["hand_landmarker.h"],
deps = [
":hand_landmarker_graph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers:hand_landmarks_detection_result",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",
],
)
# TODO: Enable this test # TODO: Enable this test

View File

@ -0,0 +1,269 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
namespace {
using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision::
hand_landmarker::proto::HandLandmarkerGraphOptions;
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
constexpr char kHandLandmarkerGraphTypeName[] =
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kHandednessStreamName[] = "handedness";
constexpr char kHandLandmarksTag[] = "LANDMARKS";
constexpr char kHandLandmarksStreamName[] = "landmarks";
constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks";
constexpr int kMicroSecondsPerMilliSecond = 1000;
// Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.hand_ladnamrker.HandLandmarkerGraph". If the task is
// running in the live stream mode, a "FlowLimiterCalculator" will be added to
// limit the number of frames in flight.
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<HandLandmarkerGraphOptionsProto> options,
bool enable_flow_limiting) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kHandLandmarkerGraphTypeName);
subgraph.GetOptions<HandLandmarkerGraphOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >>
graph.Out(kHandednessTag);
subgraph.Out(kHandLandmarksTag).SetName(kHandLandmarksStreamName) >>
graph.Out(kHandLandmarksTag);
subgraph.Out(kHandWorldLandmarksTag).SetName(kHandWorldLandmarksStreamName) >>
graph.Out(kHandWorldLandmarksTag);
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(
graph, subgraph, {kImageTag, kNormRectTag}, kHandLandmarksTag);
}
graph.In(kImageTag) >> subgraph.In(kImageTag);
graph.In(kNormRectTag) >> subgraph.In(kNormRectTag);
return graph.GetConfig();
}
// Converts the user-facing HandLandmarkerOptions struct to the internal
// HandLandmarkerGraphOptions proto.
std::unique_ptr<HandLandmarkerGraphOptionsProto>
ConvertHandLandmarkerGraphOptionsProto(HandLandmarkerOptions* options) {
auto options_proto = std::make_unique<HandLandmarkerGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE);
// Configure hand detector options.
auto* hand_detector_graph_options =
options_proto->mutable_hand_detector_graph_options();
hand_detector_graph_options->set_num_hands(options->num_hands);
hand_detector_graph_options->set_min_detection_confidence(
options->min_hand_detection_confidence);
// Configure hand landmark detector options.
options_proto->set_min_tracking_confidence(options->min_tracking_confidence);
auto* hand_landmarks_detector_graph_options =
options_proto->mutable_hand_landmarks_detector_graph_options();
hand_landmarks_detector_graph_options->set_min_detection_confidence(
options->min_hand_presence_confidence);
return options_proto;
}
} // namespace
absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
std::unique_ptr<HandLandmarkerOptions> options) {
auto options_proto = ConvertHandLandmarkerGraphOptionsProto(options.get());
tasks::core::PacketsCallback packets_callback = nullptr;
if (options->result_callback) {
auto result_callback = options->result_callback;
packets_callback = [=](absl::StatusOr<tasks::core::PacketMap>
status_or_packets) {
if (!status_or_packets.ok()) {
Image image;
result_callback(status_or_packets.status(), image,
Timestamp::Unset().Value());
return;
}
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return;
}
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
if (status_or_packets.value()[kHandLandmarksStreamName].IsEmpty()) {
Packet empty_packet =
status_or_packets.value()[kHandLandmarksStreamName];
result_callback(
{HandLandmarksDetectionResult()}, image_packet.Get<Image>(),
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
return;
}
Packet handedness_packet =
status_or_packets.value()[kHandednessStreamName];
Packet hand_landmarks_packet =
status_or_packets.value()[kHandLandmarksStreamName];
Packet hand_world_landmarks_packet =
status_or_packets.value()[kHandWorldLandmarksStreamName];
result_callback(
{{handedness_packet.Get<std::vector<ClassificationList>>(),
hand_landmarks_packet.Get<std::vector<NormalizedLandmarkList>>(),
hand_world_landmarks_packet.Get<std::vector<LandmarkList>>()}},
image_packet.Get<Image>(),
hand_landmarks_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
};
}
return core::VisionTaskApiFactory::Create<HandLandmarker,
HandLandmarkerGraphOptionsProto>(
CreateGraphConfig(
std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback));
}
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
return {HandLandmarksDetectionResult()};
}
return {{/* handedness= */
{output_packets[kHandednessStreamName]
.Get<std::vector<mediapipe::ClassificationList>>()},
/* hand_landmarks= */
{output_packets[kHandLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
/* hand_world_landmarks */
{output_packets[kHandWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>()}}};
}
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::DetectForVideo(
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
return {HandLandmarksDetectionResult()};
}
return {
{/* handedness= */
{output_packets[kHandednessStreamName]
.Get<std::vector<mediapipe::ClassificationList>>()},
/* hand_landmarks= */
{output_packets[kHandLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
/* hand_world_landmarks */
{output_packets[kHandWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>()}},
};
}
absl::Status HandLandmarker::DetectAsync(
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,192 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_
#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_
#include <memory>
#include <optional>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
struct HandLandmarkerOptions {
// Base options for configuring MediaPipe Tasks library, such as specifying
// the TfLite model bundle file with metadata, accelerator options, op
// resolver, etc.
tasks::core::BaseOptions base_options;
// The running mode of the task. Default to the image mode.
// HandLandmarker has three running modes:
// 1) The image mode for detecting hand landmarks on single image inputs.
// 2) The video mode for detecting hand landmarks on the decoded frames of a
// video.
// 3) The live stream mode for detecting hand landmarks on the live stream of
// input data, such as from camera. In this mode, the "result_callback"
// below must be specified to receive the detection results asynchronously.
core::RunningMode running_mode = core::RunningMode::IMAGE;
// The maximum number of hands can be detected by the HandLandmarker.
int num_hands = 1;
// The minimum confidence score for the hand detection to be considered
// successful.
float min_hand_detection_confidence = 0.5;
// The minimum confidence score of hand presence score in the hand landmark
// detection.
float min_hand_presence_confidence = 0.5;
// The minimum confidence score for the hand tracking to be considered
// successful.
float min_tracking_confidence = 0.5;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(
absl::StatusOr<components::containers::HandLandmarksDetectionResult>,
const Image&, int64)>
result_callback = nullptr;
};
// Performs hand landmarks detection on the given image.
//
// TODO add the link to DevSite.
// This API expects a pre-trained hand landmarker model asset bundle.
//
// Inputs:
// Image
// - The image that hand landmarks detection runs on.
// std::optional<NormalizedRect>
// - If provided, can be used to specify the rotation to apply to the image
// before performing hand landmarks detection, by setting its 'rotation'
// field in radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation).
// Note that specifying a region-of-interest using the 'x_center',
// 'y_center', 'width' and 'height' fields is NOT supported and will
// result in an invalid argument error being returned.
// Outputs:
// HandLandmarksDetectionResult
// - The hand landmarks detection results.
class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
public:
using BaseVisionTaskApi::BaseVisionTaskApi;
// Creates a HandLandmarker from a HandLandmarkerOptions to process image data
// or streaming data. Hand landmarker can be created with one of the following
// three running modes:
// 1) Image mode for detecting hand landmarks on single image inputs. Users
// provide mediapipe::Image to the `Detect` method, and will receive the
// deteced hand landmarks results as the return value.
// 2) Video mode for detecting hand landmarks on the decoded frames of a
// video. Users call `DetectForVideo` method, and will receive the detected
// hand landmarks results as the return value.
// 3) Live stream mode for detecting hand landmarks on the live stream of the
// input data, such as from camera. Users call `DetectAsync` to push the
// image data into the HandLandmarker, the detected results along with the
// input timestamp and the image that hand landmarker runs on will be
// available in the result callback when the hand landmarker finishes the
// work.
static absl::StatusOr<std::unique_ptr<HandLandmarker>> Create(
std::unique_ptr<HandLandmarkerOptions> options);
// Performs hand landmarks detection on the given image.
// Only use this method when the HandLandmarker is created with the image
// running mode.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a region-of-interest
// using the 'region_of_interest' field is NOT supported and will result in an
// invalid argument error being returned.
//
// The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented.
absl::StatusOr<components::containers::HandLandmarksDetectionResult> Detect(
Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs hand landmarks detection on the provided video frame.
// Only use this method when the HandLandmarker is created with the video
// running mode.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a region-of-interest
// using the 'region_of_interest' field is NOT supported and will result in an
// invalid argument error being returned.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
absl::StatusOr<components::containers::HandLandmarksDetectionResult>
DetectForVideo(Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Sends live image data to perform hand landmarks detection, and the results
// will be available via the "result_callback" provided in the
// HandLandmarkerOptions. Only use this method when the HandLandmarker
// is created with the live stream running mode.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the hand landmarker. The input timestamps must be monotonically
// increasing.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a region-of-interest
// using the 'region_of_interest' field is NOT supported and will result in an
// invalid argument error being returned.
//
// The "result_callback" provides
// - A vector of HandLandmarksDetectionResult, each is the detected results
// for a input frame.
// - The const reference to the corresponding input image that the hand
// landmarker runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status DetectAsync(Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the HandLandmarker when all works are done.
absl::Status Close() { return runner_->Close(); }
};
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_H_

View File

@ -0,0 +1,511 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h"
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
namespace {
using ::file::Defaults;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::EqualsProto;
using ::testing::HasSubstr;
using ::testing::Optional;
using ::testing::Pointwise;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kHandLandmarkerBundleAsset[] = "hand_landmarker.task";
constexpr char kThumbUpLandmarksFilename[] = "thumb_up_landmarks.pbtxt";
constexpr char kPointingUpLandmarksFilename[] = "pointing_up_landmarks.pbtxt";
constexpr char kPointingUpRotatedLandmarksFilename[] =
"pointing_up_rotated_landmarks.pbtxt";
constexpr char kThumbUpImage[] = "thumb_up.jpg";
constexpr char kPointingUpImage[] = "pointing_up.jpg";
constexpr char kPointingUpRotatedImage[] = "pointing_up_rotated.jpg";
constexpr char kNoHandsImage[] = "cats_and_dogs.jpg";
constexpr float kLandmarksFractionDiff = 0.03; // percentage
constexpr float kLandmarksAbsMargin = 0.03;
constexpr float kHandednessMargin = 0.05;
LandmarksDetectionResult GetLandmarksDetectionResult(
absl::string_view landmarks_file_name) {
LandmarksDetectionResult result;
MP_EXPECT_OK(GetTextProto(
file::JoinPath("./", kTestDataDirectory, landmarks_file_name), &result,
Defaults()));
// Remove z position of landmarks, because they are not used in correctness
// testing. For video or live stream mode, the z positions varies a lot during
// tracking from frame to frame.
for (int i = 0; i < result.landmarks().landmark().size(); i++) {
auto& landmark = *result.mutable_landmarks()->mutable_landmark(i);
landmark.clear_z();
}
return result;
}
HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult(
const std::vector<absl::string_view>& landmarks_file_names) {
HandLandmarksDetectionResult expected_results;
for (const auto& file_name : landmarks_file_names) {
const auto landmarks_detection_result =
GetLandmarksDetectionResult(file_name);
expected_results.hand_landmarks.push_back(
landmarks_detection_result.landmarks());
expected_results.handedness.push_back(
landmarks_detection_result.classifications());
}
return expected_results;
}
void ExpectHandLandmarksDetectionResultsCorrect(
const HandLandmarksDetectionResult& actual_results,
const HandLandmarksDetectionResult& expected_results) {
const auto& actual_landmarks = actual_results.hand_landmarks;
const auto& actual_handedness = actual_results.handedness;
const auto& expected_landmarks = expected_results.hand_landmarks;
const auto& expected_handedness = expected_results.handedness;
ASSERT_EQ(actual_landmarks.size(), expected_landmarks.size());
ASSERT_EQ(actual_handedness.size(), expected_handedness.size());
EXPECT_THAT(
actual_handedness,
Pointwise(Approximately(Partially(EqualsProto()), kHandednessMargin),
expected_handedness));
EXPECT_THAT(actual_landmarks,
Pointwise(Approximately(Partially(EqualsProto()),
/*margin=*/kLandmarksAbsMargin,
/*fraction=*/kLandmarksFractionDiff),
expected_landmarks));
}
} // namespace
struct TestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of test image.
std::string test_image_name;
// The filename of test model.
std::string test_model_file;
// The rotation to apply to the test image before processing, in degrees
// clockwise.
int rotation;
// Expected results from the hand landmarker model output.
HandLandmarksDetectionResult expected_results;
};
class ImageModeTest : public testing::TestWithParam<TestParams> {};
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
auto results = hand_landmarker->DetectForVideo(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the video mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = hand_landmarker->DetectAsync(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the live stream mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(hand_landmarker->Close());
}
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = hand_landmarker->Detect(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("This task doesn't support region-of-interest"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
}
TEST_P(ImageModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
HandLandmarksDetectionResult hand_landmarker_results;
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK_AND_ASSIGN(
hand_landmarker_results,
hand_landmarker->Detect(image, image_processing_options));
} else {
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
hand_landmarker->Detect(image));
}
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results,
GetParam().expected_results);
MP_ASSERT_OK(hand_landmarker->Close());
}
INSTANTIATE_TEST_SUITE_P(
HandGestureTest, ImageModeTest,
Values(TestParams{
/* test_name= */ "LandmarksThumbUp",
/* test_image_name= */ kThumbUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kThumbUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUp",
/* test_image_name= */ kPointingUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUpRotated",
/* test_image_name= */ kPointingUpRotatedImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpRotatedLandmarksFilename}),
},
TestParams{
/* test_name= */ "NoHands",
/* test_image_name= */ kNoHandsImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
{{}, {}, {}},
}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
class VideoModeTest : public testing::TestWithParam<TestParams> {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
auto results = hand_landmarker->Detect(image);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the image mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = hand_landmarker->DetectAsync(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the live stream mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(hand_landmarker->Close());
}
TEST_P(VideoModeTest, Succeeds) {
const int iterations = 100;
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
const auto expected_results = GetParam().expected_results;
for (int i = 0; i < iterations; ++i) {
HandLandmarksDetectionResult hand_landmarker_results;
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK_AND_ASSIGN(
hand_landmarker_results,
hand_landmarker->DetectForVideo(image, i, image_processing_options));
} else {
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
hand_landmarker->DetectForVideo(image, i));
}
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results,
expected_results);
}
MP_ASSERT_OK(hand_landmarker->Close());
}
INSTANTIATE_TEST_SUITE_P(
HandGestureTest, VideoModeTest,
Values(TestParams{
/* test_name= */ "LandmarksThumbUp",
/* test_image_name= */ kThumbUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kThumbUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUp",
/* test_image_name= */ kPointingUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUpRotated",
/* test_image_name= */ kPointingUpRotatedImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpRotatedLandmarksFilename}),
},
TestParams{
/* test_name= */ "NoHands",
/* test_image_name= */ kNoHandsImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
{{}, {}, {}},
}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
class LiveStreamModeTest : public testing::TestWithParam<TestParams> {};
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kThumbUpImage)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback =
[](absl::StatusOr<HandLandmarksDetectionResult> results,
const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
auto results = hand_landmarker->Detect(image);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the image mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
results = hand_landmarker->DetectForVideo(image, 0);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("not initialized with the video mode"));
EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError))));
MP_ASSERT_OK(hand_landmarker->Close());
}
TEST_P(LiveStreamModeTest, Succeeds) {
const int iterations = 100;
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<HandLandmarkerOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
options->running_mode = core::RunningMode::LIVE_STREAM;
std::vector<HandLandmarksDetectionResult> hand_landmarker_results;
std::vector<std::pair<int, int>> image_sizes;
std::vector<int64> timestamps;
options->result_callback =
[&hand_landmarker_results, &image_sizes, &timestamps](
absl::StatusOr<HandLandmarksDetectionResult> results,
const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(results.status());
hand_landmarker_results.push_back(std::move(results.value()));
image_sizes.push_back({image.width(), image.height()});
timestamps.push_back(timestamp_ms);
};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options)));
for (int i = 0; i < iterations; ++i) {
HandLandmarksDetectionResult hand_landmarker_results;
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK(
hand_landmarker->DetectAsync(image, i, image_processing_options));
} else {
MP_ASSERT_OK(hand_landmarker->DetectAsync(image, i));
}
}
MP_ASSERT_OK(hand_landmarker->Close());
// Due to the flow limiter, the total of outputs will be smaller than the
// number of iterations.
ASSERT_LE(hand_landmarker_results.size(), iterations);
ASSERT_GT(hand_landmarker_results.size(), 0);
const auto expected_results = GetParam().expected_results;
for (int i = 0; i < hand_landmarker_results.size(); ++i) {
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results[i],
expected_results);
}
for (const auto& image_size : image_sizes) {
EXPECT_EQ(image_size.first, image.width());
EXPECT_EQ(image_size.second, image.height());
}
int64 timestamp_ms = -1;
for (const auto& timestamp : timestamps) {
EXPECT_GT(timestamp, timestamp_ms);
timestamp_ms = timestamp;
}
}
INSTANTIATE_TEST_SUITE_P(
HandGestureTest, LiveStreamModeTest,
Values(TestParams{
/* test_name= */ "LandmarksThumbUp",
/* test_image_name= */ kThumbUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kThumbUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUp",
/* test_image_name= */ kPointingUpImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpLandmarksFilename}),
},
TestParams{
/* test_name= */ "LandmarksPointingUpRotated",
/* test_image_name= */ kPointingUpRotatedImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90,
/* expected_results = */
GetExpectedHandLandmarksDetectionResult(
{kPointingUpRotatedLandmarksFilename}),
},
TestParams{
/* test_name= */ "NoHands",
/* test_image_name= */ kNoHandsImage,
/* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0,
/* expected_results = */
{{}, {}, {}},
}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe