Add a string-to-bool test model to TextClassifier.
PiperOrigin-RevId: 479803799
This commit is contained in:
parent
08ae99688c
commit
1ab332835a
|
@ -320,6 +320,8 @@ cc_library(
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
|
"@org_tensorflow//tensorflow/lite:string_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,8 +21,10 @@
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "tensorflow/lite/c/c_api_types.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/interpreter_builder.h"
|
#include "tensorflow/lite/interpreter_builder.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
||||||
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
|
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
|
||||||
|
tflite::Interpreter* interpreter,
|
||||||
|
int input_tensor_index) {
|
||||||
|
const char* input_tensor_buffer =
|
||||||
|
input_tensor.GetCpuReadView().buffer<char>();
|
||||||
|
tflite::DynamicBuffer dynamic_buffer;
|
||||||
|
dynamic_buffer.AddString(input_tensor_buffer,
|
||||||
|
input_tensor.shape().num_elements());
|
||||||
|
dynamic_buffer.WriteToTensorAsVector(
|
||||||
|
interpreter->tensor(interpreter->inputs()[input_tensor_index]));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
||||||
int output_tensor_index,
|
int output_tensor_index,
|
||||||
|
@ -87,13 +102,13 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case TfLiteType::kTfLiteUInt8: {
|
case TfLiteType::kTfLiteUInt8: {
|
||||||
CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
|
CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
|
||||||
interpreter_.get(), i);
|
interpreter_.get(), i);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case TfLiteType::kTfLiteInt8: {
|
case TfLiteType::kTfLiteInt8: {
|
||||||
CopyTensorBufferToInterpreter<int8>(input_tensors[i],
|
CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
|
||||||
interpreter_.get(), i);
|
interpreter_.get(), i);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case TfLiteType::kTfLiteInt32: {
|
case TfLiteType::kTfLiteInt32: {
|
||||||
|
@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
||||||
interpreter_.get(), i);
|
interpreter_.get(), i);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case TfLiteType::kTfLiteString: {
|
||||||
|
CopyTensorBufferToInterpreter<char>(input_tensors[i],
|
||||||
|
interpreter_.get(), i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TfLiteType::kTfLiteBool:
|
||||||
|
// No current use-case for copying MediaPipe Tensors with bool type to
|
||||||
|
// TfLiteTensors.
|
||||||
default:
|
default:
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
absl::StrCat("Unsupported input tensor type:", input_tensor_type));
|
absl::StrCat("Unsupported input tensor type:", input_tensor_type));
|
||||||
|
@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
||||||
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
|
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
|
||||||
&output_tensors.back());
|
&output_tensors.back());
|
||||||
break;
|
break;
|
||||||
|
case TfLiteType::kTfLiteBool:
|
||||||
|
output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
|
||||||
|
Tensor::QuantizationParameters{1.0f, 0});
|
||||||
|
CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
|
||||||
|
&output_tensors.back());
|
||||||
|
break;
|
||||||
|
case TfLiteType::kTfLiteString:
|
||||||
|
// No current use-case for copying TfLiteTensors with string type to
|
||||||
|
// MediaPipe Tensors.
|
||||||
default:
|
default:
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
absl::StrCat("Unsupported output tensor type:",
|
absl::StrCat("Unsupported output tensor type:",
|
||||||
|
|
|
@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
|
||||||
case Tensor::ElementType::kInt8:
|
case Tensor::ElementType::kInt8:
|
||||||
Dequantize<int8>(input_tensor, &output_tensors->back());
|
Dequantize<int8>(input_tensor, &output_tensors->back());
|
||||||
break;
|
break;
|
||||||
|
case Tensor::ElementType::kBool:
|
||||||
|
Dequantize<bool>(input_tensor, &output_tensors->back());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
return absl::InvalidArgumentError(absl::StrCat(
|
||||||
"Unsupported input tensor type: ", input_tensor.element_type()));
|
"Unsupported input tensor type: ", input_tensor.element_type()));
|
||||||
|
|
|
@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
|
||||||
ValidateResult(GetOutput(), {-1.007874, 0, 1});
|
ValidateResult(GetOutput(), {-1.007874, 0, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
|
||||||
|
std::vector<bool> tensor = {true, false, true};
|
||||||
|
PushTensor(Tensor::ElementType::kBool, tensor,
|
||||||
|
Tensor::QuantizationParameters{1.0f, 0});
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_.Run());
|
||||||
|
|
||||||
|
ValidateResult(GetOutput(), {1, 0, 1});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -97,8 +97,8 @@ class Tensor {
|
||||||
kUInt8,
|
kUInt8,
|
||||||
kInt8,
|
kInt8,
|
||||||
kInt32,
|
kInt32,
|
||||||
// TODO: Update the inference runner to handle kTfLiteString.
|
kChar,
|
||||||
kChar
|
kBool
|
||||||
};
|
};
|
||||||
struct Shape {
|
struct Shape {
|
||||||
Shape() = default;
|
Shape() = default;
|
||||||
|
@ -330,6 +330,8 @@ class Tensor {
|
||||||
return sizeof(int32_t);
|
return sizeof(int32_t);
|
||||||
case ElementType::kChar:
|
case ElementType::kChar:
|
||||||
return sizeof(char);
|
return sizeof(char);
|
||||||
|
case ElementType::kBool:
|
||||||
|
return sizeof(bool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int bytes() const { return shape_.num_elements() * element_size(); }
|
int bytes() const { return shape_.num_elements() * element_size(); }
|
||||||
|
|
|
@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
|
||||||
|
|
||||||
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
|
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
|
||||||
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
|
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
|
||||||
|
|
||||||
|
Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
|
||||||
|
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Cpu, TestMemoryAllocation) {
|
TEST(Cpu, TestMemoryAllocation) {
|
||||||
|
|
|
@ -123,15 +123,17 @@ absl::StatusOr<ClassificationHeadsProperties> GetClassificationHeadsProperties(
|
||||||
const auto* tensor =
|
const auto* tensor =
|
||||||
primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
|
primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
|
||||||
if (tensor->type() != tflite::TensorType_FLOAT32 &&
|
if (tensor->type() != tflite::TensorType_FLOAT32 &&
|
||||||
tensor->type() != tflite::TensorType_UINT8) {
|
tensor->type() != tflite::TensorType_UINT8 &&
|
||||||
|
tensor->type() != tflite::TensorType_BOOL) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::StrFormat("Expected output tensor at index %d to have type "
|
absl::StrFormat("Expected output tensor at index %d to have type "
|
||||||
"UINT8 or FLOAT32, found %s instead.",
|
"UINT8 or FLOAT32 or BOOL, found %s instead.",
|
||||||
i, tflite::EnumNameTensorType(tensor->type())),
|
i, tflite::EnumNameTensorType(tensor->type())),
|
||||||
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
|
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
|
||||||
}
|
}
|
||||||
if (tensor->type() == tflite::TensorType_UINT8) {
|
if (tensor->type() == tflite::TensorType_UINT8 ||
|
||||||
|
tensor->type() == tflite::TensorType_BOOL) {
|
||||||
num_quantized_tensors++;
|
num_quantized_tensors++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,8 @@ message TextPreprocessingGraphOptions {
|
||||||
BERT_PREPROCESSOR = 1;
|
BERT_PREPROCESSOR = 1;
|
||||||
// Used for the RegexPreprocessorCalculator.
|
// Used for the RegexPreprocessorCalculator.
|
||||||
REGEX_PREPROCESSOR = 2;
|
REGEX_PREPROCESSOR = 2;
|
||||||
|
// Used for the TextToTensorCalculator.
|
||||||
|
STRING_PREPROCESSOR = 3;
|
||||||
}
|
}
|
||||||
optional PreprocessorType preprocessor_type = 1;
|
optional PreprocessorType preprocessor_type = 1;
|
||||||
|
|
||||||
|
|
|
@ -65,6 +65,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromPreprocessorType(
|
||||||
return "BertPreprocessorCalculator";
|
return "BertPreprocessorCalculator";
|
||||||
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR:
|
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR:
|
||||||
return "RegexPreprocessorCalculator";
|
return "RegexPreprocessorCalculator";
|
||||||
|
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR:
|
||||||
|
return "TextToTensorCalculator";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,11 +93,7 @@ GetPreprocessorType(const ModelResources& model_resources) {
|
||||||
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
|
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
|
||||||
}
|
}
|
||||||
if (all_string_tensors) {
|
if (all_string_tensors) {
|
||||||
// TODO: Support a TextToTensor calculator for string tensors.
|
return TextPreprocessingGraphOptions::STRING_PREPROCESSOR;
|
||||||
return CreateStatusWithPayload(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
"String tensors are not supported yet",
|
|
||||||
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, all tensors should have type int32
|
// Otherwise, all tensors should have type int32
|
||||||
|
@ -185,10 +183,19 @@ absl::Status ConfigureTextPreprocessingSubgraph(
|
||||||
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type,
|
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type,
|
||||||
GetPreprocessorType(model_resources));
|
GetPreprocessorType(model_resources));
|
||||||
options.set_preprocessor_type(preprocessor_type);
|
options.set_preprocessor_type(preprocessor_type);
|
||||||
ASSIGN_OR_RETURN(
|
switch (preprocessor_type) {
|
||||||
int max_seq_len,
|
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
||||||
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
|
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
|
||||||
options.set_max_seq_len(max_seq_len);
|
break;
|
||||||
|
}
|
||||||
|
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR:
|
||||||
|
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: {
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
int max_seq_len,
|
||||||
|
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
|
||||||
|
options.set_max_seq_len(max_seq_len);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -236,7 +243,8 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph {
|
||||||
GetCalculatorNameFromPreprocessorType(options.preprocessor_type()));
|
GetCalculatorNameFromPreprocessorType(options.preprocessor_type()));
|
||||||
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
||||||
switch (options.preprocessor_type()) {
|
switch (options.preprocessor_type()) {
|
||||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: {
|
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
||||||
|
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: {
|
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: {
|
||||||
|
|
1
mediapipe/tasks/testdata/text/BUILD
vendored
1
mediapipe/tasks/testdata/text/BUILD
vendored
|
@ -27,6 +27,7 @@ mediapipe_files(srcs = [
|
||||||
"albert_with_metadata.tflite",
|
"albert_with_metadata.tflite",
|
||||||
"bert_text_classifier.tflite",
|
"bert_text_classifier.tflite",
|
||||||
"mobilebert_with_metadata.tflite",
|
"mobilebert_with_metadata.tflite",
|
||||||
|
"test_model_text_classifier_bool_output.tflite",
|
||||||
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -562,6 +562,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_add_op.tflite?generation=1661875950076192"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_add_op.tflite?generation=1661875950076192"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_test_model_text_classifier_bool_output_tflite",
|
||||||
|
sha256 = "09877ac6d718d78da6380e21fe8179854909d116632d6d770c12f8a51792e310",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_text_classifier_bool_output.tflite?generation=1664904110313163"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite",
|
name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite",
|
||||||
sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f",
|
sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user