From 7196db275efae1738bc31f18fb2ed366f1b41b1d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 21 Oct 2022 17:26:48 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 482925717 --- mediapipe/calculators/tensor/BUILD | 20 +++ mediapipe/tasks/cc/text/text_classifier/BUILD | 23 ++++ .../text_classifier/text_classifier_test.cc | 124 ------------------ mediapipe/tasks/cc/text/tokenizers/BUILD | 4 - mediapipe/tasks/testdata/text/BUILD | 7 +- 5 files changed, 47 insertions(+), 131 deletions(-) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index e953342da..99b5b3e91 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -253,6 +253,26 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "regex_preprocessor_calculator_test", + srcs = ["regex_preprocessor_calculator_test.cc"], + data = ["//mediapipe/tasks/testdata/text:text_classifier_models"], + linkopts = ["-ldl"], + deps = [ + ":regex_preprocessor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:sink", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "text_to_tensor_calculator", srcs = ["text_to_tensor_calculator.cc"], diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index a85538631..336b1bb45 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -63,6 +63,29 @@ cc_library( ], ) +cc_test( + name = "text_classifier_test", + srcs = ["text_classifier_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + ":text_classifier", + ":text_classifier_test_utils", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) + cc_library( name = "text_classifier_test_utils", srcs = ["text_classifier_test_utils.cc"], diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 5b33f6606..62837be8c 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -49,9 +49,6 @@ using ::mediapipe::tasks::kMediaPipeTasksPayload; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::testing::HasSubstr; using ::testing::Optional; -using ::testing::proto::Approximately; -using ::testing::proto::IgnoringRepeatedFieldOrdering; -using ::testing::proto::Partially; constexpr float kEpsilon = 0.001; constexpr int kMaxSeqLen = 128; @@ -110,127 +107,6 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) { MP_ASSERT_OK(TextClassifier::Create(std::move(options))); } -TEST_F(TextClassifierTest, TextClassifierWithBert) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN( - ClassificationResult negative_result, - classifier->Classify("unflinchingly bleak and desperate")); - ASSERT_THAT(negative_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "negative" score: 0.956 } - categories { category_name: "positive" score: 0.044 } - } - } - )pb"), - kEpsilon)))); - - MP_ASSERT_OK_AND_ASSIGN( - ClassificationResult positive_result, - classifier->Classify("it's a charming and often affecting journey")); - ASSERT_THAT(positive_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "negative" score: 0.0 } - categories { category_name: "positive" score: 1.0 } - } - } - )pb"), - kEpsilon)))); - MP_ASSERT_OK(classifier->Close()); -} - -TEST_F(TextClassifierTest, TextClassifierWithIntInputs) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result, - classifier->Classify("What a waste of my time.")); - ASSERT_THAT(negative_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "Negative" score: 0.813 } - categories { category_name: "Positive" score: 0.187 } - } - } - )pb"), - kEpsilon)))); - - MP_ASSERT_OK_AND_ASSIGN( - ClassificationResult positive_result, - classifier->Classify("This is the best movie I’ve seen in recent years. " - "Strongly recommend it!")); - ASSERT_THAT(positive_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "Negative" score: 0.487 } - categories { category_name: "Positive" score: 0.513 } - } - } - )pb"), - kEpsilon)))); - MP_ASSERT_OK(classifier->Close()); -} - -TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath); - options->base_options.op_resolver = CreateCustomResolver(); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result, - classifier->Classify("hello")); - ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb( - classifications { - entries { - categories { index: 1 score: 1 } - categories { index: 0 score: 1 } - categories { index: 2 score: 0 } - } - } - )pb")))); -} - -TEST_F(TextClassifierTest, BertLongPositive) { - std::stringstream ss_for_positive_review; - ss_for_positive_review - << "it's a charming and often affecting journey and this is a long"; - for (int i = 0; i < kMaxSeqLen; ++i) { - ss_for_positive_review << " long"; - } - ss_for_positive_review << " movie review"; - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result, - classifier->Classify(ss_for_positive_review.str())); - ASSERT_THAT(result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "negative" score: 0.014 } - categories { category_name: "positive" score: 0.986 } - } - } - )pb"), - kEpsilon)))); - MP_ASSERT_OK(classifier->Close()); -} - } // namespace } // namespace text_classifier } // namespace text diff --git a/mediapipe/tasks/cc/text/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD index 048c7021d..e76d943c5 100644 --- a/mediapipe/tasks/cc/text/tokenizers/BUILD +++ b/mediapipe/tasks/cc/text/tokenizers/BUILD @@ -73,8 +73,6 @@ cc_library( ], ) -# TODO: This test fails in OSS - cc_library( name = "tokenizer_utils", srcs = ["tokenizer_utils.cc"], @@ -97,8 +95,6 @@ cc_library( ], ) -# TODO: This test fails in OSS - cc_library( name = "regex_tokenizer", srcs = [ diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 6cce5ae41..14999a03e 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -76,9 +76,10 @@ filegroup( filegroup( name = "text_classifier_models", - srcs = glob([ - "test_model_text_classifier*.tflite", - ]), + srcs = [ + "test_model_text_classifier_bool_output.tflite", + "test_model_text_classifier_with_regex_tokenizer.tflite", + ], ) filegroup(