From c7402efe5e57c0ab6368e194031feee3f531b343 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 29 Sep 2023 20:50:52 -0700 Subject: [PATCH] Add End to End test for Text Classifier C API PiperOrigin-RevId: 569658768 --- .../containers/category_converter.cc | 7 ++ .../containers/category_converter.h | 2 + .../containers/category_converter_test.cc | 5 +- .../classification_result_converter.cc | 16 +++++ .../classification_result_converter.h | 2 + .../classification_result_converter_test.cc | 8 +-- mediapipe/tasks/c/text/text_classifier/BUILD | 17 +++++ .../c/text/text_classifier/text_classifier.cc | 15 +++- .../c/text/text_classifier/text_classifier.h | 6 +- .../text_classifier/text_classifier_test.cc | 70 +++++++++++++++++++ 10 files changed, 139 insertions(+), 9 deletions(-) create mode 100644 mediapipe/tasks/c/text/text_classifier/text_classifier_test.cc diff --git a/mediapipe/tasks/c/components/containers/category_converter.cc b/mediapipe/tasks/c/components/containers/category_converter.cc index cb42ce0dc..a38ef10b0 100644 --- a/mediapipe/tasks/c/components/containers/category_converter.cc +++ b/mediapipe/tasks/c/components/containers/category_converter.cc @@ -15,6 +15,8 @@ limitations under the License. #include "mediapipe/tasks/c/components/containers/category_converter.h" +#include + #include "mediapipe/tasks/c/components/containers/category.h" #include "mediapipe/tasks/cc/components/containers/category.h" @@ -32,4 +34,9 @@ void CppConvertToCategory( in.display_name.has_value() ? strdup(in.display_name->c_str()) : nullptr; } +void CppCloseCategory(Category* in) { + free(in->category_name); + free(in->display_name); +} + } // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/category_converter.h b/mediapipe/tasks/c/components/containers/category_converter.h index c3e48f6e3..9edf539b2 100644 --- a/mediapipe/tasks/c/components/containers/category_converter.h +++ b/mediapipe/tasks/c/components/containers/category_converter.h @@ -25,6 +25,8 @@ void CppConvertToCategory( const mediapipe::tasks::components::containers::Category& in, Category* out); +void CppCloseCategory(Category* in); + } // namespace mediapipe::tasks::c::components::containers #endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/containers/category_converter_test.cc b/mediapipe/tasks/c/components/containers/category_converter_test.cc index 566bfa803..49c627ef0 100644 --- a/mediapipe/tasks/c/components/containers/category_converter_test.cc +++ b/mediapipe/tasks/c/components/containers/category_converter_test.cc @@ -40,8 +40,7 @@ TEST(CategoryConverterTest, ConvertsCategoryCustomValues) { EXPECT_EQ(std::string{c_category.category_name}, "category_name"); EXPECT_EQ(std::string{c_category.display_name}, "display_name"); - free(c_category.category_name); - free(c_category.display_name); + CppCloseCategory(&c_category); } TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) { @@ -58,6 +57,8 @@ TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) { EXPECT_FLOAT_EQ(c_category.score, 0.1); EXPECT_EQ(c_category.category_name, nullptr); EXPECT_EQ(c_category.display_name, nullptr); + + CppCloseCategory(&c_category); } } // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/classification_result_converter.cc b/mediapipe/tasks/c/components/containers/classification_result_converter.cc index a1e77ce48..64e62c309 100644 --- a/mediapipe/tasks/c/components/containers/classification_result_converter.cc +++ b/mediapipe/tasks/c/components/containers/classification_result_converter.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mediapipe/tasks/c/components/containers/classification_result_converter.h" #include +#include #include "mediapipe/tasks/c/components/containers/category.h" #include "mediapipe/tasks/c/components/containers/category_converter.h" @@ -57,4 +58,19 @@ void CppConvertToClassificationResult( } } +void CppCloseClassificationResult(ClassificationResult* in) { + for (uint32_t i = 0; i < in->classifications_count; ++i) { + auto classification_in = in->classifications[i]; + + for (uint32_t j = 0; j < classification_in.categories_count; ++j) { + CppCloseCategory(&classification_in.categories[j]); + } + delete[] classification_in.categories; + + free(classification_in.head_name); + } + + delete[] in->classifications; +} + } // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/classification_result_converter.h b/mediapipe/tasks/c/components/containers/classification_result_converter.h index be4c745bc..2e84d019d 100644 --- a/mediapipe/tasks/c/components/containers/classification_result_converter.h +++ b/mediapipe/tasks/c/components/containers/classification_result_converter.h @@ -25,6 +25,8 @@ void CppConvertToClassificationResult( const mediapipe::tasks::components::containers::ClassificationResult& in, ClassificationResult* out); +void CppCloseClassificationResult(ClassificationResult* in); + } // namespace mediapipe::tasks::c::components::containers #endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/containers/classification_result_converter_test.cc b/mediapipe/tasks/c/components/containers/classification_result_converter_test.cc index 6166ffdba..59eb53aaf 100644 --- a/mediapipe/tasks/c/components/containers/classification_result_converter_test.cc +++ b/mediapipe/tasks/c/components/containers/classification_result_converter_test.cc @@ -53,9 +53,7 @@ TEST(ClassificationResultConverterTest, EXPECT_EQ(c_classification_result.timestamp_ms, 42); EXPECT_EQ(c_classification_result.has_timestamp_ms, true); - free(c_classification_result.classifications[0].categories); - free(c_classification_result.classifications[0].head_name); - free(c_classification_result.classifications); + CppCloseClassificationResult(&c_classification_result); } TEST(ClassificationResultConverterTest, @@ -79,7 +77,7 @@ TEST(ClassificationResultConverterTest, EXPECT_EQ(c_classification_result.timestamp_ms, 0); EXPECT_EQ(c_classification_result.has_timestamp_ms, false); - free(c_classification_result.classifications); + CppCloseClassificationResult(&c_classification_result); } TEST(ClassificationResultConverterTest, @@ -97,6 +95,8 @@ TEST(ClassificationResultConverterTest, EXPECT_EQ(c_classification_result.classifications_count, 0); EXPECT_EQ(c_classification_result.timestamp_ms, 0); EXPECT_EQ(c_classification_result.has_timestamp_ms, false); + + CppCloseClassificationResult(&c_classification_result); } } // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/text/text_classifier/BUILD b/mediapipe/tasks/c/text/text_classifier/BUILD index ca6936658..29e421567 100644 --- a/mediapipe/tasks/c/text/text_classifier/BUILD +++ b/mediapipe/tasks/c/text/text_classifier/BUILD @@ -67,3 +67,20 @@ cc_binary( ], deps = [":text_classifier_lib"], ) + +cc_test( + name = "text_classifier_test", + srcs = ["text_classifier_test.cc"], + data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"], + linkstatic = 1, + deps = [ + ":text_classifier_lib", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest", + "//mediapipe/tasks/c/components/containers:category", + "//mediapipe/tasks/cc/components/containers:category", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier.cc b/mediapipe/tasks/c/text/text_classifier/text_classifier.cc index 53fe34f75..cc3d44a69 100644 --- a/mediapipe/tasks/c/text/text_classifier/text_classifier.cc +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier.cc @@ -30,6 +30,8 @@ namespace mediapipe::tasks::c::text::text_classifier { namespace { +using ::mediapipe::tasks::c::components::containers:: + CppCloseClassificationResult; using ::mediapipe::tasks::c::components::containers:: CppConvertToClassificationResult; using ::mediapipe::tasks::c::components::processors:: @@ -55,7 +57,7 @@ TextClassifier* CppTextClassifierCreate(const TextClassifierOptions& options) { return classifier->release(); } -bool CppTextClassifierClassify(void* classifier, char* utf8_str, +bool CppTextClassifierClassify(void* classifier, const char* utf8_str, TextClassifierResult* result) { auto cpp_classifier = static_cast(classifier); auto cpp_result = cpp_classifier->Classify(utf8_str); @@ -67,6 +69,10 @@ bool CppTextClassifierClassify(void* classifier, char* utf8_str, return true; } +void CppTextClassifierCloseResult(TextClassifierResult* result) { + CppCloseClassificationResult(result); +} + void CppTextClassifierClose(void* classifier) { auto cpp_classifier = static_cast(classifier); auto result = cpp_classifier->Close(); @@ -85,12 +91,17 @@ void* text_classifier_create(struct TextClassifierOptions* options) { *options); } -int text_classifier_classify(void* classifier, char* utf8_str, +int text_classifier_classify(void* classifier, const char* utf8_str, TextClassifierResult* result) { return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClassify( classifier, utf8_str, result); } +void text_classifier_close_result(TextClassifierResult* result) { + mediapipe::tasks::c::text::text_classifier::CppTextClassifierCloseResult( + result); +} + void text_classifier_close(void* classifier) { mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose( classifier); diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier.h b/mediapipe/tasks/c/text/text_classifier/text_classifier.h index 1ba140883..7794eb0b3 100644 --- a/mediapipe/tasks/c/text/text_classifier/text_classifier.h +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier.h @@ -45,9 +45,13 @@ struct TextClassifierOptions { MP_EXPORT void* text_classifier_create(struct TextClassifierOptions* options); // Performs classification on the input `text`. -MP_EXPORT int text_classifier_classify(void* classifier, char* utf8_str, +MP_EXPORT int text_classifier_classify(void* classifier, const char* utf8_str, TextClassifierResult* result); +// Frees the memory allocated inside a TextClassifierResult result. Does not +// free the result pointer itself. +MP_EXPORT void text_classifier_close_result(TextClassifierResult* result); + // Shuts down the TextClassifier when all the work is done. Frees all memory. MP_EXPORT void text_classifier_close(void* classifier); diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/c/text/text_classifier/text_classifier_test.cc new file mode 100644 index 000000000..e3815dd5f --- /dev/null +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier_test.cc @@ -0,0 +1,70 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/text/text_classifier/text_classifier.h" + +#include + +#include "absl/flags/flag.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/c/components/containers/category.h" + +namespace { + +using ::mediapipe::file::JoinPath; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; +constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite"; +constexpr char kTestString[] = "It's beautiful outside."; +constexpr float kPrecision = 1e-6; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +TEST(TextClassifierTest, SmokeTest) { + std::string model_path = GetFullPath(kTestBertModelPath); + TextClassifierOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ model_path.c_str()}, + /* classifier_options= */ + {/* display_names_locale= */ nullptr, + /* max_results= */ -1, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0}, + }; + + void* classifier = text_classifier_create(&options); + EXPECT_NE(classifier, nullptr); + + TextClassifierResult result; + text_classifier_classify(classifier, kTestString, &result); + EXPECT_EQ(result.classifications_count, 1); + EXPECT_EQ(result.classifications[0].categories_count, 2); + EXPECT_EQ(std::string{result.classifications[0].categories[0].category_name}, + "positive"); + EXPECT_NEAR(result.classifications[0].categories[0].score, 0.999371, + kPrecision); + + text_classifier_close_result(&result); + text_classifier_close(classifier); +} + +} // namespace