Add a string-to-bool test model to TextClassifier.

PiperOrigin-RevId: 479803799
This commit is contained in:
MediaPipe Team 2022-10-08 09:42:20 -07:00 committed by Copybara-Service
parent 08ae99688c
commit 1ab332835a
11 changed files with 90 additions and 19 deletions

View File

@ -320,6 +320,8 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
],
)

View File

@ -21,8 +21,10 @@
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe {
@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
}
template <>
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
tflite::Interpreter* interpreter,
int input_tensor_index) {
const char* input_tensor_buffer =
input_tensor.GetCpuReadView().buffer<char>();
tflite::DynamicBuffer dynamic_buffer;
dynamic_buffer.AddString(input_tensor_buffer,
input_tensor.shape().num_elements());
dynamic_buffer.WriteToTensorAsVector(
interpreter->tensor(interpreter->inputs()[input_tensor_index]));
}
template <typename T>
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
int output_tensor_index,
@ -87,13 +102,13 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
break;
}
case TfLiteType::kTfLiteUInt8: {
CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
interpreter_.get(), i);
CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteInt8: {
CopyTensorBufferToInterpreter<int8>(input_tensors[i],
interpreter_.get(), i);
CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteInt32: {
@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteString: {
CopyTensorBufferToInterpreter<char>(input_tensors[i],
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteBool:
// No current use-case for copying MediaPipe Tensors with bool type to
// TfLiteTensors.
default:
return absl::InvalidArgumentError(
absl::StrCat("Unsupported input tensor type:", input_tensor_type));
@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
&output_tensors.back());
break;
case TfLiteType::kTfLiteBool:
output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
Tensor::QuantizationParameters{1.0f, 0});
CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
&output_tensors.back());
break;
case TfLiteType::kTfLiteString:
// No current use-case for copying TfLiteTensors with string type to
// MediaPipe Tensors.
default:
return absl::InvalidArgumentError(
absl::StrCat("Unsupported output tensor type:",

View File

@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
case Tensor::ElementType::kInt8:
Dequantize<int8>(input_tensor, &output_tensors->back());
break;
case Tensor::ElementType::kBool:
Dequantize<bool>(input_tensor, &output_tensors->back());
break;
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported input tensor type: ", input_tensor.element_type()));

View File

@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
ValidateResult(GetOutput(), {-1.007874, 0, 1});
}
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
std::vector<bool> tensor = {true, false, true};
PushTensor(Tensor::ElementType::kBool, tensor,
Tensor::QuantizationParameters{1.0f, 0});
MP_ASSERT_OK(runner_.Run());
ValidateResult(GetOutput(), {1, 0, 1});
}
} // namespace
} // namespace mediapipe

View File

@ -97,8 +97,8 @@ class Tensor {
kUInt8,
kInt8,
kInt32,
// TODO: Update the inference runner to handle kTfLiteString.
kChar
kChar,
kBool
};
struct Shape {
Shape() = default;
@ -330,6 +330,8 @@ class Tensor {
return sizeof(int32_t);
case ElementType::kChar:
return sizeof(char);
case ElementType::kBool:
return sizeof(bool);
}
}
int bytes() const { return shape_.num_elements() * element_size(); }

View File

@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
}
TEST(Cpu, TestMemoryAllocation) {

View File

@ -123,15 +123,17 @@ absl::StatusOr<ClassificationHeadsProperties> GetClassificationHeadsProperties(
const auto* tensor =
primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
if (tensor->type() != tflite::TensorType_FLOAT32 &&
tensor->type() != tflite::TensorType_UINT8) {
tensor->type() != tflite::TensorType_UINT8 &&
tensor->type() != tflite::TensorType_BOOL) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Expected output tensor at index %d to have type "
"UINT8 or FLOAT32, found %s instead.",
"UINT8 or FLOAT32 or BOOL, found %s instead.",
i, tflite::EnumNameTensorType(tensor->type())),
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
}
if (tensor->type() == tflite::TensorType_UINT8) {
if (tensor->type() == tflite::TensorType_UINT8 ||
tensor->type() == tflite::TensorType_BOOL) {
num_quantized_tensors++;
}
}

View File

@ -31,6 +31,8 @@ message TextPreprocessingGraphOptions {
BERT_PREPROCESSOR = 1;
// Used for the RegexPreprocessorCalculator.
REGEX_PREPROCESSOR = 2;
// Used for the TextToTensorCalculator.
STRING_PREPROCESSOR = 3;
}
optional PreprocessorType preprocessor_type = 1;

View File

@ -65,6 +65,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromPreprocessorType(
return "BertPreprocessorCalculator";
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR:
return "RegexPreprocessorCalculator";
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR:
return "TextToTensorCalculator";
}
}
@ -91,11 +93,7 @@ GetPreprocessorType(const ModelResources& model_resources) {
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
}
if (all_string_tensors) {
// TODO: Support a TextToTensor calculator for string tensors.
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"String tensors are not supported yet",
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
return TextPreprocessingGraphOptions::STRING_PREPROCESSOR;
}
// Otherwise, all tensors should have type int32
@ -185,10 +183,19 @@ absl::Status ConfigureTextPreprocessingSubgraph(
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type,
GetPreprocessorType(model_resources));
options.set_preprocessor_type(preprocessor_type);
ASSIGN_OR_RETURN(
int max_seq_len,
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
options.set_max_seq_len(max_seq_len);
switch (preprocessor_type) {
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
break;
}
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR:
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: {
ASSIGN_OR_RETURN(
int max_seq_len,
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
options.set_max_seq_len(max_seq_len);
}
}
return absl::OkStatus();
}
@ -236,7 +243,8 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph {
GetCalculatorNameFromPreprocessorType(options.preprocessor_type()));
auto& text_preprocessor = graph.AddNode(preprocessor_name);
switch (options.preprocessor_type()) {
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: {
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
break;
}
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: {

View File

@ -27,6 +27,7 @@ mediapipe_files(srcs = [
"albert_with_metadata.tflite",
"bert_text_classifier.tflite",
"mobilebert_with_metadata.tflite",
"test_model_text_classifier_bool_output.tflite",
"test_model_text_classifier_with_regex_tokenizer.tflite",
])

View File

@ -562,6 +562,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_add_op.tflite?generation=1661875950076192"],
)
http_file(
name = "com_google_mediapipe_test_model_text_classifier_bool_output_tflite",
sha256 = "09877ac6d718d78da6380e21fe8179854909d116632d6d770c12f8a51792e310",
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_text_classifier_bool_output.tflite?generation=1664904110313163"],
)
http_file(
name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite",
sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f",