Internal change
PiperOrigin-RevId: 482925717
This commit is contained in:
parent
ea1d85d811
commit
7196db275e
|
@ -253,6 +253,26 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "regex_preprocessor_calculator_test",
|
||||
srcs = ["regex_preprocessor_calculator_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":regex_preprocessor_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:sink",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_to_tensor_calculator",
|
||||
srcs = ["text_to_tensor_calculator.cc"],
|
||||
|
|
|
@ -63,6 +63,29 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "text_classifier_test",
|
||||
srcs = ["text_classifier_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
|
||||
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||
],
|
||||
deps = [
|
||||
":text_classifier",
|
||||
":text_classifier_test_utils",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:cord",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_classifier_test_utils",
|
||||
srcs = ["text_classifier_test_utils.cc"],
|
||||
|
|
|
@ -49,9 +49,6 @@ using ::mediapipe::tasks::kMediaPipeTasksPayload;
|
|||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::IgnoringRepeatedFieldOrdering;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr float kEpsilon = 0.001;
|
||||
constexpr int kMaxSeqLen = 128;
|
||||
|
@ -110,127 +107,6 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
|
|||
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithBert) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult negative_result,
|
||||
classifier->Classify("unflinchingly bleak and desperate"));
|
||||
ASSERT_THAT(negative_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.956 }
|
||||
categories { category_name: "positive" score: 0.044 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult positive_result,
|
||||
classifier->Classify("it's a charming and often affecting journey"));
|
||||
ASSERT_THAT(positive_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.0 }
|
||||
categories { category_name: "positive" score: 1.0 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result,
|
||||
classifier->Classify("What a waste of my time."));
|
||||
ASSERT_THAT(negative_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "Negative" score: 0.813 }
|
||||
categories { category_name: "Positive" score: 0.187 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult positive_result,
|
||||
classifier->Classify("This is the best movie I’ve seen in recent years. "
|
||||
"Strongly recommend it!"));
|
||||
ASSERT_THAT(positive_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "Negative" score: 0.487 }
|
||||
categories { category_name: "Positive" score: 0.513 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
|
||||
options->base_options.op_resolver = CreateCustomResolver();
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
||||
classifier->Classify("hello"));
|
||||
ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { index: 1 score: 1 }
|
||||
categories { index: 0 score: 1 }
|
||||
categories { index: 2 score: 0 }
|
||||
}
|
||||
}
|
||||
)pb"))));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, BertLongPositive) {
|
||||
std::stringstream ss_for_positive_review;
|
||||
ss_for_positive_review
|
||||
<< "it's a charming and often affecting journey and this is a long";
|
||||
for (int i = 0; i < kMaxSeqLen; ++i) {
|
||||
ss_for_positive_review << " long";
|
||||
}
|
||||
ss_for_positive_review << " movie review";
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
||||
classifier->Classify(ss_for_positive_review.str()));
|
||||
ASSERT_THAT(result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.014 }
|
||||
categories { category_name: "positive" score: 0.986 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace text_classifier
|
||||
} // namespace text
|
||||
|
|
|
@ -73,8 +73,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
# TODO: This test fails in OSS
|
||||
|
||||
cc_library(
|
||||
name = "tokenizer_utils",
|
||||
srcs = ["tokenizer_utils.cc"],
|
||||
|
@ -97,8 +95,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
# TODO: This test fails in OSS
|
||||
|
||||
cc_library(
|
||||
name = "regex_tokenizer",
|
||||
srcs = [
|
||||
|
|
7
mediapipe/tasks/testdata/text/BUILD
vendored
7
mediapipe/tasks/testdata/text/BUILD
vendored
|
@ -76,9 +76,10 @@ filegroup(
|
|||
|
||||
filegroup(
|
||||
name = "text_classifier_models",
|
||||
srcs = glob([
|
||||
"test_model_text_classifier*.tflite",
|
||||
]),
|
||||
srcs = [
|
||||
"test_model_text_classifier_bool_output.tflite",
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
|
Loading…
Reference in New Issue
Block a user