Merge branch 'master' into c-landmarker-apis
This commit is contained in:
commit
6ed5e3d0df
|
@ -7,4 +7,4 @@ tensorflow-addons
|
|||
tensorflow-datasets
|
||||
tensorflow-hub
|
||||
tensorflow-text
|
||||
tf-models-official>=2.13.1
|
||||
tf-models-official>=2.13.2
|
||||
|
|
|
@ -125,7 +125,8 @@ def draw_landmarks(
|
|||
color=RED_COLOR),
|
||||
connection_drawing_spec: Union[DrawingSpec,
|
||||
Mapping[Tuple[int, int],
|
||||
DrawingSpec]] = DrawingSpec()):
|
||||
DrawingSpec]] = DrawingSpec(),
|
||||
is_drawing_landmarks: bool = True):
|
||||
"""Draws the landmarks and the connections on the image.
|
||||
|
||||
Args:
|
||||
|
@ -142,6 +143,8 @@ def draw_landmarks(
|
|||
connections to the DrawingSpecs that specifies the connections' drawing
|
||||
settings such as color and line thickness. If this argument is explicitly
|
||||
set to None, no landmark connections will be drawn.
|
||||
is_drawing_landmarks: Whether to draw landmarks. If set false, skip drawing
|
||||
landmarks, only contours will be drawed.
|
||||
|
||||
Raises:
|
||||
ValueError: If one of the followings:
|
||||
|
@ -181,7 +184,7 @@ def draw_landmarks(
|
|||
drawing_spec.thickness)
|
||||
# Draws landmark points after finishing the connection lines, which is
|
||||
# aesthetically better.
|
||||
if landmark_drawing_spec:
|
||||
if is_drawing_landmarks and landmark_drawing_spec:
|
||||
for idx, landmark_px in idx_to_coordinates.items():
|
||||
drawing_spec = landmark_drawing_spec[idx] if isinstance(
|
||||
landmark_drawing_spec, Mapping) else landmark_drawing_spec
|
||||
|
|
|
@ -70,6 +70,60 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rect",
|
||||
hdrs = ["rect.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rect_converter",
|
||||
srcs = ["rect_converter.cc"],
|
||||
hdrs = ["rect_converter.h"],
|
||||
deps = [
|
||||
":rect",
|
||||
"//mediapipe/tasks/cc/components/containers:rect",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "rect_converter_test",
|
||||
srcs = ["rect_converter_test.cc"],
|
||||
deps = [
|
||||
":rect",
|
||||
":rect_converter",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/tasks/cc/components/containers:rect",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "keypoint",
|
||||
hdrs = ["keypoint.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "keypoint_converter",
|
||||
srcs = ["keypoint_converter.cc"],
|
||||
hdrs = ["keypoint_converter.h"],
|
||||
deps = [
|
||||
":keypoint",
|
||||
"//mediapipe/tasks/cc/components/containers:keypoint",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "keypoint_converter_test",
|
||||
srcs = ["keypoint_converter_test.cc"],
|
||||
deps = [
|
||||
":keypoint",
|
||||
":keypoint_converter",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/tasks/cc/components/containers:keypoint",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "classification_result",
|
||||
hdrs = ["classification_result.h"],
|
||||
|
@ -99,6 +153,40 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "detection_result",
|
||||
hdrs = ["detection_result.h"],
|
||||
deps = [":rect"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "detection_result_converter",
|
||||
srcs = ["detection_result_converter.cc"],
|
||||
hdrs = ["detection_result_converter.h"],
|
||||
deps = [
|
||||
":category",
|
||||
":category_converter",
|
||||
":detection_result",
|
||||
":keypoint",
|
||||
":keypoint_converter",
|
||||
":rect",
|
||||
":rect_converter",
|
||||
"//mediapipe/tasks/cc/components/containers:detection_result",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "detection_result_converter_test",
|
||||
srcs = ["detection_result_converter_test.cc"],
|
||||
deps = [
|
||||
":detection_result",
|
||||
":detection_result_converter",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/tasks/cc/components/containers:detection_result",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedding_result",
|
||||
hdrs = ["embedding_result.h"],
|
||||
|
|
63
mediapipe/tasks/c/components/containers/detection_result.h
Normal file
63
mediapipe/tasks/c/components/containers/detection_result.h
Normal file
|
@ -0,0 +1,63 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/rect.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Detection for a single bounding box.
|
||||
struct Detection {
|
||||
// An array of detected categories.
|
||||
struct Category* categories;
|
||||
|
||||
// The number of elements in the categories array.
|
||||
uint32_t categories_count;
|
||||
|
||||
// The bounding box location.
|
||||
struct MPRect bounding_box;
|
||||
|
||||
// Optional list of keypoints associated with the detection. Keypoints
|
||||
// represent interesting points related to the detection. For example, the
|
||||
// keypoints represent the eye, ear and mouth from face detection model. Or
|
||||
// in the template matching detection, e.g. KNIFT, they can represent the
|
||||
// feature points for template matching.
|
||||
// `nullptr` if keypoints is not present.
|
||||
struct NormalizedKeypoint* keypoints;
|
||||
|
||||
// The number of elements in the keypoints array. 0 if keypoints do not exist.
|
||||
uint32_t keypoints_count;
|
||||
};
|
||||
|
||||
// Detection results of a model.
|
||||
struct DetectionResult {
|
||||
// An array of Detections.
|
||||
struct Detection* detections;
|
||||
|
||||
// The number of detections in the detections array.
|
||||
uint32_t detections_count;
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_
|
|
@ -0,0 +1,86 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/detection_result_converter.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/category.h"
|
||||
#include "mediapipe/tasks/c/components/containers/category_converter.h"
|
||||
#include "mediapipe/tasks/c/components/containers/detection_result.h"
|
||||
#include "mediapipe/tasks/c/components/containers/keypoint.h"
|
||||
#include "mediapipe/tasks/c/components/containers/keypoint_converter.h"
|
||||
#include "mediapipe/tasks/c/components/containers/rect_converter.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
void CppConvertToDetection(
|
||||
const mediapipe::tasks::components::containers::Detection& in,
|
||||
::Detection* out) {
|
||||
out->categories_count = in.categories.size();
|
||||
out->categories = new Category[out->categories_count];
|
||||
for (size_t i = 0; i < out->categories_count; ++i) {
|
||||
CppConvertToCategory(in.categories[i], &out->categories[i]);
|
||||
}
|
||||
|
||||
CppConvertToRect(in.bounding_box, &out->bounding_box);
|
||||
|
||||
if (in.keypoints.has_value()) {
|
||||
auto& keypoints = in.keypoints.value();
|
||||
out->keypoints_count = keypoints.size();
|
||||
out->keypoints = new NormalizedKeypoint[out->keypoints_count];
|
||||
for (size_t i = 0; i < out->keypoints_count; ++i) {
|
||||
CppConvertToNormalizedKeypoint(keypoints[i], &out->keypoints[i]);
|
||||
}
|
||||
} else {
|
||||
out->keypoints = nullptr;
|
||||
out->keypoints_count = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void CppConvertToDetectionResult(
|
||||
const mediapipe::tasks::components::containers::DetectionResult& in,
|
||||
::DetectionResult* out) {
|
||||
out->detections_count = in.detections.size();
|
||||
out->detections = new ::Detection[out->detections_count];
|
||||
for (size_t i = 0; i < out->detections_count; ++i) {
|
||||
CppConvertToDetection(in.detections[i], &out->detections[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Functions to free the memory of C structures.
|
||||
void CppCloseDetection(::Detection* in) {
|
||||
for (size_t i = 0; i < in->categories_count; ++i) {
|
||||
CppCloseCategory(&in->categories[i]);
|
||||
}
|
||||
delete[] in->categories;
|
||||
in->categories = nullptr;
|
||||
for (size_t i = 0; i < in->keypoints_count; ++i) {
|
||||
CppCloseNormalizedKeypoint(&in->keypoints[i]);
|
||||
}
|
||||
delete[] in->keypoints;
|
||||
in->keypoints = nullptr;
|
||||
}
|
||||
|
||||
void CppCloseDetectionResult(::DetectionResult* in) {
|
||||
for (size_t i = 0; i < in->detections_count; ++i) {
|
||||
CppCloseDetection(&in->detections[i]);
|
||||
}
|
||||
delete[] in->detections;
|
||||
in->detections = nullptr;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_
|
||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/detection_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
void CppConvertToDetection(
|
||||
const mediapipe::tasks::components::containers::Detection& in,
|
||||
Detection* out);
|
||||
|
||||
void CppConvertToDetectionResult(
|
||||
const mediapipe::tasks::components::containers::DetectionResult& in,
|
||||
DetectionResult* out);
|
||||
|
||||
void CppCloseDetection(Detection* in);
|
||||
|
||||
void CppCloseDetectionResult(DetectionResult* in);
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_
|
|
@ -0,0 +1,74 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/detection_result_converter.h"
|
||||
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/c/components/containers/detection_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
TEST(DetectionResultConverterTest, ConvertsDetectionResultCustomCategory) {
|
||||
mediapipe::tasks::components::containers::DetectionResult
|
||||
cpp_detection_result = {/* detections= */ {
|
||||
{/* categories= */ {{/* index= */ 1, /* score= */ 0.1,
|
||||
/* category_name= */ "cat",
|
||||
/* display_name= */ "cat"}},
|
||||
/* bounding_box= */ {10, 11, 12, 13},
|
||||
{/* keypoints */ {{0.1, 0.1, "foo", 0.5}}}}}};
|
||||
|
||||
DetectionResult c_detection_result;
|
||||
CppConvertToDetectionResult(cpp_detection_result, &c_detection_result);
|
||||
EXPECT_NE(c_detection_result.detections, nullptr);
|
||||
EXPECT_EQ(c_detection_result.detections_count, 1);
|
||||
EXPECT_NE(c_detection_result.detections[0].categories, nullptr);
|
||||
EXPECT_EQ(c_detection_result.detections[0].categories_count, 1);
|
||||
EXPECT_EQ(c_detection_result.detections[0].bounding_box.left, 10);
|
||||
EXPECT_EQ(c_detection_result.detections[0].bounding_box.top, 11);
|
||||
EXPECT_EQ(c_detection_result.detections[0].bounding_box.right, 12);
|
||||
EXPECT_EQ(c_detection_result.detections[0].bounding_box.bottom, 13);
|
||||
EXPECT_NE(c_detection_result.detections[0].keypoints, nullptr);
|
||||
|
||||
CppCloseDetectionResult(&c_detection_result);
|
||||
}
|
||||
|
||||
TEST(DetectionResultConverterTest, ConvertsDetectionResultNoCategory) {
|
||||
mediapipe::tasks::components::containers::DetectionResult
|
||||
cpp_detection_result = {/* detections= */ {/* categories= */ {}}};
|
||||
|
||||
DetectionResult c_detection_result;
|
||||
CppConvertToDetectionResult(cpp_detection_result, &c_detection_result);
|
||||
EXPECT_NE(c_detection_result.detections, nullptr);
|
||||
EXPECT_EQ(c_detection_result.detections_count, 1);
|
||||
EXPECT_NE(c_detection_result.detections[0].categories, nullptr);
|
||||
EXPECT_EQ(c_detection_result.detections[0].categories_count, 0);
|
||||
|
||||
CppCloseDetectionResult(&c_detection_result);
|
||||
}
|
||||
|
||||
TEST(DetectionResultConverterTest, FreesMemory) {
|
||||
mediapipe::tasks::components::containers::DetectionResult
|
||||
cpp_detection_result = {/* detections= */ {{/* categories= */ {}}}};
|
||||
|
||||
DetectionResult c_detection_result;
|
||||
CppConvertToDetectionResult(cpp_detection_result, &c_detection_result);
|
||||
EXPECT_NE(c_detection_result.detections, nullptr);
|
||||
|
||||
CppCloseDetectionResult(&c_detection_result);
|
||||
EXPECT_EQ(c_detection_result.detections, nullptr);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
46
mediapipe/tasks/c/components/containers/keypoint.h
Normal file
46
mediapipe/tasks/c/components/containers/keypoint.h
Normal file
|
@ -0,0 +1,46 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_
|
||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// A keypoint, defined by the coordinates (x, y), normalized by the image
|
||||
// dimensions.
|
||||
struct NormalizedKeypoint {
|
||||
// x in normalized image coordinates.
|
||||
float x;
|
||||
|
||||
// y in normalized image coordinates.
|
||||
float y;
|
||||
|
||||
// Optional label of the keypoint. `nullptr` if the label is not present.
|
||||
char* label;
|
||||
|
||||
// Optional score of the keypoint.
|
||||
float score;
|
||||
|
||||
// `True` if the score is valid.
|
||||
bool has_score;
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_
|
|
@ -0,0 +1,45 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/keypoint_converter.h"
|
||||
|
||||
#include <string.h> // IWYU pragma: for open source compule
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/keypoint.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
void CppConvertToNormalizedKeypoint(
|
||||
const mediapipe::tasks::components::containers::NormalizedKeypoint& in,
|
||||
NormalizedKeypoint* out) {
|
||||
out->x = in.x;
|
||||
out->y = in.y;
|
||||
|
||||
out->label = in.label.has_value() ? strdup(in.label->c_str()) : nullptr;
|
||||
out->has_score = in.score.has_value();
|
||||
out->score = out->has_score ? in.score.value() : 0;
|
||||
}
|
||||
|
||||
void CppCloseNormalizedKeypoint(NormalizedKeypoint* keypoint) {
|
||||
if (keypoint && keypoint->label) {
|
||||
free(keypoint->label);
|
||||
keypoint->label = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
32
mediapipe/tasks/c/components/containers/keypoint_converter.h
Normal file
32
mediapipe/tasks/c/components/containers/keypoint_converter.h
Normal file
|
@ -0,0 +1,32 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_
|
||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/keypoint.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
void CppConvertToNormalizedKeypoint(
|
||||
const mediapipe::tasks::components::containers::NormalizedKeypoint& in,
|
||||
NormalizedKeypoint* out);
|
||||
|
||||
void CppCloseNormalizedKeypoint(NormalizedKeypoint* keypoint);
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_
|
|
@ -0,0 +1,52 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/keypoint_converter.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/c/components/containers/keypoint.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
constexpr float kPrecision = 1e-6;
|
||||
|
||||
TEST(KeypointConverterTest, ConvertsKeypointCustomValues) {
|
||||
mediapipe::tasks::components::containers::NormalizedKeypoint cpp_keypoint = {
|
||||
0.1, 0.2, "foo", 0.5};
|
||||
|
||||
NormalizedKeypoint c_keypoint;
|
||||
CppConvertToNormalizedKeypoint(cpp_keypoint, &c_keypoint);
|
||||
EXPECT_NEAR(c_keypoint.x, 0.1f, kPrecision);
|
||||
EXPECT_NEAR(c_keypoint.y, 0.2f, kPrecision);
|
||||
EXPECT_EQ(std::string(c_keypoint.label), "foo");
|
||||
EXPECT_NEAR(c_keypoint.score, 0.5f, kPrecision);
|
||||
CppCloseNormalizedKeypoint(&c_keypoint);
|
||||
}
|
||||
|
||||
TEST(KeypointConverterTest, FreesMemory) {
|
||||
mediapipe::tasks::components::containers::NormalizedKeypoint cpp_keypoint = {
|
||||
0.1, 0.2, "foo", 0.5};
|
||||
|
||||
NormalizedKeypoint c_keypoint;
|
||||
CppConvertToNormalizedKeypoint(cpp_keypoint, &c_keypoint);
|
||||
EXPECT_NE(c_keypoint.label, nullptr);
|
||||
CppCloseNormalizedKeypoint(&c_keypoint);
|
||||
EXPECT_EQ(c_keypoint.label, nullptr);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
46
mediapipe/tasks/c/components/containers/rect.h
Normal file
46
mediapipe/tasks/c/components/containers/rect.h
Normal file
|
@ -0,0 +1,46 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_
|
||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Defines a rectangle, used e.g. as part of detection results or as input
|
||||
// region-of-interest.
|
||||
struct MPRect {
|
||||
int left;
|
||||
int top;
|
||||
int bottom;
|
||||
int right;
|
||||
};
|
||||
|
||||
// The coordinates are normalized wrt the image dimensions, i.e. generally in
|
||||
// [0,1] but they may exceed these bounds if describing a region overlapping the
|
||||
// image. The origin is on the top-left corner of the image.
|
||||
struct MPRectF {
|
||||
float left;
|
||||
float top;
|
||||
float bottom;
|
||||
float right;
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_
|
41
mediapipe/tasks/c/components/containers/rect_converter.cc
Normal file
41
mediapipe/tasks/c/components/containers/rect_converter.cc
Normal file
|
@ -0,0 +1,41 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/rect_converter.h"
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
// Converts a C++ Rect to a C Rect.
|
||||
void CppConvertToRect(const mediapipe::tasks::components::containers::Rect& in,
|
||||
struct MPRect* out) {
|
||||
out->left = in.left;
|
||||
out->top = in.top;
|
||||
out->right = in.right;
|
||||
out->bottom = in.bottom;
|
||||
}
|
||||
|
||||
// Converts a C++ RectF to a C RectF.
|
||||
void CppConvertToRectF(
|
||||
const mediapipe::tasks::components::containers::RectF& in, MPRectF* out) {
|
||||
out->left = in.left;
|
||||
out->top = in.top;
|
||||
out->right = in.right;
|
||||
out->bottom = in.bottom;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
32
mediapipe/tasks/c/components/containers/rect_converter.h
Normal file
32
mediapipe/tasks/c/components/containers/rect_converter.h
Normal file
|
@ -0,0 +1,32 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_
|
||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
void CppConvertToRect(const mediapipe::tasks::components::containers::Rect& in,
|
||||
MPRect* out);
|
||||
|
||||
void CppConvertToRectF(
|
||||
const mediapipe::tasks::components::containers::RectF& in, MPRectF* out);
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_
|
|
@ -0,0 +1,47 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/rect_converter.h"
|
||||
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/c/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
TEST(RectConverterTest, ConvertsRectCustomValues) {
|
||||
mediapipe::tasks::components::containers::Rect cpp_rect = {0, 1, 2, 3};
|
||||
|
||||
MPRect c_rect;
|
||||
CppConvertToRect(cpp_rect, &c_rect);
|
||||
EXPECT_EQ(c_rect.left, 0);
|
||||
EXPECT_EQ(c_rect.top, 1);
|
||||
EXPECT_EQ(c_rect.right, 2);
|
||||
EXPECT_EQ(c_rect.bottom, 3);
|
||||
}
|
||||
|
||||
TEST(RectFConverterTest, ConvertsRectFCustomValues) {
|
||||
mediapipe::tasks::components::containers::RectF cpp_rect = {0.1, 0.2, 0.3,
|
||||
0.4};
|
||||
|
||||
MPRectF c_rect;
|
||||
CppConvertToRectF(cpp_rect, &c_rect);
|
||||
EXPECT_FLOAT_EQ(c_rect.left, 0.1);
|
||||
EXPECT_FLOAT_EQ(c_rect.top, 0.2);
|
||||
EXPECT_FLOAT_EQ(c_rect.right, 0.3);
|
||||
EXPECT_FLOAT_EQ(c_rect.bottom, 0.4);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
|
@ -60,12 +60,12 @@ struct ImageClassifierOptions {
|
|||
//
|
||||
// A caller is responsible for closing image classifier result.
|
||||
typedef void (*result_callback_fn)(ImageClassifierResult* result,
|
||||
const MpImage image, int64_t timestamp_ms,
|
||||
const MpImage& image, int64_t timestamp_ms,
|
||||
char* error_msg);
|
||||
result_callback_fn result_callback;
|
||||
};
|
||||
|
||||
// Creates an ImageClassifier from provided `options`.
|
||||
// Creates an ImageClassifier from the provided `options`.
|
||||
// Returns a pointer to the image classifier on success.
|
||||
// If an error occurs, returns `nullptr` and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
|
|
|
@ -142,7 +142,7 @@ TEST(ImageClassifierTest, VideoModeTest) {
|
|||
// timestamp is greater than the previous one.
|
||||
struct LiveStreamModeCallback {
|
||||
static int64_t last_timestamp;
|
||||
static void Fn(ImageClassifierResult* classifier_result, const MpImage image,
|
||||
static void Fn(ImageClassifierResult* classifier_result, const MpImage& image,
|
||||
int64_t timestamp, char* error_msg) {
|
||||
ASSERT_NE(classifier_result, nullptr);
|
||||
ASSERT_EQ(error_msg, nullptr);
|
||||
|
|
|
@ -62,12 +62,12 @@ struct ImageEmbedderOptions {
|
|||
//
|
||||
// A caller is responsible for closing image embedder result.
|
||||
typedef void (*result_callback_fn)(ImageEmbedderResult* result,
|
||||
const MpImage image, int64_t timestamp_ms,
|
||||
const MpImage& image, int64_t timestamp_ms,
|
||||
char* error_msg);
|
||||
result_callback_fn result_callback;
|
||||
};
|
||||
|
||||
// Creates an ImageEmbedder from provided `options`.
|
||||
// Creates an ImageEmbedder from the provided `options`.
|
||||
// Returns a pointer to the image embedder on success.
|
||||
// If an error occurs, returns `nullptr` and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
|
|
|
@ -199,7 +199,7 @@ TEST(ImageEmbedderTest, VideoModeTest) {
|
|||
// timestamp is greater than the previous one.
|
||||
struct LiveStreamModeCallback {
|
||||
static int64_t last_timestamp;
|
||||
static void Fn(ImageEmbedderResult* embedder_result, const MpImage image,
|
||||
static void Fn(ImageEmbedderResult* embedder_result, const MpImage& image,
|
||||
int64_t timestamp, char* error_msg) {
|
||||
ASSERT_NE(embedder_result, nullptr);
|
||||
ASSERT_EQ(error_msg, nullptr);
|
||||
|
|
65
mediapipe/tasks/c/vision/object_detector/BUILD
Normal file
65
mediapipe/tasks/c/vision/object_detector/BUILD
Normal file
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "object_detector_lib",
|
||||
srcs = ["object_detector.cc"],
|
||||
hdrs = ["object_detector.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/tasks/c/components/containers:detection_result",
|
||||
"//mediapipe/tasks/c/components/containers:detection_result_converter",
|
||||
"//mediapipe/tasks/c/core:base_options",
|
||||
"//mediapipe/tasks/c/core:base_options_converter",
|
||||
"//mediapipe/tasks/c/vision/core:common",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/object_detector",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_utils",
|
||||
"@com_google_absl//absl/log:absl_log",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "object_detector_test",
|
||||
srcs = ["object_detector_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/tasks/testdata/vision:test_images",
|
||||
"//mediapipe/tasks/testdata/vision:test_models",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":object_detector_lib",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/tasks/c/components/containers:category",
|
||||
"//mediapipe/tasks/c/vision/core:common",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_utils",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
290
mediapipe/tasks/c/vision/object_detector/object_detector.cc
Normal file
290
mediapipe/tasks/c/vision/object_detector/object_detector.cc
Normal file
|
@ -0,0 +1,290 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/vision/object_detector/object_detector.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/tasks/c/components/containers/detection_result_converter.h"
|
||||
#include "mediapipe/tasks/c/core/base_options_converter.h"
|
||||
#include "mediapipe/tasks/c/vision/core/common.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/object_detector/object_detector.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
namespace mediapipe::tasks::c::vision::object_detector {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::c::components::containers::CppCloseDetectionResult;
|
||||
using ::mediapipe::tasks::c::components::containers::
|
||||
CppConvertToDetectionResult;
|
||||
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
|
||||
using ::mediapipe::tasks::vision::CreateImageFromBuffer;
|
||||
using ::mediapipe::tasks::vision::ObjectDetector;
|
||||
using ::mediapipe::tasks::vision::core::RunningMode;
|
||||
typedef ::mediapipe::tasks::vision::ObjectDetectorResult
|
||||
CppObjectDetectorResult;
|
||||
|
||||
int CppProcessError(absl::Status status, char** error_msg) {
|
||||
if (error_msg) {
|
||||
*error_msg = strdup(status.ToString().c_str());
|
||||
}
|
||||
return status.raw_code();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void CppConvertToDetectorOptions(
|
||||
const ObjectDetectorOptions& in,
|
||||
mediapipe::tasks::vision::ObjectDetectorOptions* out) {
|
||||
out->display_names_locale =
|
||||
in.display_names_locale ? std::string(in.display_names_locale) : "en";
|
||||
out->max_results = in.max_results;
|
||||
out->score_threshold = in.score_threshold;
|
||||
out->category_allowlist =
|
||||
std::vector<std::string>(in.category_allowlist_count);
|
||||
for (uint32_t i = 0; i < in.category_allowlist_count; ++i) {
|
||||
out->category_allowlist[i] = in.category_allowlist[i];
|
||||
}
|
||||
out->category_denylist = std::vector<std::string>(in.category_denylist_count);
|
||||
for (uint32_t i = 0; i < in.category_denylist_count; ++i) {
|
||||
out->category_denylist[i] = in.category_denylist[i];
|
||||
}
|
||||
}
|
||||
|
||||
ObjectDetector* CppObjectDetectorCreate(const ObjectDetectorOptions& options,
|
||||
char** error_msg) {
|
||||
auto cpp_options =
|
||||
std::make_unique<::mediapipe::tasks::vision::ObjectDetectorOptions>();
|
||||
|
||||
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
|
||||
CppConvertToDetectorOptions(options, cpp_options.get());
|
||||
cpp_options->running_mode = static_cast<RunningMode>(options.running_mode);
|
||||
|
||||
// Enable callback for processing live stream data when the running mode is
|
||||
// set to RunningMode::LIVE_STREAM.
|
||||
if (cpp_options->running_mode == RunningMode::LIVE_STREAM) {
|
||||
if (options.result_callback == nullptr) {
|
||||
const absl::Status status = absl::InvalidArgumentError(
|
||||
"Provided null pointer to callback function.");
|
||||
ABSL_LOG(ERROR) << "Failed to create ObjectDetector: " << status;
|
||||
CppProcessError(status, error_msg);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ObjectDetectorOptions::result_callback_fn result_callback =
|
||||
options.result_callback;
|
||||
cpp_options->result_callback =
|
||||
[result_callback](absl::StatusOr<CppObjectDetectorResult> cpp_result,
|
||||
const Image& image, int64_t timestamp) {
|
||||
char* error_msg = nullptr;
|
||||
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status();
|
||||
CppProcessError(cpp_result.status(), &error_msg);
|
||||
result_callback(nullptr, MpImage(), timestamp, error_msg);
|
||||
free(error_msg);
|
||||
return;
|
||||
}
|
||||
|
||||
// Result is valid for the lifetime of the callback function.
|
||||
ObjectDetectorResult result;
|
||||
CppConvertToDetectionResult(*cpp_result, &result);
|
||||
|
||||
const auto& image_frame = image.GetImageFrameSharedPtr();
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {
|
||||
.format = static_cast<::ImageFormat>(image_frame->Format()),
|
||||
.image_buffer = image_frame->PixelData(),
|
||||
.width = image_frame->Width(),
|
||||
.height = image_frame->Height()}};
|
||||
|
||||
result_callback(&result, mp_image, timestamp,
|
||||
/* error_msg= */ nullptr);
|
||||
|
||||
CppCloseDetectionResult(&result);
|
||||
};
|
||||
}
|
||||
|
||||
auto detector = ObjectDetector::Create(std::move(cpp_options));
|
||||
if (!detector.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create ObjectDetector: " << detector.status();
|
||||
CppProcessError(detector.status(), error_msg);
|
||||
return nullptr;
|
||||
}
|
||||
return detector->release();
|
||||
}
|
||||
|
||||
int CppObjectDetectorDetect(void* detector, const MpImage* image,
|
||||
ObjectDetectorResult* result, char** error_msg) {
|
||||
if (image->type == MpImage::GPU_BUFFER) {
|
||||
const absl::Status status =
|
||||
absl::InvalidArgumentError("GPU Buffer not supported yet.");
|
||||
|
||||
ABSL_LOG(ERROR) << "Detection failed: " << status.message();
|
||||
return CppProcessError(status, error_msg);
|
||||
}
|
||||
|
||||
const auto img = CreateImageFromBuffer(
|
||||
static_cast<ImageFormat::Format>(image->image_frame.format),
|
||||
image->image_frame.image_buffer, image->image_frame.width,
|
||||
image->image_frame.height);
|
||||
|
||||
if (!img.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
|
||||
return CppProcessError(img.status(), error_msg);
|
||||
}
|
||||
|
||||
auto cpp_detector = static_cast<ObjectDetector*>(detector);
|
||||
auto cpp_result = cpp_detector->Detect(*img);
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status();
|
||||
return CppProcessError(cpp_result.status(), error_msg);
|
||||
}
|
||||
CppConvertToDetectionResult(*cpp_result, result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CppObjectDetectorDetectForVideo(void* detector, const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
ObjectDetectorResult* result,
|
||||
char** error_msg) {
|
||||
if (image->type == MpImage::GPU_BUFFER) {
|
||||
absl::Status status =
|
||||
absl::InvalidArgumentError("GPU Buffer not supported yet");
|
||||
|
||||
ABSL_LOG(ERROR) << "Detection failed: " << status.message();
|
||||
return CppProcessError(status, error_msg);
|
||||
}
|
||||
|
||||
const auto img = CreateImageFromBuffer(
|
||||
static_cast<ImageFormat::Format>(image->image_frame.format),
|
||||
image->image_frame.image_buffer, image->image_frame.width,
|
||||
image->image_frame.height);
|
||||
|
||||
if (!img.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
|
||||
return CppProcessError(img.status(), error_msg);
|
||||
}
|
||||
|
||||
auto cpp_detector = static_cast<ObjectDetector*>(detector);
|
||||
auto cpp_result = cpp_detector->DetectForVideo(*img, timestamp_ms);
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status();
|
||||
return CppProcessError(cpp_result.status(), error_msg);
|
||||
}
|
||||
CppConvertToDetectionResult(*cpp_result, result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CppObjectDetectorDetectAsync(void* detector, const MpImage* image,
|
||||
int64_t timestamp_ms, char** error_msg) {
|
||||
if (image->type == MpImage::GPU_BUFFER) {
|
||||
absl::Status status =
|
||||
absl::InvalidArgumentError("GPU Buffer not supported yet");
|
||||
|
||||
ABSL_LOG(ERROR) << "Detection failed: " << status.message();
|
||||
return CppProcessError(status, error_msg);
|
||||
}
|
||||
|
||||
const auto img = CreateImageFromBuffer(
|
||||
static_cast<ImageFormat::Format>(image->image_frame.format),
|
||||
image->image_frame.image_buffer, image->image_frame.width,
|
||||
image->image_frame.height);
|
||||
|
||||
if (!img.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
|
||||
return CppProcessError(img.status(), error_msg);
|
||||
}
|
||||
|
||||
auto cpp_detector = static_cast<ObjectDetector*>(detector);
|
||||
auto cpp_result = cpp_detector->DetectAsync(*img, timestamp_ms);
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Data preparation for the object detection failed: "
|
||||
<< cpp_result;
|
||||
return CppProcessError(cpp_result, error_msg);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CppObjectDetectorCloseResult(ObjectDetectorResult* result) {
|
||||
CppCloseDetectionResult(result);
|
||||
}
|
||||
|
||||
int CppObjectDetectorClose(void* detector, char** error_msg) {
|
||||
auto cpp_detector = static_cast<ObjectDetector*>(detector);
|
||||
auto result = cpp_detector->Close();
|
||||
if (!result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to close ObjectDetector: " << result;
|
||||
return CppProcessError(result, error_msg);
|
||||
}
|
||||
delete cpp_detector;
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::vision::object_detector
|
||||
|
||||
extern "C" {
|
||||
|
||||
void* object_detector_create(struct ObjectDetectorOptions* options,
|
||||
char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorCreate(
|
||||
*options, error_msg);
|
||||
}
|
||||
|
||||
int object_detector_detect_image(void* detector, const MpImage* image,
|
||||
ObjectDetectorResult* result,
|
||||
char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorDetect(
|
||||
detector, image, result, error_msg);
|
||||
}
|
||||
|
||||
int object_detector_detect_for_video(void* detector, const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
ObjectDetectorResult* result,
|
||||
char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::object_detector::
|
||||
CppObjectDetectorDetectForVideo(detector, image, timestamp_ms, result,
|
||||
error_msg);
|
||||
}
|
||||
|
||||
int object_detector_detect_async(void* detector, const MpImage* image,
|
||||
int64_t timestamp_ms, char** error_msg) {
|
||||
return mediapipe::tasks::c::vision::object_detector::
|
||||
CppObjectDetectorDetectAsync(detector, image, timestamp_ms, error_msg);
|
||||
}
|
||||
|
||||
void object_detector_close_result(ObjectDetectorResult* result) {
|
||||
mediapipe::tasks::c::vision::object_detector::CppObjectDetectorCloseResult(
|
||||
result);
|
||||
}
|
||||
|
||||
int object_detector_close(void* detector, char** error_ms) {
|
||||
return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorClose(
|
||||
detector, error_ms);
|
||||
}
|
||||
|
||||
} // extern "C"
|
157
mediapipe/tasks/c/vision/object_detector/object_detector.h
Normal file
157
mediapipe/tasks/c/vision/object_detector/object_detector.h
Normal file
|
@ -0,0 +1,157 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_
|
||||
#define MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/detection_result.h"
|
||||
#include "mediapipe/tasks/c/core/base_options.h"
|
||||
#include "mediapipe/tasks/c/vision/core/common.h"
|
||||
|
||||
#ifndef MP_EXPORT
|
||||
#define MP_EXPORT __attribute__((visibility("default")))
|
||||
#endif // MP_EXPORT
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef DetectionResult ObjectDetectorResult;
|
||||
|
||||
// The options for configuring a MediaPipe object detector task.
|
||||
struct ObjectDetectorOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
struct BaseOptions base_options;
|
||||
|
||||
// The running mode of the task. Default to the image mode.
|
||||
// Object detector has three running modes:
|
||||
// 1) The image mode for detecting objects on single image inputs.
|
||||
// 2) The video mode for detecting objects on the decoded frames of a video.
|
||||
// 3) The live stream mode for detecting objects on the live stream of input
|
||||
// data, such as from camera. In this mode, the "result_callback" below must
|
||||
// be specified to receive the detection results asynchronously.
|
||||
RunningMode running_mode;
|
||||
|
||||
// The locale to use for display names specified through the TFLite Model
|
||||
// Metadata, if any. Defaults to English.
|
||||
const char* display_names_locale;
|
||||
|
||||
// The maximum number of top-scored detection results to return. If < 0,
|
||||
// all available results will be returned. If 0, an invalid argument error is
|
||||
// returned.
|
||||
int max_results;
|
||||
|
||||
// Score threshold to override the one provided in the model metadata (if
|
||||
// any). Results below this value are rejected.
|
||||
float score_threshold;
|
||||
|
||||
// The allowlist of category names. If non-empty, detection results whose
|
||||
// category name is not in this set will be filtered out. Duplicate or unknown
|
||||
// category names are ignored. Mutually exclusive with category_denylist.
|
||||
const char** category_allowlist;
|
||||
// The number of elements in the category allowlist.
|
||||
uint32_t category_allowlist_count;
|
||||
|
||||
// The denylist of category names. If non-empty, detection results whose
|
||||
// category name is in this set will be filtered out. Duplicate or unknown
|
||||
// category names are ignored. Mutually exclusive with category_allowlist.
|
||||
const char** category_denylist;
|
||||
// The number of elements in the category denylist.
|
||||
uint32_t category_denylist_count;
|
||||
|
||||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM. Arguments of the callback function include:
|
||||
// the pointer to detection result, the image that result was obtained
|
||||
// on, the timestamp relevant to detection results and pointer to error
|
||||
// message in case of any failure. The validity of the passed arguments is
|
||||
// true for the lifetime of the callback function.
|
||||
//
|
||||
// A caller is responsible for closing object detector result.
|
||||
typedef void (*result_callback_fn)(ObjectDetectorResult* result,
|
||||
const MpImage& image, int64_t timestamp_ms,
|
||||
char* error_msg);
|
||||
result_callback_fn result_callback;
|
||||
};
|
||||
|
||||
// Creates an ObjectDetector from the provided `options`.
|
||||
// Returns a pointer to the image detector on success.
|
||||
// If an error occurs, returns `nullptr` and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT void* object_detector_create(struct ObjectDetectorOptions* options,
|
||||
char** error_msg);
|
||||
|
||||
// Performs image detection on the input `image`. Returns `0` on success.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int object_detector_detect_image(void* detector, const MpImage* image,
|
||||
ObjectDetectorResult* result,
|
||||
char** error_msg);
|
||||
|
||||
// Performs image detection on the provided video frame.
|
||||
// Only use this method when the ObjectDetector is created with the video
|
||||
// running mode.
|
||||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int object_detector_detect_for_video(void* detector,
|
||||
const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
ObjectDetectorResult* result,
|
||||
char** error_msg);
|
||||
|
||||
// Sends live image data to image detection, and the results will be
|
||||
// available via the `result_callback` provided in the ObjectDetectorOptions.
|
||||
// Only use this method when the ObjectDetector is created with the live
|
||||
// stream running mode.
|
||||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
// sent to the object detector. The input timestamps must be monotonically
|
||||
// increasing.
|
||||
// The `result_callback` provides:
|
||||
// - The detection results as an ObjectDetectorResult object.
|
||||
// - The const reference to the corresponding input image that the image
|
||||
// detector runs on. Note that the const reference to the image will no
|
||||
// longer be valid when the callback returns. To access the image data
|
||||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int object_detector_detect_async(void* detector, const MpImage* image,
|
||||
int64_t timestamp_ms,
|
||||
char** error_msg);
|
||||
|
||||
// Frees the memory allocated inside a ObjectDetectorResult result.
|
||||
// Does not free the result pointer itself.
|
||||
MP_EXPORT void object_detector_close_result(ObjectDetectorResult* result);
|
||||
|
||||
// Frees object detector.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not `nullptr`). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int object_detector_close(void* detector, char** error_msg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_
|
253
mediapipe/tasks/c/vision/object_detector/object_detector_test.cc
Normal file
253
mediapipe/tasks/c/vision/object_detector/object_detector_test.cc
Normal file
|
@ -0,0 +1,253 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/vision/object_detector/object_detector.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/c/components/containers/category.h"
|
||||
#include "mediapipe/tasks/c/vision/core/common.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::vision::DecodeImageFromFile;
|
||||
using testing::HasSubstr;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kImageFile[] = "cats_and_dogs.jpg";
|
||||
constexpr char kModelName[] =
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
|
||||
constexpr float kPrecision = 1e-4;
|
||||
constexpr int kIterations = 100;
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
TEST(ObjectDetectorTest, ImageModeTest) {
|
||||
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
|
||||
ASSERT_TRUE(image.ok());
|
||||
|
||||
const std::string model_path = GetFullPath(kModelName);
|
||||
ObjectDetectorOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* running_mode= */ RunningMode::IMAGE,
|
||||
/* display_names_locale= */ nullptr,
|
||||
/* max_results= */ -1,
|
||||
/* score_threshold= */ 0.0,
|
||||
/* category_allowlist= */ nullptr,
|
||||
/* category_allowlist_count= */ 0,
|
||||
/* category_denylist= */ nullptr,
|
||||
/* category_denylist_count= */ 0,
|
||||
};
|
||||
|
||||
void* detector = object_detector_create(&options, /* error_msg */ nullptr);
|
||||
EXPECT_NE(detector, nullptr);
|
||||
|
||||
const auto& image_frame = image->GetImageFrameSharedPtr();
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
|
||||
.image_buffer = image_frame->PixelData(),
|
||||
.width = image_frame->Width(),
|
||||
.height = image_frame->Height()}};
|
||||
|
||||
ObjectDetectorResult result;
|
||||
object_detector_detect_image(detector, &mp_image, &result,
|
||||
/* error_msg */ nullptr);
|
||||
EXPECT_EQ(result.detections_count, 10);
|
||||
EXPECT_EQ(result.detections[0].categories_count, 1);
|
||||
EXPECT_EQ(std::string{result.detections[0].categories[0].category_name},
|
||||
"cat");
|
||||
EXPECT_NEAR(result.detections[0].categories[0].score, 0.6992f, kPrecision);
|
||||
object_detector_close_result(&result);
|
||||
object_detector_close(detector, /* error_msg */ nullptr);
|
||||
}
|
||||
|
||||
TEST(ObjectDetectorTest, VideoModeTest) {
|
||||
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
|
||||
ASSERT_TRUE(image.ok());
|
||||
|
||||
const std::string model_path = GetFullPath(kModelName);
|
||||
ObjectDetectorOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* running_mode= */ RunningMode::VIDEO,
|
||||
/* display_names_locale= */ nullptr,
|
||||
/* max_results= */ 3,
|
||||
/* score_threshold= */ 0.0,
|
||||
/* category_allowlist= */ nullptr,
|
||||
/* category_allowlist_count= */ 0,
|
||||
/* category_denylist= */ nullptr,
|
||||
/* category_denylist_count= */ 0,
|
||||
};
|
||||
|
||||
void* detector = object_detector_create(&options, /* error_msg */ nullptr);
|
||||
EXPECT_NE(detector, nullptr);
|
||||
|
||||
const auto& image_frame = image->GetImageFrameSharedPtr();
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
|
||||
.image_buffer = image_frame->PixelData(),
|
||||
.width = image_frame->Width(),
|
||||
.height = image_frame->Height()}};
|
||||
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
ObjectDetectorResult result;
|
||||
object_detector_detect_for_video(detector, &mp_image, i, &result,
|
||||
/* error_msg */ nullptr);
|
||||
EXPECT_EQ(result.detections_count, 3);
|
||||
EXPECT_EQ(result.detections[0].categories_count, 1);
|
||||
EXPECT_EQ(std::string{result.detections[0].categories[0].category_name},
|
||||
"cat");
|
||||
EXPECT_NEAR(result.detections[0].categories[0].score, 0.6992f, kPrecision);
|
||||
object_detector_close_result(&result);
|
||||
}
|
||||
object_detector_close(detector, /* error_msg */ nullptr);
|
||||
}
|
||||
|
||||
// A structure to support LiveStreamModeTest below. This structure holds a
|
||||
// static method `Fn` for a callback function of C API. A `static` qualifier
|
||||
// allows to take an address of the method to follow API style. Another static
|
||||
// struct member is `last_timestamp` that is used to verify that current
|
||||
// timestamp is greater than the previous one.
|
||||
struct LiveStreamModeCallback {
|
||||
static int64_t last_timestamp;
|
||||
static void Fn(ObjectDetectorResult* detector_result, const MpImage& image,
|
||||
int64_t timestamp, char* error_msg) {
|
||||
ASSERT_NE(detector_result, nullptr);
|
||||
ASSERT_EQ(error_msg, nullptr);
|
||||
EXPECT_EQ(detector_result->detections_count, 3);
|
||||
EXPECT_EQ(detector_result->detections[0].categories_count, 1);
|
||||
EXPECT_EQ(
|
||||
std::string{detector_result->detections[0].categories[0].category_name},
|
||||
"cat");
|
||||
EXPECT_NEAR(detector_result->detections[0].categories[0].score, 0.6992f,
|
||||
kPrecision);
|
||||
EXPECT_GT(image.image_frame.width, 0);
|
||||
EXPECT_GT(image.image_frame.height, 0);
|
||||
EXPECT_GT(timestamp, last_timestamp);
|
||||
last_timestamp++;
|
||||
}
|
||||
};
|
||||
int64_t LiveStreamModeCallback::last_timestamp = -1;
|
||||
|
||||
TEST(ObjectDetectorTest, LiveStreamModeTest) {
|
||||
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
|
||||
ASSERT_TRUE(image.ok());
|
||||
|
||||
const std::string model_path = GetFullPath(kModelName);
|
||||
|
||||
ObjectDetectorOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* running_mode= */ RunningMode::LIVE_STREAM,
|
||||
/* display_names_locale= */ nullptr,
|
||||
/* max_results= */ 3,
|
||||
/* score_threshold= */ 0.0,
|
||||
/* category_allowlist= */ nullptr,
|
||||
/* category_allowlist_count= */ 0,
|
||||
/* category_denylist= */ nullptr,
|
||||
/* category_denylist_count= */ 0,
|
||||
/* result_callback= */ LiveStreamModeCallback::Fn,
|
||||
};
|
||||
|
||||
void* detector = object_detector_create(&options, /* error_msg */
|
||||
nullptr);
|
||||
EXPECT_NE(detector, nullptr);
|
||||
|
||||
const auto& image_frame = image->GetImageFrameSharedPtr();
|
||||
const MpImage mp_image = {
|
||||
.type = MpImage::IMAGE_FRAME,
|
||||
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
|
||||
.image_buffer = image_frame->PixelData(),
|
||||
.width = image_frame->Width(),
|
||||
.height = image_frame->Height()}};
|
||||
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
EXPECT_GE(object_detector_detect_async(detector, &mp_image, i,
|
||||
/* error_msg */ nullptr),
|
||||
0);
|
||||
}
|
||||
object_detector_close(detector, /* error_msg */ nullptr);
|
||||
|
||||
// Due to the flow limiter, the total of outputs might be smaller than the
|
||||
// number of iterations.
|
||||
EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations);
|
||||
EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0);
|
||||
}
|
||||
|
||||
TEST(ObjectDetectorTest, InvalidArgumentHandling) {
|
||||
// It is an error to set neither the asset buffer nor the path.
|
||||
ObjectDetectorOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ nullptr},
|
||||
};
|
||||
|
||||
char* error_msg;
|
||||
void* detector = object_detector_create(&options, &error_msg);
|
||||
EXPECT_EQ(detector, nullptr);
|
||||
|
||||
EXPECT_THAT(error_msg, HasSubstr("ExternalFile must specify"));
|
||||
|
||||
free(error_msg);
|
||||
}
|
||||
|
||||
TEST(ObjectDetectorTest, FailedDetectionHandling) {
|
||||
const std::string model_path = GetFullPath(kModelName);
|
||||
ObjectDetectorOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_buffer_count= */ 0,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* running_mode= */ RunningMode::IMAGE,
|
||||
/* display_names_locale= */ nullptr,
|
||||
/* max_results= */ -1,
|
||||
/* score_threshold= */ 0.0,
|
||||
/* category_allowlist= */ nullptr,
|
||||
/* category_allowlist_count= */ 0,
|
||||
/* category_denylist= */ nullptr,
|
||||
/* category_denylist_count= */ 0,
|
||||
};
|
||||
|
||||
void* detector = object_detector_create(&options, /* error_msg */
|
||||
nullptr);
|
||||
EXPECT_NE(detector, nullptr);
|
||||
|
||||
const MpImage mp_image = {.type = MpImage::GPU_BUFFER, .gpu_buffer = {}};
|
||||
ObjectDetectorResult result;
|
||||
char* error_msg;
|
||||
object_detector_detect_image(detector, &mp_image, &result, &error_msg);
|
||||
EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet"));
|
||||
free(error_msg);
|
||||
object_detector_close(detector, /* error_msg */ nullptr);
|
||||
}
|
||||
|
||||
} // namespace
|
|
@ -155,6 +155,37 @@ cc_library(
|
|||
|
||||
# TODO: open source hand joints graph
|
||||
|
||||
cc_library(
|
||||
name = "hand_roi_refinement_graph",
|
||||
srcs = ["hand_roi_refinement_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2/stream:detections_to_rects",
|
||||
"//mediapipe/framework/api2/stream:landmarks_projection",
|
||||
"//mediapipe/framework/api2/stream:landmarks_to_detection",
|
||||
"//mediapipe/framework/api2/stream:rect_transformation",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hand_landmarker_result",
|
||||
srcs = ["hand_landmarker_result.cc"],
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <array>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
|
||||
#include "mediapipe/framework/api2/stream/landmarks_projection.h"
|
||||
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
|
||||
#include "mediapipe/framework/api2/stream/rect_transformation.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace hand_landmarker {
|
||||
|
||||
using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionToRect;
|
||||
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::ProjectLandmarks;
|
||||
using ::mediapipe::api2::builder::ScaleAndShiftAndMakeSquareLong;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
|
||||
// Refine the input hand RoI with hand_roi_refinement model.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// The image to preprocess.
|
||||
// NORM_RECT - NormalizedRect
|
||||
// Coarse RoI of hand.
|
||||
// Outputs:
|
||||
// NORM_RECT - NormalizedRect
|
||||
// Refined RoI of hand.
|
||||
class HandRoiRefinementGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
mediapipe::SubgraphContext* context) override {
|
||||
Graph graph;
|
||||
Stream<Image> image_in = graph.In("IMAGE").Cast<Image>();
|
||||
Stream<NormalizedRect> roi_in =
|
||||
graph.In("NORM_RECT").Cast<NormalizedRect>();
|
||||
|
||||
auto& graph_options =
|
||||
*context->MutableOptions<proto::HandRoiRefinementGraphOptions>();
|
||||
|
||||
MP_ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
GetOrCreateModelResources<proto::HandRoiRefinementGraphOptions>(
|
||||
context));
|
||||
|
||||
auto& preprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
|
||||
bool use_gpu =
|
||||
components::processors::DetermineImagePreprocessingGpuBackend(
|
||||
graph_options.base_options().acceleration());
|
||||
auto& image_to_tensor_options =
|
||||
*preprocessing
|
||||
.GetOptions<tasks::components::processors::proto::
|
||||
ImagePreprocessingGraphOptions>()
|
||||
.mutable_image_to_tensor_options();
|
||||
image_to_tensor_options.set_keep_aspect_ratio(true);
|
||||
image_to_tensor_options.set_border_mode(
|
||||
mediapipe::ImageToTensorCalculatorOptions::BORDER_REPLICATE);
|
||||
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
|
||||
*model_resources, use_gpu, graph_options.base_options().gpu_origin(),
|
||||
&preprocessing.GetOptions<tasks::components::processors::proto::
|
||||
ImagePreprocessingGraphOptions>()));
|
||||
image_in >> preprocessing.In("IMAGE");
|
||||
roi_in >> preprocessing.In("NORM_RECT");
|
||||
auto tensors_in = preprocessing.Out("TENSORS");
|
||||
auto matrix = preprocessing.Out("MATRIX").Cast<std::array<float, 16>>();
|
||||
auto image_size =
|
||||
preprocessing.Out("IMAGE_SIZE").Cast<std::pair<int, int>>();
|
||||
|
||||
auto& inference = AddInference(
|
||||
*model_resources, graph_options.base_options().acceleration(), graph);
|
||||
tensors_in >> inference.In("TENSORS");
|
||||
auto tensors_out = inference.Out("TENSORS").Cast<std::vector<Tensor>>();
|
||||
|
||||
MP_ASSIGN_OR_RETURN(auto image_tensor_specs,
|
||||
BuildInputImageTensorSpecs(*model_resources));
|
||||
|
||||
// Convert tensors to landmarks. Recrop model outputs two points,
|
||||
// center point and guide point.
|
||||
auto& to_landmarks = graph.AddNode("TensorsToLandmarksCalculator");
|
||||
auto& to_landmarks_opts =
|
||||
to_landmarks
|
||||
.GetOptions<mediapipe::TensorsToLandmarksCalculatorOptions>();
|
||||
to_landmarks_opts.set_num_landmarks(/*num_landmarks=*/2);
|
||||
to_landmarks_opts.set_input_image_width(image_tensor_specs.image_width);
|
||||
to_landmarks_opts.set_input_image_height(image_tensor_specs.image_height);
|
||||
to_landmarks_opts.set_normalize_z(/*z_norm_factor=*/1.0f);
|
||||
tensors_out.ConnectTo(to_landmarks.In("TENSORS"));
|
||||
auto recrop_landmarks = to_landmarks.Out("NORM_LANDMARKS")
|
||||
.Cast<mediapipe::NormalizedLandmarkList>();
|
||||
|
||||
// Project landmarks.
|
||||
auto projected_recrop_landmarks =
|
||||
ProjectLandmarks(recrop_landmarks, matrix, graph);
|
||||
|
||||
// Convert re-crop landmarks to detection.
|
||||
auto recrop_detection =
|
||||
ConvertLandmarksToDetection(projected_recrop_landmarks, graph);
|
||||
|
||||
// Convert re-crop detection to rect.
|
||||
auto recrop_rect = ConvertAlignmentPointsDetectionToRect(
|
||||
recrop_detection, image_size, /*start_keypoint_index=*/0,
|
||||
/*end_keypoint_index=*/1, /*target_angle=*/-90, graph);
|
||||
|
||||
auto refined_roi =
|
||||
ScaleAndShiftAndMakeSquareLong(recrop_rect, image_size,
|
||||
/*scale_x_factor=*/1.0,
|
||||
/*scale_y_factor=*/1.0, /*shift_x=*/0,
|
||||
/*shift_y=*/-0.1, graph);
|
||||
refined_roi >> graph.Out("NORM_RECT").Cast<NormalizedRect>();
|
||||
return graph.GetConfig();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::hand_landmarker::HandRoiRefinementGraph);
|
||||
|
||||
} // namespace hand_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
152
mediapipe/tasks/cc/vision/holistic_landmarker/BUILD
Normal file
152
mediapipe/tasks/cc/vision/holistic_landmarker/BUILD
Normal file
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/tasks:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "holistic_face_tracking",
|
||||
srcs = ["holistic_face_tracking.cc"],
|
||||
hdrs = ["holistic_face_tracking.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2/stream:detections_to_rects",
|
||||
"//mediapipe/framework/api2/stream:image_size",
|
||||
"//mediapipe/framework/api2/stream:landmarks_to_detection",
|
||||
"//mediapipe/framework/api2/stream:loopback",
|
||||
"//mediapipe/framework/api2/stream:rect_transformation",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator",
|
||||
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker:face_blendshapes_graph",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarks_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "holistic_hand_tracking",
|
||||
srcs = ["holistic_hand_tracking.cc"],
|
||||
hdrs = ["holistic_hand_tracking.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/util:align_hand_to_pose_in_world_calculator",
|
||||
"//mediapipe/calculators/util:align_hand_to_pose_in_world_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:landmark_visibility_calculator",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2/stream:image_size",
|
||||
"//mediapipe/framework/api2/stream:landmarks_to_detection",
|
||||
"//mediapipe/framework/api2/stream:loopback",
|
||||
"//mediapipe/framework/api2/stream:rect_transformation",
|
||||
"//mediapipe/framework/api2/stream:split",
|
||||
"//mediapipe/framework/api2/stream:threshold",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/modules/holistic_landmark/calculators:hand_detections_from_pose_to_rects_calculator",
|
||||
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator",
|
||||
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:gate",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_roi_refinement_graph",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "holistic_pose_tracking",
|
||||
srcs = ["holistic_pose_tracking.cc"],
|
||||
hdrs = ["holistic_pose_tracking.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2/stream:detections_to_rects",
|
||||
"//mediapipe/framework/api2/stream:image_size",
|
||||
"//mediapipe/framework/api2/stream:landmarks_to_detection",
|
||||
"//mediapipe/framework/api2/stream:loopback",
|
||||
"//mediapipe/framework/api2/stream:merge",
|
||||
"//mediapipe/framework/api2/stream:presence",
|
||||
"//mediapipe/framework/api2/stream:rect_transformation",
|
||||
"//mediapipe/framework/api2/stream:segmentation_smoothing",
|
||||
"//mediapipe/framework/api2/stream:smoothing",
|
||||
"//mediapipe/framework/api2/stream:split",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc/components/utils:gate",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker:pose_landmarks_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "holistic_landmarker_graph",
|
||||
srcs = ["holistic_landmarker_graph.cc"],
|
||||
deps = [
|
||||
":holistic_face_tracking",
|
||||
":holistic_hand_tracking",
|
||||
":holistic_pose_tracking",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2/stream:split",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
|
||||
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker:pose_topology",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/util:graph_builder_utils",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
|
@ -0,0 +1,260 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h"
|
||||
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
|
||||
#include "mediapipe/framework/api2/stream/image_size.h"
|
||||
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
|
||||
#include "mediapipe/framework/api2/stream/loopback.h"
|
||||
#include "mediapipe/framework/api2/stream/rect_transformation.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace holistic_landmarker {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::api2::builder::ConvertDetectionsToRectUsingKeypoints;
|
||||
using ::mediapipe::api2::builder::ConvertDetectionToRect;
|
||||
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
|
||||
using ::mediapipe::api2::builder::GetImageSize;
|
||||
using ::mediapipe::api2::builder::GetLoopbackData;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Scale;
|
||||
using ::mediapipe::api2::builder::ScaleAndMakeSquare;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
|
||||
struct FaceLandmarksResult {
|
||||
std::optional<Stream<NormalizedLandmarkList>> landmarks;
|
||||
std::optional<Stream<ClassificationList>> classifications;
|
||||
};
|
||||
|
||||
absl::Status ValidateGraphOptions(
|
||||
const face_detector::proto::FaceDetectorGraphOptions&
|
||||
face_detector_graph_options,
|
||||
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
|
||||
face_landmarks_detector_graph_options,
|
||||
const HolisticFaceTrackingRequest& request) {
|
||||
if (face_detector_graph_options.num_faces() != 1) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"Only support num_faces to be 1, but got num_faces = %d.",
|
||||
face_detector_graph_options.num_faces()));
|
||||
}
|
||||
if (request.classifications && !face_landmarks_detector_graph_options
|
||||
.has_face_blendshapes_graph_options()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Blendshapes detection is requested, but "
|
||||
"face_blendshapes_graph_options is not configured.");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Stream<NormalizedRect> GetFaceRoiFromPoseFaceLandmarks(
|
||||
Stream<NormalizedLandmarkList> pose_face_landmarks,
|
||||
Stream<std::pair<int, int>> image_size, Graph& graph) {
|
||||
Stream<mediapipe::Detection> detection =
|
||||
ConvertLandmarksToDetection(pose_face_landmarks, graph);
|
||||
|
||||
// Refer the pose face landmarks indices here:
|
||||
// https://developers.google.com/mediapipe/solutions/vision/pose_landmarker#pose_landmarker_model
|
||||
Stream<NormalizedRect> rect = ConvertDetectionToRect(
|
||||
detection, image_size, /*start_keypoint_index=*/5,
|
||||
/*end_keypoint_index=*/2, /*target_angle=*/0, graph);
|
||||
|
||||
// Scale the face RoI from a tight rect enclosing the pose face landmarks, to
|
||||
// a larger square so that the whole face is within the RoI.
|
||||
return ScaleAndMakeSquare(rect, image_size,
|
||||
/*scale_x_factor=*/3.0,
|
||||
/*scale_y_factor=*/3.0, graph);
|
||||
}
|
||||
|
||||
Stream<NormalizedRect> GetFaceRoiFromFaceLandmarks(
|
||||
Stream<NormalizedLandmarkList> face_landmarks,
|
||||
Stream<std::pair<int, int>> image_size, Graph& graph) {
|
||||
Stream<mediapipe::Detection> detection =
|
||||
ConvertLandmarksToDetection(face_landmarks, graph);
|
||||
|
||||
Stream<NormalizedRect> rect = ConvertDetectionToRect(
|
||||
detection, image_size, /*start_keypoint_index=*/33,
|
||||
/*end_keypoint_index=*/263, /*target_angle=*/0, graph);
|
||||
|
||||
return Scale(rect, image_size,
|
||||
/*scale_x_factor=*/1.5,
|
||||
/*scale_y_factor=*/1.5, graph);
|
||||
}
|
||||
|
||||
Stream<std::vector<Detection>> GetFaceDetections(
|
||||
Stream<Image> image, Stream<NormalizedRect> roi,
|
||||
const face_detector::proto::FaceDetectorGraphOptions&
|
||||
face_detector_graph_options,
|
||||
Graph& graph) {
|
||||
auto& face_detector_graph =
|
||||
graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph");
|
||||
face_detector_graph
|
||||
.GetOptions<face_detector::proto::FaceDetectorGraphOptions>() =
|
||||
face_detector_graph_options;
|
||||
image >> face_detector_graph.In("IMAGE");
|
||||
roi >> face_detector_graph.In("NORM_RECT");
|
||||
return face_detector_graph.Out("DETECTIONS").Cast<std::vector<Detection>>();
|
||||
}
|
||||
|
||||
Stream<NormalizedRect> GetFaceRoiFromFaceDetections(
|
||||
Stream<std::vector<Detection>> face_detections,
|
||||
Stream<std::pair<int, int>> image_size, Graph& graph) {
|
||||
// Convert detection to rect.
|
||||
Stream<NormalizedRect> rect = ConvertDetectionsToRectUsingKeypoints(
|
||||
face_detections, image_size, /*start_keypoint_index=*/0,
|
||||
/*end_keypoint_index=*/1, /*target_angle=*/0, graph);
|
||||
|
||||
return ScaleAndMakeSquare(rect, image_size,
|
||||
/*scale_x_factor=*/2.0,
|
||||
/*scale_y_factor=*/2.0, graph);
|
||||
}
|
||||
|
||||
Stream<NormalizedRect> TrackFaceRoi(
|
||||
Stream<NormalizedLandmarkList> prev_landmarks, Stream<NormalizedRect> roi,
|
||||
Stream<std::pair<int, int>> image_size, Graph& graph) {
|
||||
// Gets face ROI from previous frame face landmarks.
|
||||
Stream<NormalizedRect> prev_roi =
|
||||
GetFaceRoiFromFaceLandmarks(prev_landmarks, image_size, graph);
|
||||
|
||||
auto& tracking_node = graph.AddNode("RoiTrackingCalculator");
|
||||
auto& tracking_node_opts =
|
||||
tracking_node.GetOptions<RoiTrackingCalculatorOptions>();
|
||||
auto* rect_requirements = tracking_node_opts.mutable_rect_requirements();
|
||||
rect_requirements->set_rotation_degrees(15.0);
|
||||
rect_requirements->set_translation(0.1);
|
||||
rect_requirements->set_scale(0.3);
|
||||
auto* landmarks_requirements =
|
||||
tracking_node_opts.mutable_landmarks_requirements();
|
||||
landmarks_requirements->set_recrop_rect_margin(-0.2);
|
||||
prev_landmarks.ConnectTo(tracking_node.In("PREV_LANDMARKS"));
|
||||
prev_roi.ConnectTo(tracking_node.In("PREV_LANDMARKS_RECT"));
|
||||
roi.ConnectTo(tracking_node.In("RECROP_RECT"));
|
||||
image_size.ConnectTo(tracking_node.In("IMAGE_SIZE"));
|
||||
return tracking_node.Out("TRACKING_RECT").Cast<NormalizedRect>();
|
||||
}
|
||||
|
||||
FaceLandmarksResult GetFaceLandmarksDetection(
|
||||
Stream<Image> image, Stream<NormalizedRect> roi,
|
||||
Stream<std::pair<int, int>> image_size,
|
||||
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
|
||||
face_landmarks_detector_graph_options,
|
||||
const HolisticFaceTrackingRequest& request, Graph& graph) {
|
||||
FaceLandmarksResult result;
|
||||
auto& face_landmarks_detector_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.face_landmarker."
|
||||
"SingleFaceLandmarksDetectorGraph");
|
||||
face_landmarks_detector_graph
|
||||
.GetOptions<face_landmarker::proto::FaceLandmarksDetectorGraphOptions>() =
|
||||
face_landmarks_detector_graph_options;
|
||||
image >> face_landmarks_detector_graph.In("IMAGE");
|
||||
roi >> face_landmarks_detector_graph.In("NORM_RECT");
|
||||
auto landmarks = face_landmarks_detector_graph.Out("NORM_LANDMARKS")
|
||||
.Cast<NormalizedLandmarkList>();
|
||||
result.landmarks = landmarks;
|
||||
if (request.classifications) {
|
||||
auto& blendshapes_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.face_landmarker.FaceBlendshapesGraph");
|
||||
blendshapes_graph
|
||||
.GetOptions<face_landmarker::proto::FaceBlendshapesGraphOptions>() =
|
||||
face_landmarks_detector_graph_options.face_blendshapes_graph_options();
|
||||
landmarks >> blendshapes_graph.In("LANDMARKS");
|
||||
image_size >> blendshapes_graph.In("IMAGE_SIZE");
|
||||
result.classifications =
|
||||
blendshapes_graph.Out("BLENDSHAPES").Cast<ClassificationList>();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<HolisticFaceTrackingOutput> TrackHolisticFace(
|
||||
Stream<Image> image, Stream<NormalizedLandmarkList> pose_face_landmarks,
|
||||
const face_detector::proto::FaceDetectorGraphOptions&
|
||||
face_detector_graph_options,
|
||||
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
|
||||
face_landmarks_detector_graph_options,
|
||||
const HolisticFaceTrackingRequest& request, Graph& graph) {
|
||||
MP_RETURN_IF_ERROR(ValidateGraphOptions(face_detector_graph_options,
|
||||
face_landmarks_detector_graph_options,
|
||||
request));
|
||||
|
||||
// Extracts image size from the input images.
|
||||
Stream<std::pair<int, int>> image_size = GetImageSize(image, graph);
|
||||
|
||||
// Gets face ROI from pose face landmarks.
|
||||
Stream<NormalizedRect> roi_from_pose =
|
||||
GetFaceRoiFromPoseFaceLandmarks(pose_face_landmarks, image_size, graph);
|
||||
|
||||
// Detects faces within ROI of pose face.
|
||||
Stream<std::vector<Detection>> face_detections = GetFaceDetections(
|
||||
image, roi_from_pose, face_detector_graph_options, graph);
|
||||
|
||||
// Gets face ROI from face detector.
|
||||
Stream<NormalizedRect> roi_from_detection =
|
||||
GetFaceRoiFromFaceDetections(face_detections, image_size, graph);
|
||||
|
||||
// Loop for previous frame landmarks.
|
||||
auto [prev_landmarks, set_prev_landmarks_fn] =
|
||||
GetLoopbackData<NormalizedLandmarkList>(/*tick=*/image_size, graph);
|
||||
|
||||
// Tracks face ROI.
|
||||
auto tracking_roi =
|
||||
TrackFaceRoi(prev_landmarks, roi_from_detection, image_size, graph);
|
||||
|
||||
// Predicts face landmarks.
|
||||
auto landmarks_detection_result = GetFaceLandmarksDetection(
|
||||
image, tracking_roi, image_size, face_landmarks_detector_graph_options,
|
||||
request, graph);
|
||||
|
||||
// Sets previous landmarks for ROI tracking.
|
||||
set_prev_landmarks_fn(landmarks_detection_result.landmarks.value());
|
||||
|
||||
return {{.landmarks = landmarks_detection_result.landmarks,
|
||||
.classifications = landmarks_detection_result.classifications,
|
||||
.debug_output = {
|
||||
.roi_from_pose = roi_from_pose,
|
||||
.roi_from_detection = roi_from_detection,
|
||||
.tracking_roi = tracking_roi,
|
||||
}}};
|
||||
}
|
||||
|
||||
} // namespace holistic_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,89 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace holistic_landmarker {
|
||||
|
||||
struct HolisticFaceTrackingRequest {
|
||||
bool classifications = false;
|
||||
};
|
||||
|
||||
struct HolisticFaceTrackingOutput {
|
||||
std::optional<api2::builder::Stream<mediapipe::NormalizedLandmarkList>>
|
||||
landmarks;
|
||||
std::optional<api2::builder::Stream<mediapipe::ClassificationList>>
|
||||
classifications;
|
||||
|
||||
struct DebugOutput {
|
||||
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_pose;
|
||||
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_detection;
|
||||
api2::builder::Stream<mediapipe::NormalizedRect> tracking_roi;
|
||||
};
|
||||
|
||||
DebugOutput debug_output;
|
||||
};
|
||||
|
||||
// Updates @graph to track a single face in @image based on pose landmarks.
|
||||
//
|
||||
// To track single face this subgraph uses pose face landmarks to obtain
|
||||
// approximate face location, refines it with face detector model and then runs
|
||||
// face landmarks model. It can also reuse face ROI from the previous frame if
|
||||
// face hasn't moved too much.
|
||||
//
|
||||
// @image - Image to track a single face in.
|
||||
// @pose_face_landmarks - Pose face landmarks to derive initial face location
|
||||
// from.
|
||||
// @face_detector_graph_options - face detector graph options used to detect the
|
||||
// face within the RoI constructed from the pose face landmarks.
|
||||
// @face_landmarks_detector_graph_options - face landmarks detector graph
|
||||
// options used to detect face landmarks within the RoI given be the face
|
||||
// detector graph.
|
||||
// @request - object to request specific face tracking outputs.
|
||||
// NOTE: Outputs that were not requested won't be returned and corresponding
|
||||
// parts of the graph won't be genertaed.
|
||||
// @graph - graph to update.
|
||||
absl::StatusOr<HolisticFaceTrackingOutput> TrackHolisticFace(
|
||||
api2::builder::Stream<Image> image,
|
||||
api2::builder::Stream<mediapipe::NormalizedLandmarkList>
|
||||
pose_face_landmarks,
|
||||
const face_detector::proto::FaceDetectorGraphOptions&
|
||||
face_detector_graph_options,
|
||||
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
|
||||
face_landmarks_detector_graph_options,
|
||||
const HolisticFaceTrackingRequest& request,
|
||||
mediapipe::api2::builder::Graph& graph);
|
||||
|
||||
} // namespace holistic_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_
|
|
@ -0,0 +1,227 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/rect_to_render_data_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/stream/image_size.h"
|
||||
#include "mediapipe/framework/api2/stream/split.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace holistic_landmarker {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::api2::builder::GetImageSize;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::SplitToRanges;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
using ::mediapipe::tasks::core::ModelAssetBundleResources;
|
||||
using ::mediapipe::tasks::core::TaskRunner;
|
||||
using ::mediapipe::tasks::core::proto::ExternalFile;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr float kAbsMargin = 0.015;
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kTestImageFile[] = "male_full_height_hands.jpg";
|
||||
constexpr char kHolisticResultFile[] =
|
||||
"male_full_height_hands_result_cpu.pbtxt";
|
||||
constexpr char kImageInStream[] = "image_in";
|
||||
constexpr char kPoseLandmarksInStream[] = "pose_landmarks_in";
|
||||
constexpr char kFaceLandmarksOutStream[] = "face_landmarks_out";
|
||||
constexpr char kRenderedImageOutStream[] = "rendered_image_out";
|
||||
constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite";
|
||||
constexpr char kFaceLandmarksDetectorTFLiteName[] =
|
||||
"face_landmarks_detector.tflite";
|
||||
|
||||
std::string GetFilePath(absl::string_view filename) {
|
||||
return file::JoinPath("./", kTestDataDirectory, filename);
|
||||
}
|
||||
|
||||
mediapipe::LandmarksToRenderDataCalculatorOptions GetFaceRendererOptions() {
|
||||
mediapipe::LandmarksToRenderDataCalculatorOptions render_options;
|
||||
for (const auto& connection :
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksConnectors) {
|
||||
render_options.add_landmark_connections(connection[0]);
|
||||
render_options.add_landmark_connections(connection[1]);
|
||||
}
|
||||
render_options.mutable_landmark_color()->set_r(255);
|
||||
render_options.mutable_landmark_color()->set_g(255);
|
||||
render_options.mutable_landmark_color()->set_b(255);
|
||||
render_options.mutable_connection_color()->set_r(255);
|
||||
render_options.mutable_connection_color()->set_g(255);
|
||||
render_options.mutable_connection_color()->set_b(255);
|
||||
render_options.set_thickness(0.5);
|
||||
render_options.set_visualize_landmark_depth(false);
|
||||
return render_options;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<ModelAssetBundleResources>>
|
||||
CreateModelAssetBundleResources(const std::string& model_asset_filename) {
|
||||
auto external_model_bundle = std::make_unique<ExternalFile>();
|
||||
external_model_bundle->set_file_name(model_asset_filename);
|
||||
return ModelAssetBundleResources::Create("",
|
||||
std::move(external_model_bundle));
|
||||
}
|
||||
|
||||
// Helper function to create a TaskRunner.
|
||||
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner() {
|
||||
Graph graph;
|
||||
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
|
||||
Stream<mediapipe::NormalizedLandmarkList> pose_landmarks =
|
||||
graph.In("POSE_LANDMARKS")
|
||||
.Cast<mediapipe::NormalizedLandmarkList>()
|
||||
.SetName(kPoseLandmarksInStream);
|
||||
Stream<NormalizedLandmarkList> face_landmarks_from_pose =
|
||||
SplitToRanges(pose_landmarks, {{0, 11}}, graph)[0];
|
||||
// Create face landmarker model bundle.
|
||||
MP_ASSIGN_OR_RETURN(
|
||||
auto model_bundle,
|
||||
CreateModelAssetBundleResources(GetFilePath("face_landmarker_v2.task")));
|
||||
face_detector::proto::FaceDetectorGraphOptions detector_options;
|
||||
face_landmarker::proto::FaceLandmarksDetectorGraphOptions
|
||||
landmarks_detector_options;
|
||||
|
||||
// Set face detection model.
|
||||
MP_ASSIGN_OR_RETURN(auto face_detector_model_file,
|
||||
model_bundle->GetFile(kFaceDetectorTFLiteName));
|
||||
core::proto::FilePointerMeta face_detection_file_pointer;
|
||||
face_detection_file_pointer.set_pointer(
|
||||
reinterpret_cast<uint64_t>(face_detector_model_file.data()));
|
||||
face_detection_file_pointer.set_length(face_detector_model_file.size());
|
||||
detector_options.mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->mutable_file_pointer_meta()
|
||||
->Swap(&face_detection_file_pointer);
|
||||
detector_options.set_num_faces(1);
|
||||
|
||||
// Set face landmarks model.
|
||||
MP_ASSIGN_OR_RETURN(auto face_landmarks_model_file,
|
||||
model_bundle->GetFile(kFaceLandmarksDetectorTFLiteName));
|
||||
core::proto::FilePointerMeta face_landmarks_detector_file_pointer;
|
||||
face_landmarks_detector_file_pointer.set_pointer(
|
||||
reinterpret_cast<uint64_t>(face_landmarks_model_file.data()));
|
||||
face_landmarks_detector_file_pointer.set_length(
|
||||
face_landmarks_model_file.size());
|
||||
landmarks_detector_options.mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->mutable_file_pointer_meta()
|
||||
->Swap(&face_landmarks_detector_file_pointer);
|
||||
|
||||
// Track holistic face.
|
||||
HolisticFaceTrackingRequest request;
|
||||
MP_ASSIGN_OR_RETURN(
|
||||
HolisticFaceTrackingOutput result,
|
||||
TrackHolisticFace(image, face_landmarks_from_pose, detector_options,
|
||||
landmarks_detector_options, request, graph));
|
||||
auto face_landmarks =
|
||||
result.landmarks.value().SetName(kFaceLandmarksOutStream);
|
||||
|
||||
auto image_size = GetImageSize(image, graph);
|
||||
auto render_scale = utils::GetRenderScale(
|
||||
image_size, result.debug_output.roi_from_pose, 0.0001, graph);
|
||||
|
||||
auto face_landmarks_render_data = utils::RenderLandmarks(
|
||||
face_landmarks, render_scale, GetFaceRendererOptions(), graph);
|
||||
std::vector<Stream<mediapipe::RenderData>> render_list = {
|
||||
face_landmarks_render_data};
|
||||
|
||||
auto rendered_image =
|
||||
utils::Render(
|
||||
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
|
||||
.SetName(kRenderedImageOutStream);
|
||||
face_landmarks >> graph.Out("FACE_LANDMARKS");
|
||||
rendered_image >> graph.Out("RENDERED_IMAGE");
|
||||
|
||||
auto config = graph.GetConfig();
|
||||
core::FixGraphBackEdges(config);
|
||||
return TaskRunner::Create(
|
||||
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
class HolisticFaceTrackingTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(HolisticFaceTrackingTest, SmokeTest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image image,
|
||||
DecodeImageFromFile(GetFilePath(kTestImageFile)));
|
||||
|
||||
proto::HolisticResult holistic_result;
|
||||
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
|
||||
::file::Defaults()));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto output_packets,
|
||||
task_runner->Process(
|
||||
{{kImageInStream, MakePacket<Image>(image)},
|
||||
{kPoseLandmarksInStream, MakePacket<NormalizedLandmarkList>(
|
||||
holistic_result.pose_landmarks())}}));
|
||||
ASSERT_TRUE(output_packets.find(kFaceLandmarksOutStream) !=
|
||||
output_packets.end());
|
||||
auto face_landmarks = output_packets.find(kFaceLandmarksOutStream)
|
||||
->second.Get<NormalizedLandmarkList>();
|
||||
EXPECT_THAT(
|
||||
face_landmarks,
|
||||
Approximately(Partially(EqualsProto(holistic_result.face_landmarks())),
|
||||
/*margin=*/kAbsMargin));
|
||||
auto rendered_image = output_packets.at(kRenderedImageOutStream).Get<Image>();
|
||||
MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(),
|
||||
"holistic_face_landmarks"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace holistic_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,272 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h"
|
||||
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/util/align_hand_to_pose_in_world_calculator.h"
|
||||
#include "mediapipe/calculators/util/align_hand_to_pose_in_world_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/stream/image_size.h"
|
||||
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
|
||||
#include "mediapipe/framework/api2/stream/loopback.h"
|
||||
#include "mediapipe/framework/api2/stream/rect_transformation.h"
|
||||
#include "mediapipe/framework/api2/stream/split.h"
|
||||
#include "mediapipe/framework/api2/stream/threshold.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/gate.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace holistic_landmarker {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::api2::AlignHandToPoseInWorldCalculator;
|
||||
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
|
||||
using ::mediapipe::api2::builder::GetImageSize;
|
||||
using ::mediapipe::api2::builder::GetLoopbackData;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::IsOverThreshold;
|
||||
using ::mediapipe::api2::builder::ScaleAndShiftAndMakeSquareLong;
|
||||
using ::mediapipe::api2::builder::SplitAndCombine;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
using ::mediapipe::tasks::components::utils::AllowIf;
|
||||
|
||||
struct HandLandmarksResult {
|
||||
std::optional<Stream<NormalizedLandmarkList>> landmarks;
|
||||
std::optional<Stream<LandmarkList>> world_landmarks;
|
||||
};
|
||||
|
||||
Stream<LandmarkList> AlignHandToPoseInWorldCalculator(
|
||||
Stream<LandmarkList> hand_world_landmarks,
|
||||
Stream<LandmarkList> pose_world_landmarks, int pose_wrist_idx,
|
||||
Graph& graph) {
|
||||
auto& node = graph.AddNode("AlignHandToPoseInWorldCalculator");
|
||||
auto& opts = node.GetOptions<AlignHandToPoseInWorldCalculatorOptions>();
|
||||
opts.set_hand_wrist_idx(0);
|
||||
opts.set_pose_wrist_idx(pose_wrist_idx);
|
||||
hand_world_landmarks.ConnectTo(
|
||||
node[AlignHandToPoseInWorldCalculator::kInHandLandmarks]);
|
||||
pose_world_landmarks.ConnectTo(
|
||||
node[AlignHandToPoseInWorldCalculator::kInPoseLandmarks]);
|
||||
return node[AlignHandToPoseInWorldCalculator::kOutHandLandmarks];
|
||||
}
|
||||
|
||||
Stream<bool> GetPosePalmVisibility(
|
||||
Stream<NormalizedLandmarkList> pose_palm_landmarks, Graph& graph) {
|
||||
// Get wrist landmark.
|
||||
auto pose_wrist = SplitAndCombine(pose_palm_landmarks, {0}, graph);
|
||||
|
||||
// Get visibility score.
|
||||
auto& score_node = graph.AddNode("LandmarkVisibilityCalculator");
|
||||
pose_wrist.ConnectTo(score_node.In("NORM_LANDMARKS"));
|
||||
Stream<float> score = score_node.Out("VISIBILITY").Cast<float>();
|
||||
|
||||
// Convert score into flag.
|
||||
return IsOverThreshold(score, /*threshold=*/0.1, graph);
|
||||
}
|
||||
|
||||
Stream<NormalizedRect> GetHandRoiFromPosePalmLandmarks(
|
||||
Stream<NormalizedLandmarkList> pose_palm_landmarks,
|
||||
Stream<std::pair<int, int>> image_size, Graph& graph) {
|
||||
// Convert pose palm landmarks to detection.
|
||||
auto detection = ConvertLandmarksToDetection(pose_palm_landmarks, graph);
|
||||
|
||||
// Convert detection to rect.
|
||||
auto& rect_node = graph.AddNode("HandDetectionsFromPoseToRectsCalculator");
|
||||
detection.ConnectTo(rect_node.In("DETECTION"));
|
||||
image_size.ConnectTo(rect_node.In("IMAGE_SIZE"));
|
||||
Stream<NormalizedRect> rect =
|
||||
rect_node.Out("NORM_RECT").Cast<NormalizedRect>();
|
||||
|
||||
return ScaleAndShiftAndMakeSquareLong(rect, image_size,
|
||||
/*scale_x_factor=*/2.7,
|
||||
/*scale_y_factor=*/2.7, /*shift_x=*/0,
|
||||
/*shift_y=*/-0.1, graph);
|
||||
}
|
||||
|
||||
absl::StatusOr<Stream<NormalizedRect>> RefineHandRoi(
|
||||
Stream<Image> image, Stream<NormalizedRect> roi,
|
||||
const hand_landmarker::proto::HandRoiRefinementGraphOptions&
|
||||
hand_roi_refinenement_graph_options,
|
||||
Graph& graph) {
|
||||
auto& hand_roi_refinement = graph.AddNode(
|
||||
"mediapipe.tasks.vision.hand_landmarker.HandRoiRefinementGraph");
|
||||
hand_roi_refinement
|
||||
.GetOptions<hand_landmarker::proto::HandRoiRefinementGraphOptions>() =
|
||||
hand_roi_refinenement_graph_options;
|
||||
image >> hand_roi_refinement.In("IMAGE");
|
||||
roi >> hand_roi_refinement.In("NORM_RECT");
|
||||
return hand_roi_refinement.Out("NORM_RECT").Cast<NormalizedRect>();
|
||||
}
|
||||
|
||||
Stream<NormalizedRect> TrackHandRoi(
|
||||
Stream<NormalizedLandmarkList> prev_landmarks, Stream<NormalizedRect> roi,
|
||||
Stream<std::pair<int, int>> image_size, Graph& graph) {
|
||||
// Convert hand landmarks to tight rect.
|
||||
auto& prev_rect_node = graph.AddNode("HandLandmarksToRectCalculator");
|
||||
prev_landmarks.ConnectTo(prev_rect_node.In("NORM_LANDMARKS"));
|
||||
image_size.ConnectTo(prev_rect_node.In("IMAGE_SIZE"));
|
||||
Stream<NormalizedRect> prev_rect =
|
||||
prev_rect_node.Out("NORM_RECT").Cast<NormalizedRect>();
|
||||
|
||||
// Convert tight hand rect to hand roi.
|
||||
Stream<NormalizedRect> prev_roi =
|
||||
ScaleAndShiftAndMakeSquareLong(prev_rect, image_size,
|
||||
/*scale_x_factor=*/2.0,
|
||||
/*scale_y_factor=*/2.0, /*shift_x=*/0,
|
||||
/*shift_y=*/-0.1, graph);
|
||||
|
||||
auto& tracking_node = graph.AddNode("RoiTrackingCalculator");
|
||||
auto& tracking_node_opts =
|
||||
tracking_node.GetOptions<RoiTrackingCalculatorOptions>();
|
||||
auto* rect_requirements = tracking_node_opts.mutable_rect_requirements();
|
||||
rect_requirements->set_rotation_degrees(40.0);
|
||||
rect_requirements->set_translation(0.2);
|
||||
rect_requirements->set_scale(0.4);
|
||||
auto* landmarks_requirements =
|
||||
tracking_node_opts.mutable_landmarks_requirements();
|
||||
landmarks_requirements->set_recrop_rect_margin(-0.1);
|
||||
prev_landmarks.ConnectTo(tracking_node.In("PREV_LANDMARKS"));
|
||||
prev_roi.ConnectTo(tracking_node.In("PREV_LANDMARKS_RECT"));
|
||||
roi.ConnectTo(tracking_node.In("RECROP_RECT"));
|
||||
image_size.ConnectTo(tracking_node.In("IMAGE_SIZE"));
|
||||
return tracking_node.Out("TRACKING_RECT").Cast<NormalizedRect>();
|
||||
}
|
||||
|
||||
HandLandmarksResult GetHandLandmarksDetection(
|
||||
Stream<Image> image, Stream<NormalizedRect> roi,
|
||||
const hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
|
||||
hand_landmarks_detector_graph_options,
|
||||
const HolisticHandTrackingRequest& request, Graph& graph) {
|
||||
HandLandmarksResult result;
|
||||
auto& hand_landmarks_detector_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.hand_landmarker."
|
||||
"SingleHandLandmarksDetectorGraph");
|
||||
hand_landmarks_detector_graph
|
||||
.GetOptions<hand_landmarker::proto::HandLandmarksDetectorGraphOptions>() =
|
||||
hand_landmarks_detector_graph_options;
|
||||
|
||||
image >> hand_landmarks_detector_graph.In("IMAGE");
|
||||
roi >> hand_landmarks_detector_graph.In("HAND_RECT");
|
||||
|
||||
if (request.landmarks) {
|
||||
result.landmarks = hand_landmarks_detector_graph.Out("LANDMARKS")
|
||||
.Cast<NormalizedLandmarkList>();
|
||||
}
|
||||
if (request.world_landmarks) {
|
||||
result.world_landmarks =
|
||||
hand_landmarks_detector_graph.Out("WORLD_LANDMARKS")
|
||||
.Cast<LandmarkList>();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<HolisticHandTrackingOutput> TrackHolisticHand(
|
||||
Stream<Image> image, Stream<NormalizedLandmarkList> pose_landmarks,
|
||||
Stream<LandmarkList> pose_world_landmarks,
|
||||
const hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
|
||||
hand_landmarks_detector_graph_options,
|
||||
const hand_landmarker::proto::HandRoiRefinementGraphOptions&
|
||||
hand_roi_refinement_graph_options,
|
||||
const PoseIndices& pose_indices, const HolisticHandTrackingRequest& request,
|
||||
Graph& graph) {
|
||||
// Extracts pose palm landmarks.
|
||||
Stream<NormalizedLandmarkList> pose_palm_landmarks = SplitAndCombine(
|
||||
pose_landmarks,
|
||||
{pose_indices.wrist_idx, pose_indices.pinky_idx, pose_indices.index_idx},
|
||||
graph);
|
||||
|
||||
// Get pose palm visibility.
|
||||
Stream<bool> is_pose_palm_visible =
|
||||
GetPosePalmVisibility(pose_palm_landmarks, graph);
|
||||
|
||||
// Drop pose palm landmarks if pose palm is invisible.
|
||||
pose_palm_landmarks =
|
||||
AllowIf(pose_palm_landmarks, is_pose_palm_visible, graph);
|
||||
|
||||
// Extracts image size from the input images.
|
||||
Stream<std::pair<int, int>> image_size = GetImageSize(image, graph);
|
||||
|
||||
// Get hand ROI from pose palm landmarks.
|
||||
Stream<NormalizedRect> roi_from_pose =
|
||||
GetHandRoiFromPosePalmLandmarks(pose_palm_landmarks, image_size, graph);
|
||||
|
||||
// Refine hand ROI with re-crop model.
|
||||
MP_ASSIGN_OR_RETURN(Stream<NormalizedRect> roi_from_recrop,
|
||||
RefineHandRoi(image, roi_from_pose,
|
||||
hand_roi_refinement_graph_options, graph));
|
||||
|
||||
// Loop for previous frame landmarks.
|
||||
auto [prev_landmarks, set_prev_landmarks_fn] =
|
||||
GetLoopbackData<NormalizedLandmarkList>(/*tick=*/image_size, graph);
|
||||
|
||||
// Track hand ROI.
|
||||
auto tracking_roi =
|
||||
TrackHandRoi(prev_landmarks, roi_from_recrop, image_size, graph);
|
||||
|
||||
// Predict hand landmarks.
|
||||
auto landmarks_detection_result = GetHandLandmarksDetection(
|
||||
image, tracking_roi, hand_landmarks_detector_graph_options, request,
|
||||
graph);
|
||||
|
||||
// Set previous landmarks for ROI tracking.
|
||||
set_prev_landmarks_fn(landmarks_detection_result.landmarks.value());
|
||||
|
||||
// Output landmarks.
|
||||
std::optional<Stream<NormalizedLandmarkList>> hand_landmarks;
|
||||
if (request.landmarks) {
|
||||
hand_landmarks = landmarks_detection_result.landmarks;
|
||||
}
|
||||
|
||||
// Output world landmarks.
|
||||
std::optional<Stream<LandmarkList>> hand_world_landmarks;
|
||||
if (request.world_landmarks) {
|
||||
hand_world_landmarks = landmarks_detection_result.world_landmarks;
|
||||
|
||||
// Align hand world landmarks with pose world landmarks.
|
||||
hand_world_landmarks = AlignHandToPoseInWorldCalculator(
|
||||
hand_world_landmarks.value(), pose_world_landmarks,
|
||||
pose_indices.wrist_idx, graph);
|
||||
}
|
||||
|
||||
return {{.landmarks = hand_landmarks,
|
||||
.world_landmarks = hand_world_landmarks,
|
||||
.debug_output = {
|
||||
.roi_from_pose = roi_from_pose,
|
||||
.roi_from_recrop = roi_from_recrop,
|
||||
.tracking_roi = tracking_roi,
|
||||
}}};
|
||||
}
|
||||
|
||||
} // namespace holistic_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,94 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace holistic_landmarker {
|
||||
|
||||
struct PoseIndices {
|
||||
int wrist_idx;
|
||||
int pinky_idx;
|
||||
int index_idx;
|
||||
};
|
||||
|
||||
struct HolisticHandTrackingRequest {
|
||||
bool landmarks = false;
|
||||
bool world_landmarks = false;
|
||||
};
|
||||
|
||||
struct HolisticHandTrackingOutput {
|
||||
std::optional<api2::builder::Stream<mediapipe::NormalizedLandmarkList>>
|
||||
landmarks;
|
||||
std::optional<api2::builder::Stream<mediapipe::LandmarkList>> world_landmarks;
|
||||
|
||||
struct DebugOutput {
|
||||
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_pose;
|
||||
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_recrop;
|
||||
api2::builder::Stream<mediapipe::NormalizedRect> tracking_roi;
|
||||
};
|
||||
|
||||
DebugOutput debug_output;
|
||||
};
|
||||
|
||||
// Updates @graph to track a single hand in @image based on pose landmarks.
|
||||
//
|
||||
// To track single hand this subgraph uses pose palm landmarks to obtain
|
||||
// approximate hand location, refines it with re-crop model and then runs hand
|
||||
// landmarks model. It can also reuse hand ROI from the previous frame if hand
|
||||
// hasn't moved too much.
|
||||
//
|
||||
// @image - ImageFrame/GpuBuffer to track a single hand in.
|
||||
// @pose_landmarks - Pose landmarks to derive initial hand location from.
|
||||
// @pose_world_landmarks - Pose world landmarks to align hand world landmarks
|
||||
// wrist with.
|
||||
// @ hand_landmarks_detector_graph_options - Options of the
|
||||
// HandLandmarksDetectorGraph used to detect the hand landmarks.
|
||||
// @ hand_roi_refinement_graph_options - Options of HandRoiRefinementGraph used
|
||||
// to refine the hand RoIs got from Pose landmarks.
|
||||
// @request - object to request specific hand tracking outputs.
|
||||
// NOTE: Outputs that were not requested won't be returned and corresponding
|
||||
// parts of the graph won't be genertaed.
|
||||
// @graph - graph to update.
|
||||
absl::StatusOr<HolisticHandTrackingOutput> TrackHolisticHand(
|
||||
api2::builder::Stream<Image> image,
|
||||
api2::builder::Stream<mediapipe::NormalizedLandmarkList> pose_landmarks,
|
||||
api2::builder::Stream<mediapipe::LandmarkList> pose_world_landmarks,
|
||||
const hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
|
||||
hand_landmarks_detector_graph_options,
|
||||
const hand_landmarker::proto::HandRoiRefinementGraphOptions&
|
||||
hand_roi_refinement_graph_options,
|
||||
const PoseIndices& pose_indices, const HolisticHandTrackingRequest& request,
|
||||
mediapipe::api2::builder::Graph& graph);
|
||||
|
||||
} // namespace holistic_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_
|
|
@ -0,0 +1,303 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "file/base/helpers.h"
|
||||
#include "file/base/options.h"
|
||||
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/stream/image_size.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_topology.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace holistic_landmarker {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::file::Defaults;
|
||||
using ::file::GetTextProto;
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::api2::builder::GetImageSize;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
using ::mediapipe::tasks::core::TaskRunner;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr float kAbsMargin = 0.018;
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kHolisticHandTrackingLeft[] =
|
||||
"holistic_hand_tracking_left_hand_graph.pbtxt";
|
||||
constexpr char kTestImageFile[] = "male_full_height_hands.jpg";
|
||||
constexpr char kHolisticResultFile[] =
|
||||
"male_full_height_hands_result_cpu.pbtxt";
|
||||
constexpr char kImageInStream[] = "image_in";
|
||||
constexpr char kPoseLandmarksInStream[] = "pose_landmarks_in";
|
||||
constexpr char kPoseWorldLandmarksInStream[] = "pose_world_landmarks_in";
|
||||
constexpr char kLeftHandLandmarksOutStream[] = "left_hand_landmarks_out";
|
||||
constexpr char kLeftHandWorldLandmarksOutStream[] =
|
||||
"left_hand_world_landmarks_out";
|
||||
constexpr char kRightHandLandmarksOutStream[] = "right_hand_landmarks_out";
|
||||
constexpr char kRenderedImageOutStream[] = "rendered_image_out";
|
||||
constexpr char kHandLandmarksModelFile[] = "hand_landmark_full.tflite";
|
||||
constexpr char kHandRoiRefinementModelFile[] =
|
||||
"handrecrop_2020_07_21_v0.f16.tflite";
|
||||
|
||||
std::string GetFilePath(const std::string& filename) {
|
||||
return file::JoinPath("./", kTestDataDirectory, filename);
|
||||
}
|
||||
|
||||
mediapipe::LandmarksToRenderDataCalculatorOptions GetHandRendererOptions() {
|
||||
mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options;
|
||||
for (const auto& connection : hand_landmarker::kHandConnections) {
|
||||
renderer_options.add_landmark_connections(connection[0]);
|
||||
renderer_options.add_landmark_connections(connection[1]);
|
||||
}
|
||||
renderer_options.mutable_landmark_color()->set_r(255);
|
||||
renderer_options.mutable_landmark_color()->set_g(255);
|
||||
renderer_options.mutable_landmark_color()->set_b(255);
|
||||
renderer_options.mutable_connection_color()->set_r(255);
|
||||
renderer_options.mutable_connection_color()->set_g(255);
|
||||
renderer_options.mutable_connection_color()->set_b(255);
|
||||
renderer_options.set_thickness(0.5);
|
||||
renderer_options.set_visualize_landmark_depth(false);
|
||||
return renderer_options;
|
||||
}
|
||||
|
||||
void ConfigHandTrackingModelsOptions(
|
||||
hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
|
||||
hand_landmarks_detector_graph_options,
|
||||
hand_landmarker::proto::HandRoiRefinementGraphOptions&
|
||||
hand_roi_refinement_options) {
|
||||
hand_landmarks_detector_graph_options.mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath(kHandLandmarksModelFile));
|
||||
|
||||
hand_roi_refinement_options.mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath(kHandRoiRefinementModelFile));
|
||||
}
|
||||
|
||||
// Helper function to create a TaskRunner.
|
||||
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner() {
|
||||
Graph graph;
|
||||
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
|
||||
Stream<mediapipe::NormalizedLandmarkList> pose_landmarks =
|
||||
graph.In("POSE_LANDMARKS")
|
||||
.Cast<mediapipe::NormalizedLandmarkList>()
|
||||
.SetName(kPoseLandmarksInStream);
|
||||
Stream<mediapipe::LandmarkList> pose_world_landmarks =
|
||||
graph.In("POSE_WORLD_LANDMARKS")
|
||||
.Cast<mediapipe::LandmarkList>()
|
||||
.SetName(kPoseWorldLandmarksInStream);
|
||||
hand_landmarker::proto::HandLandmarksDetectorGraphOptions
|
||||
hand_landmarks_detector_options;
|
||||
hand_landmarker::proto::HandRoiRefinementGraphOptions
|
||||
hand_roi_refinement_options;
|
||||
ConfigHandTrackingModelsOptions(hand_landmarks_detector_options,
|
||||
hand_roi_refinement_options);
|
||||
HolisticHandTrackingRequest request;
|
||||
request.landmarks = true;
|
||||
MP_ASSIGN_OR_RETURN(
|
||||
HolisticHandTrackingOutput left_hand_result,
|
||||
TrackHolisticHand(
|
||||
image, pose_landmarks, pose_world_landmarks,
|
||||
hand_landmarks_detector_options, hand_roi_refinement_options,
|
||||
PoseIndices{
|
||||
/*wrist_idx=*/static_cast<int>(
|
||||
pose_landmarker::PoseLandmarkName::kLeftWrist),
|
||||
/*pinky_idx=*/
|
||||
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftPinky1),
|
||||
/*index_idx=*/
|
||||
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftIndex1)},
|
||||
request, graph));
|
||||
MP_ASSIGN_OR_RETURN(
|
||||
HolisticHandTrackingOutput right_hand_result,
|
||||
TrackHolisticHand(
|
||||
image, pose_landmarks, pose_world_landmarks,
|
||||
hand_landmarks_detector_options, hand_roi_refinement_options,
|
||||
PoseIndices{
|
||||
/*wrist_idx=*/static_cast<int>(
|
||||
pose_landmarker::PoseLandmarkName::kRightWrist),
|
||||
/*pinky_idx=*/
|
||||
static_cast<int>(pose_landmarker::PoseLandmarkName::kRightPinky1),
|
||||
/*index_idx=*/
|
||||
static_cast<int>(
|
||||
pose_landmarker::PoseLandmarkName::kRightIndex1)},
|
||||
request, graph));
|
||||
|
||||
auto image_size = GetImageSize(image, graph);
|
||||
auto left_hand_landmarks_render_data = utils::RenderLandmarks(
|
||||
*left_hand_result.landmarks,
|
||||
utils::GetRenderScale(image_size,
|
||||
left_hand_result.debug_output.roi_from_pose, 0.0001,
|
||||
graph),
|
||||
GetHandRendererOptions(), graph);
|
||||
auto right_hand_landmarks_render_data = utils::RenderLandmarks(
|
||||
*right_hand_result.landmarks,
|
||||
utils::GetRenderScale(image_size,
|
||||
right_hand_result.debug_output.roi_from_pose,
|
||||
0.0001, graph),
|
||||
GetHandRendererOptions(), graph);
|
||||
std::vector<Stream<mediapipe::RenderData>> render_list = {
|
||||
left_hand_landmarks_render_data, right_hand_landmarks_render_data};
|
||||
auto rendered_image =
|
||||
utils::Render(
|
||||
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
|
||||
.SetName(kRenderedImageOutStream);
|
||||
left_hand_result.landmarks->SetName(kLeftHandLandmarksOutStream) >>
|
||||
graph.Out("LEFT_HAND_LANDMARKS");
|
||||
right_hand_result.landmarks->SetName(kRightHandLandmarksOutStream) >>
|
||||
graph.Out("RIGHT_HAND_LANDMARKS");
|
||||
rendered_image >> graph.Out("RENDERED_IMAGE");
|
||||
|
||||
auto config = graph.GetConfig();
|
||||
core::FixGraphBackEdges(config);
|
||||
|
||||
return TaskRunner::Create(
|
||||
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
class HolisticHandTrackingTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(HolisticHandTrackingTest, VerifyGraph) {
|
||||
Graph graph;
|
||||
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
|
||||
Stream<mediapipe::NormalizedLandmarkList> pose_landmarks =
|
||||
graph.In("POSE_LANDMARKS")
|
||||
.Cast<mediapipe::NormalizedLandmarkList>()
|
||||
.SetName(kPoseLandmarksInStream);
|
||||
Stream<mediapipe::LandmarkList> pose_world_landmarks =
|
||||
graph.In("POSE_WORLD_LANDMARKS")
|
||||
.Cast<mediapipe::LandmarkList>()
|
||||
.SetName(kPoseWorldLandmarksInStream);
|
||||
hand_landmarker::proto::HandLandmarksDetectorGraphOptions
|
||||
hand_landmarks_detector_options;
|
||||
hand_landmarker::proto::HandRoiRefinementGraphOptions
|
||||
hand_roi_refinement_options;
|
||||
ConfigHandTrackingModelsOptions(hand_landmarks_detector_options,
|
||||
hand_roi_refinement_options);
|
||||
HolisticHandTrackingRequest request;
|
||||
request.landmarks = true;
|
||||
request.world_landmarks = true;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
HolisticHandTrackingOutput left_hand_result,
|
||||
TrackHolisticHand(
|
||||
image, pose_landmarks, pose_world_landmarks,
|
||||
hand_landmarks_detector_options, hand_roi_refinement_options,
|
||||
PoseIndices{
|
||||
/*wrist_idx=*/static_cast<int>(
|
||||
pose_landmarker::PoseLandmarkName::kLeftWrist),
|
||||
/*pinky_idx=*/
|
||||
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftPinky1),
|
||||
/*index_idx=*/
|
||||
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftIndex1)},
|
||||
request, graph));
|
||||
left_hand_result.landmarks->SetName(kLeftHandLandmarksOutStream) >>
|
||||
graph.Out("LEFT_HAND_LANDMARKS");
|
||||
left_hand_result.world_landmarks->SetName(kLeftHandWorldLandmarksOutStream) >>
|
||||
graph.Out("LEFT_HAND_WORLD_LANDMARKS");
|
||||
|
||||
// Read the expected graph config.
|
||||
std::string expected_graph_contents;
|
||||
MP_ASSERT_OK(file::GetContents(
|
||||
file::JoinPath("./", kTestDataDirectory, kHolisticHandTrackingLeft),
|
||||
&expected_graph_contents));
|
||||
|
||||
// Need to replace the expected graph config with the test srcdir, because
|
||||
// each run has different test dir on TAP.
|
||||
expected_graph_contents = absl::Substitute(
|
||||
expected_graph_contents, FLAGS_test_srcdir, FLAGS_test_srcdir);
|
||||
CalculatorGraphConfig expected_graph =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(expected_graph_contents);
|
||||
|
||||
EXPECT_THAT(graph.GetConfig(), testing::proto::IgnoringRepeatedFieldOrdering(
|
||||
testing::EqualsProto(expected_graph)));
|
||||
}
|
||||
|
||||
TEST_F(HolisticHandTrackingTest, SmokeTest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image image,
|
||||
DecodeImageFromFile(GetFilePath(kTestImageFile)));
|
||||
|
||||
proto::HolisticResult holistic_result;
|
||||
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
|
||||
Defaults()));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto output_packets,
|
||||
task_runner->Process(
|
||||
{{kImageInStream, MakePacket<Image>(image)},
|
||||
{kPoseLandmarksInStream, MakePacket<NormalizedLandmarkList>(
|
||||
holistic_result.pose_landmarks())},
|
||||
{kPoseWorldLandmarksInStream,
|
||||
MakePacket<LandmarkList>(
|
||||
holistic_result.pose_world_landmarks())}}));
|
||||
auto left_hand_landmarks = output_packets.at(kLeftHandLandmarksOutStream)
|
||||
.Get<NormalizedLandmarkList>();
|
||||
auto right_hand_landmarks = output_packets.at(kRightHandLandmarksOutStream)
|
||||
.Get<NormalizedLandmarkList>();
|
||||
EXPECT_THAT(left_hand_landmarks,
|
||||
Approximately(
|
||||
Partially(EqualsProto(holistic_result.left_hand_landmarks())),
|
||||
/*margin=*/kAbsMargin));
|
||||
EXPECT_THAT(
|
||||
right_hand_landmarks,
|
||||
Approximately(
|
||||
Partially(EqualsProto(holistic_result.right_hand_landmarks())),
|
||||
/*margin=*/kAbsMargin));
|
||||
auto rendered_image = output_packets.at(kRenderedImageOutStream).Get<Image>();
|
||||
MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(),
|
||||
"holistic_hand_landmarks"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace holistic_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,521 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/stream/split.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h"
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h"
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h"
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_topology.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/util/graph_builder_utils.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace holistic_landmarker {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
using ::mediapipe::tasks::metadata::SetExternalFile;
|
||||
|
||||
constexpr absl::string_view kHandLandmarksDetectorModelName =
|
||||
"hand_landmarks_detector.tflite";
|
||||
constexpr absl::string_view kHandRoiRefinementModelName =
|
||||
"hand_roi_refinement.tflite";
|
||||
constexpr absl::string_view kFaceDetectorModelName = "face_detector.tflite";
|
||||
constexpr absl::string_view kFaceLandmarksDetectorModelName =
|
||||
"face_landmarks_detector.tflite";
|
||||
constexpr absl::string_view kFaceBlendshapesModelName =
|
||||
"face_blendshapes.tflite";
|
||||
constexpr absl::string_view kPoseDetectorModelName = "pose_detector.tflite";
|
||||
constexpr absl::string_view kPoseLandmarksDetectorModelName =
|
||||
"pose_landmarks_detector.tflite";
|
||||
|
||||
absl::Status SetGraphPoseOutputs(
|
||||
const HolisticPoseTrackingRequest& pose_request,
|
||||
const CalculatorGraphConfig::Node& node,
|
||||
HolisticPoseTrackingOutput& pose_output, Graph& graph) {
|
||||
// Main outputs.
|
||||
if (pose_request.landmarks) {
|
||||
RET_CHECK(pose_output.landmarks.has_value())
|
||||
<< "POSE_LANDMARKS output is not supported.";
|
||||
pose_output.landmarks->ConnectTo(graph.Out("POSE_LANDMARKS"));
|
||||
}
|
||||
if (pose_request.world_landmarks) {
|
||||
RET_CHECK(pose_output.world_landmarks.has_value())
|
||||
<< "POSE_WORLD_LANDMARKS output is not supported.";
|
||||
pose_output.world_landmarks->ConnectTo(graph.Out("POSE_WORLD_LANDMARKS"));
|
||||
}
|
||||
if (pose_request.segmentation_mask) {
|
||||
RET_CHECK(pose_output.segmentation_mask.has_value())
|
||||
<< "POSE_SEGMENTATION_MASK output is not supported.";
|
||||
pose_output.segmentation_mask->ConnectTo(
|
||||
graph.Out("POSE_SEGMENTATION_MASK"));
|
||||
}
|
||||
|
||||
// Debug outputs.
|
||||
if (HasOutput(node, "POSE_AUXILIARY_LANDMARKS")) {
|
||||
pose_output.debug_output.auxiliary_landmarks.ConnectTo(
|
||||
graph.Out("POSE_AUXILIARY_LANDMARKS"));
|
||||
}
|
||||
if (HasOutput(node, "POSE_LANDMARKS_ROI")) {
|
||||
pose_output.debug_output.roi_from_landmarks.ConnectTo(
|
||||
graph.Out("POSE_LANDMARKS_ROI"));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Sets the base options in the sub tasks.
|
||||
template <typename T>
|
||||
absl::Status SetSubTaskBaseOptions(
|
||||
const core::ModelAssetBundleResources* resources,
|
||||
proto::HolisticLandmarkerGraphOptions* options, T* sub_task_options,
|
||||
absl::string_view model_name, bool is_copy) {
|
||||
if (!sub_task_options->base_options().has_model_asset()) {
|
||||
MP_ASSIGN_OR_RETURN(const auto model_file_content,
|
||||
resources->GetFile(std::string(model_name)));
|
||||
SetExternalFile(
|
||||
model_file_content,
|
||||
sub_task_options->mutable_base_options()->mutable_model_asset(),
|
||||
is_copy);
|
||||
}
|
||||
sub_task_options->mutable_base_options()->mutable_acceleration()->CopyFrom(
|
||||
options->base_options().acceleration());
|
||||
sub_task_options->mutable_base_options()->set_use_stream_mode(
|
||||
options->base_options().use_stream_mode());
|
||||
sub_task_options->mutable_base_options()->set_gpu_origin(
|
||||
options->base_options().gpu_origin());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void SetGraphHandOutputs(bool is_left, const CalculatorGraphConfig::Node& node,
|
||||
HolisticHandTrackingOutput& hand_output,
|
||||
Graph& graph) {
|
||||
const std::string hand_side = is_left ? "LEFT" : "RIGHT";
|
||||
|
||||
if (hand_output.landmarks) {
|
||||
hand_output.landmarks->ConnectTo(graph.Out(hand_side + "_HAND_LANDMARKS"));
|
||||
}
|
||||
if (hand_output.world_landmarks) {
|
||||
hand_output.world_landmarks->ConnectTo(
|
||||
graph.Out(hand_side + "_HAND_WORLD_LANDMARKS"));
|
||||
}
|
||||
|
||||
// Debug outputs.
|
||||
if (HasOutput(node, hand_side + "_HAND_ROI_FROM_POSE")) {
|
||||
hand_output.debug_output.roi_from_pose.ConnectTo(
|
||||
graph.Out(hand_side + "_HAND_ROI_FROM_POSE"));
|
||||
}
|
||||
if (HasOutput(node, hand_side + "_HAND_ROI_FROM_RECROP")) {
|
||||
hand_output.debug_output.roi_from_recrop.ConnectTo(
|
||||
graph.Out(hand_side + "_HAND_ROI_FROM_RECROP"));
|
||||
}
|
||||
if (HasOutput(node, hand_side + "_HAND_TRACKING_ROI")) {
|
||||
hand_output.debug_output.tracking_roi.ConnectTo(
|
||||
graph.Out(hand_side + "_HAND_TRACKING_ROI"));
|
||||
}
|
||||
}
|
||||
|
||||
void SetGraphFaceOutputs(const CalculatorGraphConfig::Node& node,
|
||||
HolisticFaceTrackingOutput& face_output,
|
||||
Graph& graph) {
|
||||
if (face_output.landmarks) {
|
||||
face_output.landmarks->ConnectTo(graph.Out("FACE_LANDMARKS"));
|
||||
}
|
||||
if (face_output.classifications) {
|
||||
face_output.classifications->ConnectTo(graph.Out("FACE_BLENDSHAPES"));
|
||||
}
|
||||
|
||||
// Face detection debug outputs
|
||||
if (HasOutput(node, "FACE_ROI_FROM_POSE")) {
|
||||
face_output.debug_output.roi_from_pose.ConnectTo(
|
||||
graph.Out("FACE_ROI_FROM_POSE"));
|
||||
}
|
||||
if (HasOutput(node, "FACE_ROI_FROM_DETECTION")) {
|
||||
face_output.debug_output.roi_from_detection.ConnectTo(
|
||||
graph.Out("FACE_ROI_FROM_DETECTION"));
|
||||
}
|
||||
if (HasOutput(node, "FACE_TRACKING_ROI")) {
|
||||
face_output.debug_output.tracking_roi.ConnectTo(
|
||||
graph.Out("FACE_TRACKING_ROI"));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Tracks pose and detects hands and face.
|
||||
//
|
||||
// NOTE: for GPU works only with image having GpuOrigin::TOP_LEFT
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform detection on.
|
||||
//
|
||||
// Outputs:
|
||||
// POSE_LANDMARKS - NormalizedLandmarkList
|
||||
// 33 landmarks (see pose_landmarker/pose_topology.h)
|
||||
// 0 - nose
|
||||
// 1 - left eye (inner)
|
||||
// 2 - left eye
|
||||
// 3 - left eye (outer)
|
||||
// 4 - right eye (inner)
|
||||
// 5 - right eye
|
||||
// 6 - right eye (outer)
|
||||
// 7 - left ear
|
||||
// 8 - right ear
|
||||
// 9 - mouth (left)
|
||||
// 10 - mouth (right)
|
||||
// 11 - left shoulder
|
||||
// 12 - right shoulder
|
||||
// 13 - left elbow
|
||||
// 14 - right elbow
|
||||
// 15 - left wrist
|
||||
// 16 - right wrist
|
||||
// 17 - left pinky
|
||||
// 18 - right pinky
|
||||
// 19 - left index
|
||||
// 20 - right index
|
||||
// 21 - left thumb
|
||||
// 22 - right thumb
|
||||
// 23 - left hip
|
||||
// 24 - right hip
|
||||
// 25 - left knee
|
||||
// 26 - right knee
|
||||
// 27 - left ankle
|
||||
// 28 - right ankle
|
||||
// 29 - left heel
|
||||
// 30 - right heel
|
||||
// 31 - left foot index
|
||||
// 32 - right foot index
|
||||
// POSE_WORLD_LANDMARKS - LandmarkList
|
||||
// World landmarks are real world 3D coordinates with origin in hips center
|
||||
// and coordinates in meters. To understand the difference: POSE_LANDMARKS
|
||||
// stream provides coordinates (in pixels) of 3D object projected on a 2D
|
||||
// surface of the image (check on how perspective projection works), while
|
||||
// POSE_WORLD_LANDMARKS stream provides coordinates (in meters) of the 3D
|
||||
// object itself. POSE_WORLD_LANDMARKS has the same landmarks topology,
|
||||
// visibility and presence as POSE_LANDMARKS.
|
||||
// POSE_SEGMENTATION_MASK - Image
|
||||
// Separates person from background. Mask is stored as gray float32 image
|
||||
// with [0.0, 1.0] range for pixels (1 for person and 0 for background) on
|
||||
// CPU and, on GPU - RGBA texture with R channel indicating person vs.
|
||||
// background probability.
|
||||
// LEFT_HAND_LANDMARKS - NormalizedLandmarkList
|
||||
// 21 left hand landmarks.
|
||||
// RIGHT_HAND_LANDMARKS - NormalizedLandmarkList
|
||||
// 21 right hand landmarks.
|
||||
// FACE_LANDMARKS - NormalizedLandmarkList
|
||||
// 468 face landmarks.
|
||||
// FACE_BLENDSHAPES - ClassificationList
|
||||
// Supplementary blendshape coefficients that are predicted directly from
|
||||
// the input image.
|
||||
// LEFT_HAND_WORLD_LANDMARKS - LandmarkList
|
||||
// 21 left hand world 3D landmarks.
|
||||
// Hand landmarks are aligned with pose landmarks: translated so that wrist
|
||||
// from # hand matches wrist from pose in pose coordinates system.
|
||||
// RIGHT_HAND_WORLD_LANDMARKS - LandmarkList
|
||||
// 21 right hand world 3D landmarks.
|
||||
// Hand landmarks are aligned with pose landmarks: translated so that wrist
|
||||
// from # hand matches wrist from pose in pose coordinates system.
|
||||
// IMAGE - Image
|
||||
// The input image that the hiolistic landmarker runs on and has the pixel
|
||||
// data stored on the target storage (CPU vs GPU).
|
||||
//
|
||||
// Debug outputs:
|
||||
// POSE_AUXILIARY_LANDMARKS - NormalizedLandmarkList
|
||||
// TODO: Return ROI rather than auxiliary landmarks
|
||||
// Auxiliary landmarks for deriving the ROI in the subsequent image.
|
||||
// 0 - hidden center point
|
||||
// 1 - hidden scale point
|
||||
// POSE_LANDMARKS_ROI - NormalizedRect
|
||||
// Region of interest calculated based on landmarks.
|
||||
// LEFT_HAND_ROI_FROM_POSE - NormalizedLandmarkList
|
||||
// LEFT_HAND_ROI_FROM_RECROP - NormalizedLandmarkList
|
||||
// LEFT_HAND_TRACKING_ROI - NormalizedLandmarkList
|
||||
// RIGHT_HAND_ROI_FROM_POSE - NormalizedLandmarkList
|
||||
// RIGHT_HAND_ROI_FROM_RECROP - NormalizedLandmarkList
|
||||
// RIGHT_HAND_TRACKING_ROI - NormalizedLandmarkList
|
||||
// FACE_ROI_FROM_POSE - NormalizedLandmarkList
|
||||
// FACE_ROI_FROM_DETECTION - NormalizedLandmarkList
|
||||
// FACE_TRACKING_ROI - NormalizedLandmarkList
|
||||
//
|
||||
// NOTE: failure is reported if some output has been requested, but specified
|
||||
// model doesn't support it.
|
||||
//
|
||||
// NOTE: there will not be an output packet in an output stream for a
|
||||
// particular timestamp if nothing is detected. However, the MediaPipe
|
||||
// framework will internally inform the downstream calculators of the
|
||||
// absence of this packet so that they don't wait for it unnecessarily.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph"
|
||||
// input_stream: "IMAGE:input_frames_image"
|
||||
// output_stream: "POSE_LANDMARKS:pose_landmarks"
|
||||
// output_stream: "POSE_WORLD_LANDMARKS:pose_world_landmarks"
|
||||
// output_stream: "FACE_LANDMARKS:face_landmarks"
|
||||
// output_stream: "FACE_BLENDSHAPES:extra_blendshapes"
|
||||
// output_stream: "LEFT_HAND_LANDMARKS:left_hand_landmarks"
|
||||
// output_stream: "LEFT_HAND_WORLD_LANDMARKS:left_hand_world_landmarks"
|
||||
// output_stream: "RIGHT_HAND_LANDMARKS:right_hand_landmarks"
|
||||
// output_stream: "RIGHT_HAND_WORLD_LANDMARKS:right_hand_world_landmarks"
|
||||
// node_options {
|
||||
// [type.googleapis.com/mediapipe.tasks.vision.holistic_landmarker.proto.HolisticLandmarkerGraphOptions]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name:
|
||||
// "mediapipe/tasks/testdata/vision/holistic_landmarker.task"
|
||||
// }
|
||||
// }
|
||||
// face_detector_graph_options: {
|
||||
// num_faces: 1
|
||||
// }
|
||||
// pose_detector_graph_options: {
|
||||
// num_poses: 1
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class HolisticLandmarkerGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
const auto& holistic_node = sc->OriginalNode();
|
||||
proto::HolisticLandmarkerGraphOptions* holistic_options =
|
||||
sc->MutableOptions<proto::HolisticLandmarkerGraphOptions>();
|
||||
const core::ModelAssetBundleResources* model_asset_bundle_resources;
|
||||
if (holistic_options->base_options().has_model_asset()) {
|
||||
MP_ASSIGN_OR_RETURN(model_asset_bundle_resources,
|
||||
CreateModelAssetBundleResources<
|
||||
proto::HolisticLandmarkerGraphOptions>(sc));
|
||||
}
|
||||
// Copies the file content instead of passing the pointer of file in
|
||||
// memory if the subgraph model resource service is not available.
|
||||
bool create_copy =
|
||||
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||
.IsAvailable();
|
||||
|
||||
Stream<Image> image = graph.In("IMAGE").Cast<Image>();
|
||||
|
||||
// Check whether Hand requested
|
||||
const bool is_left_hand_requested =
|
||||
HasOutput(holistic_node, "LEFT_HAND_LANDMARKS");
|
||||
const bool is_right_hand_requested =
|
||||
HasOutput(holistic_node, "RIGHT_HAND_LANDMARKS");
|
||||
const bool is_left_hand_world_requested =
|
||||
HasOutput(holistic_node, "LEFT_HAND_WORLD_LANDMARKS");
|
||||
const bool is_right_hand_world_requested =
|
||||
HasOutput(holistic_node, "RIGHT_HAND_WORLD_LANDMARKS");
|
||||
const bool hands_requested =
|
||||
is_left_hand_requested || is_right_hand_requested ||
|
||||
is_left_hand_world_requested || is_right_hand_world_requested;
|
||||
if (hands_requested) {
|
||||
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
|
||||
model_asset_bundle_resources, holistic_options,
|
||||
holistic_options->mutable_hand_landmarks_detector_graph_options(),
|
||||
kHandLandmarksDetectorModelName, create_copy));
|
||||
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
|
||||
model_asset_bundle_resources, holistic_options,
|
||||
holistic_options->mutable_hand_roi_refinement_graph_options(),
|
||||
kHandRoiRefinementModelName, create_copy));
|
||||
}
|
||||
|
||||
// Check whether Face requested
|
||||
const bool is_face_requested = HasOutput(holistic_node, "FACE_LANDMARKS");
|
||||
const bool is_face_blendshapes_requested =
|
||||
HasOutput(holistic_node, "FACE_BLENDSHAPES");
|
||||
const bool face_requested =
|
||||
is_face_requested || is_face_blendshapes_requested;
|
||||
if (face_requested) {
|
||||
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
|
||||
model_asset_bundle_resources, holistic_options,
|
||||
holistic_options->mutable_face_detector_graph_options(),
|
||||
kFaceDetectorModelName, create_copy));
|
||||
// Forcely set num_faces to 1, because holistic landmarker only supports a
|
||||
// single subject for now.
|
||||
holistic_options->mutable_face_detector_graph_options()->set_num_faces(1);
|
||||
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
|
||||
model_asset_bundle_resources, holistic_options,
|
||||
holistic_options->mutable_face_landmarks_detector_graph_options(),
|
||||
kFaceLandmarksDetectorModelName, create_copy));
|
||||
if (is_face_blendshapes_requested) {
|
||||
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
|
||||
model_asset_bundle_resources, holistic_options,
|
||||
holistic_options->mutable_face_landmarks_detector_graph_options()
|
||||
->mutable_face_blendshapes_graph_options(),
|
||||
kFaceBlendshapesModelName, create_copy));
|
||||
}
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
|
||||
model_asset_bundle_resources, holistic_options,
|
||||
holistic_options->mutable_pose_detector_graph_options(),
|
||||
kPoseDetectorModelName, create_copy));
|
||||
// Forcely set num_poses to 1, because holistic landmarker sonly supports a
|
||||
// single subject for now.
|
||||
holistic_options->mutable_pose_detector_graph_options()->set_num_poses(1);
|
||||
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
|
||||
model_asset_bundle_resources, holistic_options,
|
||||
holistic_options->mutable_pose_landmarks_detector_graph_options(),
|
||||
kPoseLandmarksDetectorModelName, create_copy));
|
||||
|
||||
HolisticPoseTrackingRequest pose_request = {
|
||||
.landmarks = HasOutput(holistic_node, "POSE_LANDMARKS") ||
|
||||
hands_requested || face_requested,
|
||||
.world_landmarks =
|
||||
HasOutput(holistic_node, "POSE_WORLD_LANDMARKS") || hands_requested,
|
||||
.segmentation_mask =
|
||||
HasOutput(holistic_node, "POSE_SEGMENTATION_MASK")};
|
||||
|
||||
// Detect and track pose.
|
||||
MP_ASSIGN_OR_RETURN(
|
||||
HolisticPoseTrackingOutput pose_output,
|
||||
TrackHolisticPose(
|
||||
image, holistic_options->pose_detector_graph_options(),
|
||||
holistic_options->pose_landmarks_detector_graph_options(),
|
||||
pose_request, graph));
|
||||
MP_RETURN_IF_ERROR(
|
||||
SetGraphPoseOutputs(pose_request, holistic_node, pose_output, graph));
|
||||
|
||||
// Detect and track hand.
|
||||
if (hands_requested) {
|
||||
if (is_left_hand_requested || is_left_hand_world_requested) {
|
||||
RET_CHECK(pose_output.landmarks.has_value());
|
||||
RET_CHECK(pose_output.world_landmarks.has_value());
|
||||
|
||||
PoseIndices pose_indices = {
|
||||
.wrist_idx =
|
||||
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftWrist),
|
||||
.pinky_idx = static_cast<int>(
|
||||
pose_landmarker::PoseLandmarkName::kLeftPinky1),
|
||||
.index_idx = static_cast<int>(
|
||||
pose_landmarker::PoseLandmarkName::kLeftIndex1),
|
||||
};
|
||||
HolisticHandTrackingRequest hand_request = {
|
||||
.landmarks = is_left_hand_requested,
|
||||
.world_landmarks = is_left_hand_world_requested,
|
||||
};
|
||||
MP_ASSIGN_OR_RETURN(
|
||||
HolisticHandTrackingOutput hand_output,
|
||||
TrackHolisticHand(
|
||||
image, *pose_output.landmarks, *pose_output.world_landmarks,
|
||||
holistic_options->hand_landmarks_detector_graph_options(),
|
||||
holistic_options->hand_roi_refinement_graph_options(),
|
||||
pose_indices, hand_request, graph
|
||||
|
||||
));
|
||||
SetGraphHandOutputs(/*is_left=*/true, holistic_node, hand_output,
|
||||
graph);
|
||||
}
|
||||
|
||||
if (is_right_hand_requested || is_right_hand_world_requested) {
|
||||
RET_CHECK(pose_output.landmarks.has_value());
|
||||
RET_CHECK(pose_output.world_landmarks.has_value());
|
||||
|
||||
PoseIndices pose_indices = {
|
||||
.wrist_idx = static_cast<int>(
|
||||
pose_landmarker::PoseLandmarkName::kRightWrist),
|
||||
.pinky_idx = static_cast<int>(
|
||||
pose_landmarker::PoseLandmarkName::kRightPinky1),
|
||||
.index_idx = static_cast<int>(
|
||||
pose_landmarker::PoseLandmarkName::kRightIndex1),
|
||||
};
|
||||
HolisticHandTrackingRequest hand_request = {
|
||||
.landmarks = is_right_hand_requested,
|
||||
.world_landmarks = is_right_hand_world_requested,
|
||||
};
|
||||
MP_ASSIGN_OR_RETURN(
|
||||
HolisticHandTrackingOutput hand_output,
|
||||
TrackHolisticHand(
|
||||
image, *pose_output.landmarks, *pose_output.world_landmarks,
|
||||
holistic_options->hand_landmarks_detector_graph_options(),
|
||||
holistic_options->hand_roi_refinement_graph_options(),
|
||||
pose_indices, hand_request, graph
|
||||
|
||||
));
|
||||
SetGraphHandOutputs(/*is_left=*/false, holistic_node, hand_output,
|
||||
graph);
|
||||
}
|
||||
}
|
||||
|
||||
// Detect and track face.
|
||||
if (face_requested) {
|
||||
RET_CHECK(pose_output.landmarks.has_value());
|
||||
|
||||
Stream<mediapipe::NormalizedLandmarkList> face_landmarks_from_pose =
|
||||
api2::builder::SplitToRanges(*pose_output.landmarks, {{0, 11}},
|
||||
graph)[0];
|
||||
|
||||
HolisticFaceTrackingRequest face_request = {
|
||||
.classifications = is_face_blendshapes_requested,
|
||||
};
|
||||
MP_ASSIGN_OR_RETURN(
|
||||
HolisticFaceTrackingOutput face_output,
|
||||
TrackHolisticFace(
|
||||
image, face_landmarks_from_pose,
|
||||
holistic_options->face_detector_graph_options(),
|
||||
holistic_options->face_landmarks_detector_graph_options(),
|
||||
face_request, graph));
|
||||
SetGraphFaceOutputs(holistic_node, face_output, graph);
|
||||
}
|
||||
|
||||
auto& pass_through = graph.AddNode("PassThroughCalculator");
|
||||
image >> pass_through.In("");
|
||||
pass_through.Out("") >> graph.Out("IMAGE");
|
||||
|
||||
auto config = graph.GetConfig();
|
||||
core::FixGraphBackEdges(config);
|
||||
return config;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::holistic_landmarker::HolisticLandmarkerGraph);
|
||||
|
||||
} // namespace holistic_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,595 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "file/base/helpers.h"
|
||||
#include "file/base/options.h"
|
||||
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/stream/image_size.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
#include "testing/base/public/gmock.h"
|
||||
#include "testing/base/public/googletest.h"
|
||||
#include "testing/base/public/gunit.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace holistic_landmarker {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::builder::GetImageSize;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
using ::mediapipe::tasks::core::TaskRunner;
|
||||
using ::testing::TestParamInfo;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr float kAbsMargin = 0.025;
|
||||
constexpr absl::string_view kTestDataDirectory =
|
||||
"/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kHolisticResultFile[] =
|
||||
"male_full_height_hands_result_cpu.pbtxt";
|
||||
constexpr absl::string_view kTestImageFile = "male_full_height_hands.jpg";
|
||||
constexpr absl::string_view kImageInStream = "image_in";
|
||||
constexpr absl::string_view kLeftHandLandmarksStream = "left_hand_landmarks";
|
||||
constexpr absl::string_view kRightHandLandmarksStream = "right_hand_landmarks";
|
||||
constexpr absl::string_view kFaceLandmarksStream = "face_landmarks";
|
||||
constexpr absl::string_view kFaceBlendshapesStream = "face_blendshapes";
|
||||
constexpr absl::string_view kPoseLandmarksStream = "pose_landmarks";
|
||||
constexpr absl::string_view kRenderedImageOutStream = "rendered_image_out";
|
||||
constexpr absl::string_view kPoseSegmentationMaskStream =
|
||||
"pose_segmentation_mask";
|
||||
constexpr absl::string_view kHolisticLandmarkerModelBundleFile =
|
||||
"holistic_landmarker.task";
|
||||
constexpr absl::string_view kHandLandmarksModelFile =
|
||||
"hand_landmark_full.tflite";
|
||||
constexpr absl::string_view kHandRoiRefinementModelFile =
|
||||
"handrecrop_2020_07_21_v0.f16.tflite";
|
||||
constexpr absl::string_view kPoseDetectionModelFile = "pose_detection.tflite";
|
||||
constexpr absl::string_view kPoseLandmarksModelFile =
|
||||
"pose_landmark_lite.tflite";
|
||||
constexpr absl::string_view kFaceDetectionModelFile =
|
||||
"face_detection_short_range.tflite";
|
||||
constexpr absl::string_view kFaceLandmarksModelFile =
|
||||
"facemesh2_lite_iris_faceflag_2023_02_14.tflite";
|
||||
constexpr absl::string_view kFaceBlendshapesModelFile =
|
||||
"face_blendshapes.tflite";
|
||||
|
||||
enum RenderPart {
|
||||
HAND = 0,
|
||||
POSE = 1,
|
||||
FACE = 2,
|
||||
};
|
||||
|
||||
mediapipe::Color GetColor(RenderPart render_part) {
|
||||
mediapipe::Color color;
|
||||
switch (render_part) {
|
||||
case HAND:
|
||||
color.set_b(255);
|
||||
color.set_g(255);
|
||||
color.set_r(255);
|
||||
break;
|
||||
case POSE:
|
||||
color.set_b(0);
|
||||
color.set_g(255);
|
||||
color.set_r(0);
|
||||
break;
|
||||
case FACE:
|
||||
color.set_b(0);
|
||||
color.set_g(0);
|
||||
color.set_r(255);
|
||||
break;
|
||||
}
|
||||
return color;
|
||||
}
|
||||
|
||||
std::string GetFilePath(absl::string_view filename) {
|
||||
return file::JoinPath("./", kTestDataDirectory, filename);
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
mediapipe::LandmarksToRenderDataCalculatorOptions GetRendererOptions(
|
||||
const std::array<std::array<int, 2>, N>& connections,
|
||||
mediapipe::Color color) {
|
||||
mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options;
|
||||
for (const auto& connection : connections) {
|
||||
renderer_options.add_landmark_connections(connection[0]);
|
||||
renderer_options.add_landmark_connections(connection[1]);
|
||||
}
|
||||
*renderer_options.mutable_landmark_color() = color;
|
||||
*renderer_options.mutable_connection_color() = color;
|
||||
renderer_options.set_thickness(0.5);
|
||||
renderer_options.set_visualize_landmark_depth(false);
|
||||
return renderer_options;
|
||||
}
|
||||
|
||||
void ConfigureHandProtoOptions(proto::HolisticLandmarkerGraphOptions& options) {
|
||||
options.mutable_hand_landmarks_detector_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath(kHandLandmarksModelFile));
|
||||
|
||||
options.mutable_hand_roi_refinement_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath(kHandRoiRefinementModelFile));
|
||||
}
|
||||
|
||||
void ConfigureFaceProtoOptions(proto::HolisticLandmarkerGraphOptions& options) {
|
||||
// Set face detection model.
|
||||
face_detector::proto::FaceDetectorGraphOptions& face_detector_graph_options =
|
||||
*options.mutable_face_detector_graph_options();
|
||||
face_detector_graph_options.mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath(kFaceDetectionModelFile));
|
||||
face_detector_graph_options.set_num_faces(1);
|
||||
|
||||
// Set face landmarks model.
|
||||
face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
|
||||
face_landmarks_graph_options =
|
||||
*options.mutable_face_landmarks_detector_graph_options();
|
||||
face_landmarks_graph_options.mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath(kFaceLandmarksModelFile));
|
||||
face_landmarks_graph_options.mutable_face_blendshapes_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath(kFaceBlendshapesModelFile));
|
||||
}
|
||||
|
||||
void ConfigurePoseProtoOptions(proto::HolisticLandmarkerGraphOptions& options) {
|
||||
pose_detector::proto::PoseDetectorGraphOptions& pose_detector_graph_options =
|
||||
*options.mutable_pose_detector_graph_options();
|
||||
pose_detector_graph_options.mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath(kPoseDetectionModelFile));
|
||||
pose_detector_graph_options.set_num_poses(1);
|
||||
options.mutable_pose_landmarks_detector_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath(kPoseLandmarksModelFile));
|
||||
}
|
||||
|
||||
struct HolisticRequest {
|
||||
bool is_left_hand_requested = false;
|
||||
bool is_right_hand_requested = false;
|
||||
bool is_face_requested = false;
|
||||
bool is_face_blendshapes_requested = false;
|
||||
};
|
||||
|
||||
// Helper function to create a TaskRunner.
|
||||
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner(
|
||||
bool use_model_bundle, HolisticRequest holistic_request) {
|
||||
Graph graph;
|
||||
|
||||
Stream<Image> image = graph.In("IMAEG").Cast<Image>().SetName(kImageInStream);
|
||||
|
||||
auto& holistic_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph");
|
||||
proto::HolisticLandmarkerGraphOptions& options =
|
||||
holistic_graph.GetOptions<proto::HolisticLandmarkerGraphOptions>();
|
||||
if (use_model_bundle) {
|
||||
options.mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||
GetFilePath(kHolisticLandmarkerModelBundleFile));
|
||||
} else {
|
||||
ConfigureHandProtoOptions(options);
|
||||
ConfigurePoseProtoOptions(options);
|
||||
ConfigureFaceProtoOptions(options);
|
||||
}
|
||||
|
||||
std::vector<Stream<mediapipe::RenderData>> render_list;
|
||||
image >> holistic_graph.In("IMAGE");
|
||||
Stream<std::pair<int, int>> image_size = GetImageSize(image, graph);
|
||||
|
||||
if (holistic_request.is_left_hand_requested) {
|
||||
Stream<NormalizedLandmarkList> left_hand_landmarks =
|
||||
holistic_graph.Out("LEFT_HAND_LANDMARKS")
|
||||
.Cast<NormalizedLandmarkList>()
|
||||
.SetName(kLeftHandLandmarksStream);
|
||||
Stream<NormalizedRect> left_hand_tracking_roi =
|
||||
holistic_graph.Out("LEFT_HAND_TRACKING_ROI").Cast<NormalizedRect>();
|
||||
auto left_hand_landmarks_render_data = utils::RenderLandmarks(
|
||||
left_hand_landmarks,
|
||||
utils::GetRenderScale(image_size, left_hand_tracking_roi, 0.0001,
|
||||
graph),
|
||||
GetRendererOptions(hand_landmarker::kHandConnections,
|
||||
GetColor(RenderPart::HAND)),
|
||||
graph);
|
||||
render_list.push_back(left_hand_landmarks_render_data);
|
||||
left_hand_landmarks >> graph.Out("LEFT_HAND_LANDMARKS");
|
||||
}
|
||||
if (holistic_request.is_right_hand_requested) {
|
||||
Stream<NormalizedLandmarkList> right_hand_landmarks =
|
||||
holistic_graph.Out("RIGHT_HAND_LANDMARKS")
|
||||
.Cast<NormalizedLandmarkList>()
|
||||
.SetName(kRightHandLandmarksStream);
|
||||
Stream<NormalizedRect> right_hand_tracking_roi =
|
||||
holistic_graph.Out("RIGHT_HAND_TRACKING_ROI").Cast<NormalizedRect>();
|
||||
auto right_hand_landmarks_render_data = utils::RenderLandmarks(
|
||||
right_hand_landmarks,
|
||||
utils::GetRenderScale(image_size, right_hand_tracking_roi, 0.0001,
|
||||
graph),
|
||||
GetRendererOptions(hand_landmarker::kHandConnections,
|
||||
GetColor(RenderPart::HAND)),
|
||||
graph);
|
||||
render_list.push_back(right_hand_landmarks_render_data);
|
||||
right_hand_landmarks >> graph.Out("RIGHT_HAND_LANDMARKS");
|
||||
}
|
||||
if (holistic_request.is_face_requested) {
|
||||
Stream<NormalizedLandmarkList> face_landmarks =
|
||||
holistic_graph.Out("FACE_LANDMARKS")
|
||||
.Cast<NormalizedLandmarkList>()
|
||||
.SetName(kFaceLandmarksStream);
|
||||
Stream<NormalizedRect> face_tracking_roi =
|
||||
holistic_graph.Out("FACE_TRACKING_ROI").Cast<NormalizedRect>();
|
||||
auto face_landmarks_render_data = utils::RenderLandmarks(
|
||||
face_landmarks,
|
||||
utils::GetRenderScale(image_size, face_tracking_roi, 0.0001, graph),
|
||||
GetRendererOptions(
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksConnectors,
|
||||
GetColor(RenderPart::FACE)),
|
||||
graph);
|
||||
render_list.push_back(face_landmarks_render_data);
|
||||
face_landmarks >> graph.Out("FACE_LANDMARKS");
|
||||
}
|
||||
if (holistic_request.is_face_blendshapes_requested) {
|
||||
Stream<ClassificationList> face_blendshapes =
|
||||
holistic_graph.Out("FACE_BLENDSHAPES")
|
||||
.Cast<ClassificationList>()
|
||||
.SetName(kFaceBlendshapesStream);
|
||||
face_blendshapes >> graph.Out("FACE_BLENDSHAPES");
|
||||
}
|
||||
Stream<NormalizedLandmarkList> pose_landmarks =
|
||||
holistic_graph.Out("POSE_LANDMARKS")
|
||||
.Cast<NormalizedLandmarkList>()
|
||||
.SetName(kPoseLandmarksStream);
|
||||
Stream<NormalizedRect> pose_tracking_roi =
|
||||
holistic_graph.Out("POSE_LANDMARKS_ROI").Cast<NormalizedRect>();
|
||||
Stream<Image> pose_segmentation_mask =
|
||||
holistic_graph.Out("POSE_SEGMENTATION_MASK")
|
||||
.Cast<Image>()
|
||||
.SetName(kPoseSegmentationMaskStream);
|
||||
|
||||
auto pose_landmarks_render_data = utils::RenderLandmarks(
|
||||
pose_landmarks,
|
||||
utils::GetRenderScale(image_size, pose_tracking_roi, 0.0001, graph),
|
||||
GetRendererOptions(pose_landmarker::kPoseLandmarksConnections,
|
||||
GetColor(RenderPart::POSE)),
|
||||
graph);
|
||||
render_list.push_back(pose_landmarks_render_data);
|
||||
auto rendered_image =
|
||||
utils::Render(
|
||||
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
|
||||
.SetName(kRenderedImageOutStream);
|
||||
|
||||
pose_landmarks >> graph.Out("POSE_LANDMARKS");
|
||||
pose_segmentation_mask >> graph.Out("POSE_SEGMENTATION_MASK");
|
||||
rendered_image >> graph.Out("RENDERED_IMAGE");
|
||||
|
||||
auto config = graph.GetConfig();
|
||||
core::FixGraphBackEdges(config);
|
||||
|
||||
return TaskRunner::Create(
|
||||
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<T> FetchResult(const core::PacketMap& output_packets,
|
||||
absl::string_view stream_name) {
|
||||
auto it = output_packets.find(std::string(stream_name));
|
||||
RET_CHECK(it != output_packets.end());
|
||||
return it->second.Get<T>();
|
||||
}
|
||||
|
||||
// Remove fields not to be checked in the result, since the model
|
||||
// generating expected result is different from the testing model.
|
||||
void RemoveUncheckedResult(proto::HolisticResult& holistic_result) {
|
||||
for (auto& landmark :
|
||||
*holistic_result.mutable_pose_landmarks()->mutable_landmark()) {
|
||||
landmark.clear_z();
|
||||
landmark.clear_visibility();
|
||||
landmark.clear_presence();
|
||||
}
|
||||
for (auto& landmark :
|
||||
*holistic_result.mutable_face_landmarks()->mutable_landmark()) {
|
||||
landmark.clear_z();
|
||||
landmark.clear_visibility();
|
||||
landmark.clear_presence();
|
||||
}
|
||||
for (auto& landmark :
|
||||
*holistic_result.mutable_left_hand_landmarks()->mutable_landmark()) {
|
||||
landmark.clear_z();
|
||||
landmark.clear_visibility();
|
||||
landmark.clear_presence();
|
||||
}
|
||||
for (auto& landmark :
|
||||
*holistic_result.mutable_right_hand_landmarks()->mutable_landmark()) {
|
||||
landmark.clear_z();
|
||||
landmark.clear_visibility();
|
||||
landmark.clear_presence();
|
||||
}
|
||||
}
|
||||
|
||||
std::string RequestToString(HolisticRequest request) {
|
||||
return absl::StrFormat(
|
||||
"%s_%s_%s_%s",
|
||||
request.is_left_hand_requested ? "left_hand" : "no_left_hand",
|
||||
request.is_right_hand_requested ? "right_hand" : "no_right_hand",
|
||||
request.is_face_requested ? "face" : "no_face",
|
||||
request.is_face_blendshapes_requested ? "face_blendshapes"
|
||||
: "no_face_blendshapes");
|
||||
}
|
||||
|
||||
struct TestParams {
|
||||
// The name of this test, for convenience when displaying test results.
|
||||
std::string test_name;
|
||||
// The filename of test image.
|
||||
std::string test_image_name;
|
||||
// Whether to use holistic model bundle to test.
|
||||
bool use_model_bundle;
|
||||
// Requests of holistic parts.
|
||||
HolisticRequest holistic_request;
|
||||
};
|
||||
|
||||
class SmokeTest : public testing::TestWithParam<TestParams> {};
|
||||
|
||||
TEST_P(SmokeTest, Succeeds) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(GetFilePath(GetParam().test_image_name)));
|
||||
|
||||
proto::HolisticResult holistic_result;
|
||||
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
|
||||
::file::Defaults()));
|
||||
RemoveUncheckedResult(holistic_result);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto task_runner,
|
||||
CreateTaskRunner(GetParam().use_model_bundle,
|
||||
GetParam().holistic_request));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto output_packets,
|
||||
task_runner->Process({{std::string(kImageInStream),
|
||||
MakePacket<Image>(image)}}));
|
||||
|
||||
// Check face landmarks
|
||||
if (GetParam().holistic_request.is_face_requested) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto face_landmarks,
|
||||
FetchResult<NormalizedLandmarkList>(
|
||||
output_packets, kFaceLandmarksStream));
|
||||
EXPECT_THAT(
|
||||
face_landmarks,
|
||||
Approximately(Partially(EqualsProto(holistic_result.face_landmarks())),
|
||||
/*margin=*/kAbsMargin));
|
||||
} else {
|
||||
ASSERT_FALSE(output_packets.contains(std::string(kFaceLandmarksStream)));
|
||||
}
|
||||
|
||||
if (GetParam().holistic_request.is_face_blendshapes_requested) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto face_blendshapes,
|
||||
FetchResult<ClassificationList>(
|
||||
output_packets, kFaceBlendshapesStream));
|
||||
EXPECT_THAT(face_blendshapes,
|
||||
Approximately(
|
||||
Partially(EqualsProto(holistic_result.face_blendshapes())),
|
||||
/*margin=*/kAbsMargin));
|
||||
} else {
|
||||
ASSERT_FALSE(output_packets.contains(std::string(kFaceBlendshapesStream)));
|
||||
}
|
||||
|
||||
// Check Pose landmarks
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto pose_landmarks,
|
||||
FetchResult<NormalizedLandmarkList>(
|
||||
output_packets, kPoseLandmarksStream));
|
||||
EXPECT_THAT(
|
||||
pose_landmarks,
|
||||
Approximately(Partially(EqualsProto(holistic_result.pose_landmarks())),
|
||||
/*margin=*/kAbsMargin));
|
||||
|
||||
// Check Hand landmarks
|
||||
if (GetParam().holistic_request.is_left_hand_requested) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto left_hand_landmarks,
|
||||
FetchResult<NormalizedLandmarkList>(
|
||||
output_packets, kLeftHandLandmarksStream));
|
||||
EXPECT_THAT(
|
||||
left_hand_landmarks,
|
||||
Approximately(
|
||||
Partially(EqualsProto(holistic_result.left_hand_landmarks())),
|
||||
/*margin=*/kAbsMargin));
|
||||
} else {
|
||||
ASSERT_FALSE(
|
||||
output_packets.contains(std::string(kLeftHandLandmarksStream)));
|
||||
}
|
||||
|
||||
if (GetParam().holistic_request.is_right_hand_requested) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto right_hand_landmarks,
|
||||
FetchResult<NormalizedLandmarkList>(
|
||||
output_packets, kRightHandLandmarksStream));
|
||||
EXPECT_THAT(
|
||||
right_hand_landmarks,
|
||||
Approximately(
|
||||
Partially(EqualsProto(holistic_result.right_hand_landmarks())),
|
||||
/*margin=*/kAbsMargin));
|
||||
} else {
|
||||
ASSERT_FALSE(
|
||||
output_packets.contains(std::string(kRightHandLandmarksStream)));
|
||||
}
|
||||
|
||||
auto rendered_image =
|
||||
output_packets.at(std::string(kRenderedImageOutStream)).Get<Image>();
|
||||
MP_EXPECT_OK(SavePngTestOutput(
|
||||
*rendered_image.GetImageFrameSharedPtr(),
|
||||
absl::StrCat("holistic_landmark_",
|
||||
RequestToString(GetParam().holistic_request))));
|
||||
|
||||
auto pose_segmentation_mask =
|
||||
output_packets.at(std::string(kPoseSegmentationMaskStream)).Get<Image>();
|
||||
|
||||
cv::Mat matting_mask = mediapipe::formats::MatView(
|
||||
pose_segmentation_mask.GetImageFrameSharedPtr().get());
|
||||
cv::Mat visualized_mask;
|
||||
matting_mask.convertTo(visualized_mask, CV_8UC1, 255);
|
||||
ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
|
||||
visualized_mask.cols, visualized_mask.rows,
|
||||
visualized_mask.step, visualized_mask.data,
|
||||
[visualized_mask](uint8_t[]) {});
|
||||
|
||||
MP_EXPECT_OK(
|
||||
SavePngTestOutput(visualized_image, "holistic_pose_segmentation_mask"));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
HolisticLandmarkerGraphTest, SmokeTest,
|
||||
Values(TestParams{
|
||||
/* test_name= */ "UseModelBundle",
|
||||
/* test_image_name= */ std::string(kTestImageFile),
|
||||
/* use_model_bundle= */ true,
|
||||
/* holistic_request= */
|
||||
{
|
||||
/*is_left_hand_requested= */ true,
|
||||
/*is_right_hand_requested= */ true,
|
||||
/*is_face_requested= */ true,
|
||||
/*is_face_blendshapes_requested= */ true,
|
||||
},
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "UseSeparateModelFiles",
|
||||
/* test_image_name= */ std::string(kTestImageFile),
|
||||
/* use_model_bundle= */ false,
|
||||
/* holistic_request= */
|
||||
{
|
||||
/*is_left_hand_requested= */ true,
|
||||
/*is_right_hand_requested= */ true,
|
||||
/*is_face_requested= */ true,
|
||||
/*is_face_blendshapes_requested= */ true,
|
||||
},
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "ModelBundleNoLeftHand",
|
||||
/* test_image_name= */ std::string(kTestImageFile),
|
||||
/* use_model_bundle= */ true,
|
||||
/* holistic_request= */
|
||||
{
|
||||
/*is_left_hand_requested= */ false,
|
||||
/*is_right_hand_requested= */ true,
|
||||
/*is_face_requested= */ true,
|
||||
/*is_face_blendshapes_requested= */ true,
|
||||
},
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "ModelBundleNoRightHand",
|
||||
/* test_image_name= */ std::string(kTestImageFile),
|
||||
/* use_model_bundle= */ true,
|
||||
/* holistic_request= */
|
||||
{
|
||||
/*is_left_hand_requested= */ true,
|
||||
/*is_right_hand_requested= */ false,
|
||||
/*is_face_requested= */ true,
|
||||
/*is_face_blendshapes_requested= */ true,
|
||||
},
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "ModelBundleNoHand",
|
||||
/* test_image_name= */ std::string(kTestImageFile),
|
||||
/* use_model_bundle= */ true,
|
||||
/* holistic_request= */
|
||||
{
|
||||
/*is_left_hand_requested= */ false,
|
||||
/*is_right_hand_requested= */ false,
|
||||
/*is_face_requested= */ true,
|
||||
/*is_face_blendshapes_requested= */ true,
|
||||
},
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "ModelBundleNoFace",
|
||||
/* test_image_name= */ std::string(kTestImageFile),
|
||||
/* use_model_bundle= */ true,
|
||||
/* holistic_request= */
|
||||
{
|
||||
/*is_left_hand_requested= */ true,
|
||||
/*is_right_hand_requested= */ true,
|
||||
/*is_face_requested= */ false,
|
||||
/*is_face_blendshapes_requested= */ false,
|
||||
},
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "ModelBundleNoFaceBlendshapes",
|
||||
/* test_image_name= */ std::string(kTestImageFile),
|
||||
/* use_model_bundle= */ true,
|
||||
/* holistic_request= */
|
||||
{
|
||||
/*is_left_hand_requested= */ true,
|
||||
/*is_right_hand_requested= */ true,
|
||||
/*is_face_requested= */ true,
|
||||
/*is_face_blendshapes_requested= */ false,
|
||||
},
|
||||
}),
|
||||
[](const TestParamInfo<SmokeTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace holistic_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,307 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h"
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
|
||||
#include "mediapipe/framework/api2/stream/image_size.h"
|
||||
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
|
||||
#include "mediapipe/framework/api2/stream/loopback.h"
|
||||
#include "mediapipe/framework/api2/stream/merge.h"
|
||||
#include "mediapipe/framework/api2/stream/presence.h"
|
||||
#include "mediapipe/framework/api2/stream/rect_transformation.h"
|
||||
#include "mediapipe/framework/api2/stream/segmentation_smoothing.h"
|
||||
#include "mediapipe/framework/api2/stream/smoothing.h"
|
||||
#include "mediapipe/framework/api2/stream/split.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/gate.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace holistic_landmarker {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionsToRect;
|
||||
using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionToRect;
|
||||
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
|
||||
using ::mediapipe::api2::builder::GenericNode;
|
||||
using ::mediapipe::api2::builder::GetImageSize;
|
||||
using ::mediapipe::api2::builder::GetLoopbackData;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::IsPresent;
|
||||
using ::mediapipe::api2::builder::Merge;
|
||||
using ::mediapipe::api2::builder::ScaleAndMakeSquare;
|
||||
using ::mediapipe::api2::builder::SmoothLandmarks;
|
||||
using ::mediapipe::api2::builder::SmoothLandmarksVisibility;
|
||||
using ::mediapipe::api2::builder::SmoothSegmentationMask;
|
||||
using ::mediapipe::api2::builder::SplitToRanges;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
using ::mediapipe::tasks::components::utils::DisallowIf;
|
||||
using Size = std::pair<int, int>;
|
||||
|
||||
constexpr int kAuxLandmarksStartKeypointIndex = 0;
|
||||
constexpr int kAuxLandmarksEndKeypointIndex = 1;
|
||||
constexpr float kAuxLandmarksTargetAngle = 90;
|
||||
constexpr float kRoiFromDetectionScaleFactor = 1.25f;
|
||||
constexpr float kRoiFromLandmarksScaleFactor = 1.25f;
|
||||
|
||||
Stream<NormalizedRect> CalculateRoiFromDetections(
|
||||
Stream<std::vector<Detection>> detections, Stream<Size> image_size,
|
||||
Graph& graph) {
|
||||
auto roi = ConvertAlignmentPointsDetectionsToRect(detections, image_size,
|
||||
/*start_keypoint_index=*/0,
|
||||
/*end_keypoint_index=*/1,
|
||||
/*target_angle=*/90, graph);
|
||||
return ScaleAndMakeSquare(
|
||||
roi, image_size, /*scale_x_factor=*/kRoiFromDetectionScaleFactor,
|
||||
/*scale_y_factor=*/kRoiFromDetectionScaleFactor, graph);
|
||||
}
|
||||
|
||||
Stream<NormalizedRect> CalculateScaleRoiFromAuxiliaryLandmarks(
|
||||
Stream<NormalizedLandmarkList> landmarks, Stream<Size> image_size,
|
||||
Graph& graph) {
|
||||
// TODO: consider calculating ROI directly from landmarks.
|
||||
auto detection = ConvertLandmarksToDetection(landmarks, graph);
|
||||
return ConvertAlignmentPointsDetectionToRect(
|
||||
detection, image_size, kAuxLandmarksStartKeypointIndex,
|
||||
kAuxLandmarksEndKeypointIndex, kAuxLandmarksTargetAngle, graph);
|
||||
}
|
||||
|
||||
Stream<NormalizedRect> CalculateRoiFromAuxiliaryLandmarks(
|
||||
Stream<NormalizedLandmarkList> landmarks, Stream<Size> image_size,
|
||||
Graph& graph) {
|
||||
// TODO: consider calculating ROI directly from landmarks.
|
||||
auto detection = ConvertLandmarksToDetection(landmarks, graph);
|
||||
auto roi = ConvertAlignmentPointsDetectionToRect(
|
||||
detection, image_size, kAuxLandmarksStartKeypointIndex,
|
||||
kAuxLandmarksEndKeypointIndex, kAuxLandmarksTargetAngle, graph);
|
||||
return ScaleAndMakeSquare(
|
||||
roi, image_size, /*scale_x_factor=*/kRoiFromLandmarksScaleFactor,
|
||||
/*scale_y_factor=*/kRoiFromLandmarksScaleFactor, graph);
|
||||
}
|
||||
|
||||
struct PoseLandmarksResult {
|
||||
std::optional<Stream<NormalizedLandmarkList>> landmarks;
|
||||
std::optional<Stream<LandmarkList>> world_landmarks;
|
||||
std::optional<Stream<NormalizedLandmarkList>> auxiliary_landmarks;
|
||||
std::optional<Stream<Image>> segmentation_mask;
|
||||
};
|
||||
|
||||
PoseLandmarksResult RunLandmarksDetection(
|
||||
Stream<Image> image, Stream<NormalizedRect> roi,
|
||||
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
|
||||
pose_landmarks_detector_graph_options,
|
||||
const HolisticPoseTrackingRequest& request, Graph& graph) {
|
||||
GenericNode& landmarks_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.pose_landmarker."
|
||||
"SinglePoseLandmarksDetectorGraph");
|
||||
landmarks_graph
|
||||
.GetOptions<pose_landmarker::proto::PoseLandmarksDetectorGraphOptions>() =
|
||||
pose_landmarks_detector_graph_options;
|
||||
image >> landmarks_graph.In("IMAGE");
|
||||
roi >> landmarks_graph.In("NORM_RECT");
|
||||
|
||||
PoseLandmarksResult result;
|
||||
if (request.landmarks) {
|
||||
result.landmarks =
|
||||
landmarks_graph.Out("LANDMARKS").Cast<NormalizedLandmarkList>();
|
||||
result.auxiliary_landmarks = landmarks_graph.Out("AUXILIARY_LANDMARKS")
|
||||
.Cast<NormalizedLandmarkList>();
|
||||
}
|
||||
if (request.world_landmarks) {
|
||||
result.world_landmarks =
|
||||
landmarks_graph.Out("WORLD_LANDMARKS").Cast<LandmarkList>();
|
||||
}
|
||||
if (request.segmentation_mask) {
|
||||
result.segmentation_mask =
|
||||
landmarks_graph.Out("SEGMENTATION_MASK").Cast<Image>();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<HolisticPoseTrackingOutput>
|
||||
TrackHolisticPoseUsingCustomPoseDetection(
|
||||
Stream<Image> image, PoseDetectionFn pose_detection_fn,
|
||||
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
|
||||
pose_landmarks_detector_graph_options,
|
||||
const HolisticPoseTrackingRequest& request, Graph& graph) {
|
||||
// Calculate ROI from scratch (pose detection) or reuse one from the
|
||||
// previous run if available.
|
||||
auto [previous_roi, set_previous_roi_fn] =
|
||||
GetLoopbackData<NormalizedRect>(/*tick=*/image, graph);
|
||||
auto is_previous_roi_available = IsPresent(previous_roi, graph);
|
||||
auto image_for_detection =
|
||||
DisallowIf(image, is_previous_roi_available, graph);
|
||||
MP_ASSIGN_OR_RETURN(auto pose_detections,
|
||||
pose_detection_fn(image_for_detection, graph));
|
||||
auto roi_from_detections = CalculateRoiFromDetections(
|
||||
pose_detections, GetImageSize(image_for_detection, graph), graph);
|
||||
// Take first non-empty.
|
||||
auto roi = Merge(roi_from_detections, previous_roi, graph);
|
||||
|
||||
// Calculate landmarks and other outputs (if requested) in the specified ROI.
|
||||
auto landmarks_detection_result = RunLandmarksDetection(
|
||||
image, roi, pose_landmarks_detector_graph_options,
|
||||
{
|
||||
// Landmarks are required for tracking, hence force-requesting them.
|
||||
.landmarks = true,
|
||||
.world_landmarks = request.world_landmarks,
|
||||
.segmentation_mask = request.segmentation_mask,
|
||||
},
|
||||
graph);
|
||||
RET_CHECK(landmarks_detection_result.landmarks.has_value() &&
|
||||
landmarks_detection_result.auxiliary_landmarks.has_value())
|
||||
<< "Failed to calculate landmarks required for tracking.";
|
||||
|
||||
// Split landmarks to pose landmarks and auxiliary landmarks.
|
||||
auto pose_landmarks_raw = *landmarks_detection_result.landmarks;
|
||||
auto auxiliary_landmarks = *landmarks_detection_result.auxiliary_landmarks;
|
||||
|
||||
auto image_size = GetImageSize(image, graph);
|
||||
|
||||
// TODO: b/305750053 - Apply adaptive crop by adding AdaptiveCropCalculator.
|
||||
|
||||
// Calculate ROI from smoothed auxiliary landmarks.
|
||||
auto scale_roi = CalculateScaleRoiFromAuxiliaryLandmarks(auxiliary_landmarks,
|
||||
image_size, graph);
|
||||
auto auxiliary_landmarks_smoothed = SmoothLandmarks(
|
||||
auxiliary_landmarks, image_size, scale_roi,
|
||||
{// Min cutoff 0.01 results into ~0.002 alpha in landmark EMA filter when
|
||||
// landmark is static.
|
||||
.min_cutoff = 0.01,
|
||||
// Beta 10.0 in combintation with min_cutoff 0.01 results into ~0.68
|
||||
// alpha in landmark EMA filter when landmark is moving fast.
|
||||
.beta = 10.0,
|
||||
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity
|
||||
// EMA filter.
|
||||
.derivate_cutoff = 1.0},
|
||||
graph);
|
||||
auto roi_from_auxiliary_landmarks = CalculateRoiFromAuxiliaryLandmarks(
|
||||
auxiliary_landmarks_smoothed, image_size, graph);
|
||||
|
||||
// Make ROI from auxiliary landmarks to be used as "previous" ROI for a
|
||||
// subsequent run.
|
||||
set_previous_roi_fn(roi_from_auxiliary_landmarks);
|
||||
|
||||
// Populate and smooth pose landmarks if corresponding output has been
|
||||
// requested.
|
||||
std::optional<Stream<NormalizedLandmarkList>> pose_landmarks;
|
||||
if (request.landmarks) {
|
||||
pose_landmarks = SmoothLandmarksVisibility(
|
||||
pose_landmarks_raw, /*low_pass_filter_alpha=*/0.1f, graph);
|
||||
pose_landmarks = SmoothLandmarks(
|
||||
*pose_landmarks, image_size, scale_roi,
|
||||
{// Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter when
|
||||
// landmark is static.
|
||||
.min_cutoff = 0.05f,
|
||||
// Beta 80.0 in combination with min_cutoff 0.05 results into ~0.94
|
||||
// alpha in landmark EMA filter when landmark is moving fast.
|
||||
.beta = 80.0f,
|
||||
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity
|
||||
// EMA filter.
|
||||
.derivate_cutoff = 1.0f},
|
||||
graph);
|
||||
}
|
||||
|
||||
// Populate and smooth world landmarks if available.
|
||||
std::optional<Stream<LandmarkList>> world_landmarks;
|
||||
if (landmarks_detection_result.world_landmarks) {
|
||||
world_landmarks = SplitToRanges(*landmarks_detection_result.world_landmarks,
|
||||
/*ranges*/ {{0, 33}}, graph)[0];
|
||||
world_landmarks = SmoothLandmarksVisibility(
|
||||
*world_landmarks, /*low_pass_filter_alpha=*/0.1f, graph);
|
||||
world_landmarks = SmoothLandmarks(
|
||||
*world_landmarks,
|
||||
/*scale_roi=*/std::nullopt,
|
||||
{// Min cutoff 0.1 results into ~ 0.02 alpha in landmark EMA filter when
|
||||
// landmark is static.
|
||||
.min_cutoff = 0.1f,
|
||||
// Beta 40.0 in combination with min_cutoff 0.1 results into ~0.8
|
||||
// alpha in landmark EMA filter when landmark is moving fast.
|
||||
.beta = 40.0f,
|
||||
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity
|
||||
// EMA filter.
|
||||
.derivate_cutoff = 1.0f},
|
||||
graph);
|
||||
}
|
||||
|
||||
// Populate and smooth segmentation mask if available.
|
||||
std::optional<Stream<Image>> segmentation_mask;
|
||||
if (landmarks_detection_result.segmentation_mask) {
|
||||
auto mask = *landmarks_detection_result.segmentation_mask;
|
||||
auto [prev_mask_as_img, set_prev_mask_as_img_fn] =
|
||||
GetLoopbackData<mediapipe::Image>(
|
||||
/*tick=*/*landmarks_detection_result.segmentation_mask, graph);
|
||||
auto mask_smoothed =
|
||||
SmoothSegmentationMask(mask, prev_mask_as_img,
|
||||
/*combine_with_previous_ratio=*/0.7f, graph);
|
||||
set_prev_mask_as_img_fn(mask_smoothed);
|
||||
segmentation_mask = mask_smoothed;
|
||||
}
|
||||
|
||||
return {{/*landmarks=*/pose_landmarks,
|
||||
/*world_landmarks=*/world_landmarks,
|
||||
/*segmentation_mask=*/segmentation_mask,
|
||||
/*debug_output=*/
|
||||
{/*auxiliary_landmarks=*/auxiliary_landmarks_smoothed,
|
||||
/*roi_from_landmarks=*/roi_from_auxiliary_landmarks,
|
||||
/*detections*/ pose_detections}}};
|
||||
}
|
||||
|
||||
absl::StatusOr<HolisticPoseTrackingOutput> TrackHolisticPose(
|
||||
Stream<Image> image,
|
||||
const pose_detector::proto::PoseDetectorGraphOptions&
|
||||
pose_detector_graph_options,
|
||||
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
|
||||
pose_landmarks_detector_graph_options,
|
||||
const HolisticPoseTrackingRequest& request, Graph& graph) {
|
||||
PoseDetectionFn pose_detection_fn = [&pose_detector_graph_options](
|
||||
Stream<Image> image, Graph& graph)
|
||||
-> absl::StatusOr<Stream<std::vector<mediapipe::Detection>>> {
|
||||
GenericNode& pose_detector =
|
||||
graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph");
|
||||
pose_detector.GetOptions<pose_detector::proto::PoseDetectorGraphOptions>() =
|
||||
pose_detector_graph_options;
|
||||
image >> pose_detector.In("IMAGE");
|
||||
return pose_detector.Out("DETECTIONS")
|
||||
.Cast<std::vector<mediapipe::Detection>>();
|
||||
};
|
||||
return TrackHolisticPoseUsingCustomPoseDetection(
|
||||
image, pose_detection_fn, pose_landmarks_detector_graph_options, request,
|
||||
graph);
|
||||
}
|
||||
|
||||
} // namespace holistic_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,110 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_
|
||||
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace holistic_landmarker {
|
||||
|
||||
// Type of pose detection function that can be used to customize pose tracking,
|
||||
// by supplying the function into a corresponding `TrackPose` function overload.
|
||||
//
|
||||
// Function should update provided graph with node/nodes that accept image
|
||||
// stream and produce stream of detections.
|
||||
using PoseDetectionFn = std::function<
|
||||
absl::StatusOr<api2::builder::Stream<std::vector<mediapipe::Detection>>>(
|
||||
api2::builder::Stream<Image>, api2::builder::Graph&)>;
|
||||
|
||||
struct HolisticPoseTrackingRequest {
|
||||
bool landmarks = false;
|
||||
bool world_landmarks = false;
|
||||
bool segmentation_mask = false;
|
||||
};
|
||||
|
||||
struct HolisticPoseTrackingOutput {
|
||||
std::optional<api2::builder::Stream<mediapipe::NormalizedLandmarkList>>
|
||||
landmarks;
|
||||
std::optional<api2::builder::Stream<mediapipe::LandmarkList>> world_landmarks;
|
||||
std::optional<api2::builder::Stream<Image>> segmentation_mask;
|
||||
|
||||
struct DebugOutput {
|
||||
api2::builder::Stream<mediapipe::NormalizedLandmarkList>
|
||||
auxiliary_landmarks;
|
||||
api2::builder::Stream<NormalizedRect> roi_from_landmarks;
|
||||
api2::builder::Stream<std::vector<mediapipe::Detection>> detections;
|
||||
};
|
||||
|
||||
DebugOutput debug_output;
|
||||
};
|
||||
|
||||
// Updates @graph to track pose in @image.
|
||||
//
|
||||
// @image - ImageFrame/GpuBuffer to track pose in.
|
||||
// @pose_detection_fn - pose detection function that takes @image as input and
|
||||
// produces stream of pose detections.
|
||||
// @pose_landmarks_detector_graph_options - options of the
|
||||
// PoseLandmarksDetectorGraph used to detect the pose landmarks.
|
||||
// @request - object to request specific pose tracking outputs.
|
||||
// NOTE: Outputs that were not requested won't be returned and corresponding
|
||||
// parts of the graph won't be genertaed all.
|
||||
// @graph - graph to update.
|
||||
absl::StatusOr<HolisticPoseTrackingOutput>
|
||||
TrackHolisticPoseUsingCustomPoseDetection(
|
||||
api2::builder::Stream<Image> image, PoseDetectionFn pose_detection_fn,
|
||||
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
|
||||
pose_landmarks_detector_graph_options,
|
||||
const HolisticPoseTrackingRequest& request, api2::builder::Graph& graph);
|
||||
|
||||
// Updates @graph to track pose in @image.
|
||||
//
|
||||
// @image - ImageFrame/GpuBuffer to track pose in.
|
||||
// @pose_detector_graph_options - options of the PoseDetectorGraph used to
|
||||
// detect the pose.
|
||||
// @pose_landmarks_detector_graph_options - options of the
|
||||
// PoseLandmarksDetectorGraph used to detect the pose landmarks.
|
||||
// @request - object to request specific pose tracking outputs.
|
||||
// NOTE: Outputs that were not requested won't be returned and corresponding
|
||||
// parts of the graph won't be genertaed all.
|
||||
// @graph - graph to update.
|
||||
absl::StatusOr<HolisticPoseTrackingOutput> TrackHolisticPose(
|
||||
api2::builder::Stream<Image> image,
|
||||
const pose_detector::proto::PoseDetectorGraphOptions&
|
||||
pose_detector_graph_options,
|
||||
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
|
||||
pose_landmarks_detector_graph_options,
|
||||
const HolisticPoseTrackingRequest& request, api2::builder::Graph& graph);
|
||||
|
||||
} // namespace holistic_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_
|
|
@ -0,0 +1,243 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "file/base/helpers.h"
|
||||
#include "file/base/options.h"
|
||||
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/stream/image_size.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
#include "testing/base/public/googletest.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace holistic_landmarker {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::file::Defaults;
|
||||
using ::file::GetTextProto;
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::api2::builder::GetImageSize;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
using ::mediapipe::tasks::core::TaskRunner;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr float kAbsMargin = 0.025;
|
||||
constexpr absl::string_view kTestDataDirectory =
|
||||
"/mediapipe/tasks/testdata/vision/";
|
||||
constexpr absl::string_view kTestImageFile = "male_full_height_hands.jpg";
|
||||
constexpr absl::string_view kImageInStream = "image_in";
|
||||
constexpr absl::string_view kPoseLandmarksOutStream = "pose_landmarks_out";
|
||||
constexpr absl::string_view kPoseWorldLandmarksOutStream =
|
||||
"pose_world_landmarks_out";
|
||||
constexpr absl::string_view kRenderedImageOutStream = "rendered_image_out";
|
||||
constexpr absl::string_view kHolisticResultFile =
|
||||
"male_full_height_hands_result_cpu.pbtxt";
|
||||
constexpr absl::string_view kHolisticPoseTrackingGraph =
|
||||
"holistic_pose_tracking_graph.pbtxt";
|
||||
|
||||
std::string GetFilePath(absl::string_view filename) {
|
||||
return file::JoinPath("./", kTestDataDirectory, filename);
|
||||
}
|
||||
|
||||
mediapipe::LandmarksToRenderDataCalculatorOptions GetPoseRendererOptions() {
|
||||
mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options;
|
||||
for (const auto& connection : pose_landmarker::kPoseLandmarksConnections) {
|
||||
renderer_options.add_landmark_connections(connection[0]);
|
||||
renderer_options.add_landmark_connections(connection[1]);
|
||||
}
|
||||
renderer_options.mutable_landmark_color()->set_r(255);
|
||||
renderer_options.mutable_landmark_color()->set_g(255);
|
||||
renderer_options.mutable_landmark_color()->set_b(255);
|
||||
renderer_options.mutable_connection_color()->set_r(255);
|
||||
renderer_options.mutable_connection_color()->set_g(255);
|
||||
renderer_options.mutable_connection_color()->set_b(255);
|
||||
renderer_options.set_thickness(0.5);
|
||||
renderer_options.set_visualize_landmark_depth(false);
|
||||
return renderer_options;
|
||||
}
|
||||
|
||||
// Helper function to create a TaskRunner.
|
||||
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner() {
|
||||
Graph graph;
|
||||
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
|
||||
pose_detector::proto::PoseDetectorGraphOptions pose_detector_graph_options;
|
||||
pose_detector_graph_options.mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath("pose_detection.tflite"));
|
||||
pose_detector_graph_options.set_num_poses(1);
|
||||
pose_landmarker::proto::PoseLandmarksDetectorGraphOptions
|
||||
pose_landmarks_detector_graph_options;
|
||||
pose_landmarks_detector_graph_options.mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath("pose_landmark_lite.tflite"));
|
||||
|
||||
HolisticPoseTrackingRequest request;
|
||||
request.landmarks = true;
|
||||
request.world_landmarks = true;
|
||||
MP_ASSIGN_OR_RETURN(
|
||||
HolisticPoseTrackingOutput result,
|
||||
TrackHolisticPose(image, pose_detector_graph_options,
|
||||
pose_landmarks_detector_graph_options, request, graph));
|
||||
|
||||
auto image_size = GetImageSize(image, graph);
|
||||
auto render_data = utils::RenderLandmarks(
|
||||
*result.landmarks,
|
||||
utils::GetRenderScale(image_size, result.debug_output.roi_from_landmarks,
|
||||
0.0001, graph),
|
||||
GetPoseRendererOptions(), graph);
|
||||
std::vector<Stream<mediapipe::RenderData>> render_list = {render_data};
|
||||
auto rendered_image =
|
||||
utils::Render(
|
||||
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
|
||||
.SetName(kRenderedImageOutStream);
|
||||
|
||||
rendered_image >> graph.Out("RENDERED_IMAGE");
|
||||
result.landmarks->SetName(kPoseLandmarksOutStream) >>
|
||||
graph.Out("POSE_LANDMARKS");
|
||||
result.world_landmarks->SetName(kPoseWorldLandmarksOutStream) >>
|
||||
graph.Out("POSE_WORLD_LANDMARKS");
|
||||
|
||||
auto config = graph.GetConfig();
|
||||
core::FixGraphBackEdges(config);
|
||||
|
||||
return TaskRunner::Create(
|
||||
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
// Remove fields not to be checked in the result, since the model
|
||||
// generating expected result is different from the testing model.
|
||||
void RemoveUncheckedResult(proto::HolisticResult& holistic_result) {
|
||||
for (auto& landmark :
|
||||
*holistic_result.mutable_pose_landmarks()->mutable_landmark()) {
|
||||
landmark.clear_z();
|
||||
landmark.clear_visibility();
|
||||
landmark.clear_presence();
|
||||
}
|
||||
}
|
||||
|
||||
class HolisticPoseTrackingTest : public testing::Test {};
|
||||
|
||||
TEST_F(HolisticPoseTrackingTest, VerifyGraph) {
|
||||
Graph graph;
|
||||
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
|
||||
pose_detector::proto::PoseDetectorGraphOptions pose_detector_graph_options;
|
||||
pose_detector_graph_options.mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath("pose_detection.tflite"));
|
||||
pose_detector_graph_options.set_num_poses(1);
|
||||
pose_landmarker::proto::PoseLandmarksDetectorGraphOptions
|
||||
pose_landmarks_detector_graph_options;
|
||||
pose_landmarks_detector_graph_options.mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(GetFilePath("pose_landmark_lite.tflite"));
|
||||
HolisticPoseTrackingRequest request;
|
||||
request.landmarks = true;
|
||||
request.world_landmarks = true;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
HolisticPoseTrackingOutput result,
|
||||
TrackHolisticPose(image, pose_detector_graph_options,
|
||||
pose_landmarks_detector_graph_options, request, graph));
|
||||
result.landmarks->SetName(kPoseLandmarksOutStream) >>
|
||||
graph.Out("POSE_LANDMARKS");
|
||||
result.world_landmarks->SetName(kPoseWorldLandmarksOutStream) >>
|
||||
graph.Out("POSE_WORLD_LANDMARKS");
|
||||
|
||||
auto config = graph.GetConfig();
|
||||
core::FixGraphBackEdges(config);
|
||||
|
||||
// Read the expected graph config.
|
||||
std::string expected_graph_contents;
|
||||
MP_ASSERT_OK(file::GetContents(
|
||||
file::JoinPath("./", kTestDataDirectory, kHolisticPoseTrackingGraph),
|
||||
&expected_graph_contents));
|
||||
|
||||
// Need to replace the expected graph config with the test srcdir, because
|
||||
// each run has different test dir on TAP.
|
||||
expected_graph_contents = absl::Substitute(
|
||||
expected_graph_contents, FLAGS_test_srcdir, FLAGS_test_srcdir);
|
||||
CalculatorGraphConfig expected_graph =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(expected_graph_contents);
|
||||
|
||||
EXPECT_THAT(config, testing::proto::IgnoringRepeatedFieldOrdering(
|
||||
testing::EqualsProto(expected_graph)));
|
||||
}
|
||||
|
||||
TEST_F(HolisticPoseTrackingTest, SmokeTest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image image,
|
||||
DecodeImageFromFile(GetFilePath(kTestImageFile)));
|
||||
|
||||
proto::HolisticResult holistic_result;
|
||||
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
|
||||
Defaults()));
|
||||
RemoveUncheckedResult(holistic_result);
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto output_packets,
|
||||
task_runner->Process({{std::string(kImageInStream),
|
||||
MakePacket<Image>(image)}}));
|
||||
auto pose_landmarks = output_packets.at(std::string(kPoseLandmarksOutStream))
|
||||
.Get<NormalizedLandmarkList>();
|
||||
EXPECT_THAT(
|
||||
pose_landmarks,
|
||||
Approximately(Partially(EqualsProto(holistic_result.pose_landmarks())),
|
||||
/*margin=*/kAbsMargin));
|
||||
auto rendered_image =
|
||||
output_packets.at(std::string(kRenderedImageOutStream)).Get<Image>();
|
||||
MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(),
|
||||
"pose_landmarks"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace holistic_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
44
mediapipe/tasks/cc/vision/holistic_landmarker/proto/BUILD
Normal file
44
mediapipe/tasks/cc/vision/holistic_landmarker/proto/BUILD
Normal file
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "holistic_result_proto",
|
||||
srcs = ["holistic_result.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_proto",
|
||||
"//mediapipe/framework/formats:landmark_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "holistic_landmarker_graph_options_proto",
|
||||
srcs = ["holistic_landmarker_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_proto",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_proto",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,57 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package mediapipe.tasks.vision.holistic_landmarker.proto;
|
||||
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.holisticlandmarker.proto";
|
||||
option java_outer_classname = "HolisticLandmarkerGraphOptionsProto";
|
||||
|
||||
message HolisticLandmarkerGraphOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// asset bundle file with metadata, accelerator options, etc.
|
||||
core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Options for hand landmarks graph.
|
||||
hand_landmarker.proto.HandLandmarksDetectorGraphOptions
|
||||
hand_landmarks_detector_graph_options = 2;
|
||||
|
||||
// Options for hand roi refinement graph.
|
||||
hand_landmarker.proto.HandRoiRefinementGraphOptions
|
||||
hand_roi_refinement_graph_options = 3;
|
||||
|
||||
// Options for face detector graph.
|
||||
face_detector.proto.FaceDetectorGraphOptions face_detector_graph_options = 4;
|
||||
|
||||
// Options for face landmarks detector graph.
|
||||
face_landmarker.proto.FaceLandmarksDetectorGraphOptions
|
||||
face_landmarks_detector_graph_options = 5;
|
||||
|
||||
// Options for pose detector graph.
|
||||
pose_detector.proto.PoseDetectorGraphOptions pose_detector_graph_options = 6;
|
||||
|
||||
// Options for pose landmarks detector graph.
|
||||
pose_landmarker.proto.PoseLandmarksDetectorGraphOptions
|
||||
pose_landmarks_detector_graph_options = 7;
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package mediapipe.tasks.vision.holistic_landmarker.proto;
|
||||
|
||||
import "mediapipe/framework/formats/classification.proto";
|
||||
import "mediapipe/framework/formats/landmark.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.holisticlandmarker";
|
||||
option java_outer_classname = "HolisticResultProto";
|
||||
|
||||
message HolisticResult {
|
||||
mediapipe.NormalizedLandmarkList pose_landmarks = 1;
|
||||
mediapipe.LandmarkList pose_world_landmarks = 7;
|
||||
mediapipe.NormalizedLandmarkList left_hand_landmarks = 2;
|
||||
mediapipe.NormalizedLandmarkList right_hand_landmarks = 3;
|
||||
mediapipe.NormalizedLandmarkList face_landmarks = 4;
|
||||
mediapipe.ClassificationList face_blendshapes = 6;
|
||||
mediapipe.NormalizedLandmarkList auxiliary_landmarks = 5;
|
||||
}
|
|
@ -47,13 +47,14 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
|||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite",
|
||||
|
|
|
@ -67,6 +67,7 @@ cc_binary(
|
|||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||
"//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||
"//mediapipe/tasks/cc/vision/holistic_landmarker:holistic_landmarker_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
|
@ -429,6 +430,7 @@ filegroup(
|
|||
android_library(
|
||||
name = "holisticlandmarker",
|
||||
srcs = [
|
||||
"holisticlandmarker/HolisticLandmarker.java",
|
||||
"holisticlandmarker/HolisticLandmarkerResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
|
@ -439,10 +441,20 @@ android_library(
|
|||
":core",
|
||||
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:any_java_proto",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.holisticlandmarker">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,668 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.holisticlandmarker;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.facedetector.proto.FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions;
|
||||
import com.google.mediapipe.tasks.vision.facelandmarker.proto.FaceLandmarksDetectorGraphOptionsProto.FaceLandmarksDetectorGraphOptions;
|
||||
import com.google.mediapipe.tasks.vision.handlandmarker.proto.HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions;
|
||||
import com.google.mediapipe.tasks.vision.holisticlandmarker.proto.HolisticLandmarkerGraphOptionsProto.HolisticLandmarkerGraphOptions;
|
||||
import com.google.mediapipe.tasks.vision.posedetector.proto.PoseDetectorGraphOptionsProto.PoseDetectorGraphOptions;
|
||||
import com.google.mediapipe.tasks.vision.poselandmarker.proto.PoseLandmarksDetectorGraphOptionsProto.PoseLandmarksDetectorGraphOptions;
|
||||
import com.google.protobuf.Any;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs holistic landmarks detection on images.
|
||||
*
|
||||
* <p>This API expects a pre-trained holistic landmarks model asset bundle.
|
||||
*
|
||||
* <ul>
|
||||
* <li>Input image {@link MPImage}
|
||||
* <ul>
|
||||
* <li>The image that holistic landmarks detection runs on.
|
||||
* </ul>
|
||||
* <li>Output {@link HolisticLandmarkerResult}
|
||||
* <ul>
|
||||
* <li>A HolisticLandmarkerResult containing holistic landmarks.
|
||||
* </ul>
|
||||
* </ul>
|
||||
*/
|
||||
public final class HolisticLandmarker extends BaseVisionTaskApi {
|
||||
private static final String TAG = HolisticLandmarker.class.getSimpleName();
|
||||
|
||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||
private static final String POSE_LANDMARKS_STREAM = "pose_landmarks";
|
||||
private static final String POSE_WORLD_LANDMARKS_STREAM = "pose_world_landmarks";
|
||||
private static final String POSE_SEGMENTATION_MASK_STREAM = "pose_segmentation_mask";
|
||||
private static final String FACE_LANDMARKS_STREAM = "face_landmarks";
|
||||
private static final String FACE_BLENDSHAPES_STREAM = "extra_blendshapes";
|
||||
private static final String LEFT_HAND_LANDMARKS_STREAM = "left_hand_landmarks";
|
||||
private static final String LEFT_HAND_WORLD_LANDMARKS_STREAM = "left_hand_world_landmarks";
|
||||
private static final String RIGHT_HAND_LANDMARKS_STREAM = "right_hand_landmarks";
|
||||
private static final String RIGHT_HAND_WORLD_LANDMARKS_STREAM = "right_hand_world_landmarks";
|
||||
private static final String IMAGE_OUT_STREAM_NAME = "image_out";
|
||||
|
||||
private static final int FACE_LANDMARKS_OUT_STREAM_INDEX = 0;
|
||||
private static final int POSE_LANDMARKS_OUT_STREAM_INDEX = 1;
|
||||
private static final int POSE_WORLD_LANDMARKS_OUT_STREAM_INDEX = 2;
|
||||
private static final int LEFT_HAND_LANDMARKS_OUT_STREAM_INDEX = 3;
|
||||
private static final int LEFT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX = 4;
|
||||
private static final int RIGHT_HAND_LANDMARKS_OUT_STREAM_INDEX = 5;
|
||||
private static final int RIGHT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX = 6;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 7;
|
||||
|
||||
private static final float DEFAULT_PRESENCE_THRESHOLD = 0.5f;
|
||||
private static final float DEFAULT_SUPPRESION_THRESHOLD = 0.3f;
|
||||
private static final boolean DEFAULT_OUTPUT_FACE_BLENDSHAPES = false;
|
||||
private static final boolean DEFAULT_OUTPUT_SEGMENTATION_MASKS = false;
|
||||
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph";
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME));
|
||||
|
||||
static {
|
||||
System.loadLibrary("mediapipe_tasks_vision_jni");
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link HolisticLandmarker} instance from a model asset bundle path and the default
|
||||
* {@link HolisticLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelAssetPath path to the holistic landmarks model with metadata in the assets.
|
||||
* @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation.
|
||||
*/
|
||||
public static HolisticLandmarker createFromFile(Context context, String modelAssetPath) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelAssetPath).build();
|
||||
return createFromOptions(
|
||||
context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link HolisticLandmarker} instance from a model asset bundle file and the default
|
||||
* {@link HolisticLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelAssetFile the holistic landmarks model {@link File} instance.
|
||||
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||
* @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation.
|
||||
*/
|
||||
public static HolisticLandmarker createFromFile(Context context, File modelAssetFile)
|
||||
throws IOException {
|
||||
try (ParcelFileDescriptor descriptor =
|
||||
ParcelFileDescriptor.open(modelAssetFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||
BaseOptions baseOptions =
|
||||
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||
return createFromOptions(
|
||||
context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link HolisticLandmarker} instance from a model asset bundle buffer and the default
|
||||
* {@link HolisticLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelAssetBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
|
||||
* detection model.
|
||||
* @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation.
|
||||
*/
|
||||
public static HolisticLandmarker createFromBuffer(
|
||||
Context context, final ByteBuffer modelAssetBuffer) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelAssetBuffer).build();
|
||||
return createFromOptions(
|
||||
context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link HolisticLandmarker} instance from a {@link HolisticLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param landmarkerOptions a {@link HolisticLandmarkerOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation.
|
||||
*/
|
||||
public static HolisticLandmarker createFromOptions(
|
||||
Context context, HolisticLandmarkerOptions landmarkerOptions) {
|
||||
List<String> outputStreams = new ArrayList<>();
|
||||
outputStreams.add("FACE_LANDMARKS:" + FACE_LANDMARKS_STREAM);
|
||||
outputStreams.add("POSE_LANDMARKS:" + POSE_LANDMARKS_STREAM);
|
||||
outputStreams.add("POSE_WORLD_LANDMARKS:" + POSE_WORLD_LANDMARKS_STREAM);
|
||||
outputStreams.add("LEFT_HAND_LANDMARKS:" + LEFT_HAND_LANDMARKS_STREAM);
|
||||
outputStreams.add("LEFT_HAND_WORLD_LANDMARKS:" + LEFT_HAND_WORLD_LANDMARKS_STREAM);
|
||||
outputStreams.add("RIGHT_HAND_LANDMARKS:" + RIGHT_HAND_LANDMARKS_STREAM);
|
||||
outputStreams.add("RIGHT_HAND_WORLD_LANDMARKS:" + RIGHT_HAND_WORLD_LANDMARKS_STREAM);
|
||||
outputStreams.add("IMAGE:" + IMAGE_OUT_STREAM_NAME);
|
||||
|
||||
int[] faceBlendshapesOutStreamIndex = new int[] {-1};
|
||||
if (landmarkerOptions.outputFaceBlendshapes()) {
|
||||
outputStreams.add("FACE_BLENDSHAPES:" + FACE_BLENDSHAPES_STREAM);
|
||||
faceBlendshapesOutStreamIndex[0] = outputStreams.size() - 1;
|
||||
}
|
||||
|
||||
int[] poseSegmentationMasksOutStreamIndex = new int[] {-1};
|
||||
if (landmarkerOptions.outputPoseSegmentationMasks()) {
|
||||
outputStreams.add("POSE_SEGMENTATION_MASK:" + POSE_SEGMENTATION_MASK_STREAM);
|
||||
poseSegmentationMasksOutStreamIndex[0] = outputStreams.size() - 1;
|
||||
}
|
||||
|
||||
OutputHandler<HolisticLandmarkerResult, MPImage> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<HolisticLandmarkerResult, MPImage>() {
|
||||
@Override
|
||||
public HolisticLandmarkerResult convertToTaskResult(List<Packet> packets) {
|
||||
// If there are no detected landmarks, just returns empty lists.
|
||||
if (packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX).isEmpty()) {
|
||||
return HolisticLandmarkerResult.createEmpty(
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
landmarkerOptions.runningMode(),
|
||||
packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX)));
|
||||
}
|
||||
|
||||
NormalizedLandmarkList faceLandmarkProtos =
|
||||
PacketGetter.getProto(
|
||||
packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser());
|
||||
Optional<ClassificationList> faceBlendshapeProtos =
|
||||
landmarkerOptions.outputFaceBlendshapes()
|
||||
? Optional.of(
|
||||
PacketGetter.getProto(
|
||||
packets.get(faceBlendshapesOutStreamIndex[0]),
|
||||
ClassificationList.parser()))
|
||||
: Optional.empty();
|
||||
NormalizedLandmarkList poseLandmarkProtos =
|
||||
PacketGetter.getProto(
|
||||
packets.get(POSE_LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser());
|
||||
LandmarkList poseWorldLandmarkProtos =
|
||||
PacketGetter.getProto(
|
||||
packets.get(POSE_WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser());
|
||||
Optional<MPImage> segmentationMask =
|
||||
landmarkerOptions.outputPoseSegmentationMasks()
|
||||
? Optional.of(
|
||||
getSegmentationMask(packets, poseSegmentationMasksOutStreamIndex[0]))
|
||||
: Optional.empty();
|
||||
NormalizedLandmarkList leftHandLandmarkProtos =
|
||||
PacketGetter.getProto(
|
||||
packets.get(LEFT_HAND_LANDMARKS_OUT_STREAM_INDEX),
|
||||
NormalizedLandmarkList.parser());
|
||||
LandmarkList leftHandWorldLandmarkProtos =
|
||||
PacketGetter.getProto(
|
||||
packets.get(LEFT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser());
|
||||
NormalizedLandmarkList rightHandLandmarkProtos =
|
||||
PacketGetter.getProto(
|
||||
packets.get(RIGHT_HAND_LANDMARKS_OUT_STREAM_INDEX),
|
||||
NormalizedLandmarkList.parser());
|
||||
LandmarkList rightHandWorldLandmarkProtos =
|
||||
PacketGetter.getProto(
|
||||
packets.get(RIGHT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX),
|
||||
LandmarkList.parser());
|
||||
|
||||
return HolisticLandmarkerResult.create(
|
||||
faceLandmarkProtos,
|
||||
faceBlendshapeProtos,
|
||||
poseLandmarkProtos,
|
||||
poseWorldLandmarkProtos,
|
||||
segmentationMask,
|
||||
leftHandLandmarkProtos,
|
||||
leftHandWorldLandmarkProtos,
|
||||
rightHandLandmarkProtos,
|
||||
rightHandWorldLandmarkProtos,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
landmarkerOptions.runningMode(), packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MPImage convertToTaskInput(List<Packet> packets) {
|
||||
return new BitmapImageBuilder(
|
||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||
.build();
|
||||
}
|
||||
});
|
||||
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
|
||||
landmarkerOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<HolisticLandmarkerOptions>builder()
|
||||
.setTaskName(HolisticLandmarker.class.getSimpleName())
|
||||
.setTaskRunningModeName(landmarkerOptions.runningMode().name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(outputStreams)
|
||||
.setTaskOptions(landmarkerOptions)
|
||||
.setEnableFlowLimiting(landmarkerOptions.runningMode() == RunningMode.LIVE_STREAM)
|
||||
.build(),
|
||||
handler);
|
||||
return new HolisticLandmarker(runner, landmarkerOptions.runningMode());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize an {@link HolisticLandmarker} from a {@link TaskRunner} and a {@link
|
||||
* RunningMode}.
|
||||
*
|
||||
* @param taskRunner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
private HolisticLandmarker(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, /* normRectStreamName= */ "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs holistic landmarks detection on the provided single image with default image
|
||||
* processing options, i.e. without any rotation applied. Only use this method when the {@link
|
||||
* HolisticLandmarker} is created with {@link RunningMode.IMAGE}.
|
||||
*
|
||||
* <p>{@link HolisticLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public HolisticLandmarkerResult detect(MPImage image) {
|
||||
return detect(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs holistic landmarks detection on the provided single image. Only use this method when
|
||||
* the {@link HolisticLandmarker} is created with {@link RunningMode.IMAGE}.
|
||||
*
|
||||
* <p>{@link HolisticLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public HolisticLandmarkerResult detect(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (HolisticLandmarkerResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs holistic landmarks detection on the provided video frame with default image processing
|
||||
* options, i.e. without any rotation applied. Only use this method when the {@link
|
||||
* HolisticLandmarker} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame"s timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link HolisticLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public HolisticLandmarkerResult detectForVideo(MPImage image, long timestampMs) {
|
||||
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs holistic landmarks detection on the provided video frame. Only use this method when
|
||||
* the {@link HolisticLandmarker} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame"s timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link HolisticLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public HolisticLandmarkerResult detectForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (HolisticLandmarkerResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform holistic landmarks detection with default image processing
|
||||
* options, i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link HolisticLandmarkerOptions}. Only use this method when
|
||||
* the {@link HolisticLandmarker } is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the holistic landmarker. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link HolisticLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void detectAsync(MPImage image, long timestampMs) {
|
||||
detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform holistic landmarks detection, and the results will be
|
||||
* available via the {@link ResultListener} provided in the {@link HolisticLandmarkerOptions}.
|
||||
* Only use this method when the {@link HolisticLandmarker} is created with {@link
|
||||
* RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the holistic landmarker. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link HolisticLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void detectAsync(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link HolisticLandmarker}. */
|
||||
@AutoValue
|
||||
public abstract static class HolisticLandmarkerOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link HolisticLandmarkerOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the base options for the holistic landmarker task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/**
|
||||
* Sets the running mode for the holistic landmarker task. Defaults to the image mode.
|
||||
* Holistic landmarker has three modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>IMAGE: The mode for detecting holistic landmarks on single image inputs.
|
||||
* <li>VIDEO: The mode for detecting holistic landmarks on the decoded frames of a video.
|
||||
* <li>LIVE_STREAM: The mode for for detecting holistic landmarks on a live stream of input
|
||||
* data, such as from camera. In this mode, {@code setResultListener} must be called to
|
||||
* set up a listener to receive the detection results asynchronously.
|
||||
* </ul>
|
||||
*/
|
||||
public abstract Builder setRunningMode(RunningMode value);
|
||||
|
||||
/**
|
||||
* Sets minimum confidence score for the face detection to be considered successful. Defaults
|
||||
* to 0.5.
|
||||
*/
|
||||
public abstract Builder setMinFaceDetectionConfidence(Float value);
|
||||
|
||||
/**
|
||||
* The minimum threshold for the face suppression score in the face detection. Defaults to
|
||||
* 0.3.
|
||||
*/
|
||||
public abstract Builder setMinFaceSuppressionThreshold(Float value);
|
||||
|
||||
/**
|
||||
* Sets minimum confidence score for the face landmark detection to be considered successful.
|
||||
* Defaults to 0.5.
|
||||
*/
|
||||
public abstract Builder setMinFaceLandmarksConfidence(Float value);
|
||||
|
||||
/**
|
||||
* The minimum confidence score for the pose detection to be considered successful. Defaults
|
||||
* to 0.5.
|
||||
*/
|
||||
public abstract Builder setMinPoseDetectionConfidence(Float value);
|
||||
|
||||
/**
|
||||
* The minimum threshold for the pose suppression score in the pose detection. Defaults to
|
||||
* 0.3.
|
||||
*/
|
||||
public abstract Builder setMinPoseSuppressionThreshold(Float value);
|
||||
|
||||
/**
|
||||
* The minimum confidence score for the pose landmarks detection to be considered successful.
|
||||
* Defaults to 0.5.
|
||||
*/
|
||||
public abstract Builder setMinPoseLandmarksConfidence(Float value);
|
||||
|
||||
/**
|
||||
* The minimum confidence score for the hand landmark detection to be considered successful.
|
||||
* Defaults to 0.5.
|
||||
*/
|
||||
public abstract Builder setMinHandLandmarksConfidence(Float value);
|
||||
|
||||
/** Whether to output segmentation masks. Defaults to false. */
|
||||
public abstract Builder setOutputPoseSegmentationMasks(Boolean value);
|
||||
|
||||
/** Whether to output face blendshapes. Defaults to false. */
|
||||
public abstract Builder setOutputFaceBlendshapes(Boolean value);
|
||||
|
||||
/**
|
||||
* Sets the result listener to receive the detection results asynchronously when the holistic
|
||||
* landmarker is in the live stream mode.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
ResultListener<HolisticLandmarkerResult, MPImage> value);
|
||||
|
||||
/** Sets an optional error listener. */
|
||||
public abstract Builder setErrorListener(ErrorListener value);
|
||||
|
||||
abstract HolisticLandmarkerOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link HolisticLandmarkerOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the holistic
|
||||
* landmarker is in the live stream mode.
|
||||
*/
|
||||
public final HolisticLandmarkerOptions build() {
|
||||
HolisticLandmarkerOptions options = autoBuild();
|
||||
if (options.runningMode() == RunningMode.LIVE_STREAM) {
|
||||
if (!options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The holistic landmarker is in the live stream mode, a user-defined result listener"
|
||||
+ " must be provided in HolisticLandmarkerOptions.");
|
||||
}
|
||||
} else if (options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The holistic landmarker is in the image or the video mode, a user-defined result"
|
||||
+ " listener shouldn't be provided in HolisticLandmarkerOptions.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract Optional<Float> minFaceDetectionConfidence();
|
||||
|
||||
abstract Optional<Float> minFaceSuppressionThreshold();
|
||||
|
||||
abstract Optional<Float> minFaceLandmarksConfidence();
|
||||
|
||||
abstract Optional<Float> minPoseDetectionConfidence();
|
||||
|
||||
abstract Optional<Float> minPoseSuppressionThreshold();
|
||||
|
||||
abstract Optional<Float> minPoseLandmarksConfidence();
|
||||
|
||||
abstract Optional<Float> minHandLandmarksConfidence();
|
||||
|
||||
abstract Boolean outputFaceBlendshapes();
|
||||
|
||||
abstract Boolean outputPoseSegmentationMasks();
|
||||
|
||||
abstract Optional<ResultListener<HolisticLandmarkerResult, MPImage>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_HolisticLandmarker_HolisticLandmarkerOptions.Builder()
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setMinFaceDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD)
|
||||
.setMinFaceSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD)
|
||||
.setMinFaceLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD)
|
||||
.setMinPoseDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD)
|
||||
.setMinPoseSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD)
|
||||
.setMinPoseLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD)
|
||||
.setMinHandLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD)
|
||||
.setOutputFaceBlendshapes(DEFAULT_OUTPUT_FACE_BLENDSHAPES)
|
||||
.setOutputPoseSegmentationMasks(DEFAULT_OUTPUT_SEGMENTATION_MASKS);
|
||||
}
|
||||
|
||||
/** Converts a {@link HolisticLandmarkerOptions} to a {@link Any} protobuf message. */
|
||||
@Override
|
||||
public Any convertToAnyProto() {
|
||||
HolisticLandmarkerGraphOptions.Builder holisticLandmarkerGraphOptions =
|
||||
HolisticLandmarkerGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
BaseOptionsProto.BaseOptions.newBuilder()
|
||||
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
|
||||
.mergeFrom(convertBaseOptionsToProto(baseOptions()))
|
||||
.build());
|
||||
|
||||
HandLandmarksDetectorGraphOptions.Builder handLandmarksDetectorGraphOptions =
|
||||
HandLandmarksDetectorGraphOptions.newBuilder();
|
||||
FaceDetectorGraphOptions.Builder faceDetectorGraphOptions =
|
||||
FaceDetectorGraphOptions.newBuilder();
|
||||
FaceLandmarksDetectorGraphOptions.Builder faceLandmarksDetectorGraphOptions =
|
||||
FaceLandmarksDetectorGraphOptions.newBuilder();
|
||||
PoseDetectorGraphOptions.Builder poseDetectorGraphOptions =
|
||||
PoseDetectorGraphOptions.newBuilder();
|
||||
PoseLandmarksDetectorGraphOptions.Builder poseLandmarkerGraphOptions =
|
||||
PoseLandmarksDetectorGraphOptions.newBuilder();
|
||||
|
||||
// Configure hand detector options.
|
||||
minHandLandmarksConfidence()
|
||||
.ifPresent(handLandmarksDetectorGraphOptions::setMinDetectionConfidence);
|
||||
|
||||
// Configure pose detector options.
|
||||
minPoseDetectionConfidence().ifPresent(poseDetectorGraphOptions::setMinDetectionConfidence);
|
||||
minPoseSuppressionThreshold().ifPresent(poseDetectorGraphOptions::setMinSuppressionThreshold);
|
||||
minPoseLandmarksConfidence().ifPresent(poseLandmarkerGraphOptions::setMinDetectionConfidence);
|
||||
|
||||
// Configure face detector options.
|
||||
minFaceDetectionConfidence().ifPresent(faceDetectorGraphOptions::setMinDetectionConfidence);
|
||||
minFaceSuppressionThreshold().ifPresent(faceDetectorGraphOptions::setMinSuppressionThreshold);
|
||||
minFaceLandmarksConfidence()
|
||||
.ifPresent(faceLandmarksDetectorGraphOptions::setMinDetectionConfidence);
|
||||
|
||||
holisticLandmarkerGraphOptions
|
||||
.setHandLandmarksDetectorGraphOptions(handLandmarksDetectorGraphOptions.build())
|
||||
.setFaceDetectorGraphOptions(faceDetectorGraphOptions.build())
|
||||
.setFaceLandmarksDetectorGraphOptions(faceLandmarksDetectorGraphOptions.build())
|
||||
.setPoseDetectorGraphOptions(poseDetectorGraphOptions.build())
|
||||
.setPoseLandmarksDetectorGraphOptions(poseLandmarkerGraphOptions.build());
|
||||
|
||||
return Any.newBuilder()
|
||||
.setTypeUrl(
|
||||
"type.googleapis.com/mediapipe.tasks.vision.holistic_landmarker.proto.HolisticLandmarkerGraphOptions")
|
||||
.setValue(holisticLandmarkerGraphOptions.build().toByteString())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that the provided {@link ImageProcessingOptions} doesn"t contain a
|
||||
* region-of-interest.
|
||||
*/
|
||||
private static void validateImageProcessingOptions(
|
||||
ImageProcessingOptions imageProcessingOptions) {
|
||||
if (imageProcessingOptions.regionOfInterest().isPresent()) {
|
||||
throw new IllegalArgumentException("HolisticLandmarker doesn't support region-of-interest.");
|
||||
}
|
||||
}
|
||||
|
||||
private static MPImage getSegmentationMask(List<Packet> packets, int packetIndex) {
|
||||
int width = PacketGetter.getImageWidth(packets.get(packetIndex));
|
||||
int height = PacketGetter.getImageHeight(packets.get(packetIndex));
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(width * height * 4);
|
||||
|
||||
if (!PacketGetter.getImageData(packets.get(packetIndex), buffer)) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"There was an error getting the sefmentation mask.");
|
||||
}
|
||||
|
||||
ByteBufferImageBuilder builder =
|
||||
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
|
||||
return builder.build();
|
||||
}
|
||||
}
|
|
@ -15,7 +15,6 @@
|
|||
package com.google.mediapipe.tasks.vision.imagesegmenter;
|
||||
|
||||
import android.content.Context;
|
||||
import android.util.Log;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.holisticlandmarkertest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="holisticlandmarkertest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.vision.holisticlandmarkertest" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable this in OSS
|
|
@ -0,0 +1,512 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.holisticlandmarker;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.graphics.BitmapFactory;
|
||||
import android.graphics.RectF;
|
||||
import androidx.test.core.app.ApplicationProvider;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.common.truth.Correspondence;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.NormalizedLandmark;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.holisticlandmarker.HolisticLandmarker.HolisticLandmarkerOptions;
|
||||
import com.google.mediapipe.tasks.vision.holisticlandmarker.HolisticResultProto.HolisticResult;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Optional;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.junit.runners.Suite.SuiteClasses;
|
||||
|
||||
/** Test for {@link HolisticLandmarker}. */
|
||||
@RunWith(Suite.class)
|
||||
@SuiteClasses({HolisticLandmarkerTest.General.class, HolisticLandmarkerTest.RunningModeTest.class})
|
||||
public class HolisticLandmarkerTest {
|
||||
private static final String HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = "holistic_landmarker.task";
|
||||
private static final String POSE_IMAGE = "male_full_height_hands.jpg";
|
||||
private static final String CAT_IMAGE = "cat.jpg";
|
||||
private static final String HOLISTIC_RESULT = "male_full_height_hands_result_cpu.pb";
|
||||
private static final String TAG = "Holistic Landmarker Test";
|
||||
private static final float FACE_LANDMARKS_ERROR_TOLERANCE = 0.03f;
|
||||
private static final float FACE_BLENDSHAPES_ERROR_TOLERANCE = 0.13f;
|
||||
private static final MPImage PLACEHOLDER_MASK =
|
||||
new ByteBufferImageBuilder(
|
||||
ByteBuffer.allocate(0), /* widht= */ 0, /* height= */ 0, MPImage.IMAGE_FORMAT_VEC32F1)
|
||||
.build();
|
||||
private static final int IMAGE_WIDTH = 638;
|
||||
private static final int IMAGE_HEIGHT = 1000;
|
||||
|
||||
private static final Correspondence<NormalizedLandmark, NormalizedLandmark> VALIDATE_LANDMARRKS =
|
||||
Correspondence.from(
|
||||
(Correspondence.BinaryPredicate<NormalizedLandmark, NormalizedLandmark>)
|
||||
(actual, expected) -> {
|
||||
return Correspondence.tolerance(FACE_LANDMARKS_ERROR_TOLERANCE)
|
||||
.compare(actual.x(), expected.x())
|
||||
&& Correspondence.tolerance(FACE_LANDMARKS_ERROR_TOLERANCE)
|
||||
.compare(actual.y(), expected.y());
|
||||
},
|
||||
"landmarks approximately equal to");
|
||||
|
||||
private static final Correspondence<Category, Category> VALIDATE_BLENDSHAPES =
|
||||
Correspondence.from(
|
||||
(Correspondence.BinaryPredicate<Category, Category>)
|
||||
(actual, expected) ->
|
||||
Correspondence.tolerance(FACE_BLENDSHAPES_ERROR_TOLERANCE)
|
||||
.compare(actual.score(), expected.score())
|
||||
&& actual.index() == expected.index()
|
||||
&& actual.categoryName().equals(expected.categoryName()),
|
||||
"face blendshapes approximately equal to");
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends HolisticLandmarkerTest {
|
||||
|
||||
@Test
|
||||
public void detect_successWithValidModels() throws Exception {
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.build();
|
||||
HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
HolisticLandmarkerResult actualResult =
|
||||
holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE));
|
||||
HolisticLandmarkerResult expectedResult =
|
||||
getExpectedHolisticLandmarkerResult(
|
||||
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithBlendshapes() throws Exception {
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setOutputFaceBlendshapes(true)
|
||||
.build();
|
||||
HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
HolisticLandmarkerResult actualResult =
|
||||
holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE));
|
||||
HolisticLandmarkerResult expectedResult =
|
||||
getExpectedHolisticLandmarkerResult(
|
||||
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ true, /* hasSegmentationMask= */ false);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithSegmentationMasks() throws Exception {
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setOutputPoseSegmentationMasks(true)
|
||||
.build();
|
||||
HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
HolisticLandmarkerResult actualResult =
|
||||
holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE));
|
||||
HolisticLandmarkerResult expectedResult =
|
||||
getExpectedHolisticLandmarkerResult(
|
||||
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ true);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithEmptyResult() throws Exception {
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.build();
|
||||
HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
HolisticLandmarkerResult actualResult =
|
||||
holisticLandmarker.detect(getImageFromAsset(CAT_IMAGE));
|
||||
assertThat(actualResult.faceLandmarks()).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_failsWithRegionOfInterest() throws Exception {
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.build();
|
||||
HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE), imageProcessingOptions));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("HolisticLandmarker doesn't support region-of-interest");
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class RunningModeTest extends HolisticLandmarkerTest {
|
||||
private void assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode runningMode)
|
||||
throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(runningMode)
|
||||
.setResultListener((HolisticLandmarkerResult, inputImage) -> {})
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("a user-defined result listener shouldn't be provided");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithIllegalResultListenerInVideoMode() throws Exception {
|
||||
assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode.VIDEO);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithIllegalResultListenerInImageMode() throws Exception {
|
||||
assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode.IMAGE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("a user-defined result listener must be provided");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_failsWithCallingWrongApiInImageMode() throws Exception {
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
|
||||
HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
holisticLandmarker.detectForVideo(
|
||||
getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
holisticLandmarker.detectAsync(
|
||||
getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_failsWithCallingWrongApiInVideoMode() throws Exception {
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
|
||||
HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
holisticLandmarker.detectAsync(
|
||||
getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener((HolisticLandmarkerResult, inputImage) -> {})
|
||||
.build();
|
||||
|
||||
HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
holisticLandmarker.detectForVideo(
|
||||
getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithImageMode() throws Exception {
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
|
||||
HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
HolisticLandmarkerResult actualResult =
|
||||
holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE));
|
||||
HolisticLandmarkerResult expectedResult =
|
||||
getExpectedHolisticLandmarkerResult(
|
||||
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithVideoMode() throws Exception {
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
HolisticLandmarkerResult expectedResult =
|
||||
getExpectedHolisticLandmarkerResult(
|
||||
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
HolisticLandmarkerResult actualResult =
|
||||
holisticLandmarker.detectForVideo(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ i);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_failsWithOutOfOrderInputTimestamps() throws Exception {
|
||||
MPImage image = getImageFromAsset(POSE_IMAGE);
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener((actualResult, inputImage) -> {})
|
||||
.build();
|
||||
try (HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options)) {
|
||||
holisticLandmarker.detectAsync(image, /* timestampsMs= */ 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> holisticLandmarker.detectAsync(image, /* timestampsMs= */ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithLiveSteamMode() throws Exception {
|
||||
MPImage image = getImageFromAsset(POSE_IMAGE);
|
||||
HolisticLandmarkerResult expectedResult =
|
||||
getExpectedHolisticLandmarkerResult(
|
||||
HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false);
|
||||
HolisticLandmarkerOptions options =
|
||||
HolisticLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(
|
||||
actualResult, expectedResult);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.build();
|
||||
try (HolisticLandmarker holisticLandmarker =
|
||||
HolisticLandmarker.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
holisticLandmarker.detectAsync(image, /* timestampsMs= */ i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||
}
|
||||
|
||||
private static HolisticLandmarkerResult getExpectedHolisticLandmarkerResult(
|
||||
String resultPath, boolean hasFaceBlendshapes, boolean hasSegmentationMask) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
|
||||
HolisticResult holisticResult = HolisticResult.parseFrom(assetManager.open(resultPath));
|
||||
|
||||
Optional<ClassificationList> blendshapes =
|
||||
hasFaceBlendshapes
|
||||
? Optional.of(holisticResult.getFaceBlendshapes())
|
||||
: Optional.<ClassificationList>empty();
|
||||
Optional<MPImage> segmentationMask =
|
||||
hasSegmentationMask ? Optional.of(PLACEHOLDER_MASK) : Optional.<MPImage>empty();
|
||||
|
||||
return HolisticLandmarkerResult.create(
|
||||
holisticResult.getFaceLandmarks(),
|
||||
blendshapes,
|
||||
holisticResult.getPoseLandmarks(),
|
||||
LandmarkList.getDefaultInstance(),
|
||||
segmentationMask,
|
||||
holisticResult.getLeftHandLandmarks(),
|
||||
LandmarkList.getDefaultInstance(),
|
||||
holisticResult.getRightHandLandmarks(),
|
||||
LandmarkList.getDefaultInstance(),
|
||||
/* timestampMs= */ 0);
|
||||
}
|
||||
|
||||
private static void assertActualResultApproximatelyEqualsToExpectedResult(
|
||||
HolisticLandmarkerResult actualResult, HolisticLandmarkerResult expectedResult) {
|
||||
// Expects to have the same number of holistics detected.
|
||||
assertThat(actualResult.faceLandmarks()).hasSize(expectedResult.faceLandmarks().size());
|
||||
assertThat(actualResult.faceBlendshapes().isPresent())
|
||||
.isEqualTo(expectedResult.faceBlendshapes().isPresent());
|
||||
assertThat(actualResult.poseLandmarks()).hasSize(expectedResult.poseLandmarks().size());
|
||||
assertThat(actualResult.segmentationMask().isPresent())
|
||||
.isEqualTo(expectedResult.segmentationMask().isPresent());
|
||||
assertThat(actualResult.leftHandLandmarks()).hasSize(expectedResult.leftHandLandmarks().size());
|
||||
assertThat(actualResult.rightHandLandmarks())
|
||||
.hasSize(expectedResult.rightHandLandmarks().size());
|
||||
|
||||
// Actual face landmarks match expected face landmarks.
|
||||
assertThat(actualResult.faceLandmarks())
|
||||
.comparingElementsUsing(VALIDATE_LANDMARRKS)
|
||||
.containsExactlyElementsIn(expectedResult.faceLandmarks());
|
||||
|
||||
// Actual face blendshapes match expected face blendshapes.
|
||||
if (actualResult.faceBlendshapes().isPresent()) {
|
||||
assertThat(actualResult.faceBlendshapes().get())
|
||||
.comparingElementsUsing(VALIDATE_BLENDSHAPES)
|
||||
.containsExactlyElementsIn(expectedResult.faceBlendshapes().get());
|
||||
}
|
||||
|
||||
// Actual pose landmarks match expected pose landmarks.
|
||||
assertThat(actualResult.poseLandmarks())
|
||||
.comparingElementsUsing(VALIDATE_LANDMARRKS)
|
||||
.containsExactlyElementsIn(expectedResult.poseLandmarks());
|
||||
|
||||
if (actualResult.segmentationMask().isPresent()) {
|
||||
assertImageSizeIsExpected(actualResult.segmentationMask().get());
|
||||
}
|
||||
|
||||
// Actual left hand landmarks match expected left hand landmarks.
|
||||
assertThat(actualResult.leftHandLandmarks())
|
||||
.comparingElementsUsing(VALIDATE_LANDMARRKS)
|
||||
.containsExactlyElementsIn(expectedResult.leftHandLandmarks());
|
||||
|
||||
// Actual right hand landmarks match expected right hand landmarks.
|
||||
assertThat(actualResult.rightHandLandmarks())
|
||||
.comparingElementsUsing(VALIDATE_LANDMARRKS)
|
||||
.containsExactlyElementsIn(expectedResult.rightHandLandmarks());
|
||||
}
|
||||
|
||||
private static void assertImageSizeIsExpected(MPImage inputImage) {
|
||||
assertThat(inputImage).isNotNull();
|
||||
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
|
||||
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
|
||||
}
|
||||
}
|
1
mediapipe/tasks/testdata/vision/BUILD
vendored
1
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -225,6 +225,7 @@ filegroup(
|
|||
"hand_detector_result_one_hand.pbtxt",
|
||||
"hand_detector_result_one_hand_rotated.pbtxt",
|
||||
"hand_detector_result_two_hands.pbtxt",
|
||||
"male_full_height_hands_result_cpu.pbtxt",
|
||||
"pointing_up_landmarks.pbtxt",
|
||||
"pointing_up_rotated_landmarks.pbtxt",
|
||||
"portrait_expected_detection.pbtxt",
|
||||
|
|
|
@ -66,7 +66,7 @@ const vision = await FilesetResolver.forVisionTasks(
|
|||
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision/wasm"
|
||||
);
|
||||
const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision,
|
||||
"hhttps://storage.googleapis.com/mediapipe-models/gesture_recognizer/gesture_recognizer/float16/1/gesture_recognizer.task"
|
||||
"https://storage.googleapis.com/mediapipe-models/gesture_recognizer/gesture_recognizer/float16/1/gesture_recognizer.task"
|
||||
);
|
||||
const image = document.getElementById("image") as HTMLImageElement;
|
||||
const recognitions = gestureRecognizer.recognize(image);
|
||||
|
|
|
@ -391,6 +391,7 @@ export class DrawingUtils {
|
|||
drawCategoryMask(
|
||||
mask: MPMask, categoryToColorMap: RGBAColor[],
|
||||
background?: RGBAColor|ImageSource): void;
|
||||
/** @export */
|
||||
drawCategoryMask(
|
||||
mask: MPMask, categoryToColorMap: CategoryToColorMap,
|
||||
background: RGBAColor|ImageSource = [0, 0, 0, 255]): void {
|
||||
|
@ -480,6 +481,7 @@ export class DrawingUtils {
|
|||
* frame, you can reduce the cost of re-uploading these images by passing a
|
||||
* `HTMLCanvasElement` instead.
|
||||
*
|
||||
* @export
|
||||
* @param mask A confidence mask that was returned from a segmentation task.
|
||||
* @param defaultTexture An image or a four-channel color that will be used
|
||||
* when confidence values are low.
|
||||
|
|
Loading…
Reference in New Issue
Block a user