Metadata Writer: Add metadata writer for Text Classifier.

PiperOrigin-RevId: 486844428
This commit is contained in:
Yuqi Li 2022-11-07 21:24:19 -08:00 committed by Copybara-Service
parent b14178d305
commit 0a08e4768b
16 changed files with 10470 additions and 10 deletions

View File

@ -43,3 +43,9 @@ py_library(
srcs = ["image_classifier.py"],
deps = [":metadata_writer"],
)
py_library(
name = "text_classifier",
srcs = ["text_classifier.py"],
deps = [":metadata_writer"],
)

View File

@ -62,10 +62,10 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
Returns:
An MetadataWrite object.
A MetadataWriter object.
"""
writer = metadata_writer.MetadataWriter(model_buffer)
writer.add_genernal_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_image_input(input_norm_mean, input_norm_std)
writer.add_classification_output(labels, score_calibration)
return cls(writer)

View File

@ -228,6 +228,45 @@ class ScoreThresholdingMd:
return score_thresholding
class RegexTokenizerMd:
"""A container for the Regex tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
"""
def __init__(self, delim_regex_pattern: str, vocab_file_path: str):
"""Initializes a RegexTokenizerMd object.
Args:
delim_regex_pattern: the regular expression to segment strings and create
tokens.
vocab_file_path: path to the vocabulary file.
"""
self._delim_regex_pattern = delim_regex_pattern
self._vocab_file_path = vocab_file_path
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the Regex tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the Regex tokenizer metadata.
"""
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = _VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
# Create the RegexTokenizer.
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = (
_metadata_fb.ProcessUnitOptions.RegexTokenizerOptions)
tokenizer.options = _metadata_fb.RegexTokenizerOptionsT()
tokenizer.options.delimRegexPattern = self._delim_regex_pattern
tokenizer.options.vocabFile = [vocab]
return tokenizer
class TensorMd:
"""A container for common tensor metadata information.
@ -397,6 +436,56 @@ class InputImageTensorMd(TensorMd):
return tensor_metadata
class InputTextTensorMd(TensorMd):
"""A container for the input text tensor metadata information.
Attributes:
tokenizer_md: information of the tokenizer in the input text tensor, if any.
"""
def __init__(self,
name: Optional[str] = None,
description: Optional[str] = None,
tokenizer_md: Optional[RegexTokenizerMd] = None):
"""Initializes the instance of InputTextTensorMd.
Args:
name: name of the tensor.
description: description of what the tensor is.
tokenizer_md: information of the tokenizer in the input text tensor, if
any. Only `RegexTokenizer` [1] is currenly supported. If the tokenizer
is `BertTokenizer` [2] or `SentencePieceTokenizer` [3], refer to
`BertInputTensorsMd` class.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
"""
super().__init__(name, description)
self.tokenizer_md = tokenizer_md
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the input text metadata based on the information.
Returns:
A Flatbuffers Python object of the input text metadata.
Raises:
ValueError: if the type of tokenizer_md is unsupported.
"""
if not isinstance(self.tokenizer_md, (type(None), RegexTokenizerMd)):
raise ValueError(
f"The type of tokenizer_options, {type(self.tokenizer_md)}, is "
f"unsupported")
tensor_metadata = super().create_metadata()
if self.tokenizer_md:
tensor_metadata.processUnits = [self.tokenizer_md.create_metadata()]
return tensor_metadata
class ClassificationTensorMd(TensorMd):
"""A container for the classification tensor metadata information.

View File

@ -29,6 +29,9 @@ from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
_INPUT_IMAGE_NAME = 'image'
_INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.'
_INPUT_REGEX_TEXT_NAME = 'input_text'
_INPUT_REGEX_TEXT_DESCRIPTION = ('Embedding vectors representing the input '
'text to be processed.')
_OUTPUT_CLASSIFICATION_NAME = 'score'
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.'
@ -82,6 +85,22 @@ class ScoreThresholding:
global_score_threshold: float
@dataclasses.dataclass
class RegexTokenizer:
"""Parameters of the Regex tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
Attributes:
delim_regex_pattern: the regular expression to segment strings and create
tokens.
vocab_file_path: path to the vocabulary file.
"""
delim_regex_pattern: str
vocab_file_path: str
class Labels(object):
"""Simple container holding classification labels of a particular tensor.
@ -355,11 +374,11 @@ class MetadataWriter(object):
if os.path.exists(self._temp_folder.name):
self._temp_folder.cleanup()
def add_genernal_info(
def add_general_info(
self,
model_name: str,
model_description: Optional[str] = None) -> 'MetadataWriter':
"""Adds a genernal info metadata for the general metadata informantion."""
"""Adds a general info metadata for the general metadata informantion."""
# Will overwrite the previous `self._general_md` if exists.
self._general_md = metadata_info.GeneralMd(
name=model_name, description=model_description)
@ -415,6 +434,34 @@ class MetadataWriter(object):
self._input_mds.append(input_md)
return self
def add_regex_text_input(
self,
regex_tokenizer: RegexTokenizer,
name: str = _INPUT_REGEX_TEXT_NAME,
description: str = _INPUT_REGEX_TEXT_DESCRIPTION) -> 'MetadataWriter':
"""Adds an input text metadata for the text input with regex tokenizer.
Args:
regex_tokenizer: information of the regex tokenizer [1] used to process
the input string.
name: Name of the input tensor.
description: Description of the input tensor.
Returns:
The MetaWriter instance, can be used for chained operation.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
"""
tokenizer_md = metadata_info.RegexTokenizerMd(
delim_regex_pattern=regex_tokenizer.delim_regex_pattern,
vocab_file_path=regex_tokenizer.vocab_file_path)
input_md = metadata_info.InputTextTensorMd(
name=name, description=description, tokenizer_md=tokenizer_md)
self._input_mds.append(input_md)
self._associated_files.append(regex_tokenizer.vocab_file_path)
return self
def add_classification_output(
self,
labels: Optional[Labels] = None,

View File

@ -0,0 +1,64 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Writes metadata and label file to the Text classifier models."""
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
_MODEL_NAME = "TextClassifier"
_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.")
class MetadataWriter(metadata_writer.MetadataWriterBase):
"""MetadataWriter to write the metadata into the text classifier."""
@classmethod
def create_for_regex_model(
cls, model_buffer: bytearray,
regex_tokenizer: metadata_writer.RegexTokenizer,
labels: metadata_writer.Labels) -> "MetadataWriter":
"""Creates MetadataWriter for TFLite model with regex tokentizer.
The parameters required in this method are mandatory when using MediaPipe
Tasks.
Note that only the output TFLite is used for deployment. The output JSON
content is used to interpret the metadata content.
Args:
model_buffer: A valid flatbuffer loaded from the TFLite model file.
regex_tokenizer: information of the regex tokenizer [1] used to process
the input string. If the tokenizer is `BertTokenizer` [2] or
`SentencePieceTokenizer` [3], please refer to
`create_for_bert_model`.
labels: an instance of Labels helper class used in the output
classification tensor [4].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
[4]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
Returns:
A MetadataWriter object.
"""
writer = metadata_writer.MetadataWriter(model_buffer)
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_regex_text_input(regex_tokenizer)
writer.add_classification_output(labels)
return cls(writer)

View File

@ -53,3 +53,17 @@ py_test(
"//mediapipe/tasks/python/test:test_utils",
],
)
py_test(
name = "text_classifier_test",
srcs = ["text_classifier_test.py"],
data = [
"//mediapipe/tasks/testdata/metadata:data_files",
"//mediapipe/tasks/testdata/metadata:model_files",
],
deps = [
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:text_classifier",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -191,6 +191,43 @@ class InputImageTensorMdTest(parameterized.TestCase):
f"{len(norm_mean)} and {len(norm_std)}", str(error.exception))
class InputTextTensorMdTest(absltest.TestCase):
_NAME = "input text"
_DESCRIPTION = "The input string."
_VOCAB_FILE = "vocab.txt"
_DELIM_REGEX_PATTERN = r"[^\w\']+"
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "input_text_tensor_meta.json"))
_EXPECTED_TENSOR_DEFAULT_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "input_text_tensor_default_meta.json"))
def test_create_metadata_should_succeed(self):
regex_tokenizer_md = metadata_info.RegexTokenizerMd(
self._DELIM_REGEX_PATTERN, self._VOCAB_FILE)
text_tensor_md = metadata_info.InputTextTensorMd(self._NAME,
self._DESCRIPTION,
regex_tokenizer_md)
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_tensor(
text_tensor_md.create_metadata()))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def test_create_metadata_by_default_should_succeed(self):
text_tensor_md = metadata_info.InputTextTensorMd()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_tensor(
text_tensor_md.create_metadata()))
with open(self._EXPECTED_TENSOR_DEFAULT_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
class ClassificationTensorMdTest(parameterized.TestCase):
_NAME = "probability"

View File

@ -113,7 +113,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
def test_initialize_and_populate(self):
writer = metadata_writer.MetadataWriter.create(
self.image_classifier_model_buffer)
writer.add_genernal_info(
writer.add_general_info(
model_name='my_image_model', model_description='my_description')
tflite_model, metadata_json = writer.populate()
self.assertLen(tflite_model, 1882986)
@ -142,7 +142,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
def test_add_feature_input_output(self):
writer = metadata_writer.MetadataWriter.create(
self.image_classifier_model_buffer)
writer.add_genernal_info(
writer.add_general_info(
model_name='my_model', model_description='my_description')
writer.add_feature_input(
name='input_tesnor', description='a feature input tensor')
@ -191,7 +191,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
def test_image_classifier(self):
writer = metadata_writer.MetadataWriter.create(
self.image_classifier_model_buffer)
writer.add_genernal_info(
writer.add_general_info(
model_name='image_classifier',
model_description='Imagenet classification model')
writer.add_image_input(
@ -282,7 +282,7 @@ class MetadataWriterForTaskTest(absltest.TestCase):
def test_image_classifier_with_locale_and_score_calibration(self):
writer = metadata_writer.MetadataWriter(self.image_classifier_model_buffer)
writer.add_genernal_info(
writer.add_general_info(
model_name='image_classifier',
model_description='Classify the input image.')
writer.add_image_input(

View File

@ -0,0 +1,51 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metadata_writer.text_classifier."""
from absl.testing import absltest
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier
from mediapipe.tasks.python.test import test_utils
_TEST_DIR = "mediapipe/tasks/testdata/metadata/"
_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
_LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR +
"movie_review_labels.txt")
_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
_DELIM_REGEX_PATTERN = r"[^\w\']+"
_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
class TextClassifierTest(absltest.TestCase):
def test_write_metadata(self,):
with open(_MODEL, "rb") as f:
model_buffer = f.read()
writer = text_classifier.MetadataWriter.create_for_regex_model(
model_buffer,
regex_tokenizer=metadata_writer.RegexTokenizer(
delim_regex_pattern=_DELIM_REGEX_PATTERN,
vocab_file_path=_VOCAB_FILE),
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
_, metadata_json = writer.populate()
with open(_JSON_FILE, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
if __name__ == "__main__":
absltest.main()

View File

@ -31,6 +31,7 @@ mediapipe_files(srcs = [
"mobilenet_v2_1.0_224_quant.tflite",
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
"mobilenet_v2_1.0_224_without_metadata.tflite",
"movie_review.tflite",
])
exports_files([
@ -54,6 +55,11 @@ exports_files([
"labels.txt",
"mobilenet_v2_1.0_224.json",
"mobilenet_v2_1.0_224_quant.json",
"input_text_tensor_meta.json",
"input_text_tensor_default_meta.json",
"movie_review_labels.txt",
"regex_vocab.txt",
"movie_review.json",
])
filegroup(
@ -67,6 +73,7 @@ filegroup(
"mobilenet_v2_1.0_224_quant.tflite",
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
"mobilenet_v2_1.0_224_without_metadata.tflite",
"movie_review.tflite",
],
)
@ -86,9 +93,14 @@ filegroup(
"input_image_tensor_float_meta.json",
"input_image_tensor_uint8_meta.json",
"input_image_tensor_unsupported_meta.json",
"input_text_tensor_default_meta.json",
"input_text_tensor_meta.json",
"labels.txt",
"mobilenet_v2_1.0_224.json",
"mobilenet_v2_1.0_224_quant.json",
"movie_review.json",
"movie_review_labels.txt",
"regex_vocab.txt",
"score_calibration.txt",
"score_calibration_file_meta.json",
"score_calibration_tensor_meta.json",

View File

@ -0,0 +1,17 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
}
]
}
]
}

View File

@ -0,0 +1,34 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input text",
"description": "The input string.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "RegexTokenizerOptions",
"options": {
"delim_regex_pattern": "[^\\w\\']+",
"vocab_file": [
{
"name": "vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
],
"stats": {
}
}
]
}
]
}

View File

@ -0,0 +1,63 @@
{
"name": "TextClassifier",
"description": "Classify the input text into a set of known categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input_text",
"description": "Embedding vectors representing the input text to be processed.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "RegexTokenizerOptions",
"options": {
"delim_regex_pattern": "[^\\w\\']+",
"vocab_file": [
{
"name": "regex_vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
],
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.2.1"
}

View File

@ -0,0 +1,2 @@
Negative
Positive

File diff suppressed because it is too large Load Diff

View File

@ -346,6 +346,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/input_image_tensor_unsupported_meta.json?generation=1665422835757143"],
)
http_file(
name = "com_google_mediapipe_input_text_tensor_default_meta_json",
sha256 = "9723e59960b0e6ca60d120494c32e798b054ea6e5a441b359c84f759bd2b3a36",
urls = ["https://storage.googleapis.com/mediapipe-assets/input_text_tensor_default_meta.json?generation=1667855382021347"],
)
http_file(
name = "com_google_mediapipe_input_text_tensor_meta_json",
sha256 = "c6782f676220e2cc89b70bacccb649fc848c18e33bedc449bf49f5d839b3cc6c",
urls = ["https://storage.googleapis.com/mediapipe-assets/input_text_tensor_meta.json?generation=1667855384891533"],
)
http_file(
name = "com_google_mediapipe_iris_and_gaze_tflite",
sha256 = "b6dcb860a92a3c7264a8e50786f46cecb529672cdafc17d39c78931257da661d",
@ -390,8 +402,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_labels_txt",
sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9",
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1665988394538324"],
sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a",
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667855388142641"],
)
http_file(
@ -538,6 +550,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/model_without_metadata.tflite?generation=1661875850966737"],
)
http_file(
name = "com_google_mediapipe_movie_review_json",
sha256 = "89ad347ad1cb7c587da144de6efbadec1d3e8ff0cd13e379dd16661a8186fbb5",
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667855392734031"],
)
http_file(
name = "com_google_mediapipe_movie_review_tflite",
sha256 = "3935ee73b13d435327d05af4d6f37dc3c146e117e1c3d572ae4d2ae0f5f412fe",
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.tflite?generation=1667855395736217"],
)
http_file(
name = "com_google_mediapipe_mozart_square_jpg",
sha256 = "4feb4dadc5d6f853ade57b8c9d4c9a1f5ececd6469616c8e505f9a14823392b6",