Add End to End test for Text Classifier C API
PiperOrigin-RevId: 569658768
This commit is contained in:
parent
96fa10b906
commit
c7402efe5e
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/c/components/containers/category_converter.h"
|
#include "mediapipe/tasks/c/components/containers/category_converter.h"
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
#include "mediapipe/tasks/c/components/containers/category.h"
|
#include "mediapipe/tasks/c/components/containers/category.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
|
|
||||||
|
@ -32,4 +34,9 @@ void CppConvertToCategory(
|
||||||
in.display_name.has_value() ? strdup(in.display_name->c_str()) : nullptr;
|
in.display_name.has_value() ? strdup(in.display_name->c_str()) : nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CppCloseCategory(Category* in) {
|
||||||
|
free(in->category_name);
|
||||||
|
free(in->display_name);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe::tasks::c::components::containers
|
} // namespace mediapipe::tasks::c::components::containers
|
||||||
|
|
|
@ -25,6 +25,8 @@ void CppConvertToCategory(
|
||||||
const mediapipe::tasks::components::containers::Category& in,
|
const mediapipe::tasks::components::containers::Category& in,
|
||||||
Category* out);
|
Category* out);
|
||||||
|
|
||||||
|
void CppCloseCategory(Category* in);
|
||||||
|
|
||||||
} // namespace mediapipe::tasks::c::components::containers
|
} // namespace mediapipe::tasks::c::components::containers
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_CONVERTER_H_
|
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_CONVERTER_H_
|
||||||
|
|
|
@ -40,8 +40,7 @@ TEST(CategoryConverterTest, ConvertsCategoryCustomValues) {
|
||||||
EXPECT_EQ(std::string{c_category.category_name}, "category_name");
|
EXPECT_EQ(std::string{c_category.category_name}, "category_name");
|
||||||
EXPECT_EQ(std::string{c_category.display_name}, "display_name");
|
EXPECT_EQ(std::string{c_category.display_name}, "display_name");
|
||||||
|
|
||||||
free(c_category.category_name);
|
CppCloseCategory(&c_category);
|
||||||
free(c_category.display_name);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) {
|
TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) {
|
||||||
|
@ -58,6 +57,8 @@ TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) {
|
||||||
EXPECT_FLOAT_EQ(c_category.score, 0.1);
|
EXPECT_FLOAT_EQ(c_category.score, 0.1);
|
||||||
EXPECT_EQ(c_category.category_name, nullptr);
|
EXPECT_EQ(c_category.category_name, nullptr);
|
||||||
EXPECT_EQ(c_category.display_name, nullptr);
|
EXPECT_EQ(c_category.display_name, nullptr);
|
||||||
|
|
||||||
|
CppCloseCategory(&c_category);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe::tasks::c::components::containers
|
} // namespace mediapipe::tasks::c::components::containers
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/c/components/containers/classification_result_converter.h"
|
#include "mediapipe/tasks/c/components/containers/classification_result_converter.h"
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
#include "mediapipe/tasks/c/components/containers/category.h"
|
#include "mediapipe/tasks/c/components/containers/category.h"
|
||||||
#include "mediapipe/tasks/c/components/containers/category_converter.h"
|
#include "mediapipe/tasks/c/components/containers/category_converter.h"
|
||||||
|
@ -57,4 +58,19 @@ void CppConvertToClassificationResult(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CppCloseClassificationResult(ClassificationResult* in) {
|
||||||
|
for (uint32_t i = 0; i < in->classifications_count; ++i) {
|
||||||
|
auto classification_in = in->classifications[i];
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < classification_in.categories_count; ++j) {
|
||||||
|
CppCloseCategory(&classification_in.categories[j]);
|
||||||
|
}
|
||||||
|
delete[] classification_in.categories;
|
||||||
|
|
||||||
|
free(classification_in.head_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete[] in->classifications;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe::tasks::c::components::containers
|
} // namespace mediapipe::tasks::c::components::containers
|
||||||
|
|
|
@ -25,6 +25,8 @@ void CppConvertToClassificationResult(
|
||||||
const mediapipe::tasks::components::containers::ClassificationResult& in,
|
const mediapipe::tasks::components::containers::ClassificationResult& in,
|
||||||
ClassificationResult* out);
|
ClassificationResult* out);
|
||||||
|
|
||||||
|
void CppCloseClassificationResult(ClassificationResult* in);
|
||||||
|
|
||||||
} // namespace mediapipe::tasks::c::components::containers
|
} // namespace mediapipe::tasks::c::components::containers
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_CONVERTER_H_
|
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_CONVERTER_H_
|
||||||
|
|
|
@ -53,9 +53,7 @@ TEST(ClassificationResultConverterTest,
|
||||||
EXPECT_EQ(c_classification_result.timestamp_ms, 42);
|
EXPECT_EQ(c_classification_result.timestamp_ms, 42);
|
||||||
EXPECT_EQ(c_classification_result.has_timestamp_ms, true);
|
EXPECT_EQ(c_classification_result.has_timestamp_ms, true);
|
||||||
|
|
||||||
free(c_classification_result.classifications[0].categories);
|
CppCloseClassificationResult(&c_classification_result);
|
||||||
free(c_classification_result.classifications[0].head_name);
|
|
||||||
free(c_classification_result.classifications);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ClassificationResultConverterTest,
|
TEST(ClassificationResultConverterTest,
|
||||||
|
@ -79,7 +77,7 @@ TEST(ClassificationResultConverterTest,
|
||||||
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
|
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
|
||||||
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
|
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
|
||||||
|
|
||||||
free(c_classification_result.classifications);
|
CppCloseClassificationResult(&c_classification_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ClassificationResultConverterTest,
|
TEST(ClassificationResultConverterTest,
|
||||||
|
@ -97,6 +95,8 @@ TEST(ClassificationResultConverterTest,
|
||||||
EXPECT_EQ(c_classification_result.classifications_count, 0);
|
EXPECT_EQ(c_classification_result.classifications_count, 0);
|
||||||
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
|
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
|
||||||
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
|
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
|
||||||
|
|
||||||
|
CppCloseClassificationResult(&c_classification_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe::tasks::c::components::containers
|
} // namespace mediapipe::tasks::c::components::containers
|
||||||
|
|
|
@ -67,3 +67,20 @@ cc_binary(
|
||||||
],
|
],
|
||||||
deps = [":text_classifier_lib"],
|
deps = [":text_classifier_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "text_classifier_test",
|
||||||
|
srcs = ["text_classifier_test.cc"],
|
||||||
|
data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"],
|
||||||
|
linkstatic = 1,
|
||||||
|
deps = [
|
||||||
|
":text_classifier_lib",
|
||||||
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/tasks/c/components/containers:category",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:category",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -30,6 +30,8 @@ namespace mediapipe::tasks::c::text::text_classifier {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::c::components::containers::
|
||||||
|
CppCloseClassificationResult;
|
||||||
using ::mediapipe::tasks::c::components::containers::
|
using ::mediapipe::tasks::c::components::containers::
|
||||||
CppConvertToClassificationResult;
|
CppConvertToClassificationResult;
|
||||||
using ::mediapipe::tasks::c::components::processors::
|
using ::mediapipe::tasks::c::components::processors::
|
||||||
|
@ -55,7 +57,7 @@ TextClassifier* CppTextClassifierCreate(const TextClassifierOptions& options) {
|
||||||
return classifier->release();
|
return classifier->release();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CppTextClassifierClassify(void* classifier, char* utf8_str,
|
bool CppTextClassifierClassify(void* classifier, const char* utf8_str,
|
||||||
TextClassifierResult* result) {
|
TextClassifierResult* result) {
|
||||||
auto cpp_classifier = static_cast<TextClassifier*>(classifier);
|
auto cpp_classifier = static_cast<TextClassifier*>(classifier);
|
||||||
auto cpp_result = cpp_classifier->Classify(utf8_str);
|
auto cpp_result = cpp_classifier->Classify(utf8_str);
|
||||||
|
@ -67,6 +69,10 @@ bool CppTextClassifierClassify(void* classifier, char* utf8_str,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CppTextClassifierCloseResult(TextClassifierResult* result) {
|
||||||
|
CppCloseClassificationResult(result);
|
||||||
|
}
|
||||||
|
|
||||||
void CppTextClassifierClose(void* classifier) {
|
void CppTextClassifierClose(void* classifier) {
|
||||||
auto cpp_classifier = static_cast<TextClassifier*>(classifier);
|
auto cpp_classifier = static_cast<TextClassifier*>(classifier);
|
||||||
auto result = cpp_classifier->Close();
|
auto result = cpp_classifier->Close();
|
||||||
|
@ -85,12 +91,17 @@ void* text_classifier_create(struct TextClassifierOptions* options) {
|
||||||
*options);
|
*options);
|
||||||
}
|
}
|
||||||
|
|
||||||
int text_classifier_classify(void* classifier, char* utf8_str,
|
int text_classifier_classify(void* classifier, const char* utf8_str,
|
||||||
TextClassifierResult* result) {
|
TextClassifierResult* result) {
|
||||||
return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClassify(
|
return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClassify(
|
||||||
classifier, utf8_str, result);
|
classifier, utf8_str, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void text_classifier_close_result(TextClassifierResult* result) {
|
||||||
|
mediapipe::tasks::c::text::text_classifier::CppTextClassifierCloseResult(
|
||||||
|
result);
|
||||||
|
}
|
||||||
|
|
||||||
void text_classifier_close(void* classifier) {
|
void text_classifier_close(void* classifier) {
|
||||||
mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose(
|
mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose(
|
||||||
classifier);
|
classifier);
|
||||||
|
|
|
@ -45,9 +45,13 @@ struct TextClassifierOptions {
|
||||||
MP_EXPORT void* text_classifier_create(struct TextClassifierOptions* options);
|
MP_EXPORT void* text_classifier_create(struct TextClassifierOptions* options);
|
||||||
|
|
||||||
// Performs classification on the input `text`.
|
// Performs classification on the input `text`.
|
||||||
MP_EXPORT int text_classifier_classify(void* classifier, char* utf8_str,
|
MP_EXPORT int text_classifier_classify(void* classifier, const char* utf8_str,
|
||||||
TextClassifierResult* result);
|
TextClassifierResult* result);
|
||||||
|
|
||||||
|
// Frees the memory allocated inside a TextClassifierResult result. Does not
|
||||||
|
// free the result pointer itself.
|
||||||
|
MP_EXPORT void text_classifier_close_result(TextClassifierResult* result);
|
||||||
|
|
||||||
// Shuts down the TextClassifier when all the work is done. Frees all memory.
|
// Shuts down the TextClassifier when all the work is done. Frees all memory.
|
||||||
MP_EXPORT void text_classifier_close(void* classifier);
|
MP_EXPORT void text_classifier_close(void* classifier);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/text/text_classifier/text_classifier.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/flags/flag.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/c/components/containers/category.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::file::JoinPath;
|
||||||
|
|
||||||
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||||
|
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
|
||||||
|
constexpr char kTestString[] = "It's beautiful outside.";
|
||||||
|
constexpr float kPrecision = 1e-6;
|
||||||
|
|
||||||
|
std::string GetFullPath(absl::string_view file_name) {
|
||||||
|
return JoinPath("./", kTestDataDirectory, file_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TextClassifierTest, SmokeTest) {
|
||||||
|
std::string model_path = GetFullPath(kTestBertModelPath);
|
||||||
|
TextClassifierOptions options = {
|
||||||
|
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
||||||
|
/* model_asset_path= */ model_path.c_str()},
|
||||||
|
/* classifier_options= */
|
||||||
|
{/* display_names_locale= */ nullptr,
|
||||||
|
/* max_results= */ -1,
|
||||||
|
/* score_threshold= */ 0.0,
|
||||||
|
/* category_allowlist= */ nullptr,
|
||||||
|
/* category_allowlist_count= */ 0,
|
||||||
|
/* category_denylist= */ nullptr,
|
||||||
|
/* category_denylist_count= */ 0},
|
||||||
|
};
|
||||||
|
|
||||||
|
void* classifier = text_classifier_create(&options);
|
||||||
|
EXPECT_NE(classifier, nullptr);
|
||||||
|
|
||||||
|
TextClassifierResult result;
|
||||||
|
text_classifier_classify(classifier, kTestString, &result);
|
||||||
|
EXPECT_EQ(result.classifications_count, 1);
|
||||||
|
EXPECT_EQ(result.classifications[0].categories_count, 2);
|
||||||
|
EXPECT_EQ(std::string{result.classifications[0].categories[0].category_name},
|
||||||
|
"positive");
|
||||||
|
EXPECT_NEAR(result.classifications[0].categories[0].score, 0.999371,
|
||||||
|
kPrecision);
|
||||||
|
|
||||||
|
text_classifier_close_result(&result);
|
||||||
|
text_classifier_close(classifier);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
Loading…
Reference in New Issue
Block a user