Add unit tests for C layer for the input types of Text Classifier
PiperOrigin-RevId: 569553038
This commit is contained in:
parent
6915a79e28
commit
96fa10b906
|
@ -30,3 +30,15 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "classifier_options_converter_test",
|
||||||
|
srcs = ["classifier_options_converter_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":classifier_options",
|
||||||
|
":classifier_options_converter",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -26,7 +26,7 @@ extern "C" {
|
||||||
struct ClassifierOptions {
|
struct ClassifierOptions {
|
||||||
// The locale to use for display names specified through the TFLite Model
|
// The locale to use for display names specified through the TFLite Model
|
||||||
// Metadata, if any. Defaults to English.
|
// Metadata, if any. Defaults to English.
|
||||||
char* display_names_locale;
|
const char* display_names_locale;
|
||||||
|
|
||||||
// The maximum number of top-scored classification results to return. If < 0,
|
// The maximum number of top-scored classification results to return. If < 0,
|
||||||
// all available results will be returned. If 0, an invalid argument error is
|
// all available results will be returned. If 0, an invalid argument error is
|
||||||
|
@ -40,14 +40,14 @@ struct ClassifierOptions {
|
||||||
// The allowlist of category names. If non-empty, detection results whose
|
// The allowlist of category names. If non-empty, detection results whose
|
||||||
// category name is not in this set will be filtered out. Duplicate or unknown
|
// category name is not in this set will be filtered out. Duplicate or unknown
|
||||||
// category names are ignored. Mutually exclusive with category_denylist.
|
// category names are ignored. Mutually exclusive with category_denylist.
|
||||||
char** category_allowlist;
|
const char** category_allowlist;
|
||||||
// The number of elements in the category allowlist.
|
// The number of elements in the category allowlist.
|
||||||
uint32_t category_allowlist_count;
|
uint32_t category_allowlist_count;
|
||||||
|
|
||||||
// The denylist of category names. If non-empty, detection results whose
|
// The denylist of category names. If non-empty, detection results whose
|
||||||
// category name is in this set will be filtered out. Duplicate or unknown
|
// category name is in this set will be filtered out. Duplicate or unknown
|
||||||
// category names are ignored. Mutually exclusive with category_allowlist.
|
// category names are ignored. Mutually exclusive with category_allowlist.
|
||||||
char** category_denylist;
|
const char** category_denylist;
|
||||||
// The number of elements in the category denylist.
|
// The number of elements in the category denylist.
|
||||||
uint32_t category_denylist_count;
|
uint32_t category_denylist_count;
|
||||||
};
|
};
|
||||||
|
|
|
@ -26,7 +26,7 @@ void CppConvertToClassifierOptions(
|
||||||
const ClassifierOptions& in,
|
const ClassifierOptions& in,
|
||||||
mediapipe::tasks::components::processors::ClassifierOptions* out) {
|
mediapipe::tasks::components::processors::ClassifierOptions* out) {
|
||||||
out->display_names_locale =
|
out->display_names_locale =
|
||||||
in.display_names_locale ? std::string(in.display_names_locale) : "";
|
in.display_names_locale ? std::string(in.display_names_locale) : "en";
|
||||||
out->max_results = in.max_results;
|
out->max_results = in.max_results;
|
||||||
out->score_threshold = in.score_threshold;
|
out->score_threshold = in.score_threshold;
|
||||||
out->category_allowlist =
|
out->category_allowlist =
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::c::components::processors {
|
||||||
|
|
||||||
|
constexpr char kCategoryAllowlist[] = "fruit";
|
||||||
|
constexpr char kCategoryDenylist[] = "veggies";
|
||||||
|
constexpr char kDisplayNamesLocaleGerman[] = "de";
|
||||||
|
|
||||||
|
TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsCustomValues) {
|
||||||
|
std::vector<const char*> category_allowlist = {kCategoryAllowlist};
|
||||||
|
std::vector<const char*> category_denylist = {kCategoryDenylist};
|
||||||
|
|
||||||
|
ClassifierOptions c_classifier_options = {
|
||||||
|
/* display_names_locale= */ kDisplayNamesLocaleGerman,
|
||||||
|
/* max_results= */ 1,
|
||||||
|
/* score_threshold= */ 0.1,
|
||||||
|
/* category_allowlist= */ category_allowlist.data(),
|
||||||
|
/* category_allowlist_count= */ 1,
|
||||||
|
/* category_denylist= */ category_denylist.data(),
|
||||||
|
/* category_denylist_count= */ 1};
|
||||||
|
|
||||||
|
mediapipe::tasks::components::processors::ClassifierOptions
|
||||||
|
cpp_classifier_options = {};
|
||||||
|
|
||||||
|
CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
|
||||||
|
EXPECT_EQ(cpp_classifier_options.display_names_locale, "de");
|
||||||
|
EXPECT_EQ(cpp_classifier_options.max_results, 1);
|
||||||
|
EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.1);
|
||||||
|
EXPECT_EQ(cpp_classifier_options.category_allowlist,
|
||||||
|
std::vector<std::string>{"fruit"});
|
||||||
|
EXPECT_EQ(cpp_classifier_options.category_denylist,
|
||||||
|
std::vector<std::string>{"veggies"});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsDefaultValues) {
|
||||||
|
std::vector<const char*> category_allowlist = {kCategoryAllowlist};
|
||||||
|
std::vector<const char*> category_denylist = {kCategoryDenylist};
|
||||||
|
|
||||||
|
ClassifierOptions c_classifier_options = {/* display_names_locale= */ nullptr,
|
||||||
|
/* max_results= */ -1,
|
||||||
|
/* score_threshold= */ 0.0,
|
||||||
|
/* category_allowlist= */ nullptr,
|
||||||
|
/* category_allowlist_count= */ 0,
|
||||||
|
/* category_denylist= */ nullptr,
|
||||||
|
/* category_denylist_count= */ 0};
|
||||||
|
|
||||||
|
mediapipe::tasks::components::processors::ClassifierOptions
|
||||||
|
cpp_classifier_options = {};
|
||||||
|
|
||||||
|
CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
|
||||||
|
EXPECT_EQ(cpp_classifier_options.display_names_locale, "en");
|
||||||
|
EXPECT_EQ(cpp_classifier_options.max_results, -1);
|
||||||
|
EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.0);
|
||||||
|
EXPECT_EQ(cpp_classifier_options.category_allowlist,
|
||||||
|
std::vector<std::string>{});
|
||||||
|
EXPECT_EQ(cpp_classifier_options.category_denylist,
|
||||||
|
std::vector<std::string>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::c::components::processors
|
|
@ -30,3 +30,15 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "base_options_converter_test",
|
||||||
|
srcs = ["base_options_converter_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":base_options",
|
||||||
|
":base_options_converter",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -23,10 +23,10 @@ extern "C" {
|
||||||
// Base options for MediaPipe C Tasks.
|
// Base options for MediaPipe C Tasks.
|
||||||
struct BaseOptions {
|
struct BaseOptions {
|
||||||
// The model asset file contents as a string.
|
// The model asset file contents as a string.
|
||||||
char* model_asset_buffer;
|
const char* model_asset_buffer;
|
||||||
|
|
||||||
// The path to the model asset to open and mmap in memory.
|
// The path to the model asset to open and mmap in memory.
|
||||||
char* model_asset_path;
|
const char* model_asset_path;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/c/core/base_options.h"
|
#include "mediapipe/tasks/c/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::c::components::containers {
|
namespace mediapipe::tasks::c::core {
|
||||||
|
|
||||||
void CppConvertToBaseOptions(const BaseOptions& in,
|
void CppConvertToBaseOptions(const BaseOptions& in,
|
||||||
mediapipe::tasks::core::BaseOptions* out) {
|
mediapipe::tasks::core::BaseOptions* out) {
|
||||||
|
@ -33,4 +33,4 @@ void CppConvertToBaseOptions(const BaseOptions& in,
|
||||||
in.model_asset_path ? std::string(in.model_asset_path) : "";
|
in.model_asset_path ? std::string(in.model_asset_path) : "";
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe::tasks::c::components::containers
|
} // namespace mediapipe::tasks::c::core
|
||||||
|
|
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/c/core/base_options.h"
|
#include "mediapipe/tasks/c/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::c::components::containers {
|
namespace mediapipe::tasks::c::core {
|
||||||
|
|
||||||
void CppConvertToBaseOptions(const BaseOptions& in,
|
void CppConvertToBaseOptions(const BaseOptions& in,
|
||||||
mediapipe::tasks::core::BaseOptions* out);
|
mediapipe::tasks::core::BaseOptions* out);
|
||||||
|
|
||||||
} // namespace mediapipe::tasks::c::components::containers
|
} // namespace mediapipe::tasks::c::core
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_
|
#endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_
|
||||||
|
|
51
mediapipe/tasks/c/core/base_options_converter_test.cc
Normal file
51
mediapipe/tasks/c/core/base_options_converter_test.cc
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/core/base_options_converter.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/c/core/base_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::c::core {
|
||||||
|
|
||||||
|
constexpr char kAssetBuffer[] = "abc";
|
||||||
|
constexpr char kModelAssetPath[] = "abc.tflite";
|
||||||
|
|
||||||
|
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) {
|
||||||
|
BaseOptions c_base_options = {/* model_asset_buffer= */ kAssetBuffer,
|
||||||
|
/* model_asset_path= */ nullptr};
|
||||||
|
|
||||||
|
mediapipe::tasks::core::BaseOptions cpp_base_options = {};
|
||||||
|
|
||||||
|
CppConvertToBaseOptions(c_base_options, &cpp_base_options);
|
||||||
|
EXPECT_EQ(*cpp_base_options.model_asset_buffer, std::string{kAssetBuffer});
|
||||||
|
EXPECT_EQ(cpp_base_options.model_asset_path, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetPath) {
|
||||||
|
BaseOptions c_base_options = {/* model_asset_buffer= */ nullptr,
|
||||||
|
/* model_asset_path= */ kModelAssetPath};
|
||||||
|
|
||||||
|
mediapipe::tasks::core::BaseOptions cpp_base_options = {};
|
||||||
|
|
||||||
|
CppConvertToBaseOptions(c_base_options, &cpp_base_options);
|
||||||
|
EXPECT_EQ(cpp_base_options.model_asset_buffer.get(), nullptr);
|
||||||
|
EXPECT_EQ(cpp_base_options.model_asset_path, std::string{kModelAssetPath});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::c::core
|
|
@ -30,11 +30,11 @@ namespace mediapipe::tasks::c::text::text_classifier {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::tasks::c::components::containers::CppConvertToBaseOptions;
|
|
||||||
using ::mediapipe::tasks::c::components::containers::
|
using ::mediapipe::tasks::c::components::containers::
|
||||||
CppConvertToClassificationResult;
|
CppConvertToClassificationResult;
|
||||||
using ::mediapipe::tasks::c::components::processors::
|
using ::mediapipe::tasks::c::components::processors::
|
||||||
CppConvertToClassifierOptions;
|
CppConvertToClassifierOptions;
|
||||||
|
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
|
||||||
using ::mediapipe::tasks::text::text_classifier::TextClassifier;
|
using ::mediapipe::tasks::text::text_classifier::TextClassifier;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user