Make TextEmbedder and TextClassifier tests pass on Windows
PiperOrigin-RevId: 506421383
This commit is contained in:
parent
bdd77b0d61
commit
286dde97ad
|
@ -135,28 +135,46 @@ TEST_F(TextClassifierTest, TextClassifierWithBert) {
|
||||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||||
TextClassifier::Create(std::move(options)));
|
TextClassifier::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
|
||||||
TextClassifierResult negative_result,
|
|
||||||
classifier->Classify("unflinchingly bleak and desperate"));
|
|
||||||
TextClassifierResult negative_expected;
|
TextClassifierResult negative_expected;
|
||||||
|
TextClassifierResult positive_expected;
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
negative_expected.classifications.emplace_back(Classifications{
|
||||||
|
/*categories=*/{
|
||||||
|
{/*index=*/0, /*score=*/0.956124, /*category_name=*/"negative"},
|
||||||
|
{/*index=*/1, /*score=*/0.043875, /*category_name=*/"positive"}},
|
||||||
|
/*head_index=*/0,
|
||||||
|
/*head_name=*/"probability"});
|
||||||
|
positive_expected.classifications.emplace_back(Classifications{
|
||||||
|
/*categories=*/{
|
||||||
|
{/*index=*/1, /*score=*/0.999951, /*category_name=*/"positive"},
|
||||||
|
{/*index=*/0, /*score=*/0.000048, /*category_name=*/"negative"}},
|
||||||
|
/*head_index=*/0,
|
||||||
|
/*head_name=*/"probability"});
|
||||||
|
#else
|
||||||
negative_expected.classifications.emplace_back(Classifications{
|
negative_expected.classifications.emplace_back(Classifications{
|
||||||
/*categories=*/{
|
/*categories=*/{
|
||||||
{/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"},
|
{/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"},
|
||||||
{/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}},
|
{/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}},
|
||||||
/*head_index=*/0,
|
/*head_index=*/0,
|
||||||
/*head_name=*/"probability"});
|
/*head_name=*/"probability"});
|
||||||
ExpectApproximatelyEqual(negative_result, negative_expected);
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
|
||||||
TextClassifierResult positive_result,
|
|
||||||
classifier->Classify("it's a charming and often affecting journey"));
|
|
||||||
TextClassifierResult positive_expected;
|
|
||||||
positive_expected.classifications.emplace_back(Classifications{
|
positive_expected.classifications.emplace_back(Classifications{
|
||||||
/*categories=*/{
|
/*categories=*/{
|
||||||
{/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"},
|
{/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"},
|
||||||
{/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}},
|
{/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}},
|
||||||
/*head_index=*/0,
|
/*head_index=*/0,
|
||||||
/*head_name=*/"probability"});
|
/*head_name=*/"probability"});
|
||||||
|
#endif // _WIN32
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TextClassifierResult negative_result,
|
||||||
|
classifier->Classify("unflinchingly bleak and desperate"));
|
||||||
|
ExpectApproximatelyEqual(negative_result, negative_expected);
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TextClassifierResult positive_result,
|
||||||
|
classifier->Classify("it's a charming and often affecting journey"));
|
||||||
ExpectApproximatelyEqual(positive_result, positive_expected);
|
ExpectApproximatelyEqual(positive_result, positive_expected);
|
||||||
|
|
||||||
MP_ASSERT_OK(classifier->Close());
|
MP_ASSERT_OK(classifier->Close());
|
||||||
|
@ -233,12 +251,17 @@ TEST_F(TextClassifierTest, BertLongPositive) {
|
||||||
TextClassifierResult expected;
|
TextClassifierResult expected;
|
||||||
std::vector<Category> categories;
|
std::vector<Category> categories;
|
||||||
|
|
||||||
// Predicted scores are slightly different on Mac OS.
|
// Predicted scores are slightly different across platforms.
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
categories.push_back(
|
categories.push_back(
|
||||||
{/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"});
|
{/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"});
|
||||||
categories.push_back(
|
categories.push_back(
|
||||||
{/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"});
|
{/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"});
|
||||||
|
#elif defined _WIN32
|
||||||
|
categories.push_back(
|
||||||
|
{/*index=*/1, /*score=*/0.976686, /*category_name=*/"positive"});
|
||||||
|
categories.push_back(
|
||||||
|
{/*index=*/0, /*score=*/0.023313, /*category_name=*/"negative"});
|
||||||
#else
|
#else
|
||||||
categories.push_back(
|
categories.push_back(
|
||||||
{/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"});
|
{/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"});
|
||||||
|
|
|
@ -75,7 +75,11 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) {
|
||||||
text_embedder->Embed("it's a charming and often affecting journey"));
|
text_embedder->Embed("it's a charming and often affecting journey"));
|
||||||
ASSERT_EQ(result0.embeddings.size(), 1);
|
ASSERT_EQ(result0.embeddings.size(), 1);
|
||||||
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
|
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
|
||||||
|
#ifdef _WIN32
|
||||||
|
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.2148f, kEpsilon);
|
||||||
|
#else
|
||||||
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
|
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
|
||||||
|
#endif // _WIN32
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
||||||
|
@ -87,7 +91,11 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||||
result1.embeddings[0]));
|
result1.embeddings[0]));
|
||||||
|
#ifdef _WIN32
|
||||||
|
EXPECT_NEAR(similarity, 0.971417, kSimilarityTolerancy);
|
||||||
|
#else
|
||||||
EXPECT_NEAR(similarity, 0.969514, kSimilarityTolerancy);
|
EXPECT_NEAR(similarity, 0.969514, kSimilarityTolerancy);
|
||||||
|
#endif // _WIN32
|
||||||
|
|
||||||
MP_ASSERT_OK(text_embedder->Close());
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
}
|
}
|
||||||
|
@ -160,8 +168,12 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||||
result1.embeddings[0]));
|
result1.embeddings[0]));
|
||||||
// TODO: The similarity should likely be lower
|
// TODO: These similarity should likely be lower
|
||||||
|
#ifdef _WIN32
|
||||||
|
EXPECT_NEAR(similarity, 0.98152, kSimilarityTolerancy);
|
||||||
|
#else
|
||||||
EXPECT_NEAR(similarity, 0.98088, kSimilarityTolerancy);
|
EXPECT_NEAR(similarity, 0.98088, kSimilarityTolerancy);
|
||||||
|
#endif // _WIN32
|
||||||
|
|
||||||
MP_ASSERT_OK(text_embedder->Close());
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user