Internal text tasks change.
PiperOrigin-RevId: 506957718
This commit is contained in:
parent
386445c8dd
commit
632a3602dd
|
@ -27,5 +27,7 @@ message TextModelType {
|
|||
REGEX_MODEL = 2;
|
||||
// A model taking a string tensor input.
|
||||
STRING_MODEL = 3;
|
||||
// A UniversalSentenceEncoder-based model.
|
||||
USE_MODEL = 4;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,6 +58,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromModelType(
|
|||
TextModelType::ModelType model_type) {
|
||||
switch (model_type) {
|
||||
case TextModelType::UNSPECIFIED_MODEL:
|
||||
// TODO: Support the UniversalSentenceEncoder model.
|
||||
case TextModelType::USE_MODEL:
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument, "Unspecified model type",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
|
@ -129,7 +131,8 @@ absl::Status ConfigureTextPreprocessingGraph(
|
|||
options.set_model_type(model_type);
|
||||
switch (model_type) {
|
||||
case TextModelType::UNSPECIFIED_MODEL:
|
||||
case TextModelType::STRING_MODEL: {
|
||||
case TextModelType::STRING_MODEL:
|
||||
case TextModelType::USE_MODEL: {
|
||||
break;
|
||||
}
|
||||
case TextModelType::BERT_MODEL:
|
||||
|
@ -186,7 +189,8 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
|
|||
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
||||
switch (options.model_type()) {
|
||||
case TextModelType::UNSPECIFIED_MODEL:
|
||||
case TextModelType::STRING_MODEL: {
|
||||
case TextModelType::STRING_MODEL:
|
||||
case TextModelType::USE_MODEL: {
|
||||
break;
|
||||
}
|
||||
case TextModelType::BERT_MODEL: {
|
||||
|
|
|
@ -69,6 +69,7 @@ cc_test(
|
|||
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||
"//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa",
|
||||
],
|
||||
deps = [
|
||||
":text_model_utils",
|
||||
|
|
|
@ -35,6 +35,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
|||
constexpr int kNumInputTensorsForBert = 3;
|
||||
constexpr int kNumInputTensorsForRegex = 1;
|
||||
constexpr int kNumInputTensorsForStringPreprocessor = 1;
|
||||
constexpr int kNumInputTensorsForUSE = 3;
|
||||
|
||||
// Determines the ModelType for a model with int32 input tensors based
|
||||
// on the number of input tensors. Returns an error if there is missing metadata
|
||||
|
@ -78,12 +79,16 @@ absl::StatusOr<TextModelType::ModelType> GetStringTensorModelType(
|
|||
return TextModelType::STRING_MODEL;
|
||||
}
|
||||
|
||||
if (num_input_tensors == kNumInputTensorsForUSE) {
|
||||
return TextModelType::USE_MODEL;
|
||||
}
|
||||
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::Substitute("Models with string input tensors should take exactly "
|
||||
"$0 tensors, but found $1",
|
||||
"$0 or $1 input tensors, but found $2",
|
||||
kNumInputTensorsForStringPreprocessor,
|
||||
num_input_tensors),
|
||||
kNumInputTensorsForUSE, num_input_tensors),
|
||||
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -57,6 +57,8 @@ constexpr absl::string_view kRegexOneEmbeddingModel =
|
|||
// Classification model that takes a string tensor and outputs a bool tensor.
|
||||
constexpr absl::string_view kStringToBoolModelPath =
|
||||
"test_model_text_classifier_bool_output.tflite";
|
||||
constexpr char kUniversalSentenceEncoderModel[] =
|
||||
"universal_sentence_encoder_qa_with_metadata.tflite";
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
|
@ -105,4 +107,10 @@ TEST_F(TextModelUtilsTest, StringInputModelTest) {
|
|||
ASSERT_EQ(model_type, TextModelType::STRING_MODEL);
|
||||
}
|
||||
|
||||
TEST_F(TextModelUtilsTest, USEModelTest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||
GetModelTypeFromFile(kUniversalSentenceEncoderModel));
|
||||
ASSERT_EQ(model_type, TextModelType::USE_MODEL);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::text::utils
|
||||
|
|
Loading…
Reference in New Issue
Block a user