Merge pull request #4860 from kinaryml:c-language-detector-api
PiperOrigin-RevId: 572385111
This commit is contained in:
commit
84f6959f9d
|
@ -98,3 +98,26 @@ cc_test(
|
|||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "language_detection_result_converter",
|
||||
srcs = ["language_detection_result_converter.cc"],
|
||||
hdrs = ["language_detection_result_converter.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/c/text/language_detector",
|
||||
"//mediapipe/tasks/cc/text/language_detector",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "language_detection_result_converter_test",
|
||||
srcs = ["language_detection_result_converter_test.cc"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":language_detection_result_converter",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/tasks/c/text/language_detector",
|
||||
"//mediapipe/tasks/cc/text/language_detector",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "mediapipe/tasks/c/text/language_detector/language_detector.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
void CppConvertToLanguageDetectionResult(
|
||||
const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in,
|
||||
LanguageDetectorResult* out) {
|
||||
out->predictions_count = in.size();
|
||||
out->predictions =
|
||||
out->predictions_count
|
||||
? new LanguageDetectorPrediction[out->predictions_count]
|
||||
: nullptr;
|
||||
|
||||
for (uint32_t i = 0; i < out->predictions_count; ++i) {
|
||||
auto language_detection_prediction_in = in[i];
|
||||
auto& language_detection_prediction_out = out->predictions[i];
|
||||
language_detection_prediction_out.probability =
|
||||
language_detection_prediction_in.probability;
|
||||
language_detection_prediction_out.language_code =
|
||||
strdup(language_detection_prediction_in.language_code.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void CppCloseLanguageDetectionResult(LanguageDetectorResult* in) {
|
||||
for (uint32_t i = 0; i < in->predictions_count; ++i) {
|
||||
auto prediction_in = in->predictions[i];
|
||||
|
||||
free(prediction_in.language_code);
|
||||
prediction_in.language_code = nullptr;
|
||||
}
|
||||
delete[] in->predictions;
|
||||
in->predictions = nullptr;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
|
@ -0,0 +1,32 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_
|
||||
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_
|
||||
|
||||
#include "mediapipe/tasks/c/text/language_detector/language_detector.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
void CppConvertToLanguageDetectionResult(
|
||||
const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in,
|
||||
LanguageDetectorResult* out);
|
||||
|
||||
void CppCloseLanguageDetectionResult(LanguageDetectorResult* in);
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_
|
|
@ -0,0 +1,54 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h"
|
||||
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/c/text/language_detector/language_detector.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||
|
||||
namespace mediapipe::tasks::c::components::containers {
|
||||
|
||||
TEST(LanguageDetectionResultConverterTest,
|
||||
ConvertsLanguageDetectionResultCustomResult) {
|
||||
mediapipe::tasks::text::language_detector::LanguageDetectorResult
|
||||
cpp_detector_result = {{/* language_code= */ "fr",
|
||||
/* probability= */ 0.5},
|
||||
{/* language_code= */ "en",
|
||||
/* probability= */ 0.5}};
|
||||
|
||||
LanguageDetectorResult c_detector_result;
|
||||
CppConvertToLanguageDetectionResult(cpp_detector_result, &c_detector_result);
|
||||
EXPECT_NE(c_detector_result.predictions, nullptr);
|
||||
EXPECT_EQ(c_detector_result.predictions_count, 2);
|
||||
EXPECT_NE(c_detector_result.predictions[0].language_code, "fr");
|
||||
EXPECT_EQ(c_detector_result.predictions[0].probability, 0.5);
|
||||
|
||||
CppCloseLanguageDetectionResult(&c_detector_result);
|
||||
}
|
||||
|
||||
TEST(LanguageDetectionResultConverterTest, FreesMemory) {
|
||||
mediapipe::tasks::text::language_detector::LanguageDetectorResult
|
||||
cpp_detector_result = {{"fr", 0.5}};
|
||||
|
||||
LanguageDetectorResult c_detector_result;
|
||||
CppConvertToLanguageDetectionResult(cpp_detector_result, &c_detector_result);
|
||||
EXPECT_NE(c_detector_result.predictions, nullptr);
|
||||
|
||||
CppCloseLanguageDetectionResult(&c_detector_result);
|
||||
EXPECT_EQ(c_detector_result.predictions, nullptr);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
93
mediapipe/tasks/c/text/language_detector/BUILD
Normal file
93
mediapipe/tasks/c/text/language_detector/BUILD
Normal file
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "language_detector_lib",
|
||||
srcs = ["language_detector.cc"],
|
||||
hdrs = ["language_detector.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/c/components/containers:language_detection_result_converter",
|
||||
"//mediapipe/tasks/c/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/c/components/processors:classifier_options_converter",
|
||||
"//mediapipe/tasks/c/core:base_options",
|
||||
"//mediapipe/tasks/c/core:base_options_converter",
|
||||
"//mediapipe/tasks/cc/text/language_detector",
|
||||
"@com_google_absl//absl/log:absl_log",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# bazel build -c opt --linkopt -s --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
|
||||
# //mediapipe/tasks/c/text/language_detector:liblanguage_detector.so
|
||||
cc_binary(
|
||||
name = "liblanguage_detector.so",
|
||||
linkopts = [
|
||||
"-Wl,-soname=liblanguage_detector.so",
|
||||
"-fvisibility=hidden",
|
||||
],
|
||||
linkshared = True,
|
||||
tags = [
|
||||
"manual",
|
||||
"nobuilder",
|
||||
"notap",
|
||||
],
|
||||
deps = [":language_detector_lib"],
|
||||
)
|
||||
|
||||
# bazel build --config darwin_arm64 -c opt --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
|
||||
# //mediapipe/tasks/c/text/language_detector:liblanguage_detector.dylib
|
||||
cc_binary(
|
||||
name = "liblanguage_detector.dylib",
|
||||
linkopts = [
|
||||
"-Wl,-install_name,liblanguage_detector.dylib",
|
||||
"-fvisibility=hidden",
|
||||
],
|
||||
linkshared = True,
|
||||
tags = [
|
||||
"manual",
|
||||
"nobuilder",
|
||||
"notap",
|
||||
],
|
||||
deps = [":language_detector_lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "language_detector",
|
||||
hdrs = ["language_detector.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/c/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/c/core:base_options",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "language_detector_test",
|
||||
srcs = ["language_detector_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:language_detector"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":language_detector_lib",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
124
mediapipe/tasks/c/text/language_detector/language_detector.cc
Normal file
124
mediapipe/tasks/c/text/language_detector/language_detector.cc
Normal file
|
@ -0,0 +1,124 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/text/language_detector/language_detector.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h"
|
||||
#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"
|
||||
#include "mediapipe/tasks/c/core/base_options_converter.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||
|
||||
namespace mediapipe::tasks::c::text::language_detector {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::c::components::containers::
|
||||
CppCloseLanguageDetectionResult;
|
||||
using ::mediapipe::tasks::c::components::containers::
|
||||
CppConvertToLanguageDetectionResult;
|
||||
using ::mediapipe::tasks::c::components::processors::
|
||||
CppConvertToClassifierOptions;
|
||||
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
|
||||
using ::mediapipe::tasks::text::language_detector::LanguageDetector;
|
||||
|
||||
int CppProcessError(absl::Status status, char** error_msg) {
|
||||
if (error_msg) {
|
||||
*error_msg = strdup(status.ToString().c_str());
|
||||
}
|
||||
return status.raw_code();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LanguageDetector* CppLanguageDetectorCreate(
|
||||
const LanguageDetectorOptions& options, char** error_msg) {
|
||||
auto cpp_options = std::make_unique<
|
||||
::mediapipe::tasks::text::language_detector::LanguageDetectorOptions>();
|
||||
|
||||
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
|
||||
CppConvertToClassifierOptions(options.classifier_options,
|
||||
&cpp_options->classifier_options);
|
||||
|
||||
auto detector = LanguageDetector::Create(std::move(cpp_options));
|
||||
if (!detector.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to create LanguageDetector: "
|
||||
<< detector.status();
|
||||
CppProcessError(detector.status(), error_msg);
|
||||
return nullptr;
|
||||
}
|
||||
return detector->release();
|
||||
}
|
||||
|
||||
int CppLanguageDetectorDetect(void* detector, const char* utf8_str,
|
||||
LanguageDetectorResult* result,
|
||||
char** error_msg) {
|
||||
auto cpp_detector = static_cast<LanguageDetector*>(detector);
|
||||
auto cpp_result = cpp_detector->Detect(utf8_str);
|
||||
if (!cpp_result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Language Detection failed: " << cpp_result.status();
|
||||
return CppProcessError(cpp_result.status(), error_msg);
|
||||
}
|
||||
|
||||
CppConvertToLanguageDetectionResult(*cpp_result, result);
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CppLanguageDetectorCloseResult(LanguageDetectorResult* result) {
|
||||
CppCloseLanguageDetectionResult(result);
|
||||
}
|
||||
|
||||
int CppLanguageDetectorClose(void* detector, char** error_msg) {
|
||||
auto cpp_detector = static_cast<LanguageDetector*>(detector);
|
||||
auto result = cpp_detector->Close();
|
||||
if (!result.ok()) {
|
||||
ABSL_LOG(ERROR) << "Failed to close LanguageDetector: " << result;
|
||||
return CppProcessError(result, error_msg);
|
||||
}
|
||||
delete cpp_detector;
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::text::language_detector
|
||||
|
||||
extern "C" {
|
||||
|
||||
void* language_detector_create(struct LanguageDetectorOptions* options,
|
||||
char** error_msg) {
|
||||
return mediapipe::tasks::c::text::language_detector::
|
||||
CppLanguageDetectorCreate(*options, error_msg);
|
||||
}
|
||||
|
||||
int language_detector_detect(void* detector, const char* utf8_str,
|
||||
LanguageDetectorResult* result, char** error_msg) {
|
||||
return mediapipe::tasks::c::text::language_detector::
|
||||
CppLanguageDetectorDetect(detector, utf8_str, result, error_msg);
|
||||
}
|
||||
|
||||
void language_detector_close_result(LanguageDetectorResult* result) {
|
||||
mediapipe::tasks::c::text::language_detector::CppLanguageDetectorCloseResult(
|
||||
result);
|
||||
}
|
||||
|
||||
int language_detector_close(void* detector, char** error_ms) {
|
||||
return mediapipe::tasks::c::text::language_detector::CppLanguageDetectorClose(
|
||||
detector, error_ms);
|
||||
}
|
||||
|
||||
} // extern "C"
|
91
mediapipe/tasks/c/text/language_detector/language_detector.h
Normal file
91
mediapipe/tasks/c/text/language_detector/language_detector.h
Normal file
|
@ -0,0 +1,91 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
||||
#define MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/c/core/base_options.h"
|
||||
|
||||
#ifndef MP_EXPORT
|
||||
#define MP_EXPORT __attribute__((visibility("default")))
|
||||
#endif // MP_EXPORT
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// A language code and its probability.
|
||||
struct LanguageDetectorPrediction {
|
||||
// An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek,
|
||||
// "ja"-Latn for Japanese (romaji).
|
||||
char* language_code;
|
||||
|
||||
float probability;
|
||||
};
|
||||
|
||||
// Task output.
|
||||
struct LanguageDetectorResult {
|
||||
struct LanguageDetectorPrediction* predictions;
|
||||
|
||||
// The count of predictions.
|
||||
uint32_t predictions_count;
|
||||
};
|
||||
|
||||
// The options for configuring a MediaPipe language detector task.
|
||||
struct LanguageDetectorOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
struct BaseOptions base_options;
|
||||
|
||||
// Options for configuring the detector behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
struct ClassifierOptions classifier_options;
|
||||
};
|
||||
|
||||
// Creates a LanguageDetector from the provided `options`.
|
||||
// Returns a pointer to the language detector on success.
|
||||
// If an error occurs, returns `nullptr` and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not nullptr). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT void* language_detector_create(
|
||||
struct LanguageDetectorOptions* options, char** error_msg = nullptr);
|
||||
|
||||
// Performs language detection on the input `text`. Returns `0` on success.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not nullptr). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int language_detector_detect(void* detector, const char* utf8_str,
|
||||
LanguageDetectorResult* result,
|
||||
char** error_msg = nullptr);
|
||||
|
||||
// Frees the memory allocated inside a LanguageDetectorResult result. Does not
|
||||
// free the result pointer itself.
|
||||
MP_EXPORT void language_detector_close_result(LanguageDetectorResult* result);
|
||||
|
||||
// Shuts down the LanguageDetector when all the work is done. Frees all memory.
|
||||
// If an error occurs, returns an error code and sets the error parameter to an
|
||||
// an error message (if `error_msg` is not nullptr). You must free the memory
|
||||
// allocated for the error message.
|
||||
MP_EXPORT int language_detector_close(void* detector,
|
||||
char** error_msg = nullptr);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
|
@ -0,0 +1,87 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/text/language_detector/language_detector.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using testing::HasSubstr;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
constexpr char kTestLanguageDetectorModelPath[] = "language_detector.tflite";
|
||||
constexpr char kTestString[] =
|
||||
"Il y a beaucoup de bouches qui parlent et fort peu "
|
||||
"de têtes qui pensent.";
|
||||
constexpr float kPrecision = 1e-6;
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
TEST(LanguageDetectorTest, SmokeTest) {
|
||||
std::string model_path = GetFullPath(kTestLanguageDetectorModelPath);
|
||||
LanguageDetectorOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* classifier_options= */
|
||||
{/* display_names_locale= */ nullptr,
|
||||
/* max_results= */ -1,
|
||||
/* score_threshold= */ 0.0,
|
||||
/* category_allowlist= */ nullptr,
|
||||
/* category_allowlist_count= */ 0,
|
||||
/* category_denylist= */ nullptr,
|
||||
/* category_denylist_count= */ 0},
|
||||
};
|
||||
|
||||
void* detector = language_detector_create(&options);
|
||||
EXPECT_NE(detector, nullptr);
|
||||
|
||||
LanguageDetectorResult result;
|
||||
language_detector_detect(detector, kTestString, &result);
|
||||
EXPECT_EQ(std::string(result.predictions[0].language_code), "fr");
|
||||
EXPECT_NEAR(result.predictions[0].probability, 0.999781, kPrecision);
|
||||
|
||||
language_detector_close_result(&result);
|
||||
language_detector_close(detector);
|
||||
}
|
||||
|
||||
TEST(LanguageDetectorTest, ErrorHandling) {
|
||||
// It is an error to set neither the asset buffer nor the path.
|
||||
LanguageDetectorOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_path= */ nullptr},
|
||||
/* classifier_options= */ {},
|
||||
};
|
||||
|
||||
char* error_msg;
|
||||
void* detector = language_detector_create(&options, &error_msg);
|
||||
EXPECT_EQ(detector, nullptr);
|
||||
|
||||
EXPECT_THAT(error_msg, HasSubstr("INVALID_ARGUMENT"));
|
||||
|
||||
free(error_msg);
|
||||
}
|
||||
|
||||
} // namespace
|
Loading…
Reference in New Issue
Block a user