From b879e3a2041b71c8a5264a650a8f3fed6fb1437f Mon Sep 17 00:00:00 2001 From: Kinar Date: Thu, 16 Nov 2023 10:05:34 -0800 Subject: [PATCH] Updated components and their tests in the C Tasks API --- .../detection_result_converter_test.cc | 50 +++++++++++++++++++ .../containers/keypoint_converter.cc | 3 +- .../containers/keypoint_converter_test.cc | 26 ++++++++-- .../c/components/containers/rect_converter.cc | 2 - .../containers/rect_converter_test.cc | 2 +- 5 files changed, 74 insertions(+), 9 deletions(-) diff --git a/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc b/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc index 884481f29..2fd85bf31 100644 --- a/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc +++ b/mediapipe/tasks/c/components/containers/detection_result_converter_test.cc @@ -25,4 +25,54 @@ limitations under the License. namespace mediapipe::tasks::c::components::containers { +TEST(DetectionResultConverterTest, ConvertsDetectionResultCustomCategory) { + mediapipe::tasks::components::containers::DetectionResult + cpp_detection_result = {/* detections= */ { + {/* categories= */ {{/* index= */ 1, /* score= */ 0.1, + /* category_name= */ "cat", + /* display_name= */ "cat"}}, + /* bounding_box= */ {10, 10, 10, 10}, + {/* keypoints */ {{0.1, 0.1, "foo", 0.5}}}}}}; + + DetectionResult c_detection_result; + CppConvertToDetectionResult(cpp_detection_result, &c_detection_result); + EXPECT_NE(c_detection_result.detections, nullptr); + EXPECT_EQ(c_detection_result.detections_count, 1); + EXPECT_NE(c_detection_result.detections[0].categories, nullptr); + EXPECT_EQ(c_detection_result.detections[0].categories_count, 1); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.left, 10); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.top, 10); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.right, 10); + EXPECT_EQ(c_detection_result.detections[0].bounding_box.bottom, 10); + EXPECT_NE(c_detection_result.detections[0].keypoints, nullptr); + + CppCloseDetectionResult(&c_detection_result); +} + +TEST(DetectionResultConverterTest, ConvertsDetectionResultNoCategory) { + mediapipe::tasks::components::containers::DetectionResult + cpp_detection_result = {/* detections= */ {/* categories= */ {}}}; + + DetectionResult c_detection_result; + CppConvertToDetectionResult(cpp_detection_result, &c_detection_result); + EXPECT_NE(c_detection_result.detections, nullptr); + EXPECT_EQ(c_detection_result.detections_count, 1); + EXPECT_NE(c_detection_result.detections[0].categories, nullptr); + EXPECT_EQ(c_detection_result.detections[0].categories_count, 0); + + CppCloseDetectionResult(&c_detection_result); +} + +TEST(DetectionResultConverterTest, FreesMemory) { + mediapipe::tasks::components::containers::DetectionResult + cpp_detection_result = {/* detections= */ {{/* categories= */ {}}}}; + + DetectionResult c_detection_result; + CppConvertToDetectionResult(cpp_detection_result, &c_detection_result); + EXPECT_NE(c_detection_result.detections, nullptr); + + CppCloseDetectionResult(&c_detection_result); + EXPECT_EQ(c_detection_result.detections, nullptr); +} + } // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/keypoint_converter.cc b/mediapipe/tasks/c/components/containers/keypoint_converter.cc index 53e8a5da1..2d64e8063 100644 --- a/mediapipe/tasks/c/components/containers/keypoint_converter.cc +++ b/mediapipe/tasks/c/components/containers/keypoint_converter.cc @@ -15,7 +15,6 @@ limitations under the License. #include "mediapipe/tasks/c/components/containers/keypoint_converter.h" -#include #include #include @@ -38,7 +37,7 @@ void CppConvertToNormalizedKeypoint( void CppCloseNormalizedKeypoint(NormalizedKeypoint* keypoint) { if (keypoint && keypoint->label) { free(keypoint->label); - keypoint->label = NULL; + keypoint->label = nullptr; } } diff --git a/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc b/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc index ca09154c3..38bf1e3c6 100644 --- a/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc +++ b/mediapipe/tasks/c/components/containers/keypoint_converter_test.cc @@ -25,11 +25,29 @@ limitations under the License. namespace mediapipe::tasks::c::components::containers { -TEST(RectConverterTest, ConvertsRectCustomValues) { - mediapipe::tasks::components::containers::Rect cpp_rect = {0, 0, 0, 0}; +constexpr float kPrecision = 1e-6; - Rect c_rect; - CppConvertToRect(cpp_rect, &c_rect); +TEST(KeypointConverterTest, ConvertsKeypointCustomValues) { + mediapipe::tasks::components::containers::NormalizedKeypoint cpp_keypoint = { + 0.1, 0.1, "foo", 0.5}; + + NormalizedKeypoint c_keypoint; + CppConvertToNormalizedKeypoint(cpp_keypoint, &c_keypoint); + EXPECT_NEAR(c_keypoint.x, 0.1f, kPrecision); + EXPECT_NEAR(c_keypoint.x, 0.1f, kPrecision); + EXPECT_EQ(std::string(c_keypoint.label), "foo"); + EXPECT_NEAR(c_keypoint.score, 0.5f, kPrecision); +} + +TEST(KeypointConverterTest, FreesMemory) { + mediapipe::tasks::components::containers::NormalizedKeypoint cpp_keypoint = { + 0.1, 0.1, "foo", 0.5}; + + NormalizedKeypoint c_keypoint; + CppConvertToNormalizedKeypoint(cpp_keypoint, &c_keypoint); + EXPECT_NE(c_keypoint.label, nullptr); + CppCloseNormalizedKeypoint(&c_keypoint); + EXPECT_EQ(c_keypoint.label, nullptr); } } // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/rect_converter.cc b/mediapipe/tasks/c/components/containers/rect_converter.cc index ff700acee..9f30bec4e 100644 --- a/mediapipe/tasks/c/components/containers/rect_converter.cc +++ b/mediapipe/tasks/c/components/containers/rect_converter.cc @@ -15,8 +15,6 @@ limitations under the License. #include "mediapipe/tasks/c/components/containers/rect_converter.h" -#include - #include "mediapipe/tasks/c/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h" diff --git a/mediapipe/tasks/c/components/containers/rect_converter_test.cc b/mediapipe/tasks/c/components/containers/rect_converter_test.cc index 3e8848094..eb2107240 100644 --- a/mediapipe/tasks/c/components/containers/rect_converter_test.cc +++ b/mediapipe/tasks/c/components/containers/rect_converter_test.cc @@ -41,7 +41,7 @@ TEST(RectFConverterTest, ConvertsRectFCustomValues) { 0.1}; RectF c_rect; - CppConvertToRect(cpp_rect, &c_rect); + CppConvertToRectF(cpp_rect, &c_rect); EXPECT_FLOAT_EQ(c_rect.left, 0.1); EXPECT_FLOAT_EQ(c_rect.right, 0.1); EXPECT_FLOAT_EQ(c_rect.top, 0.1);