From 1c860cace655c02b55bfe95eee5d24d2da1a50bf Mon Sep 17 00:00:00 2001 From: Kinar Date: Mon, 13 Nov 2023 09:53:37 -0800 Subject: [PATCH 1/9] Added files for the Object Detector C Tasks API --- mediapipe/tasks/c/components/containers/BUILD | 87 ++++++ .../components/containers/detection_result.h | 63 ++++ .../containers/detection_result_converter.cc | 80 +++++ .../containers/detection_result_converter.h | 38 +++ .../detection_result_converter_test.cc | 28 ++ .../tasks/c/components/containers/keypoint.h | 48 +++ .../containers/keypoint_converter.cc | 45 +++ .../containers/keypoint_converter.h | 32 ++ .../containers/keypoint_converter_test.cc | 35 +++ .../tasks/c/components/containers/rect.h | 46 +++ .../c/components/containers/rect_converter.cc | 43 +++ .../c/components/containers/rect_converter.h | 32 ++ .../containers/rect_converter_test.cc | 51 ++++ .../tasks/c/vision/object_detector/BUILD | 64 ++++ .../vision/object_detector/object_detector.cc | 288 ++++++++++++++++++ .../vision/object_detector/object_detector.h | 157 ++++++++++ .../object_detector/object_detector_test.cc | 253 +++++++++++++++ 17 files changed, 1390 insertions(+) create mode 100644 mediapipe/tasks/c/components/containers/detection_result.h create mode 100644 mediapipe/tasks/c/components/containers/detection_result_converter.cc create mode 100644 mediapipe/tasks/c/components/containers/detection_result_converter.h create mode 100644 mediapipe/tasks/c/components/containers/detection_result_converter_test.cc create mode 100644 mediapipe/tasks/c/components/containers/keypoint.h create mode 100644 mediapipe/tasks/c/components/containers/keypoint_converter.cc create mode 100644 mediapipe/tasks/c/components/containers/keypoint_converter.h create mode 100644 mediapipe/tasks/c/components/containers/keypoint_converter_test.cc create mode 100644 mediapipe/tasks/c/components/containers/rect.h create mode 100644 mediapipe/tasks/c/components/containers/rect_converter.cc create mode 100644 mediapipe/tasks/c/components/containers/rect_converter.h create mode 100644 mediapipe/tasks/c/components/containers/rect_converter_test.cc create mode 100644 mediapipe/tasks/c/vision/object_detector/BUILD create mode 100644 mediapipe/tasks/c/vision/object_detector/object_detector.cc create mode 100644 mediapipe/tasks/c/vision/object_detector/object_detector.h create mode 100644 mediapipe/tasks/c/vision/object_detector/object_detector_test.cc diff --git a/mediapipe/tasks/c/components/containers/BUILD b/mediapipe/tasks/c/components/containers/BUILD index 4bb580873..d015ccb00 100644 --- a/mediapipe/tasks/c/components/containers/BUILD +++ b/mediapipe/tasks/c/components/containers/BUILD @@ -43,6 +43,60 @@ cc_test( ], ) +cc_library( + name = "rect", + hdrs = ["rect.h"], +) + +cc_library( + name = "rect_converter", + srcs = ["rect_converter.cc"], + hdrs = ["rect_converter.h"], + deps = [ + ":rect", + "//mediapipe/tasks/cc/components/containers:rect", + ], +) + +cc_test( + name = "rect_converter_test", + srcs = ["rect_converter_test.cc"], + deps = [ + ":rect", + ":rect_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/components/containers:rect", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "keypoint", + hdrs = ["keypoint.h"], +) + +cc_library( + name = "keypoint_converter", + srcs = ["keypoint_converter.cc"], + hdrs = ["keypoint_converter.h"], + deps = [ + ":keypoint", + "//mediapipe/tasks/cc/components/containers:keypoint", + ], +) + +cc_test( + name = "keypoint_converter_test", + srcs = ["keypoint_converter_test.cc"], + deps = [ + ":keypoint", + ":keypoint_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/components/containers:keypoint", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "classification_result", hdrs = ["classification_result.h"], @@ -72,6 +126,39 @@ cc_test( ], ) +cc_library( + name = "detection_result", + hdrs = ["detection_result.h"], +) + +cc_library( + name = "detection_result_converter", + srcs = ["detection_result_converter.cc"], + hdrs = ["detection_result_converter.h"], + deps = [ + ":rect", + ":rect_converter", + ":category", + ":category_converter", + ":keypoint", + ":keypoint_converter", + ":detection_result", + "//mediapipe/tasks/cc/components/containers:detection_result", + ], +) + +cc_test( + name = "detection_result_converter_test", + srcs = ["detection_result_converter_test.cc"], + deps = [ + ":detection_result", + ":detection_result_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/components/containers:detection_result", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "embedding_result", hdrs = ["embedding_result.h"], diff --git a/mediapipe/tasks/c/components/containers/detection_result.h b/mediapipe/tasks/c/components/containers/detection_result.h new file mode 100644 index 000000000..48ce200f0 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/detection_result.h @@ -0,0 +1,63 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ + +#include + +#include "mediapipe/tasks/c/components/containers/rect.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Detection for a single bounding box. +struct DetectionC { + // An array of detected categories. + struct Category* categories; + + // The number of elements in the categories array. + uint32_t categories_count; + + // The bounding box location. + struct Rect bounding_box; + + // Optional list of keypoints associated with the detection. Keypoints + // represent interesting points related to the detection. For example, the + // keypoints represent the eye, ear and mouth from face detection model. Or + // in the template matching detection, e.g. KNIFT, they can represent the + // feature points for template matching. + // `nullptr` if keypoints is not present. + struct NormalizedKeypoint* keypoints; + + // The number of elements in the keypoints array. 0 if keypoints do not exist. + uint32_t keypoints_count; +}; + +// Detection results of a model. +struct DetectionResult { + // An array of Detections. + struct DetectionC* detections; + + // The number of detections in the detections array. + uint32_t detections_count; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ diff --git a/mediapipe/tasks/c/components/containers/detection_result_converter.cc b/mediapipe/tasks/c/components/containers/detection_result_converter.cc new file mode 100644 index 000000000..a752e50fe --- /dev/null +++ b/mediapipe/tasks/c/components/containers/detection_result_converter.cc @@ -0,0 +1,80 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/detection_result_converter.h" + +#include + +#include "mediapipe/tasks/c/components/containers/category.h" +#include "mediapipe/tasks/c/components/containers/category_converter.h" +#include "mediapipe/tasks/c/components/containers/detection_result.h" +#include "mediapipe/tasks/c/components/containers/keypoint.h" +#include "mediapipe/tasks/c/components/containers/keypoint_converter.h" +#include "mediapipe/tasks/c/components/containers/rect_converter.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToDetection( + const mediapipe::tasks::components::containers::Detection& in, + DetectionC* out) { + out->categories_count = in.categories.size(); + out->categories = new Category[out->categories_count]; + for (size_t i = 0; i < out->categories_count; ++i) { + CppConvertToCategory(in.categories[i], &out->categories[i]); + } + + CppConvertToRect(in.bounding_box, &out->bounding_box); + + if (in.keypoints.has_value()) { + auto& keypoints = in.keypoints.value(); + out->keypoints_count = keypoints.size(); + out->keypoints = new NormalizedKeypoint[out->keypoints_count]; + for (size_t i = 0; i < out->keypoints_count; ++i) { + CppConvertToNormalizedKeypoint(keypoints[i], &out->keypoints[i]); + } + } else { + out->keypoints = nullptr; + out->keypoints_count = 0; + } +} + +void CppConvertToDetectionResult( + const mediapipe::tasks::components::containers::DetectionResult& in, + DetectionResult* out) { + out->detections_count = in.detections.size(); + out->detections = new DetectionC[out->detections_count]; + for (size_t i = 0; i < out->detections_count; ++i) { + CppConvertToDetection(in.detections[i], &out->detections[i]); + } +} + +// Functions to free the memory of C structures. +void CppCloseDetection(DetectionC* in) { + delete[] in->categories; + in->categories = nullptr; + delete[] in->keypoints; + in->keypoints = nullptr; +} + +void CppCloseDetectionResult(DetectionResult* in) { + for (size_t i = 0; i < in->detections_count; ++i) { + CppCloseDetection(&in->detections[i]); + } + delete[] in->detections; + in->detections = nullptr; +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/detection_result_converter.h b/mediapipe/tasks/c/components/containers/detection_result_converter.h new file mode 100644 index 000000000..e338e47e9 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/detection_result_converter.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/containers/detection_result.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToDetection( + const mediapipe::tasks::components::containers::Detection& in, + Detection* out); + +void CppConvertToDetectionResult( + const mediapipe::tasks::components::containers::DetectionResult& in, + DetectionResult* out); + +void CppCloseDetection(Detection* in); + +void CppCloseDetectionResult(DetectionResult* in); + +} // namespace mediapipe::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc b/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc new file mode 100644 index 000000000..884481f29 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc @@ -0,0 +1,28 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/detection_result_converter.h" + +#include +#include +#include + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/detection_result.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +namespace mediapipe::tasks::c::components::containers { + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/keypoint.h b/mediapipe/tasks/c/components/containers/keypoint.h new file mode 100644 index 000000000..8381b1850 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/keypoint.h @@ -0,0 +1,48 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// A keypoint, defined by the coordinates (x, y), normalized by the image +// dimensions. +struct NormalizedKeypoint { + // x in normalized image coordinates. + float x; + + // y in normalized image coordinates. + float y; + + // Optional label of the keypoint. + char* label; // `nullptr` if the label is not present. + + // Optional score of the keypoint. + float score; + + // Indicates if score is valid + bool has_score; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_ diff --git a/mediapipe/tasks/c/components/containers/keypoint_converter.cc b/mediapipe/tasks/c/components/containers/keypoint_converter.cc new file mode 100644 index 000000000..53e8a5da1 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/keypoint_converter.cc @@ -0,0 +1,45 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/keypoint_converter.h" + +#include +#include +#include + +#include "mediapipe/tasks/c/components/containers/keypoint.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToNormalizedKeypoint( + const mediapipe::tasks::components::containers::NormalizedKeypoint& in, + NormalizedKeypoint* out) { + out->x = in.x; + out->y = in.y; + + out->label = in.label.has_value() ? strdup(in.label->c_str()) : nullptr; + out->has_score = in.score.has_value(); + out->score = out->has_score ? in.score.value() : 0; +} + +void CppCloseNormalizedKeypoint(NormalizedKeypoint* keypoint) { + if (keypoint && keypoint->label) { + free(keypoint->label); + keypoint->label = NULL; + } +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/keypoint_converter.h b/mediapipe/tasks/c/components/containers/keypoint_converter.h new file mode 100644 index 000000000..a4bd725f2 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/keypoint_converter.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/containers/keypoint.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToNormalizedKeypoint( + const mediapipe::tasks::components::containers::NormalizedKeypoint& in, + NormalizedKeypoint* out); + +void CppCloseNormalizedKeypoint(NormalizedKeypoint* keypoint); + +} // namespace mediapipe::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc b/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc new file mode 100644 index 000000000..ca09154c3 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc @@ -0,0 +1,35 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/keypoint_converter.h" + +#include +#include +#include + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/keypoint.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" + +namespace mediapipe::tasks::c::components::containers { + +TEST(RectConverterTest, ConvertsRectCustomValues) { + mediapipe::tasks::components::containers::Rect cpp_rect = {0, 0, 0, 0}; + + Rect c_rect; + CppConvertToRect(cpp_rect, &c_rect); +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/rect.h b/mediapipe/tasks/c/components/containers/rect.h new file mode 100644 index 000000000..ae6c7a3ee --- /dev/null +++ b/mediapipe/tasks/c/components/containers/rect.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// Defines a rectangle, used e.g. as part of detection results or as input +// region-of-interest. +struct Rect { + int left; + int top; + int right; + int bottom; +}; + +// The coordinates are normalized wrt the image dimensions, i.e. generally in +// [0,1] but they may exceed these bounds if describing a region overlapping the +// image. The origin is on the top-left corner of the image. +struct RectF { + float left; + float top; + float right; + float bottom; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_ diff --git a/mediapipe/tasks/c/components/containers/rect_converter.cc b/mediapipe/tasks/c/components/containers/rect_converter.cc new file mode 100644 index 000000000..ff700acee --- /dev/null +++ b/mediapipe/tasks/c/components/containers/rect_converter.cc @@ -0,0 +1,43 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/rect_converter.h" + +#include + +#include "mediapipe/tasks/c/components/containers/rect.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::c::components::containers { + +// Converts a C++ Rect to a C Rect. +void CppConvertToRect(const mediapipe::tasks::components::containers::Rect& in, + struct Rect* out) { + out->left = in.left; + out->top = in.top; + out->right = in.right; + out->bottom = in.bottom; +} + +// Converts a C++ RectF to a C RectF. +void CppConvertToRectF( + const mediapipe::tasks::components::containers::RectF& in, RectF* out) { + out->left = in.left; + out->top = in.top; + out->right = in.right; + out->bottom = in.bottom; +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/rect_converter.h b/mediapipe/tasks/c/components/containers/rect_converter.h new file mode 100644 index 000000000..75e21ff52 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/rect_converter.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/containers/rect.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToRect(const mediapipe::tasks::components::containers::Rect& in, + Rect* out); + +void CppConvertToRectF( + const mediapipe::tasks::components::containers::RectF& in, RectF* out); + +} // namespace mediapipe::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/containers/rect_converter_test.cc b/mediapipe/tasks/c/components/containers/rect_converter_test.cc new file mode 100644 index 000000000..3e8848094 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/rect_converter_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/rect_converter.h" + +#include +#include +#include + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/rect.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::c::components::containers { + +TEST(RectConverterTest, ConvertsRectCustomValues) { + mediapipe::tasks::components::containers::Rect cpp_rect = {0, 0, 0, 0}; + + Rect c_rect; + CppConvertToRect(cpp_rect, &c_rect); + EXPECT_EQ(c_rect.left, 0); + EXPECT_EQ(c_rect.right, 0); + EXPECT_EQ(c_rect.top, 0); + EXPECT_EQ(c_rect.bottom, 0); +} + +TEST(RectFConverterTest, ConvertsRectFCustomValues) { + mediapipe::tasks::components::containers::RectF cpp_rect = {0.1, 0.1, 0.1, + 0.1}; + + RectF c_rect; + CppConvertToRect(cpp_rect, &c_rect); + EXPECT_FLOAT_EQ(c_rect.left, 0.1); + EXPECT_FLOAT_EQ(c_rect.right, 0.1); + EXPECT_FLOAT_EQ(c_rect.top, 0.1); + EXPECT_FLOAT_EQ(c_rect.bottom, 0.1); +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/vision/object_detector/BUILD b/mediapipe/tasks/c/vision/object_detector/BUILD new file mode 100644 index 000000000..ee405b6df --- /dev/null +++ b/mediapipe/tasks/c/vision/object_detector/BUILD @@ -0,0 +1,64 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "object_detector_lib", + srcs = ["object_detector.cc"], + hdrs = ["object_detector.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/tasks/c/components/containers:detection_result", + "//mediapipe/tasks/c/components/containers:detection_result_converter", + "//mediapipe/tasks/c/core:base_options", + "//mediapipe/tasks/c/core:base_options_converter", + "//mediapipe/tasks/c/vision/core:common", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/object_detector", + "//mediapipe/tasks/cc/vision/utils:image_utils", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_test( + name = "object_detector_test", + srcs = ["object_detector_test.cc"], + data = [ + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + linkstatic = 1, + deps = [ + ":object_detector_lib", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/c/components/containers:category", + "//mediapipe/tasks/cc/vision/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/vision/object_detector/object_detector.cc b/mediapipe/tasks/c/vision/object_detector/object_detector.cc new file mode 100644 index 000000000..9fdf77821 --- /dev/null +++ b/mediapipe/tasks/c/vision/object_detector/object_detector.cc @@ -0,0 +1,288 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/vision/object_detector/object_detector.h" + +#include +#include +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/tasks/c/components/containers/detection_result_converter.h" +#include "mediapipe/tasks/c/core/base_options_converter.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/object_detector/object_detector.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe::tasks::c::vision::object_detector { + +namespace { + +using ::mediapipe::tasks::c::components::containers::CppCloseDetectionResult; +using ::mediapipe::tasks::c::components::containers:: + CppConvertToDetectionResult; +using ::mediapipe::tasks::c::core::CppConvertToBaseOptions; +using ::mediapipe::tasks::vision::CreateImageFromBuffer; +using ::mediapipe::tasks::vision::ObjectDetector; +using ::mediapipe::tasks::vision::core::RunningMode; +typedef ::mediapipe::tasks::vision::ObjectDetectorResult + CppObjectDetectorResult; + +int CppProcessError(absl::Status status, char** error_msg) { + if (error_msg) { + *error_msg = strdup(status.ToString().c_str()); + } + return status.raw_code(); +} + +} // namespace + +void CppConvertToDetectorOptions( + const ObjectDetectorOptions& in, + mediapipe::tasks::vision::ObjectDetectorOptions* out) { + out->display_names_locale = + in.display_names_locale ? std::string(in.display_names_locale) : "en"; + out->max_results = in.max_results; + out->score_threshold = in.score_threshold; + out->category_allowlist = + std::vector(in.category_allowlist_count); + for (uint32_t i = 0; i < in.category_allowlist_count; ++i) { + out->category_allowlist[i] = in.category_allowlist[i]; + } + out->category_denylist = std::vector(in.category_denylist_count); + for (uint32_t i = 0; i < in.category_denylist_count; ++i) { + out->category_denylist[i] = in.category_denylist[i]; + } +} + +ObjectDetector* CppObjectDetectorCreate(const ObjectDetectorOptions& options, + char** error_msg) { + auto cpp_options = + std::make_unique<::mediapipe::tasks::vision::ObjectDetectorOptions>(); + + CppConvertToBaseOptions(options.base_options, &cpp_options->base_options); + CppConvertToDetectorOptions(options, cpp_options.get()); + cpp_options->running_mode = static_cast(options.running_mode); + + // Enable callback for processing live stream data when the running mode is + // set to RunningMode::LIVE_STREAM. + if (cpp_options->running_mode == RunningMode::LIVE_STREAM) { + if (options.result_callback == nullptr) { + const absl::Status status = absl::InvalidArgumentError( + "Provided null pointer to callback function."); + ABSL_LOG(ERROR) << "Failed to create ObjectDetector: " << status; + CppProcessError(status, error_msg); + return nullptr; + } + + ObjectDetectorOptions::result_callback_fn result_callback = + options.result_callback; + cpp_options->result_callback = + [result_callback](absl::StatusOr cpp_result, + const Image& image, int64_t timestamp) { + char* error_msg = nullptr; + + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status(); + CppProcessError(cpp_result.status(), &error_msg); + result_callback(nullptr, MpImage(), timestamp, error_msg); + free(error_msg); + return; + } + + // Result is valid for the lifetime of the callback function. + ObjectDetectorResult result; + CppConvertToDetectionResult(*cpp_result, &result); + + const auto& image_frame = image.GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = { + .format = static_cast<::ImageFormat>(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + result_callback(&result, mp_image, timestamp, + /* error_msg= */ nullptr); + + CppCloseDetectionResult(&result); + }; + } + + auto detector = ObjectDetector::Create(std::move(cpp_options)); + if (!detector.ok()) { + ABSL_LOG(ERROR) << "Failed to create ObjectDetector: " << detector.status(); + CppProcessError(detector.status(), error_msg); + return nullptr; + } + return detector->release(); +} + +int CppObjectDetectorDetect(void* detector, const MpImage* image, + ObjectDetectorResult* result, char** error_msg) { + if (image->type == MpImage::GPU_BUFFER) { + const absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet."); + + ABSL_LOG(ERROR) << "Detection failed: " << status.message(); + return CppProcessError(status, error_msg); + } + + const auto img = CreateImageFromBuffer( + static_cast(image->image_frame.format), + image->image_frame.image_buffer, image->image_frame.width, + image->image_frame.height); + + if (!img.ok()) { + ABSL_LOG(ERROR) << "Failed to create Image: " << img.status(); + return CppProcessError(img.status(), error_msg); + } + + auto cpp_detector = static_cast(detector); + auto cpp_result = cpp_detector->Detect(*img); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status(); + return CppProcessError(cpp_result.status(), error_msg); + } + CppConvertToDetectionResult(*cpp_result, result); + return 0; +} + +int CppObjectDetectorDetectForVideo(void* detector, const MpImage* image, + int64_t timestamp_ms, + ObjectDetectorResult* result, + char** error_msg) { + if (image->type == MpImage::GPU_BUFFER) { + absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet"); + + ABSL_LOG(ERROR) << "Detection failed: " << status.message(); + return CppProcessError(status, error_msg); + } + + const auto img = CreateImageFromBuffer( + static_cast(image->image_frame.format), + image->image_frame.image_buffer, image->image_frame.width, + image->image_frame.height); + + if (!img.ok()) { + ABSL_LOG(ERROR) << "Failed to create Image: " << img.status(); + return CppProcessError(img.status(), error_msg); + } + + auto cpp_detector = static_cast(detector); + auto cpp_result = cpp_detector->DetectForVideo(*img, timestamp_ms); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status(); + return CppProcessError(cpp_result.status(), error_msg); + } + CppConvertToDetectionResult(*cpp_result, result); + return 0; +} + +int CppObjectDetectorDetectAsync(void* detector, const MpImage* image, + int64_t timestamp_ms, char** error_msg) { + if (image->type == MpImage::GPU_BUFFER) { + absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet"); + + ABSL_LOG(ERROR) << "Detection failed: " << status.message(); + return CppProcessError(status, error_msg); + } + + const auto img = CreateImageFromBuffer( + static_cast(image->image_frame.format), + image->image_frame.image_buffer, image->image_frame.width, + image->image_frame.height); + + if (!img.ok()) { + ABSL_LOG(ERROR) << "Failed to create Image: " << img.status(); + return CppProcessError(img.status(), error_msg); + } + + auto cpp_detector = static_cast(detector); + auto cpp_result = cpp_detector->DetectAsync(*img, timestamp_ms); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Data preparation for the object detection failed: " + << cpp_result; + return CppProcessError(cpp_result, error_msg); + } + return 0; +} + +void CppObjectDetectorCloseResult(ObjectDetectorResult* result) { + CppCloseDetectionResult(result); +} + +int CppObjectDetectorClose(void* detector, char** error_msg) { + auto cpp_detector = static_cast(detector); + auto result = cpp_detector->Close(); + if (!result.ok()) { + ABSL_LOG(ERROR) << "Failed to close ObjectDetector: " << result; + return CppProcessError(result, error_msg); + } + delete cpp_detector; + return 0; +} + +} // namespace mediapipe::tasks::c::vision::object_detector + +extern "C" { + +void* object_detector_create(struct ObjectDetectorOptions* options, + char** error_msg) { + return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorCreate( + *options, error_msg); +} + +int object_detector_detect_image(void* detector, const MpImage* image, + ObjectDetectorResult* result, + char** error_msg) { + return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorDetect( + detector, image, result, error_msg); +} + +int object_detector_detect_for_video(void* detector, const MpImage* image, + int64_t timestamp_ms, + ObjectDetectorResult* result, + char** error_msg) { + return mediapipe::tasks::c::vision::object_detector:: + CppObjectDetectorDetectForVideo(detector, image, timestamp_ms, result, + error_msg); +} + +int object_detector_detect_async(void* detector, const MpImage* image, + int64_t timestamp_ms, char** error_msg) { + return mediapipe::tasks::c::vision::object_detector:: + CppObjectDetectorDetectAsync(detector, image, timestamp_ms, error_msg); +} + +void object_detector_close_result(ObjectDetectorResult* result) { + mediapipe::tasks::c::vision::object_detector::CppObjectDetectorCloseResult( + result); +} + +int object_detector_close(void* detector, char** error_ms) { + return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorClose( + detector, error_ms); +} + +} // extern "C" diff --git a/mediapipe/tasks/c/vision/object_detector/object_detector.h b/mediapipe/tasks/c/vision/object_detector/object_detector.h new file mode 100644 index 000000000..16ec32477 --- /dev/null +++ b/mediapipe/tasks/c/vision/object_detector/object_detector.h @@ -0,0 +1,157 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_ +#define MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_ + +#include "mediapipe/tasks/c/components/containers/detection_result.h" +#include "mediapipe/tasks/c/core/base_options.h" +#include "mediapipe/tasks/c/vision/core/common.h" + +#ifndef MP_EXPORT +#define MP_EXPORT __attribute__((visibility("default"))) +#endif // MP_EXPORT + +#ifdef __cplusplus +extern "C" { +#endif + +typedef DetectionResult ObjectDetectorResult; + +// The options for configuring a MediaPipe object detector task. +struct ObjectDetectorOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + struct BaseOptions base_options; + + // The running mode of the task. Default to the image mode. + // Object detector has three running modes: + // 1) The image mode for detecting objects on single image inputs. + // 2) The video mode for detecting objects on the decoded frames of a video. + // 3) The live stream mode for detecting objects on the live stream of input + // data, such as from camera. In this mode, the "result_callback" below must + // be specified to receive the detection results asynchronously. + RunningMode running_mode; + + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + const char* display_names_locale; + + // The maximum number of top-scored detection results to return. If < 0, + // all available results will be returned. If 0, an invalid argument error is + // returned. + int max_results; + + // Score threshold to override the one provided in the model metadata (if + // any). Results below this value are rejected. + float score_threshold; + + // The allowlist of category names. If non-empty, detection results whose + // category name is not in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_denylist. + const char** category_allowlist; + // The number of elements in the category allowlist. + uint32_t category_allowlist_count; + + // The denylist of category names. If non-empty, detection results whose + // category name is in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_allowlist. + const char** category_denylist; + // The number of elements in the category denylist. + uint32_t category_denylist_count; + + // The user-defined result callback for processing live stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::LIVE_STREAM. Arguments of the callback function include: + // the pointer to detection result, the image that result was obtained + // on, the timestamp relevant to detection results and pointer to error + // message in case of any failure. The validity of the passed arguments is + // true for the lifetime of the callback function. + // + // A caller is responsible for closing object detector result. + typedef void (*result_callback_fn)(ObjectDetectorResult* result, + const MpImage image, int64_t timestamp_ms, + char* error_msg); + result_callback_fn result_callback; +}; + +// Creates an ObjectDetector from provided `options`. +// Returns a pointer to the image detector on success. +// If an error occurs, returns `nullptr` and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT void* object_detector_create(struct ObjectDetectorOptions* options, + char** error_msg); + +// Performs image detection on the input `image`. Returns `0` on success. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int object_detector_detect_image(void* detector, const MpImage* image, + ObjectDetectorResult* result, + char** error_msg); + +// Performs image detection on the provided video frame. +// Only use this method when the ObjectDetector is created with the video +// running mode. +// The image can be of any size with format RGB or RGBA. It's required to +// provide the video frame's timestamp (in milliseconds). The input timestamps +// must be monotonically increasing. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int object_detector_detect_for_video(void* detector, + const MpImage* image, + int64_t timestamp_ms, + ObjectDetectorResult* result, + char** error_msg); + +// Sends live image data to image detection, and the results will be +// available via the `result_callback` provided in the ObjectDetectorOptions. +// Only use this method when the ObjectDetector is created with the live +// stream running mode. +// The image can be of any size with format RGB or RGBA. It's required to +// provide a timestamp (in milliseconds) to indicate when the input image is +// sent to the object detector. The input timestamps must be monotonically +// increasing. +// The `result_callback` provides: +// - The detection results as an ObjectDetectorResult object. +// - The const reference to the corresponding input image that the image +// detector runs on. Note that the const reference to the image will no +// longer be valid when the callback returns. To access the image data +// outside of the callback, callers need to make a copy of the image. +// - The input timestamp in milliseconds. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int object_detector_detect_async(void* detector, const MpImage* image, + int64_t timestamp_ms, + char** error_msg); + +// Frees the memory allocated inside a ObjectDetectorResult result. +// Does not free the result pointer itself. +MP_EXPORT void object_detector_close_result(ObjectDetectorResult* result); + +// Frees image detector. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int object_detector_close(void* detector, char** error_msg); + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_ diff --git a/mediapipe/tasks/c/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/c/vision/object_detector/object_detector_test.cc new file mode 100644 index 000000000..ac404b0e7 --- /dev/null +++ b/mediapipe/tasks/c/vision/object_detector/object_detector_test.cc @@ -0,0 +1,253 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/vision/object_detector/object_detector.h" + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/category.h" +#include "mediapipe/tasks/c/components/containers/rect.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace { + +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using testing::HasSubstr; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kImageFile[] = "cats_and_dogs.jpg"; +constexpr char kModelName[] = + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; +constexpr float kPrecision = 1e-4; +constexpr int kIterations = 100; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +TEST(ObjectDetectorTest, ImageModeTest) { + const auto image = DecodeImageFromFile(GetFullPath(kImageFile)); + ASSERT_TRUE(image.ok()); + + const std::string model_path = GetFullPath(kModelName); + ObjectDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_buffer_count= */ 0, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::IMAGE, + /* display_names_locale= */ nullptr, + /* max_results= */ -1, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0, + }; + + void* detector = object_detector_create(&options, /* error_msg */ nullptr); + EXPECT_NE(detector, nullptr); + + const auto& image_frame = image->GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + ObjectDetectorResult result; + object_detector_detect_image(detector, &mp_image, &result, + /* error_msg */ nullptr); + EXPECT_EQ(result.detections_count, 10); + EXPECT_EQ(result.detections[0].categories_count, 1); + EXPECT_EQ(std::string{result.detections[0].categories[0].category_name}, + "cat"); + EXPECT_NEAR(result.detections[0].categories[0].score, 0.6992f, kPrecision); + object_detector_close_result(&result); + object_detector_close(detector, /* error_msg */ nullptr); +} + +TEST(ObjectDetectorTest, VideoModeTest) { + const auto image = DecodeImageFromFile(GetFullPath(kImageFile)); + ASSERT_TRUE(image.ok()); + + const std::string model_path = GetFullPath(kModelName); + ObjectDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_buffer_count= */ 0, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::VIDEO, + /* display_names_locale= */ nullptr, + /* max_results= */ 3, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0, + }; + + void* detector = object_detector_create(&options, /* error_msg */ nullptr); + EXPECT_NE(detector, nullptr); + + const auto& image_frame = image->GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + for (int i = 0; i < kIterations; ++i) { + ObjectDetectorResult result; + object_detector_detect_for_video(detector, &mp_image, i, &result, + /* error_msg */ nullptr); + EXPECT_EQ(result.detections_count, 3); + EXPECT_EQ(result.detections[0].categories_count, 1); + EXPECT_EQ(std::string{result.detections[0].categories[0].category_name}, + "cat"); + EXPECT_NEAR(result.detections[0].categories[0].score, 0.6992f, kPrecision); + object_detector_close_result(&result); + } + object_detector_close(detector, /* error_msg */ nullptr); +} + +// A structure to support LiveStreamModeTest below. This structure holds a +// static method `Fn` for a callback function of C API. A `static` qualifier +// allows to take an address of the method to follow API style. Another static +// struct member is `last_timestamp` that is used to verify that current +// timestamp is greater than the previous one. +struct LiveStreamModeCallback { + static int64_t last_timestamp; + static void Fn(ObjectDetectorResult* detector_result, const MpImage image, + int64_t timestamp, char* error_msg) { + ASSERT_NE(detector_result, nullptr); + ASSERT_EQ(error_msg, nullptr); + EXPECT_EQ(detector_result->detections_count, 3); + EXPECT_EQ(detector_result->detections[0].categories_count, 1); + EXPECT_EQ( + std::string{detector_result->detections[0].categories[0].category_name}, + "cat"); + EXPECT_NEAR(detector_result->detections[0].categories[0].score, 0.6992f, + kPrecision); + EXPECT_GT(image.image_frame.width, 0); + EXPECT_GT(image.image_frame.height, 0); + EXPECT_GT(timestamp, last_timestamp); + last_timestamp++; + } +}; +int64_t LiveStreamModeCallback::last_timestamp = -1; + +TEST(ObjectDetectorTest, LiveStreamModeTest) { + const auto image = DecodeImageFromFile(GetFullPath(kImageFile)); + ASSERT_TRUE(image.ok()); + + const std::string model_path = GetFullPath(kModelName); + + ObjectDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_buffer_count= */ 0, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::LIVE_STREAM, + /* display_names_locale= */ nullptr, + /* max_results= */ 3, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0, + /* result_callback= */ LiveStreamModeCallback::Fn, + }; + + void* detector = object_detector_create(&options, /* error_msg */ + nullptr); + EXPECT_NE(detector, nullptr); + + const auto& image_frame = image->GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + for (int i = 0; i < kIterations; ++i) { + EXPECT_GE(object_detector_detect_async(detector, &mp_image, i, + /* error_msg */ nullptr), + 0); + } + object_detector_close(detector, /* error_msg */ nullptr); + + // Due to the flow limiter, the total of outputs might be smaller than the + // number of iterations. + EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations); + EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0); +} + +TEST(ObjectDetectorTest, InvalidArgumentHandling) { + // It is an error to set neither the asset buffer nor the path. + ObjectDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_buffer_count= */ 0, + /* model_asset_path= */ nullptr}, + }; + + char* error_msg; + void* detector = object_detector_create(&options, &error_msg); + EXPECT_EQ(detector, nullptr); + + EXPECT_THAT(error_msg, HasSubstr("ExternalFile must specify")); + + free(error_msg); +} + +TEST(ObjectDetectorTest, FailedDetectionHandling) { + const std::string model_path = GetFullPath(kModelName); + ObjectDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_buffer_count= */ 0, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::IMAGE, + /* display_names_locale= */ nullptr, + /* max_results= */ -1, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0, + }; + + void* detector = object_detector_create(&options, /* error_msg */ + nullptr); + EXPECT_NE(detector, nullptr); + + const MpImage mp_image = {.type = MpImage::GPU_BUFFER, .gpu_buffer = {}}; + ObjectDetectorResult result; + char* error_msg; + object_detector_detect_image(detector, &mp_image, &result, &error_msg); + EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet")); + free(error_msg); + object_detector_close(detector, /* error_msg */ nullptr); +} + +} // namespace From b879e3a2041b71c8a5264a650a8f3fed6fb1437f Mon Sep 17 00:00:00 2001 From: Kinar Date: Thu, 16 Nov 2023 10:05:34 -0800 Subject: [PATCH 2/9] Updated components and their tests in the C Tasks API --- .../detection_result_converter_test.cc | 50 +++++++++++++++++++ .../containers/keypoint_converter.cc | 3 +- .../containers/keypoint_converter_test.cc | 26 ++++++++-- .../c/components/containers/rect_converter.cc | 2 - .../containers/rect_converter_test.cc | 2 +- 5 files changed, 74 insertions(+), 9 deletions(-) diff --git a/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc b/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc index 884481f29..2fd85bf31 100644 --- a/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc +++ b/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc @@ -25,4 +25,54 @@ limitations under the License. namespace mediapipe::tasks::c::components::containers { +TEST(DetectionResultConverterTest, ConvertsDetectionResultCustomCategory) { + mediapipe::tasks::components::containers::DetectionResult + cpp_detection_result = {/* detections= */ { + {/* categories= */ {{/* index= */ 1, /* score= */ 0.1, + /* category_name= */ "cat", + /* display_name= */ "cat"}}, + /* bounding_box= */ {10, 10, 10, 10}, + {/* keypoints */ {{0.1, 0.1, "foo", 0.5}}}}}}; + + DetectionResult c_detection_result; + CppConvertToDetectionResult(cpp_detection_result, &c_detection_result); + EXPECT_NE(c_detection_result.detections, nullptr); + EXPECT_EQ(c_detection_result.detections_count, 1); + EXPECT_NE(c_detection_result.detections[0].categories, nullptr); + EXPECT_EQ(c_detection_result.detections[0].categories_count, 1); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.left, 10); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.top, 10); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.right, 10); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.bottom, 10); + EXPECT_NE(c_detection_result.detections[0].keypoints, nullptr); + + CppCloseDetectionResult(&c_detection_result); +} + +TEST(DetectionResultConverterTest, ConvertsDetectionResultNoCategory) { + mediapipe::tasks::components::containers::DetectionResult + cpp_detection_result = {/* detections= */ {/* categories= */ {}}}; + + DetectionResult c_detection_result; + CppConvertToDetectionResult(cpp_detection_result, &c_detection_result); + EXPECT_NE(c_detection_result.detections, nullptr); + EXPECT_EQ(c_detection_result.detections_count, 1); + EXPECT_NE(c_detection_result.detections[0].categories, nullptr); + EXPECT_EQ(c_detection_result.detections[0].categories_count, 0); + + CppCloseDetectionResult(&c_detection_result); +} + +TEST(DetectionResultConverterTest, FreesMemory) { + mediapipe::tasks::components::containers::DetectionResult + cpp_detection_result = {/* detections= */ {{/* categories= */ {}}}}; + + DetectionResult c_detection_result; + CppConvertToDetectionResult(cpp_detection_result, &c_detection_result); + EXPECT_NE(c_detection_result.detections, nullptr); + + CppCloseDetectionResult(&c_detection_result); + EXPECT_EQ(c_detection_result.detections, nullptr); +} + } // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/keypoint_converter.cc b/mediapipe/tasks/c/components/containers/keypoint_converter.cc index 53e8a5da1..2d64e8063 100644 --- a/mediapipe/tasks/c/components/containers/keypoint_converter.cc +++ b/mediapipe/tasks/c/components/containers/keypoint_converter.cc @@ -15,7 +15,6 @@ limitations under the License. #include "mediapipe/tasks/c/components/containers/keypoint_converter.h" -#include #include #include @@ -38,7 +37,7 @@ void CppConvertToNormalizedKeypoint( void CppCloseNormalizedKeypoint(NormalizedKeypoint* keypoint) { if (keypoint && keypoint->label) { free(keypoint->label); - keypoint->label = NULL; + keypoint->label = nullptr; } } diff --git a/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc b/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc index ca09154c3..38bf1e3c6 100644 --- a/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc +++ b/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc @@ -25,11 +25,29 @@ limitations under the License. namespace mediapipe::tasks::c::components::containers { -TEST(RectConverterTest, ConvertsRectCustomValues) { - mediapipe::tasks::components::containers::Rect cpp_rect = {0, 0, 0, 0}; +constexpr float kPrecision = 1e-6; - Rect c_rect; - CppConvertToRect(cpp_rect, &c_rect); +TEST(KeypointConverterTest, ConvertsKeypointCustomValues) { + mediapipe::tasks::components::containers::NormalizedKeypoint cpp_keypoint = { + 0.1, 0.1, "foo", 0.5}; + + NormalizedKeypoint c_keypoint; + CppConvertToNormalizedKeypoint(cpp_keypoint, &c_keypoint); + EXPECT_NEAR(c_keypoint.x, 0.1f, kPrecision); + EXPECT_NEAR(c_keypoint.x, 0.1f, kPrecision); + EXPECT_EQ(std::string(c_keypoint.label), "foo"); + EXPECT_NEAR(c_keypoint.score, 0.5f, kPrecision); +} + +TEST(KeypointConverterTest, FreesMemory) { + mediapipe::tasks::components::containers::NormalizedKeypoint cpp_keypoint = { + 0.1, 0.1, "foo", 0.5}; + + NormalizedKeypoint c_keypoint; + CppConvertToNormalizedKeypoint(cpp_keypoint, &c_keypoint); + EXPECT_NE(c_keypoint.label, nullptr); + CppCloseNormalizedKeypoint(&c_keypoint); + EXPECT_EQ(c_keypoint.label, nullptr); } } // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/rect_converter.cc b/mediapipe/tasks/c/components/containers/rect_converter.cc index ff700acee..9f30bec4e 100644 --- a/mediapipe/tasks/c/components/containers/rect_converter.cc +++ b/mediapipe/tasks/c/components/containers/rect_converter.cc @@ -15,8 +15,6 @@ limitations under the License. #include "mediapipe/tasks/c/components/containers/rect_converter.h" -#include - #include "mediapipe/tasks/c/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h" diff --git a/mediapipe/tasks/c/components/containers/rect_converter_test.cc b/mediapipe/tasks/c/components/containers/rect_converter_test.cc index 3e8848094..eb2107240 100644 --- a/mediapipe/tasks/c/components/containers/rect_converter_test.cc +++ b/mediapipe/tasks/c/components/containers/rect_converter_test.cc @@ -41,7 +41,7 @@ TEST(RectFConverterTest, ConvertsRectFCustomValues) { 0.1}; RectF c_rect; - CppConvertToRect(cpp_rect, &c_rect); + CppConvertToRectF(cpp_rect, &c_rect); EXPECT_FLOAT_EQ(c_rect.left, 0.1); EXPECT_FLOAT_EQ(c_rect.right, 0.1); EXPECT_FLOAT_EQ(c_rect.top, 0.1); From 5ca859f90b56ca3dffa54beab6b64e421ad04a09 Mon Sep 17 00:00:00 2001 From: Alex Macdonald-Smith Date: Tue, 21 Nov 2023 16:03:57 -0500 Subject: [PATCH 3/9] Updated mediapipe/mediapipe/tasks/web /vision/README.md There was a typo in the url referencing Gesture Recognizer ``` const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision, "hhttps://storage.googleapis.com/mediapipe-models/gesture_recognizer/gesture_recognizer/float16/1/gesture_recognizer.task" ); ``` changed to ``` const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision, "https://storage.googleapis.com/mediapipe-models/gesture_recognizer/gesture_recognizer/float16/1/gesture_recognizer.task" ); ``` The extra 'h' was dropped. Let me know if there are anymore updates needed for this. --- mediapipe/tasks/web/vision/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md index 816ef9e4f..c603beaea 100644 --- a/mediapipe/tasks/web/vision/README.md +++ b/mediapipe/tasks/web/vision/README.md @@ -66,7 +66,7 @@ const vision = await FilesetResolver.forVisionTasks( "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision/wasm" ); const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision, - "hhttps://storage.googleapis.com/mediapipe-models/gesture_recognizer/gesture_recognizer/float16/1/gesture_recognizer.task" + "https://storage.googleapis.com/mediapipe-models/gesture_recognizer/gesture_recognizer/float16/1/gesture_recognizer.task" ); const image = document.getElementById("image") as HTMLImageElement; const recognitions = gestureRecognizer.recognize(image); From 8d57a9e2e815dd11a7434cb764382cc7ce4831a5 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 27 Nov 2023 11:12:24 -0800 Subject: [PATCH 4/9] Add missing export declarations to DrawingUtils Fixes https://github.com/google/mediapipe/issues/4980 PiperOrigin-RevId: 585705106 --- mediapipe/tasks/web/vision/core/drawing_utils.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/tasks/web/vision/core/drawing_utils.ts b/mediapipe/tasks/web/vision/core/drawing_utils.ts index 520f9e2b3..f3c3d5d75 100644 --- a/mediapipe/tasks/web/vision/core/drawing_utils.ts +++ b/mediapipe/tasks/web/vision/core/drawing_utils.ts @@ -391,6 +391,7 @@ export class DrawingUtils { drawCategoryMask( mask: MPMask, categoryToColorMap: RGBAColor[], background?: RGBAColor|ImageSource): void; + /** @export */ drawCategoryMask( mask: MPMask, categoryToColorMap: CategoryToColorMap, background: RGBAColor|ImageSource = [0, 0, 0, 255]): void { @@ -480,6 +481,7 @@ export class DrawingUtils { * frame, you can reduce the cost of re-uploading these images by passing a * `HTMLCanvasElement` instead. * + * @export * @param mask A confidence mask that was returned from a segmentation task. * @param defaultTexture An image or a four-channel color that will be used * when confidence values are low. From 1ff7e95295aed0b6a4c4c77d92432d48ff4ba041 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 27 Nov 2023 11:59:35 -0800 Subject: [PATCH 5/9] No public description PiperOrigin-RevId: 585719403 --- mediapipe/model_maker/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index ff43fa3f0..ffb547b82 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -7,4 +7,4 @@ tensorflow-addons tensorflow-datasets tensorflow-hub tensorflow-text -tf-models-official>=2.13.1 +tf-models-official>=2.13.2 From 95601ff98b6ba5a5eb9bd4a84761dc25a9d5aab4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 27 Nov 2023 15:48:37 -0800 Subject: [PATCH 6/9] Remove internal logs. PiperOrigin-RevId: 585782033 --- .../mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index b673b00c9..813dba93c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -15,7 +15,6 @@ package com.google.mediapipe.tasks.vision.imagesegmenter; import android.content.Context; -import android.util.Log; import com.google.auto.value.AutoValue; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; From a898215c52a6c406cee993c97cb705ab6d66bc96 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 28 Nov 2023 14:33:15 -0800 Subject: [PATCH 7/9] Holistic Landmarker C++ Graph PiperOrigin-RevId: 586105983 --- .../tasks/cc/vision/hand_landmarker/BUILD | 31 + .../hand_roi_refinement_graph.cc | 154 +++++ .../tasks/cc/vision/holistic_landmarker/BUILD | 152 +++++ .../holistic_face_tracking.cc | 260 ++++++++ .../holistic_face_tracking.h | 89 +++ .../holistic_face_tracking_test.cc | 227 +++++++ .../holistic_hand_tracking.cc | 272 ++++++++ .../holistic_hand_tracking.h | 94 +++ .../holistic_hand_tracking_test.cc | 303 +++++++++ .../holistic_landmarker_graph.cc | 521 +++++++++++++++ .../holistic_landmarker_graph_test.cc | 595 ++++++++++++++++++ .../holistic_pose_tracking.cc | 307 +++++++++ .../holistic_pose_tracking.h | 110 ++++ .../holistic_pose_tracking_test.cc | 243 +++++++ .../cc/vision/holistic_landmarker/proto/BUILD | 44 ++ .../holistic_landmarker_graph_options.proto | 57 ++ .../proto/holistic_result.proto | 34 + 17 files changed, 3493 insertions(+) create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/hand_roi_refinement_graph.cc create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/BUILD create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.cc create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking_test.cc create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.cc create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking_test.cc create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph.cc create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph_test.cc create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.cc create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking_test.cc create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/proto/BUILD create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.proto create mode 100644 mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.proto diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 5b75ef8fc..6db49c668 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -155,6 +155,37 @@ cc_library( # TODO: open source hand joints graph +cc_library( + name = "hand_roi_refinement_graph", + srcs = ["hand_roi_refinement_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator", + "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2/stream:detections_to_rects", + "//mediapipe/framework/api2/stream:landmarks_projection", + "//mediapipe/framework/api2/stream:landmarks_to_detection", + "//mediapipe/framework/api2/stream:rect_transformation", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + cc_library( name = "hand_landmarker_result", srcs = ["hand_landmarker_result.cc"], diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_roi_refinement_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_roi_refinement_graph.cc new file mode 100644 index 000000000..e7e9b94d0 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_roi_refinement_graph.cc @@ -0,0 +1,154 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/detections_to_rects.h" +#include "mediapipe/framework/api2/stream/landmarks_projection.h" +#include "mediapipe/framework/api2/stream/landmarks_to_detection.h" +#include "mediapipe/framework/api2/stream/rect_transformation.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionToRect; +using ::mediapipe::api2::builder::ConvertLandmarksToDetection; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::ProjectLandmarks; +using ::mediapipe::api2::builder::ScaleAndShiftAndMakeSquareLong; +using ::mediapipe::api2::builder::Stream; + +// Refine the input hand RoI with hand_roi_refinement model. +// +// Inputs: +// IMAGE - Image +// The image to preprocess. +// NORM_RECT - NormalizedRect +// Coarse RoI of hand. +// Outputs: +// NORM_RECT - NormalizedRect +// Refined RoI of hand. +class HandRoiRefinementGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + mediapipe::SubgraphContext* context) override { + Graph graph; + Stream image_in = graph.In("IMAGE").Cast(); + Stream roi_in = + graph.In("NORM_RECT").Cast(); + + auto& graph_options = + *context->MutableOptions(); + + MP_ASSIGN_OR_RETURN( + const auto* model_resources, + GetOrCreateModelResources( + context)); + + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + graph_options.base_options().acceleration()); + auto& image_to_tensor_options = + *preprocessing + .GetOptions() + .mutable_image_to_tensor_options(); + image_to_tensor_options.set_keep_aspect_ratio(true); + image_to_tensor_options.set_border_mode( + mediapipe::ImageToTensorCalculatorOptions::BORDER_REPLICATE); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( + *model_resources, use_gpu, graph_options.base_options().gpu_origin(), + &preprocessing.GetOptions())); + image_in >> preprocessing.In("IMAGE"); + roi_in >> preprocessing.In("NORM_RECT"); + auto tensors_in = preprocessing.Out("TENSORS"); + auto matrix = preprocessing.Out("MATRIX").Cast>(); + auto image_size = + preprocessing.Out("IMAGE_SIZE").Cast>(); + + auto& inference = AddInference( + *model_resources, graph_options.base_options().acceleration(), graph); + tensors_in >> inference.In("TENSORS"); + auto tensors_out = inference.Out("TENSORS").Cast>(); + + MP_ASSIGN_OR_RETURN(auto image_tensor_specs, + BuildInputImageTensorSpecs(*model_resources)); + + // Convert tensors to landmarks. Recrop model outputs two points, + // center point and guide point. + auto& to_landmarks = graph.AddNode("TensorsToLandmarksCalculator"); + auto& to_landmarks_opts = + to_landmarks + .GetOptions(); + to_landmarks_opts.set_num_landmarks(/*num_landmarks=*/2); + to_landmarks_opts.set_input_image_width(image_tensor_specs.image_width); + to_landmarks_opts.set_input_image_height(image_tensor_specs.image_height); + to_landmarks_opts.set_normalize_z(/*z_norm_factor=*/1.0f); + tensors_out.ConnectTo(to_landmarks.In("TENSORS")); + auto recrop_landmarks = to_landmarks.Out("NORM_LANDMARKS") + .Cast(); + + // Project landmarks. + auto projected_recrop_landmarks = + ProjectLandmarks(recrop_landmarks, matrix, graph); + + // Convert re-crop landmarks to detection. + auto recrop_detection = + ConvertLandmarksToDetection(projected_recrop_landmarks, graph); + + // Convert re-crop detection to rect. + auto recrop_rect = ConvertAlignmentPointsDetectionToRect( + recrop_detection, image_size, /*start_keypoint_index=*/0, + /*end_keypoint_index=*/1, /*target_angle=*/-90, graph); + + auto refined_roi = + ScaleAndShiftAndMakeSquareLong(recrop_rect, image_size, + /*scale_x_factor=*/1.0, + /*scale_y_factor=*/1.0, /*shift_x=*/0, + /*shift_y=*/-0.1, graph); + refined_roi >> graph.Out("NORM_RECT").Cast(); + return graph.GetConfig(); + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::hand_landmarker::HandRoiRefinementGraph); + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/BUILD b/mediapipe/tasks/cc/vision/holistic_landmarker/BUILD new file mode 100644 index 000000000..446cf1e09 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/BUILD @@ -0,0 +1,152 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//mediapipe/tasks:internal"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "holistic_face_tracking", + srcs = ["holistic_face_tracking.cc"], + hdrs = ["holistic_face_tracking.h"], + deps = [ + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2/stream:detections_to_rects", + "//mediapipe/framework/api2/stream:image_size", + "//mediapipe/framework/api2/stream:landmarks_to_detection", + "//mediapipe/framework/api2/stream:loopback", + "//mediapipe/framework/api2/stream:rect_transformation", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:status", + "//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator", + "//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker:face_blendshapes_graph", + "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", + "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarks_detector_graph", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "holistic_hand_tracking", + srcs = ["holistic_hand_tracking.cc"], + hdrs = ["holistic_hand_tracking.h"], + deps = [ + "//mediapipe/calculators/util:align_hand_to_pose_in_world_calculator", + "//mediapipe/calculators/util:align_hand_to_pose_in_world_calculator_cc_proto", + "//mediapipe/calculators/util:landmark_visibility_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2/stream:image_size", + "//mediapipe/framework/api2/stream:landmarks_to_detection", + "//mediapipe/framework/api2/stream:loopback", + "//mediapipe/framework/api2/stream:rect_transformation", + "//mediapipe/framework/api2/stream:split", + "//mediapipe/framework/api2/stream:threshold", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:status", + "//mediapipe/modules/holistic_landmark/calculators:hand_detections_from_pose_to_rects_calculator", + "//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator", + "//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator_cc_proto", + "//mediapipe/tasks/cc/components/utils:gate", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_roi_refinement_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "holistic_pose_tracking", + srcs = ["holistic_pose_tracking.cc"], + hdrs = ["holistic_pose_tracking.h"], + deps = [ + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2/stream:detections_to_rects", + "//mediapipe/framework/api2/stream:image_size", + "//mediapipe/framework/api2/stream:landmarks_to_detection", + "//mediapipe/framework/api2/stream:loopback", + "//mediapipe/framework/api2/stream:merge", + "//mediapipe/framework/api2/stream:presence", + "//mediapipe/framework/api2/stream:rect_transformation", + "//mediapipe/framework/api2/stream:segmentation_smoothing", + "//mediapipe/framework/api2/stream:smoothing", + "//mediapipe/framework/api2/stream:split", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc/components/utils:gate", + "//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/pose_landmarker:pose_landmarks_detector_graph", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "holistic_landmarker_graph", + srcs = ["holistic_landmarker_graph.cc"], + deps = [ + ":holistic_face_tracking", + ":holistic_hand_tracking", + ":holistic_pose_tracking", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2/stream:split", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", + "//mediapipe/tasks/cc/core:model_resources_cache", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/pose_landmarker:pose_topology", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", + "//mediapipe/util:graph_builder_utils", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.cc new file mode 100644 index 000000000..1116cda21 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.cc @@ -0,0 +1,260 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/detections_to_rects.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/api2/stream/landmarks_to_detection.h" +#include "mediapipe/framework/api2/stream/loopback.h" +#include "mediapipe/framework/api2/stream/rect_transformation.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::builder::ConvertDetectionsToRectUsingKeypoints; +using ::mediapipe::api2::builder::ConvertDetectionToRect; +using ::mediapipe::api2::builder::ConvertLandmarksToDetection; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::GetLoopbackData; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Scale; +using ::mediapipe::api2::builder::ScaleAndMakeSquare; +using ::mediapipe::api2::builder::Stream; + +struct FaceLandmarksResult { + std::optional> landmarks; + std::optional> classifications; +}; + +absl::Status ValidateGraphOptions( + const face_detector::proto::FaceDetectorGraphOptions& + face_detector_graph_options, + const face_landmarker::proto::FaceLandmarksDetectorGraphOptions& + face_landmarks_detector_graph_options, + const HolisticFaceTrackingRequest& request) { + if (face_detector_graph_options.num_faces() != 1) { + return absl::InvalidArgumentError(absl::StrFormat( + "Only support num_faces to be 1, but got num_faces = %d.", + face_detector_graph_options.num_faces())); + } + if (request.classifications && !face_landmarks_detector_graph_options + .has_face_blendshapes_graph_options()) { + return absl::InvalidArgumentError( + "Blendshapes detection is requested, but " + "face_blendshapes_graph_options is not configured."); + } + return absl::OkStatus(); +} + +Stream GetFaceRoiFromPoseFaceLandmarks( + Stream pose_face_landmarks, + Stream> image_size, Graph& graph) { + Stream detection = + ConvertLandmarksToDetection(pose_face_landmarks, graph); + + // Refer the pose face landmarks indices here: + // https://developers.google.com/mediapipe/solutions/vision/pose_landmarker#pose_landmarker_model + Stream rect = ConvertDetectionToRect( + detection, image_size, /*start_keypoint_index=*/5, + /*end_keypoint_index=*/2, /*target_angle=*/0, graph); + + // Scale the face RoI from a tight rect enclosing the pose face landmarks, to + // a larger square so that the whole face is within the RoI. + return ScaleAndMakeSquare(rect, image_size, + /*scale_x_factor=*/3.0, + /*scale_y_factor=*/3.0, graph); +} + +Stream GetFaceRoiFromFaceLandmarks( + Stream face_landmarks, + Stream> image_size, Graph& graph) { + Stream detection = + ConvertLandmarksToDetection(face_landmarks, graph); + + Stream rect = ConvertDetectionToRect( + detection, image_size, /*start_keypoint_index=*/33, + /*end_keypoint_index=*/263, /*target_angle=*/0, graph); + + return Scale(rect, image_size, + /*scale_x_factor=*/1.5, + /*scale_y_factor=*/1.5, graph); +} + +Stream> GetFaceDetections( + Stream image, Stream roi, + const face_detector::proto::FaceDetectorGraphOptions& + face_detector_graph_options, + Graph& graph) { + auto& face_detector_graph = + graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph"); + face_detector_graph + .GetOptions() = + face_detector_graph_options; + image >> face_detector_graph.In("IMAGE"); + roi >> face_detector_graph.In("NORM_RECT"); + return face_detector_graph.Out("DETECTIONS").Cast>(); +} + +Stream GetFaceRoiFromFaceDetections( + Stream> face_detections, + Stream> image_size, Graph& graph) { + // Convert detection to rect. + Stream rect = ConvertDetectionsToRectUsingKeypoints( + face_detections, image_size, /*start_keypoint_index=*/0, + /*end_keypoint_index=*/1, /*target_angle=*/0, graph); + + return ScaleAndMakeSquare(rect, image_size, + /*scale_x_factor=*/2.0, + /*scale_y_factor=*/2.0, graph); +} + +Stream TrackFaceRoi( + Stream prev_landmarks, Stream roi, + Stream> image_size, Graph& graph) { + // Gets face ROI from previous frame face landmarks. + Stream prev_roi = + GetFaceRoiFromFaceLandmarks(prev_landmarks, image_size, graph); + + auto& tracking_node = graph.AddNode("RoiTrackingCalculator"); + auto& tracking_node_opts = + tracking_node.GetOptions(); + auto* rect_requirements = tracking_node_opts.mutable_rect_requirements(); + rect_requirements->set_rotation_degrees(15.0); + rect_requirements->set_translation(0.1); + rect_requirements->set_scale(0.3); + auto* landmarks_requirements = + tracking_node_opts.mutable_landmarks_requirements(); + landmarks_requirements->set_recrop_rect_margin(-0.2); + prev_landmarks.ConnectTo(tracking_node.In("PREV_LANDMARKS")); + prev_roi.ConnectTo(tracking_node.In("PREV_LANDMARKS_RECT")); + roi.ConnectTo(tracking_node.In("RECROP_RECT")); + image_size.ConnectTo(tracking_node.In("IMAGE_SIZE")); + return tracking_node.Out("TRACKING_RECT").Cast(); +} + +FaceLandmarksResult GetFaceLandmarksDetection( + Stream image, Stream roi, + Stream> image_size, + const face_landmarker::proto::FaceLandmarksDetectorGraphOptions& + face_landmarks_detector_graph_options, + const HolisticFaceTrackingRequest& request, Graph& graph) { + FaceLandmarksResult result; + auto& face_landmarks_detector_graph = graph.AddNode( + "mediapipe.tasks.vision.face_landmarker." + "SingleFaceLandmarksDetectorGraph"); + face_landmarks_detector_graph + .GetOptions() = + face_landmarks_detector_graph_options; + image >> face_landmarks_detector_graph.In("IMAGE"); + roi >> face_landmarks_detector_graph.In("NORM_RECT"); + auto landmarks = face_landmarks_detector_graph.Out("NORM_LANDMARKS") + .Cast(); + result.landmarks = landmarks; + if (request.classifications) { + auto& blendshapes_graph = graph.AddNode( + "mediapipe.tasks.vision.face_landmarker.FaceBlendshapesGraph"); + blendshapes_graph + .GetOptions() = + face_landmarks_detector_graph_options.face_blendshapes_graph_options(); + landmarks >> blendshapes_graph.In("LANDMARKS"); + image_size >> blendshapes_graph.In("IMAGE_SIZE"); + result.classifications = + blendshapes_graph.Out("BLENDSHAPES").Cast(); + } + return result; +} + +} // namespace + +absl::StatusOr TrackHolisticFace( + Stream image, Stream pose_face_landmarks, + const face_detector::proto::FaceDetectorGraphOptions& + face_detector_graph_options, + const face_landmarker::proto::FaceLandmarksDetectorGraphOptions& + face_landmarks_detector_graph_options, + const HolisticFaceTrackingRequest& request, Graph& graph) { + MP_RETURN_IF_ERROR(ValidateGraphOptions(face_detector_graph_options, + face_landmarks_detector_graph_options, + request)); + + // Extracts image size from the input images. + Stream> image_size = GetImageSize(image, graph); + + // Gets face ROI from pose face landmarks. + Stream roi_from_pose = + GetFaceRoiFromPoseFaceLandmarks(pose_face_landmarks, image_size, graph); + + // Detects faces within ROI of pose face. + Stream> face_detections = GetFaceDetections( + image, roi_from_pose, face_detector_graph_options, graph); + + // Gets face ROI from face detector. + Stream roi_from_detection = + GetFaceRoiFromFaceDetections(face_detections, image_size, graph); + + // Loop for previous frame landmarks. + auto [prev_landmarks, set_prev_landmarks_fn] = + GetLoopbackData(/*tick=*/image_size, graph); + + // Tracks face ROI. + auto tracking_roi = + TrackFaceRoi(prev_landmarks, roi_from_detection, image_size, graph); + + // Predicts face landmarks. + auto landmarks_detection_result = GetFaceLandmarksDetection( + image, tracking_roi, image_size, face_landmarks_detector_graph_options, + request, graph); + + // Sets previous landmarks for ROI tracking. + set_prev_landmarks_fn(landmarks_detection_result.landmarks.value()); + + return {{.landmarks = landmarks_detection_result.landmarks, + .classifications = landmarks_detection_result.classifications, + .debug_output = { + .roi_from_pose = roi_from_pose, + .roi_from_detection = roi_from_detection, + .tracking_roi = tracking_roi, + }}}; +} + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h new file mode 100644 index 000000000..835767ebc --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h @@ -0,0 +1,89 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +struct HolisticFaceTrackingRequest { + bool classifications = false; +}; + +struct HolisticFaceTrackingOutput { + std::optional> + landmarks; + std::optional> + classifications; + + struct DebugOutput { + api2::builder::Stream roi_from_pose; + api2::builder::Stream roi_from_detection; + api2::builder::Stream tracking_roi; + }; + + DebugOutput debug_output; +}; + +// Updates @graph to track a single face in @image based on pose landmarks. +// +// To track single face this subgraph uses pose face landmarks to obtain +// approximate face location, refines it with face detector model and then runs +// face landmarks model. It can also reuse face ROI from the previous frame if +// face hasn't moved too much. +// +// @image - Image to track a single face in. +// @pose_face_landmarks - Pose face landmarks to derive initial face location +// from. +// @face_detector_graph_options - face detector graph options used to detect the +// face within the RoI constructed from the pose face landmarks. +// @face_landmarks_detector_graph_options - face landmarks detector graph +// options used to detect face landmarks within the RoI given be the face +// detector graph. +// @request - object to request specific face tracking outputs. +// NOTE: Outputs that were not requested won't be returned and corresponding +// parts of the graph won't be genertaed. +// @graph - graph to update. +absl::StatusOr TrackHolisticFace( + api2::builder::Stream image, + api2::builder::Stream + pose_face_landmarks, + const face_detector::proto::FaceDetectorGraphOptions& + face_detector_graph_options, + const face_landmarker::proto::FaceLandmarksDetectorGraphOptions& + face_landmarks_detector_graph_options, + const HolisticFaceTrackingRequest& request, + mediapipe::api2::builder::Graph& graph); + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_ diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking_test.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking_test.cc new file mode 100644 index 000000000..314c330b3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking_test.cc @@ -0,0 +1,227 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h" + +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/calculators/util/rect_to_render_data_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/api2/stream/split.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h" +#include "mediapipe/tasks/cc/vision/utils/data_renderer.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::mediapipe::Image; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::SplitToRanges; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::core::proto::ExternalFile; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr float kAbsMargin = 0.015; +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kTestImageFile[] = "male_full_height_hands.jpg"; +constexpr char kHolisticResultFile[] = + "male_full_height_hands_result_cpu.pbtxt"; +constexpr char kImageInStream[] = "image_in"; +constexpr char kPoseLandmarksInStream[] = "pose_landmarks_in"; +constexpr char kFaceLandmarksOutStream[] = "face_landmarks_out"; +constexpr char kRenderedImageOutStream[] = "rendered_image_out"; +constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite"; +constexpr char kFaceLandmarksDetectorTFLiteName[] = + "face_landmarks_detector.tflite"; + +std::string GetFilePath(absl::string_view filename) { + return file::JoinPath("./", kTestDataDirectory, filename); +} + +mediapipe::LandmarksToRenderDataCalculatorOptions GetFaceRendererOptions() { + mediapipe::LandmarksToRenderDataCalculatorOptions render_options; + for (const auto& connection : + face_landmarker::FaceLandmarksConnections::kFaceLandmarksConnectors) { + render_options.add_landmark_connections(connection[0]); + render_options.add_landmark_connections(connection[1]); + } + render_options.mutable_landmark_color()->set_r(255); + render_options.mutable_landmark_color()->set_g(255); + render_options.mutable_landmark_color()->set_b(255); + render_options.mutable_connection_color()->set_r(255); + render_options.mutable_connection_color()->set_g(255); + render_options.mutable_connection_color()->set_b(255); + render_options.set_thickness(0.5); + render_options.set_visualize_landmark_depth(false); + return render_options; +} + +absl::StatusOr> +CreateModelAssetBundleResources(const std::string& model_asset_filename) { + auto external_model_bundle = std::make_unique(); + external_model_bundle->set_file_name(model_asset_filename); + return ModelAssetBundleResources::Create("", + std::move(external_model_bundle)); +} + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner() { + Graph graph; + Stream image = graph.In("IMAGE").Cast().SetName(kImageInStream); + Stream pose_landmarks = + graph.In("POSE_LANDMARKS") + .Cast() + .SetName(kPoseLandmarksInStream); + Stream face_landmarks_from_pose = + SplitToRanges(pose_landmarks, {{0, 11}}, graph)[0]; + // Create face landmarker model bundle. + MP_ASSIGN_OR_RETURN( + auto model_bundle, + CreateModelAssetBundleResources(GetFilePath("face_landmarker_v2.task"))); + face_detector::proto::FaceDetectorGraphOptions detector_options; + face_landmarker::proto::FaceLandmarksDetectorGraphOptions + landmarks_detector_options; + + // Set face detection model. + MP_ASSIGN_OR_RETURN(auto face_detector_model_file, + model_bundle->GetFile(kFaceDetectorTFLiteName)); + core::proto::FilePointerMeta face_detection_file_pointer; + face_detection_file_pointer.set_pointer( + reinterpret_cast(face_detector_model_file.data())); + face_detection_file_pointer.set_length(face_detector_model_file.size()); + detector_options.mutable_base_options() + ->mutable_model_asset() + ->mutable_file_pointer_meta() + ->Swap(&face_detection_file_pointer); + detector_options.set_num_faces(1); + + // Set face landmarks model. + MP_ASSIGN_OR_RETURN(auto face_landmarks_model_file, + model_bundle->GetFile(kFaceLandmarksDetectorTFLiteName)); + core::proto::FilePointerMeta face_landmarks_detector_file_pointer; + face_landmarks_detector_file_pointer.set_pointer( + reinterpret_cast(face_landmarks_model_file.data())); + face_landmarks_detector_file_pointer.set_length( + face_landmarks_model_file.size()); + landmarks_detector_options.mutable_base_options() + ->mutable_model_asset() + ->mutable_file_pointer_meta() + ->Swap(&face_landmarks_detector_file_pointer); + + // Track holistic face. + HolisticFaceTrackingRequest request; + MP_ASSIGN_OR_RETURN( + HolisticFaceTrackingOutput result, + TrackHolisticFace(image, face_landmarks_from_pose, detector_options, + landmarks_detector_options, request, graph)); + auto face_landmarks = + result.landmarks.value().SetName(kFaceLandmarksOutStream); + + auto image_size = GetImageSize(image, graph); + auto render_scale = utils::GetRenderScale( + image_size, result.debug_output.roi_from_pose, 0.0001, graph); + + auto face_landmarks_render_data = utils::RenderLandmarks( + face_landmarks, render_scale, GetFaceRendererOptions(), graph); + std::vector> render_list = { + face_landmarks_render_data}; + + auto rendered_image = + utils::Render( + image, absl::Span>(render_list), graph) + .SetName(kRenderedImageOutStream); + face_landmarks >> graph.Out("FACE_LANDMARKS"); + rendered_image >> graph.Out("RENDERED_IMAGE"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + return TaskRunner::Create( + config, std::make_unique()); +} + +class HolisticFaceTrackingTest : public ::testing::Test {}; + +TEST_F(HolisticFaceTrackingTest, SmokeTest) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(GetFilePath(kTestImageFile))); + + proto::HolisticResult holistic_result; + MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result, + ::file::Defaults())); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); + MP_ASSERT_OK_AND_ASSIGN( + auto output_packets, + task_runner->Process( + {{kImageInStream, MakePacket(image)}, + {kPoseLandmarksInStream, MakePacket( + holistic_result.pose_landmarks())}})); + ASSERT_TRUE(output_packets.find(kFaceLandmarksOutStream) != + output_packets.end()); + auto face_landmarks = output_packets.find(kFaceLandmarksOutStream) + ->second.Get(); + EXPECT_THAT( + face_landmarks, + Approximately(Partially(EqualsProto(holistic_result.face_landmarks())), + /*margin=*/kAbsMargin)); + auto rendered_image = output_packets.at(kRenderedImageOutStream).Get(); + MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(), + "holistic_face_landmarks")); +} + +} // namespace +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.cc new file mode 100644 index 000000000..2c57aa059 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.cc @@ -0,0 +1,272 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/util/align_hand_to_pose_in_world_calculator.h" +#include "mediapipe/calculators/util/align_hand_to_pose_in_world_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/api2/stream/landmarks_to_detection.h" +#include "mediapipe/framework/api2/stream/loopback.h" +#include "mediapipe/framework/api2/stream/rect_transformation.h" +#include "mediapipe/framework/api2/stream/split.h" +#include "mediapipe/framework/api2/stream/threshold.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.pb.h" +#include "mediapipe/tasks/cc/components/utils/gate.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::AlignHandToPoseInWorldCalculator; +using ::mediapipe::api2::builder::ConvertLandmarksToDetection; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::GetLoopbackData; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::IsOverThreshold; +using ::mediapipe::api2::builder::ScaleAndShiftAndMakeSquareLong; +using ::mediapipe::api2::builder::SplitAndCombine; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::components::utils::AllowIf; + +struct HandLandmarksResult { + std::optional> landmarks; + std::optional> world_landmarks; +}; + +Stream AlignHandToPoseInWorldCalculator( + Stream hand_world_landmarks, + Stream pose_world_landmarks, int pose_wrist_idx, + Graph& graph) { + auto& node = graph.AddNode("AlignHandToPoseInWorldCalculator"); + auto& opts = node.GetOptions(); + opts.set_hand_wrist_idx(0); + opts.set_pose_wrist_idx(pose_wrist_idx); + hand_world_landmarks.ConnectTo( + node[AlignHandToPoseInWorldCalculator::kInHandLandmarks]); + pose_world_landmarks.ConnectTo( + node[AlignHandToPoseInWorldCalculator::kInPoseLandmarks]); + return node[AlignHandToPoseInWorldCalculator::kOutHandLandmarks]; +} + +Stream GetPosePalmVisibility( + Stream pose_palm_landmarks, Graph& graph) { + // Get wrist landmark. + auto pose_wrist = SplitAndCombine(pose_palm_landmarks, {0}, graph); + + // Get visibility score. + auto& score_node = graph.AddNode("LandmarkVisibilityCalculator"); + pose_wrist.ConnectTo(score_node.In("NORM_LANDMARKS")); + Stream score = score_node.Out("VISIBILITY").Cast(); + + // Convert score into flag. + return IsOverThreshold(score, /*threshold=*/0.1, graph); +} + +Stream GetHandRoiFromPosePalmLandmarks( + Stream pose_palm_landmarks, + Stream> image_size, Graph& graph) { + // Convert pose palm landmarks to detection. + auto detection = ConvertLandmarksToDetection(pose_palm_landmarks, graph); + + // Convert detection to rect. + auto& rect_node = graph.AddNode("HandDetectionsFromPoseToRectsCalculator"); + detection.ConnectTo(rect_node.In("DETECTION")); + image_size.ConnectTo(rect_node.In("IMAGE_SIZE")); + Stream rect = + rect_node.Out("NORM_RECT").Cast(); + + return ScaleAndShiftAndMakeSquareLong(rect, image_size, + /*scale_x_factor=*/2.7, + /*scale_y_factor=*/2.7, /*shift_x=*/0, + /*shift_y=*/-0.1, graph); +} + +absl::StatusOr> RefineHandRoi( + Stream image, Stream roi, + const hand_landmarker::proto::HandRoiRefinementGraphOptions& + hand_roi_refinenement_graph_options, + Graph& graph) { + auto& hand_roi_refinement = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker.HandRoiRefinementGraph"); + hand_roi_refinement + .GetOptions() = + hand_roi_refinenement_graph_options; + image >> hand_roi_refinement.In("IMAGE"); + roi >> hand_roi_refinement.In("NORM_RECT"); + return hand_roi_refinement.Out("NORM_RECT").Cast(); +} + +Stream TrackHandRoi( + Stream prev_landmarks, Stream roi, + Stream> image_size, Graph& graph) { + // Convert hand landmarks to tight rect. + auto& prev_rect_node = graph.AddNode("HandLandmarksToRectCalculator"); + prev_landmarks.ConnectTo(prev_rect_node.In("NORM_LANDMARKS")); + image_size.ConnectTo(prev_rect_node.In("IMAGE_SIZE")); + Stream prev_rect = + prev_rect_node.Out("NORM_RECT").Cast(); + + // Convert tight hand rect to hand roi. + Stream prev_roi = + ScaleAndShiftAndMakeSquareLong(prev_rect, image_size, + /*scale_x_factor=*/2.0, + /*scale_y_factor=*/2.0, /*shift_x=*/0, + /*shift_y=*/-0.1, graph); + + auto& tracking_node = graph.AddNode("RoiTrackingCalculator"); + auto& tracking_node_opts = + tracking_node.GetOptions(); + auto* rect_requirements = tracking_node_opts.mutable_rect_requirements(); + rect_requirements->set_rotation_degrees(40.0); + rect_requirements->set_translation(0.2); + rect_requirements->set_scale(0.4); + auto* landmarks_requirements = + tracking_node_opts.mutable_landmarks_requirements(); + landmarks_requirements->set_recrop_rect_margin(-0.1); + prev_landmarks.ConnectTo(tracking_node.In("PREV_LANDMARKS")); + prev_roi.ConnectTo(tracking_node.In("PREV_LANDMARKS_RECT")); + roi.ConnectTo(tracking_node.In("RECROP_RECT")); + image_size.ConnectTo(tracking_node.In("IMAGE_SIZE")); + return tracking_node.Out("TRACKING_RECT").Cast(); +} + +HandLandmarksResult GetHandLandmarksDetection( + Stream image, Stream roi, + const hand_landmarker::proto::HandLandmarksDetectorGraphOptions& + hand_landmarks_detector_graph_options, + const HolisticHandTrackingRequest& request, Graph& graph) { + HandLandmarksResult result; + auto& hand_landmarks_detector_graph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker." + "SingleHandLandmarksDetectorGraph"); + hand_landmarks_detector_graph + .GetOptions() = + hand_landmarks_detector_graph_options; + + image >> hand_landmarks_detector_graph.In("IMAGE"); + roi >> hand_landmarks_detector_graph.In("HAND_RECT"); + + if (request.landmarks) { + result.landmarks = hand_landmarks_detector_graph.Out("LANDMARKS") + .Cast(); + } + if (request.world_landmarks) { + result.world_landmarks = + hand_landmarks_detector_graph.Out("WORLD_LANDMARKS") + .Cast(); + } + return result; +} + +} // namespace + +absl::StatusOr TrackHolisticHand( + Stream image, Stream pose_landmarks, + Stream pose_world_landmarks, + const hand_landmarker::proto::HandLandmarksDetectorGraphOptions& + hand_landmarks_detector_graph_options, + const hand_landmarker::proto::HandRoiRefinementGraphOptions& + hand_roi_refinement_graph_options, + const PoseIndices& pose_indices, const HolisticHandTrackingRequest& request, + Graph& graph) { + // Extracts pose palm landmarks. + Stream pose_palm_landmarks = SplitAndCombine( + pose_landmarks, + {pose_indices.wrist_idx, pose_indices.pinky_idx, pose_indices.index_idx}, + graph); + + // Get pose palm visibility. + Stream is_pose_palm_visible = + GetPosePalmVisibility(pose_palm_landmarks, graph); + + // Drop pose palm landmarks if pose palm is invisible. + pose_palm_landmarks = + AllowIf(pose_palm_landmarks, is_pose_palm_visible, graph); + + // Extracts image size from the input images. + Stream> image_size = GetImageSize(image, graph); + + // Get hand ROI from pose palm landmarks. + Stream roi_from_pose = + GetHandRoiFromPosePalmLandmarks(pose_palm_landmarks, image_size, graph); + + // Refine hand ROI with re-crop model. + MP_ASSIGN_OR_RETURN(Stream roi_from_recrop, + RefineHandRoi(image, roi_from_pose, + hand_roi_refinement_graph_options, graph)); + + // Loop for previous frame landmarks. + auto [prev_landmarks, set_prev_landmarks_fn] = + GetLoopbackData(/*tick=*/image_size, graph); + + // Track hand ROI. + auto tracking_roi = + TrackHandRoi(prev_landmarks, roi_from_recrop, image_size, graph); + + // Predict hand landmarks. + auto landmarks_detection_result = GetHandLandmarksDetection( + image, tracking_roi, hand_landmarks_detector_graph_options, request, + graph); + + // Set previous landmarks for ROI tracking. + set_prev_landmarks_fn(landmarks_detection_result.landmarks.value()); + + // Output landmarks. + std::optional> hand_landmarks; + if (request.landmarks) { + hand_landmarks = landmarks_detection_result.landmarks; + } + + // Output world landmarks. + std::optional> hand_world_landmarks; + if (request.world_landmarks) { + hand_world_landmarks = landmarks_detection_result.world_landmarks; + + // Align hand world landmarks with pose world landmarks. + hand_world_landmarks = AlignHandToPoseInWorldCalculator( + hand_world_landmarks.value(), pose_world_landmarks, + pose_indices.wrist_idx, graph); + } + + return {{.landmarks = hand_landmarks, + .world_landmarks = hand_world_landmarks, + .debug_output = { + .roi_from_pose = roi_from_pose, + .roi_from_recrop = roi_from_recrop, + .tracking_roi = tracking_roi, + }}}; +} + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h new file mode 100644 index 000000000..463f4979b --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h @@ -0,0 +1,94 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +struct PoseIndices { + int wrist_idx; + int pinky_idx; + int index_idx; +}; + +struct HolisticHandTrackingRequest { + bool landmarks = false; + bool world_landmarks = false; +}; + +struct HolisticHandTrackingOutput { + std::optional> + landmarks; + std::optional> world_landmarks; + + struct DebugOutput { + api2::builder::Stream roi_from_pose; + api2::builder::Stream roi_from_recrop; + api2::builder::Stream tracking_roi; + }; + + DebugOutput debug_output; +}; + +// Updates @graph to track a single hand in @image based on pose landmarks. +// +// To track single hand this subgraph uses pose palm landmarks to obtain +// approximate hand location, refines it with re-crop model and then runs hand +// landmarks model. It can also reuse hand ROI from the previous frame if hand +// hasn't moved too much. +// +// @image - ImageFrame/GpuBuffer to track a single hand in. +// @pose_landmarks - Pose landmarks to derive initial hand location from. +// @pose_world_landmarks - Pose world landmarks to align hand world landmarks +// wrist with. +// @ hand_landmarks_detector_graph_options - Options of the +// HandLandmarksDetectorGraph used to detect the hand landmarks. +// @ hand_roi_refinement_graph_options - Options of HandRoiRefinementGraph used +// to refine the hand RoIs got from Pose landmarks. +// @request - object to request specific hand tracking outputs. +// NOTE: Outputs that were not requested won't be returned and corresponding +// parts of the graph won't be genertaed. +// @graph - graph to update. +absl::StatusOr TrackHolisticHand( + api2::builder::Stream image, + api2::builder::Stream pose_landmarks, + api2::builder::Stream pose_world_landmarks, + const hand_landmarker::proto::HandLandmarksDetectorGraphOptions& + hand_landmarks_detector_graph_options, + const hand_landmarker::proto::HandRoiRefinementGraphOptions& + hand_roi_refinement_graph_options, + const PoseIndices& pose_indices, const HolisticHandTrackingRequest& request, + mediapipe::api2::builder::Graph& graph); + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_ diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking_test.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking_test.cc new file mode 100644 index 000000000..4ae4a37ed --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking_test.cc @@ -0,0 +1,303 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h" + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/substitute.h" +#include "absl/types/span.h" +#include "file/base/helpers.h" +#include "file/base/options.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_topology.h" +#include "mediapipe/tasks/cc/vision/utils/data_renderer.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::Image; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::core::TaskRunner; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr float kAbsMargin = 0.018; +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kHolisticHandTrackingLeft[] = + "holistic_hand_tracking_left_hand_graph.pbtxt"; +constexpr char kTestImageFile[] = "male_full_height_hands.jpg"; +constexpr char kHolisticResultFile[] = + "male_full_height_hands_result_cpu.pbtxt"; +constexpr char kImageInStream[] = "image_in"; +constexpr char kPoseLandmarksInStream[] = "pose_landmarks_in"; +constexpr char kPoseWorldLandmarksInStream[] = "pose_world_landmarks_in"; +constexpr char kLeftHandLandmarksOutStream[] = "left_hand_landmarks_out"; +constexpr char kLeftHandWorldLandmarksOutStream[] = + "left_hand_world_landmarks_out"; +constexpr char kRightHandLandmarksOutStream[] = "right_hand_landmarks_out"; +constexpr char kRenderedImageOutStream[] = "rendered_image_out"; +constexpr char kHandLandmarksModelFile[] = "hand_landmark_full.tflite"; +constexpr char kHandRoiRefinementModelFile[] = + "handrecrop_2020_07_21_v0.f16.tflite"; + +std::string GetFilePath(const std::string& filename) { + return file::JoinPath("./", kTestDataDirectory, filename); +} + +mediapipe::LandmarksToRenderDataCalculatorOptions GetHandRendererOptions() { + mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options; + for (const auto& connection : hand_landmarker::kHandConnections) { + renderer_options.add_landmark_connections(connection[0]); + renderer_options.add_landmark_connections(connection[1]); + } + renderer_options.mutable_landmark_color()->set_r(255); + renderer_options.mutable_landmark_color()->set_g(255); + renderer_options.mutable_landmark_color()->set_b(255); + renderer_options.mutable_connection_color()->set_r(255); + renderer_options.mutable_connection_color()->set_g(255); + renderer_options.mutable_connection_color()->set_b(255); + renderer_options.set_thickness(0.5); + renderer_options.set_visualize_landmark_depth(false); + return renderer_options; +} + +void ConfigHandTrackingModelsOptions( + hand_landmarker::proto::HandLandmarksDetectorGraphOptions& + hand_landmarks_detector_graph_options, + hand_landmarker::proto::HandRoiRefinementGraphOptions& + hand_roi_refinement_options) { + hand_landmarks_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kHandLandmarksModelFile)); + + hand_roi_refinement_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kHandRoiRefinementModelFile)); +} + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner() { + Graph graph; + Stream image = graph.In("IMAGE").Cast().SetName(kImageInStream); + Stream pose_landmarks = + graph.In("POSE_LANDMARKS") + .Cast() + .SetName(kPoseLandmarksInStream); + Stream pose_world_landmarks = + graph.In("POSE_WORLD_LANDMARKS") + .Cast() + .SetName(kPoseWorldLandmarksInStream); + hand_landmarker::proto::HandLandmarksDetectorGraphOptions + hand_landmarks_detector_options; + hand_landmarker::proto::HandRoiRefinementGraphOptions + hand_roi_refinement_options; + ConfigHandTrackingModelsOptions(hand_landmarks_detector_options, + hand_roi_refinement_options); + HolisticHandTrackingRequest request; + request.landmarks = true; + MP_ASSIGN_OR_RETURN( + HolisticHandTrackingOutput left_hand_result, + TrackHolisticHand( + image, pose_landmarks, pose_world_landmarks, + hand_landmarks_detector_options, hand_roi_refinement_options, + PoseIndices{ + /*wrist_idx=*/static_cast( + pose_landmarker::PoseLandmarkName::kLeftWrist), + /*pinky_idx=*/ + static_cast(pose_landmarker::PoseLandmarkName::kLeftPinky1), + /*index_idx=*/ + static_cast(pose_landmarker::PoseLandmarkName::kLeftIndex1)}, + request, graph)); + MP_ASSIGN_OR_RETURN( + HolisticHandTrackingOutput right_hand_result, + TrackHolisticHand( + image, pose_landmarks, pose_world_landmarks, + hand_landmarks_detector_options, hand_roi_refinement_options, + PoseIndices{ + /*wrist_idx=*/static_cast( + pose_landmarker::PoseLandmarkName::kRightWrist), + /*pinky_idx=*/ + static_cast(pose_landmarker::PoseLandmarkName::kRightPinky1), + /*index_idx=*/ + static_cast( + pose_landmarker::PoseLandmarkName::kRightIndex1)}, + request, graph)); + + auto image_size = GetImageSize(image, graph); + auto left_hand_landmarks_render_data = utils::RenderLandmarks( + *left_hand_result.landmarks, + utils::GetRenderScale(image_size, + left_hand_result.debug_output.roi_from_pose, 0.0001, + graph), + GetHandRendererOptions(), graph); + auto right_hand_landmarks_render_data = utils::RenderLandmarks( + *right_hand_result.landmarks, + utils::GetRenderScale(image_size, + right_hand_result.debug_output.roi_from_pose, + 0.0001, graph), + GetHandRendererOptions(), graph); + std::vector> render_list = { + left_hand_landmarks_render_data, right_hand_landmarks_render_data}; + auto rendered_image = + utils::Render( + image, absl::Span>(render_list), graph) + .SetName(kRenderedImageOutStream); + left_hand_result.landmarks->SetName(kLeftHandLandmarksOutStream) >> + graph.Out("LEFT_HAND_LANDMARKS"); + right_hand_result.landmarks->SetName(kRightHandLandmarksOutStream) >> + graph.Out("RIGHT_HAND_LANDMARKS"); + rendered_image >> graph.Out("RENDERED_IMAGE"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + + return TaskRunner::Create( + config, std::make_unique()); +} + +class HolisticHandTrackingTest : public ::testing::Test {}; + +TEST_F(HolisticHandTrackingTest, VerifyGraph) { + Graph graph; + Stream image = graph.In("IMAGE").Cast().SetName(kImageInStream); + Stream pose_landmarks = + graph.In("POSE_LANDMARKS") + .Cast() + .SetName(kPoseLandmarksInStream); + Stream pose_world_landmarks = + graph.In("POSE_WORLD_LANDMARKS") + .Cast() + .SetName(kPoseWorldLandmarksInStream); + hand_landmarker::proto::HandLandmarksDetectorGraphOptions + hand_landmarks_detector_options; + hand_landmarker::proto::HandRoiRefinementGraphOptions + hand_roi_refinement_options; + ConfigHandTrackingModelsOptions(hand_landmarks_detector_options, + hand_roi_refinement_options); + HolisticHandTrackingRequest request; + request.landmarks = true; + request.world_landmarks = true; + MP_ASSERT_OK_AND_ASSIGN( + HolisticHandTrackingOutput left_hand_result, + TrackHolisticHand( + image, pose_landmarks, pose_world_landmarks, + hand_landmarks_detector_options, hand_roi_refinement_options, + PoseIndices{ + /*wrist_idx=*/static_cast( + pose_landmarker::PoseLandmarkName::kLeftWrist), + /*pinky_idx=*/ + static_cast(pose_landmarker::PoseLandmarkName::kLeftPinky1), + /*index_idx=*/ + static_cast(pose_landmarker::PoseLandmarkName::kLeftIndex1)}, + request, graph)); + left_hand_result.landmarks->SetName(kLeftHandLandmarksOutStream) >> + graph.Out("LEFT_HAND_LANDMARKS"); + left_hand_result.world_landmarks->SetName(kLeftHandWorldLandmarksOutStream) >> + graph.Out("LEFT_HAND_WORLD_LANDMARKS"); + + // Read the expected graph config. + std::string expected_graph_contents; + MP_ASSERT_OK(file::GetContents( + file::JoinPath("./", kTestDataDirectory, kHolisticHandTrackingLeft), + &expected_graph_contents)); + + // Need to replace the expected graph config with the test srcdir, because + // each run has different test dir on TAP. + expected_graph_contents = absl::Substitute( + expected_graph_contents, FLAGS_test_srcdir, FLAGS_test_srcdir); + CalculatorGraphConfig expected_graph = + ParseTextProtoOrDie(expected_graph_contents); + + EXPECT_THAT(graph.GetConfig(), testing::proto::IgnoringRepeatedFieldOrdering( + testing::EqualsProto(expected_graph))); +} + +TEST_F(HolisticHandTrackingTest, SmokeTest) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(GetFilePath(kTestImageFile))); + + proto::HolisticResult holistic_result; + MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result, + Defaults())); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); + MP_ASSERT_OK_AND_ASSIGN( + auto output_packets, + task_runner->Process( + {{kImageInStream, MakePacket(image)}, + {kPoseLandmarksInStream, MakePacket( + holistic_result.pose_landmarks())}, + {kPoseWorldLandmarksInStream, + MakePacket( + holistic_result.pose_world_landmarks())}})); + auto left_hand_landmarks = output_packets.at(kLeftHandLandmarksOutStream) + .Get(); + auto right_hand_landmarks = output_packets.at(kRightHandLandmarksOutStream) + .Get(); + EXPECT_THAT(left_hand_landmarks, + Approximately( + Partially(EqualsProto(holistic_result.left_hand_landmarks())), + /*margin=*/kAbsMargin)); + EXPECT_THAT( + right_hand_landmarks, + Approximately( + Partially(EqualsProto(holistic_result.right_hand_landmarks())), + /*margin=*/kAbsMargin)); + auto rendered_image = output_packets.at(kRenderedImageOutStream).Get(); + MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(), + "holistic_hand_landmarks")); +} + +} // namespace +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph.cc new file mode 100644 index 000000000..2de358a6c --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph.cc @@ -0,0 +1,521 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/split.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_topology.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" +#include "mediapipe/util/graph_builder_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { +namespace { + +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::metadata::SetExternalFile; + +constexpr absl::string_view kHandLandmarksDetectorModelName = + "hand_landmarks_detector.tflite"; +constexpr absl::string_view kHandRoiRefinementModelName = + "hand_roi_refinement.tflite"; +constexpr absl::string_view kFaceDetectorModelName = "face_detector.tflite"; +constexpr absl::string_view kFaceLandmarksDetectorModelName = + "face_landmarks_detector.tflite"; +constexpr absl::string_view kFaceBlendshapesModelName = + "face_blendshapes.tflite"; +constexpr absl::string_view kPoseDetectorModelName = "pose_detector.tflite"; +constexpr absl::string_view kPoseLandmarksDetectorModelName = + "pose_landmarks_detector.tflite"; + +absl::Status SetGraphPoseOutputs( + const HolisticPoseTrackingRequest& pose_request, + const CalculatorGraphConfig::Node& node, + HolisticPoseTrackingOutput& pose_output, Graph& graph) { + // Main outputs. + if (pose_request.landmarks) { + RET_CHECK(pose_output.landmarks.has_value()) + << "POSE_LANDMARKS output is not supported."; + pose_output.landmarks->ConnectTo(graph.Out("POSE_LANDMARKS")); + } + if (pose_request.world_landmarks) { + RET_CHECK(pose_output.world_landmarks.has_value()) + << "POSE_WORLD_LANDMARKS output is not supported."; + pose_output.world_landmarks->ConnectTo(graph.Out("POSE_WORLD_LANDMARKS")); + } + if (pose_request.segmentation_mask) { + RET_CHECK(pose_output.segmentation_mask.has_value()) + << "POSE_SEGMENTATION_MASK output is not supported."; + pose_output.segmentation_mask->ConnectTo( + graph.Out("POSE_SEGMENTATION_MASK")); + } + + // Debug outputs. + if (HasOutput(node, "POSE_AUXILIARY_LANDMARKS")) { + pose_output.debug_output.auxiliary_landmarks.ConnectTo( + graph.Out("POSE_AUXILIARY_LANDMARKS")); + } + if (HasOutput(node, "POSE_LANDMARKS_ROI")) { + pose_output.debug_output.roi_from_landmarks.ConnectTo( + graph.Out("POSE_LANDMARKS_ROI")); + } + + return absl::OkStatus(); +} + +// Sets the base options in the sub tasks. +template +absl::Status SetSubTaskBaseOptions( + const core::ModelAssetBundleResources* resources, + proto::HolisticLandmarkerGraphOptions* options, T* sub_task_options, + absl::string_view model_name, bool is_copy) { + if (!sub_task_options->base_options().has_model_asset()) { + MP_ASSIGN_OR_RETURN(const auto model_file_content, + resources->GetFile(std::string(model_name))); + SetExternalFile( + model_file_content, + sub_task_options->mutable_base_options()->mutable_model_asset(), + is_copy); + } + sub_task_options->mutable_base_options()->mutable_acceleration()->CopyFrom( + options->base_options().acceleration()); + sub_task_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); + sub_task_options->mutable_base_options()->set_gpu_origin( + options->base_options().gpu_origin()); + return absl::OkStatus(); +} + +void SetGraphHandOutputs(bool is_left, const CalculatorGraphConfig::Node& node, + HolisticHandTrackingOutput& hand_output, + Graph& graph) { + const std::string hand_side = is_left ? "LEFT" : "RIGHT"; + + if (hand_output.landmarks) { + hand_output.landmarks->ConnectTo(graph.Out(hand_side + "_HAND_LANDMARKS")); + } + if (hand_output.world_landmarks) { + hand_output.world_landmarks->ConnectTo( + graph.Out(hand_side + "_HAND_WORLD_LANDMARKS")); + } + + // Debug outputs. + if (HasOutput(node, hand_side + "_HAND_ROI_FROM_POSE")) { + hand_output.debug_output.roi_from_pose.ConnectTo( + graph.Out(hand_side + "_HAND_ROI_FROM_POSE")); + } + if (HasOutput(node, hand_side + "_HAND_ROI_FROM_RECROP")) { + hand_output.debug_output.roi_from_recrop.ConnectTo( + graph.Out(hand_side + "_HAND_ROI_FROM_RECROP")); + } + if (HasOutput(node, hand_side + "_HAND_TRACKING_ROI")) { + hand_output.debug_output.tracking_roi.ConnectTo( + graph.Out(hand_side + "_HAND_TRACKING_ROI")); + } +} + +void SetGraphFaceOutputs(const CalculatorGraphConfig::Node& node, + HolisticFaceTrackingOutput& face_output, + Graph& graph) { + if (face_output.landmarks) { + face_output.landmarks->ConnectTo(graph.Out("FACE_LANDMARKS")); + } + if (face_output.classifications) { + face_output.classifications->ConnectTo(graph.Out("FACE_BLENDSHAPES")); + } + + // Face detection debug outputs + if (HasOutput(node, "FACE_ROI_FROM_POSE")) { + face_output.debug_output.roi_from_pose.ConnectTo( + graph.Out("FACE_ROI_FROM_POSE")); + } + if (HasOutput(node, "FACE_ROI_FROM_DETECTION")) { + face_output.debug_output.roi_from_detection.ConnectTo( + graph.Out("FACE_ROI_FROM_DETECTION")); + } + if (HasOutput(node, "FACE_TRACKING_ROI")) { + face_output.debug_output.tracking_roi.ConnectTo( + graph.Out("FACE_TRACKING_ROI")); + } +} + +} // namespace + +// Tracks pose and detects hands and face. +// +// NOTE: for GPU works only with image having GpuOrigin::TOP_LEFT +// +// Inputs: +// IMAGE - Image +// Image to perform detection on. +// +// Outputs: +// POSE_LANDMARKS - NormalizedLandmarkList +// 33 landmarks (see pose_landmarker/pose_topology.h) +// 0 - nose +// 1 - left eye (inner) +// 2 - left eye +// 3 - left eye (outer) +// 4 - right eye (inner) +// 5 - right eye +// 6 - right eye (outer) +// 7 - left ear +// 8 - right ear +// 9 - mouth (left) +// 10 - mouth (right) +// 11 - left shoulder +// 12 - right shoulder +// 13 - left elbow +// 14 - right elbow +// 15 - left wrist +// 16 - right wrist +// 17 - left pinky +// 18 - right pinky +// 19 - left index +// 20 - right index +// 21 - left thumb +// 22 - right thumb +// 23 - left hip +// 24 - right hip +// 25 - left knee +// 26 - right knee +// 27 - left ankle +// 28 - right ankle +// 29 - left heel +// 30 - right heel +// 31 - left foot index +// 32 - right foot index +// POSE_WORLD_LANDMARKS - LandmarkList +// World landmarks are real world 3D coordinates with origin in hips center +// and coordinates in meters. To understand the difference: POSE_LANDMARKS +// stream provides coordinates (in pixels) of 3D object projected on a 2D +// surface of the image (check on how perspective projection works), while +// POSE_WORLD_LANDMARKS stream provides coordinates (in meters) of the 3D +// object itself. POSE_WORLD_LANDMARKS has the same landmarks topology, +// visibility and presence as POSE_LANDMARKS. +// POSE_SEGMENTATION_MASK - Image +// Separates person from background. Mask is stored as gray float32 image +// with [0.0, 1.0] range for pixels (1 for person and 0 for background) on +// CPU and, on GPU - RGBA texture with R channel indicating person vs. +// background probability. +// LEFT_HAND_LANDMARKS - NormalizedLandmarkList +// 21 left hand landmarks. +// RIGHT_HAND_LANDMARKS - NormalizedLandmarkList +// 21 right hand landmarks. +// FACE_LANDMARKS - NormalizedLandmarkList +// 468 face landmarks. +// FACE_BLENDSHAPES - ClassificationList +// Supplementary blendshape coefficients that are predicted directly from +// the input image. +// LEFT_HAND_WORLD_LANDMARKS - LandmarkList +// 21 left hand world 3D landmarks. +// Hand landmarks are aligned with pose landmarks: translated so that wrist +// from # hand matches wrist from pose in pose coordinates system. +// RIGHT_HAND_WORLD_LANDMARKS - LandmarkList +// 21 right hand world 3D landmarks. +// Hand landmarks are aligned with pose landmarks: translated so that wrist +// from # hand matches wrist from pose in pose coordinates system. +// IMAGE - Image +// The input image that the hiolistic landmarker runs on and has the pixel +// data stored on the target storage (CPU vs GPU). +// +// Debug outputs: +// POSE_AUXILIARY_LANDMARKS - NormalizedLandmarkList +// TODO: Return ROI rather than auxiliary landmarks +// Auxiliary landmarks for deriving the ROI in the subsequent image. +// 0 - hidden center point +// 1 - hidden scale point +// POSE_LANDMARKS_ROI - NormalizedRect +// Region of interest calculated based on landmarks. +// LEFT_HAND_ROI_FROM_POSE - NormalizedLandmarkList +// LEFT_HAND_ROI_FROM_RECROP - NormalizedLandmarkList +// LEFT_HAND_TRACKING_ROI - NormalizedLandmarkList +// RIGHT_HAND_ROI_FROM_POSE - NormalizedLandmarkList +// RIGHT_HAND_ROI_FROM_RECROP - NormalizedLandmarkList +// RIGHT_HAND_TRACKING_ROI - NormalizedLandmarkList +// FACE_ROI_FROM_POSE - NormalizedLandmarkList +// FACE_ROI_FROM_DETECTION - NormalizedLandmarkList +// FACE_TRACKING_ROI - NormalizedLandmarkList +// +// NOTE: failure is reported if some output has been requested, but specified +// model doesn't support it. +// +// NOTE: there will not be an output packet in an output stream for a +// particular timestamp if nothing is detected. However, the MediaPipe +// framework will internally inform the downstream calculators of the +// absence of this packet so that they don't wait for it unnecessarily. +// +// Example: +// node { +// calculator: +// "mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph" +// input_stream: "IMAGE:input_frames_image" +// output_stream: "POSE_LANDMARKS:pose_landmarks" +// output_stream: "POSE_WORLD_LANDMARKS:pose_world_landmarks" +// output_stream: "FACE_LANDMARKS:face_landmarks" +// output_stream: "FACE_BLENDSHAPES:extra_blendshapes" +// output_stream: "LEFT_HAND_LANDMARKS:left_hand_landmarks" +// output_stream: "LEFT_HAND_WORLD_LANDMARKS:left_hand_world_landmarks" +// output_stream: "RIGHT_HAND_LANDMARKS:right_hand_landmarks" +// output_stream: "RIGHT_HAND_WORLD_LANDMARKS:right_hand_world_landmarks" +// node_options { +// [type.googleapis.com/mediapipe.tasks.vision.holistic_landmarker.proto.HolisticLandmarkerGraphOptions] +// { +// base_options { +// model_asset { +// file_name: +// "mediapipe/tasks/testdata/vision/holistic_landmarker.task" +// } +// } +// face_detector_graph_options: { +// num_faces: 1 +// } +// pose_detector_graph_options: { +// num_poses: 1 +// } +// } +// } +// } +class HolisticLandmarkerGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + const auto& holistic_node = sc->OriginalNode(); + proto::HolisticLandmarkerGraphOptions* holistic_options = + sc->MutableOptions(); + const core::ModelAssetBundleResources* model_asset_bundle_resources; + if (holistic_options->base_options().has_model_asset()) { + MP_ASSIGN_OR_RETURN(model_asset_bundle_resources, + CreateModelAssetBundleResources< + proto::HolisticLandmarkerGraphOptions>(sc)); + } + // Copies the file content instead of passing the pointer of file in + // memory if the subgraph model resource service is not available. + bool create_copy = + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable(); + + Stream image = graph.In("IMAGE").Cast(); + + // Check whether Hand requested + const bool is_left_hand_requested = + HasOutput(holistic_node, "LEFT_HAND_LANDMARKS"); + const bool is_right_hand_requested = + HasOutput(holistic_node, "RIGHT_HAND_LANDMARKS"); + const bool is_left_hand_world_requested = + HasOutput(holistic_node, "LEFT_HAND_WORLD_LANDMARKS"); + const bool is_right_hand_world_requested = + HasOutput(holistic_node, "RIGHT_HAND_WORLD_LANDMARKS"); + const bool hands_requested = + is_left_hand_requested || is_right_hand_requested || + is_left_hand_world_requested || is_right_hand_world_requested; + if (hands_requested) { + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_hand_landmarks_detector_graph_options(), + kHandLandmarksDetectorModelName, create_copy)); + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_hand_roi_refinement_graph_options(), + kHandRoiRefinementModelName, create_copy)); + } + + // Check whether Face requested + const bool is_face_requested = HasOutput(holistic_node, "FACE_LANDMARKS"); + const bool is_face_blendshapes_requested = + HasOutput(holistic_node, "FACE_BLENDSHAPES"); + const bool face_requested = + is_face_requested || is_face_blendshapes_requested; + if (face_requested) { + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_face_detector_graph_options(), + kFaceDetectorModelName, create_copy)); + // Forcely set num_faces to 1, because holistic landmarker only supports a + // single subject for now. + holistic_options->mutable_face_detector_graph_options()->set_num_faces(1); + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_face_landmarks_detector_graph_options(), + kFaceLandmarksDetectorModelName, create_copy)); + if (is_face_blendshapes_requested) { + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_face_landmarks_detector_graph_options() + ->mutable_face_blendshapes_graph_options(), + kFaceBlendshapesModelName, create_copy)); + } + } + + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_pose_detector_graph_options(), + kPoseDetectorModelName, create_copy)); + // Forcely set num_poses to 1, because holistic landmarker sonly supports a + // single subject for now. + holistic_options->mutable_pose_detector_graph_options()->set_num_poses(1); + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_pose_landmarks_detector_graph_options(), + kPoseLandmarksDetectorModelName, create_copy)); + + HolisticPoseTrackingRequest pose_request = { + .landmarks = HasOutput(holistic_node, "POSE_LANDMARKS") || + hands_requested || face_requested, + .world_landmarks = + HasOutput(holistic_node, "POSE_WORLD_LANDMARKS") || hands_requested, + .segmentation_mask = + HasOutput(holistic_node, "POSE_SEGMENTATION_MASK")}; + + // Detect and track pose. + MP_ASSIGN_OR_RETURN( + HolisticPoseTrackingOutput pose_output, + TrackHolisticPose( + image, holistic_options->pose_detector_graph_options(), + holistic_options->pose_landmarks_detector_graph_options(), + pose_request, graph)); + MP_RETURN_IF_ERROR( + SetGraphPoseOutputs(pose_request, holistic_node, pose_output, graph)); + + // Detect and track hand. + if (hands_requested) { + if (is_left_hand_requested || is_left_hand_world_requested) { + RET_CHECK(pose_output.landmarks.has_value()); + RET_CHECK(pose_output.world_landmarks.has_value()); + + PoseIndices pose_indices = { + .wrist_idx = + static_cast(pose_landmarker::PoseLandmarkName::kLeftWrist), + .pinky_idx = static_cast( + pose_landmarker::PoseLandmarkName::kLeftPinky1), + .index_idx = static_cast( + pose_landmarker::PoseLandmarkName::kLeftIndex1), + }; + HolisticHandTrackingRequest hand_request = { + .landmarks = is_left_hand_requested, + .world_landmarks = is_left_hand_world_requested, + }; + MP_ASSIGN_OR_RETURN( + HolisticHandTrackingOutput hand_output, + TrackHolisticHand( + image, *pose_output.landmarks, *pose_output.world_landmarks, + holistic_options->hand_landmarks_detector_graph_options(), + holistic_options->hand_roi_refinement_graph_options(), + pose_indices, hand_request, graph + + )); + SetGraphHandOutputs(/*is_left=*/true, holistic_node, hand_output, + graph); + } + + if (is_right_hand_requested || is_right_hand_world_requested) { + RET_CHECK(pose_output.landmarks.has_value()); + RET_CHECK(pose_output.world_landmarks.has_value()); + + PoseIndices pose_indices = { + .wrist_idx = static_cast( + pose_landmarker::PoseLandmarkName::kRightWrist), + .pinky_idx = static_cast( + pose_landmarker::PoseLandmarkName::kRightPinky1), + .index_idx = static_cast( + pose_landmarker::PoseLandmarkName::kRightIndex1), + }; + HolisticHandTrackingRequest hand_request = { + .landmarks = is_right_hand_requested, + .world_landmarks = is_right_hand_world_requested, + }; + MP_ASSIGN_OR_RETURN( + HolisticHandTrackingOutput hand_output, + TrackHolisticHand( + image, *pose_output.landmarks, *pose_output.world_landmarks, + holistic_options->hand_landmarks_detector_graph_options(), + holistic_options->hand_roi_refinement_graph_options(), + pose_indices, hand_request, graph + + )); + SetGraphHandOutputs(/*is_left=*/false, holistic_node, hand_output, + graph); + } + } + + // Detect and track face. + if (face_requested) { + RET_CHECK(pose_output.landmarks.has_value()); + + Stream face_landmarks_from_pose = + api2::builder::SplitToRanges(*pose_output.landmarks, {{0, 11}}, + graph)[0]; + + HolisticFaceTrackingRequest face_request = { + .classifications = is_face_blendshapes_requested, + }; + MP_ASSIGN_OR_RETURN( + HolisticFaceTrackingOutput face_output, + TrackHolisticFace( + image, face_landmarks_from_pose, + holistic_options->face_detector_graph_options(), + holistic_options->face_landmarks_detector_graph_options(), + face_request, graph)); + SetGraphFaceOutputs(holistic_node, face_output, graph); + } + + auto& pass_through = graph.AddNode("PassThroughCalculator"); + image >> pass_through.In(""); + pass_through.Out("") >> graph.Out("IMAGE"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + return config; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::holistic_landmarker::HolisticLandmarkerGraph); + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph_test.cc new file mode 100644 index 000000000..c549a022b --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph_test.cc @@ -0,0 +1,595 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "file/base/helpers.h" +#include "file/base/options.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/data_renderer.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/googletest.h" +#include "testing/base/public/gunit.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { +namespace { + +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::core::TaskRunner; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr float kAbsMargin = 0.025; +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/vision/"; +constexpr char kHolisticResultFile[] = + "male_full_height_hands_result_cpu.pbtxt"; +constexpr absl::string_view kTestImageFile = "male_full_height_hands.jpg"; +constexpr absl::string_view kImageInStream = "image_in"; +constexpr absl::string_view kLeftHandLandmarksStream = "left_hand_landmarks"; +constexpr absl::string_view kRightHandLandmarksStream = "right_hand_landmarks"; +constexpr absl::string_view kFaceLandmarksStream = "face_landmarks"; +constexpr absl::string_view kFaceBlendshapesStream = "face_blendshapes"; +constexpr absl::string_view kPoseLandmarksStream = "pose_landmarks"; +constexpr absl::string_view kRenderedImageOutStream = "rendered_image_out"; +constexpr absl::string_view kPoseSegmentationMaskStream = + "pose_segmentation_mask"; +constexpr absl::string_view kHolisticLandmarkerModelBundleFile = + "holistic_landmarker.task"; +constexpr absl::string_view kHandLandmarksModelFile = + "hand_landmark_full.tflite"; +constexpr absl::string_view kHandRoiRefinementModelFile = + "handrecrop_2020_07_21_v0.f16.tflite"; +constexpr absl::string_view kPoseDetectionModelFile = "pose_detection.tflite"; +constexpr absl::string_view kPoseLandmarksModelFile = + "pose_landmark_lite.tflite"; +constexpr absl::string_view kFaceDetectionModelFile = + "face_detection_short_range.tflite"; +constexpr absl::string_view kFaceLandmarksModelFile = + "facemesh2_lite_iris_faceflag_2023_02_14.tflite"; +constexpr absl::string_view kFaceBlendshapesModelFile = + "face_blendshapes.tflite"; + +enum RenderPart { + HAND = 0, + POSE = 1, + FACE = 2, +}; + +mediapipe::Color GetColor(RenderPart render_part) { + mediapipe::Color color; + switch (render_part) { + case HAND: + color.set_b(255); + color.set_g(255); + color.set_r(255); + break; + case POSE: + color.set_b(0); + color.set_g(255); + color.set_r(0); + break; + case FACE: + color.set_b(0); + color.set_g(0); + color.set_r(255); + break; + } + return color; +} + +std::string GetFilePath(absl::string_view filename) { + return file::JoinPath("./", kTestDataDirectory, filename); +} + +template +mediapipe::LandmarksToRenderDataCalculatorOptions GetRendererOptions( + const std::array, N>& connections, + mediapipe::Color color) { + mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options; + for (const auto& connection : connections) { + renderer_options.add_landmark_connections(connection[0]); + renderer_options.add_landmark_connections(connection[1]); + } + *renderer_options.mutable_landmark_color() = color; + *renderer_options.mutable_connection_color() = color; + renderer_options.set_thickness(0.5); + renderer_options.set_visualize_landmark_depth(false); + return renderer_options; +} + +void ConfigureHandProtoOptions(proto::HolisticLandmarkerGraphOptions& options) { + options.mutable_hand_landmarks_detector_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kHandLandmarksModelFile)); + + options.mutable_hand_roi_refinement_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kHandRoiRefinementModelFile)); +} + +void ConfigureFaceProtoOptions(proto::HolisticLandmarkerGraphOptions& options) { + // Set face detection model. + face_detector::proto::FaceDetectorGraphOptions& face_detector_graph_options = + *options.mutable_face_detector_graph_options(); + face_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kFaceDetectionModelFile)); + face_detector_graph_options.set_num_faces(1); + + // Set face landmarks model. + face_landmarker::proto::FaceLandmarksDetectorGraphOptions& + face_landmarks_graph_options = + *options.mutable_face_landmarks_detector_graph_options(); + face_landmarks_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kFaceLandmarksModelFile)); + face_landmarks_graph_options.mutable_face_blendshapes_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kFaceBlendshapesModelFile)); +} + +void ConfigurePoseProtoOptions(proto::HolisticLandmarkerGraphOptions& options) { + pose_detector::proto::PoseDetectorGraphOptions& pose_detector_graph_options = + *options.mutable_pose_detector_graph_options(); + pose_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kPoseDetectionModelFile)); + pose_detector_graph_options.set_num_poses(1); + options.mutable_pose_landmarks_detector_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kPoseLandmarksModelFile)); +} + +struct HolisticRequest { + bool is_left_hand_requested = false; + bool is_right_hand_requested = false; + bool is_face_requested = false; + bool is_face_blendshapes_requested = false; +}; + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner( + bool use_model_bundle, HolisticRequest holistic_request) { + Graph graph; + + Stream image = graph.In("IMAEG").Cast().SetName(kImageInStream); + + auto& holistic_graph = graph.AddNode( + "mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph"); + proto::HolisticLandmarkerGraphOptions& options = + holistic_graph.GetOptions(); + if (use_model_bundle) { + options.mutable_base_options()->mutable_model_asset()->set_file_name( + GetFilePath(kHolisticLandmarkerModelBundleFile)); + } else { + ConfigureHandProtoOptions(options); + ConfigurePoseProtoOptions(options); + ConfigureFaceProtoOptions(options); + } + + std::vector> render_list; + image >> holistic_graph.In("IMAGE"); + Stream> image_size = GetImageSize(image, graph); + + if (holistic_request.is_left_hand_requested) { + Stream left_hand_landmarks = + holistic_graph.Out("LEFT_HAND_LANDMARKS") + .Cast() + .SetName(kLeftHandLandmarksStream); + Stream left_hand_tracking_roi = + holistic_graph.Out("LEFT_HAND_TRACKING_ROI").Cast(); + auto left_hand_landmarks_render_data = utils::RenderLandmarks( + left_hand_landmarks, + utils::GetRenderScale(image_size, left_hand_tracking_roi, 0.0001, + graph), + GetRendererOptions(hand_landmarker::kHandConnections, + GetColor(RenderPart::HAND)), + graph); + render_list.push_back(left_hand_landmarks_render_data); + left_hand_landmarks >> graph.Out("LEFT_HAND_LANDMARKS"); + } + if (holistic_request.is_right_hand_requested) { + Stream right_hand_landmarks = + holistic_graph.Out("RIGHT_HAND_LANDMARKS") + .Cast() + .SetName(kRightHandLandmarksStream); + Stream right_hand_tracking_roi = + holistic_graph.Out("RIGHT_HAND_TRACKING_ROI").Cast(); + auto right_hand_landmarks_render_data = utils::RenderLandmarks( + right_hand_landmarks, + utils::GetRenderScale(image_size, right_hand_tracking_roi, 0.0001, + graph), + GetRendererOptions(hand_landmarker::kHandConnections, + GetColor(RenderPart::HAND)), + graph); + render_list.push_back(right_hand_landmarks_render_data); + right_hand_landmarks >> graph.Out("RIGHT_HAND_LANDMARKS"); + } + if (holistic_request.is_face_requested) { + Stream face_landmarks = + holistic_graph.Out("FACE_LANDMARKS") + .Cast() + .SetName(kFaceLandmarksStream); + Stream face_tracking_roi = + holistic_graph.Out("FACE_TRACKING_ROI").Cast(); + auto face_landmarks_render_data = utils::RenderLandmarks( + face_landmarks, + utils::GetRenderScale(image_size, face_tracking_roi, 0.0001, graph), + GetRendererOptions( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksConnectors, + GetColor(RenderPart::FACE)), + graph); + render_list.push_back(face_landmarks_render_data); + face_landmarks >> graph.Out("FACE_LANDMARKS"); + } + if (holistic_request.is_face_blendshapes_requested) { + Stream face_blendshapes = + holistic_graph.Out("FACE_BLENDSHAPES") + .Cast() + .SetName(kFaceBlendshapesStream); + face_blendshapes >> graph.Out("FACE_BLENDSHAPES"); + } + Stream pose_landmarks = + holistic_graph.Out("POSE_LANDMARKS") + .Cast() + .SetName(kPoseLandmarksStream); + Stream pose_tracking_roi = + holistic_graph.Out("POSE_LANDMARKS_ROI").Cast(); + Stream pose_segmentation_mask = + holistic_graph.Out("POSE_SEGMENTATION_MASK") + .Cast() + .SetName(kPoseSegmentationMaskStream); + + auto pose_landmarks_render_data = utils::RenderLandmarks( + pose_landmarks, + utils::GetRenderScale(image_size, pose_tracking_roi, 0.0001, graph), + GetRendererOptions(pose_landmarker::kPoseLandmarksConnections, + GetColor(RenderPart::POSE)), + graph); + render_list.push_back(pose_landmarks_render_data); + auto rendered_image = + utils::Render( + image, absl::Span>(render_list), graph) + .SetName(kRenderedImageOutStream); + + pose_landmarks >> graph.Out("POSE_LANDMARKS"); + pose_segmentation_mask >> graph.Out("POSE_SEGMENTATION_MASK"); + rendered_image >> graph.Out("RENDERED_IMAGE"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + + return TaskRunner::Create( + config, std::make_unique()); +} + +template +absl::StatusOr FetchResult(const core::PacketMap& output_packets, + absl::string_view stream_name) { + auto it = output_packets.find(std::string(stream_name)); + RET_CHECK(it != output_packets.end()); + return it->second.Get(); +} + +// Remove fields not to be checked in the result, since the model +// generating expected result is different from the testing model. +void RemoveUncheckedResult(proto::HolisticResult& holistic_result) { + for (auto& landmark : + *holistic_result.mutable_pose_landmarks()->mutable_landmark()) { + landmark.clear_z(); + landmark.clear_visibility(); + landmark.clear_presence(); + } + for (auto& landmark : + *holistic_result.mutable_face_landmarks()->mutable_landmark()) { + landmark.clear_z(); + landmark.clear_visibility(); + landmark.clear_presence(); + } + for (auto& landmark : + *holistic_result.mutable_left_hand_landmarks()->mutable_landmark()) { + landmark.clear_z(); + landmark.clear_visibility(); + landmark.clear_presence(); + } + for (auto& landmark : + *holistic_result.mutable_right_hand_landmarks()->mutable_landmark()) { + landmark.clear_z(); + landmark.clear_visibility(); + landmark.clear_presence(); + } +} + +std::string RequestToString(HolisticRequest request) { + return absl::StrFormat( + "%s_%s_%s_%s", + request.is_left_hand_requested ? "left_hand" : "no_left_hand", + request.is_right_hand_requested ? "right_hand" : "no_right_hand", + request.is_face_requested ? "face" : "no_face", + request.is_face_blendshapes_requested ? "face_blendshapes" + : "no_face_blendshapes"); +} + +struct TestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of test image. + std::string test_image_name; + // Whether to use holistic model bundle to test. + bool use_model_bundle; + // Requests of holistic parts. + HolisticRequest holistic_request; +}; + +class SmokeTest : public testing::TestWithParam {}; + +TEST_P(SmokeTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(GetFilePath(GetParam().test_image_name))); + + proto::HolisticResult holistic_result; + MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result, + ::file::Defaults())); + RemoveUncheckedResult(holistic_result); + + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, + CreateTaskRunner(GetParam().use_model_bundle, + GetParam().holistic_request)); + MP_ASSERT_OK_AND_ASSIGN(auto output_packets, + task_runner->Process({{std::string(kImageInStream), + MakePacket(image)}})); + + // Check face landmarks + if (GetParam().holistic_request.is_face_requested) { + MP_ASSERT_OK_AND_ASSIGN(auto face_landmarks, + FetchResult( + output_packets, kFaceLandmarksStream)); + EXPECT_THAT( + face_landmarks, + Approximately(Partially(EqualsProto(holistic_result.face_landmarks())), + /*margin=*/kAbsMargin)); + } else { + ASSERT_FALSE(output_packets.contains(std::string(kFaceLandmarksStream))); + } + + if (GetParam().holistic_request.is_face_blendshapes_requested) { + MP_ASSERT_OK_AND_ASSIGN(auto face_blendshapes, + FetchResult( + output_packets, kFaceBlendshapesStream)); + EXPECT_THAT(face_blendshapes, + Approximately( + Partially(EqualsProto(holistic_result.face_blendshapes())), + /*margin=*/kAbsMargin)); + } else { + ASSERT_FALSE(output_packets.contains(std::string(kFaceBlendshapesStream))); + } + + // Check Pose landmarks + MP_ASSERT_OK_AND_ASSIGN(auto pose_landmarks, + FetchResult( + output_packets, kPoseLandmarksStream)); + EXPECT_THAT( + pose_landmarks, + Approximately(Partially(EqualsProto(holistic_result.pose_landmarks())), + /*margin=*/kAbsMargin)); + + // Check Hand landmarks + if (GetParam().holistic_request.is_left_hand_requested) { + MP_ASSERT_OK_AND_ASSIGN(auto left_hand_landmarks, + FetchResult( + output_packets, kLeftHandLandmarksStream)); + EXPECT_THAT( + left_hand_landmarks, + Approximately( + Partially(EqualsProto(holistic_result.left_hand_landmarks())), + /*margin=*/kAbsMargin)); + } else { + ASSERT_FALSE( + output_packets.contains(std::string(kLeftHandLandmarksStream))); + } + + if (GetParam().holistic_request.is_right_hand_requested) { + MP_ASSERT_OK_AND_ASSIGN(auto right_hand_landmarks, + FetchResult( + output_packets, kRightHandLandmarksStream)); + EXPECT_THAT( + right_hand_landmarks, + Approximately( + Partially(EqualsProto(holistic_result.right_hand_landmarks())), + /*margin=*/kAbsMargin)); + } else { + ASSERT_FALSE( + output_packets.contains(std::string(kRightHandLandmarksStream))); + } + + auto rendered_image = + output_packets.at(std::string(kRenderedImageOutStream)).Get(); + MP_EXPECT_OK(SavePngTestOutput( + *rendered_image.GetImageFrameSharedPtr(), + absl::StrCat("holistic_landmark_", + RequestToString(GetParam().holistic_request)))); + + auto pose_segmentation_mask = + output_packets.at(std::string(kPoseSegmentationMaskStream)).Get(); + + cv::Mat matting_mask = mediapipe::formats::MatView( + pose_segmentation_mask.GetImageFrameSharedPtr().get()); + cv::Mat visualized_mask; + matting_mask.convertTo(visualized_mask, CV_8UC1, 255); + ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8, + visualized_mask.cols, visualized_mask.rows, + visualized_mask.step, visualized_mask.data, + [visualized_mask](uint8_t[]) {}); + + MP_EXPECT_OK( + SavePngTestOutput(visualized_image, "holistic_pose_segmentation_mask")); +} + +INSTANTIATE_TEST_SUITE_P( + HolisticLandmarkerGraphTest, SmokeTest, + Values(TestParams{ + /* test_name= */ "UseModelBundle", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ true, + /*is_right_hand_requested= */ true, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ true, + }, + }, + TestParams{ + /* test_name= */ "UseSeparateModelFiles", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ false, + /* holistic_request= */ + { + /*is_left_hand_requested= */ true, + /*is_right_hand_requested= */ true, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ true, + }, + }, + TestParams{ + /* test_name= */ "ModelBundleNoLeftHand", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ false, + /*is_right_hand_requested= */ true, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ true, + }, + }, + TestParams{ + /* test_name= */ "ModelBundleNoRightHand", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ true, + /*is_right_hand_requested= */ false, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ true, + }, + }, + TestParams{ + /* test_name= */ "ModelBundleNoHand", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ false, + /*is_right_hand_requested= */ false, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ true, + }, + }, + TestParams{ + /* test_name= */ "ModelBundleNoFace", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ true, + /*is_right_hand_requested= */ true, + /*is_face_requested= */ false, + /*is_face_blendshapes_requested= */ false, + }, + }, + TestParams{ + /* test_name= */ "ModelBundleNoFaceBlendshapes", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ true, + /*is_right_hand_requested= */ true, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ false, + }, + }), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.cc new file mode 100644 index 000000000..860035ad0 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.cc @@ -0,0 +1,307 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/detections_to_rects.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/api2/stream/landmarks_to_detection.h" +#include "mediapipe/framework/api2/stream/loopback.h" +#include "mediapipe/framework/api2/stream/merge.h" +#include "mediapipe/framework/api2/stream/presence.h" +#include "mediapipe/framework/api2/stream/rect_transformation.h" +#include "mediapipe/framework/api2/stream/segmentation_smoothing.h" +#include "mediapipe/framework/api2/stream/smoothing.h" +#include "mediapipe/framework/api2/stream/split.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/components/utils/gate.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionsToRect; +using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionToRect; +using ::mediapipe::api2::builder::ConvertLandmarksToDetection; +using ::mediapipe::api2::builder::GenericNode; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::GetLoopbackData; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::IsPresent; +using ::mediapipe::api2::builder::Merge; +using ::mediapipe::api2::builder::ScaleAndMakeSquare; +using ::mediapipe::api2::builder::SmoothLandmarks; +using ::mediapipe::api2::builder::SmoothLandmarksVisibility; +using ::mediapipe::api2::builder::SmoothSegmentationMask; +using ::mediapipe::api2::builder::SplitToRanges; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::components::utils::DisallowIf; +using Size = std::pair; + +constexpr int kAuxLandmarksStartKeypointIndex = 0; +constexpr int kAuxLandmarksEndKeypointIndex = 1; +constexpr float kAuxLandmarksTargetAngle = 90; +constexpr float kRoiFromDetectionScaleFactor = 1.25f; +constexpr float kRoiFromLandmarksScaleFactor = 1.25f; + +Stream CalculateRoiFromDetections( + Stream> detections, Stream image_size, + Graph& graph) { + auto roi = ConvertAlignmentPointsDetectionsToRect(detections, image_size, + /*start_keypoint_index=*/0, + /*end_keypoint_index=*/1, + /*target_angle=*/90, graph); + return ScaleAndMakeSquare( + roi, image_size, /*scale_x_factor=*/kRoiFromDetectionScaleFactor, + /*scale_y_factor=*/kRoiFromDetectionScaleFactor, graph); +} + +Stream CalculateScaleRoiFromAuxiliaryLandmarks( + Stream landmarks, Stream image_size, + Graph& graph) { + // TODO: consider calculating ROI directly from landmarks. + auto detection = ConvertLandmarksToDetection(landmarks, graph); + return ConvertAlignmentPointsDetectionToRect( + detection, image_size, kAuxLandmarksStartKeypointIndex, + kAuxLandmarksEndKeypointIndex, kAuxLandmarksTargetAngle, graph); +} + +Stream CalculateRoiFromAuxiliaryLandmarks( + Stream landmarks, Stream image_size, + Graph& graph) { + // TODO: consider calculating ROI directly from landmarks. + auto detection = ConvertLandmarksToDetection(landmarks, graph); + auto roi = ConvertAlignmentPointsDetectionToRect( + detection, image_size, kAuxLandmarksStartKeypointIndex, + kAuxLandmarksEndKeypointIndex, kAuxLandmarksTargetAngle, graph); + return ScaleAndMakeSquare( + roi, image_size, /*scale_x_factor=*/kRoiFromLandmarksScaleFactor, + /*scale_y_factor=*/kRoiFromLandmarksScaleFactor, graph); +} + +struct PoseLandmarksResult { + std::optional> landmarks; + std::optional> world_landmarks; + std::optional> auxiliary_landmarks; + std::optional> segmentation_mask; +}; + +PoseLandmarksResult RunLandmarksDetection( + Stream image, Stream roi, + const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions& + pose_landmarks_detector_graph_options, + const HolisticPoseTrackingRequest& request, Graph& graph) { + GenericNode& landmarks_graph = graph.AddNode( + "mediapipe.tasks.vision.pose_landmarker." + "SinglePoseLandmarksDetectorGraph"); + landmarks_graph + .GetOptions() = + pose_landmarks_detector_graph_options; + image >> landmarks_graph.In("IMAGE"); + roi >> landmarks_graph.In("NORM_RECT"); + + PoseLandmarksResult result; + if (request.landmarks) { + result.landmarks = + landmarks_graph.Out("LANDMARKS").Cast(); + result.auxiliary_landmarks = landmarks_graph.Out("AUXILIARY_LANDMARKS") + .Cast(); + } + if (request.world_landmarks) { + result.world_landmarks = + landmarks_graph.Out("WORLD_LANDMARKS").Cast(); + } + if (request.segmentation_mask) { + result.segmentation_mask = + landmarks_graph.Out("SEGMENTATION_MASK").Cast(); + } + return result; +} + +} // namespace + +absl::StatusOr +TrackHolisticPoseUsingCustomPoseDetection( + Stream image, PoseDetectionFn pose_detection_fn, + const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions& + pose_landmarks_detector_graph_options, + const HolisticPoseTrackingRequest& request, Graph& graph) { + // Calculate ROI from scratch (pose detection) or reuse one from the + // previous run if available. + auto [previous_roi, set_previous_roi_fn] = + GetLoopbackData(/*tick=*/image, graph); + auto is_previous_roi_available = IsPresent(previous_roi, graph); + auto image_for_detection = + DisallowIf(image, is_previous_roi_available, graph); + MP_ASSIGN_OR_RETURN(auto pose_detections, + pose_detection_fn(image_for_detection, graph)); + auto roi_from_detections = CalculateRoiFromDetections( + pose_detections, GetImageSize(image_for_detection, graph), graph); + // Take first non-empty. + auto roi = Merge(roi_from_detections, previous_roi, graph); + + // Calculate landmarks and other outputs (if requested) in the specified ROI. + auto landmarks_detection_result = RunLandmarksDetection( + image, roi, pose_landmarks_detector_graph_options, + { + // Landmarks are required for tracking, hence force-requesting them. + .landmarks = true, + .world_landmarks = request.world_landmarks, + .segmentation_mask = request.segmentation_mask, + }, + graph); + RET_CHECK(landmarks_detection_result.landmarks.has_value() && + landmarks_detection_result.auxiliary_landmarks.has_value()) + << "Failed to calculate landmarks required for tracking."; + + // Split landmarks to pose landmarks and auxiliary landmarks. + auto pose_landmarks_raw = *landmarks_detection_result.landmarks; + auto auxiliary_landmarks = *landmarks_detection_result.auxiliary_landmarks; + + auto image_size = GetImageSize(image, graph); + + // TODO: b/305750053 - Apply adaptive crop by adding AdaptiveCropCalculator. + + // Calculate ROI from smoothed auxiliary landmarks. + auto scale_roi = CalculateScaleRoiFromAuxiliaryLandmarks(auxiliary_landmarks, + image_size, graph); + auto auxiliary_landmarks_smoothed = SmoothLandmarks( + auxiliary_landmarks, image_size, scale_roi, + {// Min cutoff 0.01 results into ~0.002 alpha in landmark EMA filter when + // landmark is static. + .min_cutoff = 0.01, + // Beta 10.0 in combintation with min_cutoff 0.01 results into ~0.68 + // alpha in landmark EMA filter when landmark is moving fast. + .beta = 10.0, + // Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity + // EMA filter. + .derivate_cutoff = 1.0}, + graph); + auto roi_from_auxiliary_landmarks = CalculateRoiFromAuxiliaryLandmarks( + auxiliary_landmarks_smoothed, image_size, graph); + + // Make ROI from auxiliary landmarks to be used as "previous" ROI for a + // subsequent run. + set_previous_roi_fn(roi_from_auxiliary_landmarks); + + // Populate and smooth pose landmarks if corresponding output has been + // requested. + std::optional> pose_landmarks; + if (request.landmarks) { + pose_landmarks = SmoothLandmarksVisibility( + pose_landmarks_raw, /*low_pass_filter_alpha=*/0.1f, graph); + pose_landmarks = SmoothLandmarks( + *pose_landmarks, image_size, scale_roi, + {// Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter when + // landmark is static. + .min_cutoff = 0.05f, + // Beta 80.0 in combination with min_cutoff 0.05 results into ~0.94 + // alpha in landmark EMA filter when landmark is moving fast. + .beta = 80.0f, + // Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity + // EMA filter. + .derivate_cutoff = 1.0f}, + graph); + } + + // Populate and smooth world landmarks if available. + std::optional> world_landmarks; + if (landmarks_detection_result.world_landmarks) { + world_landmarks = SplitToRanges(*landmarks_detection_result.world_landmarks, + /*ranges*/ {{0, 33}}, graph)[0]; + world_landmarks = SmoothLandmarksVisibility( + *world_landmarks, /*low_pass_filter_alpha=*/0.1f, graph); + world_landmarks = SmoothLandmarks( + *world_landmarks, + /*scale_roi=*/std::nullopt, + {// Min cutoff 0.1 results into ~ 0.02 alpha in landmark EMA filter when + // landmark is static. + .min_cutoff = 0.1f, + // Beta 40.0 in combination with min_cutoff 0.1 results into ~0.8 + // alpha in landmark EMA filter when landmark is moving fast. + .beta = 40.0f, + // Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity + // EMA filter. + .derivate_cutoff = 1.0f}, + graph); + } + + // Populate and smooth segmentation mask if available. + std::optional> segmentation_mask; + if (landmarks_detection_result.segmentation_mask) { + auto mask = *landmarks_detection_result.segmentation_mask; + auto [prev_mask_as_img, set_prev_mask_as_img_fn] = + GetLoopbackData( + /*tick=*/*landmarks_detection_result.segmentation_mask, graph); + auto mask_smoothed = + SmoothSegmentationMask(mask, prev_mask_as_img, + /*combine_with_previous_ratio=*/0.7f, graph); + set_prev_mask_as_img_fn(mask_smoothed); + segmentation_mask = mask_smoothed; + } + + return {{/*landmarks=*/pose_landmarks, + /*world_landmarks=*/world_landmarks, + /*segmentation_mask=*/segmentation_mask, + /*debug_output=*/ + {/*auxiliary_landmarks=*/auxiliary_landmarks_smoothed, + /*roi_from_landmarks=*/roi_from_auxiliary_landmarks, + /*detections*/ pose_detections}}}; +} + +absl::StatusOr TrackHolisticPose( + Stream image, + const pose_detector::proto::PoseDetectorGraphOptions& + pose_detector_graph_options, + const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions& + pose_landmarks_detector_graph_options, + const HolisticPoseTrackingRequest& request, Graph& graph) { + PoseDetectionFn pose_detection_fn = [&pose_detector_graph_options]( + Stream image, Graph& graph) + -> absl::StatusOr>> { + GenericNode& pose_detector = + graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph"); + pose_detector.GetOptions() = + pose_detector_graph_options; + image >> pose_detector.In("IMAGE"); + return pose_detector.Out("DETECTIONS") + .Cast>(); + }; + return TrackHolisticPoseUsingCustomPoseDetection( + image, pose_detection_fn, pose_landmarks_detector_graph_options, request, + graph); +} + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h new file mode 100644 index 000000000..f51ccc283 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h @@ -0,0 +1,110 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +// Type of pose detection function that can be used to customize pose tracking, +// by supplying the function into a corresponding `TrackPose` function overload. +// +// Function should update provided graph with node/nodes that accept image +// stream and produce stream of detections. +using PoseDetectionFn = std::function< + absl::StatusOr>>( + api2::builder::Stream, api2::builder::Graph&)>; + +struct HolisticPoseTrackingRequest { + bool landmarks = false; + bool world_landmarks = false; + bool segmentation_mask = false; +}; + +struct HolisticPoseTrackingOutput { + std::optional> + landmarks; + std::optional> world_landmarks; + std::optional> segmentation_mask; + + struct DebugOutput { + api2::builder::Stream + auxiliary_landmarks; + api2::builder::Stream roi_from_landmarks; + api2::builder::Stream> detections; + }; + + DebugOutput debug_output; +}; + +// Updates @graph to track pose in @image. +// +// @image - ImageFrame/GpuBuffer to track pose in. +// @pose_detection_fn - pose detection function that takes @image as input and +// produces stream of pose detections. +// @pose_landmarks_detector_graph_options - options of the +// PoseLandmarksDetectorGraph used to detect the pose landmarks. +// @request - object to request specific pose tracking outputs. +// NOTE: Outputs that were not requested won't be returned and corresponding +// parts of the graph won't be genertaed all. +// @graph - graph to update. +absl::StatusOr +TrackHolisticPoseUsingCustomPoseDetection( + api2::builder::Stream image, PoseDetectionFn pose_detection_fn, + const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions& + pose_landmarks_detector_graph_options, + const HolisticPoseTrackingRequest& request, api2::builder::Graph& graph); + +// Updates @graph to track pose in @image. +// +// @image - ImageFrame/GpuBuffer to track pose in. +// @pose_detector_graph_options - options of the PoseDetectorGraph used to +// detect the pose. +// @pose_landmarks_detector_graph_options - options of the +// PoseLandmarksDetectorGraph used to detect the pose landmarks. +// @request - object to request specific pose tracking outputs. +// NOTE: Outputs that were not requested won't be returned and corresponding +// parts of the graph won't be genertaed all. +// @graph - graph to update. +absl::StatusOr TrackHolisticPose( + api2::builder::Stream image, + const pose_detector::proto::PoseDetectorGraphOptions& + pose_detector_graph_options, + const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions& + pose_landmarks_detector_graph_options, + const HolisticPoseTrackingRequest& request, api2::builder::Graph& graph); + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_ diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking_test.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking_test.cc new file mode 100644 index 000000000..0bf7259e8 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking_test.cc @@ -0,0 +1,243 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h" + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/span.h" +#include "file/base/helpers.h" +#include "file/base/options.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/data_renderer.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" +#include "testing/base/public/googletest.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::Image; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::core::TaskRunner; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr float kAbsMargin = 0.025; +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/vision/"; +constexpr absl::string_view kTestImageFile = "male_full_height_hands.jpg"; +constexpr absl::string_view kImageInStream = "image_in"; +constexpr absl::string_view kPoseLandmarksOutStream = "pose_landmarks_out"; +constexpr absl::string_view kPoseWorldLandmarksOutStream = + "pose_world_landmarks_out"; +constexpr absl::string_view kRenderedImageOutStream = "rendered_image_out"; +constexpr absl::string_view kHolisticResultFile = + "male_full_height_hands_result_cpu.pbtxt"; +constexpr absl::string_view kHolisticPoseTrackingGraph = + "holistic_pose_tracking_graph.pbtxt"; + +std::string GetFilePath(absl::string_view filename) { + return file::JoinPath("./", kTestDataDirectory, filename); +} + +mediapipe::LandmarksToRenderDataCalculatorOptions GetPoseRendererOptions() { + mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options; + for (const auto& connection : pose_landmarker::kPoseLandmarksConnections) { + renderer_options.add_landmark_connections(connection[0]); + renderer_options.add_landmark_connections(connection[1]); + } + renderer_options.mutable_landmark_color()->set_r(255); + renderer_options.mutable_landmark_color()->set_g(255); + renderer_options.mutable_landmark_color()->set_b(255); + renderer_options.mutable_connection_color()->set_r(255); + renderer_options.mutable_connection_color()->set_g(255); + renderer_options.mutable_connection_color()->set_b(255); + renderer_options.set_thickness(0.5); + renderer_options.set_visualize_landmark_depth(false); + return renderer_options; +} + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner() { + Graph graph; + Stream image = graph.In("IMAGE").Cast().SetName(kImageInStream); + pose_detector::proto::PoseDetectorGraphOptions pose_detector_graph_options; + pose_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath("pose_detection.tflite")); + pose_detector_graph_options.set_num_poses(1); + pose_landmarker::proto::PoseLandmarksDetectorGraphOptions + pose_landmarks_detector_graph_options; + pose_landmarks_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath("pose_landmark_lite.tflite")); + + HolisticPoseTrackingRequest request; + request.landmarks = true; + request.world_landmarks = true; + MP_ASSIGN_OR_RETURN( + HolisticPoseTrackingOutput result, + TrackHolisticPose(image, pose_detector_graph_options, + pose_landmarks_detector_graph_options, request, graph)); + + auto image_size = GetImageSize(image, graph); + auto render_data = utils::RenderLandmarks( + *result.landmarks, + utils::GetRenderScale(image_size, result.debug_output.roi_from_landmarks, + 0.0001, graph), + GetPoseRendererOptions(), graph); + std::vector> render_list = {render_data}; + auto rendered_image = + utils::Render( + image, absl::Span>(render_list), graph) + .SetName(kRenderedImageOutStream); + + rendered_image >> graph.Out("RENDERED_IMAGE"); + result.landmarks->SetName(kPoseLandmarksOutStream) >> + graph.Out("POSE_LANDMARKS"); + result.world_landmarks->SetName(kPoseWorldLandmarksOutStream) >> + graph.Out("POSE_WORLD_LANDMARKS"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + + return TaskRunner::Create( + config, std::make_unique()); +} + +// Remove fields not to be checked in the result, since the model +// generating expected result is different from the testing model. +void RemoveUncheckedResult(proto::HolisticResult& holistic_result) { + for (auto& landmark : + *holistic_result.mutable_pose_landmarks()->mutable_landmark()) { + landmark.clear_z(); + landmark.clear_visibility(); + landmark.clear_presence(); + } +} + +class HolisticPoseTrackingTest : public testing::Test {}; + +TEST_F(HolisticPoseTrackingTest, VerifyGraph) { + Graph graph; + Stream image = graph.In("IMAGE").Cast().SetName(kImageInStream); + pose_detector::proto::PoseDetectorGraphOptions pose_detector_graph_options; + pose_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath("pose_detection.tflite")); + pose_detector_graph_options.set_num_poses(1); + pose_landmarker::proto::PoseLandmarksDetectorGraphOptions + pose_landmarks_detector_graph_options; + pose_landmarks_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath("pose_landmark_lite.tflite")); + HolisticPoseTrackingRequest request; + request.landmarks = true; + request.world_landmarks = true; + MP_ASSERT_OK_AND_ASSIGN( + HolisticPoseTrackingOutput result, + TrackHolisticPose(image, pose_detector_graph_options, + pose_landmarks_detector_graph_options, request, graph)); + result.landmarks->SetName(kPoseLandmarksOutStream) >> + graph.Out("POSE_LANDMARKS"); + result.world_landmarks->SetName(kPoseWorldLandmarksOutStream) >> + graph.Out("POSE_WORLD_LANDMARKS"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + + // Read the expected graph config. + std::string expected_graph_contents; + MP_ASSERT_OK(file::GetContents( + file::JoinPath("./", kTestDataDirectory, kHolisticPoseTrackingGraph), + &expected_graph_contents)); + + // Need to replace the expected graph config with the test srcdir, because + // each run has different test dir on TAP. + expected_graph_contents = absl::Substitute( + expected_graph_contents, FLAGS_test_srcdir, FLAGS_test_srcdir); + CalculatorGraphConfig expected_graph = + ParseTextProtoOrDie(expected_graph_contents); + + EXPECT_THAT(config, testing::proto::IgnoringRepeatedFieldOrdering( + testing::EqualsProto(expected_graph))); +} + +TEST_F(HolisticPoseTrackingTest, SmokeTest) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(GetFilePath(kTestImageFile))); + + proto::HolisticResult holistic_result; + MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result, + Defaults())); + RemoveUncheckedResult(holistic_result); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); + MP_ASSERT_OK_AND_ASSIGN(auto output_packets, + task_runner->Process({{std::string(kImageInStream), + MakePacket(image)}})); + auto pose_landmarks = output_packets.at(std::string(kPoseLandmarksOutStream)) + .Get(); + EXPECT_THAT( + pose_landmarks, + Approximately(Partially(EqualsProto(holistic_result.pose_landmarks())), + /*margin=*/kAbsMargin)); + auto rendered_image = + output_packets.at(std::string(kRenderedImageOutStream)).Get(); + MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(), + "pose_landmarks")); +} + +} // namespace +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/BUILD new file mode 100644 index 000000000..147f3cc86 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/BUILD @@ -0,0 +1,44 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "holistic_result_proto", + srcs = ["holistic_result.proto"], + deps = [ + "//mediapipe/framework/formats:classification_proto", + "//mediapipe/framework/formats:landmark_proto", + ], +) + +mediapipe_proto_library( + name = "holistic_landmarker_graph_options_proto", + srcs = ["holistic_landmarker_graph_options.proto"], + deps = [ + "//mediapipe/tasks/cc/core/proto:base_options_proto", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_proto", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_proto", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.proto new file mode 100644 index 000000000..86aba8887 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.proto @@ -0,0 +1,57 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.vision.holistic_landmarker.proto; + +import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.proto"; +import "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.holisticlandmarker.proto"; +option java_outer_classname = "HolisticLandmarkerGraphOptionsProto"; + +message HolisticLandmarkerGraphOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // asset bundle file with metadata, accelerator options, etc. + core.proto.BaseOptions base_options = 1; + + // Options for hand landmarks graph. + hand_landmarker.proto.HandLandmarksDetectorGraphOptions + hand_landmarks_detector_graph_options = 2; + + // Options for hand roi refinement graph. + hand_landmarker.proto.HandRoiRefinementGraphOptions + hand_roi_refinement_graph_options = 3; + + // Options for face detector graph. + face_detector.proto.FaceDetectorGraphOptions face_detector_graph_options = 4; + + // Options for face landmarks detector graph. + face_landmarker.proto.FaceLandmarksDetectorGraphOptions + face_landmarks_detector_graph_options = 5; + + // Options for pose detector graph. + pose_detector.proto.PoseDetectorGraphOptions pose_detector_graph_options = 6; + + // Options for pose landmarks detector graph. + pose_landmarker.proto.PoseLandmarksDetectorGraphOptions + pose_landmarks_detector_graph_options = 7; +} diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.proto b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.proto new file mode 100644 index 000000000..356da45d9 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.proto @@ -0,0 +1,34 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.vision.holistic_landmarker.proto; + +import "mediapipe/framework/formats/classification.proto"; +import "mediapipe/framework/formats/landmark.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.holisticlandmarker"; +option java_outer_classname = "HolisticResultProto"; + +message HolisticResult { + mediapipe.NormalizedLandmarkList pose_landmarks = 1; + mediapipe.LandmarkList pose_world_landmarks = 7; + mediapipe.NormalizedLandmarkList left_hand_landmarks = 2; + mediapipe.NormalizedLandmarkList right_hand_landmarks = 3; + mediapipe.NormalizedLandmarkList face_landmarks = 4; + mediapipe.ClassificationList face_blendshapes = 6; + mediapipe.NormalizedLandmarkList auxiliary_landmarks = 5; +} From 62bafd39bb5ff53ffecbfff2de85b00d875e86d7 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 28 Nov 2023 14:57:41 -0800 Subject: [PATCH 8/9] HolisticLandmarker Java API PiperOrigin-RevId: 586113048 --- .../mediapipe/tasks/mediapipe_tasks_aar.bzl | 7 +- .../com/google/mediapipe/tasks/vision/BUILD | 12 + .../holisticlandmarker/AndroidManifest.xml | 8 + .../HolisticLandmarker.java | 668 ++++++++++++++++++ .../holisticlandmarker/AndroidManifest.xml | 24 + .../tasks/vision/holisticlandmarker/BUILD | 19 + .../HolisticLandmarkerTest.java | 512 ++++++++++++++ mediapipe/tasks/testdata/vision/BUILD | 1 + 8 files changed, 1248 insertions(+), 3 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarkerTest.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 916323372..e63695e31 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -47,13 +47,14 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index fc56bfa27..2d5ef7a9c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -67,6 +67,7 @@ cc_binary( "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", "//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", + "//mediapipe/tasks/cc/vision/holistic_landmarker:holistic_landmarker_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", @@ -429,6 +430,7 @@ filegroup( android_library( name = "holisticlandmarker", srcs = [ + "holisticlandmarker/HolisticLandmarker.java", "holisticlandmarker/HolisticLandmarkerResult.java", ], javacopts = [ @@ -439,10 +441,20 @@ android_library( ":core", "//mediapipe/framework/formats:classification_java_proto_lite", "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:any_java_proto", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml new file mode 100644 index 000000000..a90c388f4 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java new file mode 100644 index 000000000..e80da4fca --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java @@ -0,0 +1,668 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.holisticlandmarker; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.ByteBufferImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.facedetector.proto.FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions; +import com.google.mediapipe.tasks.vision.facelandmarker.proto.FaceLandmarksDetectorGraphOptionsProto.FaceLandmarksDetectorGraphOptions; +import com.google.mediapipe.tasks.vision.handlandmarker.proto.HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions; +import com.google.mediapipe.tasks.vision.holisticlandmarker.proto.HolisticLandmarkerGraphOptionsProto.HolisticLandmarkerGraphOptions; +import com.google.mediapipe.tasks.vision.posedetector.proto.PoseDetectorGraphOptionsProto.PoseDetectorGraphOptions; +import com.google.mediapipe.tasks.vision.poselandmarker.proto.PoseLandmarksDetectorGraphOptionsProto.PoseLandmarksDetectorGraphOptions; +import com.google.protobuf.Any; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs holistic landmarks detection on images. + * + *

This API expects a pre-trained holistic landmarks model asset bundle. + * + *

    + *
  • Input image {@link MPImage} + *
      + *
    • The image that holistic landmarks detection runs on. + *
    + *
  • Output {@link HolisticLandmarkerResult} + *
      + *
    • A HolisticLandmarkerResult containing holistic landmarks. + *
    + *
+ */ +public final class HolisticLandmarker extends BaseVisionTaskApi { + private static final String TAG = HolisticLandmarker.class.getSimpleName(); + + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String POSE_LANDMARKS_STREAM = "pose_landmarks"; + private static final String POSE_WORLD_LANDMARKS_STREAM = "pose_world_landmarks"; + private static final String POSE_SEGMENTATION_MASK_STREAM = "pose_segmentation_mask"; + private static final String FACE_LANDMARKS_STREAM = "face_landmarks"; + private static final String FACE_BLENDSHAPES_STREAM = "extra_blendshapes"; + private static final String LEFT_HAND_LANDMARKS_STREAM = "left_hand_landmarks"; + private static final String LEFT_HAND_WORLD_LANDMARKS_STREAM = "left_hand_world_landmarks"; + private static final String RIGHT_HAND_LANDMARKS_STREAM = "right_hand_landmarks"; + private static final String RIGHT_HAND_WORLD_LANDMARKS_STREAM = "right_hand_world_landmarks"; + private static final String IMAGE_OUT_STREAM_NAME = "image_out"; + + private static final int FACE_LANDMARKS_OUT_STREAM_INDEX = 0; + private static final int POSE_LANDMARKS_OUT_STREAM_INDEX = 1; + private static final int POSE_WORLD_LANDMARKS_OUT_STREAM_INDEX = 2; + private static final int LEFT_HAND_LANDMARKS_OUT_STREAM_INDEX = 3; + private static final int LEFT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX = 4; + private static final int RIGHT_HAND_LANDMARKS_OUT_STREAM_INDEX = 5; + private static final int RIGHT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX = 6; + private static final int IMAGE_OUT_STREAM_INDEX = 7; + + private static final float DEFAULT_PRESENCE_THRESHOLD = 0.5f; + private static final float DEFAULT_SUPPRESION_THRESHOLD = 0.3f; + private static final boolean DEFAULT_OUTPUT_FACE_BLENDSHAPES = false; + private static final boolean DEFAULT_OUTPUT_SEGMENTATION_MASKS = false; + + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph"; + + @SuppressWarnings("ConstantCaseForConstants") + private static final List INPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME)); + + static { + System.loadLibrary("mediapipe_tasks_vision_jni"); + } + + /** + * Creates a {@link HolisticLandmarker} instance from a model asset bundle path and the default + * {@link HolisticLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelAssetPath path to the holistic landmarks model with metadata in the assets. + * @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation. + */ + public static HolisticLandmarker createFromFile(Context context, String modelAssetPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelAssetPath).build(); + return createFromOptions( + context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link HolisticLandmarker} instance from a model asset bundle file and the default + * {@link HolisticLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelAssetFile the holistic landmarks model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation. + */ + public static HolisticLandmarker createFromFile(Context context, File modelAssetFile) + throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelAssetFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link HolisticLandmarker} instance from a model asset bundle buffer and the default + * {@link HolisticLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelAssetBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * detection model. + * @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation. + */ + public static HolisticLandmarker createFromBuffer( + Context context, final ByteBuffer modelAssetBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelAssetBuffer).build(); + return createFromOptions( + context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link HolisticLandmarker} instance from a {@link HolisticLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param landmarkerOptions a {@link HolisticLandmarkerOptions} instance. + * @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation. + */ + public static HolisticLandmarker createFromOptions( + Context context, HolisticLandmarkerOptions landmarkerOptions) { + List outputStreams = new ArrayList<>(); + outputStreams.add("FACE_LANDMARKS:" + FACE_LANDMARKS_STREAM); + outputStreams.add("POSE_LANDMARKS:" + POSE_LANDMARKS_STREAM); + outputStreams.add("POSE_WORLD_LANDMARKS:" + POSE_WORLD_LANDMARKS_STREAM); + outputStreams.add("LEFT_HAND_LANDMARKS:" + LEFT_HAND_LANDMARKS_STREAM); + outputStreams.add("LEFT_HAND_WORLD_LANDMARKS:" + LEFT_HAND_WORLD_LANDMARKS_STREAM); + outputStreams.add("RIGHT_HAND_LANDMARKS:" + RIGHT_HAND_LANDMARKS_STREAM); + outputStreams.add("RIGHT_HAND_WORLD_LANDMARKS:" + RIGHT_HAND_WORLD_LANDMARKS_STREAM); + outputStreams.add("IMAGE:" + IMAGE_OUT_STREAM_NAME); + + int[] faceBlendshapesOutStreamIndex = new int[] {-1}; + if (landmarkerOptions.outputFaceBlendshapes()) { + outputStreams.add("FACE_BLENDSHAPES:" + FACE_BLENDSHAPES_STREAM); + faceBlendshapesOutStreamIndex[0] = outputStreams.size() - 1; + } + + int[] poseSegmentationMasksOutStreamIndex = new int[] {-1}; + if (landmarkerOptions.outputPoseSegmentationMasks()) { + outputStreams.add("POSE_SEGMENTATION_MASK:" + POSE_SEGMENTATION_MASK_STREAM); + poseSegmentationMasksOutStreamIndex[0] = outputStreams.size() - 1; + } + + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public HolisticLandmarkerResult convertToTaskResult(List packets) { + // If there are no detected landmarks, just returns empty lists. + if (packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX).isEmpty()) { + return HolisticLandmarkerResult.createEmpty( + BaseVisionTaskApi.generateResultTimestampMs( + landmarkerOptions.runningMode(), + packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX))); + } + + NormalizedLandmarkList faceLandmarkProtos = + PacketGetter.getProto( + packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()); + Optional faceBlendshapeProtos = + landmarkerOptions.outputFaceBlendshapes() + ? Optional.of( + PacketGetter.getProto( + packets.get(faceBlendshapesOutStreamIndex[0]), + ClassificationList.parser())) + : Optional.empty(); + NormalizedLandmarkList poseLandmarkProtos = + PacketGetter.getProto( + packets.get(POSE_LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()); + LandmarkList poseWorldLandmarkProtos = + PacketGetter.getProto( + packets.get(POSE_WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()); + Optional segmentationMask = + landmarkerOptions.outputPoseSegmentationMasks() + ? Optional.of( + getSegmentationMask(packets, poseSegmentationMasksOutStreamIndex[0])) + : Optional.empty(); + NormalizedLandmarkList leftHandLandmarkProtos = + PacketGetter.getProto( + packets.get(LEFT_HAND_LANDMARKS_OUT_STREAM_INDEX), + NormalizedLandmarkList.parser()); + LandmarkList leftHandWorldLandmarkProtos = + PacketGetter.getProto( + packets.get(LEFT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()); + NormalizedLandmarkList rightHandLandmarkProtos = + PacketGetter.getProto( + packets.get(RIGHT_HAND_LANDMARKS_OUT_STREAM_INDEX), + NormalizedLandmarkList.parser()); + LandmarkList rightHandWorldLandmarkProtos = + PacketGetter.getProto( + packets.get(RIGHT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX), + LandmarkList.parser()); + + return HolisticLandmarkerResult.create( + faceLandmarkProtos, + faceBlendshapeProtos, + poseLandmarkProtos, + poseWorldLandmarkProtos, + segmentationMask, + leftHandLandmarkProtos, + leftHandWorldLandmarkProtos, + rightHandLandmarkProtos, + rightHandWorldLandmarkProtos, + BaseVisionTaskApi.generateResultTimestampMs( + landmarkerOptions.runningMode(), packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX))); + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + landmarkerOptions.resultListener().ifPresent(handler::setResultListener); + landmarkerOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(HolisticLandmarker.class.getSimpleName()) + .setTaskRunningModeName(landmarkerOptions.runningMode().name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(outputStreams) + .setTaskOptions(landmarkerOptions) + .setEnableFlowLimiting(landmarkerOptions.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new HolisticLandmarker(runner, landmarkerOptions.runningMode()); + } + + /** + * Constructor to initialize an {@link HolisticLandmarker} from a {@link TaskRunner} and a {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private HolisticLandmarker(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, /* normRectStreamName= */ ""); + } + + /** + * Performs holistic landmarks detection on the provided single image with default image + * processing options, i.e. without any rotation applied. Only use this method when the {@link + * HolisticLandmarker} is created with {@link RunningMode.IMAGE}. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public HolisticLandmarkerResult detect(MPImage image) { + return detect(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs holistic landmarks detection on the provided single image. Only use this method when + * the {@link HolisticLandmarker} is created with {@link RunningMode.IMAGE}. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public HolisticLandmarkerResult detect( + MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (HolisticLandmarkerResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs holistic landmarks detection on the provided video frame with default image processing + * options, i.e. without any rotation applied. Only use this method when the {@link + * HolisticLandmarker} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame"s timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public HolisticLandmarkerResult detectForVideo(MPImage image, long timestampMs) { + return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs holistic landmarks detection on the provided video frame. Only use this method when + * the {@link HolisticLandmarker} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame"s timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public HolisticLandmarkerResult detectForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (HolisticLandmarkerResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform holistic landmarks detection with default image processing + * options, i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link HolisticLandmarkerOptions}. Only use this method when + * the {@link HolisticLandmarker } is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the holistic landmarker. The input timestamps must be monotonically increasing. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync(MPImage image, long timestampMs) { + detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform holistic landmarks detection, and the results will be + * available via the {@link ResultListener} provided in the {@link HolisticLandmarkerOptions}. + * Only use this method when the {@link HolisticLandmarker} is created with {@link + * RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the holistic landmarker. The input timestamps must be monotonically increasing. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** Options for setting up an {@link HolisticLandmarker}. */ + @AutoValue + public abstract static class HolisticLandmarkerOptions extends TaskOptions { + + /** Builder for {@link HolisticLandmarkerOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the holistic landmarker task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the running mode for the holistic landmarker task. Defaults to the image mode. + * Holistic landmarker has three modes: + * + *
    + *
  • IMAGE: The mode for detecting holistic landmarks on single image inputs. + *
  • VIDEO: The mode for detecting holistic landmarks on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for detecting holistic landmarks on a live stream of input + * data, such as from camera. In this mode, {@code setResultListener} must be called to + * set up a listener to receive the detection results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode value); + + /** + * Sets minimum confidence score for the face detection to be considered successful. Defaults + * to 0.5. + */ + public abstract Builder setMinFaceDetectionConfidence(Float value); + + /** + * The minimum threshold for the face suppression score in the face detection. Defaults to + * 0.3. + */ + public abstract Builder setMinFaceSuppressionThreshold(Float value); + + /** + * Sets minimum confidence score for the face landmark detection to be considered successful. + * Defaults to 0.5. + */ + public abstract Builder setMinFaceLandmarksConfidence(Float value); + + /** + * The minimum confidence score for the pose detection to be considered successful. Defaults + * to 0.5. + */ + public abstract Builder setMinPoseDetectionConfidence(Float value); + + /** + * The minimum threshold for the pose suppression score in the pose detection. Defaults to + * 0.3. + */ + public abstract Builder setMinPoseSuppressionThreshold(Float value); + + /** + * The minimum confidence score for the pose landmarks detection to be considered successful. + * Defaults to 0.5. + */ + public abstract Builder setMinPoseLandmarksConfidence(Float value); + + /** + * The minimum confidence score for the hand landmark detection to be considered successful. + * Defaults to 0.5. + */ + public abstract Builder setMinHandLandmarksConfidence(Float value); + + /** Whether to output segmentation masks. Defaults to false. */ + public abstract Builder setOutputPoseSegmentationMasks(Boolean value); + + /** Whether to output face blendshapes. Defaults to false. */ + public abstract Builder setOutputFaceBlendshapes(Boolean value); + + /** + * Sets the result listener to receive the detection results asynchronously when the holistic + * landmarker is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener value); + + /** Sets an optional error listener. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract HolisticLandmarkerOptions autoBuild(); + + /** + * Validates and builds the {@link HolisticLandmarkerOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the holistic + * landmarker is in the live stream mode. + */ + public final HolisticLandmarkerOptions build() { + HolisticLandmarkerOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The holistic landmarker is in the live stream mode, a user-defined result listener" + + " must be provided in HolisticLandmarkerOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The holistic landmarker is in the image or the video mode, a user-defined result" + + " listener shouldn't be provided in HolisticLandmarkerOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional minFaceDetectionConfidence(); + + abstract Optional minFaceSuppressionThreshold(); + + abstract Optional minFaceLandmarksConfidence(); + + abstract Optional minPoseDetectionConfidence(); + + abstract Optional minPoseSuppressionThreshold(); + + abstract Optional minPoseLandmarksConfidence(); + + abstract Optional minHandLandmarksConfidence(); + + abstract Boolean outputFaceBlendshapes(); + + abstract Boolean outputPoseSegmentationMasks(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_HolisticLandmarker_HolisticLandmarkerOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setMinFaceDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD) + .setMinFaceSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD) + .setMinFaceLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD) + .setMinPoseDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD) + .setMinPoseSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD) + .setMinPoseLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD) + .setMinHandLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD) + .setOutputFaceBlendshapes(DEFAULT_OUTPUT_FACE_BLENDSHAPES) + .setOutputPoseSegmentationMasks(DEFAULT_OUTPUT_SEGMENTATION_MASKS); + } + + /** Converts a {@link HolisticLandmarkerOptions} to a {@link Any} protobuf message. */ + @Override + public Any convertToAnyProto() { + HolisticLandmarkerGraphOptions.Builder holisticLandmarkerGraphOptions = + HolisticLandmarkerGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptions())) + .build()); + + HandLandmarksDetectorGraphOptions.Builder handLandmarksDetectorGraphOptions = + HandLandmarksDetectorGraphOptions.newBuilder(); + FaceDetectorGraphOptions.Builder faceDetectorGraphOptions = + FaceDetectorGraphOptions.newBuilder(); + FaceLandmarksDetectorGraphOptions.Builder faceLandmarksDetectorGraphOptions = + FaceLandmarksDetectorGraphOptions.newBuilder(); + PoseDetectorGraphOptions.Builder poseDetectorGraphOptions = + PoseDetectorGraphOptions.newBuilder(); + PoseLandmarksDetectorGraphOptions.Builder poseLandmarkerGraphOptions = + PoseLandmarksDetectorGraphOptions.newBuilder(); + + // Configure hand detector options. + minHandLandmarksConfidence() + .ifPresent(handLandmarksDetectorGraphOptions::setMinDetectionConfidence); + + // Configure pose detector options. + minPoseDetectionConfidence().ifPresent(poseDetectorGraphOptions::setMinDetectionConfidence); + minPoseSuppressionThreshold().ifPresent(poseDetectorGraphOptions::setMinSuppressionThreshold); + minPoseLandmarksConfidence().ifPresent(poseLandmarkerGraphOptions::setMinDetectionConfidence); + + // Configure face detector options. + minFaceDetectionConfidence().ifPresent(faceDetectorGraphOptions::setMinDetectionConfidence); + minFaceSuppressionThreshold().ifPresent(faceDetectorGraphOptions::setMinSuppressionThreshold); + minFaceLandmarksConfidence() + .ifPresent(faceLandmarksDetectorGraphOptions::setMinDetectionConfidence); + + holisticLandmarkerGraphOptions + .setHandLandmarksDetectorGraphOptions(handLandmarksDetectorGraphOptions.build()) + .setFaceDetectorGraphOptions(faceDetectorGraphOptions.build()) + .setFaceLandmarksDetectorGraphOptions(faceLandmarksDetectorGraphOptions.build()) + .setPoseDetectorGraphOptions(poseDetectorGraphOptions.build()) + .setPoseLandmarksDetectorGraphOptions(poseLandmarkerGraphOptions.build()); + + return Any.newBuilder() + .setTypeUrl( + "type.googleapis.com/mediapipe.tasks.vision.holistic_landmarker.proto.HolisticLandmarkerGraphOptions") + .setValue(holisticLandmarkerGraphOptions.build().toByteString()) + .build(); + } + } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn"t contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("HolisticLandmarker doesn't support region-of-interest."); + } + } + + private static MPImage getSegmentationMask(List packets, int packetIndex) { + int width = PacketGetter.getImageWidth(packets.get(packetIndex)); + int height = PacketGetter.getImageHeight(packets.get(packetIndex)); + ByteBuffer buffer = ByteBuffer.allocateDirect(width * height * 4); + + if (!PacketGetter.getImageData(packets.get(packetIndex), buffer)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There was an error getting the sefmentation mask."); + } + + ByteBufferImageBuilder builder = + new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1); + return builder.build(); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml new file mode 100644 index 000000000..22b19b702 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/BUILD new file mode 100644 index 000000000..287602c85 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/BUILD @@ -0,0 +1,19 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarkerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarkerTest.java new file mode 100644 index 000000000..f8c87c798 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarkerTest.java @@ -0,0 +1,512 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.holisticlandmarker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.truth.Correspondence; +import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.ByteBufferImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.holisticlandmarker.HolisticLandmarker.HolisticLandmarkerOptions; +import com.google.mediapipe.tasks.vision.holisticlandmarker.HolisticResultProto.HolisticResult; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link HolisticLandmarker}. */ +@RunWith(Suite.class) +@SuiteClasses({HolisticLandmarkerTest.General.class, HolisticLandmarkerTest.RunningModeTest.class}) +public class HolisticLandmarkerTest { + private static final String HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = "holistic_landmarker.task"; + private static final String POSE_IMAGE = "male_full_height_hands.jpg"; + private static final String CAT_IMAGE = "cat.jpg"; + private static final String HOLISTIC_RESULT = "male_full_height_hands_result_cpu.pb"; + private static final String TAG = "Holistic Landmarker Test"; + private static final float FACE_LANDMARKS_ERROR_TOLERANCE = 0.03f; + private static final float FACE_BLENDSHAPES_ERROR_TOLERANCE = 0.13f; + private static final MPImage PLACEHOLDER_MASK = + new ByteBufferImageBuilder( + ByteBuffer.allocate(0), /* widht= */ 0, /* height= */ 0, MPImage.IMAGE_FORMAT_VEC32F1) + .build(); + private static final int IMAGE_WIDTH = 638; + private static final int IMAGE_HEIGHT = 1000; + + private static final Correspondence VALIDATE_LANDMARRKS = + Correspondence.from( + (Correspondence.BinaryPredicate) + (actual, expected) -> { + return Correspondence.tolerance(FACE_LANDMARKS_ERROR_TOLERANCE) + .compare(actual.x(), expected.x()) + && Correspondence.tolerance(FACE_LANDMARKS_ERROR_TOLERANCE) + .compare(actual.y(), expected.y()); + }, + "landmarks approximately equal to"); + + private static final Correspondence VALIDATE_BLENDSHAPES = + Correspondence.from( + (Correspondence.BinaryPredicate) + (actual, expected) -> + Correspondence.tolerance(FACE_BLENDSHAPES_ERROR_TOLERANCE) + .compare(actual.score(), expected.score()) + && actual.index() == expected.index() + && actual.categoryName().equals(expected.categoryName()), + "face blendshapes approximately equal to"); + + @RunWith(AndroidJUnit4.class) + public static final class General extends HolisticLandmarkerTest { + + @Test + public void detect_successWithValidModels() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult actualResult = + holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithBlendshapes() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setOutputFaceBlendshapes(true) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult actualResult = + holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ true, /* hasSegmentationMask= */ false); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithSegmentationMasks() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setOutputPoseSegmentationMasks(true) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult actualResult = + holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ true); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithEmptyResult() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult actualResult = + holisticLandmarker.detect(getImageFromAsset(CAT_IMAGE)); + assertThat(actualResult.faceLandmarks()).isEmpty(); + } + + @Test + public void detect_failsWithRegionOfInterest() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("HolisticLandmarker doesn't support region-of-interest"); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends HolisticLandmarkerTest { + private void assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode runningMode) + throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(runningMode) + .setResultListener((HolisticLandmarkerResult, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + + @Test + public void create_failsWithIllegalResultListenerInVideoMode() throws Exception { + assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode.VIDEO); + } + + @Test + public void create_failsWithIllegalResultListenerInImageMode() throws Exception { + assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode.IMAGE); + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void detect_failsWithCallingWrongApiInImageMode() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + holisticLandmarker.detectForVideo( + getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + holisticLandmarker.detectAsync( + getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInVideoMode() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + holisticLandmarker.detectAsync( + getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((HolisticLandmarkerResult, inputImage) -> {}) + .build(); + + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + holisticLandmarker.detectForVideo( + getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void detect_successWithImageMode() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult actualResult = + holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithVideoMode() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false); + for (int i = 0; i < 3; i++) { + HolisticLandmarkerResult actualResult = + holisticLandmarker.detectForVideo(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ i); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + } + + @Test + public void detect_failsWithOutOfOrderInputTimestamps() throws Exception { + MPImage image = getImageFromAsset(POSE_IMAGE); + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((actualResult, inputImage) -> {}) + .build(); + try (HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options)) { + holisticLandmarker.detectAsync(image, /* timestampsMs= */ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> holisticLandmarker.detectAsync(image, /* timestampsMs= */ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void detect_successWithLiveSteamMode() throws Exception { + MPImage image = getImageFromAsset(POSE_IMAGE); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false); + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (actualResult, inputImage) -> { + assertActualResultApproximatelyEqualsToExpectedResult( + actualResult, expectedResult); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; i++) { + holisticLandmarker.detectAsync(image, /* timestampsMs= */ i); + } + } + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static HolisticLandmarkerResult getExpectedHolisticLandmarkerResult( + String resultPath, boolean hasFaceBlendshapes, boolean hasSegmentationMask) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + + HolisticResult holisticResult = HolisticResult.parseFrom(assetManager.open(resultPath)); + + Optional blendshapes = + hasFaceBlendshapes + ? Optional.of(holisticResult.getFaceBlendshapes()) + : Optional.empty(); + Optional segmentationMask = + hasSegmentationMask ? Optional.of(PLACEHOLDER_MASK) : Optional.empty(); + + return HolisticLandmarkerResult.create( + holisticResult.getFaceLandmarks(), + blendshapes, + holisticResult.getPoseLandmarks(), + LandmarkList.getDefaultInstance(), + segmentationMask, + holisticResult.getLeftHandLandmarks(), + LandmarkList.getDefaultInstance(), + holisticResult.getRightHandLandmarks(), + LandmarkList.getDefaultInstance(), + /* timestampMs= */ 0); + } + + private static void assertActualResultApproximatelyEqualsToExpectedResult( + HolisticLandmarkerResult actualResult, HolisticLandmarkerResult expectedResult) { + // Expects to have the same number of holistics detected. + assertThat(actualResult.faceLandmarks()).hasSize(expectedResult.faceLandmarks().size()); + assertThat(actualResult.faceBlendshapes().isPresent()) + .isEqualTo(expectedResult.faceBlendshapes().isPresent()); + assertThat(actualResult.poseLandmarks()).hasSize(expectedResult.poseLandmarks().size()); + assertThat(actualResult.segmentationMask().isPresent()) + .isEqualTo(expectedResult.segmentationMask().isPresent()); + assertThat(actualResult.leftHandLandmarks()).hasSize(expectedResult.leftHandLandmarks().size()); + assertThat(actualResult.rightHandLandmarks()) + .hasSize(expectedResult.rightHandLandmarks().size()); + + // Actual face landmarks match expected face landmarks. + assertThat(actualResult.faceLandmarks()) + .comparingElementsUsing(VALIDATE_LANDMARRKS) + .containsExactlyElementsIn(expectedResult.faceLandmarks()); + + // Actual face blendshapes match expected face blendshapes. + if (actualResult.faceBlendshapes().isPresent()) { + assertThat(actualResult.faceBlendshapes().get()) + .comparingElementsUsing(VALIDATE_BLENDSHAPES) + .containsExactlyElementsIn(expectedResult.faceBlendshapes().get()); + } + + // Actual pose landmarks match expected pose landmarks. + assertThat(actualResult.poseLandmarks()) + .comparingElementsUsing(VALIDATE_LANDMARRKS) + .containsExactlyElementsIn(expectedResult.poseLandmarks()); + + if (actualResult.segmentationMask().isPresent()) { + assertImageSizeIsExpected(actualResult.segmentationMask().get()); + } + + // Actual left hand landmarks match expected left hand landmarks. + assertThat(actualResult.leftHandLandmarks()) + .comparingElementsUsing(VALIDATE_LANDMARRKS) + .containsExactlyElementsIn(expectedResult.leftHandLandmarks()); + + // Actual right hand landmarks match expected right hand landmarks. + assertThat(actualResult.rightHandLandmarks()) + .comparingElementsUsing(VALIDATE_LANDMARRKS) + .containsExactlyElementsIn(expectedResult.rightHandLandmarks()); + } + + private static void assertImageSizeIsExpected(MPImage inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); + assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); + } +} diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 3f83118b0..422241081 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -224,6 +224,7 @@ filegroup( "hand_detector_result_one_hand.pbtxt", "hand_detector_result_one_hand_rotated.pbtxt", "hand_detector_result_two_hands.pbtxt", + "male_full_height_hands_result_cpu.pbtxt", "pointing_up_landmarks.pbtxt", "pointing_up_rotated_landmarks.pbtxt", "portrait_expected_detection.pbtxt", From 91589b10d3c684af00cf8e3d14e4683797ab55bd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 28 Nov 2023 18:05:07 -0800 Subject: [PATCH 9/9] internal change. PiperOrigin-RevId: 586156439 --- mediapipe/python/solutions/drawing_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index a1acc0be2..78e931264 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -125,7 +125,8 @@ def draw_landmarks( color=RED_COLOR), connection_drawing_spec: Union[DrawingSpec, Mapping[Tuple[int, int], - DrawingSpec]] = DrawingSpec()): + DrawingSpec]] = DrawingSpec(), + is_drawing_landmarks: bool = True): """Draws the landmarks and the connections on the image. Args: @@ -142,6 +143,8 @@ def draw_landmarks( connections to the DrawingSpecs that specifies the connections' drawing settings such as color and line thickness. If this argument is explicitly set to None, no landmark connections will be drawn. + is_drawing_landmarks: Whether to draw landmarks. If set false, skip drawing + landmarks, only contours will be drawed. Raises: ValueError: If one of the followings: @@ -181,7 +184,7 @@ def draw_landmarks( drawing_spec.thickness) # Draws landmark points after finishing the connection lines, which is # aesthetically better. - if landmark_drawing_spec: + if is_drawing_landmarks and landmark_drawing_spec: for idx, landmark_px in idx_to_coordinates.items(): drawing_spec = landmark_drawing_spec[idx] if isinstance( landmark_drawing_spec, Mapping) else landmark_drawing_spec