mediapipe/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc
MediaPipe Team af43687f2e Open-sources a unit test.
PiperOrigin-RevId: 493184055
2022-12-05 20:11:07 -08:00

258 lines
11 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
#include <cmath>
#include <cstdlib>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe::tasks::text::text_classifier {
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::kMediaPipeTasksPayload;
using ::mediapipe::tasks::components::containers::Category;
using ::mediapipe::tasks::components::containers::Classifications;
using ::testing::HasSubstr;
using ::testing::Optional;
constexpr int kMaxSeqLen = 128;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
constexpr char kTestRegexModelPath[] =
"test_model_text_classifier_with_regex_tokenizer.tflite";
constexpr char kStringToBoolModelPath[] =
"test_model_text_classifier_bool_output.tflite";
std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
}
// Checks that the two provided `TextClassifierResult` are equal, with a
// tolerancy on floating-point score to account for numerical instabilities.
// TODO: create shared matcher for ClassificationResult.
void ExpectApproximatelyEqual(const TextClassifierResult& actual,
const TextClassifierResult& expected) {
const float kPrecision = 1e-6;
ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
for (int i = 0; i < actual.classifications.size(); ++i) {
const Classifications& a = actual.classifications[i];
const Classifications& b = expected.classifications[i];
EXPECT_EQ(a.head_index, b.head_index);
EXPECT_EQ(a.head_name, b.head_name);
EXPECT_EQ(a.categories.size(), b.categories.size());
for (int j = 0; j < a.categories.size(); ++j) {
const Category& x = a.categories[j];
const Category& y = b.categories[j];
EXPECT_EQ(x.index, y.index);
EXPECT_NEAR(x.score, y.score, kPrecision);
EXPECT_EQ(x.category_name, y.category_name);
EXPECT_EQ(x.display_name, y.display_name);
}
}
}
} // namespace
class TextClassifierTest : public tflite_shims::testing::Test {};
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
}
TEST_F(TextClassifierTest, CreateFailsWithMissingBaseOptions) {
auto options = std::make_unique<TextClassifierOptions>();
StatusOr<std::unique_ptr<TextClassifier>> classifier =
TextClassifier::Create(std::move(options));
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
classifier.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(TextClassifierTest, CreateFailsWithMissingModel) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kInvalidModelPath);
StatusOr<std::unique_ptr<TextClassifier>> classifier =
TextClassifier::Create(std::move(options));
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kNotFound);
EXPECT_THAT(classifier.status().message(),
HasSubstr("Unable to open file at"));
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
}
TEST_F(TextClassifierTest, TextClassifierWithBert) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextClassifierResult negative_result,
classifier->Classify("unflinchingly bleak and desperate"));
TextClassifierResult negative_expected;
negative_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"},
{/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(negative_result, negative_expected);
MP_ASSERT_OK_AND_ASSIGN(
TextClassifierResult positive_result,
classifier->Classify("it's a charming and often affecting journey"));
TextClassifierResult positive_expected;
positive_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"},
{/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(positive_result, positive_expected);
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult negative_result,
classifier->Classify("What a waste of my time."));
TextClassifierResult negative_expected;
negative_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/0, /*score=*/0.813130, /*category_name=*/"Negative"},
{/*index=*/1, /*score=*/0.186870, /*category_name=*/"Positive"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(negative_result, negative_expected);
MP_ASSERT_OK_AND_ASSIGN(
TextClassifierResult positive_result,
classifier->Classify("This is the best movie Ive seen in recent years."
"Strongly recommend it!"));
TextClassifierResult positive_expected;
positive_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/1, /*score=*/0.513427, /*category_name=*/"Positive"},
{/*index=*/0, /*score=*/0.486573, /*category_name=*/"Negative"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(positive_result, positive_expected);
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
options->base_options.op_resolver = CreateCustomResolver();
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result,
classifier->Classify("hello"));
// Binary outputs causes flaky ordering, so we compare manually.
ASSERT_EQ(result.classifications.size(), 1);
ASSERT_EQ(result.classifications[0].head_index, 0);
ASSERT_EQ(result.classifications[0].categories.size(), 3);
ASSERT_EQ(result.classifications[0].categories[0].score, 1);
ASSERT_LT(result.classifications[0].categories[0].index, 2); // i.e O or 1.
ASSERT_EQ(result.classifications[0].categories[1].score, 1);
ASSERT_LT(result.classifications[0].categories[1].index, 2); // i.e 0 or 1.
ASSERT_EQ(result.classifications[0].categories[2].score, 0);
ASSERT_EQ(result.classifications[0].categories[2].index, 2);
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, BertLongPositive) {
std::stringstream ss_for_positive_review;
ss_for_positive_review
<< "it's a charming and often affecting journey and this is a long";
for (int i = 0; i < kMaxSeqLen; ++i) {
ss_for_positive_review << " long";
}
ss_for_positive_review << " movie review";
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result,
classifier->Classify(ss_for_positive_review.str()));
TextClassifierResult expected;
std::vector<Category> categories;
// Predicted scores are slightly different on Mac OS.
#ifdef __APPLE__
categories.push_back(
{/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"});
categories.push_back(
{/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"});
#else
categories.push_back(
{/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"});
categories.push_back(
{/*index=*/0, /*score=*/0.014112, /*category_name=*/"negative"});
#endif // __APPLE__
expected.classifications.emplace_back(
Classifications{/*categories=*/categories,
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(result, expected);
MP_ASSERT_OK(classifier->Close());
}
} // namespace mediapipe::tasks::text::text_classifier