Merge branch 'master' into c-landmarker-apis

This commit is contained in:
Kinar R 2023-11-29 16:41:12 +05:30 committed by GitHub
commit 6ed5e3d0df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
51 changed files with 6211 additions and 14 deletions

View File

@ -7,4 +7,4 @@ tensorflow-addons
tensorflow-datasets
tensorflow-hub
tensorflow-text
tf-models-official>=2.13.1
tf-models-official>=2.13.2

View File

@ -125,7 +125,8 @@ def draw_landmarks(
color=RED_COLOR),
connection_drawing_spec: Union[DrawingSpec,
Mapping[Tuple[int, int],
DrawingSpec]] = DrawingSpec()):
DrawingSpec]] = DrawingSpec(),
is_drawing_landmarks: bool = True):
"""Draws the landmarks and the connections on the image.
Args:
@ -142,6 +143,8 @@ def draw_landmarks(
connections to the DrawingSpecs that specifies the connections' drawing
settings such as color and line thickness. If this argument is explicitly
set to None, no landmark connections will be drawn.
is_drawing_landmarks: Whether to draw landmarks. If set false, skip drawing
landmarks, only contours will be drawed.
Raises:
ValueError: If one of the followings:
@ -181,7 +184,7 @@ def draw_landmarks(
drawing_spec.thickness)
# Draws landmark points after finishing the connection lines, which is
# aesthetically better.
if landmark_drawing_spec:
if is_drawing_landmarks and landmark_drawing_spec:
for idx, landmark_px in idx_to_coordinates.items():
drawing_spec = landmark_drawing_spec[idx] if isinstance(
landmark_drawing_spec, Mapping) else landmark_drawing_spec

View File

@ -70,6 +70,60 @@ cc_test(
],
)
cc_library(
name = "rect",
hdrs = ["rect.h"],
)
cc_library(
name = "rect_converter",
srcs = ["rect_converter.cc"],
hdrs = ["rect_converter.h"],
deps = [
":rect",
"//mediapipe/tasks/cc/components/containers:rect",
],
)
cc_test(
name = "rect_converter_test",
srcs = ["rect_converter_test.cc"],
deps = [
":rect",
":rect_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/components/containers:rect",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "keypoint",
hdrs = ["keypoint.h"],
)
cc_library(
name = "keypoint_converter",
srcs = ["keypoint_converter.cc"],
hdrs = ["keypoint_converter.h"],
deps = [
":keypoint",
"//mediapipe/tasks/cc/components/containers:keypoint",
],
)
cc_test(
name = "keypoint_converter_test",
srcs = ["keypoint_converter_test.cc"],
deps = [
":keypoint",
":keypoint_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/components/containers:keypoint",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "classification_result",
hdrs = ["classification_result.h"],
@ -99,6 +153,40 @@ cc_test(
],
)
cc_library(
name = "detection_result",
hdrs = ["detection_result.h"],
deps = [":rect"],
)
cc_library(
name = "detection_result_converter",
srcs = ["detection_result_converter.cc"],
hdrs = ["detection_result_converter.h"],
deps = [
":category",
":category_converter",
":detection_result",
":keypoint",
":keypoint_converter",
":rect",
":rect_converter",
"//mediapipe/tasks/cc/components/containers:detection_result",
],
)
cc_test(
name = "detection_result_converter_test",
srcs = ["detection_result_converter_test.cc"],
deps = [
":detection_result",
":detection_result_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/components/containers:detection_result",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "embedding_result",
hdrs = ["embedding_result.h"],

View File

@ -0,0 +1,63 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
#include <stdint.h>
#include "mediapipe/tasks/c/components/containers/rect.h"
#ifdef __cplusplus
extern "C" {
#endif
// Detection for a single bounding box.
struct Detection {
// An array of detected categories.
struct Category* categories;
// The number of elements in the categories array.
uint32_t categories_count;
// The bounding box location.
struct MPRect bounding_box;
// Optional list of keypoints associated with the detection. Keypoints
// represent interesting points related to the detection. For example, the
// keypoints represent the eye, ear and mouth from face detection model. Or
// in the template matching detection, e.g. KNIFT, they can represent the
// feature points for template matching.
// `nullptr` if keypoints is not present.
struct NormalizedKeypoint* keypoints;
// The number of elements in the keypoints array. 0 if keypoints do not exist.
uint32_t keypoints_count;
};
// Detection results of a model.
struct DetectionResult {
// An array of Detections.
struct Detection* detections;
// The number of detections in the detections array.
uint32_t detections_count;
};
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_

View File

@ -0,0 +1,86 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/detection_result_converter.h"
#include <cstdlib>
#include "mediapipe/tasks/c/components/containers/category.h"
#include "mediapipe/tasks/c/components/containers/category_converter.h"
#include "mediapipe/tasks/c/components/containers/detection_result.h"
#include "mediapipe/tasks/c/components/containers/keypoint.h"
#include "mediapipe/tasks/c/components/containers/keypoint_converter.h"
#include "mediapipe/tasks/c/components/containers/rect_converter.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToDetection(
const mediapipe::tasks::components::containers::Detection& in,
::Detection* out) {
out->categories_count = in.categories.size();
out->categories = new Category[out->categories_count];
for (size_t i = 0; i < out->categories_count; ++i) {
CppConvertToCategory(in.categories[i], &out->categories[i]);
}
CppConvertToRect(in.bounding_box, &out->bounding_box);
if (in.keypoints.has_value()) {
auto& keypoints = in.keypoints.value();
out->keypoints_count = keypoints.size();
out->keypoints = new NormalizedKeypoint[out->keypoints_count];
for (size_t i = 0; i < out->keypoints_count; ++i) {
CppConvertToNormalizedKeypoint(keypoints[i], &out->keypoints[i]);
}
} else {
out->keypoints = nullptr;
out->keypoints_count = 0;
}
}
void CppConvertToDetectionResult(
const mediapipe::tasks::components::containers::DetectionResult& in,
::DetectionResult* out) {
out->detections_count = in.detections.size();
out->detections = new ::Detection[out->detections_count];
for (size_t i = 0; i < out->detections_count; ++i) {
CppConvertToDetection(in.detections[i], &out->detections[i]);
}
}
// Functions to free the memory of C structures.
void CppCloseDetection(::Detection* in) {
for (size_t i = 0; i < in->categories_count; ++i) {
CppCloseCategory(&in->categories[i]);
}
delete[] in->categories;
in->categories = nullptr;
for (size_t i = 0; i < in->keypoints_count; ++i) {
CppCloseNormalizedKeypoint(&in->keypoints[i]);
}
delete[] in->keypoints;
in->keypoints = nullptr;
}
void CppCloseDetectionResult(::DetectionResult* in) {
for (size_t i = 0; i < in->detections_count; ++i) {
CppCloseDetection(&in->detections[i]);
}
delete[] in->detections;
in->detections = nullptr;
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -0,0 +1,38 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_
#include "mediapipe/tasks/c/components/containers/detection_result.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToDetection(
const mediapipe::tasks::components::containers::Detection& in,
Detection* out);
void CppConvertToDetectionResult(
const mediapipe::tasks::components::containers::DetectionResult& in,
DetectionResult* out);
void CppCloseDetection(Detection* in);
void CppCloseDetectionResult(DetectionResult* in);
} // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_

View File

@ -0,0 +1,74 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/detection_result_converter.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/detection_result.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
namespace mediapipe::tasks::c::components::containers {
TEST(DetectionResultConverterTest, ConvertsDetectionResultCustomCategory) {
mediapipe::tasks::components::containers::DetectionResult
cpp_detection_result = {/* detections= */ {
{/* categories= */ {{/* index= */ 1, /* score= */ 0.1,
/* category_name= */ "cat",
/* display_name= */ "cat"}},
/* bounding_box= */ {10, 11, 12, 13},
{/* keypoints */ {{0.1, 0.1, "foo", 0.5}}}}}};
DetectionResult c_detection_result;
CppConvertToDetectionResult(cpp_detection_result, &c_detection_result);
EXPECT_NE(c_detection_result.detections, nullptr);
EXPECT_EQ(c_detection_result.detections_count, 1);
EXPECT_NE(c_detection_result.detections[0].categories, nullptr);
EXPECT_EQ(c_detection_result.detections[0].categories_count, 1);
EXPECT_EQ(c_detection_result.detections[0].bounding_box.left, 10);
EXPECT_EQ(c_detection_result.detections[0].bounding_box.top, 11);
EXPECT_EQ(c_detection_result.detections[0].bounding_box.right, 12);
EXPECT_EQ(c_detection_result.detections[0].bounding_box.bottom, 13);
EXPECT_NE(c_detection_result.detections[0].keypoints, nullptr);
CppCloseDetectionResult(&c_detection_result);
}
TEST(DetectionResultConverterTest, ConvertsDetectionResultNoCategory) {
mediapipe::tasks::components::containers::DetectionResult
cpp_detection_result = {/* detections= */ {/* categories= */ {}}};
DetectionResult c_detection_result;
CppConvertToDetectionResult(cpp_detection_result, &c_detection_result);
EXPECT_NE(c_detection_result.detections, nullptr);
EXPECT_EQ(c_detection_result.detections_count, 1);
EXPECT_NE(c_detection_result.detections[0].categories, nullptr);
EXPECT_EQ(c_detection_result.detections[0].categories_count, 0);
CppCloseDetectionResult(&c_detection_result);
}
TEST(DetectionResultConverterTest, FreesMemory) {
mediapipe::tasks::components::containers::DetectionResult
cpp_detection_result = {/* detections= */ {{/* categories= */ {}}}};
DetectionResult c_detection_result;
CppConvertToDetectionResult(cpp_detection_result, &c_detection_result);
EXPECT_NE(c_detection_result.detections, nullptr);
CppCloseDetectionResult(&c_detection_result);
EXPECT_EQ(c_detection_result.detections, nullptr);
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -0,0 +1,46 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_
#ifdef __cplusplus
extern "C" {
#endif
// A keypoint, defined by the coordinates (x, y), normalized by the image
// dimensions.
struct NormalizedKeypoint {
// x in normalized image coordinates.
float x;
// y in normalized image coordinates.
float y;
// Optional label of the keypoint. `nullptr` if the label is not present.
char* label;
// Optional score of the keypoint.
float score;
// `True` if the score is valid.
bool has_score;
};
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_

View File

@ -0,0 +1,45 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/keypoint_converter.h"
#include <string.h> // IWYU pragma: for open source compule
#include <cstdlib>
#include "mediapipe/tasks/c/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToNormalizedKeypoint(
const mediapipe::tasks::components::containers::NormalizedKeypoint& in,
NormalizedKeypoint* out) {
out->x = in.x;
out->y = in.y;
out->label = in.label.has_value() ? strdup(in.label->c_str()) : nullptr;
out->has_score = in.score.has_value();
out->score = out->has_score ? in.score.value() : 0;
}
void CppCloseNormalizedKeypoint(NormalizedKeypoint* keypoint) {
if (keypoint && keypoint->label) {
free(keypoint->label);
keypoint->label = nullptr;
}
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -0,0 +1,32 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_
#include "mediapipe/tasks/c/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToNormalizedKeypoint(
const mediapipe::tasks::components::containers::NormalizedKeypoint& in,
NormalizedKeypoint* out);
void CppCloseNormalizedKeypoint(NormalizedKeypoint* keypoint);
} // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_

View File

@ -0,0 +1,52 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/keypoint_converter.h"
#include <string>
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
namespace mediapipe::tasks::c::components::containers {
constexpr float kPrecision = 1e-6;
TEST(KeypointConverterTest, ConvertsKeypointCustomValues) {
mediapipe::tasks::components::containers::NormalizedKeypoint cpp_keypoint = {
0.1, 0.2, "foo", 0.5};
NormalizedKeypoint c_keypoint;
CppConvertToNormalizedKeypoint(cpp_keypoint, &c_keypoint);
EXPECT_NEAR(c_keypoint.x, 0.1f, kPrecision);
EXPECT_NEAR(c_keypoint.y, 0.2f, kPrecision);
EXPECT_EQ(std::string(c_keypoint.label), "foo");
EXPECT_NEAR(c_keypoint.score, 0.5f, kPrecision);
CppCloseNormalizedKeypoint(&c_keypoint);
}
TEST(KeypointConverterTest, FreesMemory) {
mediapipe::tasks::components::containers::NormalizedKeypoint cpp_keypoint = {
0.1, 0.2, "foo", 0.5};
NormalizedKeypoint c_keypoint;
CppConvertToNormalizedKeypoint(cpp_keypoint, &c_keypoint);
EXPECT_NE(c_keypoint.label, nullptr);
CppCloseNormalizedKeypoint(&c_keypoint);
EXPECT_EQ(c_keypoint.label, nullptr);
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -0,0 +1,46 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_
#ifdef __cplusplus
extern "C" {
#endif
// Defines a rectangle, used e.g. as part of detection results or as input
// region-of-interest.
struct MPRect {
int left;
int top;
int bottom;
int right;
};
// The coordinates are normalized wrt the image dimensions, i.e. generally in
// [0,1] but they may exceed these bounds if describing a region overlapping the
// image. The origin is on the top-left corner of the image.
struct MPRectF {
float left;
float top;
float bottom;
float right;
};
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_

View File

@ -0,0 +1,41 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/rect_converter.h"
#include "mediapipe/tasks/c/components/containers/rect.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::c::components::containers {
// Converts a C++ Rect to a C Rect.
void CppConvertToRect(const mediapipe::tasks::components::containers::Rect& in,
struct MPRect* out) {
out->left = in.left;
out->top = in.top;
out->right = in.right;
out->bottom = in.bottom;
}
// Converts a C++ RectF to a C RectF.
void CppConvertToRectF(
const mediapipe::tasks::components::containers::RectF& in, MPRectF* out) {
out->left = in.left;
out->top = in.top;
out->right = in.right;
out->bottom = in.bottom;
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -0,0 +1,32 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_
#include "mediapipe/tasks/c/components/containers/rect.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToRect(const mediapipe::tasks::components::containers::Rect& in,
MPRect* out);
void CppConvertToRectF(
const mediapipe::tasks::components::containers::RectF& in, MPRectF* out);
} // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_

View File

@ -0,0 +1,47 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/rect_converter.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/rect.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::c::components::containers {
TEST(RectConverterTest, ConvertsRectCustomValues) {
mediapipe::tasks::components::containers::Rect cpp_rect = {0, 1, 2, 3};
MPRect c_rect;
CppConvertToRect(cpp_rect, &c_rect);
EXPECT_EQ(c_rect.left, 0);
EXPECT_EQ(c_rect.top, 1);
EXPECT_EQ(c_rect.right, 2);
EXPECT_EQ(c_rect.bottom, 3);
}
TEST(RectFConverterTest, ConvertsRectFCustomValues) {
mediapipe::tasks::components::containers::RectF cpp_rect = {0.1, 0.2, 0.3,
0.4};
MPRectF c_rect;
CppConvertToRectF(cpp_rect, &c_rect);
EXPECT_FLOAT_EQ(c_rect.left, 0.1);
EXPECT_FLOAT_EQ(c_rect.top, 0.2);
EXPECT_FLOAT_EQ(c_rect.right, 0.3);
EXPECT_FLOAT_EQ(c_rect.bottom, 0.4);
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -60,12 +60,12 @@ struct ImageClassifierOptions {
//
// A caller is responsible for closing image classifier result.
typedef void (*result_callback_fn)(ImageClassifierResult* result,
const MpImage image, int64_t timestamp_ms,
const MpImage& image, int64_t timestamp_ms,
char* error_msg);
result_callback_fn result_callback;
};
// Creates an ImageClassifier from provided `options`.
// Creates an ImageClassifier from the provided `options`.
// Returns a pointer to the image classifier on success.
// If an error occurs, returns `nullptr` and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory

View File

@ -142,7 +142,7 @@ TEST(ImageClassifierTest, VideoModeTest) {
// timestamp is greater than the previous one.
struct LiveStreamModeCallback {
static int64_t last_timestamp;
static void Fn(ImageClassifierResult* classifier_result, const MpImage image,
static void Fn(ImageClassifierResult* classifier_result, const MpImage& image,
int64_t timestamp, char* error_msg) {
ASSERT_NE(classifier_result, nullptr);
ASSERT_EQ(error_msg, nullptr);

View File

@ -62,12 +62,12 @@ struct ImageEmbedderOptions {
//
// A caller is responsible for closing image embedder result.
typedef void (*result_callback_fn)(ImageEmbedderResult* result,
const MpImage image, int64_t timestamp_ms,
const MpImage& image, int64_t timestamp_ms,
char* error_msg);
result_callback_fn result_callback;
};
// Creates an ImageEmbedder from provided `options`.
// Creates an ImageEmbedder from the provided `options`.
// Returns a pointer to the image embedder on success.
// If an error occurs, returns `nullptr` and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory

View File

@ -199,7 +199,7 @@ TEST(ImageEmbedderTest, VideoModeTest) {
// timestamp is greater than the previous one.
struct LiveStreamModeCallback {
static int64_t last_timestamp;
static void Fn(ImageEmbedderResult* embedder_result, const MpImage image,
static void Fn(ImageEmbedderResult* embedder_result, const MpImage& image,
int64_t timestamp, char* error_msg) {
ASSERT_NE(embedder_result, nullptr);
ASSERT_EQ(error_msg, nullptr);

View File

@ -0,0 +1,65 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "object_detector_lib",
srcs = ["object_detector.cc"],
hdrs = ["object_detector.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/tasks/c/components/containers:detection_result",
"//mediapipe/tasks/c/components/containers:detection_result_converter",
"//mediapipe/tasks/c/core:base_options",
"//mediapipe/tasks/c/core:base_options_converter",
"//mediapipe/tasks/c/vision/core:common",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/object_detector",
"//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_test(
name = "object_detector_test",
srcs = ["object_detector_test.cc"],
data = [
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
],
linkstatic = 1,
deps = [
":object_detector_lib",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/c/components/containers:category",
"//mediapipe/tasks/c/vision/core:common",
"//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -0,0 +1,290 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/vision/object_detector/object_detector.h"
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <utility>
#include <vector>
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/tasks/c/components/containers/detection_result_converter.h"
#include "mediapipe/tasks/c/core/base_options_converter.h"
#include "mediapipe/tasks/c/vision/core/common.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/object_detector/object_detector.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe::tasks::c::vision::object_detector {
namespace {
using ::mediapipe::tasks::c::components::containers::CppCloseDetectionResult;
using ::mediapipe::tasks::c::components::containers::
CppConvertToDetectionResult;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
using ::mediapipe::tasks::vision::CreateImageFromBuffer;
using ::mediapipe::tasks::vision::ObjectDetector;
using ::mediapipe::tasks::vision::core::RunningMode;
typedef ::mediapipe::tasks::vision::ObjectDetectorResult
CppObjectDetectorResult;
int CppProcessError(absl::Status status, char** error_msg) {
if (error_msg) {
*error_msg = strdup(status.ToString().c_str());
}
return status.raw_code();
}
} // namespace
void CppConvertToDetectorOptions(
const ObjectDetectorOptions& in,
mediapipe::tasks::vision::ObjectDetectorOptions* out) {
out->display_names_locale =
in.display_names_locale ? std::string(in.display_names_locale) : "en";
out->max_results = in.max_results;
out->score_threshold = in.score_threshold;
out->category_allowlist =
std::vector<std::string>(in.category_allowlist_count);
for (uint32_t i = 0; i < in.category_allowlist_count; ++i) {
out->category_allowlist[i] = in.category_allowlist[i];
}
out->category_denylist = std::vector<std::string>(in.category_denylist_count);
for (uint32_t i = 0; i < in.category_denylist_count; ++i) {
out->category_denylist[i] = in.category_denylist[i];
}
}
ObjectDetector* CppObjectDetectorCreate(const ObjectDetectorOptions& options,
char** error_msg) {
auto cpp_options =
std::make_unique<::mediapipe::tasks::vision::ObjectDetectorOptions>();
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
CppConvertToDetectorOptions(options, cpp_options.get());
cpp_options->running_mode = static_cast<RunningMode>(options.running_mode);
// Enable callback for processing live stream data when the running mode is
// set to RunningMode::LIVE_STREAM.
if (cpp_options->running_mode == RunningMode::LIVE_STREAM) {
if (options.result_callback == nullptr) {
const absl::Status status = absl::InvalidArgumentError(
"Provided null pointer to callback function.");
ABSL_LOG(ERROR) << "Failed to create ObjectDetector: " << status;
CppProcessError(status, error_msg);
return nullptr;
}
ObjectDetectorOptions::result_callback_fn result_callback =
options.result_callback;
cpp_options->result_callback =
[result_callback](absl::StatusOr<CppObjectDetectorResult> cpp_result,
const Image& image, int64_t timestamp) {
char* error_msg = nullptr;
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status();
CppProcessError(cpp_result.status(), &error_msg);
result_callback(nullptr, MpImage(), timestamp, error_msg);
free(error_msg);
return;
}
// Result is valid for the lifetime of the callback function.
ObjectDetectorResult result;
CppConvertToDetectionResult(*cpp_result, &result);
const auto& image_frame = image.GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {
.format = static_cast<::ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
result_callback(&result, mp_image, timestamp,
/* error_msg= */ nullptr);
CppCloseDetectionResult(&result);
};
}
auto detector = ObjectDetector::Create(std::move(cpp_options));
if (!detector.ok()) {
ABSL_LOG(ERROR) << "Failed to create ObjectDetector: " << detector.status();
CppProcessError(detector.status(), error_msg);
return nullptr;
}
return detector->release();
}
int CppObjectDetectorDetect(void* detector, const MpImage* image,
ObjectDetectorResult* result, char** error_msg) {
if (image->type == MpImage::GPU_BUFFER) {
const absl::Status status =
absl::InvalidArgumentError("GPU Buffer not supported yet.");
ABSL_LOG(ERROR) << "Detection failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image->image_frame.format),
image->image_frame.image_buffer, image->image_frame.width,
image->image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_detector = static_cast<ObjectDetector*>(detector);
auto cpp_result = cpp_detector->Detect(*img);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status();
return CppProcessError(cpp_result.status(), error_msg);
}
CppConvertToDetectionResult(*cpp_result, result);
return 0;
}
int CppObjectDetectorDetectForVideo(void* detector, const MpImage* image,
int64_t timestamp_ms,
ObjectDetectorResult* result,
char** error_msg) {
if (image->type == MpImage::GPU_BUFFER) {
absl::Status status =
absl::InvalidArgumentError("GPU Buffer not supported yet");
ABSL_LOG(ERROR) << "Detection failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image->image_frame.format),
image->image_frame.image_buffer, image->image_frame.width,
image->image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_detector = static_cast<ObjectDetector*>(detector);
auto cpp_result = cpp_detector->DetectForVideo(*img, timestamp_ms);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status();
return CppProcessError(cpp_result.status(), error_msg);
}
CppConvertToDetectionResult(*cpp_result, result);
return 0;
}
int CppObjectDetectorDetectAsync(void* detector, const MpImage* image,
int64_t timestamp_ms, char** error_msg) {
if (image->type == MpImage::GPU_BUFFER) {
absl::Status status =
absl::InvalidArgumentError("GPU Buffer not supported yet");
ABSL_LOG(ERROR) << "Detection failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image->image_frame.format),
image->image_frame.image_buffer, image->image_frame.width,
image->image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_detector = static_cast<ObjectDetector*>(detector);
auto cpp_result = cpp_detector->DetectAsync(*img, timestamp_ms);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Data preparation for the object detection failed: "
<< cpp_result;
return CppProcessError(cpp_result, error_msg);
}
return 0;
}
void CppObjectDetectorCloseResult(ObjectDetectorResult* result) {
CppCloseDetectionResult(result);
}
int CppObjectDetectorClose(void* detector, char** error_msg) {
auto cpp_detector = static_cast<ObjectDetector*>(detector);
auto result = cpp_detector->Close();
if (!result.ok()) {
ABSL_LOG(ERROR) << "Failed to close ObjectDetector: " << result;
return CppProcessError(result, error_msg);
}
delete cpp_detector;
return 0;
}
} // namespace mediapipe::tasks::c::vision::object_detector
extern "C" {
void* object_detector_create(struct ObjectDetectorOptions* options,
char** error_msg) {
return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorCreate(
*options, error_msg);
}
int object_detector_detect_image(void* detector, const MpImage* image,
ObjectDetectorResult* result,
char** error_msg) {
return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorDetect(
detector, image, result, error_msg);
}
int object_detector_detect_for_video(void* detector, const MpImage* image,
int64_t timestamp_ms,
ObjectDetectorResult* result,
char** error_msg) {
return mediapipe::tasks::c::vision::object_detector::
CppObjectDetectorDetectForVideo(detector, image, timestamp_ms, result,
error_msg);
}
int object_detector_detect_async(void* detector, const MpImage* image,
int64_t timestamp_ms, char** error_msg) {
return mediapipe::tasks::c::vision::object_detector::
CppObjectDetectorDetectAsync(detector, image, timestamp_ms, error_msg);
}
void object_detector_close_result(ObjectDetectorResult* result) {
mediapipe::tasks::c::vision::object_detector::CppObjectDetectorCloseResult(
result);
}
int object_detector_close(void* detector, char** error_ms) {
return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorClose(
detector, error_ms);
}
} // extern "C"

View File

@ -0,0 +1,157 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_
#define MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_
#include "mediapipe/tasks/c/components/containers/detection_result.h"
#include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/c/vision/core/common.h"
#ifndef MP_EXPORT
#define MP_EXPORT __attribute__((visibility("default")))
#endif // MP_EXPORT
#ifdef __cplusplus
extern "C" {
#endif
typedef DetectionResult ObjectDetectorResult;
// The options for configuring a MediaPipe object detector task.
struct ObjectDetectorOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
struct BaseOptions base_options;
// The running mode of the task. Default to the image mode.
// Object detector has three running modes:
// 1) The image mode for detecting objects on single image inputs.
// 2) The video mode for detecting objects on the decoded frames of a video.
// 3) The live stream mode for detecting objects on the live stream of input
// data, such as from camera. In this mode, the "result_callback" below must
// be specified to receive the detection results asynchronously.
RunningMode running_mode;
// The locale to use for display names specified through the TFLite Model
// Metadata, if any. Defaults to English.
const char* display_names_locale;
// The maximum number of top-scored detection results to return. If < 0,
// all available results will be returned. If 0, an invalid argument error is
// returned.
int max_results;
// Score threshold to override the one provided in the model metadata (if
// any). Results below this value are rejected.
float score_threshold;
// The allowlist of category names. If non-empty, detection results whose
// category name is not in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_denylist.
const char** category_allowlist;
// The number of elements in the category allowlist.
uint32_t category_allowlist_count;
// The denylist of category names. If non-empty, detection results whose
// category name is in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_allowlist.
const char** category_denylist;
// The number of elements in the category denylist.
uint32_t category_denylist_count;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. Arguments of the callback function include:
// the pointer to detection result, the image that result was obtained
// on, the timestamp relevant to detection results and pointer to error
// message in case of any failure. The validity of the passed arguments is
// true for the lifetime of the callback function.
//
// A caller is responsible for closing object detector result.
typedef void (*result_callback_fn)(ObjectDetectorResult* result,
const MpImage& image, int64_t timestamp_ms,
char* error_msg);
result_callback_fn result_callback;
};
// Creates an ObjectDetector from the provided `options`.
// Returns a pointer to the image detector on success.
// If an error occurs, returns `nullptr` and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT void* object_detector_create(struct ObjectDetectorOptions* options,
char** error_msg);
// Performs image detection on the input `image`. Returns `0` on success.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int object_detector_detect_image(void* detector, const MpImage* image,
ObjectDetectorResult* result,
char** error_msg);
// Performs image detection on the provided video frame.
// Only use this method when the ObjectDetector is created with the video
// running mode.
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int object_detector_detect_for_video(void* detector,
const MpImage* image,
int64_t timestamp_ms,
ObjectDetectorResult* result,
char** error_msg);
// Sends live image data to image detection, and the results will be
// available via the `result_callback` provided in the ObjectDetectorOptions.
// Only use this method when the ObjectDetector is created with the live
// stream running mode.
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the object detector. The input timestamps must be monotonically
// increasing.
// The `result_callback` provides:
// - The detection results as an ObjectDetectorResult object.
// - The const reference to the corresponding input image that the image
// detector runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int object_detector_detect_async(void* detector, const MpImage* image,
int64_t timestamp_ms,
char** error_msg);
// Frees the memory allocated inside a ObjectDetectorResult result.
// Does not free the result pointer itself.
MP_EXPORT void object_detector_close_result(ObjectDetectorResult* result);
// Frees object detector.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int object_detector_close(void* detector, char** error_msg);
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_

View File

@ -0,0 +1,253 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/vision/object_detector/object_detector.h"
#include <cstdint>
#include <cstdlib>
#include <string>
#include "absl/flags/flag.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/category.h"
#include "mediapipe/tasks/c/vision/core/common.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::vision::DecodeImageFromFile;
using testing::HasSubstr;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kImageFile[] = "cats_and_dogs.jpg";
constexpr char kModelName[] =
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
constexpr float kPrecision = 1e-4;
constexpr int kIterations = 100;
std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
}
TEST(ObjectDetectorTest, ImageModeTest) {
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
ObjectDetectorOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE,
/* display_names_locale= */ nullptr,
/* max_results= */ -1,
/* score_threshold= */ 0.0,
/* category_allowlist= */ nullptr,
/* category_allowlist_count= */ 0,
/* category_denylist= */ nullptr,
/* category_denylist_count= */ 0,
};
void* detector = object_detector_create(&options, /* error_msg */ nullptr);
EXPECT_NE(detector, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
ObjectDetectorResult result;
object_detector_detect_image(detector, &mp_image, &result,
/* error_msg */ nullptr);
EXPECT_EQ(result.detections_count, 10);
EXPECT_EQ(result.detections[0].categories_count, 1);
EXPECT_EQ(std::string{result.detections[0].categories[0].category_name},
"cat");
EXPECT_NEAR(result.detections[0].categories[0].score, 0.6992f, kPrecision);
object_detector_close_result(&result);
object_detector_close(detector, /* error_msg */ nullptr);
}
TEST(ObjectDetectorTest, VideoModeTest) {
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
ObjectDetectorOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::VIDEO,
/* display_names_locale= */ nullptr,
/* max_results= */ 3,
/* score_threshold= */ 0.0,
/* category_allowlist= */ nullptr,
/* category_allowlist_count= */ 0,
/* category_denylist= */ nullptr,
/* category_denylist_count= */ 0,
};
void* detector = object_detector_create(&options, /* error_msg */ nullptr);
EXPECT_NE(detector, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
for (int i = 0; i < kIterations; ++i) {
ObjectDetectorResult result;
object_detector_detect_for_video(detector, &mp_image, i, &result,
/* error_msg */ nullptr);
EXPECT_EQ(result.detections_count, 3);
EXPECT_EQ(result.detections[0].categories_count, 1);
EXPECT_EQ(std::string{result.detections[0].categories[0].category_name},
"cat");
EXPECT_NEAR(result.detections[0].categories[0].score, 0.6992f, kPrecision);
object_detector_close_result(&result);
}
object_detector_close(detector, /* error_msg */ nullptr);
}
// A structure to support LiveStreamModeTest below. This structure holds a
// static method `Fn` for a callback function of C API. A `static` qualifier
// allows to take an address of the method to follow API style. Another static
// struct member is `last_timestamp` that is used to verify that current
// timestamp is greater than the previous one.
struct LiveStreamModeCallback {
static int64_t last_timestamp;
static void Fn(ObjectDetectorResult* detector_result, const MpImage& image,
int64_t timestamp, char* error_msg) {
ASSERT_NE(detector_result, nullptr);
ASSERT_EQ(error_msg, nullptr);
EXPECT_EQ(detector_result->detections_count, 3);
EXPECT_EQ(detector_result->detections[0].categories_count, 1);
EXPECT_EQ(
std::string{detector_result->detections[0].categories[0].category_name},
"cat");
EXPECT_NEAR(detector_result->detections[0].categories[0].score, 0.6992f,
kPrecision);
EXPECT_GT(image.image_frame.width, 0);
EXPECT_GT(image.image_frame.height, 0);
EXPECT_GT(timestamp, last_timestamp);
last_timestamp++;
}
};
int64_t LiveStreamModeCallback::last_timestamp = -1;
TEST(ObjectDetectorTest, LiveStreamModeTest) {
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
ObjectDetectorOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::LIVE_STREAM,
/* display_names_locale= */ nullptr,
/* max_results= */ 3,
/* score_threshold= */ 0.0,
/* category_allowlist= */ nullptr,
/* category_allowlist_count= */ 0,
/* category_denylist= */ nullptr,
/* category_denylist_count= */ 0,
/* result_callback= */ LiveStreamModeCallback::Fn,
};
void* detector = object_detector_create(&options, /* error_msg */
nullptr);
EXPECT_NE(detector, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
for (int i = 0; i < kIterations; ++i) {
EXPECT_GE(object_detector_detect_async(detector, &mp_image, i,
/* error_msg */ nullptr),
0);
}
object_detector_close(detector, /* error_msg */ nullptr);
// Due to the flow limiter, the total of outputs might be smaller than the
// number of iterations.
EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations);
EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0);
}
TEST(ObjectDetectorTest, InvalidArgumentHandling) {
// It is an error to set neither the asset buffer nor the path.
ObjectDetectorOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ nullptr},
};
char* error_msg;
void* detector = object_detector_create(&options, &error_msg);
EXPECT_EQ(detector, nullptr);
EXPECT_THAT(error_msg, HasSubstr("ExternalFile must specify"));
free(error_msg);
}
TEST(ObjectDetectorTest, FailedDetectionHandling) {
const std::string model_path = GetFullPath(kModelName);
ObjectDetectorOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE,
/* display_names_locale= */ nullptr,
/* max_results= */ -1,
/* score_threshold= */ 0.0,
/* category_allowlist= */ nullptr,
/* category_allowlist_count= */ 0,
/* category_denylist= */ nullptr,
/* category_denylist_count= */ 0,
};
void* detector = object_detector_create(&options, /* error_msg */
nullptr);
EXPECT_NE(detector, nullptr);
const MpImage mp_image = {.type = MpImage::GPU_BUFFER, .gpu_buffer = {}};
ObjectDetectorResult result;
char* error_msg;
object_detector_detect_image(detector, &mp_image, &result, &error_msg);
EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet"));
free(error_msg);
object_detector_close(detector, /* error_msg */ nullptr);
}
} // namespace

View File

@ -155,6 +155,37 @@ cc_library(
# TODO: open source hand joints graph
cc_library(
name = "hand_roi_refinement_graph",
srcs = ["hand_roi_refinement_graph.cc"],
deps = [
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2/stream:detections_to_rects",
"//mediapipe/framework/api2/stream:landmarks_projection",
"//mediapipe/framework/api2/stream:landmarks_to_detection",
"//mediapipe/framework/api2/stream:rect_transformation",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_library(
name = "hand_landmarker_result",
srcs = ["hand_landmarker_result.cc"],

View File

@ -0,0 +1,154 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
#include "mediapipe/framework/api2/stream/landmarks_projection.h"
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
#include "mediapipe/framework/api2/stream/rect_transformation.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionToRect;
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::ProjectLandmarks;
using ::mediapipe::api2::builder::ScaleAndShiftAndMakeSquareLong;
using ::mediapipe::api2::builder::Stream;
// Refine the input hand RoI with hand_roi_refinement model.
//
// Inputs:
// IMAGE - Image
// The image to preprocess.
// NORM_RECT - NormalizedRect
// Coarse RoI of hand.
// Outputs:
// NORM_RECT - NormalizedRect
// Refined RoI of hand.
class HandRoiRefinementGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* context) override {
Graph graph;
Stream<Image> image_in = graph.In("IMAGE").Cast<Image>();
Stream<NormalizedRect> roi_in =
graph.In("NORM_RECT").Cast<NormalizedRect>();
auto& graph_options =
*context->MutableOptions<proto::HandRoiRefinementGraphOptions>();
MP_ASSIGN_OR_RETURN(
const auto* model_resources,
GetOrCreateModelResources<proto::HandRoiRefinementGraphOptions>(
context));
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
graph_options.base_options().acceleration());
auto& image_to_tensor_options =
*preprocessing
.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()
.mutable_image_to_tensor_options();
image_to_tensor_options.set_keep_aspect_ratio(true);
image_to_tensor_options.set_border_mode(
mediapipe::ImageToTensorCalculatorOptions::BORDER_REPLICATE);
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
*model_resources, use_gpu, graph_options.base_options().gpu_origin(),
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In("IMAGE");
roi_in >> preprocessing.In("NORM_RECT");
auto tensors_in = preprocessing.Out("TENSORS");
auto matrix = preprocessing.Out("MATRIX").Cast<std::array<float, 16>>();
auto image_size =
preprocessing.Out("IMAGE_SIZE").Cast<std::pair<int, int>>();
auto& inference = AddInference(
*model_resources, graph_options.base_options().acceleration(), graph);
tensors_in >> inference.In("TENSORS");
auto tensors_out = inference.Out("TENSORS").Cast<std::vector<Tensor>>();
MP_ASSIGN_OR_RETURN(auto image_tensor_specs,
BuildInputImageTensorSpecs(*model_resources));
// Convert tensors to landmarks. Recrop model outputs two points,
// center point and guide point.
auto& to_landmarks = graph.AddNode("TensorsToLandmarksCalculator");
auto& to_landmarks_opts =
to_landmarks
.GetOptions<mediapipe::TensorsToLandmarksCalculatorOptions>();
to_landmarks_opts.set_num_landmarks(/*num_landmarks=*/2);
to_landmarks_opts.set_input_image_width(image_tensor_specs.image_width);
to_landmarks_opts.set_input_image_height(image_tensor_specs.image_height);
to_landmarks_opts.set_normalize_z(/*z_norm_factor=*/1.0f);
tensors_out.ConnectTo(to_landmarks.In("TENSORS"));
auto recrop_landmarks = to_landmarks.Out("NORM_LANDMARKS")
.Cast<mediapipe::NormalizedLandmarkList>();
// Project landmarks.
auto projected_recrop_landmarks =
ProjectLandmarks(recrop_landmarks, matrix, graph);
// Convert re-crop landmarks to detection.
auto recrop_detection =
ConvertLandmarksToDetection(projected_recrop_landmarks, graph);
// Convert re-crop detection to rect.
auto recrop_rect = ConvertAlignmentPointsDetectionToRect(
recrop_detection, image_size, /*start_keypoint_index=*/0,
/*end_keypoint_index=*/1, /*target_angle=*/-90, graph);
auto refined_roi =
ScaleAndShiftAndMakeSquareLong(recrop_rect, image_size,
/*scale_x_factor=*/1.0,
/*scale_y_factor=*/1.0, /*shift_x=*/0,
/*shift_y=*/-0.1, graph);
refined_roi >> graph.Out("NORM_RECT").Cast<NormalizedRect>();
return graph.GetConfig();
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::hand_landmarker::HandRoiRefinementGraph);
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,152 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe/tasks:internal"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "holistic_face_tracking",
srcs = ["holistic_face_tracking.cc"],
hdrs = ["holistic_face_tracking.h"],
deps = [
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2/stream:detections_to_rects",
"//mediapipe/framework/api2/stream:image_size",
"//mediapipe/framework/api2/stream:landmarks_to_detection",
"//mediapipe/framework/api2/stream:loopback",
"//mediapipe/framework/api2/stream:rect_transformation",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:status",
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator",
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker:face_blendshapes_graph",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
],
)
cc_library(
name = "holistic_hand_tracking",
srcs = ["holistic_hand_tracking.cc"],
hdrs = ["holistic_hand_tracking.h"],
deps = [
"//mediapipe/calculators/util:align_hand_to_pose_in_world_calculator",
"//mediapipe/calculators/util:align_hand_to_pose_in_world_calculator_cc_proto",
"//mediapipe/calculators/util:landmark_visibility_calculator",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2/stream:image_size",
"//mediapipe/framework/api2/stream:landmarks_to_detection",
"//mediapipe/framework/api2/stream:loopback",
"//mediapipe/framework/api2/stream:rect_transformation",
"//mediapipe/framework/api2/stream:split",
"//mediapipe/framework/api2/stream:threshold",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:status",
"//mediapipe/modules/holistic_landmark/calculators:hand_detections_from_pose_to_rects_calculator",
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator",
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator_cc_proto",
"//mediapipe/tasks/cc/components/utils:gate",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_roi_refinement_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "holistic_pose_tracking",
srcs = ["holistic_pose_tracking.cc"],
hdrs = ["holistic_pose_tracking.h"],
deps = [
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2/stream:detections_to_rects",
"//mediapipe/framework/api2/stream:image_size",
"//mediapipe/framework/api2/stream:landmarks_to_detection",
"//mediapipe/framework/api2/stream:loopback",
"//mediapipe/framework/api2/stream:merge",
"//mediapipe/framework/api2/stream:presence",
"//mediapipe/framework/api2/stream:rect_transformation",
"//mediapipe/framework/api2/stream:segmentation_smoothing",
"//mediapipe/framework/api2/stream:smoothing",
"//mediapipe/framework/api2/stream:split",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/utils:gate",
"//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/pose_landmarker:pose_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "holistic_landmarker_graph",
srcs = ["holistic_landmarker_graph.cc"],
deps = [
":holistic_face_tracking",
":holistic_hand_tracking",
":holistic_pose_tracking",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2/stream:split",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/pose_landmarker:pose_topology",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
"//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)

View File

@ -0,0 +1,260 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h"
#include <functional>
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
#include "mediapipe/framework/api2/stream/loopback.h"
#include "mediapipe/framework/api2/stream/rect_transformation.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::builder::ConvertDetectionsToRectUsingKeypoints;
using ::mediapipe::api2::builder::ConvertDetectionToRect;
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::GetLoopbackData;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Scale;
using ::mediapipe::api2::builder::ScaleAndMakeSquare;
using ::mediapipe::api2::builder::Stream;
struct FaceLandmarksResult {
std::optional<Stream<NormalizedLandmarkList>> landmarks;
std::optional<Stream<ClassificationList>> classifications;
};
absl::Status ValidateGraphOptions(
const face_detector::proto::FaceDetectorGraphOptions&
face_detector_graph_options,
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
face_landmarks_detector_graph_options,
const HolisticFaceTrackingRequest& request) {
if (face_detector_graph_options.num_faces() != 1) {
return absl::InvalidArgumentError(absl::StrFormat(
"Only support num_faces to be 1, but got num_faces = %d.",
face_detector_graph_options.num_faces()));
}
if (request.classifications && !face_landmarks_detector_graph_options
.has_face_blendshapes_graph_options()) {
return absl::InvalidArgumentError(
"Blendshapes detection is requested, but "
"face_blendshapes_graph_options is not configured.");
}
return absl::OkStatus();
}
Stream<NormalizedRect> GetFaceRoiFromPoseFaceLandmarks(
Stream<NormalizedLandmarkList> pose_face_landmarks,
Stream<std::pair<int, int>> image_size, Graph& graph) {
Stream<mediapipe::Detection> detection =
ConvertLandmarksToDetection(pose_face_landmarks, graph);
// Refer the pose face landmarks indices here:
// https://developers.google.com/mediapipe/solutions/vision/pose_landmarker#pose_landmarker_model
Stream<NormalizedRect> rect = ConvertDetectionToRect(
detection, image_size, /*start_keypoint_index=*/5,
/*end_keypoint_index=*/2, /*target_angle=*/0, graph);
// Scale the face RoI from a tight rect enclosing the pose face landmarks, to
// a larger square so that the whole face is within the RoI.
return ScaleAndMakeSquare(rect, image_size,
/*scale_x_factor=*/3.0,
/*scale_y_factor=*/3.0, graph);
}
Stream<NormalizedRect> GetFaceRoiFromFaceLandmarks(
Stream<NormalizedLandmarkList> face_landmarks,
Stream<std::pair<int, int>> image_size, Graph& graph) {
Stream<mediapipe::Detection> detection =
ConvertLandmarksToDetection(face_landmarks, graph);
Stream<NormalizedRect> rect = ConvertDetectionToRect(
detection, image_size, /*start_keypoint_index=*/33,
/*end_keypoint_index=*/263, /*target_angle=*/0, graph);
return Scale(rect, image_size,
/*scale_x_factor=*/1.5,
/*scale_y_factor=*/1.5, graph);
}
Stream<std::vector<Detection>> GetFaceDetections(
Stream<Image> image, Stream<NormalizedRect> roi,
const face_detector::proto::FaceDetectorGraphOptions&
face_detector_graph_options,
Graph& graph) {
auto& face_detector_graph =
graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph");
face_detector_graph
.GetOptions<face_detector::proto::FaceDetectorGraphOptions>() =
face_detector_graph_options;
image >> face_detector_graph.In("IMAGE");
roi >> face_detector_graph.In("NORM_RECT");
return face_detector_graph.Out("DETECTIONS").Cast<std::vector<Detection>>();
}
Stream<NormalizedRect> GetFaceRoiFromFaceDetections(
Stream<std::vector<Detection>> face_detections,
Stream<std::pair<int, int>> image_size, Graph& graph) {
// Convert detection to rect.
Stream<NormalizedRect> rect = ConvertDetectionsToRectUsingKeypoints(
face_detections, image_size, /*start_keypoint_index=*/0,
/*end_keypoint_index=*/1, /*target_angle=*/0, graph);
return ScaleAndMakeSquare(rect, image_size,
/*scale_x_factor=*/2.0,
/*scale_y_factor=*/2.0, graph);
}
Stream<NormalizedRect> TrackFaceRoi(
Stream<NormalizedLandmarkList> prev_landmarks, Stream<NormalizedRect> roi,
Stream<std::pair<int, int>> image_size, Graph& graph) {
// Gets face ROI from previous frame face landmarks.
Stream<NormalizedRect> prev_roi =
GetFaceRoiFromFaceLandmarks(prev_landmarks, image_size, graph);
auto& tracking_node = graph.AddNode("RoiTrackingCalculator");
auto& tracking_node_opts =
tracking_node.GetOptions<RoiTrackingCalculatorOptions>();
auto* rect_requirements = tracking_node_opts.mutable_rect_requirements();
rect_requirements->set_rotation_degrees(15.0);
rect_requirements->set_translation(0.1);
rect_requirements->set_scale(0.3);
auto* landmarks_requirements =
tracking_node_opts.mutable_landmarks_requirements();
landmarks_requirements->set_recrop_rect_margin(-0.2);
prev_landmarks.ConnectTo(tracking_node.In("PREV_LANDMARKS"));
prev_roi.ConnectTo(tracking_node.In("PREV_LANDMARKS_RECT"));
roi.ConnectTo(tracking_node.In("RECROP_RECT"));
image_size.ConnectTo(tracking_node.In("IMAGE_SIZE"));
return tracking_node.Out("TRACKING_RECT").Cast<NormalizedRect>();
}
FaceLandmarksResult GetFaceLandmarksDetection(
Stream<Image> image, Stream<NormalizedRect> roi,
Stream<std::pair<int, int>> image_size,
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
face_landmarks_detector_graph_options,
const HolisticFaceTrackingRequest& request, Graph& graph) {
FaceLandmarksResult result;
auto& face_landmarks_detector_graph = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker."
"SingleFaceLandmarksDetectorGraph");
face_landmarks_detector_graph
.GetOptions<face_landmarker::proto::FaceLandmarksDetectorGraphOptions>() =
face_landmarks_detector_graph_options;
image >> face_landmarks_detector_graph.In("IMAGE");
roi >> face_landmarks_detector_graph.In("NORM_RECT");
auto landmarks = face_landmarks_detector_graph.Out("NORM_LANDMARKS")
.Cast<NormalizedLandmarkList>();
result.landmarks = landmarks;
if (request.classifications) {
auto& blendshapes_graph = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker.FaceBlendshapesGraph");
blendshapes_graph
.GetOptions<face_landmarker::proto::FaceBlendshapesGraphOptions>() =
face_landmarks_detector_graph_options.face_blendshapes_graph_options();
landmarks >> blendshapes_graph.In("LANDMARKS");
image_size >> blendshapes_graph.In("IMAGE_SIZE");
result.classifications =
blendshapes_graph.Out("BLENDSHAPES").Cast<ClassificationList>();
}
return result;
}
} // namespace
absl::StatusOr<HolisticFaceTrackingOutput> TrackHolisticFace(
Stream<Image> image, Stream<NormalizedLandmarkList> pose_face_landmarks,
const face_detector::proto::FaceDetectorGraphOptions&
face_detector_graph_options,
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
face_landmarks_detector_graph_options,
const HolisticFaceTrackingRequest& request, Graph& graph) {
MP_RETURN_IF_ERROR(ValidateGraphOptions(face_detector_graph_options,
face_landmarks_detector_graph_options,
request));
// Extracts image size from the input images.
Stream<std::pair<int, int>> image_size = GetImageSize(image, graph);
// Gets face ROI from pose face landmarks.
Stream<NormalizedRect> roi_from_pose =
GetFaceRoiFromPoseFaceLandmarks(pose_face_landmarks, image_size, graph);
// Detects faces within ROI of pose face.
Stream<std::vector<Detection>> face_detections = GetFaceDetections(
image, roi_from_pose, face_detector_graph_options, graph);
// Gets face ROI from face detector.
Stream<NormalizedRect> roi_from_detection =
GetFaceRoiFromFaceDetections(face_detections, image_size, graph);
// Loop for previous frame landmarks.
auto [prev_landmarks, set_prev_landmarks_fn] =
GetLoopbackData<NormalizedLandmarkList>(/*tick=*/image_size, graph);
// Tracks face ROI.
auto tracking_roi =
TrackFaceRoi(prev_landmarks, roi_from_detection, image_size, graph);
// Predicts face landmarks.
auto landmarks_detection_result = GetFaceLandmarksDetection(
image, tracking_roi, image_size, face_landmarks_detector_graph_options,
request, graph);
// Sets previous landmarks for ROI tracking.
set_prev_landmarks_fn(landmarks_detection_result.landmarks.value());
return {{.landmarks = landmarks_detection_result.landmarks,
.classifications = landmarks_detection_result.classifications,
.debug_output = {
.roi_from_pose = roi_from_pose,
.roi_from_detection = roi_from_detection,
.tracking_roi = tracking_roi,
}}};
}
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,89 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_
#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_
#include <optional>
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
struct HolisticFaceTrackingRequest {
bool classifications = false;
};
struct HolisticFaceTrackingOutput {
std::optional<api2::builder::Stream<mediapipe::NormalizedLandmarkList>>
landmarks;
std::optional<api2::builder::Stream<mediapipe::ClassificationList>>
classifications;
struct DebugOutput {
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_pose;
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_detection;
api2::builder::Stream<mediapipe::NormalizedRect> tracking_roi;
};
DebugOutput debug_output;
};
// Updates @graph to track a single face in @image based on pose landmarks.
//
// To track single face this subgraph uses pose face landmarks to obtain
// approximate face location, refines it with face detector model and then runs
// face landmarks model. It can also reuse face ROI from the previous frame if
// face hasn't moved too much.
//
// @image - Image to track a single face in.
// @pose_face_landmarks - Pose face landmarks to derive initial face location
// from.
// @face_detector_graph_options - face detector graph options used to detect the
// face within the RoI constructed from the pose face landmarks.
// @face_landmarks_detector_graph_options - face landmarks detector graph
// options used to detect face landmarks within the RoI given be the face
// detector graph.
// @request - object to request specific face tracking outputs.
// NOTE: Outputs that were not requested won't be returned and corresponding
// parts of the graph won't be genertaed.
// @graph - graph to update.
absl::StatusOr<HolisticFaceTrackingOutput> TrackHolisticFace(
api2::builder::Stream<Image> image,
api2::builder::Stream<mediapipe::NormalizedLandmarkList>
pose_face_landmarks,
const face_detector::proto::FaceDetectorGraphOptions&
face_detector_graph_options,
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
face_landmarks_detector_graph_options,
const HolisticFaceTrackingRequest& request,
mediapipe::api2::builder::Graph& graph);
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_

View File

@ -0,0 +1,227 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h"
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
#include "mediapipe/calculators/util/rect_to_render_data_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/api2/stream/split.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::Image;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::SplitToRanges;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::core::ModelAssetBundleResources;
using ::mediapipe::tasks::core::TaskRunner;
using ::mediapipe::tasks::core::proto::ExternalFile;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr float kAbsMargin = 0.015;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kTestImageFile[] = "male_full_height_hands.jpg";
constexpr char kHolisticResultFile[] =
"male_full_height_hands_result_cpu.pbtxt";
constexpr char kImageInStream[] = "image_in";
constexpr char kPoseLandmarksInStream[] = "pose_landmarks_in";
constexpr char kFaceLandmarksOutStream[] = "face_landmarks_out";
constexpr char kRenderedImageOutStream[] = "rendered_image_out";
constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite";
constexpr char kFaceLandmarksDetectorTFLiteName[] =
"face_landmarks_detector.tflite";
std::string GetFilePath(absl::string_view filename) {
return file::JoinPath("./", kTestDataDirectory, filename);
}
mediapipe::LandmarksToRenderDataCalculatorOptions GetFaceRendererOptions() {
mediapipe::LandmarksToRenderDataCalculatorOptions render_options;
for (const auto& connection :
face_landmarker::FaceLandmarksConnections::kFaceLandmarksConnectors) {
render_options.add_landmark_connections(connection[0]);
render_options.add_landmark_connections(connection[1]);
}
render_options.mutable_landmark_color()->set_r(255);
render_options.mutable_landmark_color()->set_g(255);
render_options.mutable_landmark_color()->set_b(255);
render_options.mutable_connection_color()->set_r(255);
render_options.mutable_connection_color()->set_g(255);
render_options.mutable_connection_color()->set_b(255);
render_options.set_thickness(0.5);
render_options.set_visualize_landmark_depth(false);
return render_options;
}
absl::StatusOr<std::unique_ptr<ModelAssetBundleResources>>
CreateModelAssetBundleResources(const std::string& model_asset_filename) {
auto external_model_bundle = std::make_unique<ExternalFile>();
external_model_bundle->set_file_name(model_asset_filename);
return ModelAssetBundleResources::Create("",
std::move(external_model_bundle));
}
// Helper function to create a TaskRunner.
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner() {
Graph graph;
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
Stream<mediapipe::NormalizedLandmarkList> pose_landmarks =
graph.In("POSE_LANDMARKS")
.Cast<mediapipe::NormalizedLandmarkList>()
.SetName(kPoseLandmarksInStream);
Stream<NormalizedLandmarkList> face_landmarks_from_pose =
SplitToRanges(pose_landmarks, {{0, 11}}, graph)[0];
// Create face landmarker model bundle.
MP_ASSIGN_OR_RETURN(
auto model_bundle,
CreateModelAssetBundleResources(GetFilePath("face_landmarker_v2.task")));
face_detector::proto::FaceDetectorGraphOptions detector_options;
face_landmarker::proto::FaceLandmarksDetectorGraphOptions
landmarks_detector_options;
// Set face detection model.
MP_ASSIGN_OR_RETURN(auto face_detector_model_file,
model_bundle->GetFile(kFaceDetectorTFLiteName));
core::proto::FilePointerMeta face_detection_file_pointer;
face_detection_file_pointer.set_pointer(
reinterpret_cast<uint64_t>(face_detector_model_file.data()));
face_detection_file_pointer.set_length(face_detector_model_file.size());
detector_options.mutable_base_options()
->mutable_model_asset()
->mutable_file_pointer_meta()
->Swap(&face_detection_file_pointer);
detector_options.set_num_faces(1);
// Set face landmarks model.
MP_ASSIGN_OR_RETURN(auto face_landmarks_model_file,
model_bundle->GetFile(kFaceLandmarksDetectorTFLiteName));
core::proto::FilePointerMeta face_landmarks_detector_file_pointer;
face_landmarks_detector_file_pointer.set_pointer(
reinterpret_cast<uint64_t>(face_landmarks_model_file.data()));
face_landmarks_detector_file_pointer.set_length(
face_landmarks_model_file.size());
landmarks_detector_options.mutable_base_options()
->mutable_model_asset()
->mutable_file_pointer_meta()
->Swap(&face_landmarks_detector_file_pointer);
// Track holistic face.
HolisticFaceTrackingRequest request;
MP_ASSIGN_OR_RETURN(
HolisticFaceTrackingOutput result,
TrackHolisticFace(image, face_landmarks_from_pose, detector_options,
landmarks_detector_options, request, graph));
auto face_landmarks =
result.landmarks.value().SetName(kFaceLandmarksOutStream);
auto image_size = GetImageSize(image, graph);
auto render_scale = utils::GetRenderScale(
image_size, result.debug_output.roi_from_pose, 0.0001, graph);
auto face_landmarks_render_data = utils::RenderLandmarks(
face_landmarks, render_scale, GetFaceRendererOptions(), graph);
std::vector<Stream<mediapipe::RenderData>> render_list = {
face_landmarks_render_data};
auto rendered_image =
utils::Render(
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
.SetName(kRenderedImageOutStream);
face_landmarks >> graph.Out("FACE_LANDMARKS");
rendered_image >> graph.Out("RENDERED_IMAGE");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
return TaskRunner::Create(
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
}
class HolisticFaceTrackingTest : public ::testing::Test {};
TEST_F(HolisticFaceTrackingTest, SmokeTest) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(GetFilePath(kTestImageFile)));
proto::HolisticResult holistic_result;
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
::file::Defaults()));
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
MP_ASSERT_OK_AND_ASSIGN(
auto output_packets,
task_runner->Process(
{{kImageInStream, MakePacket<Image>(image)},
{kPoseLandmarksInStream, MakePacket<NormalizedLandmarkList>(
holistic_result.pose_landmarks())}}));
ASSERT_TRUE(output_packets.find(kFaceLandmarksOutStream) !=
output_packets.end());
auto face_landmarks = output_packets.find(kFaceLandmarksOutStream)
->second.Get<NormalizedLandmarkList>();
EXPECT_THAT(
face_landmarks,
Approximately(Partially(EqualsProto(holistic_result.face_landmarks())),
/*margin=*/kAbsMargin));
auto rendered_image = output_packets.at(kRenderedImageOutStream).Get<Image>();
MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(),
"holistic_face_landmarks"));
}
} // namespace
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,272 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h"
#include <functional>
#include <optional>
#include <utility>
#include "absl/status/statusor.h"
#include "mediapipe/calculators/util/align_hand_to_pose_in_world_calculator.h"
#include "mediapipe/calculators/util/align_hand_to_pose_in_world_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
#include "mediapipe/framework/api2/stream/loopback.h"
#include "mediapipe/framework/api2/stream/rect_transformation.h"
#include "mediapipe/framework/api2/stream/split.h"
#include "mediapipe/framework/api2/stream/threshold.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.pb.h"
#include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::AlignHandToPoseInWorldCalculator;
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::GetLoopbackData;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::IsOverThreshold;
using ::mediapipe::api2::builder::ScaleAndShiftAndMakeSquareLong;
using ::mediapipe::api2::builder::SplitAndCombine;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::components::utils::AllowIf;
struct HandLandmarksResult {
std::optional<Stream<NormalizedLandmarkList>> landmarks;
std::optional<Stream<LandmarkList>> world_landmarks;
};
Stream<LandmarkList> AlignHandToPoseInWorldCalculator(
Stream<LandmarkList> hand_world_landmarks,
Stream<LandmarkList> pose_world_landmarks, int pose_wrist_idx,
Graph& graph) {
auto& node = graph.AddNode("AlignHandToPoseInWorldCalculator");
auto& opts = node.GetOptions<AlignHandToPoseInWorldCalculatorOptions>();
opts.set_hand_wrist_idx(0);
opts.set_pose_wrist_idx(pose_wrist_idx);
hand_world_landmarks.ConnectTo(
node[AlignHandToPoseInWorldCalculator::kInHandLandmarks]);
pose_world_landmarks.ConnectTo(
node[AlignHandToPoseInWorldCalculator::kInPoseLandmarks]);
return node[AlignHandToPoseInWorldCalculator::kOutHandLandmarks];
}
Stream<bool> GetPosePalmVisibility(
Stream<NormalizedLandmarkList> pose_palm_landmarks, Graph& graph) {
// Get wrist landmark.
auto pose_wrist = SplitAndCombine(pose_palm_landmarks, {0}, graph);
// Get visibility score.
auto& score_node = graph.AddNode("LandmarkVisibilityCalculator");
pose_wrist.ConnectTo(score_node.In("NORM_LANDMARKS"));
Stream<float> score = score_node.Out("VISIBILITY").Cast<float>();
// Convert score into flag.
return IsOverThreshold(score, /*threshold=*/0.1, graph);
}
Stream<NormalizedRect> GetHandRoiFromPosePalmLandmarks(
Stream<NormalizedLandmarkList> pose_palm_landmarks,
Stream<std::pair<int, int>> image_size, Graph& graph) {
// Convert pose palm landmarks to detection.
auto detection = ConvertLandmarksToDetection(pose_palm_landmarks, graph);
// Convert detection to rect.
auto& rect_node = graph.AddNode("HandDetectionsFromPoseToRectsCalculator");
detection.ConnectTo(rect_node.In("DETECTION"));
image_size.ConnectTo(rect_node.In("IMAGE_SIZE"));
Stream<NormalizedRect> rect =
rect_node.Out("NORM_RECT").Cast<NormalizedRect>();
return ScaleAndShiftAndMakeSquareLong(rect, image_size,
/*scale_x_factor=*/2.7,
/*scale_y_factor=*/2.7, /*shift_x=*/0,
/*shift_y=*/-0.1, graph);
}
absl::StatusOr<Stream<NormalizedRect>> RefineHandRoi(
Stream<Image> image, Stream<NormalizedRect> roi,
const hand_landmarker::proto::HandRoiRefinementGraphOptions&
hand_roi_refinenement_graph_options,
Graph& graph) {
auto& hand_roi_refinement = graph.AddNode(
"mediapipe.tasks.vision.hand_landmarker.HandRoiRefinementGraph");
hand_roi_refinement
.GetOptions<hand_landmarker::proto::HandRoiRefinementGraphOptions>() =
hand_roi_refinenement_graph_options;
image >> hand_roi_refinement.In("IMAGE");
roi >> hand_roi_refinement.In("NORM_RECT");
return hand_roi_refinement.Out("NORM_RECT").Cast<NormalizedRect>();
}
Stream<NormalizedRect> TrackHandRoi(
Stream<NormalizedLandmarkList> prev_landmarks, Stream<NormalizedRect> roi,
Stream<std::pair<int, int>> image_size, Graph& graph) {
// Convert hand landmarks to tight rect.
auto& prev_rect_node = graph.AddNode("HandLandmarksToRectCalculator");
prev_landmarks.ConnectTo(prev_rect_node.In("NORM_LANDMARKS"));
image_size.ConnectTo(prev_rect_node.In("IMAGE_SIZE"));
Stream<NormalizedRect> prev_rect =
prev_rect_node.Out("NORM_RECT").Cast<NormalizedRect>();
// Convert tight hand rect to hand roi.
Stream<NormalizedRect> prev_roi =
ScaleAndShiftAndMakeSquareLong(prev_rect, image_size,
/*scale_x_factor=*/2.0,
/*scale_y_factor=*/2.0, /*shift_x=*/0,
/*shift_y=*/-0.1, graph);
auto& tracking_node = graph.AddNode("RoiTrackingCalculator");
auto& tracking_node_opts =
tracking_node.GetOptions<RoiTrackingCalculatorOptions>();
auto* rect_requirements = tracking_node_opts.mutable_rect_requirements();
rect_requirements->set_rotation_degrees(40.0);
rect_requirements->set_translation(0.2);
rect_requirements->set_scale(0.4);
auto* landmarks_requirements =
tracking_node_opts.mutable_landmarks_requirements();
landmarks_requirements->set_recrop_rect_margin(-0.1);
prev_landmarks.ConnectTo(tracking_node.In("PREV_LANDMARKS"));
prev_roi.ConnectTo(tracking_node.In("PREV_LANDMARKS_RECT"));
roi.ConnectTo(tracking_node.In("RECROP_RECT"));
image_size.ConnectTo(tracking_node.In("IMAGE_SIZE"));
return tracking_node.Out("TRACKING_RECT").Cast<NormalizedRect>();
}
HandLandmarksResult GetHandLandmarksDetection(
Stream<Image> image, Stream<NormalizedRect> roi,
const hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
hand_landmarks_detector_graph_options,
const HolisticHandTrackingRequest& request, Graph& graph) {
HandLandmarksResult result;
auto& hand_landmarks_detector_graph = graph.AddNode(
"mediapipe.tasks.vision.hand_landmarker."
"SingleHandLandmarksDetectorGraph");
hand_landmarks_detector_graph
.GetOptions<hand_landmarker::proto::HandLandmarksDetectorGraphOptions>() =
hand_landmarks_detector_graph_options;
image >> hand_landmarks_detector_graph.In("IMAGE");
roi >> hand_landmarks_detector_graph.In("HAND_RECT");
if (request.landmarks) {
result.landmarks = hand_landmarks_detector_graph.Out("LANDMARKS")
.Cast<NormalizedLandmarkList>();
}
if (request.world_landmarks) {
result.world_landmarks =
hand_landmarks_detector_graph.Out("WORLD_LANDMARKS")
.Cast<LandmarkList>();
}
return result;
}
} // namespace
absl::StatusOr<HolisticHandTrackingOutput> TrackHolisticHand(
Stream<Image> image, Stream<NormalizedLandmarkList> pose_landmarks,
Stream<LandmarkList> pose_world_landmarks,
const hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
hand_landmarks_detector_graph_options,
const hand_landmarker::proto::HandRoiRefinementGraphOptions&
hand_roi_refinement_graph_options,
const PoseIndices& pose_indices, const HolisticHandTrackingRequest& request,
Graph& graph) {
// Extracts pose palm landmarks.
Stream<NormalizedLandmarkList> pose_palm_landmarks = SplitAndCombine(
pose_landmarks,
{pose_indices.wrist_idx, pose_indices.pinky_idx, pose_indices.index_idx},
graph);
// Get pose palm visibility.
Stream<bool> is_pose_palm_visible =
GetPosePalmVisibility(pose_palm_landmarks, graph);
// Drop pose palm landmarks if pose palm is invisible.
pose_palm_landmarks =
AllowIf(pose_palm_landmarks, is_pose_palm_visible, graph);
// Extracts image size from the input images.
Stream<std::pair<int, int>> image_size = GetImageSize(image, graph);
// Get hand ROI from pose palm landmarks.
Stream<NormalizedRect> roi_from_pose =
GetHandRoiFromPosePalmLandmarks(pose_palm_landmarks, image_size, graph);
// Refine hand ROI with re-crop model.
MP_ASSIGN_OR_RETURN(Stream<NormalizedRect> roi_from_recrop,
RefineHandRoi(image, roi_from_pose,
hand_roi_refinement_graph_options, graph));
// Loop for previous frame landmarks.
auto [prev_landmarks, set_prev_landmarks_fn] =
GetLoopbackData<NormalizedLandmarkList>(/*tick=*/image_size, graph);
// Track hand ROI.
auto tracking_roi =
TrackHandRoi(prev_landmarks, roi_from_recrop, image_size, graph);
// Predict hand landmarks.
auto landmarks_detection_result = GetHandLandmarksDetection(
image, tracking_roi, hand_landmarks_detector_graph_options, request,
graph);
// Set previous landmarks for ROI tracking.
set_prev_landmarks_fn(landmarks_detection_result.landmarks.value());
// Output landmarks.
std::optional<Stream<NormalizedLandmarkList>> hand_landmarks;
if (request.landmarks) {
hand_landmarks = landmarks_detection_result.landmarks;
}
// Output world landmarks.
std::optional<Stream<LandmarkList>> hand_world_landmarks;
if (request.world_landmarks) {
hand_world_landmarks = landmarks_detection_result.world_landmarks;
// Align hand world landmarks with pose world landmarks.
hand_world_landmarks = AlignHandToPoseInWorldCalculator(
hand_world_landmarks.value(), pose_world_landmarks,
pose_indices.wrist_idx, graph);
}
return {{.landmarks = hand_landmarks,
.world_landmarks = hand_world_landmarks,
.debug_output = {
.roi_from_pose = roi_from_pose,
.roi_from_recrop = roi_from_recrop,
.tracking_roi = tracking_roi,
}}};
}
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,94 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_
#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_
#include <optional>
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
struct PoseIndices {
int wrist_idx;
int pinky_idx;
int index_idx;
};
struct HolisticHandTrackingRequest {
bool landmarks = false;
bool world_landmarks = false;
};
struct HolisticHandTrackingOutput {
std::optional<api2::builder::Stream<mediapipe::NormalizedLandmarkList>>
landmarks;
std::optional<api2::builder::Stream<mediapipe::LandmarkList>> world_landmarks;
struct DebugOutput {
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_pose;
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_recrop;
api2::builder::Stream<mediapipe::NormalizedRect> tracking_roi;
};
DebugOutput debug_output;
};
// Updates @graph to track a single hand in @image based on pose landmarks.
//
// To track single hand this subgraph uses pose palm landmarks to obtain
// approximate hand location, refines it with re-crop model and then runs hand
// landmarks model. It can also reuse hand ROI from the previous frame if hand
// hasn't moved too much.
//
// @image - ImageFrame/GpuBuffer to track a single hand in.
// @pose_landmarks - Pose landmarks to derive initial hand location from.
// @pose_world_landmarks - Pose world landmarks to align hand world landmarks
// wrist with.
// @ hand_landmarks_detector_graph_options - Options of the
// HandLandmarksDetectorGraph used to detect the hand landmarks.
// @ hand_roi_refinement_graph_options - Options of HandRoiRefinementGraph used
// to refine the hand RoIs got from Pose landmarks.
// @request - object to request specific hand tracking outputs.
// NOTE: Outputs that were not requested won't be returned and corresponding
// parts of the graph won't be genertaed.
// @graph - graph to update.
absl::StatusOr<HolisticHandTrackingOutput> TrackHolisticHand(
api2::builder::Stream<Image> image,
api2::builder::Stream<mediapipe::NormalizedLandmarkList> pose_landmarks,
api2::builder::Stream<mediapipe::LandmarkList> pose_world_landmarks,
const hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
hand_landmarks_detector_graph_options,
const hand_landmarker::proto::HandRoiRefinementGraphOptions&
hand_roi_refinement_graph_options,
const PoseIndices& pose_indices, const HolisticHandTrackingRequest& request,
mediapipe::api2::builder::Graph& graph);
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_

View File

@ -0,0 +1,303 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h"
#include <memory>
#include <string>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "file/base/helpers.h"
#include "file/base/options.h"
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_topology.h"
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::file::Defaults;
using ::file::GetTextProto;
using ::mediapipe::Image;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::core::TaskRunner;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr float kAbsMargin = 0.018;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kHolisticHandTrackingLeft[] =
"holistic_hand_tracking_left_hand_graph.pbtxt";
constexpr char kTestImageFile[] = "male_full_height_hands.jpg";
constexpr char kHolisticResultFile[] =
"male_full_height_hands_result_cpu.pbtxt";
constexpr char kImageInStream[] = "image_in";
constexpr char kPoseLandmarksInStream[] = "pose_landmarks_in";
constexpr char kPoseWorldLandmarksInStream[] = "pose_world_landmarks_in";
constexpr char kLeftHandLandmarksOutStream[] = "left_hand_landmarks_out";
constexpr char kLeftHandWorldLandmarksOutStream[] =
"left_hand_world_landmarks_out";
constexpr char kRightHandLandmarksOutStream[] = "right_hand_landmarks_out";
constexpr char kRenderedImageOutStream[] = "rendered_image_out";
constexpr char kHandLandmarksModelFile[] = "hand_landmark_full.tflite";
constexpr char kHandRoiRefinementModelFile[] =
"handrecrop_2020_07_21_v0.f16.tflite";
std::string GetFilePath(const std::string& filename) {
return file::JoinPath("./", kTestDataDirectory, filename);
}
mediapipe::LandmarksToRenderDataCalculatorOptions GetHandRendererOptions() {
mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options;
for (const auto& connection : hand_landmarker::kHandConnections) {
renderer_options.add_landmark_connections(connection[0]);
renderer_options.add_landmark_connections(connection[1]);
}
renderer_options.mutable_landmark_color()->set_r(255);
renderer_options.mutable_landmark_color()->set_g(255);
renderer_options.mutable_landmark_color()->set_b(255);
renderer_options.mutable_connection_color()->set_r(255);
renderer_options.mutable_connection_color()->set_g(255);
renderer_options.mutable_connection_color()->set_b(255);
renderer_options.set_thickness(0.5);
renderer_options.set_visualize_landmark_depth(false);
return renderer_options;
}
void ConfigHandTrackingModelsOptions(
hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
hand_landmarks_detector_graph_options,
hand_landmarker::proto::HandRoiRefinementGraphOptions&
hand_roi_refinement_options) {
hand_landmarks_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kHandLandmarksModelFile));
hand_roi_refinement_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kHandRoiRefinementModelFile));
}
// Helper function to create a TaskRunner.
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner() {
Graph graph;
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
Stream<mediapipe::NormalizedLandmarkList> pose_landmarks =
graph.In("POSE_LANDMARKS")
.Cast<mediapipe::NormalizedLandmarkList>()
.SetName(kPoseLandmarksInStream);
Stream<mediapipe::LandmarkList> pose_world_landmarks =
graph.In("POSE_WORLD_LANDMARKS")
.Cast<mediapipe::LandmarkList>()
.SetName(kPoseWorldLandmarksInStream);
hand_landmarker::proto::HandLandmarksDetectorGraphOptions
hand_landmarks_detector_options;
hand_landmarker::proto::HandRoiRefinementGraphOptions
hand_roi_refinement_options;
ConfigHandTrackingModelsOptions(hand_landmarks_detector_options,
hand_roi_refinement_options);
HolisticHandTrackingRequest request;
request.landmarks = true;
MP_ASSIGN_OR_RETURN(
HolisticHandTrackingOutput left_hand_result,
TrackHolisticHand(
image, pose_landmarks, pose_world_landmarks,
hand_landmarks_detector_options, hand_roi_refinement_options,
PoseIndices{
/*wrist_idx=*/static_cast<int>(
pose_landmarker::PoseLandmarkName::kLeftWrist),
/*pinky_idx=*/
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftPinky1),
/*index_idx=*/
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftIndex1)},
request, graph));
MP_ASSIGN_OR_RETURN(
HolisticHandTrackingOutput right_hand_result,
TrackHolisticHand(
image, pose_landmarks, pose_world_landmarks,
hand_landmarks_detector_options, hand_roi_refinement_options,
PoseIndices{
/*wrist_idx=*/static_cast<int>(
pose_landmarker::PoseLandmarkName::kRightWrist),
/*pinky_idx=*/
static_cast<int>(pose_landmarker::PoseLandmarkName::kRightPinky1),
/*index_idx=*/
static_cast<int>(
pose_landmarker::PoseLandmarkName::kRightIndex1)},
request, graph));
auto image_size = GetImageSize(image, graph);
auto left_hand_landmarks_render_data = utils::RenderLandmarks(
*left_hand_result.landmarks,
utils::GetRenderScale(image_size,
left_hand_result.debug_output.roi_from_pose, 0.0001,
graph),
GetHandRendererOptions(), graph);
auto right_hand_landmarks_render_data = utils::RenderLandmarks(
*right_hand_result.landmarks,
utils::GetRenderScale(image_size,
right_hand_result.debug_output.roi_from_pose,
0.0001, graph),
GetHandRendererOptions(), graph);
std::vector<Stream<mediapipe::RenderData>> render_list = {
left_hand_landmarks_render_data, right_hand_landmarks_render_data};
auto rendered_image =
utils::Render(
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
.SetName(kRenderedImageOutStream);
left_hand_result.landmarks->SetName(kLeftHandLandmarksOutStream) >>
graph.Out("LEFT_HAND_LANDMARKS");
right_hand_result.landmarks->SetName(kRightHandLandmarksOutStream) >>
graph.Out("RIGHT_HAND_LANDMARKS");
rendered_image >> graph.Out("RENDERED_IMAGE");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
return TaskRunner::Create(
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
}
class HolisticHandTrackingTest : public ::testing::Test {};
TEST_F(HolisticHandTrackingTest, VerifyGraph) {
Graph graph;
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
Stream<mediapipe::NormalizedLandmarkList> pose_landmarks =
graph.In("POSE_LANDMARKS")
.Cast<mediapipe::NormalizedLandmarkList>()
.SetName(kPoseLandmarksInStream);
Stream<mediapipe::LandmarkList> pose_world_landmarks =
graph.In("POSE_WORLD_LANDMARKS")
.Cast<mediapipe::LandmarkList>()
.SetName(kPoseWorldLandmarksInStream);
hand_landmarker::proto::HandLandmarksDetectorGraphOptions
hand_landmarks_detector_options;
hand_landmarker::proto::HandRoiRefinementGraphOptions
hand_roi_refinement_options;
ConfigHandTrackingModelsOptions(hand_landmarks_detector_options,
hand_roi_refinement_options);
HolisticHandTrackingRequest request;
request.landmarks = true;
request.world_landmarks = true;
MP_ASSERT_OK_AND_ASSIGN(
HolisticHandTrackingOutput left_hand_result,
TrackHolisticHand(
image, pose_landmarks, pose_world_landmarks,
hand_landmarks_detector_options, hand_roi_refinement_options,
PoseIndices{
/*wrist_idx=*/static_cast<int>(
pose_landmarker::PoseLandmarkName::kLeftWrist),
/*pinky_idx=*/
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftPinky1),
/*index_idx=*/
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftIndex1)},
request, graph));
left_hand_result.landmarks->SetName(kLeftHandLandmarksOutStream) >>
graph.Out("LEFT_HAND_LANDMARKS");
left_hand_result.world_landmarks->SetName(kLeftHandWorldLandmarksOutStream) >>
graph.Out("LEFT_HAND_WORLD_LANDMARKS");
// Read the expected graph config.
std::string expected_graph_contents;
MP_ASSERT_OK(file::GetContents(
file::JoinPath("./", kTestDataDirectory, kHolisticHandTrackingLeft),
&expected_graph_contents));
// Need to replace the expected graph config with the test srcdir, because
// each run has different test dir on TAP.
expected_graph_contents = absl::Substitute(
expected_graph_contents, FLAGS_test_srcdir, FLAGS_test_srcdir);
CalculatorGraphConfig expected_graph =
ParseTextProtoOrDie<CalculatorGraphConfig>(expected_graph_contents);
EXPECT_THAT(graph.GetConfig(), testing::proto::IgnoringRepeatedFieldOrdering(
testing::EqualsProto(expected_graph)));
}
TEST_F(HolisticHandTrackingTest, SmokeTest) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(GetFilePath(kTestImageFile)));
proto::HolisticResult holistic_result;
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
Defaults()));
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
MP_ASSERT_OK_AND_ASSIGN(
auto output_packets,
task_runner->Process(
{{kImageInStream, MakePacket<Image>(image)},
{kPoseLandmarksInStream, MakePacket<NormalizedLandmarkList>(
holistic_result.pose_landmarks())},
{kPoseWorldLandmarksInStream,
MakePacket<LandmarkList>(
holistic_result.pose_world_landmarks())}}));
auto left_hand_landmarks = output_packets.at(kLeftHandLandmarksOutStream)
.Get<NormalizedLandmarkList>();
auto right_hand_landmarks = output_packets.at(kRightHandLandmarksOutStream)
.Get<NormalizedLandmarkList>();
EXPECT_THAT(left_hand_landmarks,
Approximately(
Partially(EqualsProto(holistic_result.left_hand_landmarks())),
/*margin=*/kAbsMargin));
EXPECT_THAT(
right_hand_landmarks,
Approximately(
Partially(EqualsProto(holistic_result.right_hand_landmarks())),
/*margin=*/kAbsMargin));
auto rendered_image = output_packets.at(kRenderedImageOutStream).Get<Image>();
MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(),
"holistic_hand_landmarks"));
}
} // namespace
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,521 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/split.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_topology.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
#include "mediapipe/util/graph_builder_utils.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::metadata::SetExternalFile;
constexpr absl::string_view kHandLandmarksDetectorModelName =
"hand_landmarks_detector.tflite";
constexpr absl::string_view kHandRoiRefinementModelName =
"hand_roi_refinement.tflite";
constexpr absl::string_view kFaceDetectorModelName = "face_detector.tflite";
constexpr absl::string_view kFaceLandmarksDetectorModelName =
"face_landmarks_detector.tflite";
constexpr absl::string_view kFaceBlendshapesModelName =
"face_blendshapes.tflite";
constexpr absl::string_view kPoseDetectorModelName = "pose_detector.tflite";
constexpr absl::string_view kPoseLandmarksDetectorModelName =
"pose_landmarks_detector.tflite";
absl::Status SetGraphPoseOutputs(
const HolisticPoseTrackingRequest& pose_request,
const CalculatorGraphConfig::Node& node,
HolisticPoseTrackingOutput& pose_output, Graph& graph) {
// Main outputs.
if (pose_request.landmarks) {
RET_CHECK(pose_output.landmarks.has_value())
<< "POSE_LANDMARKS output is not supported.";
pose_output.landmarks->ConnectTo(graph.Out("POSE_LANDMARKS"));
}
if (pose_request.world_landmarks) {
RET_CHECK(pose_output.world_landmarks.has_value())
<< "POSE_WORLD_LANDMARKS output is not supported.";
pose_output.world_landmarks->ConnectTo(graph.Out("POSE_WORLD_LANDMARKS"));
}
if (pose_request.segmentation_mask) {
RET_CHECK(pose_output.segmentation_mask.has_value())
<< "POSE_SEGMENTATION_MASK output is not supported.";
pose_output.segmentation_mask->ConnectTo(
graph.Out("POSE_SEGMENTATION_MASK"));
}
// Debug outputs.
if (HasOutput(node, "POSE_AUXILIARY_LANDMARKS")) {
pose_output.debug_output.auxiliary_landmarks.ConnectTo(
graph.Out("POSE_AUXILIARY_LANDMARKS"));
}
if (HasOutput(node, "POSE_LANDMARKS_ROI")) {
pose_output.debug_output.roi_from_landmarks.ConnectTo(
graph.Out("POSE_LANDMARKS_ROI"));
}
return absl::OkStatus();
}
// Sets the base options in the sub tasks.
template <typename T>
absl::Status SetSubTaskBaseOptions(
const core::ModelAssetBundleResources* resources,
proto::HolisticLandmarkerGraphOptions* options, T* sub_task_options,
absl::string_view model_name, bool is_copy) {
if (!sub_task_options->base_options().has_model_asset()) {
MP_ASSIGN_OR_RETURN(const auto model_file_content,
resources->GetFile(std::string(model_name)));
SetExternalFile(
model_file_content,
sub_task_options->mutable_base_options()->mutable_model_asset(),
is_copy);
}
sub_task_options->mutable_base_options()->mutable_acceleration()->CopyFrom(
options->base_options().acceleration());
sub_task_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
sub_task_options->mutable_base_options()->set_gpu_origin(
options->base_options().gpu_origin());
return absl::OkStatus();
}
void SetGraphHandOutputs(bool is_left, const CalculatorGraphConfig::Node& node,
HolisticHandTrackingOutput& hand_output,
Graph& graph) {
const std::string hand_side = is_left ? "LEFT" : "RIGHT";
if (hand_output.landmarks) {
hand_output.landmarks->ConnectTo(graph.Out(hand_side + "_HAND_LANDMARKS"));
}
if (hand_output.world_landmarks) {
hand_output.world_landmarks->ConnectTo(
graph.Out(hand_side + "_HAND_WORLD_LANDMARKS"));
}
// Debug outputs.
if (HasOutput(node, hand_side + "_HAND_ROI_FROM_POSE")) {
hand_output.debug_output.roi_from_pose.ConnectTo(
graph.Out(hand_side + "_HAND_ROI_FROM_POSE"));
}
if (HasOutput(node, hand_side + "_HAND_ROI_FROM_RECROP")) {
hand_output.debug_output.roi_from_recrop.ConnectTo(
graph.Out(hand_side + "_HAND_ROI_FROM_RECROP"));
}
if (HasOutput(node, hand_side + "_HAND_TRACKING_ROI")) {
hand_output.debug_output.tracking_roi.ConnectTo(
graph.Out(hand_side + "_HAND_TRACKING_ROI"));
}
}
void SetGraphFaceOutputs(const CalculatorGraphConfig::Node& node,
HolisticFaceTrackingOutput& face_output,
Graph& graph) {
if (face_output.landmarks) {
face_output.landmarks->ConnectTo(graph.Out("FACE_LANDMARKS"));
}
if (face_output.classifications) {
face_output.classifications->ConnectTo(graph.Out("FACE_BLENDSHAPES"));
}
// Face detection debug outputs
if (HasOutput(node, "FACE_ROI_FROM_POSE")) {
face_output.debug_output.roi_from_pose.ConnectTo(
graph.Out("FACE_ROI_FROM_POSE"));
}
if (HasOutput(node, "FACE_ROI_FROM_DETECTION")) {
face_output.debug_output.roi_from_detection.ConnectTo(
graph.Out("FACE_ROI_FROM_DETECTION"));
}
if (HasOutput(node, "FACE_TRACKING_ROI")) {
face_output.debug_output.tracking_roi.ConnectTo(
graph.Out("FACE_TRACKING_ROI"));
}
}
} // namespace
// Tracks pose and detects hands and face.
//
// NOTE: for GPU works only with image having GpuOrigin::TOP_LEFT
//
// Inputs:
// IMAGE - Image
// Image to perform detection on.
//
// Outputs:
// POSE_LANDMARKS - NormalizedLandmarkList
// 33 landmarks (see pose_landmarker/pose_topology.h)
// 0 - nose
// 1 - left eye (inner)
// 2 - left eye
// 3 - left eye (outer)
// 4 - right eye (inner)
// 5 - right eye
// 6 - right eye (outer)
// 7 - left ear
// 8 - right ear
// 9 - mouth (left)
// 10 - mouth (right)
// 11 - left shoulder
// 12 - right shoulder
// 13 - left elbow
// 14 - right elbow
// 15 - left wrist
// 16 - right wrist
// 17 - left pinky
// 18 - right pinky
// 19 - left index
// 20 - right index
// 21 - left thumb
// 22 - right thumb
// 23 - left hip
// 24 - right hip
// 25 - left knee
// 26 - right knee
// 27 - left ankle
// 28 - right ankle
// 29 - left heel
// 30 - right heel
// 31 - left foot index
// 32 - right foot index
// POSE_WORLD_LANDMARKS - LandmarkList
// World landmarks are real world 3D coordinates with origin in hips center
// and coordinates in meters. To understand the difference: POSE_LANDMARKS
// stream provides coordinates (in pixels) of 3D object projected on a 2D
// surface of the image (check on how perspective projection works), while
// POSE_WORLD_LANDMARKS stream provides coordinates (in meters) of the 3D
// object itself. POSE_WORLD_LANDMARKS has the same landmarks topology,
// visibility and presence as POSE_LANDMARKS.
// POSE_SEGMENTATION_MASK - Image
// Separates person from background. Mask is stored as gray float32 image
// with [0.0, 1.0] range for pixels (1 for person and 0 for background) on
// CPU and, on GPU - RGBA texture with R channel indicating person vs.
// background probability.
// LEFT_HAND_LANDMARKS - NormalizedLandmarkList
// 21 left hand landmarks.
// RIGHT_HAND_LANDMARKS - NormalizedLandmarkList
// 21 right hand landmarks.
// FACE_LANDMARKS - NormalizedLandmarkList
// 468 face landmarks.
// FACE_BLENDSHAPES - ClassificationList
// Supplementary blendshape coefficients that are predicted directly from
// the input image.
// LEFT_HAND_WORLD_LANDMARKS - LandmarkList
// 21 left hand world 3D landmarks.
// Hand landmarks are aligned with pose landmarks: translated so that wrist
// from # hand matches wrist from pose in pose coordinates system.
// RIGHT_HAND_WORLD_LANDMARKS - LandmarkList
// 21 right hand world 3D landmarks.
// Hand landmarks are aligned with pose landmarks: translated so that wrist
// from # hand matches wrist from pose in pose coordinates system.
// IMAGE - Image
// The input image that the hiolistic landmarker runs on and has the pixel
// data stored on the target storage (CPU vs GPU).
//
// Debug outputs:
// POSE_AUXILIARY_LANDMARKS - NormalizedLandmarkList
// TODO: Return ROI rather than auxiliary landmarks
// Auxiliary landmarks for deriving the ROI in the subsequent image.
// 0 - hidden center point
// 1 - hidden scale point
// POSE_LANDMARKS_ROI - NormalizedRect
// Region of interest calculated based on landmarks.
// LEFT_HAND_ROI_FROM_POSE - NormalizedLandmarkList
// LEFT_HAND_ROI_FROM_RECROP - NormalizedLandmarkList
// LEFT_HAND_TRACKING_ROI - NormalizedLandmarkList
// RIGHT_HAND_ROI_FROM_POSE - NormalizedLandmarkList
// RIGHT_HAND_ROI_FROM_RECROP - NormalizedLandmarkList
// RIGHT_HAND_TRACKING_ROI - NormalizedLandmarkList
// FACE_ROI_FROM_POSE - NormalizedLandmarkList
// FACE_ROI_FROM_DETECTION - NormalizedLandmarkList
// FACE_TRACKING_ROI - NormalizedLandmarkList
//
// NOTE: failure is reported if some output has been requested, but specified
// model doesn't support it.
//
// NOTE: there will not be an output packet in an output stream for a
// particular timestamp if nothing is detected. However, the MediaPipe
// framework will internally inform the downstream calculators of the
// absence of this packet so that they don't wait for it unnecessarily.
//
// Example:
// node {
// calculator:
// "mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph"
// input_stream: "IMAGE:input_frames_image"
// output_stream: "POSE_LANDMARKS:pose_landmarks"
// output_stream: "POSE_WORLD_LANDMARKS:pose_world_landmarks"
// output_stream: "FACE_LANDMARKS:face_landmarks"
// output_stream: "FACE_BLENDSHAPES:extra_blendshapes"
// output_stream: "LEFT_HAND_LANDMARKS:left_hand_landmarks"
// output_stream: "LEFT_HAND_WORLD_LANDMARKS:left_hand_world_landmarks"
// output_stream: "RIGHT_HAND_LANDMARKS:right_hand_landmarks"
// output_stream: "RIGHT_HAND_WORLD_LANDMARKS:right_hand_world_landmarks"
// node_options {
// [type.googleapis.com/mediapipe.tasks.vision.holistic_landmarker.proto.HolisticLandmarkerGraphOptions]
// {
// base_options {
// model_asset {
// file_name:
// "mediapipe/tasks/testdata/vision/holistic_landmarker.task"
// }
// }
// face_detector_graph_options: {
// num_faces: 1
// }
// pose_detector_graph_options: {
// num_poses: 1
// }
// }
// }
// }
class HolisticLandmarkerGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
const auto& holistic_node = sc->OriginalNode();
proto::HolisticLandmarkerGraphOptions* holistic_options =
sc->MutableOptions<proto::HolisticLandmarkerGraphOptions>();
const core::ModelAssetBundleResources* model_asset_bundle_resources;
if (holistic_options->base_options().has_model_asset()) {
MP_ASSIGN_OR_RETURN(model_asset_bundle_resources,
CreateModelAssetBundleResources<
proto::HolisticLandmarkerGraphOptions>(sc));
}
// Copies the file content instead of passing the pointer of file in
// memory if the subgraph model resource service is not available.
bool create_copy =
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable();
Stream<Image> image = graph.In("IMAGE").Cast<Image>();
// Check whether Hand requested
const bool is_left_hand_requested =
HasOutput(holistic_node, "LEFT_HAND_LANDMARKS");
const bool is_right_hand_requested =
HasOutput(holistic_node, "RIGHT_HAND_LANDMARKS");
const bool is_left_hand_world_requested =
HasOutput(holistic_node, "LEFT_HAND_WORLD_LANDMARKS");
const bool is_right_hand_world_requested =
HasOutput(holistic_node, "RIGHT_HAND_WORLD_LANDMARKS");
const bool hands_requested =
is_left_hand_requested || is_right_hand_requested ||
is_left_hand_world_requested || is_right_hand_world_requested;
if (hands_requested) {
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_hand_landmarks_detector_graph_options(),
kHandLandmarksDetectorModelName, create_copy));
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_hand_roi_refinement_graph_options(),
kHandRoiRefinementModelName, create_copy));
}
// Check whether Face requested
const bool is_face_requested = HasOutput(holistic_node, "FACE_LANDMARKS");
const bool is_face_blendshapes_requested =
HasOutput(holistic_node, "FACE_BLENDSHAPES");
const bool face_requested =
is_face_requested || is_face_blendshapes_requested;
if (face_requested) {
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_face_detector_graph_options(),
kFaceDetectorModelName, create_copy));
// Forcely set num_faces to 1, because holistic landmarker only supports a
// single subject for now.
holistic_options->mutable_face_detector_graph_options()->set_num_faces(1);
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_face_landmarks_detector_graph_options(),
kFaceLandmarksDetectorModelName, create_copy));
if (is_face_blendshapes_requested) {
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_face_landmarks_detector_graph_options()
->mutable_face_blendshapes_graph_options(),
kFaceBlendshapesModelName, create_copy));
}
}
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_pose_detector_graph_options(),
kPoseDetectorModelName, create_copy));
// Forcely set num_poses to 1, because holistic landmarker sonly supports a
// single subject for now.
holistic_options->mutable_pose_detector_graph_options()->set_num_poses(1);
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_pose_landmarks_detector_graph_options(),
kPoseLandmarksDetectorModelName, create_copy));
HolisticPoseTrackingRequest pose_request = {
.landmarks = HasOutput(holistic_node, "POSE_LANDMARKS") ||
hands_requested || face_requested,
.world_landmarks =
HasOutput(holistic_node, "POSE_WORLD_LANDMARKS") || hands_requested,
.segmentation_mask =
HasOutput(holistic_node, "POSE_SEGMENTATION_MASK")};
// Detect and track pose.
MP_ASSIGN_OR_RETURN(
HolisticPoseTrackingOutput pose_output,
TrackHolisticPose(
image, holistic_options->pose_detector_graph_options(),
holistic_options->pose_landmarks_detector_graph_options(),
pose_request, graph));
MP_RETURN_IF_ERROR(
SetGraphPoseOutputs(pose_request, holistic_node, pose_output, graph));
// Detect and track hand.
if (hands_requested) {
if (is_left_hand_requested || is_left_hand_world_requested) {
RET_CHECK(pose_output.landmarks.has_value());
RET_CHECK(pose_output.world_landmarks.has_value());
PoseIndices pose_indices = {
.wrist_idx =
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftWrist),
.pinky_idx = static_cast<int>(
pose_landmarker::PoseLandmarkName::kLeftPinky1),
.index_idx = static_cast<int>(
pose_landmarker::PoseLandmarkName::kLeftIndex1),
};
HolisticHandTrackingRequest hand_request = {
.landmarks = is_left_hand_requested,
.world_landmarks = is_left_hand_world_requested,
};
MP_ASSIGN_OR_RETURN(
HolisticHandTrackingOutput hand_output,
TrackHolisticHand(
image, *pose_output.landmarks, *pose_output.world_landmarks,
holistic_options->hand_landmarks_detector_graph_options(),
holistic_options->hand_roi_refinement_graph_options(),
pose_indices, hand_request, graph
));
SetGraphHandOutputs(/*is_left=*/true, holistic_node, hand_output,
graph);
}
if (is_right_hand_requested || is_right_hand_world_requested) {
RET_CHECK(pose_output.landmarks.has_value());
RET_CHECK(pose_output.world_landmarks.has_value());
PoseIndices pose_indices = {
.wrist_idx = static_cast<int>(
pose_landmarker::PoseLandmarkName::kRightWrist),
.pinky_idx = static_cast<int>(
pose_landmarker::PoseLandmarkName::kRightPinky1),
.index_idx = static_cast<int>(
pose_landmarker::PoseLandmarkName::kRightIndex1),
};
HolisticHandTrackingRequest hand_request = {
.landmarks = is_right_hand_requested,
.world_landmarks = is_right_hand_world_requested,
};
MP_ASSIGN_OR_RETURN(
HolisticHandTrackingOutput hand_output,
TrackHolisticHand(
image, *pose_output.landmarks, *pose_output.world_landmarks,
holistic_options->hand_landmarks_detector_graph_options(),
holistic_options->hand_roi_refinement_graph_options(),
pose_indices, hand_request, graph
));
SetGraphHandOutputs(/*is_left=*/false, holistic_node, hand_output,
graph);
}
}
// Detect and track face.
if (face_requested) {
RET_CHECK(pose_output.landmarks.has_value());
Stream<mediapipe::NormalizedLandmarkList> face_landmarks_from_pose =
api2::builder::SplitToRanges(*pose_output.landmarks, {{0, 11}},
graph)[0];
HolisticFaceTrackingRequest face_request = {
.classifications = is_face_blendshapes_requested,
};
MP_ASSIGN_OR_RETURN(
HolisticFaceTrackingOutput face_output,
TrackHolisticFace(
image, face_landmarks_from_pose,
holistic_options->face_detector_graph_options(),
holistic_options->face_landmarks_detector_graph_options(),
face_request, graph));
SetGraphFaceOutputs(holistic_node, face_output, graph);
}
auto& pass_through = graph.AddNode("PassThroughCalculator");
image >> pass_through.In("");
pass_through.Out("") >> graph.Out("IMAGE");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::holistic_landmarker::HolisticLandmarkerGraph);
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,595 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "file/base/helpers.h"
#include "file/base/options.h"
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/googletest.h"
#include "testing/base/public/gunit.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::core::TaskRunner;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr float kAbsMargin = 0.025;
constexpr absl::string_view kTestDataDirectory =
"/mediapipe/tasks/testdata/vision/";
constexpr char kHolisticResultFile[] =
"male_full_height_hands_result_cpu.pbtxt";
constexpr absl::string_view kTestImageFile = "male_full_height_hands.jpg";
constexpr absl::string_view kImageInStream = "image_in";
constexpr absl::string_view kLeftHandLandmarksStream = "left_hand_landmarks";
constexpr absl::string_view kRightHandLandmarksStream = "right_hand_landmarks";
constexpr absl::string_view kFaceLandmarksStream = "face_landmarks";
constexpr absl::string_view kFaceBlendshapesStream = "face_blendshapes";
constexpr absl::string_view kPoseLandmarksStream = "pose_landmarks";
constexpr absl::string_view kRenderedImageOutStream = "rendered_image_out";
constexpr absl::string_view kPoseSegmentationMaskStream =
"pose_segmentation_mask";
constexpr absl::string_view kHolisticLandmarkerModelBundleFile =
"holistic_landmarker.task";
constexpr absl::string_view kHandLandmarksModelFile =
"hand_landmark_full.tflite";
constexpr absl::string_view kHandRoiRefinementModelFile =
"handrecrop_2020_07_21_v0.f16.tflite";
constexpr absl::string_view kPoseDetectionModelFile = "pose_detection.tflite";
constexpr absl::string_view kPoseLandmarksModelFile =
"pose_landmark_lite.tflite";
constexpr absl::string_view kFaceDetectionModelFile =
"face_detection_short_range.tflite";
constexpr absl::string_view kFaceLandmarksModelFile =
"facemesh2_lite_iris_faceflag_2023_02_14.tflite";
constexpr absl::string_view kFaceBlendshapesModelFile =
"face_blendshapes.tflite";
enum RenderPart {
HAND = 0,
POSE = 1,
FACE = 2,
};
mediapipe::Color GetColor(RenderPart render_part) {
mediapipe::Color color;
switch (render_part) {
case HAND:
color.set_b(255);
color.set_g(255);
color.set_r(255);
break;
case POSE:
color.set_b(0);
color.set_g(255);
color.set_r(0);
break;
case FACE:
color.set_b(0);
color.set_g(0);
color.set_r(255);
break;
}
return color;
}
std::string GetFilePath(absl::string_view filename) {
return file::JoinPath("./", kTestDataDirectory, filename);
}
template <std::size_t N>
mediapipe::LandmarksToRenderDataCalculatorOptions GetRendererOptions(
const std::array<std::array<int, 2>, N>& connections,
mediapipe::Color color) {
mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options;
for (const auto& connection : connections) {
renderer_options.add_landmark_connections(connection[0]);
renderer_options.add_landmark_connections(connection[1]);
}
*renderer_options.mutable_landmark_color() = color;
*renderer_options.mutable_connection_color() = color;
renderer_options.set_thickness(0.5);
renderer_options.set_visualize_landmark_depth(false);
return renderer_options;
}
void ConfigureHandProtoOptions(proto::HolisticLandmarkerGraphOptions& options) {
options.mutable_hand_landmarks_detector_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kHandLandmarksModelFile));
options.mutable_hand_roi_refinement_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kHandRoiRefinementModelFile));
}
void ConfigureFaceProtoOptions(proto::HolisticLandmarkerGraphOptions& options) {
// Set face detection model.
face_detector::proto::FaceDetectorGraphOptions& face_detector_graph_options =
*options.mutable_face_detector_graph_options();
face_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kFaceDetectionModelFile));
face_detector_graph_options.set_num_faces(1);
// Set face landmarks model.
face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
face_landmarks_graph_options =
*options.mutable_face_landmarks_detector_graph_options();
face_landmarks_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kFaceLandmarksModelFile));
face_landmarks_graph_options.mutable_face_blendshapes_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kFaceBlendshapesModelFile));
}
void ConfigurePoseProtoOptions(proto::HolisticLandmarkerGraphOptions& options) {
pose_detector::proto::PoseDetectorGraphOptions& pose_detector_graph_options =
*options.mutable_pose_detector_graph_options();
pose_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kPoseDetectionModelFile));
pose_detector_graph_options.set_num_poses(1);
options.mutable_pose_landmarks_detector_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kPoseLandmarksModelFile));
}
struct HolisticRequest {
bool is_left_hand_requested = false;
bool is_right_hand_requested = false;
bool is_face_requested = false;
bool is_face_blendshapes_requested = false;
};
// Helper function to create a TaskRunner.
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner(
bool use_model_bundle, HolisticRequest holistic_request) {
Graph graph;
Stream<Image> image = graph.In("IMAEG").Cast<Image>().SetName(kImageInStream);
auto& holistic_graph = graph.AddNode(
"mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph");
proto::HolisticLandmarkerGraphOptions& options =
holistic_graph.GetOptions<proto::HolisticLandmarkerGraphOptions>();
if (use_model_bundle) {
options.mutable_base_options()->mutable_model_asset()->set_file_name(
GetFilePath(kHolisticLandmarkerModelBundleFile));
} else {
ConfigureHandProtoOptions(options);
ConfigurePoseProtoOptions(options);
ConfigureFaceProtoOptions(options);
}
std::vector<Stream<mediapipe::RenderData>> render_list;
image >> holistic_graph.In("IMAGE");
Stream<std::pair<int, int>> image_size = GetImageSize(image, graph);
if (holistic_request.is_left_hand_requested) {
Stream<NormalizedLandmarkList> left_hand_landmarks =
holistic_graph.Out("LEFT_HAND_LANDMARKS")
.Cast<NormalizedLandmarkList>()
.SetName(kLeftHandLandmarksStream);
Stream<NormalizedRect> left_hand_tracking_roi =
holistic_graph.Out("LEFT_HAND_TRACKING_ROI").Cast<NormalizedRect>();
auto left_hand_landmarks_render_data = utils::RenderLandmarks(
left_hand_landmarks,
utils::GetRenderScale(image_size, left_hand_tracking_roi, 0.0001,
graph),
GetRendererOptions(hand_landmarker::kHandConnections,
GetColor(RenderPart::HAND)),
graph);
render_list.push_back(left_hand_landmarks_render_data);
left_hand_landmarks >> graph.Out("LEFT_HAND_LANDMARKS");
}
if (holistic_request.is_right_hand_requested) {
Stream<NormalizedLandmarkList> right_hand_landmarks =
holistic_graph.Out("RIGHT_HAND_LANDMARKS")
.Cast<NormalizedLandmarkList>()
.SetName(kRightHandLandmarksStream);
Stream<NormalizedRect> right_hand_tracking_roi =
holistic_graph.Out("RIGHT_HAND_TRACKING_ROI").Cast<NormalizedRect>();
auto right_hand_landmarks_render_data = utils::RenderLandmarks(
right_hand_landmarks,
utils::GetRenderScale(image_size, right_hand_tracking_roi, 0.0001,
graph),
GetRendererOptions(hand_landmarker::kHandConnections,
GetColor(RenderPart::HAND)),
graph);
render_list.push_back(right_hand_landmarks_render_data);
right_hand_landmarks >> graph.Out("RIGHT_HAND_LANDMARKS");
}
if (holistic_request.is_face_requested) {
Stream<NormalizedLandmarkList> face_landmarks =
holistic_graph.Out("FACE_LANDMARKS")
.Cast<NormalizedLandmarkList>()
.SetName(kFaceLandmarksStream);
Stream<NormalizedRect> face_tracking_roi =
holistic_graph.Out("FACE_TRACKING_ROI").Cast<NormalizedRect>();
auto face_landmarks_render_data = utils::RenderLandmarks(
face_landmarks,
utils::GetRenderScale(image_size, face_tracking_roi, 0.0001, graph),
GetRendererOptions(
face_landmarker::FaceLandmarksConnections::kFaceLandmarksConnectors,
GetColor(RenderPart::FACE)),
graph);
render_list.push_back(face_landmarks_render_data);
face_landmarks >> graph.Out("FACE_LANDMARKS");
}
if (holistic_request.is_face_blendshapes_requested) {
Stream<ClassificationList> face_blendshapes =
holistic_graph.Out("FACE_BLENDSHAPES")
.Cast<ClassificationList>()
.SetName(kFaceBlendshapesStream);
face_blendshapes >> graph.Out("FACE_BLENDSHAPES");
}
Stream<NormalizedLandmarkList> pose_landmarks =
holistic_graph.Out("POSE_LANDMARKS")
.Cast<NormalizedLandmarkList>()
.SetName(kPoseLandmarksStream);
Stream<NormalizedRect> pose_tracking_roi =
holistic_graph.Out("POSE_LANDMARKS_ROI").Cast<NormalizedRect>();
Stream<Image> pose_segmentation_mask =
holistic_graph.Out("POSE_SEGMENTATION_MASK")
.Cast<Image>()
.SetName(kPoseSegmentationMaskStream);
auto pose_landmarks_render_data = utils::RenderLandmarks(
pose_landmarks,
utils::GetRenderScale(image_size, pose_tracking_roi, 0.0001, graph),
GetRendererOptions(pose_landmarker::kPoseLandmarksConnections,
GetColor(RenderPart::POSE)),
graph);
render_list.push_back(pose_landmarks_render_data);
auto rendered_image =
utils::Render(
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
.SetName(kRenderedImageOutStream);
pose_landmarks >> graph.Out("POSE_LANDMARKS");
pose_segmentation_mask >> graph.Out("POSE_SEGMENTATION_MASK");
rendered_image >> graph.Out("RENDERED_IMAGE");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
return TaskRunner::Create(
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
}
template <typename T>
absl::StatusOr<T> FetchResult(const core::PacketMap& output_packets,
absl::string_view stream_name) {
auto it = output_packets.find(std::string(stream_name));
RET_CHECK(it != output_packets.end());
return it->second.Get<T>();
}
// Remove fields not to be checked in the result, since the model
// generating expected result is different from the testing model.
void RemoveUncheckedResult(proto::HolisticResult& holistic_result) {
for (auto& landmark :
*holistic_result.mutable_pose_landmarks()->mutable_landmark()) {
landmark.clear_z();
landmark.clear_visibility();
landmark.clear_presence();
}
for (auto& landmark :
*holistic_result.mutable_face_landmarks()->mutable_landmark()) {
landmark.clear_z();
landmark.clear_visibility();
landmark.clear_presence();
}
for (auto& landmark :
*holistic_result.mutable_left_hand_landmarks()->mutable_landmark()) {
landmark.clear_z();
landmark.clear_visibility();
landmark.clear_presence();
}
for (auto& landmark :
*holistic_result.mutable_right_hand_landmarks()->mutable_landmark()) {
landmark.clear_z();
landmark.clear_visibility();
landmark.clear_presence();
}
}
std::string RequestToString(HolisticRequest request) {
return absl::StrFormat(
"%s_%s_%s_%s",
request.is_left_hand_requested ? "left_hand" : "no_left_hand",
request.is_right_hand_requested ? "right_hand" : "no_right_hand",
request.is_face_requested ? "face" : "no_face",
request.is_face_blendshapes_requested ? "face_blendshapes"
: "no_face_blendshapes");
}
struct TestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of test image.
std::string test_image_name;
// Whether to use holistic model bundle to test.
bool use_model_bundle;
// Requests of holistic parts.
HolisticRequest holistic_request;
};
class SmokeTest : public testing::TestWithParam<TestParams> {};
TEST_P(SmokeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(GetFilePath(GetParam().test_image_name)));
proto::HolisticResult holistic_result;
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
::file::Defaults()));
RemoveUncheckedResult(holistic_result);
MP_ASSERT_OK_AND_ASSIGN(auto task_runner,
CreateTaskRunner(GetParam().use_model_bundle,
GetParam().holistic_request));
MP_ASSERT_OK_AND_ASSIGN(auto output_packets,
task_runner->Process({{std::string(kImageInStream),
MakePacket<Image>(image)}}));
// Check face landmarks
if (GetParam().holistic_request.is_face_requested) {
MP_ASSERT_OK_AND_ASSIGN(auto face_landmarks,
FetchResult<NormalizedLandmarkList>(
output_packets, kFaceLandmarksStream));
EXPECT_THAT(
face_landmarks,
Approximately(Partially(EqualsProto(holistic_result.face_landmarks())),
/*margin=*/kAbsMargin));
} else {
ASSERT_FALSE(output_packets.contains(std::string(kFaceLandmarksStream)));
}
if (GetParam().holistic_request.is_face_blendshapes_requested) {
MP_ASSERT_OK_AND_ASSIGN(auto face_blendshapes,
FetchResult<ClassificationList>(
output_packets, kFaceBlendshapesStream));
EXPECT_THAT(face_blendshapes,
Approximately(
Partially(EqualsProto(holistic_result.face_blendshapes())),
/*margin=*/kAbsMargin));
} else {
ASSERT_FALSE(output_packets.contains(std::string(kFaceBlendshapesStream)));
}
// Check Pose landmarks
MP_ASSERT_OK_AND_ASSIGN(auto pose_landmarks,
FetchResult<NormalizedLandmarkList>(
output_packets, kPoseLandmarksStream));
EXPECT_THAT(
pose_landmarks,
Approximately(Partially(EqualsProto(holistic_result.pose_landmarks())),
/*margin=*/kAbsMargin));
// Check Hand landmarks
if (GetParam().holistic_request.is_left_hand_requested) {
MP_ASSERT_OK_AND_ASSIGN(auto left_hand_landmarks,
FetchResult<NormalizedLandmarkList>(
output_packets, kLeftHandLandmarksStream));
EXPECT_THAT(
left_hand_landmarks,
Approximately(
Partially(EqualsProto(holistic_result.left_hand_landmarks())),
/*margin=*/kAbsMargin));
} else {
ASSERT_FALSE(
output_packets.contains(std::string(kLeftHandLandmarksStream)));
}
if (GetParam().holistic_request.is_right_hand_requested) {
MP_ASSERT_OK_AND_ASSIGN(auto right_hand_landmarks,
FetchResult<NormalizedLandmarkList>(
output_packets, kRightHandLandmarksStream));
EXPECT_THAT(
right_hand_landmarks,
Approximately(
Partially(EqualsProto(holistic_result.right_hand_landmarks())),
/*margin=*/kAbsMargin));
} else {
ASSERT_FALSE(
output_packets.contains(std::string(kRightHandLandmarksStream)));
}
auto rendered_image =
output_packets.at(std::string(kRenderedImageOutStream)).Get<Image>();
MP_EXPECT_OK(SavePngTestOutput(
*rendered_image.GetImageFrameSharedPtr(),
absl::StrCat("holistic_landmark_",
RequestToString(GetParam().holistic_request))));
auto pose_segmentation_mask =
output_packets.at(std::string(kPoseSegmentationMaskStream)).Get<Image>();
cv::Mat matting_mask = mediapipe::formats::MatView(
pose_segmentation_mask.GetImageFrameSharedPtr().get());
cv::Mat visualized_mask;
matting_mask.convertTo(visualized_mask, CV_8UC1, 255);
ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
visualized_mask.cols, visualized_mask.rows,
visualized_mask.step, visualized_mask.data,
[visualized_mask](uint8_t[]) {});
MP_EXPECT_OK(
SavePngTestOutput(visualized_image, "holistic_pose_segmentation_mask"));
}
INSTANTIATE_TEST_SUITE_P(
HolisticLandmarkerGraphTest, SmokeTest,
Values(TestParams{
/* test_name= */ "UseModelBundle",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ true,
/*is_right_hand_requested= */ true,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ true,
},
},
TestParams{
/* test_name= */ "UseSeparateModelFiles",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ false,
/* holistic_request= */
{
/*is_left_hand_requested= */ true,
/*is_right_hand_requested= */ true,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ true,
},
},
TestParams{
/* test_name= */ "ModelBundleNoLeftHand",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ false,
/*is_right_hand_requested= */ true,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ true,
},
},
TestParams{
/* test_name= */ "ModelBundleNoRightHand",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ true,
/*is_right_hand_requested= */ false,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ true,
},
},
TestParams{
/* test_name= */ "ModelBundleNoHand",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ false,
/*is_right_hand_requested= */ false,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ true,
},
},
TestParams{
/* test_name= */ "ModelBundleNoFace",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ true,
/*is_right_hand_requested= */ true,
/*is_face_requested= */ false,
/*is_face_blendshapes_requested= */ false,
},
},
TestParams{
/* test_name= */ "ModelBundleNoFaceBlendshapes",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ true,
/*is_right_hand_requested= */ true,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ false,
},
}),
[](const TestParamInfo<SmokeTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,307 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h"
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
#include "mediapipe/framework/api2/stream/loopback.h"
#include "mediapipe/framework/api2/stream/merge.h"
#include "mediapipe/framework/api2/stream/presence.h"
#include "mediapipe/framework/api2/stream/rect_transformation.h"
#include "mediapipe/framework/api2/stream/segmentation_smoothing.h"
#include "mediapipe/framework/api2/stream/smoothing.h"
#include "mediapipe/framework/api2/stream/split.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionsToRect;
using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionToRect;
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::GetLoopbackData;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::IsPresent;
using ::mediapipe::api2::builder::Merge;
using ::mediapipe::api2::builder::ScaleAndMakeSquare;
using ::mediapipe::api2::builder::SmoothLandmarks;
using ::mediapipe::api2::builder::SmoothLandmarksVisibility;
using ::mediapipe::api2::builder::SmoothSegmentationMask;
using ::mediapipe::api2::builder::SplitToRanges;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::components::utils::DisallowIf;
using Size = std::pair<int, int>;
constexpr int kAuxLandmarksStartKeypointIndex = 0;
constexpr int kAuxLandmarksEndKeypointIndex = 1;
constexpr float kAuxLandmarksTargetAngle = 90;
constexpr float kRoiFromDetectionScaleFactor = 1.25f;
constexpr float kRoiFromLandmarksScaleFactor = 1.25f;
Stream<NormalizedRect> CalculateRoiFromDetections(
Stream<std::vector<Detection>> detections, Stream<Size> image_size,
Graph& graph) {
auto roi = ConvertAlignmentPointsDetectionsToRect(detections, image_size,
/*start_keypoint_index=*/0,
/*end_keypoint_index=*/1,
/*target_angle=*/90, graph);
return ScaleAndMakeSquare(
roi, image_size, /*scale_x_factor=*/kRoiFromDetectionScaleFactor,
/*scale_y_factor=*/kRoiFromDetectionScaleFactor, graph);
}
Stream<NormalizedRect> CalculateScaleRoiFromAuxiliaryLandmarks(
Stream<NormalizedLandmarkList> landmarks, Stream<Size> image_size,
Graph& graph) {
// TODO: consider calculating ROI directly from landmarks.
auto detection = ConvertLandmarksToDetection(landmarks, graph);
return ConvertAlignmentPointsDetectionToRect(
detection, image_size, kAuxLandmarksStartKeypointIndex,
kAuxLandmarksEndKeypointIndex, kAuxLandmarksTargetAngle, graph);
}
Stream<NormalizedRect> CalculateRoiFromAuxiliaryLandmarks(
Stream<NormalizedLandmarkList> landmarks, Stream<Size> image_size,
Graph& graph) {
// TODO: consider calculating ROI directly from landmarks.
auto detection = ConvertLandmarksToDetection(landmarks, graph);
auto roi = ConvertAlignmentPointsDetectionToRect(
detection, image_size, kAuxLandmarksStartKeypointIndex,
kAuxLandmarksEndKeypointIndex, kAuxLandmarksTargetAngle, graph);
return ScaleAndMakeSquare(
roi, image_size, /*scale_x_factor=*/kRoiFromLandmarksScaleFactor,
/*scale_y_factor=*/kRoiFromLandmarksScaleFactor, graph);
}
struct PoseLandmarksResult {
std::optional<Stream<NormalizedLandmarkList>> landmarks;
std::optional<Stream<LandmarkList>> world_landmarks;
std::optional<Stream<NormalizedLandmarkList>> auxiliary_landmarks;
std::optional<Stream<Image>> segmentation_mask;
};
PoseLandmarksResult RunLandmarksDetection(
Stream<Image> image, Stream<NormalizedRect> roi,
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
pose_landmarks_detector_graph_options,
const HolisticPoseTrackingRequest& request, Graph& graph) {
GenericNode& landmarks_graph = graph.AddNode(
"mediapipe.tasks.vision.pose_landmarker."
"SinglePoseLandmarksDetectorGraph");
landmarks_graph
.GetOptions<pose_landmarker::proto::PoseLandmarksDetectorGraphOptions>() =
pose_landmarks_detector_graph_options;
image >> landmarks_graph.In("IMAGE");
roi >> landmarks_graph.In("NORM_RECT");
PoseLandmarksResult result;
if (request.landmarks) {
result.landmarks =
landmarks_graph.Out("LANDMARKS").Cast<NormalizedLandmarkList>();
result.auxiliary_landmarks = landmarks_graph.Out("AUXILIARY_LANDMARKS")
.Cast<NormalizedLandmarkList>();
}
if (request.world_landmarks) {
result.world_landmarks =
landmarks_graph.Out("WORLD_LANDMARKS").Cast<LandmarkList>();
}
if (request.segmentation_mask) {
result.segmentation_mask =
landmarks_graph.Out("SEGMENTATION_MASK").Cast<Image>();
}
return result;
}
} // namespace
absl::StatusOr<HolisticPoseTrackingOutput>
TrackHolisticPoseUsingCustomPoseDetection(
Stream<Image> image, PoseDetectionFn pose_detection_fn,
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
pose_landmarks_detector_graph_options,
const HolisticPoseTrackingRequest& request, Graph& graph) {
// Calculate ROI from scratch (pose detection) or reuse one from the
// previous run if available.
auto [previous_roi, set_previous_roi_fn] =
GetLoopbackData<NormalizedRect>(/*tick=*/image, graph);
auto is_previous_roi_available = IsPresent(previous_roi, graph);
auto image_for_detection =
DisallowIf(image, is_previous_roi_available, graph);
MP_ASSIGN_OR_RETURN(auto pose_detections,
pose_detection_fn(image_for_detection, graph));
auto roi_from_detections = CalculateRoiFromDetections(
pose_detections, GetImageSize(image_for_detection, graph), graph);
// Take first non-empty.
auto roi = Merge(roi_from_detections, previous_roi, graph);
// Calculate landmarks and other outputs (if requested) in the specified ROI.
auto landmarks_detection_result = RunLandmarksDetection(
image, roi, pose_landmarks_detector_graph_options,
{
// Landmarks are required for tracking, hence force-requesting them.
.landmarks = true,
.world_landmarks = request.world_landmarks,
.segmentation_mask = request.segmentation_mask,
},
graph);
RET_CHECK(landmarks_detection_result.landmarks.has_value() &&
landmarks_detection_result.auxiliary_landmarks.has_value())
<< "Failed to calculate landmarks required for tracking.";
// Split landmarks to pose landmarks and auxiliary landmarks.
auto pose_landmarks_raw = *landmarks_detection_result.landmarks;
auto auxiliary_landmarks = *landmarks_detection_result.auxiliary_landmarks;
auto image_size = GetImageSize(image, graph);
// TODO: b/305750053 - Apply adaptive crop by adding AdaptiveCropCalculator.
// Calculate ROI from smoothed auxiliary landmarks.
auto scale_roi = CalculateScaleRoiFromAuxiliaryLandmarks(auxiliary_landmarks,
image_size, graph);
auto auxiliary_landmarks_smoothed = SmoothLandmarks(
auxiliary_landmarks, image_size, scale_roi,
{// Min cutoff 0.01 results into ~0.002 alpha in landmark EMA filter when
// landmark is static.
.min_cutoff = 0.01,
// Beta 10.0 in combintation with min_cutoff 0.01 results into ~0.68
// alpha in landmark EMA filter when landmark is moving fast.
.beta = 10.0,
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity
// EMA filter.
.derivate_cutoff = 1.0},
graph);
auto roi_from_auxiliary_landmarks = CalculateRoiFromAuxiliaryLandmarks(
auxiliary_landmarks_smoothed, image_size, graph);
// Make ROI from auxiliary landmarks to be used as "previous" ROI for a
// subsequent run.
set_previous_roi_fn(roi_from_auxiliary_landmarks);
// Populate and smooth pose landmarks if corresponding output has been
// requested.
std::optional<Stream<NormalizedLandmarkList>> pose_landmarks;
if (request.landmarks) {
pose_landmarks = SmoothLandmarksVisibility(
pose_landmarks_raw, /*low_pass_filter_alpha=*/0.1f, graph);
pose_landmarks = SmoothLandmarks(
*pose_landmarks, image_size, scale_roi,
{// Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter when
// landmark is static.
.min_cutoff = 0.05f,
// Beta 80.0 in combination with min_cutoff 0.05 results into ~0.94
// alpha in landmark EMA filter when landmark is moving fast.
.beta = 80.0f,
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity
// EMA filter.
.derivate_cutoff = 1.0f},
graph);
}
// Populate and smooth world landmarks if available.
std::optional<Stream<LandmarkList>> world_landmarks;
if (landmarks_detection_result.world_landmarks) {
world_landmarks = SplitToRanges(*landmarks_detection_result.world_landmarks,
/*ranges*/ {{0, 33}}, graph)[0];
world_landmarks = SmoothLandmarksVisibility(
*world_landmarks, /*low_pass_filter_alpha=*/0.1f, graph);
world_landmarks = SmoothLandmarks(
*world_landmarks,
/*scale_roi=*/std::nullopt,
{// Min cutoff 0.1 results into ~ 0.02 alpha in landmark EMA filter when
// landmark is static.
.min_cutoff = 0.1f,
// Beta 40.0 in combination with min_cutoff 0.1 results into ~0.8
// alpha in landmark EMA filter when landmark is moving fast.
.beta = 40.0f,
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity
// EMA filter.
.derivate_cutoff = 1.0f},
graph);
}
// Populate and smooth segmentation mask if available.
std::optional<Stream<Image>> segmentation_mask;
if (landmarks_detection_result.segmentation_mask) {
auto mask = *landmarks_detection_result.segmentation_mask;
auto [prev_mask_as_img, set_prev_mask_as_img_fn] =
GetLoopbackData<mediapipe::Image>(
/*tick=*/*landmarks_detection_result.segmentation_mask, graph);
auto mask_smoothed =
SmoothSegmentationMask(mask, prev_mask_as_img,
/*combine_with_previous_ratio=*/0.7f, graph);
set_prev_mask_as_img_fn(mask_smoothed);
segmentation_mask = mask_smoothed;
}
return {{/*landmarks=*/pose_landmarks,
/*world_landmarks=*/world_landmarks,
/*segmentation_mask=*/segmentation_mask,
/*debug_output=*/
{/*auxiliary_landmarks=*/auxiliary_landmarks_smoothed,
/*roi_from_landmarks=*/roi_from_auxiliary_landmarks,
/*detections*/ pose_detections}}};
}
absl::StatusOr<HolisticPoseTrackingOutput> TrackHolisticPose(
Stream<Image> image,
const pose_detector::proto::PoseDetectorGraphOptions&
pose_detector_graph_options,
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
pose_landmarks_detector_graph_options,
const HolisticPoseTrackingRequest& request, Graph& graph) {
PoseDetectionFn pose_detection_fn = [&pose_detector_graph_options](
Stream<Image> image, Graph& graph)
-> absl::StatusOr<Stream<std::vector<mediapipe::Detection>>> {
GenericNode& pose_detector =
graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph");
pose_detector.GetOptions<pose_detector::proto::PoseDetectorGraphOptions>() =
pose_detector_graph_options;
image >> pose_detector.In("IMAGE");
return pose_detector.Out("DETECTIONS")
.Cast<std::vector<mediapipe::Detection>>();
};
return TrackHolisticPoseUsingCustomPoseDetection(
image, pose_detection_fn, pose_landmarks_detector_graph_options, request,
graph);
}
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,110 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_
#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_
#include <functional>
#include <optional>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
// Type of pose detection function that can be used to customize pose tracking,
// by supplying the function into a corresponding `TrackPose` function overload.
//
// Function should update provided graph with node/nodes that accept image
// stream and produce stream of detections.
using PoseDetectionFn = std::function<
absl::StatusOr<api2::builder::Stream<std::vector<mediapipe::Detection>>>(
api2::builder::Stream<Image>, api2::builder::Graph&)>;
struct HolisticPoseTrackingRequest {
bool landmarks = false;
bool world_landmarks = false;
bool segmentation_mask = false;
};
struct HolisticPoseTrackingOutput {
std::optional<api2::builder::Stream<mediapipe::NormalizedLandmarkList>>
landmarks;
std::optional<api2::builder::Stream<mediapipe::LandmarkList>> world_landmarks;
std::optional<api2::builder::Stream<Image>> segmentation_mask;
struct DebugOutput {
api2::builder::Stream<mediapipe::NormalizedLandmarkList>
auxiliary_landmarks;
api2::builder::Stream<NormalizedRect> roi_from_landmarks;
api2::builder::Stream<std::vector<mediapipe::Detection>> detections;
};
DebugOutput debug_output;
};
// Updates @graph to track pose in @image.
//
// @image - ImageFrame/GpuBuffer to track pose in.
// @pose_detection_fn - pose detection function that takes @image as input and
// produces stream of pose detections.
// @pose_landmarks_detector_graph_options - options of the
// PoseLandmarksDetectorGraph used to detect the pose landmarks.
// @request - object to request specific pose tracking outputs.
// NOTE: Outputs that were not requested won't be returned and corresponding
// parts of the graph won't be genertaed all.
// @graph - graph to update.
absl::StatusOr<HolisticPoseTrackingOutput>
TrackHolisticPoseUsingCustomPoseDetection(
api2::builder::Stream<Image> image, PoseDetectionFn pose_detection_fn,
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
pose_landmarks_detector_graph_options,
const HolisticPoseTrackingRequest& request, api2::builder::Graph& graph);
// Updates @graph to track pose in @image.
//
// @image - ImageFrame/GpuBuffer to track pose in.
// @pose_detector_graph_options - options of the PoseDetectorGraph used to
// detect the pose.
// @pose_landmarks_detector_graph_options - options of the
// PoseLandmarksDetectorGraph used to detect the pose landmarks.
// @request - object to request specific pose tracking outputs.
// NOTE: Outputs that were not requested won't be returned and corresponding
// parts of the graph won't be genertaed all.
// @graph - graph to update.
absl::StatusOr<HolisticPoseTrackingOutput> TrackHolisticPose(
api2::builder::Stream<Image> image,
const pose_detector::proto::PoseDetectorGraphOptions&
pose_detector_graph_options,
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
pose_landmarks_detector_graph_options,
const HolisticPoseTrackingRequest& request, api2::builder::Graph& graph);
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_

View File

@ -0,0 +1,243 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h"
#include <memory>
#include <string>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "file/base/helpers.h"
#include "file/base/options.h"
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
#include "testing/base/public/googletest.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::file::Defaults;
using ::file::GetTextProto;
using ::mediapipe::Image;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::core::TaskRunner;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr float kAbsMargin = 0.025;
constexpr absl::string_view kTestDataDirectory =
"/mediapipe/tasks/testdata/vision/";
constexpr absl::string_view kTestImageFile = "male_full_height_hands.jpg";
constexpr absl::string_view kImageInStream = "image_in";
constexpr absl::string_view kPoseLandmarksOutStream = "pose_landmarks_out";
constexpr absl::string_view kPoseWorldLandmarksOutStream =
"pose_world_landmarks_out";
constexpr absl::string_view kRenderedImageOutStream = "rendered_image_out";
constexpr absl::string_view kHolisticResultFile =
"male_full_height_hands_result_cpu.pbtxt";
constexpr absl::string_view kHolisticPoseTrackingGraph =
"holistic_pose_tracking_graph.pbtxt";
std::string GetFilePath(absl::string_view filename) {
return file::JoinPath("./", kTestDataDirectory, filename);
}
mediapipe::LandmarksToRenderDataCalculatorOptions GetPoseRendererOptions() {
mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options;
for (const auto& connection : pose_landmarker::kPoseLandmarksConnections) {
renderer_options.add_landmark_connections(connection[0]);
renderer_options.add_landmark_connections(connection[1]);
}
renderer_options.mutable_landmark_color()->set_r(255);
renderer_options.mutable_landmark_color()->set_g(255);
renderer_options.mutable_landmark_color()->set_b(255);
renderer_options.mutable_connection_color()->set_r(255);
renderer_options.mutable_connection_color()->set_g(255);
renderer_options.mutable_connection_color()->set_b(255);
renderer_options.set_thickness(0.5);
renderer_options.set_visualize_landmark_depth(false);
return renderer_options;
}
// Helper function to create a TaskRunner.
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner() {
Graph graph;
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
pose_detector::proto::PoseDetectorGraphOptions pose_detector_graph_options;
pose_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath("pose_detection.tflite"));
pose_detector_graph_options.set_num_poses(1);
pose_landmarker::proto::PoseLandmarksDetectorGraphOptions
pose_landmarks_detector_graph_options;
pose_landmarks_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath("pose_landmark_lite.tflite"));
HolisticPoseTrackingRequest request;
request.landmarks = true;
request.world_landmarks = true;
MP_ASSIGN_OR_RETURN(
HolisticPoseTrackingOutput result,
TrackHolisticPose(image, pose_detector_graph_options,
pose_landmarks_detector_graph_options, request, graph));
auto image_size = GetImageSize(image, graph);
auto render_data = utils::RenderLandmarks(
*result.landmarks,
utils::GetRenderScale(image_size, result.debug_output.roi_from_landmarks,
0.0001, graph),
GetPoseRendererOptions(), graph);
std::vector<Stream<mediapipe::RenderData>> render_list = {render_data};
auto rendered_image =
utils::Render(
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
.SetName(kRenderedImageOutStream);
rendered_image >> graph.Out("RENDERED_IMAGE");
result.landmarks->SetName(kPoseLandmarksOutStream) >>
graph.Out("POSE_LANDMARKS");
result.world_landmarks->SetName(kPoseWorldLandmarksOutStream) >>
graph.Out("POSE_WORLD_LANDMARKS");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
return TaskRunner::Create(
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
}
// Remove fields not to be checked in the result, since the model
// generating expected result is different from the testing model.
void RemoveUncheckedResult(proto::HolisticResult& holistic_result) {
for (auto& landmark :
*holistic_result.mutable_pose_landmarks()->mutable_landmark()) {
landmark.clear_z();
landmark.clear_visibility();
landmark.clear_presence();
}
}
class HolisticPoseTrackingTest : public testing::Test {};
TEST_F(HolisticPoseTrackingTest, VerifyGraph) {
Graph graph;
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
pose_detector::proto::PoseDetectorGraphOptions pose_detector_graph_options;
pose_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath("pose_detection.tflite"));
pose_detector_graph_options.set_num_poses(1);
pose_landmarker::proto::PoseLandmarksDetectorGraphOptions
pose_landmarks_detector_graph_options;
pose_landmarks_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath("pose_landmark_lite.tflite"));
HolisticPoseTrackingRequest request;
request.landmarks = true;
request.world_landmarks = true;
MP_ASSERT_OK_AND_ASSIGN(
HolisticPoseTrackingOutput result,
TrackHolisticPose(image, pose_detector_graph_options,
pose_landmarks_detector_graph_options, request, graph));
result.landmarks->SetName(kPoseLandmarksOutStream) >>
graph.Out("POSE_LANDMARKS");
result.world_landmarks->SetName(kPoseWorldLandmarksOutStream) >>
graph.Out("POSE_WORLD_LANDMARKS");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
// Read the expected graph config.
std::string expected_graph_contents;
MP_ASSERT_OK(file::GetContents(
file::JoinPath("./", kTestDataDirectory, kHolisticPoseTrackingGraph),
&expected_graph_contents));
// Need to replace the expected graph config with the test srcdir, because
// each run has different test dir on TAP.
expected_graph_contents = absl::Substitute(
expected_graph_contents, FLAGS_test_srcdir, FLAGS_test_srcdir);
CalculatorGraphConfig expected_graph =
ParseTextProtoOrDie<CalculatorGraphConfig>(expected_graph_contents);
EXPECT_THAT(config, testing::proto::IgnoringRepeatedFieldOrdering(
testing::EqualsProto(expected_graph)));
}
TEST_F(HolisticPoseTrackingTest, SmokeTest) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(GetFilePath(kTestImageFile)));
proto::HolisticResult holistic_result;
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
Defaults()));
RemoveUncheckedResult(holistic_result);
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
MP_ASSERT_OK_AND_ASSIGN(auto output_packets,
task_runner->Process({{std::string(kImageInStream),
MakePacket<Image>(image)}}));
auto pose_landmarks = output_packets.at(std::string(kPoseLandmarksOutStream))
.Get<NormalizedLandmarkList>();
EXPECT_THAT(
pose_landmarks,
Approximately(Partially(EqualsProto(holistic_result.pose_landmarks())),
/*margin=*/kAbsMargin));
auto rendered_image =
output_packets.at(std::string(kRenderedImageOutStream)).Get<Image>();
MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(),
"pose_landmarks"));
}
} // namespace
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,44 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
mediapipe_proto_library(
name = "holistic_result_proto",
srcs = ["holistic_result.proto"],
deps = [
"//mediapipe/framework/formats:classification_proto",
"//mediapipe/framework/formats:landmark_proto",
],
)
mediapipe_proto_library(
name = "holistic_landmarker_graph_options_proto",
srcs = ["holistic_landmarker_graph_options.proto"],
deps = [
"//mediapipe/tasks/cc/core/proto:base_options_proto",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_proto",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_proto",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_proto",
],
)

View File

@ -0,0 +1,57 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package mediapipe.tasks.vision.holistic_landmarker.proto;
import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.proto";
import "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.holisticlandmarker.proto";
option java_outer_classname = "HolisticLandmarkerGraphOptionsProto";
message HolisticLandmarkerGraphOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// asset bundle file with metadata, accelerator options, etc.
core.proto.BaseOptions base_options = 1;
// Options for hand landmarks graph.
hand_landmarker.proto.HandLandmarksDetectorGraphOptions
hand_landmarks_detector_graph_options = 2;
// Options for hand roi refinement graph.
hand_landmarker.proto.HandRoiRefinementGraphOptions
hand_roi_refinement_graph_options = 3;
// Options for face detector graph.
face_detector.proto.FaceDetectorGraphOptions face_detector_graph_options = 4;
// Options for face landmarks detector graph.
face_landmarker.proto.FaceLandmarksDetectorGraphOptions
face_landmarks_detector_graph_options = 5;
// Options for pose detector graph.
pose_detector.proto.PoseDetectorGraphOptions pose_detector_graph_options = 6;
// Options for pose landmarks detector graph.
pose_landmarker.proto.PoseLandmarksDetectorGraphOptions
pose_landmarks_detector_graph_options = 7;
}

View File

@ -0,0 +1,34 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package mediapipe.tasks.vision.holistic_landmarker.proto;
import "mediapipe/framework/formats/classification.proto";
import "mediapipe/framework/formats/landmark.proto";
option java_package = "com.google.mediapipe.tasks.vision.holisticlandmarker";
option java_outer_classname = "HolisticResultProto";
message HolisticResult {
mediapipe.NormalizedLandmarkList pose_landmarks = 1;
mediapipe.LandmarkList pose_world_landmarks = 7;
mediapipe.NormalizedLandmarkList left_hand_landmarks = 2;
mediapipe.NormalizedLandmarkList right_hand_landmarks = 3;
mediapipe.NormalizedLandmarkList face_landmarks = 4;
mediapipe.ClassificationList face_blendshapes = 6;
mediapipe.NormalizedLandmarkList auxiliary_landmarks = 5;
}

View File

@ -47,13 +47,14 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite",

View File

@ -67,6 +67,7 @@ cc_binary(
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph",
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/holistic_landmarker:holistic_landmarker_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
@ -429,6 +430,7 @@ filegroup(
android_library(
name = "holisticlandmarker",
srcs = [
"holisticlandmarker/HolisticLandmarker.java",
"holisticlandmarker/HolisticLandmarkerResult.java",
],
javacopts = [
@ -439,10 +441,20 @@ android_library(
":core",
"//mediapipe/framework/formats:classification_java_proto_lite",
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:any_java_proto",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.holisticlandmarker">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,668 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.holisticlandmarker;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.facedetector.proto.FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions;
import com.google.mediapipe.tasks.vision.facelandmarker.proto.FaceLandmarksDetectorGraphOptionsProto.FaceLandmarksDetectorGraphOptions;
import com.google.mediapipe.tasks.vision.handlandmarker.proto.HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions;
import com.google.mediapipe.tasks.vision.holisticlandmarker.proto.HolisticLandmarkerGraphOptionsProto.HolisticLandmarkerGraphOptions;
import com.google.mediapipe.tasks.vision.posedetector.proto.PoseDetectorGraphOptionsProto.PoseDetectorGraphOptions;
import com.google.mediapipe.tasks.vision.poselandmarker.proto.PoseLandmarksDetectorGraphOptionsProto.PoseLandmarksDetectorGraphOptions;
import com.google.protobuf.Any;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Performs holistic landmarks detection on images.
*
* <p>This API expects a pre-trained holistic landmarks model asset bundle.
*
* <ul>
* <li>Input image {@link MPImage}
* <ul>
* <li>The image that holistic landmarks detection runs on.
* </ul>
* <li>Output {@link HolisticLandmarkerResult}
* <ul>
* <li>A HolisticLandmarkerResult containing holistic landmarks.
* </ul>
* </ul>
*/
public final class HolisticLandmarker extends BaseVisionTaskApi {
private static final String TAG = HolisticLandmarker.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final String POSE_LANDMARKS_STREAM = "pose_landmarks";
private static final String POSE_WORLD_LANDMARKS_STREAM = "pose_world_landmarks";
private static final String POSE_SEGMENTATION_MASK_STREAM = "pose_segmentation_mask";
private static final String FACE_LANDMARKS_STREAM = "face_landmarks";
private static final String FACE_BLENDSHAPES_STREAM = "extra_blendshapes";
private static final String LEFT_HAND_LANDMARKS_STREAM = "left_hand_landmarks";
private static final String LEFT_HAND_WORLD_LANDMARKS_STREAM = "left_hand_world_landmarks";
private static final String RIGHT_HAND_LANDMARKS_STREAM = "right_hand_landmarks";
private static final String RIGHT_HAND_WORLD_LANDMARKS_STREAM = "right_hand_world_landmarks";
private static final String IMAGE_OUT_STREAM_NAME = "image_out";
private static final int FACE_LANDMARKS_OUT_STREAM_INDEX = 0;
private static final int POSE_LANDMARKS_OUT_STREAM_INDEX = 1;
private static final int POSE_WORLD_LANDMARKS_OUT_STREAM_INDEX = 2;
private static final int LEFT_HAND_LANDMARKS_OUT_STREAM_INDEX = 3;
private static final int LEFT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX = 4;
private static final int RIGHT_HAND_LANDMARKS_OUT_STREAM_INDEX = 5;
private static final int RIGHT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX = 6;
private static final int IMAGE_OUT_STREAM_INDEX = 7;
private static final float DEFAULT_PRESENCE_THRESHOLD = 0.5f;
private static final float DEFAULT_SUPPRESION_THRESHOLD = 0.3f;
private static final boolean DEFAULT_OUTPUT_FACE_BLENDSHAPES = false;
private static final boolean DEFAULT_OUTPUT_SEGMENTATION_MASKS = false;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph";
@SuppressWarnings("ConstantCaseForConstants")
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME));
static {
System.loadLibrary("mediapipe_tasks_vision_jni");
}
/**
* Creates a {@link HolisticLandmarker} instance from a model asset bundle path and the default
* {@link HolisticLandmarkerOptions}.
*
* @param context an Android {@link Context}.
* @param modelAssetPath path to the holistic landmarks model with metadata in the assets.
* @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation.
*/
public static HolisticLandmarker createFromFile(Context context, String modelAssetPath) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelAssetPath).build();
return createFromOptions(
context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates a {@link HolisticLandmarker} instance from a model asset bundle file and the default
* {@link HolisticLandmarkerOptions}.
*
* @param context an Android {@link Context}.
* @param modelAssetFile the holistic landmarks model {@link File} instance.
* @throws IOException if an I/O error occurs when opening the tflite model file.
* @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation.
*/
public static HolisticLandmarker createFromFile(Context context, File modelAssetFile)
throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelAssetFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
BaseOptions baseOptions =
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
return createFromOptions(
context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
}
}
/**
* Creates a {@link HolisticLandmarker} instance from a model asset bundle buffer and the default
* {@link HolisticLandmarkerOptions}.
*
* @param context an Android {@link Context}.
* @param modelAssetBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
* detection model.
* @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation.
*/
public static HolisticLandmarker createFromBuffer(
Context context, final ByteBuffer modelAssetBuffer) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelAssetBuffer).build();
return createFromOptions(
context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates a {@link HolisticLandmarker} instance from a {@link HolisticLandmarkerOptions}.
*
* @param context an Android {@link Context}.
* @param landmarkerOptions a {@link HolisticLandmarkerOptions} instance.
* @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation.
*/
public static HolisticLandmarker createFromOptions(
Context context, HolisticLandmarkerOptions landmarkerOptions) {
List<String> outputStreams = new ArrayList<>();
outputStreams.add("FACE_LANDMARKS:" + FACE_LANDMARKS_STREAM);
outputStreams.add("POSE_LANDMARKS:" + POSE_LANDMARKS_STREAM);
outputStreams.add("POSE_WORLD_LANDMARKS:" + POSE_WORLD_LANDMARKS_STREAM);
outputStreams.add("LEFT_HAND_LANDMARKS:" + LEFT_HAND_LANDMARKS_STREAM);
outputStreams.add("LEFT_HAND_WORLD_LANDMARKS:" + LEFT_HAND_WORLD_LANDMARKS_STREAM);
outputStreams.add("RIGHT_HAND_LANDMARKS:" + RIGHT_HAND_LANDMARKS_STREAM);
outputStreams.add("RIGHT_HAND_WORLD_LANDMARKS:" + RIGHT_HAND_WORLD_LANDMARKS_STREAM);
outputStreams.add("IMAGE:" + IMAGE_OUT_STREAM_NAME);
int[] faceBlendshapesOutStreamIndex = new int[] {-1};
if (landmarkerOptions.outputFaceBlendshapes()) {
outputStreams.add("FACE_BLENDSHAPES:" + FACE_BLENDSHAPES_STREAM);
faceBlendshapesOutStreamIndex[0] = outputStreams.size() - 1;
}
int[] poseSegmentationMasksOutStreamIndex = new int[] {-1};
if (landmarkerOptions.outputPoseSegmentationMasks()) {
outputStreams.add("POSE_SEGMENTATION_MASK:" + POSE_SEGMENTATION_MASK_STREAM);
poseSegmentationMasksOutStreamIndex[0] = outputStreams.size() - 1;
}
OutputHandler<HolisticLandmarkerResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<HolisticLandmarkerResult, MPImage>() {
@Override
public HolisticLandmarkerResult convertToTaskResult(List<Packet> packets) {
// If there are no detected landmarks, just returns empty lists.
if (packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX).isEmpty()) {
return HolisticLandmarkerResult.createEmpty(
BaseVisionTaskApi.generateResultTimestampMs(
landmarkerOptions.runningMode(),
packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX)));
}
NormalizedLandmarkList faceLandmarkProtos =
PacketGetter.getProto(
packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser());
Optional<ClassificationList> faceBlendshapeProtos =
landmarkerOptions.outputFaceBlendshapes()
? Optional.of(
PacketGetter.getProto(
packets.get(faceBlendshapesOutStreamIndex[0]),
ClassificationList.parser()))
: Optional.empty();
NormalizedLandmarkList poseLandmarkProtos =
PacketGetter.getProto(
packets.get(POSE_LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser());
LandmarkList poseWorldLandmarkProtos =
PacketGetter.getProto(
packets.get(POSE_WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser());
Optional<MPImage> segmentationMask =
landmarkerOptions.outputPoseSegmentationMasks()
? Optional.of(
getSegmentationMask(packets, poseSegmentationMasksOutStreamIndex[0]))
: Optional.empty();
NormalizedLandmarkList leftHandLandmarkProtos =
PacketGetter.getProto(
packets.get(LEFT_HAND_LANDMARKS_OUT_STREAM_INDEX),
NormalizedLandmarkList.parser());
LandmarkList leftHandWorldLandmarkProtos =
PacketGetter.getProto(
packets.get(LEFT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser());
NormalizedLandmarkList rightHandLandmarkProtos =
PacketGetter.getProto(
packets.get(RIGHT_HAND_LANDMARKS_OUT_STREAM_INDEX),
NormalizedLandmarkList.parser());
LandmarkList rightHandWorldLandmarkProtos =
PacketGetter.getProto(
packets.get(RIGHT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX),
LandmarkList.parser());
return HolisticLandmarkerResult.create(
faceLandmarkProtos,
faceBlendshapeProtos,
poseLandmarkProtos,
poseWorldLandmarkProtos,
segmentationMask,
leftHandLandmarkProtos,
leftHandWorldLandmarkProtos,
rightHandLandmarkProtos,
rightHandWorldLandmarkProtos,
BaseVisionTaskApi.generateResultTimestampMs(
landmarkerOptions.runningMode(), packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX)));
}
@Override
public MPImage convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
.build();
}
});
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
landmarkerOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<HolisticLandmarkerOptions>builder()
.setTaskName(HolisticLandmarker.class.getSimpleName())
.setTaskRunningModeName(landmarkerOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(outputStreams)
.setTaskOptions(landmarkerOptions)
.setEnableFlowLimiting(landmarkerOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(),
handler);
return new HolisticLandmarker(runner, landmarkerOptions.runningMode());
}
/**
* Constructor to initialize an {@link HolisticLandmarker} from a {@link TaskRunner} and a {@link
* RunningMode}.
*
* @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
private HolisticLandmarker(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, /* normRectStreamName= */ "");
}
/**
* Performs holistic landmarks detection on the provided single image with default image
* processing options, i.e. without any rotation applied. Only use this method when the {@link
* HolisticLandmarker} is created with {@link RunningMode.IMAGE}.
*
* <p>{@link HolisticLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public HolisticLandmarkerResult detect(MPImage image) {
return detect(image, ImageProcessingOptions.builder().build());
}
/**
* Performs holistic landmarks detection on the provided single image. Only use this method when
* the {@link HolisticLandmarker} is created with {@link RunningMode.IMAGE}.
*
* <p>{@link HolisticLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public HolisticLandmarkerResult detect(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
validateImageProcessingOptions(imageProcessingOptions);
return (HolisticLandmarkerResult) processImageData(image, imageProcessingOptions);
}
/**
* Performs holistic landmarks detection on the provided video frame with default image processing
* options, i.e. without any rotation applied. Only use this method when the {@link
* HolisticLandmarker} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame"s timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link HolisticLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public HolisticLandmarkerResult detectForVideo(MPImage image, long timestampMs) {
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
* Performs holistic landmarks detection on the provided video frame. Only use this method when
* the {@link HolisticLandmarker} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame"s timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link HolisticLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public HolisticLandmarkerResult detectForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions);
return (HolisticLandmarkerResult) processVideoData(image, imageProcessingOptions, timestampMs);
}
/**
* Sends live image data to perform holistic landmarks detection with default image processing
* options, i.e. without any rotation applied, and the results will be available via the {@link
* ResultListener} provided in the {@link HolisticLandmarkerOptions}. Only use this method when
* the {@link HolisticLandmarker } is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the holistic landmarker. The input timestamps must be monotonically increasing.
*
* <p>{@link HolisticLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void detectAsync(MPImage image, long timestampMs) {
detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
* Sends live image data to perform holistic landmarks detection, and the results will be
* available via the {@link ResultListener} provided in the {@link HolisticLandmarkerOptions}.
* Only use this method when the {@link HolisticLandmarker} is created with {@link
* RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the holistic landmarker. The input timestamps must be monotonically increasing.
*
* <p>{@link HolisticLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public void detectAsync(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions);
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
}
/** Options for setting up an {@link HolisticLandmarker}. */
@AutoValue
public abstract static class HolisticLandmarkerOptions extends TaskOptions {
/** Builder for {@link HolisticLandmarkerOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the base options for the holistic landmarker task. */
public abstract Builder setBaseOptions(BaseOptions value);
/**
* Sets the running mode for the holistic landmarker task. Defaults to the image mode.
* Holistic landmarker has three modes:
*
* <ul>
* <li>IMAGE: The mode for detecting holistic landmarks on single image inputs.
* <li>VIDEO: The mode for detecting holistic landmarks on the decoded frames of a video.
* <li>LIVE_STREAM: The mode for for detecting holistic landmarks on a live stream of input
* data, such as from camera. In this mode, {@code setResultListener} must be called to
* set up a listener to receive the detection results asynchronously.
* </ul>
*/
public abstract Builder setRunningMode(RunningMode value);
/**
* Sets minimum confidence score for the face detection to be considered successful. Defaults
* to 0.5.
*/
public abstract Builder setMinFaceDetectionConfidence(Float value);
/**
* The minimum threshold for the face suppression score in the face detection. Defaults to
* 0.3.
*/
public abstract Builder setMinFaceSuppressionThreshold(Float value);
/**
* Sets minimum confidence score for the face landmark detection to be considered successful.
* Defaults to 0.5.
*/
public abstract Builder setMinFaceLandmarksConfidence(Float value);
/**
* The minimum confidence score for the pose detection to be considered successful. Defaults
* to 0.5.
*/
public abstract Builder setMinPoseDetectionConfidence(Float value);
/**
* The minimum threshold for the pose suppression score in the pose detection. Defaults to
* 0.3.
*/
public abstract Builder setMinPoseSuppressionThreshold(Float value);
/**
* The minimum confidence score for the pose landmarks detection to be considered successful.
* Defaults to 0.5.
*/
public abstract Builder setMinPoseLandmarksConfidence(Float value);
/**
* The minimum confidence score for the hand landmark detection to be considered successful.
* Defaults to 0.5.
*/
public abstract Builder setMinHandLandmarksConfidence(Float value);
/** Whether to output segmentation masks. Defaults to false. */
public abstract Builder setOutputPoseSegmentationMasks(Boolean value);
/** Whether to output face blendshapes. Defaults to false. */
public abstract Builder setOutputFaceBlendshapes(Boolean value);
/**
* Sets the result listener to receive the detection results asynchronously when the holistic
* landmarker is in the live stream mode.
*/
public abstract Builder setResultListener(
ResultListener<HolisticLandmarkerResult, MPImage> value);
/** Sets an optional error listener. */
public abstract Builder setErrorListener(ErrorListener value);
abstract HolisticLandmarkerOptions autoBuild();
/**
* Validates and builds the {@link HolisticLandmarkerOptions} instance.
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the holistic
* landmarker is in the live stream mode.
*/
public final HolisticLandmarkerOptions build() {
HolisticLandmarkerOptions options = autoBuild();
if (options.runningMode() == RunningMode.LIVE_STREAM) {
if (!options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The holistic landmarker is in the live stream mode, a user-defined result listener"
+ " must be provided in HolisticLandmarkerOptions.");
}
} else if (options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The holistic landmarker is in the image or the video mode, a user-defined result"
+ " listener shouldn't be provided in HolisticLandmarkerOptions.");
}
return options;
}
}
abstract BaseOptions baseOptions();
abstract RunningMode runningMode();
abstract Optional<Float> minFaceDetectionConfidence();
abstract Optional<Float> minFaceSuppressionThreshold();
abstract Optional<Float> minFaceLandmarksConfidence();
abstract Optional<Float> minPoseDetectionConfidence();
abstract Optional<Float> minPoseSuppressionThreshold();
abstract Optional<Float> minPoseLandmarksConfidence();
abstract Optional<Float> minHandLandmarksConfidence();
abstract Boolean outputFaceBlendshapes();
abstract Boolean outputPoseSegmentationMasks();
abstract Optional<ResultListener<HolisticLandmarkerResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener();
public static Builder builder() {
return new AutoValue_HolisticLandmarker_HolisticLandmarkerOptions.Builder()
.setRunningMode(RunningMode.IMAGE)
.setMinFaceDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD)
.setMinFaceSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD)
.setMinFaceLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD)
.setMinPoseDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD)
.setMinPoseSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD)
.setMinPoseLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD)
.setMinHandLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD)
.setOutputFaceBlendshapes(DEFAULT_OUTPUT_FACE_BLENDSHAPES)
.setOutputPoseSegmentationMasks(DEFAULT_OUTPUT_SEGMENTATION_MASKS);
}
/** Converts a {@link HolisticLandmarkerOptions} to a {@link Any} protobuf message. */
@Override
public Any convertToAnyProto() {
HolisticLandmarkerGraphOptions.Builder holisticLandmarkerGraphOptions =
HolisticLandmarkerGraphOptions.newBuilder()
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptions()))
.build());
HandLandmarksDetectorGraphOptions.Builder handLandmarksDetectorGraphOptions =
HandLandmarksDetectorGraphOptions.newBuilder();
FaceDetectorGraphOptions.Builder faceDetectorGraphOptions =
FaceDetectorGraphOptions.newBuilder();
FaceLandmarksDetectorGraphOptions.Builder faceLandmarksDetectorGraphOptions =
FaceLandmarksDetectorGraphOptions.newBuilder();
PoseDetectorGraphOptions.Builder poseDetectorGraphOptions =
PoseDetectorGraphOptions.newBuilder();
PoseLandmarksDetectorGraphOptions.Builder poseLandmarkerGraphOptions =
PoseLandmarksDetectorGraphOptions.newBuilder();
// Configure hand detector options.
minHandLandmarksConfidence()
.ifPresent(handLandmarksDetectorGraphOptions::setMinDetectionConfidence);
// Configure pose detector options.
minPoseDetectionConfidence().ifPresent(poseDetectorGraphOptions::setMinDetectionConfidence);
minPoseSuppressionThreshold().ifPresent(poseDetectorGraphOptions::setMinSuppressionThreshold);
minPoseLandmarksConfidence().ifPresent(poseLandmarkerGraphOptions::setMinDetectionConfidence);
// Configure face detector options.
minFaceDetectionConfidence().ifPresent(faceDetectorGraphOptions::setMinDetectionConfidence);
minFaceSuppressionThreshold().ifPresent(faceDetectorGraphOptions::setMinSuppressionThreshold);
minFaceLandmarksConfidence()
.ifPresent(faceLandmarksDetectorGraphOptions::setMinDetectionConfidence);
holisticLandmarkerGraphOptions
.setHandLandmarksDetectorGraphOptions(handLandmarksDetectorGraphOptions.build())
.setFaceDetectorGraphOptions(faceDetectorGraphOptions.build())
.setFaceLandmarksDetectorGraphOptions(faceLandmarksDetectorGraphOptions.build())
.setPoseDetectorGraphOptions(poseDetectorGraphOptions.build())
.setPoseLandmarksDetectorGraphOptions(poseLandmarkerGraphOptions.build());
return Any.newBuilder()
.setTypeUrl(
"type.googleapis.com/mediapipe.tasks.vision.holistic_landmarker.proto.HolisticLandmarkerGraphOptions")
.setValue(holisticLandmarkerGraphOptions.build().toByteString())
.build();
}
}
/**
* Validates that the provided {@link ImageProcessingOptions} doesn"t contain a
* region-of-interest.
*/
private static void validateImageProcessingOptions(
ImageProcessingOptions imageProcessingOptions) {
if (imageProcessingOptions.regionOfInterest().isPresent()) {
throw new IllegalArgumentException("HolisticLandmarker doesn't support region-of-interest.");
}
}
private static MPImage getSegmentationMask(List<Packet> packets, int packetIndex) {
int width = PacketGetter.getImageWidth(packets.get(packetIndex));
int height = PacketGetter.getImageHeight(packets.get(packetIndex));
ByteBuffer buffer = ByteBuffer.allocateDirect(width * height * 4);
if (!PacketGetter.getImageData(packets.get(packetIndex), buffer)) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There was an error getting the sefmentation mask.");
}
ByteBufferImageBuilder builder =
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
return builder.build();
}
}

View File

@ -15,7 +15,6 @@
package com.google.mediapipe.tasks.vision.imagesegmenter;
import android.content.Context;
import android.util.Log;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.holisticlandmarkertest"
android:versionCode="1"
android:versionName="1.0" >
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
<application
android:label="holisticlandmarkertest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />
</application>
<instrumentation
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
android:targetPackage="com.google.mediapipe.tasks.vision.holisticlandmarkertest" />
</manifest>

View File

@ -0,0 +1,19 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable this in OSS

View File

@ -0,0 +1,512 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.holisticlandmarker;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import android.content.res.AssetManager;
import android.graphics.BitmapFactory;
import android.graphics.RectF;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.common.truth.Correspondence;
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.NormalizedLandmark;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.holisticlandmarker.HolisticLandmarker.HolisticLandmarkerOptions;
import com.google.mediapipe.tasks.vision.holisticlandmarker.HolisticResultProto.HolisticResult;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Optional;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.junit.runners.Suite.SuiteClasses;
/** Test for {@link HolisticLandmarker}. */
@RunWith(Suite.class)
@SuiteClasses({HolisticLandmarkerTest.General.class, HolisticLandmarkerTest.RunningModeTest.class})
public class HolisticLandmarkerTest {
private static final String HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = "holistic_landmarker.task";
private static final String POSE_IMAGE = "male_full_height_hands.jpg";
private static final String CAT_IMAGE = "cat.jpg";
private static final String HOLISTIC_RESULT = "male_full_height_hands_result_cpu.pb";
private static final String TAG = "Holistic Landmarker Test";
private static final float FACE_LANDMARKS_ERROR_TOLERANCE = 0.03f;
private static final float FACE_BLENDSHAPES_ERROR_TOLERANCE = 0.13f;
private static final MPImage PLACEHOLDER_MASK =
new ByteBufferImageBuilder(
ByteBuffer.allocate(0), /* widht= */ 0, /* height= */ 0, MPImage.IMAGE_FORMAT_VEC32F1)
.build();
private static final int IMAGE_WIDTH = 638;
private static final int IMAGE_HEIGHT = 1000;
private static final Correspondence<NormalizedLandmark, NormalizedLandmark> VALIDATE_LANDMARRKS =
Correspondence.from(
(Correspondence.BinaryPredicate<NormalizedLandmark, NormalizedLandmark>)
(actual, expected) -> {
return Correspondence.tolerance(FACE_LANDMARKS_ERROR_TOLERANCE)
.compare(actual.x(), expected.x())
&& Correspondence.tolerance(FACE_LANDMARKS_ERROR_TOLERANCE)
.compare(actual.y(), expected.y());
},
"landmarks approximately equal to");
private static final Correspondence<Category, Category> VALIDATE_BLENDSHAPES =
Correspondence.from(
(Correspondence.BinaryPredicate<Category, Category>)
(actual, expected) ->
Correspondence.tolerance(FACE_BLENDSHAPES_ERROR_TOLERANCE)
.compare(actual.score(), expected.score())
&& actual.index() == expected.index()
&& actual.categoryName().equals(expected.categoryName()),
"face blendshapes approximately equal to");
@RunWith(AndroidJUnit4.class)
public static final class General extends HolisticLandmarkerTest {
@Test
public void detect_successWithValidModels() throws Exception {
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.build();
HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
HolisticLandmarkerResult actualResult =
holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE));
HolisticLandmarkerResult expectedResult =
getExpectedHolisticLandmarkerResult(
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void detect_successWithBlendshapes() throws Exception {
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setOutputFaceBlendshapes(true)
.build();
HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
HolisticLandmarkerResult actualResult =
holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE));
HolisticLandmarkerResult expectedResult =
getExpectedHolisticLandmarkerResult(
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ true, /* hasSegmentationMask= */ false);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void detect_successWithSegmentationMasks() throws Exception {
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setOutputPoseSegmentationMasks(true)
.build();
HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
HolisticLandmarkerResult actualResult =
holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE));
HolisticLandmarkerResult expectedResult =
getExpectedHolisticLandmarkerResult(
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ true);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void detect_successWithEmptyResult() throws Exception {
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.build();
HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
HolisticLandmarkerResult actualResult =
holisticLandmarker.detect(getImageFromAsset(CAT_IMAGE));
assertThat(actualResult.faceLandmarks()).isEmpty();
}
@Test
public void detect_failsWithRegionOfInterest() throws Exception {
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.build();
HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE), imageProcessingOptions));
assertThat(exception)
.hasMessageThat()
.contains("HolisticLandmarker doesn't support region-of-interest");
}
}
@RunWith(AndroidJUnit4.class)
public static final class RunningModeTest extends HolisticLandmarkerTest {
private void assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode runningMode)
throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setRunningMode(runningMode)
.setResultListener((HolisticLandmarkerResult, inputImage) -> {})
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener shouldn't be provided");
}
@Test
public void create_failsWithIllegalResultListenerInVideoMode() throws Exception {
assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode.VIDEO);
}
@Test
public void create_failsWithIllegalResultListenerInImageMode() throws Exception {
assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode.IMAGE);
}
@Test
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setRunningMode(RunningMode.LIVE_STREAM)
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener must be provided");
}
@Test
public void detect_failsWithCallingWrongApiInImageMode() throws Exception {
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setRunningMode(RunningMode.IMAGE)
.build();
HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() ->
holisticLandmarker.detectForVideo(
getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
holisticLandmarker.detectAsync(
getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void detect_failsWithCallingWrongApiInVideoMode() throws Exception {
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setRunningMode(RunningMode.VIDEO)
.build();
HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
holisticLandmarker.detectAsync(
getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((HolisticLandmarkerResult, inputImage) -> {})
.build();
HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
holisticLandmarker.detectForVideo(
getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@Test
public void detect_successWithImageMode() throws Exception {
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setRunningMode(RunningMode.IMAGE)
.build();
HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
HolisticLandmarkerResult actualResult =
holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE));
HolisticLandmarkerResult expectedResult =
getExpectedHolisticLandmarkerResult(
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void detect_successWithVideoMode() throws Exception {
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setRunningMode(RunningMode.VIDEO)
.build();
HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
HolisticLandmarkerResult expectedResult =
getExpectedHolisticLandmarkerResult(
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false);
for (int i = 0; i < 3; i++) {
HolisticLandmarkerResult actualResult =
holisticLandmarker.detectForVideo(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ i);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
}
@Test
public void detect_failsWithOutOfOrderInputTimestamps() throws Exception {
MPImage image = getImageFromAsset(POSE_IMAGE);
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((actualResult, inputImage) -> {})
.build();
try (HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options)) {
holisticLandmarker.detectAsync(image, /* timestampsMs= */ 1);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> holisticLandmarker.detectAsync(image, /* timestampsMs= */ 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
}
}
@Test
public void detect_successWithLiveSteamMode() throws Exception {
MPImage image = getImageFromAsset(POSE_IMAGE);
HolisticLandmarkerResult expectedResult =
getExpectedHolisticLandmarkerResult(
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false);
HolisticLandmarkerOptions options =
HolisticLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(actualResult, inputImage) -> {
assertActualResultApproximatelyEqualsToExpectedResult(
actualResult, expectedResult);
assertImageSizeIsExpected(inputImage);
})
.build();
try (HolisticLandmarker holisticLandmarker =
HolisticLandmarker.createFromOptions(
ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; i++) {
holisticLandmarker.detectAsync(image, /* timestampsMs= */ i);
}
}
}
}
private static MPImage getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
}
private static HolisticLandmarkerResult getExpectedHolisticLandmarkerResult(
String resultPath, boolean hasFaceBlendshapes, boolean hasSegmentationMask) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
HolisticResult holisticResult = HolisticResult.parseFrom(assetManager.open(resultPath));
Optional<ClassificationList> blendshapes =
hasFaceBlendshapes
? Optional.of(holisticResult.getFaceBlendshapes())
: Optional.<ClassificationList>empty();
Optional<MPImage> segmentationMask =
hasSegmentationMask ? Optional.of(PLACEHOLDER_MASK) : Optional.<MPImage>empty();
return HolisticLandmarkerResult.create(
holisticResult.getFaceLandmarks(),
blendshapes,
holisticResult.getPoseLandmarks(),
LandmarkList.getDefaultInstance(),
segmentationMask,
holisticResult.getLeftHandLandmarks(),
LandmarkList.getDefaultInstance(),
holisticResult.getRightHandLandmarks(),
LandmarkList.getDefaultInstance(),
/* timestampMs= */ 0);
}
private static void assertActualResultApproximatelyEqualsToExpectedResult(
HolisticLandmarkerResult actualResult, HolisticLandmarkerResult expectedResult) {
// Expects to have the same number of holistics detected.
assertThat(actualResult.faceLandmarks()).hasSize(expectedResult.faceLandmarks().size());
assertThat(actualResult.faceBlendshapes().isPresent())
.isEqualTo(expectedResult.faceBlendshapes().isPresent());
assertThat(actualResult.poseLandmarks()).hasSize(expectedResult.poseLandmarks().size());
assertThat(actualResult.segmentationMask().isPresent())
.isEqualTo(expectedResult.segmentationMask().isPresent());
assertThat(actualResult.leftHandLandmarks()).hasSize(expectedResult.leftHandLandmarks().size());
assertThat(actualResult.rightHandLandmarks())
.hasSize(expectedResult.rightHandLandmarks().size());
// Actual face landmarks match expected face landmarks.
assertThat(actualResult.faceLandmarks())
.comparingElementsUsing(VALIDATE_LANDMARRKS)
.containsExactlyElementsIn(expectedResult.faceLandmarks());
// Actual face blendshapes match expected face blendshapes.
if (actualResult.faceBlendshapes().isPresent()) {
assertThat(actualResult.faceBlendshapes().get())
.comparingElementsUsing(VALIDATE_BLENDSHAPES)
.containsExactlyElementsIn(expectedResult.faceBlendshapes().get());
}
// Actual pose landmarks match expected pose landmarks.
assertThat(actualResult.poseLandmarks())
.comparingElementsUsing(VALIDATE_LANDMARRKS)
.containsExactlyElementsIn(expectedResult.poseLandmarks());
if (actualResult.segmentationMask().isPresent()) {
assertImageSizeIsExpected(actualResult.segmentationMask().get());
}
// Actual left hand landmarks match expected left hand landmarks.
assertThat(actualResult.leftHandLandmarks())
.comparingElementsUsing(VALIDATE_LANDMARRKS)
.containsExactlyElementsIn(expectedResult.leftHandLandmarks());
// Actual right hand landmarks match expected right hand landmarks.
assertThat(actualResult.rightHandLandmarks())
.comparingElementsUsing(VALIDATE_LANDMARRKS)
.containsExactlyElementsIn(expectedResult.rightHandLandmarks());
}
private static void assertImageSizeIsExpected(MPImage inputImage) {
assertThat(inputImage).isNotNull();
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
}
}

View File

@ -225,6 +225,7 @@ filegroup(
"hand_detector_result_one_hand.pbtxt",
"hand_detector_result_one_hand_rotated.pbtxt",
"hand_detector_result_two_hands.pbtxt",
"male_full_height_hands_result_cpu.pbtxt",
"pointing_up_landmarks.pbtxt",
"pointing_up_rotated_landmarks.pbtxt",
"portrait_expected_detection.pbtxt",

View File

@ -66,7 +66,7 @@ const vision = await FilesetResolver.forVisionTasks(
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision/wasm"
);
const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision,
"hhttps://storage.googleapis.com/mediapipe-models/gesture_recognizer/gesture_recognizer/float16/1/gesture_recognizer.task"
"https://storage.googleapis.com/mediapipe-models/gesture_recognizer/gesture_recognizer/float16/1/gesture_recognizer.task"
);
const image = document.getElementById("image") as HTMLImageElement;
const recognitions = gestureRecognizer.recognize(image);

View File

@ -391,6 +391,7 @@ export class DrawingUtils {
drawCategoryMask(
mask: MPMask, categoryToColorMap: RGBAColor[],
background?: RGBAColor|ImageSource): void;
/** @export */
drawCategoryMask(
mask: MPMask, categoryToColorMap: CategoryToColorMap,
background: RGBAColor|ImageSource = [0, 0, 0, 255]): void {
@ -480,6 +481,7 @@ export class DrawingUtils {
* frame, you can reduce the cost of re-uploading these images by passing a
* `HTMLCanvasElement` instead.
*
* @export
* @param mask A confidence mask that was returned from a segmentation task.
* @param defaultTexture An image or a four-channel color that will be used
* when confidence values are low.