Open-sources a unit test.
PiperOrigin-RevId: 493184055
This commit is contained in:
parent
3174b20fbe
commit
af43687f2e
|
@ -38,10 +38,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe::tasks::text::text_classifier {
|
||||||
namespace tasks {
|
|
||||||
namespace text {
|
|
||||||
namespace text_classifier {
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
|
@ -88,6 +85,8 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
class TextClassifierTest : public tflite_shims::testing::Test {};
|
class TextClassifierTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
||||||
|
@ -217,8 +216,42 @@ TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
|
||||||
MP_ASSERT_OK(classifier->Close());
|
MP_ASSERT_OK(classifier->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
TEST_F(TextClassifierTest, BertLongPositive) {
|
||||||
} // namespace text_classifier
|
std::stringstream ss_for_positive_review;
|
||||||
} // namespace text
|
ss_for_positive_review
|
||||||
} // namespace tasks
|
<< "it's a charming and often affecting journey and this is a long";
|
||||||
} // namespace mediapipe
|
for (int i = 0; i < kMaxSeqLen; ++i) {
|
||||||
|
ss_for_positive_review << " long";
|
||||||
|
}
|
||||||
|
ss_for_positive_review << " movie review";
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||||
|
TextClassifier::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result,
|
||||||
|
classifier->Classify(ss_for_positive_review.str()));
|
||||||
|
TextClassifierResult expected;
|
||||||
|
std::vector<Category> categories;
|
||||||
|
|
||||||
|
// Predicted scores are slightly different on Mac OS.
|
||||||
|
#ifdef __APPLE__
|
||||||
|
categories.push_back(
|
||||||
|
{/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"});
|
||||||
|
categories.push_back(
|
||||||
|
{/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"});
|
||||||
|
#else
|
||||||
|
categories.push_back(
|
||||||
|
{/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"});
|
||||||
|
categories.push_back(
|
||||||
|
{/*index=*/0, /*score=*/0.014112, /*category_name=*/"negative"});
|
||||||
|
#endif // __APPLE__
|
||||||
|
|
||||||
|
expected.classifications.emplace_back(
|
||||||
|
Classifications{/*categories=*/categories,
|
||||||
|
/*head_index=*/0,
|
||||||
|
/*head_name=*/"probability"});
|
||||||
|
ExpectApproximatelyEqual(result, expected);
|
||||||
|
MP_ASSERT_OK(classifier->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::text_classifier
|
||||||
|
|
Loading…
Reference in New Issue
Block a user