Move LanguageDetectorResult converter to LanguageDetector task

PiperOrigin-RevId: 586812754
This commit is contained in:
Sebastian Schmidt 2023-11-30 15:56:39 -08:00 committed by Copybara-Service
parent 80e4e1599a
commit 3433ba083a
6 changed files with 46 additions and 46 deletions

View File

@ -212,26 +212,3 @@ cc_test(
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
], ],
) )
cc_library(
name = "language_detection_result_converter",
srcs = ["language_detection_result_converter.cc"],
hdrs = ["language_detection_result_converter.h"],
deps = [
"//mediapipe/tasks/c/text/language_detector",
"//mediapipe/tasks/cc/text/language_detector",
],
)
cc_test(
name = "language_detection_result_converter_test",
srcs = ["language_detection_result_converter_test.cc"],
linkstatic = 1,
deps = [
":language_detection_result_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/c/text/language_detector",
"//mediapipe/tasks/cc/text/language_detector",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -22,7 +22,7 @@ cc_library(
hdrs = ["language_detector.h"], hdrs = ["language_detector.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/tasks/c/components/containers:language_detection_result_converter", ":language_detector_result_converter",
"//mediapipe/tasks/c/components/processors:classifier_options", "//mediapipe/tasks/c/components/processors:classifier_options",
"//mediapipe/tasks/c/components/processors:classifier_options_converter", "//mediapipe/tasks/c/components/processors:classifier_options_converter",
"//mediapipe/tasks/c/core:base_options", "//mediapipe/tasks/c/core:base_options",
@ -77,6 +77,29 @@ cc_library(
], ],
) )
cc_library(
name = "language_detector_result_converter",
srcs = ["language_detector_result_converter.cc"],
hdrs = ["language_detector_result_converter.h"],
deps = [
":language_detector",
"//mediapipe/tasks/cc/text/language_detector",
],
)
cc_test(
name = "language_detector_result_converter_test",
srcs = ["language_detector_result_converter_test.cc"],
linkstatic = 1,
deps = [
":language_detector",
":language_detector_result_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/text/language_detector",
"@com_google_googletest//:gtest_main",
],
)
cc_test( cc_test(
name = "language_detector_test", name = "language_detector_test",
srcs = ["language_detector_test.cc"], srcs = ["language_detector_test.cc"],

View File

@ -20,9 +20,9 @@ limitations under the License.
#include "absl/log/absl_log.h" #include "absl/log/absl_log.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h"
#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h" #include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"
#include "mediapipe/tasks/c/core/base_options_converter.h" #include "mediapipe/tasks/c/core/base_options_converter.h"
#include "mediapipe/tasks/c/text/language_detector/language_detector_result_converter.h"
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" #include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
namespace mediapipe::tasks::c::text::language_detector { namespace mediapipe::tasks::c::text::language_detector {
@ -30,9 +30,9 @@ namespace mediapipe::tasks::c::text::language_detector {
namespace { namespace {
using ::mediapipe::tasks::c::components::containers:: using ::mediapipe::tasks::c::components::containers::
CppCloseLanguageDetectionResult; CppCloseLanguageDetectorResult;
using ::mediapipe::tasks::c::components::containers:: using ::mediapipe::tasks::c::components::containers::
CppConvertToLanguageDetectionResult; CppConvertToLanguageDetectorResult;
using ::mediapipe::tasks::c::components::processors:: using ::mediapipe::tasks::c::components::processors::
CppConvertToClassifierOptions; CppConvertToClassifierOptions;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions; using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
@ -72,16 +72,16 @@ int CppLanguageDetectorDetect(void* detector, const char* utf8_str,
auto cpp_detector = static_cast<LanguageDetector*>(detector); auto cpp_detector = static_cast<LanguageDetector*>(detector);
auto cpp_result = cpp_detector->Detect(utf8_str); auto cpp_result = cpp_detector->Detect(utf8_str);
if (!cpp_result.ok()) { if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Language Detection failed: " << cpp_result.status(); ABSL_LOG(ERROR) << "Language Detector failed: " << cpp_result.status();
return CppProcessError(cpp_result.status(), error_msg); return CppProcessError(cpp_result.status(), error_msg);
} }
CppConvertToLanguageDetectionResult(*cpp_result, result); CppConvertToLanguageDetectorResult(*cpp_result, result);
return 0; return 0;
} }
void CppLanguageDetectorCloseResult(LanguageDetectorResult* result) { void CppLanguageDetectorCloseResult(LanguageDetectorResult* result) {
CppCloseLanguageDetectionResult(result); CppCloseLanguageDetectorResult(result);
} }
int CppLanguageDetectorClose(void* detector, char** error_msg) { int CppLanguageDetectorClose(void* detector, char** error_msg) {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h" #include "mediapipe/tasks/c/text/language_detector/language_detector_result_converter.h"
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
@ -23,7 +23,7 @@ limitations under the License.
namespace mediapipe::tasks::c::components::containers { namespace mediapipe::tasks::c::components::containers {
void CppConvertToLanguageDetectionResult( void CppConvertToLanguageDetectorResult(
const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in, const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in,
LanguageDetectorResult* out) { LanguageDetectorResult* out) {
out->predictions_count = in.size(); out->predictions_count = in.size();
@ -42,7 +42,7 @@ void CppConvertToLanguageDetectionResult(
} }
} }
void CppCloseLanguageDetectionResult(LanguageDetectorResult* in) { void CppCloseLanguageDetectorResult(LanguageDetectorResult* in) {
for (uint32_t i = 0; i < in->predictions_count; ++i) { for (uint32_t i = 0; i < in->predictions_count; ++i) {
auto prediction_in = in->predictions[i]; auto prediction_in = in->predictions[i];

View File

@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_ #ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTOR_RESULT_CONVERTER_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_ #define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTOR_RESULT_CONVERTER_H_
#include "mediapipe/tasks/c/text/language_detector/language_detector.h" #include "mediapipe/tasks/c/text/language_detector/language_detector.h"
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" #include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
namespace mediapipe::tasks::c::components::containers { namespace mediapipe::tasks::c::components::containers {
void CppConvertToLanguageDetectionResult( void CppConvertToLanguageDetectorResult(
const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in, const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in,
LanguageDetectorResult* out); LanguageDetectorResult* out);
void CppCloseLanguageDetectionResult(LanguageDetectorResult* in); void CppCloseLanguageDetectorResult(LanguageDetectorResult* in);
} // namespace mediapipe::tasks::c::components::containers } // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_ #endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTOR_RESULT_CONVERTER_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h" #include "mediapipe/tasks/c/text/language_detector/language_detector_result_converter.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/text/language_detector/language_detector.h" #include "mediapipe/tasks/c/text/language_detector/language_detector.h"
@ -21,8 +21,8 @@ limitations under the License.
namespace mediapipe::tasks::c::components::containers { namespace mediapipe::tasks::c::components::containers {
TEST(LanguageDetectionResultConverterTest, TEST(LanguageDetectorResultConverterTest,
ConvertsLanguageDetectionResultCustomResult) { ConvertsLanguageDetectorResultCustomResult) {
mediapipe::tasks::text::language_detector::LanguageDetectorResult mediapipe::tasks::text::language_detector::LanguageDetectorResult
cpp_detector_result = {{/* language_code= */ "fr", cpp_detector_result = {{/* language_code= */ "fr",
/* probability= */ 0.5}, /* probability= */ 0.5},
@ -30,24 +30,24 @@ TEST(LanguageDetectionResultConverterTest,
/* probability= */ 0.5}}; /* probability= */ 0.5}};
LanguageDetectorResult c_detector_result; LanguageDetectorResult c_detector_result;
CppConvertToLanguageDetectionResult(cpp_detector_result, &c_detector_result); CppConvertToLanguageDetectorResult(cpp_detector_result, &c_detector_result);
EXPECT_NE(c_detector_result.predictions, nullptr); EXPECT_NE(c_detector_result.predictions, nullptr);
EXPECT_EQ(c_detector_result.predictions_count, 2); EXPECT_EQ(c_detector_result.predictions_count, 2);
EXPECT_NE(c_detector_result.predictions[0].language_code, "fr"); EXPECT_NE(c_detector_result.predictions[0].language_code, "fr");
EXPECT_EQ(c_detector_result.predictions[0].probability, 0.5); EXPECT_EQ(c_detector_result.predictions[0].probability, 0.5);
CppCloseLanguageDetectionResult(&c_detector_result); CppCloseLanguageDetectorResult(&c_detector_result);
} }
TEST(LanguageDetectionResultConverterTest, FreesMemory) { TEST(LanguageDetectorResultConverterTest, FreesMemory) {
mediapipe::tasks::text::language_detector::LanguageDetectorResult mediapipe::tasks::text::language_detector::LanguageDetectorResult
cpp_detector_result = {{"fr", 0.5}}; cpp_detector_result = {{"fr", 0.5}};
LanguageDetectorResult c_detector_result; LanguageDetectorResult c_detector_result;
CppConvertToLanguageDetectionResult(cpp_detector_result, &c_detector_result); CppConvertToLanguageDetectorResult(cpp_detector_result, &c_detector_result);
EXPECT_NE(c_detector_result.predictions, nullptr); EXPECT_NE(c_detector_result.predictions, nullptr);
CppCloseLanguageDetectionResult(&c_detector_result); CppCloseLanguageDetectorResult(&c_detector_result);
EXPECT_EQ(c_detector_result.predictions, nullptr); EXPECT_EQ(c_detector_result.predictions, nullptr);
} }