Open source Model Maker text tasks.
PiperOrigin-RevId: 487706929
This commit is contained in:
parent
d2284083b3
commit
3e05871f98
35
mediapipe/model_maker/python/text/core/BUILD
Normal file
35
mediapipe/model_maker/python/text/core/BUILD
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "bert_model_options",
|
||||||
|
srcs = ["bert_model_options.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "bert_model_spec",
|
||||||
|
srcs = ["bert_model_spec.py"],
|
||||||
|
deps = [
|
||||||
|
":bert_model_options",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
|
)
|
13
mediapipe/model_maker/python/text/core/__init__.py
Normal file
13
mediapipe/model_maker/python/text/core/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
33
mediapipe/model_maker/python/text/core/bert_model_options.py
Normal file
33
mediapipe/model_maker/python/text/core/bert_model_options.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Configurable model options for a BERT model."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BertModelOptions:
|
||||||
|
"""Configurable model options for a BERT model.
|
||||||
|
|
||||||
|
See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional
|
||||||
|
Transformers for Language Understanding) for more details.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
seq_len: Length of the sequence to feed into the model.
|
||||||
|
do_fine_tuning: If true, then the BERT model is not frozen for training.
|
||||||
|
dropout_rate: The rate for dropout.
|
||||||
|
"""
|
||||||
|
seq_len: int = 128
|
||||||
|
do_fine_tuning: bool = True
|
||||||
|
dropout_rate: float = 0.1
|
58
mediapipe/model_maker/python/text/core/bert_model_spec.py
Normal file
58
mediapipe/model_maker/python/text/core/bert_model_spec.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Specification for a BERT model."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.text.core import bert_model_options
|
||||||
|
|
||||||
|
_DEFAULT_TFLITE_INPUT_NAME = {
|
||||||
|
'ids': 'serving_default_input_word_ids:0',
|
||||||
|
'mask': 'serving_default_input_mask:0',
|
||||||
|
'segment_ids': 'serving_default_input_type_ids:0'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BertModelSpec:
|
||||||
|
"""Specification for a BERT model.
|
||||||
|
|
||||||
|
See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional
|
||||||
|
Transformers for Language Understanding) for more details.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
hparams: Hyperparameters used for training.
|
||||||
|
model_options: Configurable options for a BERT model.
|
||||||
|
do_lower_case: boolean, whether to lower case the input text. Should be
|
||||||
|
True / False for uncased / cased models respectively, where the models
|
||||||
|
are specified by the `uri`.
|
||||||
|
tflite_input_name: Dict, input names for the TFLite model.
|
||||||
|
uri: URI for the BERT module.
|
||||||
|
name: The name of the object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||||
|
epochs=3,
|
||||||
|
batch_size=32,
|
||||||
|
learning_rate=3e-5,
|
||||||
|
distribution_strategy='mirrored')
|
||||||
|
model_options: bert_model_options.BertModelOptions = (
|
||||||
|
bert_model_options.BertModelOptions())
|
||||||
|
do_lower_case: bool = True
|
||||||
|
tflite_input_name: Dict[str, str] = dataclasses.field(
|
||||||
|
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
|
||||||
|
uri: str = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1'
|
||||||
|
name: str = 'Bert'
|
146
mediapipe/model_maker/python/text/text_classifier/BUILD
Normal file
146
mediapipe/model_maker/python/text/text_classifier/BUILD
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_options",
|
||||||
|
srcs = ["model_options.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_spec",
|
||||||
|
srcs = ["model_spec.py"],
|
||||||
|
deps = [
|
||||||
|
":model_options",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
"//mediapipe/model_maker/python/text/core:bert_model_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "model_spec_test",
|
||||||
|
srcs = ["model_spec_test.py"],
|
||||||
|
deps = [
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dataset",
|
||||||
|
srcs = ["dataset.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dataset_test",
|
||||||
|
srcs = ["dataset_test.py"],
|
||||||
|
deps = [":dataset"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "preprocessor",
|
||||||
|
srcs = ["preprocessor.py"],
|
||||||
|
deps = [":dataset"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "preprocessor_test",
|
||||||
|
srcs = ["preprocessor_test.py"],
|
||||||
|
tags = ["requires-net:external"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":model_spec",
|
||||||
|
":preprocessor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "text_classifier_options",
|
||||||
|
srcs = ["text_classifier_options.py"],
|
||||||
|
deps = [
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "text_classifier",
|
||||||
|
srcs = ["text_classifier.py"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
":preprocessor",
|
||||||
|
":text_classifier_options",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
|
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:text_classifier",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "text_classifier_test",
|
||||||
|
size = "large",
|
||||||
|
srcs = ["text_classifier_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/model_maker/python/text/text_classifier/testdata",
|
||||||
|
],
|
||||||
|
tags = ["requires-net:external"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
":text_classifier",
|
||||||
|
":text_classifier_options",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "text_classifier_demo_lib",
|
||||||
|
srcs = ["text_classifier_demo.py"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":model_spec",
|
||||||
|
":text_classifier",
|
||||||
|
":text_classifier_options",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "text_classifier_demo",
|
||||||
|
srcs = ["text_classifier_demo.py"],
|
||||||
|
deps = [
|
||||||
|
":text_classifier_demo_lib",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
88
mediapipe/model_maker/python/text/text_classifier/dataset.py
Normal file
88
mediapipe/model_maker/python/text/text_classifier/dataset.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Text classifier dataset library."""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import dataclasses
|
||||||
|
import random
|
||||||
|
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CSVParameters:
|
||||||
|
"""Parameters used when reading a CSV file.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
text_column: Column name for the input text.
|
||||||
|
label_column: Column name for the labels.
|
||||||
|
fieldnames: Sequence of keys for the CSV columns. If None, the first row of
|
||||||
|
the CSV file is used as the keys.
|
||||||
|
delimiter: Character that separates fields.
|
||||||
|
quotechar: Character used to quote fields that contain special characters
|
||||||
|
like the `delimiter`.
|
||||||
|
"""
|
||||||
|
text_column: str
|
||||||
|
label_column: str
|
||||||
|
fieldnames: Optional[Sequence[str]] = None
|
||||||
|
delimiter: str = ","
|
||||||
|
quotechar: str = '"'
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
|
"""Dataset library for text classifier."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_csv(cls,
|
||||||
|
filename: str,
|
||||||
|
csv_params: CSVParameters,
|
||||||
|
shuffle: bool = True) -> "Dataset":
|
||||||
|
"""Loads text with labels from a CSV file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Name of the CSV file.
|
||||||
|
csv_params: Parameters used for reading the CSV file.
|
||||||
|
shuffle: If True, randomly shuffle the data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing (text, label) pairs and other related info.
|
||||||
|
"""
|
||||||
|
with tf.io.gfile.GFile(filename, "r") as f:
|
||||||
|
reader = csv.DictReader(
|
||||||
|
f,
|
||||||
|
fieldnames=csv_params.fieldnames,
|
||||||
|
delimiter=csv_params.delimiter,
|
||||||
|
quotechar=csv_params.quotechar)
|
||||||
|
|
||||||
|
lines = list(reader)
|
||||||
|
if shuffle:
|
||||||
|
random.shuffle(lines)
|
||||||
|
|
||||||
|
label_names = sorted(set([line[csv_params.label_column] for line in lines]))
|
||||||
|
index_by_label = {label: index for index, label in enumerate(label_names)}
|
||||||
|
|
||||||
|
texts = [line[csv_params.text_column] for line in lines]
|
||||||
|
text_ds = tf.data.Dataset.from_tensor_slices(tf.cast(texts, tf.string))
|
||||||
|
label_indices = [
|
||||||
|
index_by_label[line[csv_params.label_column]] for line in lines
|
||||||
|
]
|
||||||
|
label_index_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
|
tf.cast(label_indices, tf.int64))
|
||||||
|
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
|
||||||
|
|
||||||
|
return Dataset(
|
||||||
|
dataset=text_label_ds, size=len(texts), label_names=label_names)
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def _get_csv_file(self):
|
||||||
|
labels_and_text = (('neutral', 'indifferent'), ('pos', 'extremely great'),
|
||||||
|
('neg', 'totally awful'), ('pos', 'super good'),
|
||||||
|
('neg', 'really bad'))
|
||||||
|
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
|
||||||
|
if os.path.exists(csv_file):
|
||||||
|
return csv_file
|
||||||
|
fieldnames = ['text', 'label']
|
||||||
|
with open(csv_file, 'w') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for label, text in labels_and_text:
|
||||||
|
writer.writerow({'text': text, 'label': label})
|
||||||
|
return csv_file
|
||||||
|
|
||||||
|
def test_from_csv(self):
|
||||||
|
csv_file = self._get_csv_file()
|
||||||
|
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
|
||||||
|
data = dataset.Dataset.from_csv(filename=csv_file, csv_params=csv_params)
|
||||||
|
self.assertLen(data, 5)
|
||||||
|
self.assertEqual(data.num_classes, 3)
|
||||||
|
self.assertEqual(data.label_names, ['neg', 'neutral', 'pos'])
|
||||||
|
data_values = set([(text.numpy()[0], label.numpy()[0])
|
||||||
|
for text, label in data.gen_tf_dataset()])
|
||||||
|
expected_data_values = set([(b'indifferent', 1), (b'extremely great', 2),
|
||||||
|
(b'totally awful', 0), (b'super good', 2),
|
||||||
|
(b'really bad', 0)])
|
||||||
|
self.assertEqual(data_values, expected_data_values)
|
||||||
|
|
||||||
|
def test_split(self):
|
||||||
|
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
|
||||||
|
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
|
||||||
|
train_data, test_data = data.split(0.5)
|
||||||
|
expected_train_data = [b'good', b'bad']
|
||||||
|
expected_test_data = [b'neutral', b'odd']
|
||||||
|
|
||||||
|
self.assertLen(train_data, 2)
|
||||||
|
train_data_values = [elem.numpy() for elem in train_data._dataset]
|
||||||
|
self.assertEqual(train_data_values, expected_train_data)
|
||||||
|
self.assertEqual(train_data.num_classes, 2)
|
||||||
|
self.assertEqual(train_data.label_names, ['pos', 'neg'])
|
||||||
|
|
||||||
|
self.assertLen(test_data, 2)
|
||||||
|
test_data_values = [elem.numpy() for elem in test_data._dataset]
|
||||||
|
self.assertEqual(test_data_values, expected_test_data)
|
||||||
|
self.assertEqual(test_data.num_classes, 2)
|
||||||
|
self.assertEqual(test_data.label_names, ['pos', 'neg'])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Configurable model options for text classifier models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.text.core import bert_model_options
|
||||||
|
|
||||||
|
# BERT text classifier options inherited from BertModelOptions.
|
||||||
|
BertClassifierOptions = bert_model_options.BertModelOptions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AverageWordEmbeddingClassifierOptions:
|
||||||
|
"""Configurable model options for an Average Word Embedding classifier.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
seq_len: Length of the sequence to feed into the model.
|
||||||
|
wordvec_dim: Dimension of the word embedding.
|
||||||
|
do_lower_case: Whether to convert all uppercase characters to lowercase
|
||||||
|
during preprocessing.
|
||||||
|
vocab_size: Number of words to generate the vocabulary from data.
|
||||||
|
dropout_rate: The rate for dropout.
|
||||||
|
"""
|
||||||
|
seq_len: int = 256
|
||||||
|
wordvec_dim: int = 16
|
||||||
|
do_lower_case: bool = True
|
||||||
|
vocab_size: int = 10000
|
||||||
|
dropout_rate: float = 0.2
|
||||||
|
|
||||||
|
|
||||||
|
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierOptions,
|
||||||
|
BertClassifierOptions]
|
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Specifications for text classifier models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import enum
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.text.core import bert_model_spec
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
|
|
||||||
|
# BERT-based text classifier spec inherited from BertModelSpec
|
||||||
|
BertClassifierSpec = bert_model_spec.BertModelSpec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AverageWordEmbeddingClassifierSpec:
|
||||||
|
"""Specification for an average word embedding classifier model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
hparams: Configurable hyperparameters for training.
|
||||||
|
model_options: Configurable options for the average word embedding model.
|
||||||
|
name: The name of the object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# `learning_rate` is unused for the average word embedding model
|
||||||
|
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||||
|
epochs=10, batch_size=32, learning_rate=0)
|
||||||
|
model_options: mo.AverageWordEmbeddingClassifierOptions = (
|
||||||
|
mo.AverageWordEmbeddingClassifierOptions())
|
||||||
|
name: str = 'AverageWordEmbedding'
|
||||||
|
|
||||||
|
|
||||||
|
average_word_embedding_classifier_spec = functools.partial(
|
||||||
|
AverageWordEmbeddingClassifierSpec)
|
||||||
|
|
||||||
|
mobilebert_classifier_spec = functools.partial(
|
||||||
|
BertClassifierSpec,
|
||||||
|
hparams=hp.BaseHParams(
|
||||||
|
epochs=3,
|
||||||
|
batch_size=48,
|
||||||
|
learning_rate=3e-5,
|
||||||
|
distribution_strategy='off'),
|
||||||
|
name='MobileBert',
|
||||||
|
uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1',
|
||||||
|
tflite_input_name={
|
||||||
|
'ids': 'serving_default_input_1:0',
|
||||||
|
'mask': 'serving_default_input_3:0',
|
||||||
|
'segment_ids': 'serving_default_input_2:0'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@enum.unique
|
||||||
|
class SupportedModels(enum.Enum):
|
||||||
|
"""Predefined text classifier model specs supported by Model Maker."""
|
||||||
|
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
|
||||||
|
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for model_spec."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSpecTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_predefined_bert_spec(self):
|
||||||
|
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
|
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
|
||||||
|
self.assertEqual(model_spec_obj.name, 'MobileBert')
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.uri, 'https://tfhub.dev/tensorflow/'
|
||||||
|
'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1')
|
||||||
|
self.assertTrue(model_spec_obj.do_lower_case)
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.tflite_input_name, {
|
||||||
|
'ids': 'serving_default_input_1:0',
|
||||||
|
'mask': 'serving_default_input_3:0',
|
||||||
|
'segment_ids': 'serving_default_input_2:0'
|
||||||
|
})
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.model_options,
|
||||||
|
classifier_model_options.BertClassifierOptions(
|
||||||
|
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.hparams,
|
||||||
|
hp.BaseHParams(
|
||||||
|
epochs=3,
|
||||||
|
batch_size=48,
|
||||||
|
learning_rate=3e-5,
|
||||||
|
distribution_strategy='off'))
|
||||||
|
|
||||||
|
def test_predefined_average_word_embedding_spec(self):
|
||||||
|
model_spec_obj = (
|
||||||
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value())
|
||||||
|
self.assertIsInstance(model_spec_obj, ms.AverageWordEmbeddingClassifierSpec)
|
||||||
|
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.model_options,
|
||||||
|
classifier_model_options.AverageWordEmbeddingClassifierOptions(
|
||||||
|
seq_len=256,
|
||||||
|
wordvec_dim=16,
|
||||||
|
do_lower_case=True,
|
||||||
|
vocab_size=10000,
|
||||||
|
dropout_rate=0.2))
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.hparams,
|
||||||
|
hp.BaseHParams(
|
||||||
|
epochs=10,
|
||||||
|
batch_size=32,
|
||||||
|
learning_rate=0,
|
||||||
|
steps_per_epoch=None,
|
||||||
|
shuffle=False,
|
||||||
|
distribution_strategy='off',
|
||||||
|
num_gpus=-1,
|
||||||
|
tpu=''))
|
||||||
|
|
||||||
|
def test_custom_bert_spec(self):
|
||||||
|
custom_bert_classifier_options = (
|
||||||
|
classifier_model_options.BertClassifierOptions(
|
||||||
|
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
|
||||||
|
model_spec_obj = (
|
||||||
|
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
|
||||||
|
model_options=custom_bert_classifier_options))
|
||||||
|
self.assertEqual(model_spec_obj.model_options,
|
||||||
|
custom_bert_classifier_options)
|
||||||
|
|
||||||
|
def test_custom_average_word_embedding_spec(self):
|
||||||
|
custom_hparams = hp.BaseHParams(
|
||||||
|
learning_rate=0.4,
|
||||||
|
batch_size=64,
|
||||||
|
epochs=10,
|
||||||
|
steps_per_epoch=10,
|
||||||
|
shuffle=True,
|
||||||
|
export_dir='foo/bar',
|
||||||
|
distribution_strategy='mirrored',
|
||||||
|
num_gpus=3,
|
||||||
|
tpu='tpu/address')
|
||||||
|
custom_average_word_embedding_model_options = (
|
||||||
|
classifier_model_options.AverageWordEmbeddingClassifierOptions(
|
||||||
|
seq_len=512,
|
||||||
|
wordvec_dim=32,
|
||||||
|
do_lower_case=False,
|
||||||
|
vocab_size=5000,
|
||||||
|
dropout_rate=0.5))
|
||||||
|
model_spec_obj = (
|
||||||
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value(
|
||||||
|
model_options=custom_average_word_embedding_model_options,
|
||||||
|
hparams=custom_hparams))
|
||||||
|
self.assertEqual(model_spec_obj.model_options,
|
||||||
|
custom_average_word_embedding_model_options)
|
||||||
|
self.assertEqual(model_spec_obj.hparams, custom_hparams)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Load compressed models from tensorflow_hub
|
||||||
|
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,285 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Preprocessors for text classification."""
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
from typing import Mapping, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_hub
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||||
|
from official.nlp.data import classifier_data_lib
|
||||||
|
from official.nlp.tools import tokenization
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None:
|
||||||
|
"""Validates the shape and type of `text` and `label`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Stores text data. Should have shape [1] and dtype tf.string.
|
||||||
|
label: Stores the label for the corresponding `text`. Should have shape [1]
|
||||||
|
and dtype tf.int64.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If either tensor has the wrong shape or type.
|
||||||
|
"""
|
||||||
|
if text.shape != [1]:
|
||||||
|
raise ValueError(f"`text` should have shape [1], got {text.shape}")
|
||||||
|
if text.dtype != tf.string:
|
||||||
|
raise ValueError(f"Expected dtype string for `text`, got {text.dtype}")
|
||||||
|
if label.shape != [1]:
|
||||||
|
raise ValueError(f"`label` should have shape [1], got {text.shape}")
|
||||||
|
if label.dtype != tf.int64:
|
||||||
|
raise ValueError(f"Expected dtype int64 for `label`, got {label.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_record(
|
||||||
|
record: tf.Tensor, name_to_features: Mapping[str, tf.io.FixedLenFeature]
|
||||||
|
) -> Tuple[Mapping[str, tf.Tensor], tf.Tensor]:
|
||||||
|
"""Decodes a record into input for a BERT model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record: Stores serialized example.
|
||||||
|
name_to_features: Maps record keys to feature types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BERT model input features and label for the record.
|
||||||
|
"""
|
||||||
|
example = tf.io.parse_single_example(record, name_to_features)
|
||||||
|
|
||||||
|
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
||||||
|
for name in list(example.keys()):
|
||||||
|
example[name] = tf.cast(example[name], tf.int32)
|
||||||
|
|
||||||
|
bert_features = {
|
||||||
|
"input_word_ids": example["input_ids"],
|
||||||
|
"input_mask": example["input_mask"],
|
||||||
|
"input_type_ids": example["segment_ids"]
|
||||||
|
}
|
||||||
|
return bert_features, example["label_ids"]
|
||||||
|
|
||||||
|
|
||||||
|
def _single_file_dataset(
|
||||||
|
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature]
|
||||||
|
) -> tf.data.TFRecordDataset:
|
||||||
|
"""Creates a single-file dataset to be passed for BERT custom training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_file: Filepath for the dataset.
|
||||||
|
name_to_features: Maps record keys to feature types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing BERT model input features and labels.
|
||||||
|
"""
|
||||||
|
d = tf.data.TFRecordDataset(input_file)
|
||||||
|
d = d.map(
|
||||||
|
lambda record: _decode_record(record, name_to_features),
|
||||||
|
num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
class AverageWordEmbeddingClassifierPreprocessor:
|
||||||
|
"""Preprocessor for an Average Word Embedding model.
|
||||||
|
|
||||||
|
Takes (text, label) data and applies regex tokenization and padding to the
|
||||||
|
text to generate (token IDs, label) data.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
seq_len: Length of the input sequence to the model.
|
||||||
|
do_lower_case: Whether text inputs should be converted to lower-case.
|
||||||
|
vocab: Vocabulary of tokens used by the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PAD: str = "<PAD>" # Index: 0
|
||||||
|
START: str = "<START>" # Index: 1
|
||||||
|
UNKNOWN: str = "<UNKNOWN>" # Index: 2
|
||||||
|
|
||||||
|
def __init__(self, seq_len: int, do_lower_case: bool, texts: Sequence[str],
|
||||||
|
vocab_size: int):
|
||||||
|
self._seq_len = seq_len
|
||||||
|
self._do_lower_case = do_lower_case
|
||||||
|
self._vocab = self._gen_vocab(texts, vocab_size)
|
||||||
|
|
||||||
|
def _gen_vocab(self, texts: Sequence[str],
|
||||||
|
vocab_size: int) -> Mapping[str, int]:
|
||||||
|
"""Generates vocabulary list in `texts` with size `vocab_size`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: All texts (across training and validation data) that will be
|
||||||
|
preprocessed by the model.
|
||||||
|
vocab_size: Size of the vocab.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The vocab mapping tokens to IDs.
|
||||||
|
"""
|
||||||
|
vocab_counter = collections.Counter()
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
tokens = self._regex_tokenize(text)
|
||||||
|
for token in tokens:
|
||||||
|
vocab_counter[token] += 1
|
||||||
|
|
||||||
|
vocab_freq = vocab_counter.most_common(vocab_size)
|
||||||
|
vocab_list = [self.PAD, self.START, self.UNKNOWN
|
||||||
|
] + [word for word, _ in vocab_freq]
|
||||||
|
return collections.OrderedDict(((v, i) for i, v in enumerate(vocab_list)))
|
||||||
|
|
||||||
|
def get_vocab(self) -> Mapping[str, int]:
|
||||||
|
"""Returns the vocab of the AverageWordEmbeddingClassifierPreprocessor."""
|
||||||
|
return self._vocab
|
||||||
|
|
||||||
|
# TODO: Align with MediaPipe's RegexTokenizer.
|
||||||
|
def _regex_tokenize(self, text: str) -> Sequence[str]:
|
||||||
|
"""Splits `text` by words but does not split on single quotes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to be tokenized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tokens.
|
||||||
|
"""
|
||||||
|
text = tf.compat.as_text(text)
|
||||||
|
if self._do_lower_case:
|
||||||
|
text = text.lower()
|
||||||
|
tokens = re.compile(r"[^\w\']+").split(text.strip())
|
||||||
|
# Filters out any empty strings in `tokens`.
|
||||||
|
return list(filter(None, tokens))
|
||||||
|
|
||||||
|
def _tokenize_and_pad(self, text: str) -> Sequence[int]:
|
||||||
|
"""Tokenizes `text` and pads the tokens to `seq_len`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to be tokenized and padded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of token IDs padded to have length `seq_len`.
|
||||||
|
"""
|
||||||
|
tokens = self._regex_tokenize(text)
|
||||||
|
|
||||||
|
# Gets ids for START, PAD and UNKNOWN tokens.
|
||||||
|
start_id = self._vocab[self.START]
|
||||||
|
pad_id = self._vocab[self.PAD]
|
||||||
|
unknown_id = self._vocab[self.UNKNOWN]
|
||||||
|
|
||||||
|
token_ids = [self._vocab.get(token, unknown_id) for token in tokens]
|
||||||
|
token_ids = [start_id] + token_ids
|
||||||
|
|
||||||
|
if len(token_ids) < self._seq_len:
|
||||||
|
pad_length = self._seq_len - len(token_ids)
|
||||||
|
token_ids = token_ids + pad_length * [pad_id]
|
||||||
|
else:
|
||||||
|
token_ids = token_ids[:self._seq_len]
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
|
||||||
|
"""Preprocesses data into input for an Average Word Embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Stores (text, label) data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing (token IDs, label) data.
|
||||||
|
"""
|
||||||
|
token_ids_list = []
|
||||||
|
labels_list = []
|
||||||
|
for text, label in dataset.gen_tf_dataset():
|
||||||
|
_validate_text_and_label(text, label)
|
||||||
|
token_ids = self._tokenize_and_pad(text.numpy()[0].decode("utf-8"))
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
labels_list.append(label.numpy()[0])
|
||||||
|
|
||||||
|
token_ids_ds = tf.data.Dataset.from_tensor_slices(token_ids_list)
|
||||||
|
labels_ds = tf.data.Dataset.from_tensor_slices(labels_list)
|
||||||
|
preprocessed_ds = tf.data.Dataset.zip((token_ids_ds, labels_ds))
|
||||||
|
return text_classifier_ds.Dataset(
|
||||||
|
dataset=preprocessed_ds,
|
||||||
|
size=dataset.size,
|
||||||
|
label_names=dataset.label_names)
|
||||||
|
|
||||||
|
|
||||||
|
class BertClassifierPreprocessor:
|
||||||
|
"""Preprocessor for a BERT-based classifier.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
seq_len: Length of the input sequence to the model.
|
||||||
|
vocab_file: File containing the BERT vocab.
|
||||||
|
tokenizer: BERT tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, seq_len: int, do_lower_case: bool, uri: str):
|
||||||
|
self._seq_len = seq_len
|
||||||
|
# Vocab filepath is tied to the BERT module's URI.
|
||||||
|
self._vocab_file = os.path.join(
|
||||||
|
tensorflow_hub.resolve(uri), "assets", "vocab.txt")
|
||||||
|
self._tokenizer = tokenization.FullTokenizer(self._vocab_file,
|
||||||
|
do_lower_case)
|
||||||
|
|
||||||
|
def _get_name_to_features(self):
|
||||||
|
"""Gets the dictionary mapping record keys to feature types."""
|
||||||
|
return {
|
||||||
|
"input_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||||
|
"input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||||
|
"segment_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||||
|
"label_ids": tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_vocab_file(self) -> str:
|
||||||
|
"""Returns the vocab file of the BertClassifierPreprocessor."""
|
||||||
|
return self._vocab_file
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
|
||||||
|
"""Preprocesses data into input for a BERT-based classifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Stores (text, label) data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing (bert_features, label) data.
|
||||||
|
"""
|
||||||
|
examples = []
|
||||||
|
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
||||||
|
_validate_text_and_label(text, label)
|
||||||
|
examples.append(
|
||||||
|
classifier_data_lib.InputExample(
|
||||||
|
guid=str(index),
|
||||||
|
text_a=text.numpy()[0].decode("utf-8"),
|
||||||
|
text_b=None,
|
||||||
|
# InputExample expects the label name rather than the int ID
|
||||||
|
label=dataset.label_names[label.numpy()[0]]))
|
||||||
|
|
||||||
|
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord")
|
||||||
|
classifier_data_lib.file_based_convert_examples_to_features(
|
||||||
|
examples=examples,
|
||||||
|
label_list=dataset.label_names,
|
||||||
|
max_seq_length=self._seq_len,
|
||||||
|
tokenizer=self._tokenizer,
|
||||||
|
output_file=tfrecord_file)
|
||||||
|
preprocessed_ds = _single_file_dataset(tfrecord_file,
|
||||||
|
self._get_name_to_features())
|
||||||
|
return text_classifier_ds.Dataset(
|
||||||
|
dataset=preprocessed_ds,
|
||||||
|
size=dataset.size,
|
||||||
|
label_names=dataset.label_names)
|
||||||
|
|
||||||
|
|
||||||
|
TextClassifierPreprocessor = (
|
||||||
|
Union[BertClassifierPreprocessor,
|
||||||
|
AverageWordEmbeddingClassifierPreprocessor])
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.testing as npt
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessorTest(tf.test.TestCase):
|
||||||
|
CSV_PARAMS_ = text_classifier_ds.CSVParameters(
|
||||||
|
text_column='text', label_column='label')
|
||||||
|
|
||||||
|
def _get_csv_file(self):
|
||||||
|
labels_and_text = (('pos', 'super super super super good'),
|
||||||
|
(('neg', 'really bad')))
|
||||||
|
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
|
||||||
|
if os.path.exists(csv_file):
|
||||||
|
return csv_file
|
||||||
|
fieldnames = ['text', 'label']
|
||||||
|
with open(csv_file, 'w') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for label, text in labels_and_text:
|
||||||
|
writer.writerow({'text': text, 'label': label})
|
||||||
|
return csv_file
|
||||||
|
|
||||||
|
def test_average_word_embedding_preprocessor(self):
|
||||||
|
csv_file = self._get_csv_file()
|
||||||
|
dataset = text_classifier_ds.Dataset.from_csv(
|
||||||
|
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||||
|
average_word_embedding_preprocessor = (
|
||||||
|
preprocessor.AverageWordEmbeddingClassifierPreprocessor(
|
||||||
|
seq_len=5,
|
||||||
|
do_lower_case=True,
|
||||||
|
texts=['super super super super good', 'really bad'],
|
||||||
|
vocab_size=7))
|
||||||
|
preprocessed_dataset = (
|
||||||
|
average_word_embedding_preprocessor.preprocess(dataset))
|
||||||
|
labels = []
|
||||||
|
features_list = []
|
||||||
|
for features, label in preprocessed_dataset.gen_tf_dataset():
|
||||||
|
self.assertEqual(label.shape, [1])
|
||||||
|
labels.append(label.numpy()[0])
|
||||||
|
self.assertEqual(features.shape, [1, 5])
|
||||||
|
features_list.append(features.numpy()[0])
|
||||||
|
self.assertEqual(labels, [1, 0])
|
||||||
|
npt.assert_array_equal(
|
||||||
|
np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]]))
|
||||||
|
|
||||||
|
def test_bert_preprocessor(self):
|
||||||
|
csv_file = self._get_csv_file()
|
||||||
|
dataset = text_classifier_ds.Dataset.from_csv(
|
||||||
|
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||||
|
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
|
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
|
seq_len=5, do_lower_case=bert_spec.do_lower_case, uri=bert_spec.uri)
|
||||||
|
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||||
|
labels = []
|
||||||
|
input_masks = []
|
||||||
|
for features, label in preprocessed_dataset.gen_tf_dataset():
|
||||||
|
self.assertEqual(label.shape, [1])
|
||||||
|
labels.append(label.numpy()[0])
|
||||||
|
self.assertSameElements(
|
||||||
|
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'])
|
||||||
|
for feature in features.values():
|
||||||
|
self.assertEqual(feature.shape, [1, 5])
|
||||||
|
input_masks.append(features['input_mask'].numpy()[0])
|
||||||
|
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
|
||||||
|
[0, 0, 0, 0, 0])
|
||||||
|
npt.assert_array_equal(
|
||||||
|
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]))
|
||||||
|
self.assertEqual(labels, [1, 0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Load compressed models from tensorflow_hub
|
||||||
|
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||||
|
tf.test.main()
|
23
mediapipe/model_maker/python/text/text_classifier/testdata/BUILD
vendored
Normal file
23
mediapipe/model_maker/python/text/text_classifier/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = ["average_word_embedding_metadata.json"],
|
||||||
|
)
|
63
mediapipe/model_maker/python/text/text_classifier/testdata/average_word_embedding_metadata.json
vendored
Normal file
63
mediapipe/model_maker/python/text/text_classifier/testdata/average_word_embedding_metadata.json
vendored
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
{
|
||||||
|
"name": "TextClassifier",
|
||||||
|
"description": "Classify the input text into a set of known categories.",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "input_text",
|
||||||
|
"description": "Embedding vectors representing the input text to be processed.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "RegexTokenizerOptions",
|
||||||
|
"options": {
|
||||||
|
"delim_regex_pattern": "[^\\w\\']+",
|
||||||
|
"vocab_file": [
|
||||||
|
{
|
||||||
|
"name": "vocab.txt",
|
||||||
|
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||||
|
"type": "VOCABULARY"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.2.1"
|
||||||
|
}
|
|
@ -0,0 +1,437 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""API for text classification."""
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
|
from mediapipe.model_maker.python.core.tasks import classifier
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer
|
||||||
|
from official.nlp import optimization
|
||||||
|
|
||||||
|
|
||||||
|
def _validate(options: text_classifier_options.TextClassifierOptions):
|
||||||
|
"""Validates that `model_options` and `supported_model` are compatible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for creating and training a text classifier.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if there is a mismatch between `model_options` and
|
||||||
|
`supported_model`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if options.model_options is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (isinstance(options.model_options,
|
||||||
|
mo.AverageWordEmbeddingClassifierOptions) and
|
||||||
|
(options.supported_model !=
|
||||||
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||||
|
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||||
|
f" got {options.supported_model}")
|
||||||
|
if (isinstance(options.model_options, mo.BertClassifierOptions) and
|
||||||
|
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
||||||
|
|
||||||
|
|
||||||
|
class TextClassifier(classifier.Classifier):
|
||||||
|
"""API for creating and training a text classification model."""
|
||||||
|
|
||||||
|
def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
|
||||||
|
label_names: Sequence[str]):
|
||||||
|
super().__init__(
|
||||||
|
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
||||||
|
self._model_spec = model_spec
|
||||||
|
self._hparams = hparams
|
||||||
|
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||||
|
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||||
|
options: text_classifier_options.TextClassifierOptions
|
||||||
|
) -> "TextClassifier":
|
||||||
|
"""Factory function that creates and trains a text classifier.
|
||||||
|
|
||||||
|
Note that `train_data` and `validation_data` are expected to share the same
|
||||||
|
`label_names` since they should be split from the same dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
options: Options for creating and training the text classifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A text classifier.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if `train_data` and `validation_data` do not have the
|
||||||
|
same label_names or `options` contains an unknown `supported_model`
|
||||||
|
"""
|
||||||
|
if train_data.label_names != validation_data.label_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Training data label names {train_data.label_names} not equal to "
|
||||||
|
f"validation data label names {validation_data.label_names}")
|
||||||
|
|
||||||
|
_validate(options)
|
||||||
|
if options.model_options is None:
|
||||||
|
options.model_options = options.supported_model.value().model_options
|
||||||
|
|
||||||
|
if options.hparams is None:
|
||||||
|
options.hparams = options.supported_model.value().hparams
|
||||||
|
|
||||||
|
if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
|
||||||
|
text_classifier = (
|
||||||
|
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
||||||
|
options,
|
||||||
|
train_data.label_names))
|
||||||
|
elif (options.supported_model ==
|
||||||
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||||
|
text_classifier = (
|
||||||
|
_AverageWordEmbeddingClassifier
|
||||||
|
.create_average_word_embedding_classifier(train_data, validation_data,
|
||||||
|
options,
|
||||||
|
train_data.label_names))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model {options.supported_model}")
|
||||||
|
|
||||||
|
return text_classifier
|
||||||
|
|
||||||
|
def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any:
|
||||||
|
"""Overrides Classifier.evaluate().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Evaluation dataset. Must be a TextClassifier Dataset.
|
||||||
|
batch_size: Number of samples per evaluation step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loss value and accuracy.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if `data` is not a TextClassifier Dataset.
|
||||||
|
"""
|
||||||
|
# This override is needed because TextClassifier preprocesses its data
|
||||||
|
# outside of the `gen_tf_dataset()` method. The preprocess call also
|
||||||
|
# requires a TextClassifier Dataset instead of a core Dataset.
|
||||||
|
if not isinstance(data, text_ds.Dataset):
|
||||||
|
raise ValueError("Need a TextClassifier Dataset.")
|
||||||
|
|
||||||
|
processed_data = self._text_preprocessor.preprocess(data)
|
||||||
|
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
||||||
|
return self._model.evaluate(dataset)
|
||||||
|
|
||||||
|
def export_model(
|
||||||
|
self,
|
||||||
|
model_name: str = "model.tflite",
|
||||||
|
quantization_config: Optional[quantization.QuantizationConfig] = None):
|
||||||
|
"""Converts and saves the model to a TFLite file with metadata included.
|
||||||
|
|
||||||
|
Note that only the TFLite file is needed for deployment. This function also
|
||||||
|
saves a metadata.json file to the same directory as the TFLite file which
|
||||||
|
can be used to interpret the metadata content in the TFLite file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: File name to save TFLite model with metadata. The full export
|
||||||
|
path is {self._hparams.export_dir}/{model_name}.
|
||||||
|
quantization_config: The configuration for model quantization.
|
||||||
|
"""
|
||||||
|
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||||
|
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||||
|
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||||
|
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
|
||||||
|
|
||||||
|
tflite_model = model_util.convert_to_tflite(
|
||||||
|
model=self._model, quantization_config=quantization_config)
|
||||||
|
vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt")
|
||||||
|
self._save_vocab(vocab_filepath)
|
||||||
|
|
||||||
|
writer = self._get_metadata_writer(tflite_model, vocab_filepath)
|
||||||
|
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||||
|
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||||
|
with open(metadata_file, "w") as f:
|
||||||
|
f.write(metadata_json)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _save_vocab(self, vocab_filepath: str):
|
||||||
|
"""Saves the preprocessor's vocab to `vocab_filepath`."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
|
||||||
|
"""Gets the metadata writer for the text classifier TFLite model."""
|
||||||
|
|
||||||
|
|
||||||
|
class _AverageWordEmbeddingClassifier(TextClassifier):
|
||||||
|
"""APIs to help create and train an Average Word Embedding text classifier."""
|
||||||
|
|
||||||
|
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||||
|
|
||||||
|
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
||||||
|
model_options: mo.AverageWordEmbeddingClassifierOptions,
|
||||||
|
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
||||||
|
super().__init__(model_spec, hparams, label_names)
|
||||||
|
self._model_options = model_options
|
||||||
|
self._loss_function = "sparse_categorical_crossentropy"
|
||||||
|
self._metric_function = "accuracy"
|
||||||
|
self._text_preprocessor: (
|
||||||
|
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_average_word_embedding_classifier(
|
||||||
|
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||||
|
options: text_classifier_options.TextClassifierOptions,
|
||||||
|
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier":
|
||||||
|
"""Creates, trains, and returns an Average Word Embedding classifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
options: Options for creating and training the text classifier.
|
||||||
|
label_names: Label names used in the data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An Average Word Embedding classifier.
|
||||||
|
"""
|
||||||
|
average_word_embedding_classifier = _AverageWordEmbeddingClassifier(
|
||||||
|
model_spec=options.supported_model.value(),
|
||||||
|
model_options=options.model_options,
|
||||||
|
hparams=options.hparams,
|
||||||
|
label_names=train_data.label_names)
|
||||||
|
average_word_embedding_classifier._create_and_train_model(
|
||||||
|
train_data, validation_data)
|
||||||
|
return average_word_embedding_classifier
|
||||||
|
|
||||||
|
def _create_and_train_model(self, train_data: text_ds.Dataset,
|
||||||
|
validation_data: text_ds.Dataset):
|
||||||
|
"""Creates the Average Word Embedding classifier keras model and trains it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
"""
|
||||||
|
(processed_train_data, processed_validation_data) = (
|
||||||
|
self._load_and_run_preprocessor(train_data, validation_data))
|
||||||
|
self._create_model()
|
||||||
|
self._optimizer = "rmsprop"
|
||||||
|
self._train_model(processed_train_data, processed_validation_data)
|
||||||
|
|
||||||
|
def _load_and_run_preprocessor(
|
||||||
|
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
|
||||||
|
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
|
||||||
|
"""Runs an AverageWordEmbeddingClassifierPreprocessor on the data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preprocessed training data and preprocessed validation data.
|
||||||
|
"""
|
||||||
|
train_texts = [text.numpy()[0] for text, _ in train_data.gen_tf_dataset()]
|
||||||
|
validation_texts = [
|
||||||
|
text.numpy()[0] for text, _ in validation_data.gen_tf_dataset()
|
||||||
|
]
|
||||||
|
self._text_preprocessor = (
|
||||||
|
preprocessor.AverageWordEmbeddingClassifierPreprocessor(
|
||||||
|
seq_len=self._model_options.seq_len,
|
||||||
|
do_lower_case=self._model_options.do_lower_case,
|
||||||
|
texts=train_texts + validation_texts,
|
||||||
|
vocab_size=self._model_options.vocab_size))
|
||||||
|
return self._text_preprocessor.preprocess(
|
||||||
|
train_data), self._text_preprocessor.preprocess(validation_data)
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
"""Creates an Average Word Embedding model."""
|
||||||
|
self._model = tf.keras.Sequential([
|
||||||
|
tf.keras.layers.InputLayer(
|
||||||
|
input_shape=[self._model_options.seq_len], dtype=tf.int32),
|
||||||
|
tf.keras.layers.Embedding(
|
||||||
|
len(self._text_preprocessor.get_vocab()),
|
||||||
|
self._model_options.wordvec_dim,
|
||||||
|
input_length=self._model_options.seq_len),
|
||||||
|
tf.keras.layers.GlobalAveragePooling1D(),
|
||||||
|
tf.keras.layers.Dense(
|
||||||
|
self._model_options.wordvec_dim, activation=tf.nn.relu),
|
||||||
|
tf.keras.layers.Dropout(self._model_options.dropout_rate),
|
||||||
|
tf.keras.layers.Dense(self._num_classes, activation="softmax")
|
||||||
|
])
|
||||||
|
|
||||||
|
def _save_vocab(self, vocab_filepath: str):
|
||||||
|
with tf.io.gfile.GFile(vocab_filepath, "w") as f:
|
||||||
|
for token, index in self._text_preprocessor.get_vocab().items():
|
||||||
|
f.write(f"{token} {index}\n")
|
||||||
|
|
||||||
|
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
|
||||||
|
return text_classifier_writer.MetadataWriter.create_for_regex_model(
|
||||||
|
model_buffer=tflite_model,
|
||||||
|
regex_tokenizer=metadata_writer.RegexTokenizer(
|
||||||
|
# TODO: Align with MediaPipe's RegexTokenizer.
|
||||||
|
delim_regex_pattern=self._DELIM_REGEX_PATTERN,
|
||||||
|
vocab_file_path=vocab_filepath),
|
||||||
|
labels=metadata_writer.Labels().add(list(self._label_names)))
|
||||||
|
|
||||||
|
|
||||||
|
class _BertClassifier(TextClassifier):
|
||||||
|
"""APIs to help create and train a BERT-based text classifier."""
|
||||||
|
|
||||||
|
_INITIALIZER_RANGE = 0.02
|
||||||
|
|
||||||
|
def __init__(self, model_spec: ms.BertClassifierSpec,
|
||||||
|
model_options: mo.BertClassifierOptions, hparams: hp.BaseHParams,
|
||||||
|
label_names: Sequence[str]):
|
||||||
|
super().__init__(model_spec, hparams, label_names)
|
||||||
|
self._model_options = model_options
|
||||||
|
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||||
|
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
|
||||||
|
"test_accuracy", dtype=tf.float32)
|
||||||
|
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_bert_classifier(
|
||||||
|
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||||
|
options: text_classifier_options.TextClassifierOptions,
|
||||||
|
label_names: Sequence[str]) -> "_BertClassifier":
|
||||||
|
"""Creates, trains, and returns a BERT-based classifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
options: Options for creating and training the text classifier.
|
||||||
|
label_names: Label names used in the data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A BERT-based classifier.
|
||||||
|
"""
|
||||||
|
bert_classifier = _BertClassifier(
|
||||||
|
model_spec=options.supported_model.value(),
|
||||||
|
model_options=options.model_options,
|
||||||
|
hparams=options.hparams,
|
||||||
|
label_names=train_data.label_names)
|
||||||
|
bert_classifier._create_and_train_model(train_data, validation_data)
|
||||||
|
return bert_classifier
|
||||||
|
|
||||||
|
def _create_and_train_model(self, train_data: text_ds.Dataset,
|
||||||
|
validation_data: text_ds.Dataset):
|
||||||
|
"""Creates the BERT-based classifier keras model and trains it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
"""
|
||||||
|
(processed_train_data, processed_validation_data) = (
|
||||||
|
self._load_and_run_preprocessor(train_data, validation_data))
|
||||||
|
self._create_model()
|
||||||
|
self._create_optimizer(processed_train_data)
|
||||||
|
self._train_model(processed_train_data, processed_validation_data)
|
||||||
|
|
||||||
|
def _load_and_run_preprocessor(
|
||||||
|
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
|
||||||
|
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
|
||||||
|
"""Loads a BertClassifierPreprocessor and runs it on the data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preprocessed training data and preprocessed validation data.
|
||||||
|
"""
|
||||||
|
self._text_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
|
seq_len=self._model_options.seq_len,
|
||||||
|
do_lower_case=self._model_spec.do_lower_case,
|
||||||
|
uri=self._model_spec.uri)
|
||||||
|
return (self._text_preprocessor.preprocess(train_data),
|
||||||
|
self._text_preprocessor.preprocess(validation_data))
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
"""Creates a BERT-based classifier model.
|
||||||
|
|
||||||
|
The model architecture consists of stacking a dense classification layer and
|
||||||
|
dropout layer on top of the BERT encoder outputs.
|
||||||
|
"""
|
||||||
|
encoder_inputs = dict(
|
||||||
|
input_word_ids=tf.keras.layers.Input(
|
||||||
|
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||||
|
input_mask=tf.keras.layers.Input(
|
||||||
|
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||||
|
input_type_ids=tf.keras.layers.Input(
|
||||||
|
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||||
|
)
|
||||||
|
encoder = hub.KerasLayer(
|
||||||
|
self._model_spec.uri, trainable=self._model_options.do_fine_tuning)
|
||||||
|
encoder_outputs = encoder(encoder_inputs)
|
||||||
|
pooled_output = encoder_outputs["pooled_output"]
|
||||||
|
|
||||||
|
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
|
||||||
|
pooled_output)
|
||||||
|
initializer = tf.keras.initializers.TruncatedNormal(
|
||||||
|
stddev=self._INITIALIZER_RANGE)
|
||||||
|
output = tf.keras.layers.Dense(
|
||||||
|
self._num_classes,
|
||||||
|
kernel_initializer=initializer,
|
||||||
|
name="output",
|
||||||
|
activation="softmax",
|
||||||
|
dtype=tf.float32)(
|
||||||
|
output)
|
||||||
|
self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output)
|
||||||
|
|
||||||
|
def _create_optimizer(self, train_data: text_ds.Dataset):
|
||||||
|
"""Loads an optimizer with a learning rate schedule.
|
||||||
|
|
||||||
|
The decay steps in the learning rate schedule depend on the
|
||||||
|
`steps_per_epoch` which may depend on the size of the training data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
"""
|
||||||
|
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||||
|
steps_per_epoch=self._hparams.steps_per_epoch,
|
||||||
|
batch_size=self._hparams.batch_size,
|
||||||
|
train_data=train_data)
|
||||||
|
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
|
||||||
|
warmup_steps = int(total_steps * 0.1)
|
||||||
|
initial_lr = self._hparams.learning_rate
|
||||||
|
self._optimizer = optimization.create_optimizer(initial_lr, total_steps,
|
||||||
|
warmup_steps)
|
||||||
|
|
||||||
|
def _save_vocab(self, vocab_filepath: str):
|
||||||
|
tf.io.gfile.copy(
|
||||||
|
self._text_preprocessor.get_vocab_file(),
|
||||||
|
vocab_filepath,
|
||||||
|
overwrite=True)
|
||||||
|
|
||||||
|
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
|
||||||
|
return text_classifier_writer.MetadataWriter.create_for_bert_model(
|
||||||
|
model_buffer=tflite_model,
|
||||||
|
tokenizer=metadata_writer.BertTokenizer(vocab_filepath),
|
||||||
|
labels=metadata_writer.Labels().add(list(self._label_names)),
|
||||||
|
ids_name=self._model_spec.tflite_input_name["ids"],
|
||||||
|
mask_name=self._model_spec.tflite_input_name["mask"],
|
||||||
|
segment_name=self._model_spec.tflite_input_name["segment_ids"])
|
|
@ -0,0 +1,108 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Demo for making a text classifier model by MediaPipe Model Maker."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
def define_flags():
|
||||||
|
flags.DEFINE_string('export_dir', None,
|
||||||
|
'The directory to save exported files.')
|
||||||
|
flags.DEFINE_enum('supported_model', 'average_word_embedding',
|
||||||
|
['average_word_embedding', 'bert'],
|
||||||
|
'The text classifier to run.')
|
||||||
|
flags.mark_flag_as_required('export_dir')
|
||||||
|
|
||||||
|
|
||||||
|
def download_demo_data():
|
||||||
|
"""Downloads demo data, and returns directory path."""
|
||||||
|
data_path = tf.keras.utils.get_file(
|
||||||
|
fname='SST-2.zip',
|
||||||
|
origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
|
||||||
|
extract=True)
|
||||||
|
return os.path.join(os.path.dirname(data_path), 'SST-2') # folder name
|
||||||
|
|
||||||
|
|
||||||
|
def run(data_dir,
|
||||||
|
export_dir=tempfile.mkdtemp(),
|
||||||
|
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||||
|
"""Runs demo."""
|
||||||
|
|
||||||
|
# Gets training data and validation data.
|
||||||
|
csv_params = text_ds.CSVParameters(
|
||||||
|
text_column='sentence', label_column='label', delimiter='\t')
|
||||||
|
train_data = text_ds.Dataset.from_csv(
|
||||||
|
filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
|
||||||
|
csv_params=csv_params)
|
||||||
|
validation_data = text_ds.Dataset.from_csv(
|
||||||
|
filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
|
||||||
|
csv_params=csv_params)
|
||||||
|
|
||||||
|
quantization_config = None
|
||||||
|
if supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER:
|
||||||
|
hparams = hp.BaseHParams(
|
||||||
|
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
|
||||||
|
# Warning: This takes extremely long to run on CPU
|
||||||
|
elif supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
|
||||||
|
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
||||||
|
hparams = hp.BaseHParams(
|
||||||
|
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
|
||||||
|
|
||||||
|
# Fine-tunes the model.
|
||||||
|
options = text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=supported_model, hparams=hparams)
|
||||||
|
model = text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
|
options)
|
||||||
|
|
||||||
|
# Gets evaluation results.
|
||||||
|
_, acc = model.evaluate(validation_data)
|
||||||
|
print('Eval accuracy: %f' % acc)
|
||||||
|
|
||||||
|
model.export_model(quantization_config=quantization_config)
|
||||||
|
model.export_labels(export_dir=options.hparams.export_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
data_dir = download_demo_data()
|
||||||
|
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||||
|
|
||||||
|
if FLAGS.supported_model == 'average_word_embedding':
|
||||||
|
supported_model = ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||||
|
elif FLAGS.supported_model == 'bert':
|
||||||
|
supported_model = ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||||
|
|
||||||
|
run(data_dir, export_dir, supported_model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
define_flags()
|
||||||
|
app.run(main)
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""User-facing customization options to create and train a text classifier."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TextClassifierOptions:
|
||||||
|
"""User-facing options for creating the text classifier.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
supported_model: A preconfigured model spec.
|
||||||
|
hparams: Training hyperparameters the user can set to override the ones in
|
||||||
|
`supported_model`.
|
||||||
|
model_options: Model options the user can set to override the ones in
|
||||||
|
`supported_model`. The model options type should be consistent with the
|
||||||
|
architecture of the `supported_model`.
|
||||||
|
"""
|
||||||
|
supported_model: ms.SupportedModels
|
||||||
|
hparams: Optional[hp.BaseHParams] = None
|
||||||
|
model_options: Optional[mo.TextClassifierModelOptions] = None
|
|
@ -0,0 +1,138 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import filecmp
|
||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TextClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||||
|
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
|
||||||
|
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
|
||||||
|
if os.path.exists(csv_file):
|
||||||
|
return csv_file
|
||||||
|
fieldnames = ['text', 'label']
|
||||||
|
with open(csv_file, 'w') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for label, text in labels_and_text:
|
||||||
|
writer.writerow({'text': text, 'label': label})
|
||||||
|
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
|
||||||
|
all_data = dataset.Dataset.from_csv(
|
||||||
|
filename=csv_file, csv_params=csv_params)
|
||||||
|
return all_data.split(0.5)
|
||||||
|
|
||||||
|
def test_create_and_train_average_word_embedding_model(self):
|
||||||
|
train_data, validation_data = self._get_data()
|
||||||
|
options = text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER,
|
||||||
|
hparams=hp.BaseHParams(epochs=1, batch_size=1, learning_rate=0))
|
||||||
|
average_word_embedding_classifier = text_classifier.TextClassifier.create(
|
||||||
|
train_data, validation_data, options)
|
||||||
|
|
||||||
|
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
|
||||||
|
self.assertGreaterEqual(accuracy, 0.0)
|
||||||
|
|
||||||
|
# Test export_model
|
||||||
|
average_word_embedding_classifier.export_model()
|
||||||
|
output_metadata_file = os.path.join(options.hparams.export_dir,
|
||||||
|
'metadata.json')
|
||||||
|
output_tflite_file = os.path.join(options.hparams.export_dir,
|
||||||
|
'model.tflite')
|
||||||
|
|
||||||
|
self.assertTrue(os.path.exists(output_tflite_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||||
|
|
||||||
|
self.assertTrue(os.path.exists(output_metadata_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
|
self.assertTrue(
|
||||||
|
filecmp.cmp(output_metadata_file,
|
||||||
|
self._AVERAGE_WORD_EMBEDDING_JSON_FILE))
|
||||||
|
|
||||||
|
def test_create_and_train_bert(self):
|
||||||
|
train_data, validation_data = self._get_data()
|
||||||
|
options = text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||||
|
model_options=mo.BertClassifierOptions(do_fine_tuning=False, seq_len=2),
|
||||||
|
hparams=hp.BaseHParams(
|
||||||
|
epochs=1,
|
||||||
|
batch_size=1,
|
||||||
|
learning_rate=3e-5,
|
||||||
|
distribution_strategy='off'))
|
||||||
|
bert_classifier = text_classifier.TextClassifier.create(
|
||||||
|
train_data, validation_data, options)
|
||||||
|
|
||||||
|
_, accuracy = bert_classifier.evaluate(validation_data)
|
||||||
|
self.assertGreaterEqual(accuracy, 0.0)
|
||||||
|
# TODO: Add a unit test that does not run OOM.
|
||||||
|
|
||||||
|
def test_label_mismatch(self):
|
||||||
|
options = (
|
||||||
|
text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER))
|
||||||
|
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
|
train_data = dataset.Dataset(train_tf_dataset, 1, ['foo'])
|
||||||
|
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
|
validation_data = dataset.Dataset(validation_tf_dataset, 1, ['bar'])
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
'Training data label names .* not equal to validation data label names'
|
||||||
|
):
|
||||||
|
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
|
options)
|
||||||
|
|
||||||
|
def test_options_mismatch(self):
|
||||||
|
train_data, validation_data = self._get_data()
|
||||||
|
|
||||||
|
avg_options = (
|
||||||
|
text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||||
|
model_options=mo.AverageWordEmbeddingClassifierOptions()))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||||
|
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
||||||
|
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
|
avg_options)
|
||||||
|
|
||||||
|
bert_options = (
|
||||||
|
text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=(
|
||||||
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||||
|
model_options=mo.BertClassifierOptions()))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
||||||
|
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
||||||
|
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
|
bert_options)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Load compressed models from tensorflow_hub
|
||||||
|
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||||
|
tf.test.main()
|
Loading…
Reference in New Issue
Block a user