Internal change

PiperOrigin-RevId: 482925717
This commit is contained in:
MediaPipe Team 2022-10-21 17:26:48 -07:00 committed by Copybara-Service
parent ea1d85d811
commit 7196db275e
5 changed files with 47 additions and 131 deletions

View File

@ -253,6 +253,26 @@ cc_library(
alwayslink = 1,
)
cc_test(
name = "regex_preprocessor_calculator_test",
srcs = ["regex_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
linkopts = ["-ldl"],
deps = [
":regex_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:sink",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "text_to_tensor_calculator",
srcs = ["text_to_tensor_calculator.cc"],

View File

@ -63,6 +63,29 @@ cc_library(
],
)
cc_test(
name = "text_classifier_test",
srcs = ["text_classifier_test.cc"],
data = [
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
deps = [
":text_classifier",
":text_classifier_test_utils",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)
cc_library(
name = "text_classifier_test_utils",
srcs = ["text_classifier_test_utils.cc"],

View File

@ -49,9 +49,6 @@ using ::mediapipe::tasks::kMediaPipeTasksPayload;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::HasSubstr;
using ::testing::Optional;
using ::testing::proto::Approximately;
using ::testing::proto::IgnoringRepeatedFieldOrdering;
using ::testing::proto::Partially;
constexpr float kEpsilon = 0.001;
constexpr int kMaxSeqLen = 128;
@ -110,127 +107,6 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
}
TEST_F(TextClassifierTest, TextClassifierWithBert) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult negative_result,
classifier->Classify("unflinchingly bleak and desperate"));
ASSERT_THAT(negative_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.956 }
categories { category_name: "positive" score: 0.044 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult positive_result,
classifier->Classify("it's a charming and often affecting journey"));
ASSERT_THAT(positive_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.0 }
categories { category_name: "positive" score: 1.0 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result,
classifier->Classify("What a waste of my time."));
ASSERT_THAT(negative_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "Negative" score: 0.813 }
categories { category_name: "Positive" score: 0.187 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult positive_result,
classifier->Classify("This is the best movie Ive seen in recent years. "
"Strongly recommend it!"));
ASSERT_THAT(positive_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "Negative" score: 0.487 }
categories { category_name: "Positive" score: 0.513 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
options->base_options.op_resolver = CreateCustomResolver();
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
classifier->Classify("hello"));
ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb(
classifications {
entries {
categories { index: 1 score: 1 }
categories { index: 0 score: 1 }
categories { index: 2 score: 0 }
}
}
)pb"))));
}
TEST_F(TextClassifierTest, BertLongPositive) {
std::stringstream ss_for_positive_review;
ss_for_positive_review
<< "it's a charming and often affecting journey and this is a long";
for (int i = 0; i < kMaxSeqLen; ++i) {
ss_for_positive_review << " long";
}
ss_for_positive_review << " movie review";
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
classifier->Classify(ss_for_positive_review.str()));
ASSERT_THAT(result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.014 }
categories { category_name: "positive" score: 0.986 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
} // namespace
} // namespace text_classifier
} // namespace text

View File

@ -73,8 +73,6 @@ cc_library(
],
)
# TODO: This test fails in OSS
cc_library(
name = "tokenizer_utils",
srcs = ["tokenizer_utils.cc"],
@ -97,8 +95,6 @@ cc_library(
],
)
# TODO: This test fails in OSS
cc_library(
name = "regex_tokenizer",
srcs = [

View File

@ -76,9 +76,10 @@ filegroup(
filegroup(
name = "text_classifier_models",
srcs = glob([
"test_model_text_classifier*.tflite",
]),
srcs = [
"test_model_text_classifier_bool_output.tflite",
"test_model_text_classifier_with_regex_tokenizer.tflite",
],
)
filegroup(