71 lines
2.5 KiB
C++
71 lines
2.5 KiB
C++
/* Copyright 2023 The MediaPipe Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "mediapipe/tasks/c/text/text_classifier/text_classifier.h"
|
|
|
|
#include <string>
|
|
|
|
#include "absl/flags/flag.h"
|
|
#include "absl/strings/string_view.h"
|
|
#include "mediapipe/framework/deps/file_path.h"
|
|
#include "mediapipe/framework/port/gtest.h"
|
|
#include "mediapipe/tasks/c/components/containers/category.h"
|
|
|
|
namespace {
|
|
|
|
using ::mediapipe::file::JoinPath;
|
|
|
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
|
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
|
|
constexpr char kTestString[] = "It's beautiful outside.";
|
|
constexpr float kPrecision = 1e-6;
|
|
|
|
std::string GetFullPath(absl::string_view file_name) {
|
|
return JoinPath("./", kTestDataDirectory, file_name);
|
|
}
|
|
|
|
TEST(TextClassifierTest, SmokeTest) {
|
|
std::string model_path = GetFullPath(kTestBertModelPath);
|
|
TextClassifierOptions options = {
|
|
/* base_options= */ {/* model_asset_buffer= */ nullptr,
|
|
/* model_asset_path= */ model_path.c_str()},
|
|
/* classifier_options= */
|
|
{/* display_names_locale= */ nullptr,
|
|
/* max_results= */ -1,
|
|
/* score_threshold= */ 0.0,
|
|
/* category_allowlist= */ nullptr,
|
|
/* category_allowlist_count= */ 0,
|
|
/* category_denylist= */ nullptr,
|
|
/* category_denylist_count= */ 0},
|
|
};
|
|
|
|
void* classifier = text_classifier_create(&options);
|
|
EXPECT_NE(classifier, nullptr);
|
|
|
|
TextClassifierResult result;
|
|
text_classifier_classify(classifier, kTestString, &result);
|
|
EXPECT_EQ(result.classifications_count, 1);
|
|
EXPECT_EQ(result.classifications[0].categories_count, 2);
|
|
EXPECT_EQ(std::string{result.classifications[0].categories[0].category_name},
|
|
"positive");
|
|
EXPECT_NEAR(result.classifications[0].categories[0].score, 0.999371,
|
|
kPrecision);
|
|
|
|
text_classifier_close_result(&result);
|
|
text_classifier_close(classifier);
|
|
}
|
|
|
|
} // namespace
|