Add unit tests for C layer for the input types of Text Classifier
PiperOrigin-RevId: 569553038
This commit is contained in:
parent
6915a79e28
commit
96fa10b906
|
@ -30,3 +30,15 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "classifier_options_converter_test",
|
||||
srcs = ["classifier_options_converter_test.cc"],
|
||||
deps = [
|
||||
":classifier_options",
|
||||
":classifier_options_converter",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -26,7 +26,7 @@ extern "C" {
|
|||
struct ClassifierOptions {
|
||||
// The locale to use for display names specified through the TFLite Model
|
||||
// Metadata, if any. Defaults to English.
|
||||
char* display_names_locale;
|
||||
const char* display_names_locale;
|
||||
|
||||
// The maximum number of top-scored classification results to return. If < 0,
|
||||
// all available results will be returned. If 0, an invalid argument error is
|
||||
|
@ -40,14 +40,14 @@ struct ClassifierOptions {
|
|||
// The allowlist of category names. If non-empty, detection results whose
|
||||
// category name is not in this set will be filtered out. Duplicate or unknown
|
||||
// category names are ignored. Mutually exclusive with category_denylist.
|
||||
char** category_allowlist;
|
||||
const char** category_allowlist;
|
||||
// The number of elements in the category allowlist.
|
||||
uint32_t category_allowlist_count;
|
||||
|
||||
// The denylist of category names. If non-empty, detection results whose
|
||||
// category name is in this set will be filtered out. Duplicate or unknown
|
||||
// category names are ignored. Mutually exclusive with category_allowlist.
|
||||
char** category_denylist;
|
||||
const char** category_denylist;
|
||||
// The number of elements in the category denylist.
|
||||
uint32_t category_denylist_count;
|
||||
};
|
||||
|
|
|
@ -26,7 +26,7 @@ void CppConvertToClassifierOptions(
|
|||
const ClassifierOptions& in,
|
||||
mediapipe::tasks::components::processors::ClassifierOptions* out) {
|
||||
out->display_names_locale =
|
||||
in.display_names_locale ? std::string(in.display_names_locale) : "";
|
||||
in.display_names_locale ? std::string(in.display_names_locale) : "en";
|
||||
out->max_results = in.max_results;
|
||||
out->score_threshold = in.score_threshold;
|
||||
out->category_allowlist =
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::processors {
|
||||
|
||||
constexpr char kCategoryAllowlist[] = "fruit";
|
||||
constexpr char kCategoryDenylist[] = "veggies";
|
||||
constexpr char kDisplayNamesLocaleGerman[] = "de";
|
||||
|
||||
TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsCustomValues) {
|
||||
std::vector<const char*> category_allowlist = {kCategoryAllowlist};
|
||||
std::vector<const char*> category_denylist = {kCategoryDenylist};
|
||||
|
||||
ClassifierOptions c_classifier_options = {
|
||||
/* display_names_locale= */ kDisplayNamesLocaleGerman,
|
||||
/* max_results= */ 1,
|
||||
/* score_threshold= */ 0.1,
|
||||
/* category_allowlist= */ category_allowlist.data(),
|
||||
/* category_allowlist_count= */ 1,
|
||||
/* category_denylist= */ category_denylist.data(),
|
||||
/* category_denylist_count= */ 1};
|
||||
|
||||
mediapipe::tasks::components::processors::ClassifierOptions
|
||||
cpp_classifier_options = {};
|
||||
|
||||
CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
|
||||
EXPECT_EQ(cpp_classifier_options.display_names_locale, "de");
|
||||
EXPECT_EQ(cpp_classifier_options.max_results, 1);
|
||||
EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.1);
|
||||
EXPECT_EQ(cpp_classifier_options.category_allowlist,
|
||||
std::vector<std::string>{"fruit"});
|
||||
EXPECT_EQ(cpp_classifier_options.category_denylist,
|
||||
std::vector<std::string>{"veggies"});
|
||||
}
|
||||
|
||||
TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsDefaultValues) {
|
||||
std::vector<const char*> category_allowlist = {kCategoryAllowlist};
|
||||
std::vector<const char*> category_denylist = {kCategoryDenylist};
|
||||
|
||||
ClassifierOptions c_classifier_options = {/* display_names_locale= */ nullptr,
|
||||
/* max_results= */ -1,
|
||||
/* score_threshold= */ 0.0,
|
||||
/* category_allowlist= */ nullptr,
|
||||
/* category_allowlist_count= */ 0,
|
||||
/* category_denylist= */ nullptr,
|
||||
/* category_denylist_count= */ 0};
|
||||
|
||||
mediapipe::tasks::components::processors::ClassifierOptions
|
||||
cpp_classifier_options = {};
|
||||
|
||||
CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
|
||||
EXPECT_EQ(cpp_classifier_options.display_names_locale, "en");
|
||||
EXPECT_EQ(cpp_classifier_options.max_results, -1);
|
||||
EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.0);
|
||||
EXPECT_EQ(cpp_classifier_options.category_allowlist,
|
||||
std::vector<std::string>{});
|
||||
EXPECT_EQ(cpp_classifier_options.category_denylist,
|
||||
std::vector<std::string>{});
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::processors
|
|
@ -30,3 +30,15 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core:base_options",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "base_options_converter_test",
|
||||
srcs = ["base_options_converter_test.cc"],
|
||||
deps = [
|
||||
":base_options",
|
||||
":base_options_converter",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -23,10 +23,10 @@ extern "C" {
|
|||
// Base options for MediaPipe C Tasks.
|
||||
struct BaseOptions {
|
||||
// The model asset file contents as a string.
|
||||
char* model_asset_buffer;
|
||||
const char* model_asset_buffer;
|
||||
|
||||
// The path to the model asset to open and mmap in memory.
|
||||
char* model_asset_path;
|
||||
const char* model_asset_path;
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/c/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
namespace mediapipe::tasks::c::core {
|
||||
|
||||
void CppConvertToBaseOptions(const BaseOptions& in,
|
||||
mediapipe::tasks::core::BaseOptions* out) {
|
||||
|
@ -33,4 +33,4 @@ void CppConvertToBaseOptions(const BaseOptions& in,
|
|||
in.model_asset_path ? std::string(in.model_asset_path) : "";
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
} // namespace mediapipe::tasks::c::core
|
||||
|
|
|
@ -19,11 +19,11 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/c/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
namespace mediapipe::tasks::c::core {
|
||||
|
||||
void CppConvertToBaseOptions(const BaseOptions& in,
|
||||
mediapipe::tasks::core::BaseOptions* out);
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
} // namespace mediapipe::tasks::c::core
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_
|
||||
|
|
51
mediapipe/tasks/c/core/base_options_converter_test.cc
Normal file
51
mediapipe/tasks/c/core/base_options_converter_test.cc
Normal file
|
@ -0,0 +1,51 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/core/base_options_converter.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/c/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
|
||||
namespace mediapipe::tasks::c::core {
|
||||
|
||||
constexpr char kAssetBuffer[] = "abc";
|
||||
constexpr char kModelAssetPath[] = "abc.tflite";
|
||||
|
||||
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) {
|
||||
BaseOptions c_base_options = {/* model_asset_buffer= */ kAssetBuffer,
|
||||
/* model_asset_path= */ nullptr};
|
||||
|
||||
mediapipe::tasks::core::BaseOptions cpp_base_options = {};
|
||||
|
||||
CppConvertToBaseOptions(c_base_options, &cpp_base_options);
|
||||
EXPECT_EQ(*cpp_base_options.model_asset_buffer, std::string{kAssetBuffer});
|
||||
EXPECT_EQ(cpp_base_options.model_asset_path, "");
|
||||
}
|
||||
|
||||
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetPath) {
|
||||
BaseOptions c_base_options = {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_path= */ kModelAssetPath};
|
||||
|
||||
mediapipe::tasks::core::BaseOptions cpp_base_options = {};
|
||||
|
||||
CppConvertToBaseOptions(c_base_options, &cpp_base_options);
|
||||
EXPECT_EQ(cpp_base_options.model_asset_buffer.get(), nullptr);
|
||||
EXPECT_EQ(cpp_base_options.model_asset_path, std::string{kModelAssetPath});
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::core
|
|
@ -30,11 +30,11 @@ namespace mediapipe::tasks::c::text::text_classifier {
|
|||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::c::components::containers::CppConvertToBaseOptions;
|
||||
using ::mediapipe::tasks::c::components::containers::
|
||||
CppConvertToClassificationResult;
|
||||
using ::mediapipe::tasks::c::components::processors::
|
||||
CppConvertToClassifierOptions;
|
||||
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
|
||||
using ::mediapipe::tasks::text::text_classifier::TextClassifier;
|
||||
} // namespace
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user