Add a string-to-bool test model to TextClassifier.
PiperOrigin-RevId: 479803799
This commit is contained in:
		
							parent
							
								
									08ae99688c
								
							
						
					
					
						commit
						1ab332835a
					
				| 
						 | 
				
			
			@ -320,6 +320,8 @@ cc_library(
 | 
			
		|||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite:framework_stable",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite:string_util",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/c:c_api_types",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,8 +21,10 @@
 | 
			
		|||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "tensorflow/lite/c/c_api_types.h"
 | 
			
		||||
#include "tensorflow/lite/interpreter.h"
 | 
			
		||||
#include "tensorflow/lite/interpreter_builder.h"
 | 
			
		||||
#include "tensorflow/lite/string_util.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
 | 
			
		|||
  std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
 | 
			
		||||
                                         tflite::Interpreter* interpreter,
 | 
			
		||||
                                         int input_tensor_index) {
 | 
			
		||||
  const char* input_tensor_buffer =
 | 
			
		||||
      input_tensor.GetCpuReadView().buffer<char>();
 | 
			
		||||
  tflite::DynamicBuffer dynamic_buffer;
 | 
			
		||||
  dynamic_buffer.AddString(input_tensor_buffer,
 | 
			
		||||
                           input_tensor.shape().num_elements());
 | 
			
		||||
  dynamic_buffer.WriteToTensorAsVector(
 | 
			
		||||
      interpreter->tensor(interpreter->inputs()[input_tensor_index]));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
 | 
			
		||||
                                     int output_tensor_index,
 | 
			
		||||
| 
						 | 
				
			
			@ -87,13 +102,13 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
 | 
			
		|||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      case TfLiteType::kTfLiteUInt8: {
 | 
			
		||||
        CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
 | 
			
		||||
                                             interpreter_.get(), i);
 | 
			
		||||
        CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
 | 
			
		||||
                                               interpreter_.get(), i);
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      case TfLiteType::kTfLiteInt8: {
 | 
			
		||||
        CopyTensorBufferToInterpreter<int8>(input_tensors[i],
 | 
			
		||||
                                            interpreter_.get(), i);
 | 
			
		||||
        CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
 | 
			
		||||
                                              interpreter_.get(), i);
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      case TfLiteType::kTfLiteInt32: {
 | 
			
		||||
| 
						 | 
				
			
			@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
 | 
			
		|||
                                               interpreter_.get(), i);
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      case TfLiteType::kTfLiteString: {
 | 
			
		||||
        CopyTensorBufferToInterpreter<char>(input_tensors[i],
 | 
			
		||||
                                            interpreter_.get(), i);
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      case TfLiteType::kTfLiteBool:
 | 
			
		||||
        // No current use-case for copying MediaPipe Tensors with bool type to
 | 
			
		||||
        // TfLiteTensors.
 | 
			
		||||
      default:
 | 
			
		||||
        return absl::InvalidArgumentError(
 | 
			
		||||
            absl::StrCat("Unsupported input tensor type:", input_tensor_type));
 | 
			
		||||
| 
						 | 
				
			
			@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
 | 
			
		|||
        CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
 | 
			
		||||
                                                 &output_tensors.back());
 | 
			
		||||
        break;
 | 
			
		||||
      case TfLiteType::kTfLiteBool:
 | 
			
		||||
        output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
 | 
			
		||||
                                    Tensor::QuantizationParameters{1.0f, 0});
 | 
			
		||||
        CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
 | 
			
		||||
                                              &output_tensors.back());
 | 
			
		||||
        break;
 | 
			
		||||
      case TfLiteType::kTfLiteString:
 | 
			
		||||
        // No current use-case for copying TfLiteTensors with string type to
 | 
			
		||||
        // MediaPipe Tensors.
 | 
			
		||||
      default:
 | 
			
		||||
        return absl::InvalidArgumentError(
 | 
			
		||||
            absl::StrCat("Unsupported output tensor type:",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
 | 
			
		|||
      case Tensor::ElementType::kInt8:
 | 
			
		||||
        Dequantize<int8>(input_tensor, &output_tensors->back());
 | 
			
		||||
        break;
 | 
			
		||||
      case Tensor::ElementType::kBool:
 | 
			
		||||
        Dequantize<bool>(input_tensor, &output_tensors->back());
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        return absl::InvalidArgumentError(absl::StrCat(
 | 
			
		||||
            "Unsupported input tensor type: ", input_tensor.element_type()));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
 | 
			
		|||
  ValidateResult(GetOutput(), {-1.007874, 0, 1});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
 | 
			
		||||
  std::vector<bool> tensor = {true, false, true};
 | 
			
		||||
  PushTensor(Tensor::ElementType::kBool, tensor,
 | 
			
		||||
             Tensor::QuantizationParameters{1.0f, 0});
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK(runner_.Run());
 | 
			
		||||
 | 
			
		||||
  ValidateResult(GetOutput(), {1, 0, 1});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,8 +97,8 @@ class Tensor {
 | 
			
		|||
    kUInt8,
 | 
			
		||||
    kInt8,
 | 
			
		||||
    kInt32,
 | 
			
		||||
    // TODO: Update the inference runner to handle kTfLiteString.
 | 
			
		||||
    kChar
 | 
			
		||||
    kChar,
 | 
			
		||||
    kBool
 | 
			
		||||
  };
 | 
			
		||||
  struct Shape {
 | 
			
		||||
    Shape() = default;
 | 
			
		||||
| 
						 | 
				
			
			@ -330,6 +330,8 @@ class Tensor {
 | 
			
		|||
        return sizeof(int32_t);
 | 
			
		||||
      case ElementType::kChar:
 | 
			
		||||
        return sizeof(char);
 | 
			
		||||
      case ElementType::kBool:
 | 
			
		||||
        return sizeof(bool);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  int bytes() const { return shape_.num_elements() * element_size(); }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
 | 
			
		|||
 | 
			
		||||
  Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
 | 
			
		||||
  EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
 | 
			
		||||
 | 
			
		||||
  Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
 | 
			
		||||
  EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(Cpu, TestMemoryAllocation) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -123,15 +123,17 @@ absl::StatusOr<ClassificationHeadsProperties> GetClassificationHeadsProperties(
 | 
			
		|||
    const auto* tensor =
 | 
			
		||||
        primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
 | 
			
		||||
    if (tensor->type() != tflite::TensorType_FLOAT32 &&
 | 
			
		||||
        tensor->type() != tflite::TensorType_UINT8) {
 | 
			
		||||
        tensor->type() != tflite::TensorType_UINT8 &&
 | 
			
		||||
        tensor->type() != tflite::TensorType_BOOL) {
 | 
			
		||||
      return CreateStatusWithPayload(
 | 
			
		||||
          absl::StatusCode::kInvalidArgument,
 | 
			
		||||
          absl::StrFormat("Expected output tensor at index %d to have type "
 | 
			
		||||
                          "UINT8 or FLOAT32, found %s instead.",
 | 
			
		||||
                          "UINT8 or FLOAT32 or BOOL, found %s instead.",
 | 
			
		||||
                          i, tflite::EnumNameTensorType(tensor->type())),
 | 
			
		||||
          MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
 | 
			
		||||
    }
 | 
			
		||||
    if (tensor->type() == tflite::TensorType_UINT8) {
 | 
			
		||||
    if (tensor->type() == tflite::TensorType_UINT8 ||
 | 
			
		||||
        tensor->type() == tflite::TensorType_BOOL) {
 | 
			
		||||
      num_quantized_tensors++;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,6 +31,8 @@ message TextPreprocessingGraphOptions {
 | 
			
		|||
    BERT_PREPROCESSOR = 1;
 | 
			
		||||
    // Used for the RegexPreprocessorCalculator.
 | 
			
		||||
    REGEX_PREPROCESSOR = 2;
 | 
			
		||||
    // Used for the TextToTensorCalculator.
 | 
			
		||||
    STRING_PREPROCESSOR = 3;
 | 
			
		||||
  }
 | 
			
		||||
  optional PreprocessorType preprocessor_type = 1;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -65,6 +65,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromPreprocessorType(
 | 
			
		|||
      return "BertPreprocessorCalculator";
 | 
			
		||||
    case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR:
 | 
			
		||||
      return "RegexPreprocessorCalculator";
 | 
			
		||||
    case TextPreprocessingGraphOptions::STRING_PREPROCESSOR:
 | 
			
		||||
      return "TextToTensorCalculator";
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -91,11 +93,7 @@ GetPreprocessorType(const ModelResources& model_resources) {
 | 
			
		|||
        MediaPipeTasksStatus::kInvalidInputTensorTypeError);
 | 
			
		||||
  }
 | 
			
		||||
  if (all_string_tensors) {
 | 
			
		||||
    // TODO: Support a TextToTensor calculator for string tensors.
 | 
			
		||||
    return CreateStatusWithPayload(
 | 
			
		||||
        absl::StatusCode::kInvalidArgument,
 | 
			
		||||
        "String tensors are not supported yet",
 | 
			
		||||
        MediaPipeTasksStatus::kInvalidInputTensorTypeError);
 | 
			
		||||
    return TextPreprocessingGraphOptions::STRING_PREPROCESSOR;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Otherwise, all tensors should have type int32
 | 
			
		||||
| 
						 | 
				
			
			@ -185,10 +183,19 @@ absl::Status ConfigureTextPreprocessingSubgraph(
 | 
			
		|||
      TextPreprocessingGraphOptions::PreprocessorType preprocessor_type,
 | 
			
		||||
      GetPreprocessorType(model_resources));
 | 
			
		||||
  options.set_preprocessor_type(preprocessor_type);
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      int max_seq_len,
 | 
			
		||||
      GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
 | 
			
		||||
  options.set_max_seq_len(max_seq_len);
 | 
			
		||||
  switch (preprocessor_type) {
 | 
			
		||||
    case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
 | 
			
		||||
    case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case TextPreprocessingGraphOptions::BERT_PREPROCESSOR:
 | 
			
		||||
    case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: {
 | 
			
		||||
      ASSIGN_OR_RETURN(
 | 
			
		||||
          int max_seq_len,
 | 
			
		||||
          GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
 | 
			
		||||
      options.set_max_seq_len(max_seq_len);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -236,7 +243,8 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph {
 | 
			
		|||
        GetCalculatorNameFromPreprocessorType(options.preprocessor_type()));
 | 
			
		||||
    auto& text_preprocessor = graph.AddNode(preprocessor_name);
 | 
			
		||||
    switch (options.preprocessor_type()) {
 | 
			
		||||
      case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: {
 | 
			
		||||
      case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
 | 
			
		||||
      case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										1
									
								
								mediapipe/tasks/testdata/text/BUILD
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								mediapipe/tasks/testdata/text/BUILD
									
									
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -27,6 +27,7 @@ mediapipe_files(srcs = [
 | 
			
		|||
    "albert_with_metadata.tflite",
 | 
			
		||||
    "bert_text_classifier.tflite",
 | 
			
		||||
    "mobilebert_with_metadata.tflite",
 | 
			
		||||
    "test_model_text_classifier_bool_output.tflite",
 | 
			
		||||
    "test_model_text_classifier_with_regex_tokenizer.tflite",
 | 
			
		||||
])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										6
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -562,6 +562,12 @@ def external_files():
 | 
			
		|||
        urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_add_op.tflite?generation=1661875950076192"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    http_file(
 | 
			
		||||
        name = "com_google_mediapipe_test_model_text_classifier_bool_output_tflite",
 | 
			
		||||
        sha256 = "09877ac6d718d78da6380e21fe8179854909d116632d6d770c12f8a51792e310",
 | 
			
		||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_text_classifier_bool_output.tflite?generation=1664904110313163"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    http_file(
 | 
			
		||||
        name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite",
 | 
			
		||||
        sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user