Internal text tasks change.

PiperOrigin-RevId: 506957718
This commit is contained in:
MediaPipe Team 2023-02-03 11:49:14 -08:00 committed by Copybara-Service
parent 386445c8dd
commit 632a3602dd
5 changed files with 24 additions and 4 deletions

View File

@ -27,5 +27,7 @@ message TextModelType {
REGEX_MODEL = 2; REGEX_MODEL = 2;
// A model taking a string tensor input. // A model taking a string tensor input.
STRING_MODEL = 3; STRING_MODEL = 3;
// A UniversalSentenceEncoder-based model.
USE_MODEL = 4;
} }
} }

View File

@ -58,6 +58,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromModelType(
TextModelType::ModelType model_type) { TextModelType::ModelType model_type) {
switch (model_type) { switch (model_type) {
case TextModelType::UNSPECIFIED_MODEL: case TextModelType::UNSPECIFIED_MODEL:
// TODO: Support the UniversalSentenceEncoder model.
case TextModelType::USE_MODEL:
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, "Unspecified model type", absl::StatusCode::kInvalidArgument, "Unspecified model type",
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
@ -129,7 +131,8 @@ absl::Status ConfigureTextPreprocessingGraph(
options.set_model_type(model_type); options.set_model_type(model_type);
switch (model_type) { switch (model_type) {
case TextModelType::UNSPECIFIED_MODEL: case TextModelType::UNSPECIFIED_MODEL:
case TextModelType::STRING_MODEL: { case TextModelType::STRING_MODEL:
case TextModelType::USE_MODEL: {
break; break;
} }
case TextModelType::BERT_MODEL: case TextModelType::BERT_MODEL:
@ -186,7 +189,8 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
auto& text_preprocessor = graph.AddNode(preprocessor_name); auto& text_preprocessor = graph.AddNode(preprocessor_name);
switch (options.model_type()) { switch (options.model_type()) {
case TextModelType::UNSPECIFIED_MODEL: case TextModelType::UNSPECIFIED_MODEL:
case TextModelType::STRING_MODEL: { case TextModelType::STRING_MODEL:
case TextModelType::USE_MODEL: {
break; break;
} }
case TextModelType::BERT_MODEL: { case TextModelType::BERT_MODEL: {

View File

@ -69,6 +69,7 @@ cc_test(
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model", "//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
"//mediapipe/tasks/testdata/text:text_classifier_models", "//mediapipe/tasks/testdata/text:text_classifier_models",
"//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa",
], ],
deps = [ deps = [
":text_model_utils", ":text_model_utils",

View File

@ -35,6 +35,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
constexpr int kNumInputTensorsForBert = 3; constexpr int kNumInputTensorsForBert = 3;
constexpr int kNumInputTensorsForRegex = 1; constexpr int kNumInputTensorsForRegex = 1;
constexpr int kNumInputTensorsForStringPreprocessor = 1; constexpr int kNumInputTensorsForStringPreprocessor = 1;
constexpr int kNumInputTensorsForUSE = 3;
// Determines the ModelType for a model with int32 input tensors based // Determines the ModelType for a model with int32 input tensors based
// on the number of input tensors. Returns an error if there is missing metadata // on the number of input tensors. Returns an error if there is missing metadata
@ -78,12 +79,16 @@ absl::StatusOr<TextModelType::ModelType> GetStringTensorModelType(
return TextModelType::STRING_MODEL; return TextModelType::STRING_MODEL;
} }
if (num_input_tensors == kNumInputTensorsForUSE) {
return TextModelType::USE_MODEL;
}
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::Substitute("Models with string input tensors should take exactly " absl::Substitute("Models with string input tensors should take exactly "
"$0 tensors, but found $1", "$0 or $1 input tensors, but found $2",
kNumInputTensorsForStringPreprocessor, kNumInputTensorsForStringPreprocessor,
num_input_tensors), kNumInputTensorsForUSE, num_input_tensors),
MediaPipeTasksStatus::kInvalidNumInputTensorsError); MediaPipeTasksStatus::kInvalidNumInputTensorsError);
} }
} // namespace } // namespace

View File

@ -57,6 +57,8 @@ constexpr absl::string_view kRegexOneEmbeddingModel =
// Classification model that takes a string tensor and outputs a bool tensor. // Classification model that takes a string tensor and outputs a bool tensor.
constexpr absl::string_view kStringToBoolModelPath = constexpr absl::string_view kStringToBoolModelPath =
"test_model_text_classifier_bool_output.tflite"; "test_model_text_classifier_bool_output.tflite";
constexpr char kUniversalSentenceEncoderModel[] =
"universal_sentence_encoder_qa_with_metadata.tflite";
std::string GetFullPath(absl::string_view file_name) { std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name); return JoinPath("./", kTestDataDirectory, file_name);
@ -105,4 +107,10 @@ TEST_F(TextModelUtilsTest, StringInputModelTest) {
ASSERT_EQ(model_type, TextModelType::STRING_MODEL); ASSERT_EQ(model_type, TextModelType::STRING_MODEL);
} }
TEST_F(TextModelUtilsTest, USEModelTest) {
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
GetModelTypeFromFile(kUniversalSentenceEncoderModel));
ASSERT_EQ(model_type, TextModelType::USE_MODEL);
}
} // namespace mediapipe::tasks::text::utils } // namespace mediapipe::tasks::text::utils