From af43687f2e3c774ff7b0f1f4881d456952a6aadd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 20:07:29 -0800 Subject: [PATCH] Open-sources a unit test. PiperOrigin-RevId: 493184055 --- .../text_classifier/text_classifier_test.cc | 51 +++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 8f73914fc..799885eac 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -38,10 +38,7 @@ limitations under the License. #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" -namespace mediapipe { -namespace tasks { -namespace text { -namespace text_classifier { +namespace mediapipe::tasks::text::text_classifier { namespace { using ::mediapipe::file::JoinPath; @@ -88,6 +85,8 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual, } } +} // namespace + class TextClassifierTest : public tflite_shims::testing::Test {}; TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { @@ -217,8 +216,42 @@ TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { MP_ASSERT_OK(classifier->Close()); } -} // namespace -} // namespace text_classifier -} // namespace text -} // namespace tasks -} // namespace mediapipe +TEST_F(TextClassifierTest, BertLongPositive) { + std::stringstream ss_for_positive_review; + ss_for_positive_review + << "it's a charming and often affecting journey and this is a long"; + for (int i = 0; i < kMaxSeqLen; ++i) { + ss_for_positive_review << " long"; + } + ss_for_positive_review << " movie review"; + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result, + classifier->Classify(ss_for_positive_review.str())); + TextClassifierResult expected; + std::vector categories; + +// Predicted scores are slightly different on Mac OS. +#ifdef __APPLE__ + categories.push_back( + {/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"}); +#else + categories.push_back( + {/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.014112, /*category_name=*/"negative"}); +#endif // __APPLE__ + + expected.classifications.emplace_back( + Classifications{/*categories=*/categories, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(result, expected); + MP_ASSERT_OK(classifier->Close()); +} + +} // namespace mediapipe::tasks::text::text_classifier