Move LanguageDetectorResult converter to LanguageDetector task

PiperOrigin-RevId: 586812754
This commit is contained in:
Sebastian Schmidt 2023-11-30 15:56:39 -08:00 committed by Copybara-Service
parent 80e4e1599a
commit 3433ba083a
6 changed files with 46 additions and 46 deletions

View File

@ -212,26 +212,3 @@ cc_test(
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "language_detection_result_converter",
srcs = ["language_detection_result_converter.cc"],
hdrs = ["language_detection_result_converter.h"],
deps = [
"//mediapipe/tasks/c/text/language_detector",
"//mediapipe/tasks/cc/text/language_detector",
],
)
cc_test(
name = "language_detection_result_converter_test",
srcs = ["language_detection_result_converter_test.cc"],
linkstatic = 1,
deps = [
":language_detection_result_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/c/text/language_detector",
"//mediapipe/tasks/cc/text/language_detector",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -22,7 +22,7 @@ cc_library(
hdrs = ["language_detector.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/tasks/c/components/containers:language_detection_result_converter",
":language_detector_result_converter",
"//mediapipe/tasks/c/components/processors:classifier_options",
"//mediapipe/tasks/c/components/processors:classifier_options_converter",
"//mediapipe/tasks/c/core:base_options",
@ -77,6 +77,29 @@ cc_library(
],
)
cc_library(
name = "language_detector_result_converter",
srcs = ["language_detector_result_converter.cc"],
hdrs = ["language_detector_result_converter.h"],
deps = [
":language_detector",
"//mediapipe/tasks/cc/text/language_detector",
],
)
cc_test(
name = "language_detector_result_converter_test",
srcs = ["language_detector_result_converter_test.cc"],
linkstatic = 1,
deps = [
":language_detector",
":language_detector_result_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/text/language_detector",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "language_detector_test",
srcs = ["language_detector_test.cc"],

View File

@ -20,9 +20,9 @@ limitations under the License.
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h"
#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"
#include "mediapipe/tasks/c/core/base_options_converter.h"
#include "mediapipe/tasks/c/text/language_detector/language_detector_result_converter.h"
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
namespace mediapipe::tasks::c::text::language_detector {
@ -30,9 +30,9 @@ namespace mediapipe::tasks::c::text::language_detector {
namespace {
using ::mediapipe::tasks::c::components::containers::
CppCloseLanguageDetectionResult;
CppCloseLanguageDetectorResult;
using ::mediapipe::tasks::c::components::containers::
CppConvertToLanguageDetectionResult;
CppConvertToLanguageDetectorResult;
using ::mediapipe::tasks::c::components::processors::
CppConvertToClassifierOptions;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
@ -72,16 +72,16 @@ int CppLanguageDetectorDetect(void* detector, const char* utf8_str,
auto cpp_detector = static_cast<LanguageDetector*>(detector);
auto cpp_result = cpp_detector->Detect(utf8_str);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Language Detection failed: " << cpp_result.status();
ABSL_LOG(ERROR) << "Language Detector failed: " << cpp_result.status();
return CppProcessError(cpp_result.status(), error_msg);
}
CppConvertToLanguageDetectionResult(*cpp_result, result);
CppConvertToLanguageDetectorResult(*cpp_result, result);
return 0;
}
void CppLanguageDetectorCloseResult(LanguageDetectorResult* result) {
CppCloseLanguageDetectionResult(result);
CppCloseLanguageDetectorResult(result);
}
int CppLanguageDetectorClose(void* detector, char** error_msg) {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h"
#include "mediapipe/tasks/c/text/language_detector/language_detector_result_converter.h"
#include <cstdint>
#include <cstdlib>
@ -23,7 +23,7 @@ limitations under the License.
namespace mediapipe::tasks::c::components::containers {
void CppConvertToLanguageDetectionResult(
void CppConvertToLanguageDetectorResult(
const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in,
LanguageDetectorResult* out) {
out->predictions_count = in.size();
@ -42,7 +42,7 @@ void CppConvertToLanguageDetectionResult(
}
}
void CppCloseLanguageDetectionResult(LanguageDetectorResult* in) {
void CppCloseLanguageDetectorResult(LanguageDetectorResult* in) {
for (uint32_t i = 0; i < in->predictions_count; ++i) {
auto prediction_in = in->predictions[i];

View File

@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTOR_RESULT_CONVERTER_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTOR_RESULT_CONVERTER_H_
#include "mediapipe/tasks/c/text/language_detector/language_detector.h"
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToLanguageDetectionResult(
void CppConvertToLanguageDetectorResult(
const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in,
LanguageDetectorResult* out);
void CppCloseLanguageDetectionResult(LanguageDetectorResult* in);
void CppCloseLanguageDetectorResult(LanguageDetectorResult* in);
} // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTOR_RESULT_CONVERTER_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h"
#include "mediapipe/tasks/c/text/language_detector/language_detector_result_converter.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/text/language_detector/language_detector.h"
@ -21,8 +21,8 @@ limitations under the License.
namespace mediapipe::tasks::c::components::containers {
TEST(LanguageDetectionResultConverterTest,
ConvertsLanguageDetectionResultCustomResult) {
TEST(LanguageDetectorResultConverterTest,
ConvertsLanguageDetectorResultCustomResult) {
mediapipe::tasks::text::language_detector::LanguageDetectorResult
cpp_detector_result = {{/* language_code= */ "fr",
/* probability= */ 0.5},
@ -30,24 +30,24 @@ TEST(LanguageDetectionResultConverterTest,
/* probability= */ 0.5}};
LanguageDetectorResult c_detector_result;
CppConvertToLanguageDetectionResult(cpp_detector_result, &c_detector_result);
CppConvertToLanguageDetectorResult(cpp_detector_result, &c_detector_result);
EXPECT_NE(c_detector_result.predictions, nullptr);
EXPECT_EQ(c_detector_result.predictions_count, 2);
EXPECT_NE(c_detector_result.predictions[0].language_code, "fr");
EXPECT_EQ(c_detector_result.predictions[0].probability, 0.5);
CppCloseLanguageDetectionResult(&c_detector_result);
CppCloseLanguageDetectorResult(&c_detector_result);
}
TEST(LanguageDetectionResultConverterTest, FreesMemory) {
TEST(LanguageDetectorResultConverterTest, FreesMemory) {
mediapipe::tasks::text::language_detector::LanguageDetectorResult
cpp_detector_result = {{"fr", 0.5}};
LanguageDetectorResult c_detector_result;
CppConvertToLanguageDetectionResult(cpp_detector_result, &c_detector_result);
CppConvertToLanguageDetectorResult(cpp_detector_result, &c_detector_result);
EXPECT_NE(c_detector_result.predictions, nullptr);
CppCloseLanguageDetectionResult(&c_detector_result);
CppCloseLanguageDetectorResult(&c_detector_result);
EXPECT_EQ(c_detector_result.predictions, nullptr);
}