Open-sources a unit test.
PiperOrigin-RevId: 493184055
This commit is contained in:
		
							parent
							
								
									3174b20fbe
								
							
						
					
					
						commit
						af43687f2e
					
				| 
						 | 
				
			
			@ -38,10 +38,7 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
 | 
			
		||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace tasks {
 | 
			
		||||
namespace text {
 | 
			
		||||
namespace text_classifier {
 | 
			
		||||
namespace mediapipe::tasks::text::text_classifier {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::file::JoinPath;
 | 
			
		||||
| 
						 | 
				
			
			@ -88,6 +85,8 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual,
 | 
			
		|||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
class TextClassifierTest : public tflite_shims::testing::Test {};
 | 
			
		||||
 | 
			
		||||
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
 | 
			
		||||
| 
						 | 
				
			
			@ -217,8 +216,42 @@ TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
 | 
			
		|||
  MP_ASSERT_OK(classifier->Close());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace text_classifier
 | 
			
		||||
}  // namespace text
 | 
			
		||||
}  // namespace tasks
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
TEST_F(TextClassifierTest, BertLongPositive) {
 | 
			
		||||
  std::stringstream ss_for_positive_review;
 | 
			
		||||
  ss_for_positive_review
 | 
			
		||||
      << "it's a charming and often affecting journey and this is a long";
 | 
			
		||||
  for (int i = 0; i < kMaxSeqLen; ++i) {
 | 
			
		||||
    ss_for_positive_review << " long";
 | 
			
		||||
  }
 | 
			
		||||
  ss_for_positive_review << " movie review";
 | 
			
		||||
  auto options = std::make_unique<TextClassifierOptions>();
 | 
			
		||||
  options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
 | 
			
		||||
                          TextClassifier::Create(std::move(options)));
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result,
 | 
			
		||||
                          classifier->Classify(ss_for_positive_review.str()));
 | 
			
		||||
  TextClassifierResult expected;
 | 
			
		||||
  std::vector<Category> categories;
 | 
			
		||||
 | 
			
		||||
// Predicted scores are slightly different on Mac OS.
 | 
			
		||||
#ifdef __APPLE__
 | 
			
		||||
  categories.push_back(
 | 
			
		||||
      {/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"});
 | 
			
		||||
  categories.push_back(
 | 
			
		||||
      {/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"});
 | 
			
		||||
#else
 | 
			
		||||
  categories.push_back(
 | 
			
		||||
      {/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"});
 | 
			
		||||
  categories.push_back(
 | 
			
		||||
      {/*index=*/0, /*score=*/0.014112, /*category_name=*/"negative"});
 | 
			
		||||
#endif  // __APPLE__
 | 
			
		||||
 | 
			
		||||
  expected.classifications.emplace_back(
 | 
			
		||||
      Classifications{/*categories=*/categories,
 | 
			
		||||
                      /*head_index=*/0,
 | 
			
		||||
                      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(result, expected);
 | 
			
		||||
  MP_ASSERT_OK(classifier->Close());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::text::text_classifier
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user