From 96fa10b9061504f9d81499f161df0d6e7454c531 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 29 Sep 2023 12:03:07 -0700 Subject: [PATCH] Add unit tests for C layer for the input types of Text Classifier PiperOrigin-RevId: 569553038 --- mediapipe/tasks/c/components/processors/BUILD | 12 +++ .../processors/classifier_options.h | 6 +- .../classifier_options_converter.cc | 2 +- .../classifier_options_converter_test.cc | 82 +++++++++++++++++++ mediapipe/tasks/c/core/BUILD | 12 +++ mediapipe/tasks/c/core/base_options.h | 4 +- .../tasks/c/core/base_options_converter.cc | 4 +- .../tasks/c/core/base_options_converter.h | 4 +- .../c/core/base_options_converter_test.cc | 51 ++++++++++++ .../c/text/text_classifier/text_classifier.cc | 2 +- 10 files changed, 168 insertions(+), 11 deletions(-) create mode 100644 mediapipe/tasks/c/components/processors/classifier_options_converter_test.cc create mode 100644 mediapipe/tasks/c/core/base_options_converter_test.cc diff --git a/mediapipe/tasks/c/components/processors/BUILD b/mediapipe/tasks/c/components/processors/BUILD index e90437d59..5794769d2 100644 --- a/mediapipe/tasks/c/components/processors/BUILD +++ b/mediapipe/tasks/c/components/processors/BUILD @@ -30,3 +30,15 @@ cc_library( "//mediapipe/tasks/cc/components/processors:classifier_options", ], ) + +cc_test( + name = "classifier_options_converter_test", + srcs = ["classifier_options_converter_test.cc"], + deps = [ + ":classifier_options", + ":classifier_options_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/components/processors/classifier_options.h b/mediapipe/tasks/c/components/processors/classifier_options.h index 4658fb42b..32ad22b0e 100644 --- a/mediapipe/tasks/c/components/processors/classifier_options.h +++ b/mediapipe/tasks/c/components/processors/classifier_options.h @@ -26,7 +26,7 @@ extern "C" { struct ClassifierOptions { // The locale to use for display names specified through the TFLite Model // Metadata, if any. Defaults to English. - char* display_names_locale; + const char* display_names_locale; // The maximum number of top-scored classification results to return. If < 0, // all available results will be returned. If 0, an invalid argument error is @@ -40,14 +40,14 @@ struct ClassifierOptions { // The allowlist of category names. If non-empty, detection results whose // category name is not in this set will be filtered out. Duplicate or unknown // category names are ignored. Mutually exclusive with category_denylist. - char** category_allowlist; + const char** category_allowlist; // The number of elements in the category allowlist. uint32_t category_allowlist_count; // The denylist of category names. If non-empty, detection results whose // category name is in this set will be filtered out. Duplicate or unknown // category names are ignored. Mutually exclusive with category_allowlist. - char** category_denylist; + const char** category_denylist; // The number of elements in the category denylist. uint32_t category_denylist_count; }; diff --git a/mediapipe/tasks/c/components/processors/classifier_options_converter.cc b/mediapipe/tasks/c/components/processors/classifier_options_converter.cc index eca2b3d33..2a026ba3b 100644 --- a/mediapipe/tasks/c/components/processors/classifier_options_converter.cc +++ b/mediapipe/tasks/c/components/processors/classifier_options_converter.cc @@ -26,7 +26,7 @@ void CppConvertToClassifierOptions( const ClassifierOptions& in, mediapipe::tasks::components::processors::ClassifierOptions* out) { out->display_names_locale = - in.display_names_locale ? std::string(in.display_names_locale) : ""; + in.display_names_locale ? std::string(in.display_names_locale) : "en"; out->max_results = in.max_results; out->score_threshold = in.score_threshold; out->category_allowlist = diff --git a/mediapipe/tasks/c/components/processors/classifier_options_converter_test.cc b/mediapipe/tasks/c/components/processors/classifier_options_converter_test.cc new file mode 100644 index 000000000..9eee18fdf --- /dev/null +++ b/mediapipe/tasks/c/components/processors/classifier_options_converter_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h" + +#include +#include + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" + +namespace mediapipe::tasks::c::components::processors { + +constexpr char kCategoryAllowlist[] = "fruit"; +constexpr char kCategoryDenylist[] = "veggies"; +constexpr char kDisplayNamesLocaleGerman[] = "de"; + +TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsCustomValues) { + std::vector category_allowlist = {kCategoryAllowlist}; + std::vector category_denylist = {kCategoryDenylist}; + + ClassifierOptions c_classifier_options = { + /* display_names_locale= */ kDisplayNamesLocaleGerman, + /* max_results= */ 1, + /* score_threshold= */ 0.1, + /* category_allowlist= */ category_allowlist.data(), + /* category_allowlist_count= */ 1, + /* category_denylist= */ category_denylist.data(), + /* category_denylist_count= */ 1}; + + mediapipe::tasks::components::processors::ClassifierOptions + cpp_classifier_options = {}; + + CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options); + EXPECT_EQ(cpp_classifier_options.display_names_locale, "de"); + EXPECT_EQ(cpp_classifier_options.max_results, 1); + EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.1); + EXPECT_EQ(cpp_classifier_options.category_allowlist, + std::vector{"fruit"}); + EXPECT_EQ(cpp_classifier_options.category_denylist, + std::vector{"veggies"}); +} + +TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsDefaultValues) { + std::vector category_allowlist = {kCategoryAllowlist}; + std::vector category_denylist = {kCategoryDenylist}; + + ClassifierOptions c_classifier_options = {/* display_names_locale= */ nullptr, + /* max_results= */ -1, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0}; + + mediapipe::tasks::components::processors::ClassifierOptions + cpp_classifier_options = {}; + + CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options); + EXPECT_EQ(cpp_classifier_options.display_names_locale, "en"); + EXPECT_EQ(cpp_classifier_options.max_results, -1); + EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.0); + EXPECT_EQ(cpp_classifier_options.category_allowlist, + std::vector{}); + EXPECT_EQ(cpp_classifier_options.category_denylist, + std::vector{}); +} + +} // namespace mediapipe::tasks::c::components::processors diff --git a/mediapipe/tasks/c/core/BUILD b/mediapipe/tasks/c/core/BUILD index 9a360404e..a7c3aa9cf 100644 --- a/mediapipe/tasks/c/core/BUILD +++ b/mediapipe/tasks/c/core/BUILD @@ -30,3 +30,15 @@ cc_library( "//mediapipe/tasks/cc/core:base_options", ], ) + +cc_test( + name = "base_options_converter_test", + srcs = ["base_options_converter_test.cc"], + deps = [ + ":base_options", + ":base_options_converter", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/cc/core:base_options", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/core/base_options.h b/mediapipe/tasks/c/core/base_options.h index d23b6884c..78d89ce8c 100644 --- a/mediapipe/tasks/c/core/base_options.h +++ b/mediapipe/tasks/c/core/base_options.h @@ -23,10 +23,10 @@ extern "C" { // Base options for MediaPipe C Tasks. struct BaseOptions { // The model asset file contents as a string. - char* model_asset_buffer; + const char* model_asset_buffer; // The path to the model asset to open and mmap in memory. - char* model_asset_path; + const char* model_asset_path; }; #ifdef __cplusplus diff --git a/mediapipe/tasks/c/core/base_options_converter.cc b/mediapipe/tasks/c/core/base_options_converter.cc index 78f5edb49..3f126168b 100644 --- a/mediapipe/tasks/c/core/base_options_converter.cc +++ b/mediapipe/tasks/c/core/base_options_converter.cc @@ -21,7 +21,7 @@ limitations under the License. #include "mediapipe/tasks/c/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h" -namespace mediapipe::tasks::c::components::containers { +namespace mediapipe::tasks::c::core { void CppConvertToBaseOptions(const BaseOptions& in, mediapipe::tasks::core::BaseOptions* out) { @@ -33,4 +33,4 @@ void CppConvertToBaseOptions(const BaseOptions& in, in.model_asset_path ? std::string(in.model_asset_path) : ""; } -} // namespace mediapipe::tasks::c::components::containers +} // namespace mediapipe::tasks::c::core diff --git a/mediapipe/tasks/c/core/base_options_converter.h b/mediapipe/tasks/c/core/base_options_converter.h index 90db6397d..f22740b6e 100644 --- a/mediapipe/tasks/c/core/base_options_converter.h +++ b/mediapipe/tasks/c/core/base_options_converter.h @@ -19,11 +19,11 @@ limitations under the License. #include "mediapipe/tasks/c/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h" -namespace mediapipe::tasks::c::components::containers { +namespace mediapipe::tasks::c::core { void CppConvertToBaseOptions(const BaseOptions& in, mediapipe::tasks::core::BaseOptions* out); -} // namespace mediapipe::tasks::c::components::containers +} // namespace mediapipe::tasks::c::core #endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ diff --git a/mediapipe/tasks/c/core/base_options_converter_test.cc b/mediapipe/tasks/c/core/base_options_converter_test.cc new file mode 100644 index 000000000..27c7fb3ec --- /dev/null +++ b/mediapipe/tasks/c/core/base_options_converter_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/core/base_options_converter.h" + +#include + +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/core/base_options.h" +#include "mediapipe/tasks/cc/core/base_options.h" + +namespace mediapipe::tasks::c::core { + +constexpr char kAssetBuffer[] = "abc"; +constexpr char kModelAssetPath[] = "abc.tflite"; + +TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) { + BaseOptions c_base_options = {/* model_asset_buffer= */ kAssetBuffer, + /* model_asset_path= */ nullptr}; + + mediapipe::tasks::core::BaseOptions cpp_base_options = {}; + + CppConvertToBaseOptions(c_base_options, &cpp_base_options); + EXPECT_EQ(*cpp_base_options.model_asset_buffer, std::string{kAssetBuffer}); + EXPECT_EQ(cpp_base_options.model_asset_path, ""); +} + +TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetPath) { + BaseOptions c_base_options = {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ kModelAssetPath}; + + mediapipe::tasks::core::BaseOptions cpp_base_options = {}; + + CppConvertToBaseOptions(c_base_options, &cpp_base_options); + EXPECT_EQ(cpp_base_options.model_asset_buffer.get(), nullptr); + EXPECT_EQ(cpp_base_options.model_asset_path, std::string{kModelAssetPath}); +} + +} // namespace mediapipe::tasks::c::core diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier.cc b/mediapipe/tasks/c/text/text_classifier/text_classifier.cc index 0de123965..53fe34f75 100644 --- a/mediapipe/tasks/c/text/text_classifier/text_classifier.cc +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier.cc @@ -30,11 +30,11 @@ namespace mediapipe::tasks::c::text::text_classifier { namespace { -using ::mediapipe::tasks::c::components::containers::CppConvertToBaseOptions; using ::mediapipe::tasks::c::components::containers:: CppConvertToClassificationResult; using ::mediapipe::tasks::c::components::processors:: CppConvertToClassifierOptions; +using ::mediapipe::tasks::c::core::CppConvertToBaseOptions; using ::mediapipe::tasks::text::text_classifier::TextClassifier; } // namespace