Add unit tests for C layer for the input types of Text Classifier

PiperOrigin-RevId: 569553038
This commit is contained in:
Sebastian Schmidt 2023-09-29 12:03:07 -07:00 committed by Copybara-Service
parent 6915a79e28
commit 96fa10b906
10 changed files with 168 additions and 11 deletions

View File

@ -30,3 +30,15 @@ cc_library(
"//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors:classifier_options",
], ],
) )
cc_test(
name = "classifier_options_converter_test",
srcs = ["classifier_options_converter_test.cc"],
deps = [
":classifier_options",
":classifier_options_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -26,7 +26,7 @@ extern "C" {
struct ClassifierOptions { struct ClassifierOptions {
// The locale to use for display names specified through the TFLite Model // The locale to use for display names specified through the TFLite Model
// Metadata, if any. Defaults to English. // Metadata, if any. Defaults to English.
char* display_names_locale; const char* display_names_locale;
// The maximum number of top-scored classification results to return. If < 0, // The maximum number of top-scored classification results to return. If < 0,
// all available results will be returned. If 0, an invalid argument error is // all available results will be returned. If 0, an invalid argument error is
@ -40,14 +40,14 @@ struct ClassifierOptions {
// The allowlist of category names. If non-empty, detection results whose // The allowlist of category names. If non-empty, detection results whose
// category name is not in this set will be filtered out. Duplicate or unknown // category name is not in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_denylist. // category names are ignored. Mutually exclusive with category_denylist.
char** category_allowlist; const char** category_allowlist;
// The number of elements in the category allowlist. // The number of elements in the category allowlist.
uint32_t category_allowlist_count; uint32_t category_allowlist_count;
// The denylist of category names. If non-empty, detection results whose // The denylist of category names. If non-empty, detection results whose
// category name is in this set will be filtered out. Duplicate or unknown // category name is in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_allowlist. // category names are ignored. Mutually exclusive with category_allowlist.
char** category_denylist; const char** category_denylist;
// The number of elements in the category denylist. // The number of elements in the category denylist.
uint32_t category_denylist_count; uint32_t category_denylist_count;
}; };

View File

@ -26,7 +26,7 @@ void CppConvertToClassifierOptions(
const ClassifierOptions& in, const ClassifierOptions& in,
mediapipe::tasks::components::processors::ClassifierOptions* out) { mediapipe::tasks::components::processors::ClassifierOptions* out) {
out->display_names_locale = out->display_names_locale =
in.display_names_locale ? std::string(in.display_names_locale) : ""; in.display_names_locale ? std::string(in.display_names_locale) : "en";
out->max_results = in.max_results; out->max_results = in.max_results;
out->score_threshold = in.score_threshold; out->score_threshold = in.score_threshold;
out->category_allowlist = out->category_allowlist =

View File

@ -0,0 +1,82 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"
#include <string>
#include <vector>
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
namespace mediapipe::tasks::c::components::processors {
constexpr char kCategoryAllowlist[] = "fruit";
constexpr char kCategoryDenylist[] = "veggies";
constexpr char kDisplayNamesLocaleGerman[] = "de";
TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsCustomValues) {
std::vector<const char*> category_allowlist = {kCategoryAllowlist};
std::vector<const char*> category_denylist = {kCategoryDenylist};
ClassifierOptions c_classifier_options = {
/* display_names_locale= */ kDisplayNamesLocaleGerman,
/* max_results= */ 1,
/* score_threshold= */ 0.1,
/* category_allowlist= */ category_allowlist.data(),
/* category_allowlist_count= */ 1,
/* category_denylist= */ category_denylist.data(),
/* category_denylist_count= */ 1};
mediapipe::tasks::components::processors::ClassifierOptions
cpp_classifier_options = {};
CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
EXPECT_EQ(cpp_classifier_options.display_names_locale, "de");
EXPECT_EQ(cpp_classifier_options.max_results, 1);
EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.1);
EXPECT_EQ(cpp_classifier_options.category_allowlist,
std::vector<std::string>{"fruit"});
EXPECT_EQ(cpp_classifier_options.category_denylist,
std::vector<std::string>{"veggies"});
}
TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsDefaultValues) {
std::vector<const char*> category_allowlist = {kCategoryAllowlist};
std::vector<const char*> category_denylist = {kCategoryDenylist};
ClassifierOptions c_classifier_options = {/* display_names_locale= */ nullptr,
/* max_results= */ -1,
/* score_threshold= */ 0.0,
/* category_allowlist= */ nullptr,
/* category_allowlist_count= */ 0,
/* category_denylist= */ nullptr,
/* category_denylist_count= */ 0};
mediapipe::tasks::components::processors::ClassifierOptions
cpp_classifier_options = {};
CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
EXPECT_EQ(cpp_classifier_options.display_names_locale, "en");
EXPECT_EQ(cpp_classifier_options.max_results, -1);
EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.0);
EXPECT_EQ(cpp_classifier_options.category_allowlist,
std::vector<std::string>{});
EXPECT_EQ(cpp_classifier_options.category_denylist,
std::vector<std::string>{});
}
} // namespace mediapipe::tasks::c::components::processors

View File

@ -30,3 +30,15 @@ cc_library(
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
], ],
) )
cc_test(
name = "base_options_converter_test",
srcs = ["base_options_converter_test.cc"],
deps = [
":base_options",
":base_options_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/core:base_options",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -23,10 +23,10 @@ extern "C" {
// Base options for MediaPipe C Tasks. // Base options for MediaPipe C Tasks.
struct BaseOptions { struct BaseOptions {
// The model asset file contents as a string. // The model asset file contents as a string.
char* model_asset_buffer; const char* model_asset_buffer;
// The path to the model asset to open and mmap in memory. // The path to the model asset to open and mmap in memory.
char* model_asset_path; const char* model_asset_path;
}; };
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "mediapipe/tasks/c/core/base_options.h" #include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
namespace mediapipe::tasks::c::components::containers { namespace mediapipe::tasks::c::core {
void CppConvertToBaseOptions(const BaseOptions& in, void CppConvertToBaseOptions(const BaseOptions& in,
mediapipe::tasks::core::BaseOptions* out) { mediapipe::tasks::core::BaseOptions* out) {
@ -33,4 +33,4 @@ void CppConvertToBaseOptions(const BaseOptions& in,
in.model_asset_path ? std::string(in.model_asset_path) : ""; in.model_asset_path ? std::string(in.model_asset_path) : "";
} }
} // namespace mediapipe::tasks::c::components::containers } // namespace mediapipe::tasks::c::core

View File

@ -19,11 +19,11 @@ limitations under the License.
#include "mediapipe/tasks/c/core/base_options.h" #include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
namespace mediapipe::tasks::c::components::containers { namespace mediapipe::tasks::c::core {
void CppConvertToBaseOptions(const BaseOptions& in, void CppConvertToBaseOptions(const BaseOptions& in,
mediapipe::tasks::core::BaseOptions* out); mediapipe::tasks::core::BaseOptions* out);
} // namespace mediapipe::tasks::c::components::containers } // namespace mediapipe::tasks::c::core
#endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ #endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_

View File

@ -0,0 +1,51 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/core/base_options_converter.h"
#include <string>
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
namespace mediapipe::tasks::c::core {
constexpr char kAssetBuffer[] = "abc";
constexpr char kModelAssetPath[] = "abc.tflite";
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) {
BaseOptions c_base_options = {/* model_asset_buffer= */ kAssetBuffer,
/* model_asset_path= */ nullptr};
mediapipe::tasks::core::BaseOptions cpp_base_options = {};
CppConvertToBaseOptions(c_base_options, &cpp_base_options);
EXPECT_EQ(*cpp_base_options.model_asset_buffer, std::string{kAssetBuffer});
EXPECT_EQ(cpp_base_options.model_asset_path, "");
}
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetPath) {
BaseOptions c_base_options = {/* model_asset_buffer= */ nullptr,
/* model_asset_path= */ kModelAssetPath};
mediapipe::tasks::core::BaseOptions cpp_base_options = {};
CppConvertToBaseOptions(c_base_options, &cpp_base_options);
EXPECT_EQ(cpp_base_options.model_asset_buffer.get(), nullptr);
EXPECT_EQ(cpp_base_options.model_asset_path, std::string{kModelAssetPath});
}
} // namespace mediapipe::tasks::c::core

View File

@ -30,11 +30,11 @@ namespace mediapipe::tasks::c::text::text_classifier {
namespace { namespace {
using ::mediapipe::tasks::c::components::containers::CppConvertToBaseOptions;
using ::mediapipe::tasks::c::components::containers:: using ::mediapipe::tasks::c::components::containers::
CppConvertToClassificationResult; CppConvertToClassificationResult;
using ::mediapipe::tasks::c::components::processors:: using ::mediapipe::tasks::c::components::processors::
CppConvertToClassifierOptions; CppConvertToClassifierOptions;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
using ::mediapipe::tasks::text::text_classifier::TextClassifier; using ::mediapipe::tasks::text::text_classifier::TextClassifier;
} // namespace } // namespace