From 6915a79e288afec408edc4bc3f851613200efb8f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 29 Sep 2023 10:24:35 -0700 Subject: [PATCH] Add tests for C API containers PiperOrigin-RevId: 569526282 --- mediapipe/tasks/c/components/containers/BUILD | 24 +++++ .../tasks/c/components/containers/category.h | 4 +- .../containers/category_converter_test.cc | 63 +++++++++++ .../containers/classification_result.h | 2 +- .../classification_result_converter.cc | 13 ++- .../classification_result_converter_test.cc | 102 ++++++++++++++++++ 6 files changed, 200 insertions(+), 8 deletions(-) create mode 100644 mediapipe/tasks/c/components/containers/category_converter_test.cc create mode 100644 mediapipe/tasks/c/components/containers/classification_result_converter_test.cc diff --git a/mediapipe/tasks/c/components/containers/BUILD b/mediapipe/tasks/c/components/containers/BUILD index 4b1841ef8..fd00261f2 100644 --- a/mediapipe/tasks/c/components/containers/BUILD +++ b/mediapipe/tasks/c/components/containers/BUILD @@ -31,6 +31,18 @@ cc_library( ], ) +cc_test( + name = "category_converter_test", + srcs = ["category_converter_test.cc"], + deps = [ + ":category", + ":category_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/components/containers:category", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "classification_result", hdrs = ["classification_result.h"], @@ -47,3 +59,15 @@ cc_library( "//mediapipe/tasks/cc/components/containers:classification_result", ], ) + +cc_test( + name = "classification_result_converter_test", + srcs = ["classification_result_converter_test.cc"], + deps = [ + ":classification_result", + ":classification_result_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/components/containers:classification_result", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/components/containers/category.h b/mediapipe/tasks/c/components/containers/category.h index b6eede40c..9a47815ab 100644 --- a/mediapipe/tasks/c/components/containers/category.h +++ b/mediapipe/tasks/c/components/containers/category.h @@ -36,11 +36,11 @@ struct Category { // The optional ID for the category, read from the label map packed in the // TFLite Model Metadata if present. Not necessarily human-readable. - const char* category_name; + char* category_name; // The optional human-readable name for the category, read from the label map // packed in the TFLite Model Metadata if present. - const char* display_name; + char* display_name; }; #ifdef __cplusplus diff --git a/mediapipe/tasks/c/components/containers/category_converter_test.cc b/mediapipe/tasks/c/components/containers/category_converter_test.cc new file mode 100644 index 000000000..566bfa803 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/category_converter_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/category_converter.h" + +#include +#include +#include + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/category.h" + +namespace mediapipe::tasks::c::components::containers { + +TEST(CategoryConverterTest, ConvertsCategoryCustomValues) { + mediapipe::tasks::components::containers::Category cpp_category = { + /* index= */ 1, + /* score= */ 0.1, + /* category_name= */ "category_name", + /* display_name= */ "display_name", + }; + + Category c_category; + CppConvertToCategory(cpp_category, &c_category); + EXPECT_EQ(c_category.index, 1); + EXPECT_FLOAT_EQ(c_category.score, 0.1); + EXPECT_EQ(std::string{c_category.category_name}, "category_name"); + EXPECT_EQ(std::string{c_category.display_name}, "display_name"); + + free(c_category.category_name); + free(c_category.display_name); +} + +TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) { + mediapipe::tasks::components::containers::Category cpp_category = { + /* index= */ 1, + /* score= */ 0.1, + /* category_name= */ std::nullopt, + /* display_name= */ std::nullopt, + }; + + Category c_category; + CppConvertToCategory(cpp_category, &c_category); + EXPECT_EQ(c_category.index, 1); + EXPECT_FLOAT_EQ(c_category.score, 0.1); + EXPECT_EQ(c_category.category_name, nullptr); + EXPECT_EQ(c_category.display_name, nullptr); +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/classification_result.h b/mediapipe/tasks/c/components/containers/classification_result.h index ef2914e5d..d03fe05ed 100644 --- a/mediapipe/tasks/c/components/containers/classification_result.h +++ b/mediapipe/tasks/c/components/containers/classification_result.h @@ -39,7 +39,7 @@ struct Classifications { // Metadata [1] if present. This is useful for multi-head models. // // [1]: https://www.tensorflow.org/lite/convert/metadata - const char* head_name; + char* head_name; }; // Defines classification results of a model. diff --git a/mediapipe/tasks/c/components/containers/classification_result_converter.cc b/mediapipe/tasks/c/components/containers/classification_result_converter.cc index 210ea3c82..a1e77ce48 100644 --- a/mediapipe/tasks/c/components/containers/classification_result_converter.cc +++ b/mediapipe/tasks/c/components/containers/classification_result_converter.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mediapipe/tasks/c/components/containers/category.h" #include "mediapipe/tasks/c/components/containers/category_converter.h" +#include "mediapipe/tasks/c/components/containers/classification_result.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h" namespace mediapipe::tasks::c::components::containers { @@ -27,12 +28,12 @@ void CppConvertToClassificationResult( const mediapipe::tasks::components::containers::ClassificationResult& in, ClassificationResult* out) { out->has_timestamp_ms = in.timestamp_ms.has_value(); - if (out->has_timestamp_ms) { - out->timestamp_ms = in.timestamp_ms.value(); - } + out->timestamp_ms = out->has_timestamp_ms ? in.timestamp_ms.value() : 0; out->classifications_count = in.classifications.size(); - out->classifications = new Classifications[out->classifications_count]; + out->classifications = out->classifications_count + ? new Classifications[out->classifications_count] + : nullptr; for (uint32_t i = 0; i < out->classifications_count; ++i) { auto classification_in = in.classifications[i]; @@ -40,7 +41,9 @@ void CppConvertToClassificationResult( classification_out.categories_count = classification_in.categories.size(); classification_out.categories = - new Category[classification_out.categories_count]; + classification_out.categories_count + ? new Category[classification_out.categories_count] + : nullptr; for (uint32_t j = 0; j < classification_out.categories_count; ++j) { CppConvertToCategory(classification_in.categories[j], &(classification_out.categories[j])); diff --git a/mediapipe/tasks/c/components/containers/classification_result_converter_test.cc b/mediapipe/tasks/c/components/containers/classification_result_converter_test.cc new file mode 100644 index 000000000..6166ffdba --- /dev/null +++ b/mediapipe/tasks/c/components/containers/classification_result_converter_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/classification_result_converter.h" + +#include +#include +#include + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" + +namespace mediapipe::tasks::c::components::containers { + +TEST(ClassificationResultConverterTest, + ConvertsClassificationResulCustomCategory) { + mediapipe::tasks::components::containers::ClassificationResult + cpp_classification_result = { + /* classifications= */ {{/* categories= */ {{ + /* index= */ 1, + /* score= */ 0.1, + /* category_name= */ std::nullopt, + /* display_name= */ std::nullopt, + }}, + /* head_index= */ 0, + /* head_name= */ "foo"}}, + /* timestamp_ms= */ 42, + }; + + ClassificationResult c_classification_result; + CppConvertToClassificationResult(cpp_classification_result, + &c_classification_result); + EXPECT_NE(c_classification_result.classifications, nullptr); + EXPECT_EQ(c_classification_result.classifications_count, 1); + EXPECT_NE(c_classification_result.classifications[0].categories, nullptr); + EXPECT_EQ(c_classification_result.classifications[0].categories_count, 1); + EXPECT_EQ(c_classification_result.classifications[0].head_index, 0); + EXPECT_EQ(std::string(c_classification_result.classifications[0].head_name), + "foo"); + EXPECT_EQ(c_classification_result.timestamp_ms, 42); + EXPECT_EQ(c_classification_result.has_timestamp_ms, true); + + free(c_classification_result.classifications[0].categories); + free(c_classification_result.classifications[0].head_name); + free(c_classification_result.classifications); +} + +TEST(ClassificationResultConverterTest, + ConvertsClassificationResulEmptyCategory) { + mediapipe::tasks::components::containers::ClassificationResult + cpp_classification_result = { + /* classifications= */ {{/* categories= */ {}, /* head_index= */ 0, + /* head_name= */ std::nullopt}}, + /* timestamp_ms= */ std::nullopt, + }; + + ClassificationResult c_classification_result; + CppConvertToClassificationResult(cpp_classification_result, + &c_classification_result); + EXPECT_NE(c_classification_result.classifications, nullptr); + EXPECT_EQ(c_classification_result.classifications_count, 1); + EXPECT_EQ(c_classification_result.classifications[0].categories, nullptr); + EXPECT_EQ(c_classification_result.classifications[0].categories_count, 0); + EXPECT_EQ(c_classification_result.classifications[0].head_index, 0); + EXPECT_EQ(c_classification_result.classifications[0].head_name, nullptr); + EXPECT_EQ(c_classification_result.timestamp_ms, 0); + EXPECT_EQ(c_classification_result.has_timestamp_ms, false); + + free(c_classification_result.classifications); +} + +TEST(ClassificationResultConverterTest, + ConvertsClassificationResultNoCategory) { + mediapipe::tasks::components::containers::ClassificationResult + cpp_classification_result = { + /* classifications= */ {}, + /* timestamp_ms= */ std::nullopt, + }; + + ClassificationResult c_classification_result; + CppConvertToClassificationResult(cpp_classification_result, + &c_classification_result); + EXPECT_EQ(c_classification_result.classifications, nullptr); + EXPECT_EQ(c_classification_result.classifications_count, 0); + EXPECT_EQ(c_classification_result.timestamp_ms, 0); + EXPECT_EQ(c_classification_result.has_timestamp_ms, false); +} + +} // namespace mediapipe::tasks::c::components::containers