diff --git a/mediapipe/tasks/c/components/containers/BUILD b/mediapipe/tasks/c/components/containers/BUILD index 4bb580873..ae6876a2e 100644 --- a/mediapipe/tasks/c/components/containers/BUILD +++ b/mediapipe/tasks/c/components/containers/BUILD @@ -43,6 +43,60 @@ cc_test( ], ) +cc_library( + name = "rect", + hdrs = ["rect.h"], +) + +cc_library( + name = "rect_converter", + srcs = ["rect_converter.cc"], + hdrs = ["rect_converter.h"], + deps = [ + ":rect", + "//mediapipe/tasks/cc/components/containers:rect", + ], +) + +cc_test( + name = "rect_converter_test", + srcs = ["rect_converter_test.cc"], + deps = [ + ":rect", + ":rect_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/components/containers:rect", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "keypoint", + hdrs = ["keypoint.h"], +) + +cc_library( + name = "keypoint_converter", + srcs = ["keypoint_converter.cc"], + hdrs = ["keypoint_converter.h"], + deps = [ + ":keypoint", + "//mediapipe/tasks/cc/components/containers:keypoint", + ], +) + +cc_test( + name = "keypoint_converter_test", + srcs = ["keypoint_converter_test.cc"], + deps = [ + ":keypoint", + ":keypoint_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/components/containers:keypoint", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "classification_result", hdrs = ["classification_result.h"], @@ -72,6 +126,40 @@ cc_test( ], ) +cc_library( + name = "detection_result", + hdrs = ["detection_result.h"], + deps = [":rect"], +) + +cc_library( + name = "detection_result_converter", + srcs = ["detection_result_converter.cc"], + hdrs = ["detection_result_converter.h"], + deps = [ + ":category", + ":category_converter", + ":detection_result", + ":keypoint", + ":keypoint_converter", + ":rect", + ":rect_converter", + "//mediapipe/tasks/cc/components/containers:detection_result", + ], +) + +cc_test( + name = "detection_result_converter_test", + srcs = ["detection_result_converter_test.cc"], + deps = [ + ":detection_result", + ":detection_result_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/components/containers:detection_result", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "embedding_result", hdrs = ["embedding_result.h"], diff --git a/mediapipe/tasks/c/components/containers/detection_result.h b/mediapipe/tasks/c/components/containers/detection_result.h new file mode 100644 index 000000000..0fd7722a1 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/detection_result.h @@ -0,0 +1,63 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ + +#include + +#include "mediapipe/tasks/c/components/containers/rect.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Detection for a single bounding box. +struct Detection { + // An array of detected categories. + struct Category* categories; + + // The number of elements in the categories array. + uint32_t categories_count; + + // The bounding box location. + struct MPRect bounding_box; + + // Optional list of keypoints associated with the detection. Keypoints + // represent interesting points related to the detection. For example, the + // keypoints represent the eye, ear and mouth from face detection model. Or + // in the template matching detection, e.g. KNIFT, they can represent the + // feature points for template matching. + // `nullptr` if keypoints is not present. + struct NormalizedKeypoint* keypoints; + + // The number of elements in the keypoints array. 0 if keypoints do not exist. + uint32_t keypoints_count; +}; + +// Detection results of a model. +struct DetectionResult { + // An array of Detections. + struct Detection* detections; + + // The number of detections in the detections array. + uint32_t detections_count; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ diff --git a/mediapipe/tasks/c/components/containers/detection_result_converter.cc b/mediapipe/tasks/c/components/containers/detection_result_converter.cc new file mode 100644 index 000000000..dc76579bc --- /dev/null +++ b/mediapipe/tasks/c/components/containers/detection_result_converter.cc @@ -0,0 +1,86 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/detection_result_converter.h" + +#include + +#include "mediapipe/tasks/c/components/containers/category.h" +#include "mediapipe/tasks/c/components/containers/category_converter.h" +#include "mediapipe/tasks/c/components/containers/detection_result.h" +#include "mediapipe/tasks/c/components/containers/keypoint.h" +#include "mediapipe/tasks/c/components/containers/keypoint_converter.h" +#include "mediapipe/tasks/c/components/containers/rect_converter.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToDetection( + const mediapipe::tasks::components::containers::Detection& in, + ::Detection* out) { + out->categories_count = in.categories.size(); + out->categories = new Category[out->categories_count]; + for (size_t i = 0; i < out->categories_count; ++i) { + CppConvertToCategory(in.categories[i], &out->categories[i]); + } + + CppConvertToRect(in.bounding_box, &out->bounding_box); + + if (in.keypoints.has_value()) { + auto& keypoints = in.keypoints.value(); + out->keypoints_count = keypoints.size(); + out->keypoints = new NormalizedKeypoint[out->keypoints_count]; + for (size_t i = 0; i < out->keypoints_count; ++i) { + CppConvertToNormalizedKeypoint(keypoints[i], &out->keypoints[i]); + } + } else { + out->keypoints = nullptr; + out->keypoints_count = 0; + } +} + +void CppConvertToDetectionResult( + const mediapipe::tasks::components::containers::DetectionResult& in, + ::DetectionResult* out) { + out->detections_count = in.detections.size(); + out->detections = new ::Detection[out->detections_count]; + for (size_t i = 0; i < out->detections_count; ++i) { + CppConvertToDetection(in.detections[i], &out->detections[i]); + } +} + +// Functions to free the memory of C structures. +void CppCloseDetection(::Detection* in) { + for (size_t i = 0; i < in->categories_count; ++i) { + CppCloseCategory(&in->categories[i]); + } + delete[] in->categories; + in->categories = nullptr; + for (size_t i = 0; i < in->keypoints_count; ++i) { + CppCloseNormalizedKeypoint(&in->keypoints[i]); + } + delete[] in->keypoints; + in->keypoints = nullptr; +} + +void CppCloseDetectionResult(::DetectionResult* in) { + for (size_t i = 0; i < in->detections_count; ++i) { + CppCloseDetection(&in->detections[i]); + } + delete[] in->detections; + in->detections = nullptr; +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/detection_result_converter.h b/mediapipe/tasks/c/components/containers/detection_result_converter.h new file mode 100644 index 000000000..e338e47e9 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/detection_result_converter.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/containers/detection_result.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToDetection( + const mediapipe::tasks::components::containers::Detection& in, + Detection* out); + +void CppConvertToDetectionResult( + const mediapipe::tasks::components::containers::DetectionResult& in, + DetectionResult* out); + +void CppCloseDetection(Detection* in); + +void CppCloseDetectionResult(DetectionResult* in); + +} // namespace mediapipe::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_RESULT_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc b/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc new file mode 100644 index 000000000..16c0a76c2 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc @@ -0,0 +1,74 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/detection_result_converter.h" + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/detection_result.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +namespace mediapipe::tasks::c::components::containers { + +TEST(DetectionResultConverterTest, ConvertsDetectionResultCustomCategory) { + mediapipe::tasks::components::containers::DetectionResult + cpp_detection_result = {/* detections= */ { + {/* categories= */ {{/* index= */ 1, /* score= */ 0.1, + /* category_name= */ "cat", + /* display_name= */ "cat"}}, + /* bounding_box= */ {10, 11, 12, 13}, + {/* keypoints */ {{0.1, 0.1, "foo", 0.5}}}}}}; + + DetectionResult c_detection_result; + CppConvertToDetectionResult(cpp_detection_result, &c_detection_result); + EXPECT_NE(c_detection_result.detections, nullptr); + EXPECT_EQ(c_detection_result.detections_count, 1); + EXPECT_NE(c_detection_result.detections[0].categories, nullptr); + EXPECT_EQ(c_detection_result.detections[0].categories_count, 1); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.left, 10); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.top, 11); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.right, 12); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.bottom, 13); + EXPECT_NE(c_detection_result.detections[0].keypoints, nullptr); + + CppCloseDetectionResult(&c_detection_result); +} + +TEST(DetectionResultConverterTest, ConvertsDetectionResultNoCategory) { + mediapipe::tasks::components::containers::DetectionResult + cpp_detection_result = {/* detections= */ {/* categories= */ {}}}; + + DetectionResult c_detection_result; + CppConvertToDetectionResult(cpp_detection_result, &c_detection_result); + EXPECT_NE(c_detection_result.detections, nullptr); + EXPECT_EQ(c_detection_result.detections_count, 1); + EXPECT_NE(c_detection_result.detections[0].categories, nullptr); + EXPECT_EQ(c_detection_result.detections[0].categories_count, 0); + + CppCloseDetectionResult(&c_detection_result); +} + +TEST(DetectionResultConverterTest, FreesMemory) { + mediapipe::tasks::components::containers::DetectionResult + cpp_detection_result = {/* detections= */ {{/* categories= */ {}}}}; + + DetectionResult c_detection_result; + CppConvertToDetectionResult(cpp_detection_result, &c_detection_result); + EXPECT_NE(c_detection_result.detections, nullptr); + + CppCloseDetectionResult(&c_detection_result); + EXPECT_EQ(c_detection_result.detections, nullptr); +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/keypoint.h b/mediapipe/tasks/c/components/containers/keypoint.h new file mode 100644 index 000000000..e70d0325d --- /dev/null +++ b/mediapipe/tasks/c/components/containers/keypoint.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// A keypoint, defined by the coordinates (x, y), normalized by the image +// dimensions. +struct NormalizedKeypoint { + // x in normalized image coordinates. + float x; + + // y in normalized image coordinates. + float y; + + // Optional label of the keypoint. `nullptr` if the label is not present. + char* label; + + // Optional score of the keypoint. + float score; + + // `True` if the score is valid. + bool has_score; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_H_ diff --git a/mediapipe/tasks/c/components/containers/keypoint_converter.cc b/mediapipe/tasks/c/components/containers/keypoint_converter.cc new file mode 100644 index 000000000..d7fb9aa8a --- /dev/null +++ b/mediapipe/tasks/c/components/containers/keypoint_converter.cc @@ -0,0 +1,45 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/keypoint_converter.h" + +#include // IWYU pragma: for open source compule + +#include + +#include "mediapipe/tasks/c/components/containers/keypoint.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToNormalizedKeypoint( + const mediapipe::tasks::components::containers::NormalizedKeypoint& in, + NormalizedKeypoint* out) { + out->x = in.x; + out->y = in.y; + + out->label = in.label.has_value() ? strdup(in.label->c_str()) : nullptr; + out->has_score = in.score.has_value(); + out->score = out->has_score ? in.score.value() : 0; +} + +void CppCloseNormalizedKeypoint(NormalizedKeypoint* keypoint) { + if (keypoint && keypoint->label) { + free(keypoint->label); + keypoint->label = nullptr; + } +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/keypoint_converter.h b/mediapipe/tasks/c/components/containers/keypoint_converter.h new file mode 100644 index 000000000..a4bd725f2 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/keypoint_converter.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/containers/keypoint.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToNormalizedKeypoint( + const mediapipe::tasks::components::containers::NormalizedKeypoint& in, + NormalizedKeypoint* out); + +void CppCloseNormalizedKeypoint(NormalizedKeypoint* keypoint); + +} // namespace mediapipe::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_KEYPOINT_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc b/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc new file mode 100644 index 000000000..7c9ba6fe2 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc @@ -0,0 +1,52 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/keypoint_converter.h" + +#include + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/keypoint.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" + +namespace mediapipe::tasks::c::components::containers { + +constexpr float kPrecision = 1e-6; + +TEST(KeypointConverterTest, ConvertsKeypointCustomValues) { + mediapipe::tasks::components::containers::NormalizedKeypoint cpp_keypoint = { + 0.1, 0.2, "foo", 0.5}; + + NormalizedKeypoint c_keypoint; + CppConvertToNormalizedKeypoint(cpp_keypoint, &c_keypoint); + EXPECT_NEAR(c_keypoint.x, 0.1f, kPrecision); + EXPECT_NEAR(c_keypoint.y, 0.2f, kPrecision); + EXPECT_EQ(std::string(c_keypoint.label), "foo"); + EXPECT_NEAR(c_keypoint.score, 0.5f, kPrecision); + CppCloseNormalizedKeypoint(&c_keypoint); +} + +TEST(KeypointConverterTest, FreesMemory) { + mediapipe::tasks::components::containers::NormalizedKeypoint cpp_keypoint = { + 0.1, 0.2, "foo", 0.5}; + + NormalizedKeypoint c_keypoint; + CppConvertToNormalizedKeypoint(cpp_keypoint, &c_keypoint); + EXPECT_NE(c_keypoint.label, nullptr); + CppCloseNormalizedKeypoint(&c_keypoint); + EXPECT_EQ(c_keypoint.label, nullptr); +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/rect.h b/mediapipe/tasks/c/components/containers/rect.h new file mode 100644 index 000000000..c21857d2f --- /dev/null +++ b/mediapipe/tasks/c/components/containers/rect.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// Defines a rectangle, used e.g. as part of detection results or as input +// region-of-interest. +struct MPRect { + int left; + int top; + int bottom; + int right; +}; + +// The coordinates are normalized wrt the image dimensions, i.e. generally in +// [0,1] but they may exceed these bounds if describing a region overlapping the +// image. The origin is on the top-left corner of the image. +struct MPRectF { + float left; + float top; + float bottom; + float right; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_H_ diff --git a/mediapipe/tasks/c/components/containers/rect_converter.cc b/mediapipe/tasks/c/components/containers/rect_converter.cc new file mode 100644 index 000000000..42c574566 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/rect_converter.cc @@ -0,0 +1,41 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/rect_converter.h" + +#include "mediapipe/tasks/c/components/containers/rect.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::c::components::containers { + +// Converts a C++ Rect to a C Rect. +void CppConvertToRect(const mediapipe::tasks::components::containers::Rect& in, + struct MPRect* out) { + out->left = in.left; + out->top = in.top; + out->right = in.right; + out->bottom = in.bottom; +} + +// Converts a C++ RectF to a C RectF. +void CppConvertToRectF( + const mediapipe::tasks::components::containers::RectF& in, MPRectF* out) { + out->left = in.left; + out->top = in.top; + out->right = in.right; + out->bottom = in.bottom; +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/rect_converter.h b/mediapipe/tasks/c/components/containers/rect_converter.h new file mode 100644 index 000000000..ee446a816 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/rect_converter.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/containers/rect.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToRect(const mediapipe::tasks::components::containers::Rect& in, + MPRect* out); + +void CppConvertToRectF( + const mediapipe::tasks::components::containers::RectF& in, MPRectF* out); + +} // namespace mediapipe::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_RECT_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/containers/rect_converter_test.cc b/mediapipe/tasks/c/components/containers/rect_converter_test.cc new file mode 100644 index 000000000..7aa2daed3 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/rect_converter_test.cc @@ -0,0 +1,47 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/rect_converter.h" + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/rect.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::c::components::containers { + +TEST(RectConverterTest, ConvertsRectCustomValues) { + mediapipe::tasks::components::containers::Rect cpp_rect = {0, 1, 2, 3}; + + MPRect c_rect; + CppConvertToRect(cpp_rect, &c_rect); + EXPECT_EQ(c_rect.left, 0); + EXPECT_EQ(c_rect.top, 1); + EXPECT_EQ(c_rect.right, 2); + EXPECT_EQ(c_rect.bottom, 3); +} + +TEST(RectFConverterTest, ConvertsRectFCustomValues) { + mediapipe::tasks::components::containers::RectF cpp_rect = {0.1, 0.2, 0.3, + 0.4}; + + MPRectF c_rect; + CppConvertToRectF(cpp_rect, &c_rect); + EXPECT_FLOAT_EQ(c_rect.left, 0.1); + EXPECT_FLOAT_EQ(c_rect.top, 0.2); + EXPECT_FLOAT_EQ(c_rect.right, 0.3); + EXPECT_FLOAT_EQ(c_rect.bottom, 0.4); +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/vision/image_classifier/image_classifier.h b/mediapipe/tasks/c/vision/image_classifier/image_classifier.h index 56e63bacc..2a1691d3c 100644 --- a/mediapipe/tasks/c/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/c/vision/image_classifier/image_classifier.h @@ -60,12 +60,12 @@ struct ImageClassifierOptions { // // A caller is responsible for closing image classifier result. typedef void (*result_callback_fn)(ImageClassifierResult* result, - const MpImage image, int64_t timestamp_ms, + const MpImage& image, int64_t timestamp_ms, char* error_msg); result_callback_fn result_callback; }; -// Creates an ImageClassifier from provided `options`. +// Creates an ImageClassifier from the provided `options`. // Returns a pointer to the image classifier on success. // If an error occurs, returns `nullptr` and sets the error parameter to an // an error message (if `error_msg` is not `nullptr`). You must free the memory diff --git a/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc index 22a716dfd..2b0114dc6 100644 --- a/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/c/vision/image_classifier/image_classifier_test.cc @@ -142,7 +142,7 @@ TEST(ImageClassifierTest, VideoModeTest) { // timestamp is greater than the previous one. struct LiveStreamModeCallback { static int64_t last_timestamp; - static void Fn(ImageClassifierResult* classifier_result, const MpImage image, + static void Fn(ImageClassifierResult* classifier_result, const MpImage& image, int64_t timestamp, char* error_msg) { ASSERT_NE(classifier_result, nullptr); ASSERT_EQ(error_msg, nullptr); diff --git a/mediapipe/tasks/c/vision/image_embedder/image_embedder.h b/mediapipe/tasks/c/vision/image_embedder/image_embedder.h index 68a72dc80..809c7f2f8 100644 --- a/mediapipe/tasks/c/vision/image_embedder/image_embedder.h +++ b/mediapipe/tasks/c/vision/image_embedder/image_embedder.h @@ -62,12 +62,12 @@ struct ImageEmbedderOptions { // // A caller is responsible for closing image embedder result. typedef void (*result_callback_fn)(ImageEmbedderResult* result, - const MpImage image, int64_t timestamp_ms, + const MpImage& image, int64_t timestamp_ms, char* error_msg); result_callback_fn result_callback; }; -// Creates an ImageEmbedder from provided `options`. +// Creates an ImageEmbedder from the provided `options`. // Returns a pointer to the image embedder on success. // If an error occurs, returns `nullptr` and sets the error parameter to an // an error message (if `error_msg` is not `nullptr`). You must free the memory diff --git a/mediapipe/tasks/c/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/c/vision/image_embedder/image_embedder_test.cc index 17bc8580c..5daeac949 100644 --- a/mediapipe/tasks/c/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/c/vision/image_embedder/image_embedder_test.cc @@ -199,7 +199,7 @@ TEST(ImageEmbedderTest, VideoModeTest) { // timestamp is greater than the previous one. struct LiveStreamModeCallback { static int64_t last_timestamp; - static void Fn(ImageEmbedderResult* embedder_result, const MpImage image, + static void Fn(ImageEmbedderResult* embedder_result, const MpImage& image, int64_t timestamp, char* error_msg) { ASSERT_NE(embedder_result, nullptr); ASSERT_EQ(error_msg, nullptr); diff --git a/mediapipe/tasks/c/vision/object_detector/BUILD b/mediapipe/tasks/c/vision/object_detector/BUILD new file mode 100644 index 000000000..28bb6fa91 --- /dev/null +++ b/mediapipe/tasks/c/vision/object_detector/BUILD @@ -0,0 +1,65 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "object_detector_lib", + srcs = ["object_detector.cc"], + hdrs = ["object_detector.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/tasks/c/components/containers:detection_result", + "//mediapipe/tasks/c/components/containers:detection_result_converter", + "//mediapipe/tasks/c/core:base_options", + "//mediapipe/tasks/c/core:base_options_converter", + "//mediapipe/tasks/c/vision/core:common", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/object_detector", + "//mediapipe/tasks/cc/vision/utils:image_utils", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_test( + name = "object_detector_test", + srcs = ["object_detector_test.cc"], + data = [ + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + linkstatic = 1, + deps = [ + ":object_detector_lib", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/c/components/containers:category", + "//mediapipe/tasks/c/vision/core:common", + "//mediapipe/tasks/cc/vision/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/vision/object_detector/object_detector.cc b/mediapipe/tasks/c/vision/object_detector/object_detector.cc new file mode 100644 index 000000000..70f35ec95 --- /dev/null +++ b/mediapipe/tasks/c/vision/object_detector/object_detector.cc @@ -0,0 +1,290 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/vision/object_detector/object_detector.h" + +#include +#include +#include +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/tasks/c/components/containers/detection_result_converter.h" +#include "mediapipe/tasks/c/core/base_options_converter.h" +#include "mediapipe/tasks/c/vision/core/common.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/object_detector/object_detector.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe::tasks::c::vision::object_detector { + +namespace { + +using ::mediapipe::tasks::c::components::containers::CppCloseDetectionResult; +using ::mediapipe::tasks::c::components::containers:: + CppConvertToDetectionResult; +using ::mediapipe::tasks::c::core::CppConvertToBaseOptions; +using ::mediapipe::tasks::vision::CreateImageFromBuffer; +using ::mediapipe::tasks::vision::ObjectDetector; +using ::mediapipe::tasks::vision::core::RunningMode; +typedef ::mediapipe::tasks::vision::ObjectDetectorResult + CppObjectDetectorResult; + +int CppProcessError(absl::Status status, char** error_msg) { + if (error_msg) { + *error_msg = strdup(status.ToString().c_str()); + } + return status.raw_code(); +} + +} // namespace + +void CppConvertToDetectorOptions( + const ObjectDetectorOptions& in, + mediapipe::tasks::vision::ObjectDetectorOptions* out) { + out->display_names_locale = + in.display_names_locale ? std::string(in.display_names_locale) : "en"; + out->max_results = in.max_results; + out->score_threshold = in.score_threshold; + out->category_allowlist = + std::vector(in.category_allowlist_count); + for (uint32_t i = 0; i < in.category_allowlist_count; ++i) { + out->category_allowlist[i] = in.category_allowlist[i]; + } + out->category_denylist = std::vector(in.category_denylist_count); + for (uint32_t i = 0; i < in.category_denylist_count; ++i) { + out->category_denylist[i] = in.category_denylist[i]; + } +} + +ObjectDetector* CppObjectDetectorCreate(const ObjectDetectorOptions& options, + char** error_msg) { + auto cpp_options = + std::make_unique<::mediapipe::tasks::vision::ObjectDetectorOptions>(); + + CppConvertToBaseOptions(options.base_options, &cpp_options->base_options); + CppConvertToDetectorOptions(options, cpp_options.get()); + cpp_options->running_mode = static_cast(options.running_mode); + + // Enable callback for processing live stream data when the running mode is + // set to RunningMode::LIVE_STREAM. + if (cpp_options->running_mode == RunningMode::LIVE_STREAM) { + if (options.result_callback == nullptr) { + const absl::Status status = absl::InvalidArgumentError( + "Provided null pointer to callback function."); + ABSL_LOG(ERROR) << "Failed to create ObjectDetector: " << status; + CppProcessError(status, error_msg); + return nullptr; + } + + ObjectDetectorOptions::result_callback_fn result_callback = + options.result_callback; + cpp_options->result_callback = + [result_callback](absl::StatusOr cpp_result, + const Image& image, int64_t timestamp) { + char* error_msg = nullptr; + + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status(); + CppProcessError(cpp_result.status(), &error_msg); + result_callback(nullptr, MpImage(), timestamp, error_msg); + free(error_msg); + return; + } + + // Result is valid for the lifetime of the callback function. + ObjectDetectorResult result; + CppConvertToDetectionResult(*cpp_result, &result); + + const auto& image_frame = image.GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = { + .format = static_cast<::ImageFormat>(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + result_callback(&result, mp_image, timestamp, + /* error_msg= */ nullptr); + + CppCloseDetectionResult(&result); + }; + } + + auto detector = ObjectDetector::Create(std::move(cpp_options)); + if (!detector.ok()) { + ABSL_LOG(ERROR) << "Failed to create ObjectDetector: " << detector.status(); + CppProcessError(detector.status(), error_msg); + return nullptr; + } + return detector->release(); +} + +int CppObjectDetectorDetect(void* detector, const MpImage* image, + ObjectDetectorResult* result, char** error_msg) { + if (image->type == MpImage::GPU_BUFFER) { + const absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet."); + + ABSL_LOG(ERROR) << "Detection failed: " << status.message(); + return CppProcessError(status, error_msg); + } + + const auto img = CreateImageFromBuffer( + static_cast(image->image_frame.format), + image->image_frame.image_buffer, image->image_frame.width, + image->image_frame.height); + + if (!img.ok()) { + ABSL_LOG(ERROR) << "Failed to create Image: " << img.status(); + return CppProcessError(img.status(), error_msg); + } + + auto cpp_detector = static_cast(detector); + auto cpp_result = cpp_detector->Detect(*img); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status(); + return CppProcessError(cpp_result.status(), error_msg); + } + CppConvertToDetectionResult(*cpp_result, result); + return 0; +} + +int CppObjectDetectorDetectForVideo(void* detector, const MpImage* image, + int64_t timestamp_ms, + ObjectDetectorResult* result, + char** error_msg) { + if (image->type == MpImage::GPU_BUFFER) { + absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet"); + + ABSL_LOG(ERROR) << "Detection failed: " << status.message(); + return CppProcessError(status, error_msg); + } + + const auto img = CreateImageFromBuffer( + static_cast(image->image_frame.format), + image->image_frame.image_buffer, image->image_frame.width, + image->image_frame.height); + + if (!img.ok()) { + ABSL_LOG(ERROR) << "Failed to create Image: " << img.status(); + return CppProcessError(img.status(), error_msg); + } + + auto cpp_detector = static_cast(detector); + auto cpp_result = cpp_detector->DetectForVideo(*img, timestamp_ms); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status(); + return CppProcessError(cpp_result.status(), error_msg); + } + CppConvertToDetectionResult(*cpp_result, result); + return 0; +} + +int CppObjectDetectorDetectAsync(void* detector, const MpImage* image, + int64_t timestamp_ms, char** error_msg) { + if (image->type == MpImage::GPU_BUFFER) { + absl::Status status = + absl::InvalidArgumentError("GPU Buffer not supported yet"); + + ABSL_LOG(ERROR) << "Detection failed: " << status.message(); + return CppProcessError(status, error_msg); + } + + const auto img = CreateImageFromBuffer( + static_cast(image->image_frame.format), + image->image_frame.image_buffer, image->image_frame.width, + image->image_frame.height); + + if (!img.ok()) { + ABSL_LOG(ERROR) << "Failed to create Image: " << img.status(); + return CppProcessError(img.status(), error_msg); + } + + auto cpp_detector = static_cast(detector); + auto cpp_result = cpp_detector->DetectAsync(*img, timestamp_ms); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Data preparation for the object detection failed: " + << cpp_result; + return CppProcessError(cpp_result, error_msg); + } + return 0; +} + +void CppObjectDetectorCloseResult(ObjectDetectorResult* result) { + CppCloseDetectionResult(result); +} + +int CppObjectDetectorClose(void* detector, char** error_msg) { + auto cpp_detector = static_cast(detector); + auto result = cpp_detector->Close(); + if (!result.ok()) { + ABSL_LOG(ERROR) << "Failed to close ObjectDetector: " << result; + return CppProcessError(result, error_msg); + } + delete cpp_detector; + return 0; +} + +} // namespace mediapipe::tasks::c::vision::object_detector + +extern "C" { + +void* object_detector_create(struct ObjectDetectorOptions* options, + char** error_msg) { + return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorCreate( + *options, error_msg); +} + +int object_detector_detect_image(void* detector, const MpImage* image, + ObjectDetectorResult* result, + char** error_msg) { + return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorDetect( + detector, image, result, error_msg); +} + +int object_detector_detect_for_video(void* detector, const MpImage* image, + int64_t timestamp_ms, + ObjectDetectorResult* result, + char** error_msg) { + return mediapipe::tasks::c::vision::object_detector:: + CppObjectDetectorDetectForVideo(detector, image, timestamp_ms, result, + error_msg); +} + +int object_detector_detect_async(void* detector, const MpImage* image, + int64_t timestamp_ms, char** error_msg) { + return mediapipe::tasks::c::vision::object_detector:: + CppObjectDetectorDetectAsync(detector, image, timestamp_ms, error_msg); +} + +void object_detector_close_result(ObjectDetectorResult* result) { + mediapipe::tasks::c::vision::object_detector::CppObjectDetectorCloseResult( + result); +} + +int object_detector_close(void* detector, char** error_ms) { + return mediapipe::tasks::c::vision::object_detector::CppObjectDetectorClose( + detector, error_ms); +} + +} // extern "C" diff --git a/mediapipe/tasks/c/vision/object_detector/object_detector.h b/mediapipe/tasks/c/vision/object_detector/object_detector.h new file mode 100644 index 000000000..e14523a49 --- /dev/null +++ b/mediapipe/tasks/c/vision/object_detector/object_detector.h @@ -0,0 +1,157 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_ +#define MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_ + +#include "mediapipe/tasks/c/components/containers/detection_result.h" +#include "mediapipe/tasks/c/core/base_options.h" +#include "mediapipe/tasks/c/vision/core/common.h" + +#ifndef MP_EXPORT +#define MP_EXPORT __attribute__((visibility("default"))) +#endif // MP_EXPORT + +#ifdef __cplusplus +extern "C" { +#endif + +typedef DetectionResult ObjectDetectorResult; + +// The options for configuring a MediaPipe object detector task. +struct ObjectDetectorOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + struct BaseOptions base_options; + + // The running mode of the task. Default to the image mode. + // Object detector has three running modes: + // 1) The image mode for detecting objects on single image inputs. + // 2) The video mode for detecting objects on the decoded frames of a video. + // 3) The live stream mode for detecting objects on the live stream of input + // data, such as from camera. In this mode, the "result_callback" below must + // be specified to receive the detection results asynchronously. + RunningMode running_mode; + + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + const char* display_names_locale; + + // The maximum number of top-scored detection results to return. If < 0, + // all available results will be returned. If 0, an invalid argument error is + // returned. + int max_results; + + // Score threshold to override the one provided in the model metadata (if + // any). Results below this value are rejected. + float score_threshold; + + // The allowlist of category names. If non-empty, detection results whose + // category name is not in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_denylist. + const char** category_allowlist; + // The number of elements in the category allowlist. + uint32_t category_allowlist_count; + + // The denylist of category names. If non-empty, detection results whose + // category name is in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_allowlist. + const char** category_denylist; + // The number of elements in the category denylist. + uint32_t category_denylist_count; + + // The user-defined result callback for processing live stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::LIVE_STREAM. Arguments of the callback function include: + // the pointer to detection result, the image that result was obtained + // on, the timestamp relevant to detection results and pointer to error + // message in case of any failure. The validity of the passed arguments is + // true for the lifetime of the callback function. + // + // A caller is responsible for closing object detector result. + typedef void (*result_callback_fn)(ObjectDetectorResult* result, + const MpImage& image, int64_t timestamp_ms, + char* error_msg); + result_callback_fn result_callback; +}; + +// Creates an ObjectDetector from the provided `options`. +// Returns a pointer to the image detector on success. +// If an error occurs, returns `nullptr` and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT void* object_detector_create(struct ObjectDetectorOptions* options, + char** error_msg); + +// Performs image detection on the input `image`. Returns `0` on success. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int object_detector_detect_image(void* detector, const MpImage* image, + ObjectDetectorResult* result, + char** error_msg); + +// Performs image detection on the provided video frame. +// Only use this method when the ObjectDetector is created with the video +// running mode. +// The image can be of any size with format RGB or RGBA. It's required to +// provide the video frame's timestamp (in milliseconds). The input timestamps +// must be monotonically increasing. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int object_detector_detect_for_video(void* detector, + const MpImage* image, + int64_t timestamp_ms, + ObjectDetectorResult* result, + char** error_msg); + +// Sends live image data to image detection, and the results will be +// available via the `result_callback` provided in the ObjectDetectorOptions. +// Only use this method when the ObjectDetector is created with the live +// stream running mode. +// The image can be of any size with format RGB or RGBA. It's required to +// provide a timestamp (in milliseconds) to indicate when the input image is +// sent to the object detector. The input timestamps must be monotonically +// increasing. +// The `result_callback` provides: +// - The detection results as an ObjectDetectorResult object. +// - The const reference to the corresponding input image that the image +// detector runs on. Note that the const reference to the image will no +// longer be valid when the callback returns. To access the image data +// outside of the callback, callers need to make a copy of the image. +// - The input timestamp in milliseconds. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int object_detector_detect_async(void* detector, const MpImage* image, + int64_t timestamp_ms, + char** error_msg); + +// Frees the memory allocated inside a ObjectDetectorResult result. +// Does not free the result pointer itself. +MP_EXPORT void object_detector_close_result(ObjectDetectorResult* result); + +// Frees object detector. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not `nullptr`). You must free the memory +// allocated for the error message. +MP_EXPORT int object_detector_close(void* detector, char** error_msg); + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_VISION_OBJECT_DETECTOR_OBJECT_DETECTOR_H_ diff --git a/mediapipe/tasks/c/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/c/vision/object_detector/object_detector_test.cc new file mode 100644 index 000000000..8e53fa5c9 --- /dev/null +++ b/mediapipe/tasks/c/vision/object_detector/object_detector_test.cc @@ -0,0 +1,253 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/vision/object_detector/object_detector.h" + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/category.h" +#include "mediapipe/tasks/c/vision/core/common.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace { + +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using testing::HasSubstr; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kImageFile[] = "cats_and_dogs.jpg"; +constexpr char kModelName[] = + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; +constexpr float kPrecision = 1e-4; +constexpr int kIterations = 100; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +TEST(ObjectDetectorTest, ImageModeTest) { + const auto image = DecodeImageFromFile(GetFullPath(kImageFile)); + ASSERT_TRUE(image.ok()); + + const std::string model_path = GetFullPath(kModelName); + ObjectDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_buffer_count= */ 0, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::IMAGE, + /* display_names_locale= */ nullptr, + /* max_results= */ -1, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0, + }; + + void* detector = object_detector_create(&options, /* error_msg */ nullptr); + EXPECT_NE(detector, nullptr); + + const auto& image_frame = image->GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + ObjectDetectorResult result; + object_detector_detect_image(detector, &mp_image, &result, + /* error_msg */ nullptr); + EXPECT_EQ(result.detections_count, 10); + EXPECT_EQ(result.detections[0].categories_count, 1); + EXPECT_EQ(std::string{result.detections[0].categories[0].category_name}, + "cat"); + EXPECT_NEAR(result.detections[0].categories[0].score, 0.6992f, kPrecision); + object_detector_close_result(&result); + object_detector_close(detector, /* error_msg */ nullptr); +} + +TEST(ObjectDetectorTest, VideoModeTest) { + const auto image = DecodeImageFromFile(GetFullPath(kImageFile)); + ASSERT_TRUE(image.ok()); + + const std::string model_path = GetFullPath(kModelName); + ObjectDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_buffer_count= */ 0, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::VIDEO, + /* display_names_locale= */ nullptr, + /* max_results= */ 3, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0, + }; + + void* detector = object_detector_create(&options, /* error_msg */ nullptr); + EXPECT_NE(detector, nullptr); + + const auto& image_frame = image->GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + for (int i = 0; i < kIterations; ++i) { + ObjectDetectorResult result; + object_detector_detect_for_video(detector, &mp_image, i, &result, + /* error_msg */ nullptr); + EXPECT_EQ(result.detections_count, 3); + EXPECT_EQ(result.detections[0].categories_count, 1); + EXPECT_EQ(std::string{result.detections[0].categories[0].category_name}, + "cat"); + EXPECT_NEAR(result.detections[0].categories[0].score, 0.6992f, kPrecision); + object_detector_close_result(&result); + } + object_detector_close(detector, /* error_msg */ nullptr); +} + +// A structure to support LiveStreamModeTest below. This structure holds a +// static method `Fn` for a callback function of C API. A `static` qualifier +// allows to take an address of the method to follow API style. Another static +// struct member is `last_timestamp` that is used to verify that current +// timestamp is greater than the previous one. +struct LiveStreamModeCallback { + static int64_t last_timestamp; + static void Fn(ObjectDetectorResult* detector_result, const MpImage& image, + int64_t timestamp, char* error_msg) { + ASSERT_NE(detector_result, nullptr); + ASSERT_EQ(error_msg, nullptr); + EXPECT_EQ(detector_result->detections_count, 3); + EXPECT_EQ(detector_result->detections[0].categories_count, 1); + EXPECT_EQ( + std::string{detector_result->detections[0].categories[0].category_name}, + "cat"); + EXPECT_NEAR(detector_result->detections[0].categories[0].score, 0.6992f, + kPrecision); + EXPECT_GT(image.image_frame.width, 0); + EXPECT_GT(image.image_frame.height, 0); + EXPECT_GT(timestamp, last_timestamp); + last_timestamp++; + } +}; +int64_t LiveStreamModeCallback::last_timestamp = -1; + +TEST(ObjectDetectorTest, LiveStreamModeTest) { + const auto image = DecodeImageFromFile(GetFullPath(kImageFile)); + ASSERT_TRUE(image.ok()); + + const std::string model_path = GetFullPath(kModelName); + + ObjectDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_buffer_count= */ 0, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::LIVE_STREAM, + /* display_names_locale= */ nullptr, + /* max_results= */ 3, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0, + /* result_callback= */ LiveStreamModeCallback::Fn, + }; + + void* detector = object_detector_create(&options, /* error_msg */ + nullptr); + EXPECT_NE(detector, nullptr); + + const auto& image_frame = image->GetImageFrameSharedPtr(); + const MpImage mp_image = { + .type = MpImage::IMAGE_FRAME, + .image_frame = {.format = static_cast(image_frame->Format()), + .image_buffer = image_frame->PixelData(), + .width = image_frame->Width(), + .height = image_frame->Height()}}; + + for (int i = 0; i < kIterations; ++i) { + EXPECT_GE(object_detector_detect_async(detector, &mp_image, i, + /* error_msg */ nullptr), + 0); + } + object_detector_close(detector, /* error_msg */ nullptr); + + // Due to the flow limiter, the total of outputs might be smaller than the + // number of iterations. + EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations); + EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0); +} + +TEST(ObjectDetectorTest, InvalidArgumentHandling) { + // It is an error to set neither the asset buffer nor the path. + ObjectDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_buffer_count= */ 0, + /* model_asset_path= */ nullptr}, + }; + + char* error_msg; + void* detector = object_detector_create(&options, &error_msg); + EXPECT_EQ(detector, nullptr); + + EXPECT_THAT(error_msg, HasSubstr("ExternalFile must specify")); + + free(error_msg); +} + +TEST(ObjectDetectorTest, FailedDetectionHandling) { + const std::string model_path = GetFullPath(kModelName); + ObjectDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_buffer_count= */ 0, + /* model_asset_path= */ model_path.c_str()}, + /* running_mode= */ RunningMode::IMAGE, + /* display_names_locale= */ nullptr, + /* max_results= */ -1, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0, + }; + + void* detector = object_detector_create(&options, /* error_msg */ + nullptr); + EXPECT_NE(detector, nullptr); + + const MpImage mp_image = {.type = MpImage::GPU_BUFFER, .gpu_buffer = {}}; + ObjectDetectorResult result; + char* error_msg; + object_detector_detect_image(detector, &mp_image, &result, &error_msg); + EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet")); + free(error_msg); + object_detector_close(detector, /* error_msg */ nullptr); +} + +} // namespace