Add End to End test for Text Classifier C API
PiperOrigin-RevId: 569658768
This commit is contained in:
		
							parent
							
								
									96fa10b906
								
							
						
					
					
						commit
						c7402efe5e
					
				|  | @ -15,6 +15,8 @@ limitations under the License. | |||
| 
 | ||||
| #include "mediapipe/tasks/c/components/containers/category_converter.h" | ||||
| 
 | ||||
| #include <cstdlib> | ||||
| 
 | ||||
| #include "mediapipe/tasks/c/components/containers/category.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/category.h" | ||||
| 
 | ||||
|  | @ -32,4 +34,9 @@ void CppConvertToCategory( | |||
|       in.display_name.has_value() ? strdup(in.display_name->c_str()) : nullptr; | ||||
| } | ||||
| 
 | ||||
| void CppCloseCategory(Category* in) { | ||||
|   free(in->category_name); | ||||
|   free(in->display_name); | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tasks::c::components::containers
 | ||||
|  |  | |||
|  | @ -25,6 +25,8 @@ void CppConvertToCategory( | |||
|     const mediapipe::tasks::components::containers::Category& in, | ||||
|     Category* out); | ||||
| 
 | ||||
| void CppCloseCategory(Category* in); | ||||
| 
 | ||||
| }  // namespace mediapipe::tasks::c::components::containers
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_CONVERTER_H_
 | ||||
|  |  | |||
|  | @ -40,8 +40,7 @@ TEST(CategoryConverterTest, ConvertsCategoryCustomValues) { | |||
|   EXPECT_EQ(std::string{c_category.category_name}, "category_name"); | ||||
|   EXPECT_EQ(std::string{c_category.display_name}, "display_name"); | ||||
| 
 | ||||
|   free(c_category.category_name); | ||||
|   free(c_category.display_name); | ||||
|   CppCloseCategory(&c_category); | ||||
| } | ||||
| 
 | ||||
| TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) { | ||||
|  | @ -58,6 +57,8 @@ TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) { | |||
|   EXPECT_FLOAT_EQ(c_category.score, 0.1); | ||||
|   EXPECT_EQ(c_category.category_name, nullptr); | ||||
|   EXPECT_EQ(c_category.display_name, nullptr); | ||||
| 
 | ||||
|   CppCloseCategory(&c_category); | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tasks::c::components::containers
 | ||||
|  |  | |||
|  | @ -16,6 +16,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/c/components/containers/classification_result_converter.h" | ||||
| 
 | ||||
| #include <cstdint> | ||||
| #include <cstdlib> | ||||
| 
 | ||||
| #include "mediapipe/tasks/c/components/containers/category.h" | ||||
| #include "mediapipe/tasks/c/components/containers/category_converter.h" | ||||
|  | @ -57,4 +58,19 @@ void CppConvertToClassificationResult( | |||
|   } | ||||
| } | ||||
| 
 | ||||
| void CppCloseClassificationResult(ClassificationResult* in) { | ||||
|   for (uint32_t i = 0; i < in->classifications_count; ++i) { | ||||
|     auto classification_in = in->classifications[i]; | ||||
| 
 | ||||
|     for (uint32_t j = 0; j < classification_in.categories_count; ++j) { | ||||
|       CppCloseCategory(&classification_in.categories[j]); | ||||
|     } | ||||
|     delete[] classification_in.categories; | ||||
| 
 | ||||
|     free(classification_in.head_name); | ||||
|   } | ||||
| 
 | ||||
|   delete[] in->classifications; | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tasks::c::components::containers
 | ||||
|  |  | |||
|  | @ -25,6 +25,8 @@ void CppConvertToClassificationResult( | |||
|     const mediapipe::tasks::components::containers::ClassificationResult& in, | ||||
|     ClassificationResult* out); | ||||
| 
 | ||||
| void CppCloseClassificationResult(ClassificationResult* in); | ||||
| 
 | ||||
| }  // namespace mediapipe::tasks::c::components::containers
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_CONVERTER_H_
 | ||||
|  |  | |||
|  | @ -53,9 +53,7 @@ TEST(ClassificationResultConverterTest, | |||
|   EXPECT_EQ(c_classification_result.timestamp_ms, 42); | ||||
|   EXPECT_EQ(c_classification_result.has_timestamp_ms, true); | ||||
| 
 | ||||
|   free(c_classification_result.classifications[0].categories); | ||||
|   free(c_classification_result.classifications[0].head_name); | ||||
|   free(c_classification_result.classifications); | ||||
|   CppCloseClassificationResult(&c_classification_result); | ||||
| } | ||||
| 
 | ||||
| TEST(ClassificationResultConverterTest, | ||||
|  | @ -79,7 +77,7 @@ TEST(ClassificationResultConverterTest, | |||
|   EXPECT_EQ(c_classification_result.timestamp_ms, 0); | ||||
|   EXPECT_EQ(c_classification_result.has_timestamp_ms, false); | ||||
| 
 | ||||
|   free(c_classification_result.classifications); | ||||
|   CppCloseClassificationResult(&c_classification_result); | ||||
| } | ||||
| 
 | ||||
| TEST(ClassificationResultConverterTest, | ||||
|  | @ -97,6 +95,8 @@ TEST(ClassificationResultConverterTest, | |||
|   EXPECT_EQ(c_classification_result.classifications_count, 0); | ||||
|   EXPECT_EQ(c_classification_result.timestamp_ms, 0); | ||||
|   EXPECT_EQ(c_classification_result.has_timestamp_ms, false); | ||||
| 
 | ||||
|   CppCloseClassificationResult(&c_classification_result); | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tasks::c::components::containers
 | ||||
|  |  | |||
|  | @ -67,3 +67,20 @@ cc_binary( | |||
|     ], | ||||
|     deps = [":text_classifier_lib"], | ||||
| ) | ||||
| 
 | ||||
| cc_test( | ||||
|     name = "text_classifier_test", | ||||
|     srcs = ["text_classifier_test.cc"], | ||||
|     data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"], | ||||
|     linkstatic = 1, | ||||
|     deps = [ | ||||
|         ":text_classifier_lib", | ||||
|         "//mediapipe/framework/deps:file_path", | ||||
|         "//mediapipe/framework/port:gtest", | ||||
|         "//mediapipe/tasks/c/components/containers:category", | ||||
|         "//mediapipe/tasks/cc/components/containers:category", | ||||
|         "@com_google_absl//absl/flags:flag", | ||||
|         "@com_google_absl//absl/strings", | ||||
|         "@com_google_googletest//:gtest_main", | ||||
|     ], | ||||
| ) | ||||
|  |  | |||
|  | @ -30,6 +30,8 @@ namespace mediapipe::tasks::c::text::text_classifier { | |||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| using ::mediapipe::tasks::c::components::containers:: | ||||
|     CppCloseClassificationResult; | ||||
| using ::mediapipe::tasks::c::components::containers:: | ||||
|     CppConvertToClassificationResult; | ||||
| using ::mediapipe::tasks::c::components::processors:: | ||||
|  | @ -55,7 +57,7 @@ TextClassifier* CppTextClassifierCreate(const TextClassifierOptions& options) { | |||
|   return classifier->release(); | ||||
| } | ||||
| 
 | ||||
| bool CppTextClassifierClassify(void* classifier, char* utf8_str, | ||||
| bool CppTextClassifierClassify(void* classifier, const char* utf8_str, | ||||
|                                TextClassifierResult* result) { | ||||
|   auto cpp_classifier = static_cast<TextClassifier*>(classifier); | ||||
|   auto cpp_result = cpp_classifier->Classify(utf8_str); | ||||
|  | @ -67,6 +69,10 @@ bool CppTextClassifierClassify(void* classifier, char* utf8_str, | |||
|   return true; | ||||
| } | ||||
| 
 | ||||
| void CppTextClassifierCloseResult(TextClassifierResult* result) { | ||||
|   CppCloseClassificationResult(result); | ||||
| } | ||||
| 
 | ||||
| void CppTextClassifierClose(void* classifier) { | ||||
|   auto cpp_classifier = static_cast<TextClassifier*>(classifier); | ||||
|   auto result = cpp_classifier->Close(); | ||||
|  | @ -85,12 +91,17 @@ void* text_classifier_create(struct TextClassifierOptions* options) { | |||
|       *options); | ||||
| } | ||||
| 
 | ||||
| int text_classifier_classify(void* classifier, char* utf8_str, | ||||
| int text_classifier_classify(void* classifier, const char* utf8_str, | ||||
|                              TextClassifierResult* result) { | ||||
|   return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClassify( | ||||
|       classifier, utf8_str, result); | ||||
| } | ||||
| 
 | ||||
| void text_classifier_close_result(TextClassifierResult* result) { | ||||
|   mediapipe::tasks::c::text::text_classifier::CppTextClassifierCloseResult( | ||||
|       result); | ||||
| } | ||||
| 
 | ||||
| void text_classifier_close(void* classifier) { | ||||
|   mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose( | ||||
|       classifier); | ||||
|  |  | |||
|  | @ -45,9 +45,13 @@ struct TextClassifierOptions { | |||
| MP_EXPORT void* text_classifier_create(struct TextClassifierOptions* options); | ||||
| 
 | ||||
| // Performs classification on the input `text`.
 | ||||
| MP_EXPORT int text_classifier_classify(void* classifier, char* utf8_str, | ||||
| MP_EXPORT int text_classifier_classify(void* classifier, const char* utf8_str, | ||||
|                                        TextClassifierResult* result); | ||||
| 
 | ||||
| // Frees the memory allocated inside a TextClassifierResult result. Does not
 | ||||
| // free the result pointer itself.
 | ||||
| MP_EXPORT void text_classifier_close_result(TextClassifierResult* result); | ||||
| 
 | ||||
| // Shuts down the TextClassifier when all the work is done. Frees all memory.
 | ||||
| MP_EXPORT void text_classifier_close(void* classifier); | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,70 @@ | |||
| /* Copyright 2023 The MediaPipe Authors.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mediapipe/tasks/c/text/text_classifier/text_classifier.h" | ||||
| 
 | ||||
| #include <string> | ||||
| 
 | ||||
| #include "absl/flags/flag.h" | ||||
| #include "absl/strings/string_view.h" | ||||
| #include "mediapipe/framework/deps/file_path.h" | ||||
| #include "mediapipe/framework/port/gtest.h" | ||||
| #include "mediapipe/tasks/c/components/containers/category.h" | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| using ::mediapipe::file::JoinPath; | ||||
| 
 | ||||
| constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; | ||||
| constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite"; | ||||
| constexpr char kTestString[] = "It's beautiful outside."; | ||||
| constexpr float kPrecision = 1e-6; | ||||
| 
 | ||||
| std::string GetFullPath(absl::string_view file_name) { | ||||
|   return JoinPath("./", kTestDataDirectory, file_name); | ||||
| } | ||||
| 
 | ||||
| TEST(TextClassifierTest, SmokeTest) { | ||||
|   std::string model_path = GetFullPath(kTestBertModelPath); | ||||
|   TextClassifierOptions options = { | ||||
|       /* base_options= */ {/* model_asset_buffer= */ nullptr, | ||||
|                            /* model_asset_path= */ model_path.c_str()}, | ||||
|       /* classifier_options= */ | ||||
|       {/* display_names_locale= */ nullptr, | ||||
|        /* max_results= */ -1, | ||||
|        /* score_threshold= */ 0.0, | ||||
|        /* category_allowlist= */ nullptr, | ||||
|        /* category_allowlist_count= */ 0, | ||||
|        /* category_denylist= */ nullptr, | ||||
|        /* category_denylist_count= */ 0}, | ||||
|   }; | ||||
| 
 | ||||
|   void* classifier = text_classifier_create(&options); | ||||
|   EXPECT_NE(classifier, nullptr); | ||||
| 
 | ||||
|   TextClassifierResult result; | ||||
|   text_classifier_classify(classifier, kTestString, &result); | ||||
|   EXPECT_EQ(result.classifications_count, 1); | ||||
|   EXPECT_EQ(result.classifications[0].categories_count, 2); | ||||
|   EXPECT_EQ(std::string{result.classifications[0].categories[0].category_name}, | ||||
|             "positive"); | ||||
|   EXPECT_NEAR(result.classifications[0].categories[0].score, 0.999371, | ||||
|               kPrecision); | ||||
| 
 | ||||
|   text_classifier_close_result(&result); | ||||
|   text_classifier_close(classifier); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user