Add End to End test for Text Classifier C API

PiperOrigin-RevId: 569658768
This commit is contained in:
Sebastian Schmidt 2023-09-29 20:50:52 -07:00 committed by Copybara-Service
parent 96fa10b906
commit c7402efe5e
10 changed files with 139 additions and 9 deletions

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "mediapipe/tasks/c/components/containers/category_converter.h" #include "mediapipe/tasks/c/components/containers/category_converter.h"
#include <cstdlib>
#include "mediapipe/tasks/c/components/containers/category.h" #include "mediapipe/tasks/c/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/category.h" #include "mediapipe/tasks/cc/components/containers/category.h"
@ -32,4 +34,9 @@ void CppConvertToCategory(
in.display_name.has_value() ? strdup(in.display_name->c_str()) : nullptr; in.display_name.has_value() ? strdup(in.display_name->c_str()) : nullptr;
} }
void CppCloseCategory(Category* in) {
free(in->category_name);
free(in->display_name);
}
} // namespace mediapipe::tasks::c::components::containers } // namespace mediapipe::tasks::c::components::containers

View File

@ -25,6 +25,8 @@ void CppConvertToCategory(
const mediapipe::tasks::components::containers::Category& in, const mediapipe::tasks::components::containers::Category& in,
Category* out); Category* out);
void CppCloseCategory(Category* in);
} // namespace mediapipe::tasks::c::components::containers } // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_CONVERTER_H_ #endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_CONVERTER_H_

View File

@ -40,8 +40,7 @@ TEST(CategoryConverterTest, ConvertsCategoryCustomValues) {
EXPECT_EQ(std::string{c_category.category_name}, "category_name"); EXPECT_EQ(std::string{c_category.category_name}, "category_name");
EXPECT_EQ(std::string{c_category.display_name}, "display_name"); EXPECT_EQ(std::string{c_category.display_name}, "display_name");
free(c_category.category_name); CppCloseCategory(&c_category);
free(c_category.display_name);
} }
TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) { TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) {
@ -58,6 +57,8 @@ TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) {
EXPECT_FLOAT_EQ(c_category.score, 0.1); EXPECT_FLOAT_EQ(c_category.score, 0.1);
EXPECT_EQ(c_category.category_name, nullptr); EXPECT_EQ(c_category.category_name, nullptr);
EXPECT_EQ(c_category.display_name, nullptr); EXPECT_EQ(c_category.display_name, nullptr);
CppCloseCategory(&c_category);
} }
} // namespace mediapipe::tasks::c::components::containers } // namespace mediapipe::tasks::c::components::containers

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "mediapipe/tasks/c/components/containers/classification_result_converter.h" #include "mediapipe/tasks/c/components/containers/classification_result_converter.h"
#include <cstdint> #include <cstdint>
#include <cstdlib>
#include "mediapipe/tasks/c/components/containers/category.h" #include "mediapipe/tasks/c/components/containers/category.h"
#include "mediapipe/tasks/c/components/containers/category_converter.h" #include "mediapipe/tasks/c/components/containers/category_converter.h"
@ -57,4 +58,19 @@ void CppConvertToClassificationResult(
} }
} }
void CppCloseClassificationResult(ClassificationResult* in) {
for (uint32_t i = 0; i < in->classifications_count; ++i) {
auto classification_in = in->classifications[i];
for (uint32_t j = 0; j < classification_in.categories_count; ++j) {
CppCloseCategory(&classification_in.categories[j]);
}
delete[] classification_in.categories;
free(classification_in.head_name);
}
delete[] in->classifications;
}
} // namespace mediapipe::tasks::c::components::containers } // namespace mediapipe::tasks::c::components::containers

View File

@ -25,6 +25,8 @@ void CppConvertToClassificationResult(
const mediapipe::tasks::components::containers::ClassificationResult& in, const mediapipe::tasks::components::containers::ClassificationResult& in,
ClassificationResult* out); ClassificationResult* out);
void CppCloseClassificationResult(ClassificationResult* in);
} // namespace mediapipe::tasks::c::components::containers } // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_CONVERTER_H_ #endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_CONVERTER_H_

View File

@ -53,9 +53,7 @@ TEST(ClassificationResultConverterTest,
EXPECT_EQ(c_classification_result.timestamp_ms, 42); EXPECT_EQ(c_classification_result.timestamp_ms, 42);
EXPECT_EQ(c_classification_result.has_timestamp_ms, true); EXPECT_EQ(c_classification_result.has_timestamp_ms, true);
free(c_classification_result.classifications[0].categories); CppCloseClassificationResult(&c_classification_result);
free(c_classification_result.classifications[0].head_name);
free(c_classification_result.classifications);
} }
TEST(ClassificationResultConverterTest, TEST(ClassificationResultConverterTest,
@ -79,7 +77,7 @@ TEST(ClassificationResultConverterTest,
EXPECT_EQ(c_classification_result.timestamp_ms, 0); EXPECT_EQ(c_classification_result.timestamp_ms, 0);
EXPECT_EQ(c_classification_result.has_timestamp_ms, false); EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
free(c_classification_result.classifications); CppCloseClassificationResult(&c_classification_result);
} }
TEST(ClassificationResultConverterTest, TEST(ClassificationResultConverterTest,
@ -97,6 +95,8 @@ TEST(ClassificationResultConverterTest,
EXPECT_EQ(c_classification_result.classifications_count, 0); EXPECT_EQ(c_classification_result.classifications_count, 0);
EXPECT_EQ(c_classification_result.timestamp_ms, 0); EXPECT_EQ(c_classification_result.timestamp_ms, 0);
EXPECT_EQ(c_classification_result.has_timestamp_ms, false); EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
CppCloseClassificationResult(&c_classification_result);
} }
} // namespace mediapipe::tasks::c::components::containers } // namespace mediapipe::tasks::c::components::containers

View File

@ -67,3 +67,20 @@ cc_binary(
], ],
deps = [":text_classifier_lib"], deps = [":text_classifier_lib"],
) )
cc_test(
name = "text_classifier_test",
srcs = ["text_classifier_test.cc"],
data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"],
linkstatic = 1,
deps = [
":text_classifier_lib",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/c/components/containers:category",
"//mediapipe/tasks/cc/components/containers:category",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -30,6 +30,8 @@ namespace mediapipe::tasks::c::text::text_classifier {
namespace { namespace {
using ::mediapipe::tasks::c::components::containers::
CppCloseClassificationResult;
using ::mediapipe::tasks::c::components::containers:: using ::mediapipe::tasks::c::components::containers::
CppConvertToClassificationResult; CppConvertToClassificationResult;
using ::mediapipe::tasks::c::components::processors:: using ::mediapipe::tasks::c::components::processors::
@ -55,7 +57,7 @@ TextClassifier* CppTextClassifierCreate(const TextClassifierOptions& options) {
return classifier->release(); return classifier->release();
} }
bool CppTextClassifierClassify(void* classifier, char* utf8_str, bool CppTextClassifierClassify(void* classifier, const char* utf8_str,
TextClassifierResult* result) { TextClassifierResult* result) {
auto cpp_classifier = static_cast<TextClassifier*>(classifier); auto cpp_classifier = static_cast<TextClassifier*>(classifier);
auto cpp_result = cpp_classifier->Classify(utf8_str); auto cpp_result = cpp_classifier->Classify(utf8_str);
@ -67,6 +69,10 @@ bool CppTextClassifierClassify(void* classifier, char* utf8_str,
return true; return true;
} }
void CppTextClassifierCloseResult(TextClassifierResult* result) {
CppCloseClassificationResult(result);
}
void CppTextClassifierClose(void* classifier) { void CppTextClassifierClose(void* classifier) {
auto cpp_classifier = static_cast<TextClassifier*>(classifier); auto cpp_classifier = static_cast<TextClassifier*>(classifier);
auto result = cpp_classifier->Close(); auto result = cpp_classifier->Close();
@ -85,12 +91,17 @@ void* text_classifier_create(struct TextClassifierOptions* options) {
*options); *options);
} }
int text_classifier_classify(void* classifier, char* utf8_str, int text_classifier_classify(void* classifier, const char* utf8_str,
TextClassifierResult* result) { TextClassifierResult* result) {
return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClassify( return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClassify(
classifier, utf8_str, result); classifier, utf8_str, result);
} }
void text_classifier_close_result(TextClassifierResult* result) {
mediapipe::tasks::c::text::text_classifier::CppTextClassifierCloseResult(
result);
}
void text_classifier_close(void* classifier) { void text_classifier_close(void* classifier) {
mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose( mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose(
classifier); classifier);

View File

@ -45,9 +45,13 @@ struct TextClassifierOptions {
MP_EXPORT void* text_classifier_create(struct TextClassifierOptions* options); MP_EXPORT void* text_classifier_create(struct TextClassifierOptions* options);
// Performs classification on the input `text`. // Performs classification on the input `text`.
MP_EXPORT int text_classifier_classify(void* classifier, char* utf8_str, MP_EXPORT int text_classifier_classify(void* classifier, const char* utf8_str,
TextClassifierResult* result); TextClassifierResult* result);
// Frees the memory allocated inside a TextClassifierResult result. Does not
// free the result pointer itself.
MP_EXPORT void text_classifier_close_result(TextClassifierResult* result);
// Shuts down the TextClassifier when all the work is done. Frees all memory. // Shuts down the TextClassifier when all the work is done. Frees all memory.
MP_EXPORT void text_classifier_close(void* classifier); MP_EXPORT void text_classifier_close(void* classifier);

View File

@ -0,0 +1,70 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/text/text_classifier/text_classifier.h"
#include <string>
#include "absl/flags/flag.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/category.h"
namespace {
using ::mediapipe::file::JoinPath;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
constexpr char kTestString[] = "It's beautiful outside.";
constexpr float kPrecision = 1e-6;
std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
}
TEST(TextClassifierTest, SmokeTest) {
std::string model_path = GetFullPath(kTestBertModelPath);
TextClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_path= */ model_path.c_str()},
/* classifier_options= */
{/* display_names_locale= */ nullptr,
/* max_results= */ -1,
/* score_threshold= */ 0.0,
/* category_allowlist= */ nullptr,
/* category_allowlist_count= */ 0,
/* category_denylist= */ nullptr,
/* category_denylist_count= */ 0},
};
void* classifier = text_classifier_create(&options);
EXPECT_NE(classifier, nullptr);
TextClassifierResult result;
text_classifier_classify(classifier, kTestString, &result);
EXPECT_EQ(result.classifications_count, 1);
EXPECT_EQ(result.classifications[0].categories_count, 2);
EXPECT_EQ(std::string{result.classifications[0].categories[0].category_name},
"positive");
EXPECT_NEAR(result.classifications[0].categories[0].score, 0.999371,
kPrecision);
text_classifier_close_result(&result);
text_classifier_close(classifier);
}
} // namespace