Internal text tasks change.
PiperOrigin-RevId: 506957718
This commit is contained in:
parent
386445c8dd
commit
632a3602dd
|
@ -27,5 +27,7 @@ message TextModelType {
|
||||||
REGEX_MODEL = 2;
|
REGEX_MODEL = 2;
|
||||||
// A model taking a string tensor input.
|
// A model taking a string tensor input.
|
||||||
STRING_MODEL = 3;
|
STRING_MODEL = 3;
|
||||||
|
// A UniversalSentenceEncoder-based model.
|
||||||
|
USE_MODEL = 4;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,6 +58,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromModelType(
|
||||||
TextModelType::ModelType model_type) {
|
TextModelType::ModelType model_type) {
|
||||||
switch (model_type) {
|
switch (model_type) {
|
||||||
case TextModelType::UNSPECIFIED_MODEL:
|
case TextModelType::UNSPECIFIED_MODEL:
|
||||||
|
// TODO: Support the UniversalSentenceEncoder model.
|
||||||
|
case TextModelType::USE_MODEL:
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument, "Unspecified model type",
|
absl::StatusCode::kInvalidArgument, "Unspecified model type",
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
@ -129,7 +131,8 @@ absl::Status ConfigureTextPreprocessingGraph(
|
||||||
options.set_model_type(model_type);
|
options.set_model_type(model_type);
|
||||||
switch (model_type) {
|
switch (model_type) {
|
||||||
case TextModelType::UNSPECIFIED_MODEL:
|
case TextModelType::UNSPECIFIED_MODEL:
|
||||||
case TextModelType::STRING_MODEL: {
|
case TextModelType::STRING_MODEL:
|
||||||
|
case TextModelType::USE_MODEL: {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case TextModelType::BERT_MODEL:
|
case TextModelType::BERT_MODEL:
|
||||||
|
@ -186,7 +189,8 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
|
||||||
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
||||||
switch (options.model_type()) {
|
switch (options.model_type()) {
|
||||||
case TextModelType::UNSPECIFIED_MODEL:
|
case TextModelType::UNSPECIFIED_MODEL:
|
||||||
case TextModelType::STRING_MODEL: {
|
case TextModelType::STRING_MODEL:
|
||||||
|
case TextModelType::USE_MODEL: {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case TextModelType::BERT_MODEL: {
|
case TextModelType::BERT_MODEL: {
|
||||||
|
|
|
@ -69,6 +69,7 @@ cc_test(
|
||||||
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||||
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||||
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||||
|
"//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":text_model_utils",
|
":text_model_utils",
|
||||||
|
|
|
@ -35,6 +35,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
constexpr int kNumInputTensorsForBert = 3;
|
constexpr int kNumInputTensorsForBert = 3;
|
||||||
constexpr int kNumInputTensorsForRegex = 1;
|
constexpr int kNumInputTensorsForRegex = 1;
|
||||||
constexpr int kNumInputTensorsForStringPreprocessor = 1;
|
constexpr int kNumInputTensorsForStringPreprocessor = 1;
|
||||||
|
constexpr int kNumInputTensorsForUSE = 3;
|
||||||
|
|
||||||
// Determines the ModelType for a model with int32 input tensors based
|
// Determines the ModelType for a model with int32 input tensors based
|
||||||
// on the number of input tensors. Returns an error if there is missing metadata
|
// on the number of input tensors. Returns an error if there is missing metadata
|
||||||
|
@ -78,12 +79,16 @@ absl::StatusOr<TextModelType::ModelType> GetStringTensorModelType(
|
||||||
return TextModelType::STRING_MODEL;
|
return TextModelType::STRING_MODEL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (num_input_tensors == kNumInputTensorsForUSE) {
|
||||||
|
return TextModelType::USE_MODEL;
|
||||||
|
}
|
||||||
|
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::Substitute("Models with string input tensors should take exactly "
|
absl::Substitute("Models with string input tensors should take exactly "
|
||||||
"$0 tensors, but found $1",
|
"$0 or $1 input tensors, but found $2",
|
||||||
kNumInputTensorsForStringPreprocessor,
|
kNumInputTensorsForStringPreprocessor,
|
||||||
num_input_tensors),
|
kNumInputTensorsForUSE, num_input_tensors),
|
||||||
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
|
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -57,6 +57,8 @@ constexpr absl::string_view kRegexOneEmbeddingModel =
|
||||||
// Classification model that takes a string tensor and outputs a bool tensor.
|
// Classification model that takes a string tensor and outputs a bool tensor.
|
||||||
constexpr absl::string_view kStringToBoolModelPath =
|
constexpr absl::string_view kStringToBoolModelPath =
|
||||||
"test_model_text_classifier_bool_output.tflite";
|
"test_model_text_classifier_bool_output.tflite";
|
||||||
|
constexpr char kUniversalSentenceEncoderModel[] =
|
||||||
|
"universal_sentence_encoder_qa_with_metadata.tflite";
|
||||||
|
|
||||||
std::string GetFullPath(absl::string_view file_name) {
|
std::string GetFullPath(absl::string_view file_name) {
|
||||||
return JoinPath("./", kTestDataDirectory, file_name);
|
return JoinPath("./", kTestDataDirectory, file_name);
|
||||||
|
@ -105,4 +107,10 @@ TEST_F(TextModelUtilsTest, StringInputModelTest) {
|
||||||
ASSERT_EQ(model_type, TextModelType::STRING_MODEL);
|
ASSERT_EQ(model_type, TextModelType::STRING_MODEL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TextModelUtilsTest, USEModelTest) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||||
|
GetModelTypeFromFile(kUniversalSentenceEncoderModel));
|
||||||
|
ASSERT_EQ(model_type, TextModelType::USE_MODEL);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe::tasks::text::utils
|
} // namespace mediapipe::tasks::text::utils
|
||||||
|
|
Loading…
Reference in New Issue
Block a user