Add End to End test for Text Classifier C API
PiperOrigin-RevId: 569658768
This commit is contained in:
parent
96fa10b906
commit
c7402efe5e
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/c/components/containers/category_converter.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
|
||||
|
@ -32,4 +34,9 @@ void CppConvertToCategory(
|
|||
in.display_name.has_value() ? strdup(in.display_name->c_str()) : nullptr;
|
||||
}
|
||||
|
||||
void CppCloseCategory(Category* in) {
|
||||
free(in->category_name);
|
||||
free(in->display_name);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
|
|
@ -25,6 +25,8 @@ void CppConvertToCategory(
|
|||
const mediapipe::tasks::components::containers::Category& in,
|
||||
Category* out);
|
||||
|
||||
void CppCloseCategory(Category* in);
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_CONVERTER_H_
|
||||
|
|
|
@ -40,8 +40,7 @@ TEST(CategoryConverterTest, ConvertsCategoryCustomValues) {
|
|||
EXPECT_EQ(std::string{c_category.category_name}, "category_name");
|
||||
EXPECT_EQ(std::string{c_category.display_name}, "display_name");
|
||||
|
||||
free(c_category.category_name);
|
||||
free(c_category.display_name);
|
||||
CppCloseCategory(&c_category);
|
||||
}
|
||||
|
||||
TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) {
|
||||
|
@ -58,6 +57,8 @@ TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) {
|
|||
EXPECT_FLOAT_EQ(c_category.score, 0.1);
|
||||
EXPECT_EQ(c_category.category_name, nullptr);
|
||||
EXPECT_EQ(c_category.display_name, nullptr);
|
||||
|
||||
CppCloseCategory(&c_category);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/c/components/containers/classification_result_converter.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "mediapipe/tasks/c/components/containers/category.h"
|
||||
#include "mediapipe/tasks/c/components/containers/category_converter.h"
|
||||
|
@ -57,4 +58,19 @@ void CppConvertToClassificationResult(
|
|||
}
|
||||
}
|
||||
|
||||
void CppCloseClassificationResult(ClassificationResult* in) {
|
||||
for (uint32_t i = 0; i < in->classifications_count; ++i) {
|
||||
auto classification_in = in->classifications[i];
|
||||
|
||||
for (uint32_t j = 0; j < classification_in.categories_count; ++j) {
|
||||
CppCloseCategory(&classification_in.categories[j]);
|
||||
}
|
||||
delete[] classification_in.categories;
|
||||
|
||||
free(classification_in.head_name);
|
||||
}
|
||||
|
||||
delete[] in->classifications;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
|
|
@ -25,6 +25,8 @@ void CppConvertToClassificationResult(
|
|||
const mediapipe::tasks::components::containers::ClassificationResult& in,
|
||||
ClassificationResult* out);
|
||||
|
||||
void CppCloseClassificationResult(ClassificationResult* in);
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_CONVERTER_H_
|
||||
|
|
|
@ -53,9 +53,7 @@ TEST(ClassificationResultConverterTest,
|
|||
EXPECT_EQ(c_classification_result.timestamp_ms, 42);
|
||||
EXPECT_EQ(c_classification_result.has_timestamp_ms, true);
|
||||
|
||||
free(c_classification_result.classifications[0].categories);
|
||||
free(c_classification_result.classifications[0].head_name);
|
||||
free(c_classification_result.classifications);
|
||||
CppCloseClassificationResult(&c_classification_result);
|
||||
}
|
||||
|
||||
TEST(ClassificationResultConverterTest,
|
||||
|
@ -79,7 +77,7 @@ TEST(ClassificationResultConverterTest,
|
|||
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
|
||||
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
|
||||
|
||||
free(c_classification_result.classifications);
|
||||
CppCloseClassificationResult(&c_classification_result);
|
||||
}
|
||||
|
||||
TEST(ClassificationResultConverterTest,
|
||||
|
@ -97,6 +95,8 @@ TEST(ClassificationResultConverterTest,
|
|||
EXPECT_EQ(c_classification_result.classifications_count, 0);
|
||||
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
|
||||
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
|
||||
|
||||
CppCloseClassificationResult(&c_classification_result);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::c::components::containers
|
||||
|
|
|
@ -67,3 +67,20 @@ cc_binary(
|
|||
],
|
||||
deps = [":text_classifier_lib"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "text_classifier_test",
|
||||
srcs = ["text_classifier_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":text_classifier_lib",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/tasks/c/components/containers:category",
|
||||
"//mediapipe/tasks/cc/components/containers:category",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -30,6 +30,8 @@ namespace mediapipe::tasks::c::text::text_classifier {
|
|||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::c::components::containers::
|
||||
CppCloseClassificationResult;
|
||||
using ::mediapipe::tasks::c::components::containers::
|
||||
CppConvertToClassificationResult;
|
||||
using ::mediapipe::tasks::c::components::processors::
|
||||
|
@ -55,7 +57,7 @@ TextClassifier* CppTextClassifierCreate(const TextClassifierOptions& options) {
|
|||
return classifier->release();
|
||||
}
|
||||
|
||||
bool CppTextClassifierClassify(void* classifier, char* utf8_str,
|
||||
bool CppTextClassifierClassify(void* classifier, const char* utf8_str,
|
||||
TextClassifierResult* result) {
|
||||
auto cpp_classifier = static_cast<TextClassifier*>(classifier);
|
||||
auto cpp_result = cpp_classifier->Classify(utf8_str);
|
||||
|
@ -67,6 +69,10 @@ bool CppTextClassifierClassify(void* classifier, char* utf8_str,
|
|||
return true;
|
||||
}
|
||||
|
||||
void CppTextClassifierCloseResult(TextClassifierResult* result) {
|
||||
CppCloseClassificationResult(result);
|
||||
}
|
||||
|
||||
void CppTextClassifierClose(void* classifier) {
|
||||
auto cpp_classifier = static_cast<TextClassifier*>(classifier);
|
||||
auto result = cpp_classifier->Close();
|
||||
|
@ -85,12 +91,17 @@ void* text_classifier_create(struct TextClassifierOptions* options) {
|
|||
*options);
|
||||
}
|
||||
|
||||
int text_classifier_classify(void* classifier, char* utf8_str,
|
||||
int text_classifier_classify(void* classifier, const char* utf8_str,
|
||||
TextClassifierResult* result) {
|
||||
return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClassify(
|
||||
classifier, utf8_str, result);
|
||||
}
|
||||
|
||||
void text_classifier_close_result(TextClassifierResult* result) {
|
||||
mediapipe::tasks::c::text::text_classifier::CppTextClassifierCloseResult(
|
||||
result);
|
||||
}
|
||||
|
||||
void text_classifier_close(void* classifier) {
|
||||
mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose(
|
||||
classifier);
|
||||
|
|
|
@ -45,9 +45,13 @@ struct TextClassifierOptions {
|
|||
MP_EXPORT void* text_classifier_create(struct TextClassifierOptions* options);
|
||||
|
||||
// Performs classification on the input `text`.
|
||||
MP_EXPORT int text_classifier_classify(void* classifier, char* utf8_str,
|
||||
MP_EXPORT int text_classifier_classify(void* classifier, const char* utf8_str,
|
||||
TextClassifierResult* result);
|
||||
|
||||
// Frees the memory allocated inside a TextClassifierResult result. Does not
|
||||
// free the result pointer itself.
|
||||
MP_EXPORT void text_classifier_close_result(TextClassifierResult* result);
|
||||
|
||||
// Shuts down the TextClassifier when all the work is done. Frees all memory.
|
||||
MP_EXPORT void text_classifier_close(void* classifier);
|
||||
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/c/text/text_classifier/text_classifier.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/c/components/containers/category.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
|
||||
constexpr char kTestString[] = "It's beautiful outside.";
|
||||
constexpr float kPrecision = 1e-6;
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
TEST(TextClassifierTest, SmokeTest) {
|
||||
std::string model_path = GetFullPath(kTestBertModelPath);
|
||||
TextClassifierOptions options = {
|
||||
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||
/* model_asset_path= */ model_path.c_str()},
|
||||
/* classifier_options= */
|
||||
{/* display_names_locale= */ nullptr,
|
||||
/* max_results= */ -1,
|
||||
/* score_threshold= */ 0.0,
|
||||
/* category_allowlist= */ nullptr,
|
||||
/* category_allowlist_count= */ 0,
|
||||
/* category_denylist= */ nullptr,
|
||||
/* category_denylist_count= */ 0},
|
||||
};
|
||||
|
||||
void* classifier = text_classifier_create(&options);
|
||||
EXPECT_NE(classifier, nullptr);
|
||||
|
||||
TextClassifierResult result;
|
||||
text_classifier_classify(classifier, kTestString, &result);
|
||||
EXPECT_EQ(result.classifications_count, 1);
|
||||
EXPECT_EQ(result.classifications[0].categories_count, 2);
|
||||
EXPECT_EQ(std::string{result.classifications[0].categories[0].category_name},
|
||||
"positive");
|
||||
EXPECT_NEAR(result.classifications[0].categories[0].score, 0.999371,
|
||||
kPrecision);
|
||||
|
||||
text_classifier_close_result(&result);
|
||||
text_classifier_close(classifier);
|
||||
}
|
||||
|
||||
} // namespace
|
Loading…
Reference in New Issue
Block a user