Add a string-to-bool test model to TextClassifier.
PiperOrigin-RevId: 479803799
This commit is contained in:
parent
08ae99688c
commit
1ab332835a
|
@ -320,6 +320,8 @@ cc_library(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -21,8 +21,10 @@
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
|||
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
|
||||
}
|
||||
|
||||
template <>
|
||||
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
|
||||
tflite::Interpreter* interpreter,
|
||||
int input_tensor_index) {
|
||||
const char* input_tensor_buffer =
|
||||
input_tensor.GetCpuReadView().buffer<char>();
|
||||
tflite::DynamicBuffer dynamic_buffer;
|
||||
dynamic_buffer.AddString(input_tensor_buffer,
|
||||
input_tensor.shape().num_elements());
|
||||
dynamic_buffer.WriteToTensorAsVector(
|
||||
interpreter->tensor(interpreter->inputs()[input_tensor_index]));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
||||
int output_tensor_index,
|
||||
|
@ -87,13 +102,13 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteUInt8: {
|
||||
CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt8: {
|
||||
CopyTensorBufferToInterpreter<int8>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt32: {
|
||||
|
@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteString: {
|
||||
CopyTensorBufferToInterpreter<char>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteBool:
|
||||
// No current use-case for copying MediaPipe Tensors with bool type to
|
||||
// TfLiteTensors.
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported input tensor type:", input_tensor_type));
|
||||
|
@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteBool:
|
||||
output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
|
||||
Tensor::QuantizationParameters{1.0f, 0});
|
||||
CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteString:
|
||||
// No current use-case for copying TfLiteTensors with string type to
|
||||
// MediaPipe Tensors.
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported output tensor type:",
|
||||
|
|
|
@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
|
|||
case Tensor::ElementType::kInt8:
|
||||
Dequantize<int8>(input_tensor, &output_tensors->back());
|
||||
break;
|
||||
case Tensor::ElementType::kBool:
|
||||
Dequantize<bool>(input_tensor, &output_tensors->back());
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Unsupported input tensor type: ", input_tensor.element_type()));
|
||||
|
|
|
@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
|
|||
ValidateResult(GetOutput(), {-1.007874, 0, 1});
|
||||
}
|
||||
|
||||
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
|
||||
std::vector<bool> tensor = {true, false, true};
|
||||
PushTensor(Tensor::ElementType::kBool, tensor,
|
||||
Tensor::QuantizationParameters{1.0f, 0});
|
||||
|
||||
MP_ASSERT_OK(runner_.Run());
|
||||
|
||||
ValidateResult(GetOutput(), {1, 0, 1});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -97,8 +97,8 @@ class Tensor {
|
|||
kUInt8,
|
||||
kInt8,
|
||||
kInt32,
|
||||
// TODO: Update the inference runner to handle kTfLiteString.
|
||||
kChar
|
||||
kChar,
|
||||
kBool
|
||||
};
|
||||
struct Shape {
|
||||
Shape() = default;
|
||||
|
@ -330,6 +330,8 @@ class Tensor {
|
|||
return sizeof(int32_t);
|
||||
case ElementType::kChar:
|
||||
return sizeof(char);
|
||||
case ElementType::kBool:
|
||||
return sizeof(bool);
|
||||
}
|
||||
}
|
||||
int bytes() const { return shape_.num_elements() * element_size(); }
|
||||
|
|
|
@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
|
|||
|
||||
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
|
||||
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
|
||||
|
||||
Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
|
||||
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
|
||||
}
|
||||
|
||||
TEST(Cpu, TestMemoryAllocation) {
|
||||
|
|
|
@ -123,15 +123,17 @@ absl::StatusOr<ClassificationHeadsProperties> GetClassificationHeadsProperties(
|
|||
const auto* tensor =
|
||||
primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
|
||||
if (tensor->type() != tflite::TensorType_FLOAT32 &&
|
||||
tensor->type() != tflite::TensorType_UINT8) {
|
||||
tensor->type() != tflite::TensorType_UINT8 &&
|
||||
tensor->type() != tflite::TensorType_BOOL) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Expected output tensor at index %d to have type "
|
||||
"UINT8 or FLOAT32, found %s instead.",
|
||||
"UINT8 or FLOAT32 or BOOL, found %s instead.",
|
||||
i, tflite::EnumNameTensorType(tensor->type())),
|
||||
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
|
||||
}
|
||||
if (tensor->type() == tflite::TensorType_UINT8) {
|
||||
if (tensor->type() == tflite::TensorType_UINT8 ||
|
||||
tensor->type() == tflite::TensorType_BOOL) {
|
||||
num_quantized_tensors++;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,6 +31,8 @@ message TextPreprocessingGraphOptions {
|
|||
BERT_PREPROCESSOR = 1;
|
||||
// Used for the RegexPreprocessorCalculator.
|
||||
REGEX_PREPROCESSOR = 2;
|
||||
// Used for the TextToTensorCalculator.
|
||||
STRING_PREPROCESSOR = 3;
|
||||
}
|
||||
optional PreprocessorType preprocessor_type = 1;
|
||||
|
||||
|
|
|
@ -65,6 +65,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromPreprocessorType(
|
|||
return "BertPreprocessorCalculator";
|
||||
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR:
|
||||
return "RegexPreprocessorCalculator";
|
||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR:
|
||||
return "TextToTensorCalculator";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,11 +93,7 @@ GetPreprocessorType(const ModelResources& model_resources) {
|
|||
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
|
||||
}
|
||||
if (all_string_tensors) {
|
||||
// TODO: Support a TextToTensor calculator for string tensors.
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"String tensors are not supported yet",
|
||||
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
|
||||
return TextPreprocessingGraphOptions::STRING_PREPROCESSOR;
|
||||
}
|
||||
|
||||
// Otherwise, all tensors should have type int32
|
||||
|
@ -185,10 +183,19 @@ absl::Status ConfigureTextPreprocessingSubgraph(
|
|||
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type,
|
||||
GetPreprocessorType(model_resources));
|
||||
options.set_preprocessor_type(preprocessor_type);
|
||||
ASSIGN_OR_RETURN(
|
||||
int max_seq_len,
|
||||
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
|
||||
options.set_max_seq_len(max_seq_len);
|
||||
switch (preprocessor_type) {
|
||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
|
||||
break;
|
||||
}
|
||||
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR:
|
||||
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: {
|
||||
ASSIGN_OR_RETURN(
|
||||
int max_seq_len,
|
||||
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
|
||||
options.set_max_seq_len(max_seq_len);
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -236,7 +243,8 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph {
|
|||
GetCalculatorNameFromPreprocessorType(options.preprocessor_type()));
|
||||
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
||||
switch (options.preprocessor_type()) {
|
||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: {
|
||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
|
||||
break;
|
||||
}
|
||||
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: {
|
||||
|
|
1
mediapipe/tasks/testdata/text/BUILD
vendored
1
mediapipe/tasks/testdata/text/BUILD
vendored
|
@ -27,6 +27,7 @@ mediapipe_files(srcs = [
|
|||
"albert_with_metadata.tflite",
|
||||
"bert_text_classifier.tflite",
|
||||
"mobilebert_with_metadata.tflite",
|
||||
"test_model_text_classifier_bool_output.tflite",
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
||||
])
|
||||
|
||||
|
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -562,6 +562,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_add_op.tflite?generation=1661875950076192"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_test_model_text_classifier_bool_output_tflite",
|
||||
sha256 = "09877ac6d718d78da6380e21fe8179854909d116632d6d770c12f8a51792e310",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_text_classifier_bool_output.tflite?generation=1664904110313163"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite",
|
||||
sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f",
|
||||
|
|
Loading…
Reference in New Issue
Block a user