Open source Model Maker text tasks.

PiperOrigin-RevId: 487706929
This commit is contained in:
MediaPipe Team 2022-11-10 19:51:08 -08:00 committed by Copybara-Service
parent d2284083b3
commit 3e05871f98
19 changed files with 1882 additions and 0 deletions

View File

@ -0,0 +1,35 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])
py_library(
name = "bert_model_options",
srcs = ["bert_model_options.py"],
)
py_library(
name = "bert_model_spec",
srcs = ["bert_model_spec.py"],
deps = [
":bert_model_options",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,33 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurable model options for a BERT model."""
import dataclasses
@dataclasses.dataclass
class BertModelOptions:
"""Configurable model options for a BERT model.
See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional
Transformers for Language Understanding) for more details.
Attributes:
seq_len: Length of the sequence to feed into the model.
do_fine_tuning: If true, then the BERT model is not frozen for training.
dropout_rate: The rate for dropout.
"""
seq_len: int = 128
do_fine_tuning: bool = True
dropout_rate: float = 0.1

View File

@ -0,0 +1,58 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specification for a BERT model."""
import dataclasses
from typing import Dict
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.core import bert_model_options
_DEFAULT_TFLITE_INPUT_NAME = {
'ids': 'serving_default_input_word_ids:0',
'mask': 'serving_default_input_mask:0',
'segment_ids': 'serving_default_input_type_ids:0'
}
@dataclasses.dataclass
class BertModelSpec:
"""Specification for a BERT model.
See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional
Transformers for Language Understanding) for more details.
Attributes:
hparams: Hyperparameters used for training.
model_options: Configurable options for a BERT model.
do_lower_case: boolean, whether to lower case the input text. Should be
True / False for uncased / cased models respectively, where the models
are specified by the `uri`.
tflite_input_name: Dict, input names for the TFLite model.
uri: URI for the BERT module.
name: The name of the object.
"""
hparams: hp.BaseHParams = hp.BaseHParams(
epochs=3,
batch_size=32,
learning_rate=3e-5,
distribution_strategy='mirrored')
model_options: bert_model_options.BertModelOptions = (
bert_model_options.BertModelOptions())
do_lower_case: bool = True
tflite_input_name: Dict[str, str] = dataclasses.field(
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
uri: str = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1'
name: str = 'Bert'

View File

@ -0,0 +1,146 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])
py_library(
name = "model_options",
srcs = ["model_options.py"],
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
)
py_library(
name = "model_spec",
srcs = ["model_spec.py"],
deps = [
":model_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/text/core:bert_model_spec",
],
)
py_test(
name = "model_spec_test",
srcs = ["model_spec_test.py"],
deps = [
":model_options",
":model_spec",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
deps = [":dataset"],
)
py_library(
name = "preprocessor",
srcs = ["preprocessor.py"],
deps = [":dataset"],
)
py_test(
name = "preprocessor_test",
srcs = ["preprocessor_test.py"],
tags = ["requires-net:external"],
deps = [
":dataset",
":model_spec",
":preprocessor",
],
)
py_library(
name = "text_classifier_options",
srcs = ["text_classifier_options.py"],
deps = [
":model_options",
":model_spec",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
name = "text_classifier",
srcs = ["text_classifier.py"],
deps = [
":dataset",
":model_options",
":model_spec",
":preprocessor",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:text_classifier",
],
)
py_test(
name = "text_classifier_test",
size = "large",
srcs = ["text_classifier_test.py"],
data = [
"//mediapipe/model_maker/python/text/text_classifier/testdata",
],
tags = ["requires-net:external"],
deps = [
":dataset",
":model_options",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/tasks/python/test:test_utils",
],
)
py_library(
name = "text_classifier_demo_lib",
srcs = ["text_classifier_demo.py"],
deps = [
":dataset",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/utils:quantization",
],
)
py_binary(
name = "text_classifier_demo",
srcs = ["text_classifier_demo.py"],
deps = [
":text_classifier_demo_lib",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,88 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text classifier dataset library."""
import csv
import dataclasses
import random
from typing import Optional, Sequence
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
@dataclasses.dataclass
class CSVParameters:
"""Parameters used when reading a CSV file.
Attributes:
text_column: Column name for the input text.
label_column: Column name for the labels.
fieldnames: Sequence of keys for the CSV columns. If None, the first row of
the CSV file is used as the keys.
delimiter: Character that separates fields.
quotechar: Character used to quote fields that contain special characters
like the `delimiter`.
"""
text_column: str
label_column: str
fieldnames: Optional[Sequence[str]] = None
delimiter: str = ","
quotechar: str = '"'
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for text classifier."""
@classmethod
def from_csv(cls,
filename: str,
csv_params: CSVParameters,
shuffle: bool = True) -> "Dataset":
"""Loads text with labels from a CSV file.
Args:
filename: Name of the CSV file.
csv_params: Parameters used for reading the CSV file.
shuffle: If True, randomly shuffle the data.
Returns:
Dataset containing (text, label) pairs and other related info.
"""
with tf.io.gfile.GFile(filename, "r") as f:
reader = csv.DictReader(
f,
fieldnames=csv_params.fieldnames,
delimiter=csv_params.delimiter,
quotechar=csv_params.quotechar)
lines = list(reader)
if shuffle:
random.shuffle(lines)
label_names = sorted(set([line[csv_params.label_column] for line in lines]))
index_by_label = {label: index for index, label in enumerate(label_names)}
texts = [line[csv_params.text_column] for line in lines]
text_ds = tf.data.Dataset.from_tensor_slices(tf.cast(texts, tf.string))
label_indices = [
index_by_label[line[csv_params.label_column]] for line in lines
]
label_index_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(label_indices, tf.int64))
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
return Dataset(
dataset=text_label_ds, size=len(texts), label_names=label_names)

View File

@ -0,0 +1,75 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import os
import tensorflow as tf
from mediapipe.model_maker.python.text.text_classifier import dataset
class DatasetTest(tf.test.TestCase):
def _get_csv_file(self):
labels_and_text = (('neutral', 'indifferent'), ('pos', 'extremely great'),
('neg', 'totally awful'), ('pos', 'super good'),
('neg', 'really bad'))
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
if os.path.exists(csv_file):
return csv_file
fieldnames = ['text', 'label']
with open(csv_file, 'w') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for label, text in labels_and_text:
writer.writerow({'text': text, 'label': label})
return csv_file
def test_from_csv(self):
csv_file = self._get_csv_file()
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
data = dataset.Dataset.from_csv(filename=csv_file, csv_params=csv_params)
self.assertLen(data, 5)
self.assertEqual(data.num_classes, 3)
self.assertEqual(data.label_names, ['neg', 'neutral', 'pos'])
data_values = set([(text.numpy()[0], label.numpy()[0])
for text, label in data.gen_tf_dataset()])
expected_data_values = set([(b'indifferent', 1), (b'extremely great', 2),
(b'totally awful', 0), (b'super good', 2),
(b'really bad', 0)])
self.assertEqual(data_values, expected_data_values)
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
train_data, test_data = data.split(0.5)
expected_train_data = [b'good', b'bad']
expected_test_data = [b'neutral', b'odd']
self.assertLen(train_data, 2)
train_data_values = [elem.numpy() for elem in train_data._dataset]
self.assertEqual(train_data_values, expected_train_data)
self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.label_names, ['pos', 'neg'])
self.assertLen(test_data, 2)
test_data_values = [elem.numpy() for elem in test_data._dataset]
self.assertEqual(test_data_values, expected_test_data)
self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.label_names, ['pos', 'neg'])
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,45 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurable model options for text classifier models."""
import dataclasses
from typing import Union
from mediapipe.model_maker.python.text.core import bert_model_options
# BERT text classifier options inherited from BertModelOptions.
BertClassifierOptions = bert_model_options.BertModelOptions
@dataclasses.dataclass
class AverageWordEmbeddingClassifierOptions:
"""Configurable model options for an Average Word Embedding classifier.
Attributes:
seq_len: Length of the sequence to feed into the model.
wordvec_dim: Dimension of the word embedding.
do_lower_case: Whether to convert all uppercase characters to lowercase
during preprocessing.
vocab_size: Number of words to generate the vocabulary from data.
dropout_rate: The rate for dropout.
"""
seq_len: int = 256
wordvec_dim: int = 16
do_lower_case: bool = True
vocab_size: int = 10000
dropout_rate: float = 0.2
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierOptions,
BertClassifierOptions]

View File

@ -0,0 +1,70 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specifications for text classifier models."""
import dataclasses
import enum
import functools
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.core import bert_model_spec
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
# BERT-based text classifier spec inherited from BertModelSpec
BertClassifierSpec = bert_model_spec.BertModelSpec
@dataclasses.dataclass
class AverageWordEmbeddingClassifierSpec:
"""Specification for an average word embedding classifier model.
Attributes:
hparams: Configurable hyperparameters for training.
model_options: Configurable options for the average word embedding model.
name: The name of the object.
"""
# `learning_rate` is unused for the average word embedding model
hparams: hp.BaseHParams = hp.BaseHParams(
epochs=10, batch_size=32, learning_rate=0)
model_options: mo.AverageWordEmbeddingClassifierOptions = (
mo.AverageWordEmbeddingClassifierOptions())
name: str = 'AverageWordEmbedding'
average_word_embedding_classifier_spec = functools.partial(
AverageWordEmbeddingClassifierSpec)
mobilebert_classifier_spec = functools.partial(
BertClassifierSpec,
hparams=hp.BaseHParams(
epochs=3,
batch_size=48,
learning_rate=3e-5,
distribution_strategy='off'),
name='MobileBert',
uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'mask': 'serving_default_input_3:0',
'segment_ids': 'serving_default_input_2:0'
},
)
@enum.unique
class SupportedModels(enum.Enum):
"""Predefined text classifier model specs supported by Model Maker."""
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec

View File

@ -0,0 +1,118 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for model_spec."""
import os
import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
class ModelSpecTest(tf.test.TestCase):
def test_predefined_bert_spec(self):
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
self.assertEqual(model_spec_obj.name, 'MobileBert')
self.assertEqual(
model_spec_obj.uri, 'https://tfhub.dev/tensorflow/'
'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1')
self.assertTrue(model_spec_obj.do_lower_case)
self.assertEqual(
model_spec_obj.tflite_input_name, {
'ids': 'serving_default_input_1:0',
'mask': 'serving_default_input_3:0',
'segment_ids': 'serving_default_input_2:0'
})
self.assertEqual(
model_spec_obj.model_options,
classifier_model_options.BertClassifierOptions(
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
self.assertEqual(
model_spec_obj.hparams,
hp.BaseHParams(
epochs=3,
batch_size=48,
learning_rate=3e-5,
distribution_strategy='off'))
def test_predefined_average_word_embedding_spec(self):
model_spec_obj = (
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value())
self.assertIsInstance(model_spec_obj, ms.AverageWordEmbeddingClassifierSpec)
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
self.assertEqual(
model_spec_obj.model_options,
classifier_model_options.AverageWordEmbeddingClassifierOptions(
seq_len=256,
wordvec_dim=16,
do_lower_case=True,
vocab_size=10000,
dropout_rate=0.2))
self.assertEqual(
model_spec_obj.hparams,
hp.BaseHParams(
epochs=10,
batch_size=32,
learning_rate=0,
steps_per_epoch=None,
shuffle=False,
distribution_strategy='off',
num_gpus=-1,
tpu=''))
def test_custom_bert_spec(self):
custom_bert_classifier_options = (
classifier_model_options.BertClassifierOptions(
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
model_spec_obj = (
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
model_options=custom_bert_classifier_options))
self.assertEqual(model_spec_obj.model_options,
custom_bert_classifier_options)
def test_custom_average_word_embedding_spec(self):
custom_hparams = hp.BaseHParams(
learning_rate=0.4,
batch_size=64,
epochs=10,
steps_per_epoch=10,
shuffle=True,
export_dir='foo/bar',
distribution_strategy='mirrored',
num_gpus=3,
tpu='tpu/address')
custom_average_word_embedding_model_options = (
classifier_model_options.AverageWordEmbeddingClassifierOptions(
seq_len=512,
wordvec_dim=32,
do_lower_case=False,
vocab_size=5000,
dropout_rate=0.5))
model_spec_obj = (
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value(
model_options=custom_average_word_embedding_model_options,
hparams=custom_hparams))
self.assertEqual(model_spec_obj.model_options,
custom_average_word_embedding_model_options)
self.assertEqual(model_spec_obj.hparams, custom_hparams)
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()

View File

@ -0,0 +1,285 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessors for text classification."""
import collections
import os
import re
import tempfile
from typing import Mapping, Sequence, Tuple, Union
import tensorflow as tf
import tensorflow_hub
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization
def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None:
"""Validates the shape and type of `text` and `label`.
Args:
text: Stores text data. Should have shape [1] and dtype tf.string.
label: Stores the label for the corresponding `text`. Should have shape [1]
and dtype tf.int64.
Raises:
ValueError: If either tensor has the wrong shape or type.
"""
if text.shape != [1]:
raise ValueError(f"`text` should have shape [1], got {text.shape}")
if text.dtype != tf.string:
raise ValueError(f"Expected dtype string for `text`, got {text.dtype}")
if label.shape != [1]:
raise ValueError(f"`label` should have shape [1], got {text.shape}")
if label.dtype != tf.int64:
raise ValueError(f"Expected dtype int64 for `label`, got {label.dtype}")
def _decode_record(
record: tf.Tensor, name_to_features: Mapping[str, tf.io.FixedLenFeature]
) -> Tuple[Mapping[str, tf.Tensor], tf.Tensor]:
"""Decodes a record into input for a BERT model.
Args:
record: Stores serialized example.
name_to_features: Maps record keys to feature types.
Returns:
BERT model input features and label for the record.
"""
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
for name in list(example.keys()):
example[name] = tf.cast(example[name], tf.int32)
bert_features = {
"input_word_ids": example["input_ids"],
"input_mask": example["input_mask"],
"input_type_ids": example["segment_ids"]
}
return bert_features, example["label_ids"]
def _single_file_dataset(
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature]
) -> tf.data.TFRecordDataset:
"""Creates a single-file dataset to be passed for BERT custom training.
Args:
input_file: Filepath for the dataset.
name_to_features: Maps record keys to feature types.
Returns:
Dataset containing BERT model input features and labels.
"""
d = tf.data.TFRecordDataset(input_file)
d = d.map(
lambda record: _decode_record(record, name_to_features),
num_parallel_calls=tf.data.AUTOTUNE)
return d
class AverageWordEmbeddingClassifierPreprocessor:
"""Preprocessor for an Average Word Embedding model.
Takes (text, label) data and applies regex tokenization and padding to the
text to generate (token IDs, label) data.
Attributes:
seq_len: Length of the input sequence to the model.
do_lower_case: Whether text inputs should be converted to lower-case.
vocab: Vocabulary of tokens used by the model.
"""
PAD: str = "<PAD>" # Index: 0
START: str = "<START>" # Index: 1
UNKNOWN: str = "<UNKNOWN>" # Index: 2
def __init__(self, seq_len: int, do_lower_case: bool, texts: Sequence[str],
vocab_size: int):
self._seq_len = seq_len
self._do_lower_case = do_lower_case
self._vocab = self._gen_vocab(texts, vocab_size)
def _gen_vocab(self, texts: Sequence[str],
vocab_size: int) -> Mapping[str, int]:
"""Generates vocabulary list in `texts` with size `vocab_size`.
Args:
texts: All texts (across training and validation data) that will be
preprocessed by the model.
vocab_size: Size of the vocab.
Returns:
The vocab mapping tokens to IDs.
"""
vocab_counter = collections.Counter()
for text in texts:
tokens = self._regex_tokenize(text)
for token in tokens:
vocab_counter[token] += 1
vocab_freq = vocab_counter.most_common(vocab_size)
vocab_list = [self.PAD, self.START, self.UNKNOWN
] + [word for word, _ in vocab_freq]
return collections.OrderedDict(((v, i) for i, v in enumerate(vocab_list)))
def get_vocab(self) -> Mapping[str, int]:
"""Returns the vocab of the AverageWordEmbeddingClassifierPreprocessor."""
return self._vocab
# TODO: Align with MediaPipe's RegexTokenizer.
def _regex_tokenize(self, text: str) -> Sequence[str]:
"""Splits `text` by words but does not split on single quotes.
Args:
text: Text to be tokenized.
Returns:
List of tokens.
"""
text = tf.compat.as_text(text)
if self._do_lower_case:
text = text.lower()
tokens = re.compile(r"[^\w\']+").split(text.strip())
# Filters out any empty strings in `tokens`.
return list(filter(None, tokens))
def _tokenize_and_pad(self, text: str) -> Sequence[int]:
"""Tokenizes `text` and pads the tokens to `seq_len`.
Args:
text: Text to be tokenized and padded.
Returns:
List of token IDs padded to have length `seq_len`.
"""
tokens = self._regex_tokenize(text)
# Gets ids for START, PAD and UNKNOWN tokens.
start_id = self._vocab[self.START]
pad_id = self._vocab[self.PAD]
unknown_id = self._vocab[self.UNKNOWN]
token_ids = [self._vocab.get(token, unknown_id) for token in tokens]
token_ids = [start_id] + token_ids
if len(token_ids) < self._seq_len:
pad_length = self._seq_len - len(token_ids)
token_ids = token_ids + pad_length * [pad_id]
else:
token_ids = token_ids[:self._seq_len]
return token_ids
def preprocess(
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
"""Preprocesses data into input for an Average Word Embedding model.
Args:
dataset: Stores (text, label) data.
Returns:
Dataset containing (token IDs, label) data.
"""
token_ids_list = []
labels_list = []
for text, label in dataset.gen_tf_dataset():
_validate_text_and_label(text, label)
token_ids = self._tokenize_and_pad(text.numpy()[0].decode("utf-8"))
token_ids_list.append(token_ids)
labels_list.append(label.numpy()[0])
token_ids_ds = tf.data.Dataset.from_tensor_slices(token_ids_list)
labels_ds = tf.data.Dataset.from_tensor_slices(labels_list)
preprocessed_ds = tf.data.Dataset.zip((token_ids_ds, labels_ds))
return text_classifier_ds.Dataset(
dataset=preprocessed_ds,
size=dataset.size,
label_names=dataset.label_names)
class BertClassifierPreprocessor:
"""Preprocessor for a BERT-based classifier.
Attributes:
seq_len: Length of the input sequence to the model.
vocab_file: File containing the BERT vocab.
tokenizer: BERT tokenizer.
"""
def __init__(self, seq_len: int, do_lower_case: bool, uri: str):
self._seq_len = seq_len
# Vocab filepath is tied to the BERT module's URI.
self._vocab_file = os.path.join(
tensorflow_hub.resolve(uri), "assets", "vocab.txt")
self._tokenizer = tokenization.FullTokenizer(self._vocab_file,
do_lower_case)
def _get_name_to_features(self):
"""Gets the dictionary mapping record keys to feature types."""
return {
"input_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"segment_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"label_ids": tf.io.FixedLenFeature([], tf.int64),
}
def get_vocab_file(self) -> str:
"""Returns the vocab file of the BertClassifierPreprocessor."""
return self._vocab_file
def preprocess(
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
"""Preprocesses data into input for a BERT-based classifier.
Args:
dataset: Stores (text, label) data.
Returns:
Dataset containing (bert_features, label) data.
"""
examples = []
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
_validate_text_and_label(text, label)
examples.append(
classifier_data_lib.InputExample(
guid=str(index),
text_a=text.numpy()[0].decode("utf-8"),
text_b=None,
# InputExample expects the label name rather than the int ID
label=dataset.label_names[label.numpy()[0]]))
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord")
classifier_data_lib.file_based_convert_examples_to_features(
examples=examples,
label_list=dataset.label_names,
max_seq_length=self._seq_len,
tokenizer=self._tokenizer,
output_file=tfrecord_file)
preprocessed_ds = _single_file_dataset(tfrecord_file,
self._get_name_to_features())
return text_classifier_ds.Dataset(
dataset=preprocessed_ds,
size=dataset.size,
label_names=dataset.label_names)
TextClassifierPreprocessor = (
Union[BertClassifierPreprocessor,
AverageWordEmbeddingClassifierPreprocessor])

View File

@ -0,0 +1,96 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import os
import numpy as np
import numpy.testing as npt
import tensorflow as tf
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import preprocessor
class PreprocessorTest(tf.test.TestCase):
CSV_PARAMS_ = text_classifier_ds.CSVParameters(
text_column='text', label_column='label')
def _get_csv_file(self):
labels_and_text = (('pos', 'super super super super good'),
(('neg', 'really bad')))
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
if os.path.exists(csv_file):
return csv_file
fieldnames = ['text', 'label']
with open(csv_file, 'w') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for label, text in labels_and_text:
writer.writerow({'text': text, 'label': label})
return csv_file
def test_average_word_embedding_preprocessor(self):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file, csv_params=self.CSV_PARAMS_)
average_word_embedding_preprocessor = (
preprocessor.AverageWordEmbeddingClassifierPreprocessor(
seq_len=5,
do_lower_case=True,
texts=['super super super super good', 'really bad'],
vocab_size=7))
preprocessed_dataset = (
average_word_embedding_preprocessor.preprocess(dataset))
labels = []
features_list = []
for features, label in preprocessed_dataset.gen_tf_dataset():
self.assertEqual(label.shape, [1])
labels.append(label.numpy()[0])
self.assertEqual(features.shape, [1, 5])
features_list.append(features.numpy()[0])
self.assertEqual(labels, [1, 0])
npt.assert_array_equal(
np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]]))
def test_bert_preprocessor(self):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file, csv_params=self.CSV_PARAMS_)
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5, do_lower_case=bert_spec.do_lower_case, uri=bert_spec.uri)
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
labels = []
input_masks = []
for features, label in preprocessed_dataset.gen_tf_dataset():
self.assertEqual(label.shape, [1])
labels.append(label.numpy()[0])
self.assertSameElements(
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'])
for feature in features.values():
self.assertEqual(feature.shape, [1, 5])
input_masks.append(features['input_mask'].numpy()[0])
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
[0, 0, 0, 0, 0])
npt.assert_array_equal(
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]))
self.assertEqual(labels, [1, 0])
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()

View File

@ -0,0 +1,23 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "testdata",
srcs = ["average_word_embedding_metadata.json"],
)

View File

@ -0,0 +1,63 @@
{
"name": "TextClassifier",
"description": "Classify the input text into a set of known categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input_text",
"description": "Embedding vectors representing the input text to be processed.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "RegexTokenizerOptions",
"options": {
"delim_regex_pattern": "[^\\w\\']+",
"vocab_file": [
{
"name": "vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
],
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.2.1"
}

View File

@ -0,0 +1,437 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API for text classification."""
import abc
import os
import tempfile
from typing import Any, Optional, Sequence, Tuple
import tensorflow as tf
import tensorflow_hub as hub
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import preprocessor
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer
from official.nlp import optimization
def _validate(options: text_classifier_options.TextClassifierOptions):
"""Validates that `model_options` and `supported_model` are compatible.
Args:
options: Options for creating and training a text classifier.
Raises:
ValueError if there is a mismatch between `model_options` and
`supported_model`.
"""
if options.model_options is None:
return
if (isinstance(options.model_options,
mo.AverageWordEmbeddingClassifierOptions) and
(options.supported_model !=
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}")
if (isinstance(options.model_options, mo.BertClassifierOptions) and
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
raise ValueError(
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
class TextClassifier(classifier.Classifier):
"""API for creating and training a text classification model."""
def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
label_names: Sequence[str]):
super().__init__(
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
self._model_spec = model_spec
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
@classmethod
def create(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions
) -> "TextClassifier":
"""Factory function that creates and trains a text classifier.
Note that `train_data` and `validation_data` are expected to share the same
`label_names` since they should be split from the same dataset.
Args:
train_data: Training data.
validation_data: Validation data.
options: Options for creating and training the text classifier.
Returns:
A text classifier.
Raises:
ValueError if `train_data` and `validation_data` do not have the
same label_names or `options` contains an unknown `supported_model`
"""
if train_data.label_names != validation_data.label_names:
raise ValueError(
f"Training data label names {train_data.label_names} not equal to "
f"validation data label names {validation_data.label_names}")
_validate(options)
if options.model_options is None:
options.model_options = options.supported_model.value().model_options
if options.hparams is None:
options.hparams = options.supported_model.value().hparams
if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
text_classifier = (
_BertClassifier.create_bert_classifier(train_data, validation_data,
options,
train_data.label_names))
elif (options.supported_model ==
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
text_classifier = (
_AverageWordEmbeddingClassifier
.create_average_word_embedding_classifier(train_data, validation_data,
options,
train_data.label_names))
else:
raise ValueError(f"Unknown model {options.supported_model}")
return text_classifier
def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any:
"""Overrides Classifier.evaluate().
Args:
data: Evaluation dataset. Must be a TextClassifier Dataset.
batch_size: Number of samples per evaluation step.
Returns:
The loss value and accuracy.
Raises:
ValueError if `data` is not a TextClassifier Dataset.
"""
# This override is needed because TextClassifier preprocesses its data
# outside of the `gen_tf_dataset()` method. The preprocess call also
# requires a TextClassifier Dataset instead of a core Dataset.
if not isinstance(data, text_ds.Dataset):
raise ValueError("Need a TextClassifier Dataset.")
processed_data = self._text_preprocessor.preprocess(data)
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
return self._model.evaluate(dataset)
def export_model(
self,
model_name: str = "model.tflite",
quantization_config: Optional[quantization.QuantizationConfig] = None):
"""Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function also
saves a metadata.json file to the same directory as the TFLite file which
can be used to interpret the metadata content in the TFLite file.
Args:
model_name: File name to save TFLite model with metadata. The full export
path is {self._hparams.export_dir}/{model_name}.
quantization_config: The configuration for model quantization.
"""
if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.export_dir, model_name)
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
tflite_model = model_util.convert_to_tflite(
model=self._model, quantization_config=quantization_config)
vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt")
self._save_vocab(vocab_filepath)
writer = self._get_metadata_writer(tflite_model, vocab_filepath)
tflite_model_with_metadata, metadata_json = writer.populate()
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with open(metadata_file, "w") as f:
f.write(metadata_json)
@abc.abstractmethod
def _save_vocab(self, vocab_filepath: str):
"""Saves the preprocessor's vocab to `vocab_filepath`."""
@abc.abstractmethod
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
"""Gets the metadata writer for the text classifier TFLite model."""
class _AverageWordEmbeddingClassifier(TextClassifier):
"""APIs to help create and train an Average Word Embedding text classifier."""
_DELIM_REGEX_PATTERN = r"[^\w\']+"
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
model_options: mo.AverageWordEmbeddingClassifierOptions,
hparams: hp.BaseHParams, label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
self._model_options = model_options
self._loss_function = "sparse_categorical_crossentropy"
self._metric_function = "accuracy"
self._text_preprocessor: (
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
@classmethod
def create_average_word_embedding_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier":
"""Creates, trains, and returns an Average Word Embedding classifier.
Args:
train_data: Training data.
validation_data: Validation data.
options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns:
An Average Word Embedding classifier.
"""
average_word_embedding_classifier = _AverageWordEmbeddingClassifier(
model_spec=options.supported_model.value(),
model_options=options.model_options,
hparams=options.hparams,
label_names=train_data.label_names)
average_word_embedding_classifier._create_and_train_model(
train_data, validation_data)
return average_word_embedding_classifier
def _create_and_train_model(self, train_data: text_ds.Dataset,
validation_data: text_ds.Dataset):
"""Creates the Average Word Embedding classifier keras model and trains it.
Args:
train_data: Training data.
validation_data: Validation data.
"""
(processed_train_data, processed_validation_data) = (
self._load_and_run_preprocessor(train_data, validation_data))
self._create_model()
self._optimizer = "rmsprop"
self._train_model(processed_train_data, processed_validation_data)
def _load_and_run_preprocessor(
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
"""Runs an AverageWordEmbeddingClassifierPreprocessor on the data.
Args:
train_data: Training data.
validation_data: Validation data.
Returns:
Preprocessed training data and preprocessed validation data.
"""
train_texts = [text.numpy()[0] for text, _ in train_data.gen_tf_dataset()]
validation_texts = [
text.numpy()[0] for text, _ in validation_data.gen_tf_dataset()
]
self._text_preprocessor = (
preprocessor.AverageWordEmbeddingClassifierPreprocessor(
seq_len=self._model_options.seq_len,
do_lower_case=self._model_options.do_lower_case,
texts=train_texts + validation_texts,
vocab_size=self._model_options.vocab_size))
return self._text_preprocessor.preprocess(
train_data), self._text_preprocessor.preprocess(validation_data)
def _create_model(self):
"""Creates an Average Word Embedding model."""
self._model = tf.keras.Sequential([
tf.keras.layers.InputLayer(
input_shape=[self._model_options.seq_len], dtype=tf.int32),
tf.keras.layers.Embedding(
len(self._text_preprocessor.get_vocab()),
self._model_options.wordvec_dim,
input_length=self._model_options.seq_len),
tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense(
self._model_options.wordvec_dim, activation=tf.nn.relu),
tf.keras.layers.Dropout(self._model_options.dropout_rate),
tf.keras.layers.Dense(self._num_classes, activation="softmax")
])
def _save_vocab(self, vocab_filepath: str):
with tf.io.gfile.GFile(vocab_filepath, "w") as f:
for token, index in self._text_preprocessor.get_vocab().items():
f.write(f"{token} {index}\n")
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
return text_classifier_writer.MetadataWriter.create_for_regex_model(
model_buffer=tflite_model,
regex_tokenizer=metadata_writer.RegexTokenizer(
# TODO: Align with MediaPipe's RegexTokenizer.
delim_regex_pattern=self._DELIM_REGEX_PATTERN,
vocab_file_path=vocab_filepath),
labels=metadata_writer.Labels().add(list(self._label_names)))
class _BertClassifier(TextClassifier):
"""APIs to help create and train a BERT-based text classifier."""
_INITIALIZER_RANGE = 0.02
def __init__(self, model_spec: ms.BertClassifierSpec,
model_options: mo.BertClassifierOptions, hparams: hp.BaseHParams,
label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
self._model_options = model_options
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
"test_accuracy", dtype=tf.float32)
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
@classmethod
def create_bert_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_BertClassifier":
"""Creates, trains, and returns a BERT-based classifier.
Args:
train_data: Training data.
validation_data: Validation data.
options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns:
A BERT-based classifier.
"""
bert_classifier = _BertClassifier(
model_spec=options.supported_model.value(),
model_options=options.model_options,
hparams=options.hparams,
label_names=train_data.label_names)
bert_classifier._create_and_train_model(train_data, validation_data)
return bert_classifier
def _create_and_train_model(self, train_data: text_ds.Dataset,
validation_data: text_ds.Dataset):
"""Creates the BERT-based classifier keras model and trains it.
Args:
train_data: Training data.
validation_data: Validation data.
"""
(processed_train_data, processed_validation_data) = (
self._load_and_run_preprocessor(train_data, validation_data))
self._create_model()
self._create_optimizer(processed_train_data)
self._train_model(processed_train_data, processed_validation_data)
def _load_and_run_preprocessor(
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
"""Loads a BertClassifierPreprocessor and runs it on the data.
Args:
train_data: Training data.
validation_data: Validation data.
Returns:
Preprocessed training data and preprocessed validation data.
"""
self._text_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=self._model_options.seq_len,
do_lower_case=self._model_spec.do_lower_case,
uri=self._model_spec.uri)
return (self._text_preprocessor.preprocess(train_data),
self._text_preprocessor.preprocess(validation_data))
def _create_model(self):
"""Creates a BERT-based classifier model.
The model architecture consists of stacking a dense classification layer and
dropout layer on top of the BERT encoder outputs.
"""
encoder_inputs = dict(
input_word_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
input_mask=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
input_type_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
)
encoder = hub.KerasLayer(
self._model_spec.uri, trainable=self._model_options.do_fine_tuning)
encoder_outputs = encoder(encoder_inputs)
pooled_output = encoder_outputs["pooled_output"]
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
pooled_output)
initializer = tf.keras.initializers.TruncatedNormal(
stddev=self._INITIALIZER_RANGE)
output = tf.keras.layers.Dense(
self._num_classes,
kernel_initializer=initializer,
name="output",
activation="softmax",
dtype=tf.float32)(
output)
self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output)
def _create_optimizer(self, train_data: text_ds.Dataset):
"""Loads an optimizer with a learning rate schedule.
The decay steps in the learning rate schedule depend on the
`steps_per_epoch` which may depend on the size of the training data.
Args:
train_data: Training data.
"""
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=self._hparams.steps_per_epoch,
batch_size=self._hparams.batch_size,
train_data=train_data)
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
warmup_steps = int(total_steps * 0.1)
initial_lr = self._hparams.learning_rate
self._optimizer = optimization.create_optimizer(initial_lr, total_steps,
warmup_steps)
def _save_vocab(self, vocab_filepath: str):
tf.io.gfile.copy(
self._text_preprocessor.get_vocab_file(),
vocab_filepath,
overwrite=True)
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
return text_classifier_writer.MetadataWriter.create_for_bert_model(
model_buffer=tflite_model,
tokenizer=metadata_writer.BertTokenizer(vocab_filepath),
labels=metadata_writer.Labels().add(list(self._label_names)),
ids_name=self._model_spec.tflite_input_name["ids"],
mask_name=self._model_spec.tflite_input_name["mask"],
segment_name=self._model_spec.tflite_input_name["segment_ids"])

View File

@ -0,0 +1,108 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Demo for making a text classifier model by MediaPipe Model Maker."""
import os
import tempfile
# Dependency imports
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
FLAGS = flags.FLAGS
def define_flags():
flags.DEFINE_string('export_dir', None,
'The directory to save exported files.')
flags.DEFINE_enum('supported_model', 'average_word_embedding',
['average_word_embedding', 'bert'],
'The text classifier to run.')
flags.mark_flag_as_required('export_dir')
def download_demo_data():
"""Downloads demo data, and returns directory path."""
data_path = tf.keras.utils.get_file(
fname='SST-2.zip',
origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
extract=True)
return os.path.join(os.path.dirname(data_path), 'SST-2') # folder name
def run(data_dir,
export_dir=tempfile.mkdtemp(),
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
"""Runs demo."""
# Gets training data and validation data.
csv_params = text_ds.CSVParameters(
text_column='sentence', label_column='label', delimiter='\t')
train_data = text_ds.Dataset.from_csv(
filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
csv_params=csv_params)
validation_data = text_ds.Dataset.from_csv(
filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
csv_params=csv_params)
quantization_config = None
if supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER:
hparams = hp.BaseHParams(
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
# Warning: This takes extremely long to run on CPU
elif supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
quantization_config = quantization.QuantizationConfig.for_dynamic()
hparams = hp.BaseHParams(
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
# Fine-tunes the model.
options = text_classifier_options.TextClassifierOptions(
supported_model=supported_model, hparams=hparams)
model = text_classifier.TextClassifier.create(train_data, validation_data,
options)
# Gets evaluation results.
_, acc = model.evaluate(validation_data)
print('Eval accuracy: %f' % acc)
model.export_model(quantization_config=quantization_config)
model.export_labels(export_dir=options.hparams.export_dir)
def main(_):
logging.set_verbosity(logging.INFO)
data_dir = download_demo_data()
export_dir = os.path.expanduser(FLAGS.export_dir)
if FLAGS.supported_model == 'average_word_embedding':
supported_model = ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
elif FLAGS.supported_model == 'bert':
supported_model = ms.SupportedModels.MOBILEBERT_CLASSIFIER
run(data_dir, export_dir, supported_model)
if __name__ == '__main__':
define_flags()
app.run(main)

View File

@ -0,0 +1,38 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""User-facing customization options to create and train a text classifier."""
import dataclasses
from typing import Optional
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
@dataclasses.dataclass
class TextClassifierOptions:
"""User-facing options for creating the text classifier.
Attributes:
supported_model: A preconfigured model spec.
hparams: Training hyperparameters the user can set to override the ones in
`supported_model`.
model_options: Model options the user can set to override the ones in
`supported_model`. The model options type should be consistent with the
architecture of the `supported_model`.
"""
supported_model: ms.SupportedModels
hparams: Optional[hp.BaseHParams] = None
model_options: Optional[mo.TextClassifierModelOptions] = None

View File

@ -0,0 +1,138 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import filecmp
import os
import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import dataset
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.tasks.python.test import test_utils
class TextClassifierTest(tf.test.TestCase):
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
def _get_data(self):
labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
if os.path.exists(csv_file):
return csv_file
fieldnames = ['text', 'label']
with open(csv_file, 'w') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for label, text in labels_and_text:
writer.writerow({'text': text, 'label': label})
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
all_data = dataset.Dataset.from_csv(
filename=csv_file, csv_params=csv_params)
return all_data.split(0.5)
def test_create_and_train_average_word_embedding_model(self):
train_data, validation_data = self._get_data()
options = text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER,
hparams=hp.BaseHParams(epochs=1, batch_size=1, learning_rate=0))
average_word_embedding_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options)
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0)
# Test export_model
average_word_embedding_classifier.export_model()
output_metadata_file = os.path.join(options.hparams.export_dir,
'metadata.json')
output_tflite_file = os.path.join(options.hparams.export_dir,
'model.tflite')
self.assertTrue(os.path.exists(output_tflite_file))
self.assertGreater(os.path.getsize(output_tflite_file), 0)
self.assertTrue(os.path.exists(output_metadata_file))
self.assertGreater(os.path.getsize(output_metadata_file), 0)
self.assertTrue(
filecmp.cmp(output_metadata_file,
self._AVERAGE_WORD_EMBEDDING_JSON_FILE))
def test_create_and_train_bert(self):
train_data, validation_data = self._get_data()
options = text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=mo.BertClassifierOptions(do_fine_tuning=False, seq_len=2),
hparams=hp.BaseHParams(
epochs=1,
batch_size=1,
learning_rate=3e-5,
distribution_strategy='off'))
bert_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options)
_, accuracy = bert_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0)
# TODO: Add a unit test that does not run OOM.
def test_label_mismatch(self):
options = (
text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER))
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
train_data = dataset.Dataset(train_tf_dataset, 1, ['foo'])
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
validation_data = dataset.Dataset(validation_tf_dataset, 1, ['bar'])
with self.assertRaisesRegex(
ValueError,
'Training data label names .* not equal to validation data label names'
):
text_classifier.TextClassifier.create(train_data, validation_data,
options)
def test_options_mismatch(self):
train_data, validation_data = self._get_data()
avg_options = (
text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=mo.AverageWordEmbeddingClassifierOptions()))
with self.assertRaisesRegex(
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
' SupportedModels.MOBILEBERT_CLASSIFIER'):
text_classifier.TextClassifier.create(train_data, validation_data,
avg_options)
bert_options = (
text_classifier_options.TextClassifierOptions(
supported_model=(
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
model_options=mo.BertClassifierOptions()))
with self.assertRaisesRegex(
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
text_classifier.TextClassifier.create(train_data, validation_data,
bert_options)
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()