Use c++ struct as hand landmark detection results.

PiperOrigin-RevId: 504048095
This commit is contained in:
MediaPipe Team 2023-01-23 12:09:41 -08:00 committed by Copybara-Service
parent 1124569c29
commit 69d354fc89
11 changed files with 439 additions and 42 deletions

View File

@ -62,3 +62,12 @@ cc_library(
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
],
)
cc_library(
name = "landmark",
srcs = ["landmark.cc"],
hdrs = ["landmark.h"],
deps = [
"//mediapipe/framework/formats:landmark_cc_proto",
],
)

View File

@ -40,6 +40,19 @@ Classifications ConvertToClassifications(const proto::Classifications& proto) {
return classifications;
}
Classifications ConvertToClassifications(
const mediapipe::ClassificationList& proto, int head_index,
std::optional<std::string> head_name) {
Classifications classifications;
classifications.categories.reserve(proto.classification_size());
for (const auto& classification : proto.classification()) {
classifications.categories.push_back(ConvertToCategory(classification));
}
classifications.head_index = head_index;
classifications.head_name = head_name;
return classifications;
}
ClassificationResult ConvertToClassificationResult(
const proto::ClassificationResult& proto) {
ClassificationResult classification_result;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
@ -58,6 +59,12 @@ struct ClassificationResult {
// Classifications struct.
Classifications ConvertToClassifications(const proto::Classifications& proto);
// Utility function to convert from ClassificationList proto to
// Classifications struct.
Classifications ConvertToClassifications(
const mediapipe::ClassificationList& proto, int head_index = 0,
std::optional<std::string> head_name = std::nullopt);
// Utility function to convert from ClassificationResult proto to
// ClassificationResult struct.
ClassificationResult ConvertToClassificationResult(

View File

@ -0,0 +1,65 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/landmark.h"
#include <optional>
#include <utility>
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe::tasks::components::containers {
Landmark ConvertToLandmark(const mediapipe::Landmark& proto) {
return {/*x=*/proto.x(), /*y=*/proto.y(), /*z=*/proto.z(),
/*visibility=*/proto.has_visibility()
? std::optional<float>(proto.visibility())
: std::nullopt,
/*presence=*/proto.has_presence()
? std::optional<float>(proto.presence())
: std::nullopt};
}
NormalizedLandmark ConvertToNormalizedLandmark(
const mediapipe::NormalizedLandmark& proto) {
return {/*x=*/proto.x(), /*y=*/proto.y(), /*z=*/proto.z(),
/*visibility=*/proto.has_visibility()
? std::optional<float>(proto.visibility())
: std::nullopt,
/*presence=*/proto.has_presence()
? std::optional<float>(proto.presence())
: std::nullopt};
}
Landmarks ConvertToLandmarks(const mediapipe::LandmarkList& proto) {
Landmarks landmarks;
landmarks.landmarks.reserve(proto.landmark_size());
for (const auto& landmark : proto.landmark()) {
landmarks.landmarks.push_back(ConvertToLandmark(landmark));
}
return landmarks;
}
NormalizedLandmarks ConvertToNormalizedLandmarks(
const mediapipe::NormalizedLandmarkList& proto) {
NormalizedLandmarks landmarks;
landmarks.landmarks.reserve(proto.landmark_size());
for (const auto& landmark : proto.landmark()) {
landmarks.landmarks.push_back(ConvertToNormalizedLandmark(landmark));
}
return landmarks;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,103 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_
#include <cstdlib>
#include <string>
#include <vector>
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe::tasks::components::containers {
constexpr float kLandmarkTolerance = 1e-6;
// Landmark represents a point in 3D space with x, y, z coordinates. The
// landmark coordinates are in meters. z represents the landmark depth, and the
// smaller the value the closer the world landmark is to the camera.
struct Landmark {
float x;
float y;
float z;
// Landmark visibility. Should stay unset if not supported.
// Float score of whether landmark is visible or occluded by other objects.
// Landmark considered as invisible also if it is not present on the screen
// (out of scene bounds). Depending on the model, visibility value is either a
// sigmoid or an argument of sigmoid.
std::optional<float> visibility = std::nullopt;
// Landmark presence. Should stay unset if not supported.
// Float score of whether landmark is present on the scene (located within
// scene bounds). Depending on the model, presence value is either a result of
// sigmoid or an argument of sigmoid function to get landmark presence
// probability.
std::optional<float> presence = std::nullopt;
// Landmark name. Should stay unset if not supported.
std::optional<std::string> name = std::nullopt;
};
inline bool operator==(const Landmark& lhs, const Landmark& rhs) {
return abs(lhs.x - rhs.x) < kLandmarkTolerance &&
abs(lhs.y - rhs.y) < kLandmarkTolerance &&
abs(lhs.z - rhs.z) < kLandmarkTolerance;
}
// A normalized version of above Landmark struct. All coordinates should be
// within [0, 1].
struct NormalizedLandmark {
float x;
float y;
float z;
std::optional<float> visibility = std::nullopt;
std::optional<float> presence = std::nullopt;
std::optional<std::string> name = std::nullopt;
};
inline bool operator==(const NormalizedLandmark& lhs,
const NormalizedLandmark& rhs) {
return abs(lhs.x - rhs.x) < kLandmarkTolerance &&
abs(lhs.y - rhs.y) < kLandmarkTolerance &&
abs(lhs.z - rhs.z) < kLandmarkTolerance;
}
// A list of Landmarks.
struct Landmarks {
std::vector<Landmark> landmarks;
};
// A list of NormalizedLandmarks.
struct NormalizedLandmarks {
std::vector<NormalizedLandmark> landmarks;
};
// Utility function to convert from Landmark proto to Landmark struct.
Landmark ConvertToLandmark(const mediapipe::Landmark& proto);
// Utility function to convert from NormalizedLandmark proto to
// NormalizedLandmark struct.
NormalizedLandmark ConvertToNormalizedLandmark(
const mediapipe::NormalizedLandmark& proto);
// Utility function to convert from LandmarkList proto to Landmarks struct.
Landmarks ConvertToLandmarks(const mediapipe::LandmarkList& proto);
// Utility function to convert from NormalizedLandmarkList proto to
// NormalizedLandmarks struct.
NormalizedLandmarks ConvertToNormalizedLandmarks(
const mediapipe::NormalizedLandmarkList& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_

View File

@ -154,11 +154,14 @@ cc_library(
cc_library(
name = "hand_landmarker_result",
srcs = ["hand_landmarker_result.cc"],
hdrs = ["hand_landmarker_result.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/tasks/cc/components/containers:classification_result",
"//mediapipe/tasks/cc/components/containers:landmark",
],
)

View File

@ -155,9 +155,13 @@ absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
Packet hand_world_landmarks_packet =
status_or_packets.value()[kHandWorldLandmarksStreamName];
result_callback(
{{handedness_packet.Get<std::vector<ClassificationList>>(),
ConvertToHandLandmarkerResult(
/* handedness= */ handedness_packet
.Get<std::vector<ClassificationList>>(),
/* hand_landmarks= */
hand_landmarks_packet.Get<std::vector<NormalizedLandmarkList>>(),
hand_world_landmarks_packet.Get<std::vector<LandmarkList>>()}},
/* hand_world_landmarks= */
hand_world_landmarks_packet.Get<std::vector<LandmarkList>>()),
image_packet.Get<Image>(),
hand_landmarks_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
@ -193,15 +197,21 @@ absl::StatusOr<HandLandmarkerResult> HandLandmarker::Detect(
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
return {HandLandmarkerResult()};
}
return {{/* handedness= */
{output_packets[kHandednessStreamName]
.Get<std::vector<mediapipe::ClassificationList>>()},
return ConvertToHandLandmarkerResult(/* handedness= */
output_packets[kHandednessStreamName]
.Get<std::vector<
mediapipe::
ClassificationList>>(),
/* hand_landmarks= */
{output_packets[kHandLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
output_packets[kHandLandmarksStreamName]
.Get<std::vector<
mediapipe::
NormalizedLandmarkList>>(),
/* hand_world_landmarks */
{output_packets[kHandWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>()}}};
output_packets
[kHandWorldLandmarksStreamName]
.Get<std::vector<
mediapipe::LandmarkList>>());
}
absl::StatusOr<HandLandmarkerResult> HandLandmarker::DetectForVideo(
@ -228,17 +238,21 @@ absl::StatusOr<HandLandmarkerResult> HandLandmarker::DetectForVideo(
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
return {HandLandmarkerResult()};
}
return {
{/* handedness= */
{output_packets[kHandednessStreamName]
.Get<std::vector<mediapipe::ClassificationList>>()},
return ConvertToHandLandmarkerResult(/* handedness= */
output_packets[kHandednessStreamName]
.Get<std::vector<
mediapipe::
ClassificationList>>(),
/* hand_landmarks= */
{output_packets[kHandLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
output_packets[kHandLandmarksStreamName]
.Get<std::vector<
mediapipe::
NormalizedLandmarkList>>(),
/* hand_world_landmarks */
{output_packets[kHandWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>()}},
};
output_packets
[kHandWorldLandmarksStreamName]
.Get<std::vector<
mediapipe::LandmarkList>>());
}
absl::Status HandLandmarker::DetectAsync(

View File

@ -0,0 +1,56 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
#include <algorithm>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/landmark.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
HandLandmarkerResult ConvertToHandLandmarkerResult(
const std::vector<mediapipe::ClassificationList>& handedness_proto,
const std::vector<mediapipe::NormalizedLandmarkList>& hand_landmarks_proto,
const std::vector<mediapipe::LandmarkList>& hand_world_landmarks_proto) {
HandLandmarkerResult result;
result.handedness.resize(handedness_proto.size());
result.hand_landmarks.resize(hand_landmarks_proto.size());
result.hand_world_landmarks.resize(hand_world_landmarks_proto.size());
std::transform(handedness_proto.begin(), handedness_proto.end(),
result.handedness.begin(),
[](const mediapipe::ClassificationList& classification_list) {
return components::containers::ConvertToClassifications(
classification_list);
});
std::transform(hand_landmarks_proto.begin(), hand_landmarks_proto.end(),
result.hand_landmarks.begin(),
components::containers::ConvertToNormalizedLandmarks);
std::transform(hand_world_landmarks_proto.begin(),
hand_world_landmarks_proto.end(),
result.hand_world_landmarks.begin(),
components::containers::ConvertToLandmarks);
return result;
}
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -18,6 +18,8 @@ limitations under the License.
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/landmark.h"
namespace mediapipe {
namespace tasks {
@ -28,13 +30,18 @@ namespace hand_landmarker {
// element represents a single hand detected in the image.
struct HandLandmarkerResult {
// Classification of handedness.
std::vector<mediapipe::ClassificationList> handedness;
std::vector<components::containers::Classifications> handedness;
// Detected hand landmarks in normalized image coordinates.
std::vector<mediapipe::NormalizedLandmarkList> hand_landmarks;
std::vector<components::containers::NormalizedLandmarks> hand_landmarks;
// Detected hand landmarks in world coordinates.
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
std::vector<components::containers::Landmarks> hand_world_landmarks;
};
HandLandmarkerResult ConvertToHandLandmarkerResult(
const std::vector<mediapipe::ClassificationList>& handedness_proto,
const std::vector<mediapipe::NormalizedLandmarkList>& hand_landmarks_proto,
const std::vector<mediapipe::LandmarkList>& hand_world_landmarks_proto);
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks

View File

@ -0,0 +1,88 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
#include <optional>
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/landmark.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
TEST(ConvertFromProto, Succeeds) {
mediapipe::ClassificationList classification_list_proto;
mediapipe::Classification& classification_proto =
*classification_list_proto.add_classification();
classification_proto.set_index(1);
classification_proto.set_score(0.5);
classification_proto.set_label("Left");
classification_proto.set_display_name("Left_Hand");
mediapipe::NormalizedLandmarkList normalized_landmark_list_proto;
mediapipe::NormalizedLandmark& normalized_landmark_proto =
*normalized_landmark_list_proto.add_landmark();
normalized_landmark_proto.set_x(0.1);
normalized_landmark_proto.set_y(0.2);
normalized_landmark_proto.set_z(0.3);
mediapipe::LandmarkList landmark_list_proto;
mediapipe::Landmark& landmark_proto = *landmark_list_proto.add_landmark();
landmark_proto.set_x(3.1);
landmark_proto.set_y(5.2);
landmark_proto.set_z(4.3);
std::vector<mediapipe::ClassificationList> classification_lists = {
classification_list_proto};
std::vector<mediapipe::NormalizedLandmarkList> normalized_landmarks_lists = {
normalized_landmark_list_proto};
std::vector<mediapipe::LandmarkList> landmarks_lists = {landmark_list_proto};
HandLandmarkerResult hand_landmarker_result = ConvertToHandLandmarkerResult(
classification_lists, normalized_landmarks_lists, landmarks_lists);
EXPECT_EQ(hand_landmarker_result.handedness.size(), 1);
EXPECT_EQ(hand_landmarker_result.handedness[0].categories.size(), 1);
EXPECT_THAT(
hand_landmarker_result.handedness[0].categories[0],
testing::FieldsAre(1, testing::FloatEq(0.5), "Left", "Left_Hand"));
EXPECT_EQ(hand_landmarker_result.hand_landmarks.size(), 1);
EXPECT_EQ(hand_landmarker_result.hand_landmarks[0].landmarks.size(), 1);
EXPECT_THAT(hand_landmarker_result.hand_landmarks[0].landmarks[0],
testing::FieldsAre(testing::FloatEq(0.1), testing::FloatEq(0.2),
testing::FloatEq(0.3), std::nullopt,
std::nullopt, std::nullopt));
EXPECT_EQ(hand_landmarker_result.hand_world_landmarks.size(), 1);
EXPECT_EQ(hand_landmarker_result.hand_world_landmarks[0].landmarks.size(), 1);
EXPECT_THAT(hand_landmarker_result.hand_world_landmarks[0].landmarks[0],
testing::FieldsAre(testing::FloatEq(3.1), testing::FloatEq(5.2),
testing::FloatEq(4.3), std::nullopt,
std::nullopt, std::nullopt));
}
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -32,6 +32,8 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/landmark.h"
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
@ -50,18 +52,16 @@ namespace {
using ::file::Defaults;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::ConvertToClassifications;
using ::mediapipe::tasks::components::containers::ConvertToNormalizedLandmarks;
using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::EqualsProto;
using ::testing::HasSubstr;
using ::testing::Optional;
using ::testing::Pointwise;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kHandLandmarkerBundleAsset[] = "hand_landmarker.task";
@ -74,7 +74,6 @@ constexpr char kPointingUpImage[] = "pointing_up.jpg";
constexpr char kPointingUpRotatedImage[] = "pointing_up_rotated.jpg";
constexpr char kNoHandsImage[] = "cats_and_dogs.jpg";
constexpr float kLandmarksFractionDiff = 0.03; // percentage
constexpr float kLandmarksAbsMargin = 0.03;
constexpr float kHandednessMargin = 0.05;
@ -101,13 +100,47 @@ HandLandmarkerResult GetExpectedHandLandmarkerResult(
const auto landmarks_detection_result =
GetLandmarksDetectionResult(file_name);
expected_results.hand_landmarks.push_back(
landmarks_detection_result.landmarks());
ConvertToNormalizedLandmarks(landmarks_detection_result.landmarks()));
expected_results.handedness.push_back(
landmarks_detection_result.classifications());
ConvertToClassifications(landmarks_detection_result.classifications()));
}
return expected_results;
}
MATCHER_P2(HandednessMatches, expected_handedness, tolerance, "") {
for (int i = 0; i < arg.size(); i++) {
for (int j = 0; j < arg[i].categories.size(); j++) {
if (arg[i].categories[j].index !=
expected_handedness[i].categories[j].index) {
return false;
}
if (std::abs(arg[i].categories[j].score -
expected_handedness[i].categories[j].score) > tolerance) {
return false;
}
if (arg[i].categories[j].category_name !=
expected_handedness[i].categories[j].category_name) {
return false;
}
}
}
return true;
}
MATCHER_P2(LandmarksMatches, expected_landmarks, toleration, "") {
for (int i = 0; i < arg.size(); i++) {
for (int j = 0; j < arg[i].landmarks.size(); j++) {
if (std::abs(arg[i].landmarks[j].x -
expected_landmarks[i].landmarks[j].x) > toleration ||
std::abs(arg[i].landmarks[j].y -
expected_landmarks[i].landmarks[j].y) > toleration) {
return false;
}
}
}
return true;
}
void ExpectHandLandmarkerResultsCorrect(
const HandLandmarkerResult& actual_results,
const HandLandmarkerResult& expected_results) {
@ -119,16 +152,15 @@ void ExpectHandLandmarkerResultsCorrect(
ASSERT_EQ(actual_landmarks.size(), expected_landmarks.size());
ASSERT_EQ(actual_handedness.size(), expected_handedness.size());
if (actual_landmarks.empty()) {
return;
}
ASSERT_GE(actual_landmarks.size(), 1);
EXPECT_THAT(
actual_handedness,
Pointwise(Approximately(Partially(EqualsProto()), kHandednessMargin),
expected_handedness));
EXPECT_THAT(actual_handedness,
HandednessMatches(expected_handedness, kHandednessMargin));
EXPECT_THAT(actual_landmarks,
Pointwise(Approximately(Partially(EqualsProto()),
/*margin=*/kLandmarksAbsMargin,
/*fraction=*/kLandmarksFractionDiff),
expected_landmarks));
LandmarksMatches(expected_landmarks, kLandmarksAbsMargin));
}
} // namespace