Add a string-to-bool test model to TextClassifier.

PiperOrigin-RevId: 479803799
This commit is contained in:
MediaPipe Team 2022-10-08 09:42:20 -07:00 committed by Copybara-Service
parent 08ae99688c
commit 1ab332835a
11 changed files with 90 additions and 19 deletions

View File

@ -320,6 +320,8 @@ cc_library(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite:framework_stable", "@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
], ],
) )

View File

@ -21,8 +21,10 @@
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe { namespace mediapipe {
@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes()); std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
} }
template <>
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
tflite::Interpreter* interpreter,
int input_tensor_index) {
const char* input_tensor_buffer =
input_tensor.GetCpuReadView().buffer<char>();
tflite::DynamicBuffer dynamic_buffer;
dynamic_buffer.AddString(input_tensor_buffer,
input_tensor.shape().num_elements());
dynamic_buffer.WriteToTensorAsVector(
interpreter->tensor(interpreter->inputs()[input_tensor_index]));
}
template <typename T> template <typename T>
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter, void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
int output_tensor_index, int output_tensor_index,
@ -87,12 +102,12 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
break; break;
} }
case TfLiteType::kTfLiteUInt8: { case TfLiteType::kTfLiteUInt8: {
CopyTensorBufferToInterpreter<uint8>(input_tensors[i], CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
case TfLiteType::kTfLiteInt8: { case TfLiteType::kTfLiteInt8: {
CopyTensorBufferToInterpreter<int8>(input_tensors[i], CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
case TfLiteType::kTfLiteString: {
CopyTensorBufferToInterpreter<char>(input_tensors[i],
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteBool:
// No current use-case for copying MediaPipe Tensors with bool type to
// TfLiteTensors.
default: default:
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::StrCat("Unsupported input tensor type:", input_tensor_type)); absl::StrCat("Unsupported input tensor type:", input_tensor_type));
@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i, CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
&output_tensors.back()); &output_tensors.back());
break; break;
case TfLiteType::kTfLiteBool:
output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
Tensor::QuantizationParameters{1.0f, 0});
CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
&output_tensors.back());
break;
case TfLiteType::kTfLiteString:
// No current use-case for copying TfLiteTensors with string type to
// MediaPipe Tensors.
default: default:
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::StrCat("Unsupported output tensor type:", absl::StrCat("Unsupported output tensor type:",

View File

@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
case Tensor::ElementType::kInt8: case Tensor::ElementType::kInt8:
Dequantize<int8>(input_tensor, &output_tensors->back()); Dequantize<int8>(input_tensor, &output_tensors->back());
break; break;
case Tensor::ElementType::kBool:
Dequantize<bool>(input_tensor, &output_tensors->back());
break;
default: default:
return absl::InvalidArgumentError(absl::StrCat( return absl::InvalidArgumentError(absl::StrCat(
"Unsupported input tensor type: ", input_tensor.element_type())); "Unsupported input tensor type: ", input_tensor.element_type()));

View File

@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
ValidateResult(GetOutput(), {-1.007874, 0, 1}); ValidateResult(GetOutput(), {-1.007874, 0, 1});
} }
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
std::vector<bool> tensor = {true, false, true};
PushTensor(Tensor::ElementType::kBool, tensor,
Tensor::QuantizationParameters{1.0f, 0});
MP_ASSERT_OK(runner_.Run());
ValidateResult(GetOutput(), {1, 0, 1});
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -97,8 +97,8 @@ class Tensor {
kUInt8, kUInt8,
kInt8, kInt8,
kInt32, kInt32,
// TODO: Update the inference runner to handle kTfLiteString. kChar,
kChar kBool
}; };
struct Shape { struct Shape {
Shape() = default; Shape() = default;
@ -330,6 +330,8 @@ class Tensor {
return sizeof(int32_t); return sizeof(int32_t);
case ElementType::kChar: case ElementType::kChar:
return sizeof(char); return sizeof(char);
case ElementType::kBool:
return sizeof(bool);
} }
} }
int bytes() const { return shape_.num_elements() * element_size(); } int bytes() const { return shape_.num_elements() * element_size(); }

View File

@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4}); Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char)); EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
} }
TEST(Cpu, TestMemoryAllocation) { TEST(Cpu, TestMemoryAllocation) {

View File

@ -123,15 +123,17 @@ absl::StatusOr<ClassificationHeadsProperties> GetClassificationHeadsProperties(
const auto* tensor = const auto* tensor =
primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i)); primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
if (tensor->type() != tflite::TensorType_FLOAT32 && if (tensor->type() != tflite::TensorType_FLOAT32 &&
tensor->type() != tflite::TensorType_UINT8) { tensor->type() != tflite::TensorType_UINT8 &&
tensor->type() != tflite::TensorType_BOOL) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrFormat("Expected output tensor at index %d to have type " absl::StrFormat("Expected output tensor at index %d to have type "
"UINT8 or FLOAT32, found %s instead.", "UINT8 or FLOAT32 or BOOL, found %s instead.",
i, tflite::EnumNameTensorType(tensor->type())), i, tflite::EnumNameTensorType(tensor->type())),
MediaPipeTasksStatus::kInvalidOutputTensorTypeError); MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
} }
if (tensor->type() == tflite::TensorType_UINT8) { if (tensor->type() == tflite::TensorType_UINT8 ||
tensor->type() == tflite::TensorType_BOOL) {
num_quantized_tensors++; num_quantized_tensors++;
} }
} }

View File

@ -31,6 +31,8 @@ message TextPreprocessingGraphOptions {
BERT_PREPROCESSOR = 1; BERT_PREPROCESSOR = 1;
// Used for the RegexPreprocessorCalculator. // Used for the RegexPreprocessorCalculator.
REGEX_PREPROCESSOR = 2; REGEX_PREPROCESSOR = 2;
// Used for the TextToTensorCalculator.
STRING_PREPROCESSOR = 3;
} }
optional PreprocessorType preprocessor_type = 1; optional PreprocessorType preprocessor_type = 1;

View File

@ -65,6 +65,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromPreprocessorType(
return "BertPreprocessorCalculator"; return "BertPreprocessorCalculator";
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR:
return "RegexPreprocessorCalculator"; return "RegexPreprocessorCalculator";
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR:
return "TextToTensorCalculator";
} }
} }
@ -91,11 +93,7 @@ GetPreprocessorType(const ModelResources& model_resources) {
MediaPipeTasksStatus::kInvalidInputTensorTypeError); MediaPipeTasksStatus::kInvalidInputTensorTypeError);
} }
if (all_string_tensors) { if (all_string_tensors) {
// TODO: Support a TextToTensor calculator for string tensors. return TextPreprocessingGraphOptions::STRING_PREPROCESSOR;
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"String tensors are not supported yet",
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
} }
// Otherwise, all tensors should have type int32 // Otherwise, all tensors should have type int32
@ -185,10 +183,19 @@ absl::Status ConfigureTextPreprocessingSubgraph(
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, TextPreprocessingGraphOptions::PreprocessorType preprocessor_type,
GetPreprocessorType(model_resources)); GetPreprocessorType(model_resources));
options.set_preprocessor_type(preprocessor_type); options.set_preprocessor_type(preprocessor_type);
switch (preprocessor_type) {
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
break;
}
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR:
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
int max_seq_len, int max_seq_len,
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
options.set_max_seq_len(max_seq_len); options.set_max_seq_len(max_seq_len);
}
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -236,7 +243,8 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph {
GetCalculatorNameFromPreprocessorType(options.preprocessor_type())); GetCalculatorNameFromPreprocessorType(options.preprocessor_type()));
auto& text_preprocessor = graph.AddNode(preprocessor_name); auto& text_preprocessor = graph.AddNode(preprocessor_name);
switch (options.preprocessor_type()) { switch (options.preprocessor_type()) {
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: { case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
break; break;
} }
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: {

View File

@ -27,6 +27,7 @@ mediapipe_files(srcs = [
"albert_with_metadata.tflite", "albert_with_metadata.tflite",
"bert_text_classifier.tflite", "bert_text_classifier.tflite",
"mobilebert_with_metadata.tflite", "mobilebert_with_metadata.tflite",
"test_model_text_classifier_bool_output.tflite",
"test_model_text_classifier_with_regex_tokenizer.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite",
]) ])

View File

@ -562,6 +562,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_add_op.tflite?generation=1661875950076192"], urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_add_op.tflite?generation=1661875950076192"],
) )
http_file(
name = "com_google_mediapipe_test_model_text_classifier_bool_output_tflite",
sha256 = "09877ac6d718d78da6380e21fe8179854909d116632d6d770c12f8a51792e310",
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_text_classifier_bool_output.tflite?generation=1664904110313163"],
)
http_file( http_file(
name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite", name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite",
sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f", sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f",