Update C++ HandLandmarksDetectionResult to HandLandmarkerResult.

PiperOrigin-RevId: 487443827
This commit is contained in:
MediaPipe Team 2022-11-09 23:18:02 -08:00 committed by Copybara-Service
parent d2142e86a9
commit f11c757629
6 changed files with 70 additions and 82 deletions

View File

@ -30,15 +30,6 @@ cc_library(
], ],
) )
cc_library(
name = "hand_landmarks_detection_result",
hdrs = ["hand_landmarks_detection_result.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
],
)
cc_library( cc_library(
name = "category", name = "category",
srcs = ["category.cc"], srcs = ["category.cc"],

View File

@ -110,12 +110,22 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "hand_landmarker_result",
hdrs = ["hand_landmarker_result.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
],
)
cc_library( cc_library(
name = "hand_landmarker", name = "hand_landmarker",
srcs = ["hand_landmarker.cc"], srcs = ["hand_landmarker.cc"],
hdrs = ["hand_landmarker.h"], hdrs = ["hand_landmarker.h"],
deps = [ deps = [
":hand_landmarker_graph", ":hand_landmarker_graph",
":hand_landmarker_result",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
@ -124,7 +134,6 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers:hand_landmarks_detection_result",
"//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/base_task_api.h"
@ -34,6 +33,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
@ -47,8 +47,6 @@ namespace {
using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision::
hand_landmarker::proto::HandLandmarkerGraphOptions; hand_landmarker::proto::HandLandmarkerGraphOptions;
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
constexpr char kHandLandmarkerGraphTypeName[] = constexpr char kHandLandmarkerGraphTypeName[] =
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"; "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph";
@ -145,7 +143,7 @@ absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
Packet empty_packet = Packet empty_packet =
status_or_packets.value()[kHandLandmarksStreamName]; status_or_packets.value()[kHandLandmarksStreamName];
result_callback( result_callback(
{HandLandmarksDetectionResult()}, image_packet.Get<Image>(), {HandLandmarkerResult()}, image_packet.Get<Image>(),
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
return; return;
} }
@ -173,7 +171,7 @@ absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
std::move(packets_callback)); std::move(packets_callback));
} }
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect( absl::StatusOr<HandLandmarkerResult> HandLandmarker::Detect(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -192,7 +190,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}})); MakePacket<NormalizedRect>(std::move(norm_rect))}}));
if (output_packets[kHandLandmarksStreamName].IsEmpty()) { if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
return {HandLandmarksDetectionResult()}; return {HandLandmarkerResult()};
} }
return {{/* handedness= */ return {{/* handedness= */
{output_packets[kHandednessStreamName] {output_packets[kHandednessStreamName]
@ -205,7 +203,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
.Get<std::vector<mediapipe::LandmarkList>>()}}}; .Get<std::vector<mediapipe::LandmarkList>>()}}};
} }
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::DetectForVideo( absl::StatusOr<HandLandmarkerResult> HandLandmarker::DetectForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -227,7 +225,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::DetectForVideo(
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
if (output_packets[kHandLandmarksStreamName].IsEmpty()) { if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
return {HandLandmarksDetectionResult()}; return {HandLandmarkerResult()};
} }
return { return {
{/* handedness= */ {/* handedness= */

View File

@ -24,12 +24,12 @@ limitations under the License.
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -70,9 +70,7 @@ struct HandLandmarkerOptions {
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.
std::function<void( std::function<void(absl::StatusOr<HandLandmarkerResult>, const Image&, int64)>
absl::StatusOr<components::containers::HandLandmarksDetectionResult>,
const Image&, int64)>
result_callback = nullptr; result_callback = nullptr;
}; };
@ -92,7 +90,7 @@ struct HandLandmarkerOptions {
// 'y_center', 'width' and 'height' fields is NOT supported and will // 'y_center', 'width' and 'height' fields is NOT supported and will
// result in an invalid argument error being returned. // result in an invalid argument error being returned.
// Outputs: // Outputs:
// HandLandmarksDetectionResult // HandLandmarkerResult
// - The hand landmarks detection results. // - The hand landmarks detection results.
class HandLandmarker : tasks::vision::core::BaseVisionTaskApi { class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
public: public:
@ -129,7 +127,7 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed // TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented. // after the yuv support is implemented.
absl::StatusOr<components::containers::HandLandmarksDetectionResult> Detect( absl::StatusOr<HandLandmarkerResult> Detect(
Image image, Image image,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -147,10 +145,10 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to // The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<components::containers::HandLandmarksDetectionResult> absl::StatusOr<HandLandmarkerResult> DetectForVideo(
DetectForVideo(Image image, int64 timestamp_ms, Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> std::optional<core::ImageProcessingOptions> image_processing_options =
image_processing_options = std::nullopt); std::nullopt);
// Sends live image data to perform hand landmarks detection, and the results // Sends live image data to perform hand landmarks detection, and the results
// will be available via the "result_callback" provided in the // will be available via the "result_callback" provided in the
@ -169,7 +167,7 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
// invalid argument error being returned. // invalid argument error being returned.
// //
// The "result_callback" provides // The "result_callback" provides
// - A vector of HandLandmarksDetectionResult, each is the detected results // - A vector of HandLandmarkerResult, each is the detected results
// for a input frame. // for a input frame.
// - The const reference to the corresponding input image that the hand // - The const reference to the corresponding input image that the hand
// landmarker runs on. Note that the const reference to the image will no // landmarker runs on. Note that the const reference to the image will no

View File

@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_ #ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_ #define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace vision {
namespace containers { namespace hand_landmarker {
// The hand landmarks detection result from HandLandmarker, where each vector // The hand landmarks detection result from HandLandmarker, where each vector
// element represents a single hand detected in the image. // element represents a single hand detected in the image.
struct HandLandmarksDetectionResult { struct HandLandmarkerResult {
// Classification of handedness. // Classification of handedness.
std::vector<mediapipe::ClassificationList> handedness; std::vector<mediapipe::ClassificationList> handedness;
// Detected hand landmarks in normalized image coordinates. // Detected hand landmarks in normalized image coordinates.
@ -35,9 +35,9 @@ struct HandLandmarksDetectionResult {
std::vector<mediapipe::LandmarkList> hand_world_landmarks; std::vector<mediapipe::LandmarkList> hand_world_landmarks;
}; };
} // namespace containers } // namespace hand_landmarker
} // namespace components } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_ #endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_

View File

@ -32,12 +32,12 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -50,7 +50,6 @@ namespace {
using ::file::Defaults; using ::file::Defaults;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
@ -95,9 +94,9 @@ LandmarksDetectionResult GetLandmarksDetectionResult(
return result; return result;
} }
HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult( HandLandmarkerResult GetExpectedHandLandmarkerResult(
const std::vector<absl::string_view>& landmarks_file_names) { const std::vector<absl::string_view>& landmarks_file_names) {
HandLandmarksDetectionResult expected_results; HandLandmarkerResult expected_results;
for (const auto& file_name : landmarks_file_names) { for (const auto& file_name : landmarks_file_names) {
const auto landmarks_detection_result = const auto landmarks_detection_result =
GetLandmarksDetectionResult(file_name); GetLandmarksDetectionResult(file_name);
@ -109,9 +108,9 @@ HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult(
return expected_results; return expected_results;
} }
void ExpectHandLandmarksDetectionResultsCorrect( void ExpectHandLandmarkerResultsCorrect(
const HandLandmarksDetectionResult& actual_results, const HandLandmarkerResult& actual_results,
const HandLandmarksDetectionResult& expected_results) { const HandLandmarkerResult& expected_results) {
const auto& actual_landmarks = actual_results.hand_landmarks; const auto& actual_landmarks = actual_results.hand_landmarks;
const auto& actual_handedness = actual_results.handedness; const auto& actual_handedness = actual_results.handedness;
@ -145,7 +144,7 @@ struct TestParams {
// clockwise. // clockwise.
int rotation; int rotation;
// Expected results from the hand landmarker model output. // Expected results from the hand landmarker model output.
HandLandmarksDetectionResult expected_results; HandLandmarkerResult expected_results;
}; };
class ImageModeTest : public testing::TestWithParam<TestParams> {}; class ImageModeTest : public testing::TestWithParam<TestParams> {};
@ -213,7 +212,7 @@ TEST_P(ImageModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options))); HandLandmarker::Create(std::move(options)));
HandLandmarksDetectionResult hand_landmarker_results; HandLandmarkerResult hand_landmarker_results;
if (GetParam().rotation != 0) { if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options; ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation; image_processing_options.rotation_degrees = GetParam().rotation;
@ -224,7 +223,7 @@ TEST_P(ImageModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results, MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
hand_landmarker->Detect(image)); hand_landmarker->Detect(image));
} }
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results, ExpectHandLandmarkerResultsCorrect(hand_landmarker_results,
GetParam().expected_results); GetParam().expected_results);
MP_ASSERT_OK(hand_landmarker->Close()); MP_ASSERT_OK(hand_landmarker->Close());
} }
@ -237,8 +236,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
{kThumbUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUp", /* test_name= */ "LandmarksPointingUp",
@ -246,8 +244,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
{kPointingUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUpRotated", /* test_name= */ "LandmarksPointingUpRotated",
@ -255,7 +252,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90, /* rotation= */ -90,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult(
{kPointingUpRotatedLandmarksFilename}), {kPointingUpRotatedLandmarksFilename}),
}, },
TestParams{ TestParams{
@ -315,7 +312,7 @@ TEST_P(VideoModeTest, Succeeds) {
HandLandmarker::Create(std::move(options))); HandLandmarker::Create(std::move(options)));
const auto expected_results = GetParam().expected_results; const auto expected_results = GetParam().expected_results;
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
HandLandmarksDetectionResult hand_landmarker_results; HandLandmarkerResult hand_landmarker_results;
if (GetParam().rotation != 0) { if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options; ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation; image_processing_options.rotation_degrees = GetParam().rotation;
@ -326,7 +323,7 @@ TEST_P(VideoModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results, MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
hand_landmarker->DetectForVideo(image, i)); hand_landmarker->DetectForVideo(image, i));
} }
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results, ExpectHandLandmarkerResultsCorrect(hand_landmarker_results,
expected_results); expected_results);
} }
MP_ASSERT_OK(hand_landmarker->Close()); MP_ASSERT_OK(hand_landmarker->Close());
@ -340,8 +337,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
{kThumbUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUp", /* test_name= */ "LandmarksPointingUp",
@ -349,8 +345,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
{kPointingUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUpRotated", /* test_name= */ "LandmarksPointingUpRotated",
@ -358,7 +353,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90, /* rotation= */ -90,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult(
{kPointingUpRotatedLandmarksFilename}), {kPointingUpRotatedLandmarksFilename}),
}, },
TestParams{ TestParams{
@ -383,8 +378,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset); JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = options->result_callback = [](absl::StatusOr<HandLandmarkerResult> results,
[](absl::StatusOr<HandLandmarksDetectionResult> results,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
@ -416,12 +410,12 @@ TEST_P(LiveStreamModeTest, Succeeds) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().test_model_file); JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
std::vector<HandLandmarksDetectionResult> hand_landmarker_results; std::vector<HandLandmarkerResult> hand_landmarker_results;
std::vector<std::pair<int, int>> image_sizes; std::vector<std::pair<int, int>> image_sizes;
std::vector<int64> timestamps; std::vector<int64> timestamps;
options->result_callback = options->result_callback = [&hand_landmarker_results, &image_sizes,
[&hand_landmarker_results, &image_sizes, &timestamps]( &timestamps](
absl::StatusOr<HandLandmarksDetectionResult> results, absl::StatusOr<HandLandmarkerResult> results,
const Image& image, int64 timestamp_ms) { const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(results.status()); MP_ASSERT_OK(results.status());
hand_landmarker_results.push_back(std::move(results.value())); hand_landmarker_results.push_back(std::move(results.value()));
@ -432,7 +426,7 @@ TEST_P(LiveStreamModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options))); HandLandmarker::Create(std::move(options)));
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
HandLandmarksDetectionResult hand_landmarker_results; HandLandmarkerResult hand_landmarker_results;
if (GetParam().rotation != 0) { if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options; ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation; image_processing_options.rotation_degrees = GetParam().rotation;
@ -450,7 +444,7 @@ TEST_P(LiveStreamModeTest, Succeeds) {
const auto expected_results = GetParam().expected_results; const auto expected_results = GetParam().expected_results;
for (int i = 0; i < hand_landmarker_results.size(); ++i) { for (int i = 0; i < hand_landmarker_results.size(); ++i) {
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results[i], ExpectHandLandmarkerResultsCorrect(hand_landmarker_results[i],
expected_results); expected_results);
} }
for (const auto& image_size : image_sizes) { for (const auto& image_size : image_sizes) {
@ -472,8 +466,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
{kThumbUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUp", /* test_name= */ "LandmarksPointingUp",
@ -481,8 +474,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
{kPointingUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUpRotated", /* test_name= */ "LandmarksPointingUpRotated",
@ -490,7 +482,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90, /* rotation= */ -90,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult(
{kPointingUpRotatedLandmarksFilename}), {kPointingUpRotatedLandmarksFilename}),
}, },
TestParams{ TestParams{