Internal change
PiperOrigin-RevId: 482925717
This commit is contained in:
parent
ea1d85d811
commit
7196db275e
|
@ -253,6 +253,26 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "regex_preprocessor_calculator_test",
|
||||||
|
srcs = ["regex_preprocessor_calculator_test.cc"],
|
||||||
|
data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
|
||||||
|
linkopts = ["-ldl"],
|
||||||
|
deps = [
|
||||||
|
":regex_preprocessor_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/tool:sink",
|
||||||
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "text_to_tensor_calculator",
|
name = "text_to_tensor_calculator",
|
||||||
srcs = ["text_to_tensor_calculator.cc"],
|
srcs = ["text_to_tensor_calculator.cc"],
|
||||||
|
|
|
@ -63,6 +63,29 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "text_classifier_test",
|
||||||
|
srcs = ["text_classifier_test.cc"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
|
||||||
|
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":text_classifier",
|
||||||
|
":text_classifier_test_utils",
|
||||||
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:cord",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "text_classifier_test_utils",
|
name = "text_classifier_test_utils",
|
||||||
srcs = ["text_classifier_test_utils.cc"],
|
srcs = ["text_classifier_test_utils.cc"],
|
||||||
|
|
|
@ -49,9 +49,6 @@ using ::mediapipe::tasks::kMediaPipeTasksPayload;
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
using ::testing::proto::Approximately;
|
|
||||||
using ::testing::proto::IgnoringRepeatedFieldOrdering;
|
|
||||||
using ::testing::proto::Partially;
|
|
||||||
|
|
||||||
constexpr float kEpsilon = 0.001;
|
constexpr float kEpsilon = 0.001;
|
||||||
constexpr int kMaxSeqLen = 128;
|
constexpr int kMaxSeqLen = 128;
|
||||||
|
@ -110,127 +107,6 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
|
||||||
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TextClassifierTest, TextClassifierWithBert) {
|
|
||||||
auto options = std::make_unique<TextClassifierOptions>();
|
|
||||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
|
||||||
TextClassifier::Create(std::move(options)));
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
|
||||||
ClassificationResult negative_result,
|
|
||||||
classifier->Classify("unflinchingly bleak and desperate"));
|
|
||||||
ASSERT_THAT(negative_result,
|
|
||||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
|
||||||
EqualsProto(R"pb(
|
|
||||||
classifications {
|
|
||||||
entries {
|
|
||||||
categories { category_name: "negative" score: 0.956 }
|
|
||||||
categories { category_name: "positive" score: 0.044 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)pb"),
|
|
||||||
kEpsilon))));
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
|
||||||
ClassificationResult positive_result,
|
|
||||||
classifier->Classify("it's a charming and often affecting journey"));
|
|
||||||
ASSERT_THAT(positive_result,
|
|
||||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
|
||||||
EqualsProto(R"pb(
|
|
||||||
classifications {
|
|
||||||
entries {
|
|
||||||
categories { category_name: "negative" score: 0.0 }
|
|
||||||
categories { category_name: "positive" score: 1.0 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)pb"),
|
|
||||||
kEpsilon))));
|
|
||||||
MP_ASSERT_OK(classifier->Close());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
|
|
||||||
auto options = std::make_unique<TextClassifierOptions>();
|
|
||||||
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
|
||||||
TextClassifier::Create(std::move(options)));
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result,
|
|
||||||
classifier->Classify("What a waste of my time."));
|
|
||||||
ASSERT_THAT(negative_result,
|
|
||||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
|
||||||
EqualsProto(R"pb(
|
|
||||||
classifications {
|
|
||||||
entries {
|
|
||||||
categories { category_name: "Negative" score: 0.813 }
|
|
||||||
categories { category_name: "Positive" score: 0.187 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)pb"),
|
|
||||||
kEpsilon))));
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
|
||||||
ClassificationResult positive_result,
|
|
||||||
classifier->Classify("This is the best movie I’ve seen in recent years. "
|
|
||||||
"Strongly recommend it!"));
|
|
||||||
ASSERT_THAT(positive_result,
|
|
||||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
|
||||||
EqualsProto(R"pb(
|
|
||||||
classifications {
|
|
||||||
entries {
|
|
||||||
categories { category_name: "Negative" score: 0.487 }
|
|
||||||
categories { category_name: "Positive" score: 0.513 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)pb"),
|
|
||||||
kEpsilon))));
|
|
||||||
MP_ASSERT_OK(classifier->Close());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
|
|
||||||
auto options = std::make_unique<TextClassifierOptions>();
|
|
||||||
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
|
|
||||||
options->base_options.op_resolver = CreateCustomResolver();
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
|
||||||
TextClassifier::Create(std::move(options)));
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
|
||||||
classifier->Classify("hello"));
|
|
||||||
ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb(
|
|
||||||
classifications {
|
|
||||||
entries {
|
|
||||||
categories { index: 1 score: 1 }
|
|
||||||
categories { index: 0 score: 1 }
|
|
||||||
categories { index: 2 score: 0 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)pb"))));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TextClassifierTest, BertLongPositive) {
|
|
||||||
std::stringstream ss_for_positive_review;
|
|
||||||
ss_for_positive_review
|
|
||||||
<< "it's a charming and often affecting journey and this is a long";
|
|
||||||
for (int i = 0; i < kMaxSeqLen; ++i) {
|
|
||||||
ss_for_positive_review << " long";
|
|
||||||
}
|
|
||||||
ss_for_positive_review << " movie review";
|
|
||||||
auto options = std::make_unique<TextClassifierOptions>();
|
|
||||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
|
||||||
TextClassifier::Create(std::move(options)));
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
|
||||||
classifier->Classify(ss_for_positive_review.str()));
|
|
||||||
ASSERT_THAT(result,
|
|
||||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
|
||||||
EqualsProto(R"pb(
|
|
||||||
classifications {
|
|
||||||
entries {
|
|
||||||
categories { category_name: "negative" score: 0.014 }
|
|
||||||
categories { category_name: "positive" score: 0.986 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)pb"),
|
|
||||||
kEpsilon))));
|
|
||||||
MP_ASSERT_OK(classifier->Close());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace text_classifier
|
} // namespace text_classifier
|
||||||
} // namespace text
|
} // namespace text
|
||||||
|
|
|
@ -73,8 +73,6 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: This test fails in OSS
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tokenizer_utils",
|
name = "tokenizer_utils",
|
||||||
srcs = ["tokenizer_utils.cc"],
|
srcs = ["tokenizer_utils.cc"],
|
||||||
|
@ -97,8 +95,6 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: This test fails in OSS
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "regex_tokenizer",
|
name = "regex_tokenizer",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
7
mediapipe/tasks/testdata/text/BUILD
vendored
7
mediapipe/tasks/testdata/text/BUILD
vendored
|
@ -76,9 +76,10 @@ filegroup(
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "text_classifier_models",
|
name = "text_classifier_models",
|
||||||
srcs = glob([
|
srcs = [
|
||||||
"test_model_text_classifier*.tflite",
|
"test_model_text_classifier_bool_output.tflite",
|
||||||
]),
|
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user