From 3e05871f980a7d78c9dc48c308890cbe82e27599 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 10 Nov 2022 19:51:08 -0800 Subject: [PATCH] Open source Model Maker text tasks. PiperOrigin-RevId: 487706929 --- mediapipe/model_maker/python/text/core/BUILD | 35 ++ .../model_maker/python/text/core/__init__.py | 13 + .../python/text/core/bert_model_options.py | 33 ++ .../python/text/core/bert_model_spec.py | 58 +++ .../python/text/text_classifier/BUILD | 146 ++++++ .../python/text/text_classifier/__init__.py | 13 + .../python/text/text_classifier/dataset.py | 88 ++++ .../text/text_classifier/dataset_test.py | 75 +++ .../text/text_classifier/model_options.py | 45 ++ .../python/text/text_classifier/model_spec.py | 70 +++ .../text/text_classifier/model_spec_test.py | 118 +++++ .../text/text_classifier/preprocessor.py | 285 ++++++++++++ .../text/text_classifier/preprocessor_test.py | 96 ++++ .../text/text_classifier/testdata/BUILD | 23 + .../average_word_embedding_metadata.json | 63 +++ .../text/text_classifier/text_classifier.py | 437 ++++++++++++++++++ .../text_classifier/text_classifier_demo.py | 108 +++++ .../text_classifier_options.py | 38 ++ .../text_classifier/text_classifier_test.py | 138 ++++++ 19 files changed, 1882 insertions(+) create mode 100644 mediapipe/model_maker/python/text/core/BUILD create mode 100644 mediapipe/model_maker/python/text/core/__init__.py create mode 100644 mediapipe/model_maker/python/text/core/bert_model_options.py create mode 100644 mediapipe/model_maker/python/text/core/bert_model_spec.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/BUILD create mode 100644 mediapipe/model_maker/python/text/text_classifier/__init__.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/dataset.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/dataset_test.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/model_options.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/model_spec.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/model_spec_test.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/preprocessor.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/testdata/BUILD create mode 100644 mediapipe/model_maker/python/text/text_classifier/testdata/average_word_embedding_metadata.json create mode 100644 mediapipe/model_maker/python/text/text_classifier/text_classifier.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/text_classifier_options.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py diff --git a/mediapipe/model_maker/python/text/core/BUILD b/mediapipe/model_maker/python/text/core/BUILD new file mode 100644 index 000000000..db06c3a75 --- /dev/null +++ b/mediapipe/model_maker/python/text/core/BUILD @@ -0,0 +1,35 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +licenses(["notice"]) + +py_library( + name = "bert_model_options", + srcs = ["bert_model_options.py"], +) + +py_library( + name = "bert_model_spec", + srcs = ["bert_model_spec.py"], + deps = [ + ":bert_model_options", + "//mediapipe/model_maker/python/core:hyperparameters", + ], +) diff --git a/mediapipe/model_maker/python/text/core/__init__.py b/mediapipe/model_maker/python/text/core/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/text/core/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/text/core/bert_model_options.py b/mediapipe/model_maker/python/text/core/bert_model_options.py new file mode 100644 index 000000000..ce5ef6af4 --- /dev/null +++ b/mediapipe/model_maker/python/text/core/bert_model_options.py @@ -0,0 +1,33 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configurable model options for a BERT model.""" + +import dataclasses + + +@dataclasses.dataclass +class BertModelOptions: + """Configurable model options for a BERT model. + + See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional + Transformers for Language Understanding) for more details. + + Attributes: + seq_len: Length of the sequence to feed into the model. + do_fine_tuning: If true, then the BERT model is not frozen for training. + dropout_rate: The rate for dropout. + """ + seq_len: int = 128 + do_fine_tuning: bool = True + dropout_rate: float = 0.1 diff --git a/mediapipe/model_maker/python/text/core/bert_model_spec.py b/mediapipe/model_maker/python/text/core/bert_model_spec.py new file mode 100644 index 000000000..6c0085617 --- /dev/null +++ b/mediapipe/model_maker/python/text/core/bert_model_spec.py @@ -0,0 +1,58 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Specification for a BERT model.""" + +import dataclasses +from typing import Dict + +from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.text.core import bert_model_options + +_DEFAULT_TFLITE_INPUT_NAME = { + 'ids': 'serving_default_input_word_ids:0', + 'mask': 'serving_default_input_mask:0', + 'segment_ids': 'serving_default_input_type_ids:0' +} + + +@dataclasses.dataclass +class BertModelSpec: + """Specification for a BERT model. + + See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional + Transformers for Language Understanding) for more details. + + Attributes: + hparams: Hyperparameters used for training. + model_options: Configurable options for a BERT model. + do_lower_case: boolean, whether to lower case the input text. Should be + True / False for uncased / cased models respectively, where the models + are specified by the `uri`. + tflite_input_name: Dict, input names for the TFLite model. + uri: URI for the BERT module. + name: The name of the object. + """ + + hparams: hp.BaseHParams = hp.BaseHParams( + epochs=3, + batch_size=32, + learning_rate=3e-5, + distribution_strategy='mirrored') + model_options: bert_model_options.BertModelOptions = ( + bert_model_options.BertModelOptions()) + do_lower_case: bool = True + tflite_input_name: Dict[str, str] = dataclasses.field( + default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME) + uri: str = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1' + name: str = 'Bert' diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD new file mode 100644 index 000000000..357263678 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -0,0 +1,146 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict test compatibility macro. + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +licenses(["notice"]) + +py_library( + name = "model_options", + srcs = ["model_options.py"], + deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"], +) + +py_library( + name = "model_spec", + srcs = ["model_spec.py"], + deps = [ + ":model_options", + "//mediapipe/model_maker/python/core:hyperparameters", + "//mediapipe/model_maker/python/text/core:bert_model_spec", + ], +) + +py_test( + name = "model_spec_test", + srcs = ["model_spec_test.py"], + deps = [ + ":model_options", + ":model_spec", + "//mediapipe/model_maker/python/core:hyperparameters", + ], +) + +py_library( + name = "dataset", + srcs = ["dataset.py"], + deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"], +) + +py_test( + name = "dataset_test", + srcs = ["dataset_test.py"], + deps = [":dataset"], +) + +py_library( + name = "preprocessor", + srcs = ["preprocessor.py"], + deps = [":dataset"], +) + +py_test( + name = "preprocessor_test", + srcs = ["preprocessor_test.py"], + tags = ["requires-net:external"], + deps = [ + ":dataset", + ":model_spec", + ":preprocessor", + ], +) + +py_library( + name = "text_classifier_options", + srcs = ["text_classifier_options.py"], + deps = [ + ":model_options", + ":model_spec", + "//mediapipe/model_maker/python/core:hyperparameters", + ], +) + +py_library( + name = "text_classifier", + srcs = ["text_classifier.py"], + deps = [ + ":dataset", + ":model_options", + ":model_spec", + ":preprocessor", + ":text_classifier_options", + "//mediapipe/model_maker/python/core:hyperparameters", + "//mediapipe/model_maker/python/core/data:dataset", + "//mediapipe/model_maker/python/core/tasks:classifier", + "//mediapipe/model_maker/python/core/utils:model_util", + "//mediapipe/model_maker/python/core/utils:quantization", + "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", + "//mediapipe/tasks/python/metadata/metadata_writers:text_classifier", + ], +) + +py_test( + name = "text_classifier_test", + size = "large", + srcs = ["text_classifier_test.py"], + data = [ + "//mediapipe/model_maker/python/text/text_classifier/testdata", + ], + tags = ["requires-net:external"], + deps = [ + ":dataset", + ":model_options", + ":model_spec", + ":text_classifier", + ":text_classifier_options", + "//mediapipe/model_maker/python/core:hyperparameters", + "//mediapipe/tasks/python/test:test_utils", + ], +) + +py_library( + name = "text_classifier_demo_lib", + srcs = ["text_classifier_demo.py"], + deps = [ + ":dataset", + ":model_spec", + ":text_classifier", + ":text_classifier_options", + "//mediapipe/model_maker/python/core:hyperparameters", + "//mediapipe/model_maker/python/core/utils:quantization", + ], +) + +py_binary( + name = "text_classifier_demo", + srcs = ["text_classifier_demo.py"], + deps = [ + ":text_classifier_demo_lib", + ], +) diff --git a/mediapipe/model_maker/python/text/text_classifier/__init__.py b/mediapipe/model_maker/python/text/text_classifier/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/text/text_classifier/dataset.py b/mediapipe/model_maker/python/text/text_classifier/dataset.py new file mode 100644 index 000000000..3679b67ae --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/dataset.py @@ -0,0 +1,88 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Text classifier dataset library.""" + +import csv +import dataclasses +import random + +from typing import Optional, Sequence +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import classification_dataset + + +@dataclasses.dataclass +class CSVParameters: + """Parameters used when reading a CSV file. + + Attributes: + text_column: Column name for the input text. + label_column: Column name for the labels. + fieldnames: Sequence of keys for the CSV columns. If None, the first row of + the CSV file is used as the keys. + delimiter: Character that separates fields. + quotechar: Character used to quote fields that contain special characters + like the `delimiter`. + """ + text_column: str + label_column: str + fieldnames: Optional[Sequence[str]] = None + delimiter: str = "," + quotechar: str = '"' + + +class Dataset(classification_dataset.ClassificationDataset): + """Dataset library for text classifier.""" + + @classmethod + def from_csv(cls, + filename: str, + csv_params: CSVParameters, + shuffle: bool = True) -> "Dataset": + """Loads text with labels from a CSV file. + + Args: + filename: Name of the CSV file. + csv_params: Parameters used for reading the CSV file. + shuffle: If True, randomly shuffle the data. + + Returns: + Dataset containing (text, label) pairs and other related info. + """ + with tf.io.gfile.GFile(filename, "r") as f: + reader = csv.DictReader( + f, + fieldnames=csv_params.fieldnames, + delimiter=csv_params.delimiter, + quotechar=csv_params.quotechar) + + lines = list(reader) + if shuffle: + random.shuffle(lines) + + label_names = sorted(set([line[csv_params.label_column] for line in lines])) + index_by_label = {label: index for index, label in enumerate(label_names)} + + texts = [line[csv_params.text_column] for line in lines] + text_ds = tf.data.Dataset.from_tensor_slices(tf.cast(texts, tf.string)) + label_indices = [ + index_by_label[line[csv_params.label_column]] for line in lines + ] + label_index_ds = tf.data.Dataset.from_tensor_slices( + tf.cast(label_indices, tf.int64)) + text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds)) + + return Dataset( + dataset=text_label_ds, size=len(texts), label_names=label_names) diff --git a/mediapipe/model_maker/python/text/text_classifier/dataset_test.py b/mediapipe/model_maker/python/text/text_classifier/dataset_test.py new file mode 100644 index 000000000..ec9e8fa2d --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/dataset_test.py @@ -0,0 +1,75 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os + +import tensorflow as tf + +from mediapipe.model_maker.python.text.text_classifier import dataset + + +class DatasetTest(tf.test.TestCase): + + def _get_csv_file(self): + labels_and_text = (('neutral', 'indifferent'), ('pos', 'extremely great'), + ('neg', 'totally awful'), ('pos', 'super good'), + ('neg', 'really bad')) + csv_file = os.path.join(self.get_temp_dir(), 'data.csv') + if os.path.exists(csv_file): + return csv_file + fieldnames = ['text', 'label'] + with open(csv_file, 'w') as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for label, text in labels_and_text: + writer.writerow({'text': text, 'label': label}) + return csv_file + + def test_from_csv(self): + csv_file = self._get_csv_file() + csv_params = dataset.CSVParameters(text_column='text', label_column='label') + data = dataset.Dataset.from_csv(filename=csv_file, csv_params=csv_params) + self.assertLen(data, 5) + self.assertEqual(data.num_classes, 3) + self.assertEqual(data.label_names, ['neg', 'neutral', 'pos']) + data_values = set([(text.numpy()[0], label.numpy()[0]) + for text, label in data.gen_tf_dataset()]) + expected_data_values = set([(b'indifferent', 1), (b'extremely great', 2), + (b'totally awful', 0), (b'super good', 2), + (b'really bad', 0)]) + self.assertEqual(data_values, expected_data_values) + + def test_split(self): + ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd']) + data = dataset.Dataset(ds, 4, ['pos', 'neg']) + train_data, test_data = data.split(0.5) + expected_train_data = [b'good', b'bad'] + expected_test_data = [b'neutral', b'odd'] + + self.assertLen(train_data, 2) + train_data_values = [elem.numpy() for elem in train_data._dataset] + self.assertEqual(train_data_values, expected_train_data) + self.assertEqual(train_data.num_classes, 2) + self.assertEqual(train_data.label_names, ['pos', 'neg']) + + self.assertLen(test_data, 2) + test_data_values = [elem.numpy() for elem in test_data._dataset] + self.assertEqual(test_data_values, expected_test_data) + self.assertEqual(test_data.num_classes, 2) + self.assertEqual(test_data.label_names, ['pos', 'neg']) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/text/text_classifier/model_options.py b/mediapipe/model_maker/python/text/text_classifier/model_options.py new file mode 100644 index 000000000..b48e38da1 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/model_options.py @@ -0,0 +1,45 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configurable model options for text classifier models.""" + +import dataclasses +from typing import Union + +from mediapipe.model_maker.python.text.core import bert_model_options + +# BERT text classifier options inherited from BertModelOptions. +BertClassifierOptions = bert_model_options.BertModelOptions + + +@dataclasses.dataclass +class AverageWordEmbeddingClassifierOptions: + """Configurable model options for an Average Word Embedding classifier. + + Attributes: + seq_len: Length of the sequence to feed into the model. + wordvec_dim: Dimension of the word embedding. + do_lower_case: Whether to convert all uppercase characters to lowercase + during preprocessing. + vocab_size: Number of words to generate the vocabulary from data. + dropout_rate: The rate for dropout. + """ + seq_len: int = 256 + wordvec_dim: int = 16 + do_lower_case: bool = True + vocab_size: int = 10000 + dropout_rate: float = 0.2 + + +TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierOptions, + BertClassifierOptions] diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py new file mode 100644 index 000000000..c2694786c --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -0,0 +1,70 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Specifications for text classifier models.""" + +import dataclasses +import enum +import functools + +from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.text.core import bert_model_spec +from mediapipe.model_maker.python.text.text_classifier import model_options as mo + +# BERT-based text classifier spec inherited from BertModelSpec +BertClassifierSpec = bert_model_spec.BertModelSpec + + +@dataclasses.dataclass +class AverageWordEmbeddingClassifierSpec: + """Specification for an average word embedding classifier model. + + Attributes: + hparams: Configurable hyperparameters for training. + model_options: Configurable options for the average word embedding model. + name: The name of the object. + """ + + # `learning_rate` is unused for the average word embedding model + hparams: hp.BaseHParams = hp.BaseHParams( + epochs=10, batch_size=32, learning_rate=0) + model_options: mo.AverageWordEmbeddingClassifierOptions = ( + mo.AverageWordEmbeddingClassifierOptions()) + name: str = 'AverageWordEmbedding' + + +average_word_embedding_classifier_spec = functools.partial( + AverageWordEmbeddingClassifierSpec) + +mobilebert_classifier_spec = functools.partial( + BertClassifierSpec, + hparams=hp.BaseHParams( + epochs=3, + batch_size=48, + learning_rate=3e-5, + distribution_strategy='off'), + name='MobileBert', + uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1', + tflite_input_name={ + 'ids': 'serving_default_input_1:0', + 'mask': 'serving_default_input_3:0', + 'segment_ids': 'serving_default_input_2:0' + }, +) + + +@enum.unique +class SupportedModels(enum.Enum): + """Predefined text classifier model specs supported by Model Maker.""" + AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec + MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py new file mode 100644 index 000000000..118b84fdc --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -0,0 +1,118 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for model_spec.""" + +import os + +import tensorflow as tf + +from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options +from mediapipe.model_maker.python.text.text_classifier import model_spec as ms + + +class ModelSpecTest(tf.test.TestCase): + + def test_predefined_bert_spec(self): + model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value() + self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec) + self.assertEqual(model_spec_obj.name, 'MobileBert') + self.assertEqual( + model_spec_obj.uri, 'https://tfhub.dev/tensorflow/' + 'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1') + self.assertTrue(model_spec_obj.do_lower_case) + self.assertEqual( + model_spec_obj.tflite_input_name, { + 'ids': 'serving_default_input_1:0', + 'mask': 'serving_default_input_3:0', + 'segment_ids': 'serving_default_input_2:0' + }) + self.assertEqual( + model_spec_obj.model_options, + classifier_model_options.BertClassifierOptions( + seq_len=128, do_fine_tuning=True, dropout_rate=0.1)) + self.assertEqual( + model_spec_obj.hparams, + hp.BaseHParams( + epochs=3, + batch_size=48, + learning_rate=3e-5, + distribution_strategy='off')) + + def test_predefined_average_word_embedding_spec(self): + model_spec_obj = ( + ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value()) + self.assertIsInstance(model_spec_obj, ms.AverageWordEmbeddingClassifierSpec) + self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding') + self.assertEqual( + model_spec_obj.model_options, + classifier_model_options.AverageWordEmbeddingClassifierOptions( + seq_len=256, + wordvec_dim=16, + do_lower_case=True, + vocab_size=10000, + dropout_rate=0.2)) + self.assertEqual( + model_spec_obj.hparams, + hp.BaseHParams( + epochs=10, + batch_size=32, + learning_rate=0, + steps_per_epoch=None, + shuffle=False, + distribution_strategy='off', + num_gpus=-1, + tpu='')) + + def test_custom_bert_spec(self): + custom_bert_classifier_options = ( + classifier_model_options.BertClassifierOptions( + seq_len=512, do_fine_tuning=False, dropout_rate=0.3)) + model_spec_obj = ( + ms.SupportedModels.MOBILEBERT_CLASSIFIER.value( + model_options=custom_bert_classifier_options)) + self.assertEqual(model_spec_obj.model_options, + custom_bert_classifier_options) + + def test_custom_average_word_embedding_spec(self): + custom_hparams = hp.BaseHParams( + learning_rate=0.4, + batch_size=64, + epochs=10, + steps_per_epoch=10, + shuffle=True, + export_dir='foo/bar', + distribution_strategy='mirrored', + num_gpus=3, + tpu='tpu/address') + custom_average_word_embedding_model_options = ( + classifier_model_options.AverageWordEmbeddingClassifierOptions( + seq_len=512, + wordvec_dim=32, + do_lower_case=False, + vocab_size=5000, + dropout_rate=0.5)) + model_spec_obj = ( + ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value( + model_options=custom_average_word_embedding_model_options, + hparams=custom_hparams)) + self.assertEqual(model_spec_obj.model_options, + custom_average_word_embedding_model_options) + self.assertEqual(model_spec_obj.hparams, custom_hparams) + + +if __name__ == '__main__': + # Load compressed models from tensorflow_hub + os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED' + tf.test.main() diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py new file mode 100644 index 000000000..0a48f459c --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py @@ -0,0 +1,285 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preprocessors for text classification.""" + +import collections +import os +import re +import tempfile +from typing import Mapping, Sequence, Tuple, Union + +import tensorflow as tf +import tensorflow_hub + +from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds +from official.nlp.data import classifier_data_lib +from official.nlp.tools import tokenization + + +def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None: + """Validates the shape and type of `text` and `label`. + + Args: + text: Stores text data. Should have shape [1] and dtype tf.string. + label: Stores the label for the corresponding `text`. Should have shape [1] + and dtype tf.int64. + + Raises: + ValueError: If either tensor has the wrong shape or type. + """ + if text.shape != [1]: + raise ValueError(f"`text` should have shape [1], got {text.shape}") + if text.dtype != tf.string: + raise ValueError(f"Expected dtype string for `text`, got {text.dtype}") + if label.shape != [1]: + raise ValueError(f"`label` should have shape [1], got {text.shape}") + if label.dtype != tf.int64: + raise ValueError(f"Expected dtype int64 for `label`, got {label.dtype}") + + +def _decode_record( + record: tf.Tensor, name_to_features: Mapping[str, tf.io.FixedLenFeature] +) -> Tuple[Mapping[str, tf.Tensor], tf.Tensor]: + """Decodes a record into input for a BERT model. + + Args: + record: Stores serialized example. + name_to_features: Maps record keys to feature types. + + Returns: + BERT model input features and label for the record. + """ + example = tf.io.parse_single_example(record, name_to_features) + + # tf.Example only supports tf.int64, but the TPU only supports tf.int32. + for name in list(example.keys()): + example[name] = tf.cast(example[name], tf.int32) + + bert_features = { + "input_word_ids": example["input_ids"], + "input_mask": example["input_mask"], + "input_type_ids": example["segment_ids"] + } + return bert_features, example["label_ids"] + + +def _single_file_dataset( + input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature] +) -> tf.data.TFRecordDataset: + """Creates a single-file dataset to be passed for BERT custom training. + + Args: + input_file: Filepath for the dataset. + name_to_features: Maps record keys to feature types. + + Returns: + Dataset containing BERT model input features and labels. + """ + d = tf.data.TFRecordDataset(input_file) + d = d.map( + lambda record: _decode_record(record, name_to_features), + num_parallel_calls=tf.data.AUTOTUNE) + return d + + +class AverageWordEmbeddingClassifierPreprocessor: + """Preprocessor for an Average Word Embedding model. + + Takes (text, label) data and applies regex tokenization and padding to the + text to generate (token IDs, label) data. + + Attributes: + seq_len: Length of the input sequence to the model. + do_lower_case: Whether text inputs should be converted to lower-case. + vocab: Vocabulary of tokens used by the model. + """ + + PAD: str = "" # Index: 0 + START: str = "" # Index: 1 + UNKNOWN: str = "" # Index: 2 + + def __init__(self, seq_len: int, do_lower_case: bool, texts: Sequence[str], + vocab_size: int): + self._seq_len = seq_len + self._do_lower_case = do_lower_case + self._vocab = self._gen_vocab(texts, vocab_size) + + def _gen_vocab(self, texts: Sequence[str], + vocab_size: int) -> Mapping[str, int]: + """Generates vocabulary list in `texts` with size `vocab_size`. + + Args: + texts: All texts (across training and validation data) that will be + preprocessed by the model. + vocab_size: Size of the vocab. + + Returns: + The vocab mapping tokens to IDs. + """ + vocab_counter = collections.Counter() + + for text in texts: + tokens = self._regex_tokenize(text) + for token in tokens: + vocab_counter[token] += 1 + + vocab_freq = vocab_counter.most_common(vocab_size) + vocab_list = [self.PAD, self.START, self.UNKNOWN + ] + [word for word, _ in vocab_freq] + return collections.OrderedDict(((v, i) for i, v in enumerate(vocab_list))) + + def get_vocab(self) -> Mapping[str, int]: + """Returns the vocab of the AverageWordEmbeddingClassifierPreprocessor.""" + return self._vocab + + # TODO: Align with MediaPipe's RegexTokenizer. + def _regex_tokenize(self, text: str) -> Sequence[str]: + """Splits `text` by words but does not split on single quotes. + + Args: + text: Text to be tokenized. + + Returns: + List of tokens. + """ + text = tf.compat.as_text(text) + if self._do_lower_case: + text = text.lower() + tokens = re.compile(r"[^\w\']+").split(text.strip()) + # Filters out any empty strings in `tokens`. + return list(filter(None, tokens)) + + def _tokenize_and_pad(self, text: str) -> Sequence[int]: + """Tokenizes `text` and pads the tokens to `seq_len`. + + Args: + text: Text to be tokenized and padded. + + Returns: + List of token IDs padded to have length `seq_len`. + """ + tokens = self._regex_tokenize(text) + + # Gets ids for START, PAD and UNKNOWN tokens. + start_id = self._vocab[self.START] + pad_id = self._vocab[self.PAD] + unknown_id = self._vocab[self.UNKNOWN] + + token_ids = [self._vocab.get(token, unknown_id) for token in tokens] + token_ids = [start_id] + token_ids + + if len(token_ids) < self._seq_len: + pad_length = self._seq_len - len(token_ids) + token_ids = token_ids + pad_length * [pad_id] + else: + token_ids = token_ids[:self._seq_len] + return token_ids + + def preprocess( + self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset: + """Preprocesses data into input for an Average Word Embedding model. + + Args: + dataset: Stores (text, label) data. + + Returns: + Dataset containing (token IDs, label) data. + """ + token_ids_list = [] + labels_list = [] + for text, label in dataset.gen_tf_dataset(): + _validate_text_and_label(text, label) + token_ids = self._tokenize_and_pad(text.numpy()[0].decode("utf-8")) + token_ids_list.append(token_ids) + labels_list.append(label.numpy()[0]) + + token_ids_ds = tf.data.Dataset.from_tensor_slices(token_ids_list) + labels_ds = tf.data.Dataset.from_tensor_slices(labels_list) + preprocessed_ds = tf.data.Dataset.zip((token_ids_ds, labels_ds)) + return text_classifier_ds.Dataset( + dataset=preprocessed_ds, + size=dataset.size, + label_names=dataset.label_names) + + +class BertClassifierPreprocessor: + """Preprocessor for a BERT-based classifier. + + Attributes: + seq_len: Length of the input sequence to the model. + vocab_file: File containing the BERT vocab. + tokenizer: BERT tokenizer. + """ + + def __init__(self, seq_len: int, do_lower_case: bool, uri: str): + self._seq_len = seq_len + # Vocab filepath is tied to the BERT module's URI. + self._vocab_file = os.path.join( + tensorflow_hub.resolve(uri), "assets", "vocab.txt") + self._tokenizer = tokenization.FullTokenizer(self._vocab_file, + do_lower_case) + + def _get_name_to_features(self): + """Gets the dictionary mapping record keys to feature types.""" + return { + "input_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64), + "input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64), + "segment_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64), + "label_ids": tf.io.FixedLenFeature([], tf.int64), + } + + def get_vocab_file(self) -> str: + """Returns the vocab file of the BertClassifierPreprocessor.""" + return self._vocab_file + + def preprocess( + self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset: + """Preprocesses data into input for a BERT-based classifier. + + Args: + dataset: Stores (text, label) data. + + Returns: + Dataset containing (bert_features, label) data. + """ + examples = [] + for index, (text, label) in enumerate(dataset.gen_tf_dataset()): + _validate_text_and_label(text, label) + examples.append( + classifier_data_lib.InputExample( + guid=str(index), + text_a=text.numpy()[0].decode("utf-8"), + text_b=None, + # InputExample expects the label name rather than the int ID + label=dataset.label_names[label.numpy()[0]])) + + tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord") + classifier_data_lib.file_based_convert_examples_to_features( + examples=examples, + label_list=dataset.label_names, + max_seq_length=self._seq_len, + tokenizer=self._tokenizer, + output_file=tfrecord_file) + preprocessed_ds = _single_file_dataset(tfrecord_file, + self._get_name_to_features()) + return text_classifier_ds.Dataset( + dataset=preprocessed_ds, + size=dataset.size, + label_names=dataset.label_names) + + +TextClassifierPreprocessor = ( + Union[BertClassifierPreprocessor, + AverageWordEmbeddingClassifierPreprocessor]) diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py new file mode 100644 index 000000000..b9558b2b5 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py @@ -0,0 +1,96 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os + +import numpy as np +import numpy.testing as npt +import tensorflow as tf + +from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds +from mediapipe.model_maker.python.text.text_classifier import model_spec +from mediapipe.model_maker.python.text.text_classifier import preprocessor + + +class PreprocessorTest(tf.test.TestCase): + CSV_PARAMS_ = text_classifier_ds.CSVParameters( + text_column='text', label_column='label') + + def _get_csv_file(self): + labels_and_text = (('pos', 'super super super super good'), + (('neg', 'really bad'))) + csv_file = os.path.join(self.get_temp_dir(), 'data.csv') + if os.path.exists(csv_file): + return csv_file + fieldnames = ['text', 'label'] + with open(csv_file, 'w') as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for label, text in labels_and_text: + writer.writerow({'text': text, 'label': label}) + return csv_file + + def test_average_word_embedding_preprocessor(self): + csv_file = self._get_csv_file() + dataset = text_classifier_ds.Dataset.from_csv( + filename=csv_file, csv_params=self.CSV_PARAMS_) + average_word_embedding_preprocessor = ( + preprocessor.AverageWordEmbeddingClassifierPreprocessor( + seq_len=5, + do_lower_case=True, + texts=['super super super super good', 'really bad'], + vocab_size=7)) + preprocessed_dataset = ( + average_word_embedding_preprocessor.preprocess(dataset)) + labels = [] + features_list = [] + for features, label in preprocessed_dataset.gen_tf_dataset(): + self.assertEqual(label.shape, [1]) + labels.append(label.numpy()[0]) + self.assertEqual(features.shape, [1, 5]) + features_list.append(features.numpy()[0]) + self.assertEqual(labels, [1, 0]) + npt.assert_array_equal( + np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]])) + + def test_bert_preprocessor(self): + csv_file = self._get_csv_file() + dataset = text_classifier_ds.Dataset.from_csv( + filename=csv_file, csv_params=self.CSV_PARAMS_) + bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() + bert_preprocessor = preprocessor.BertClassifierPreprocessor( + seq_len=5, do_lower_case=bert_spec.do_lower_case, uri=bert_spec.uri) + preprocessed_dataset = bert_preprocessor.preprocess(dataset) + labels = [] + input_masks = [] + for features, label in preprocessed_dataset.gen_tf_dataset(): + self.assertEqual(label.shape, [1]) + labels.append(label.numpy()[0]) + self.assertSameElements( + features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']) + for feature in features.values(): + self.assertEqual(feature.shape, [1, 5]) + input_masks.append(features['input_mask'].numpy()[0]) + npt.assert_array_equal(features['input_type_ids'].numpy()[0], + [0, 0, 0, 0, 0]) + npt.assert_array_equal( + np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])) + self.assertEqual(labels, [1, 0]) + + +if __name__ == '__main__': + # Load compressed models from tensorflow_hub + os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED' + tf.test.main() diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD new file mode 100644 index 000000000..663c72082 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD @@ -0,0 +1,23 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "testdata", + srcs = ["average_word_embedding_metadata.json"], +) diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/average_word_embedding_metadata.json b/mediapipe/model_maker/python/text/text_classifier/testdata/average_word_embedding_metadata.json new file mode 100644 index 000000000..bf5cb3640 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/average_word_embedding_metadata.json @@ -0,0 +1,63 @@ +{ + "name": "TextClassifier", + "description": "Classify the input text into a set of known categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "input_text", + "description": "Embedding vectors representing the input text to be processed.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "RegexTokenizerOptions", + "options": { + "delim_regex_pattern": "[^\\w\\']+", + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ], + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ] + } + ], + "min_parser_version": "1.2.1" +} diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py new file mode 100644 index 000000000..919277b8a --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -0,0 +1,437 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""API for text classification.""" + +import abc +import os +import tempfile +from typing import Any, Optional, Sequence, Tuple + +import tensorflow as tf +import tensorflow_hub as hub + +from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.core.data import dataset as ds +from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds +from mediapipe.model_maker.python.text.text_classifier import model_options as mo +from mediapipe.model_maker.python.text.text_classifier import model_spec as ms +from mediapipe.model_maker.python.text.text_classifier import preprocessor +from mediapipe.model_maker.python.text.text_classifier import text_classifier_options +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer +from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer +from official.nlp import optimization + + +def _validate(options: text_classifier_options.TextClassifierOptions): + """Validates that `model_options` and `supported_model` are compatible. + + Args: + options: Options for creating and training a text classifier. + + Raises: + ValueError if there is a mismatch between `model_options` and + `supported_model`. + """ + + if options.model_options is None: + return + + if (isinstance(options.model_options, + mo.AverageWordEmbeddingClassifierOptions) and + (options.supported_model != + ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)): + raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER," + f" got {options.supported_model}") + if (isinstance(options.model_options, mo.BertClassifierOptions) and + (options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)): + raise ValueError( + f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}") + + +class TextClassifier(classifier.Classifier): + """API for creating and training a text classification model.""" + + def __init__(self, model_spec: Any, hparams: hp.BaseHParams, + label_names: Sequence[str]): + super().__init__( + model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle) + self._model_spec = model_spec + self._hparams = hparams + self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) + self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None + + @classmethod + def create( + cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, + options: text_classifier_options.TextClassifierOptions + ) -> "TextClassifier": + """Factory function that creates and trains a text classifier. + + Note that `train_data` and `validation_data` are expected to share the same + `label_names` since they should be split from the same dataset. + + Args: + train_data: Training data. + validation_data: Validation data. + options: Options for creating and training the text classifier. + + Returns: + A text classifier. + + Raises: + ValueError if `train_data` and `validation_data` do not have the + same label_names or `options` contains an unknown `supported_model` + """ + if train_data.label_names != validation_data.label_names: + raise ValueError( + f"Training data label names {train_data.label_names} not equal to " + f"validation data label names {validation_data.label_names}") + + _validate(options) + if options.model_options is None: + options.model_options = options.supported_model.value().model_options + + if options.hparams is None: + options.hparams = options.supported_model.value().hparams + + if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER: + text_classifier = ( + _BertClassifier.create_bert_classifier(train_data, validation_data, + options, + train_data.label_names)) + elif (options.supported_model == + ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): + text_classifier = ( + _AverageWordEmbeddingClassifier + .create_average_word_embedding_classifier(train_data, validation_data, + options, + train_data.label_names)) + else: + raise ValueError(f"Unknown model {options.supported_model}") + + return text_classifier + + def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any: + """Overrides Classifier.evaluate(). + + Args: + data: Evaluation dataset. Must be a TextClassifier Dataset. + batch_size: Number of samples per evaluation step. + + Returns: + The loss value and accuracy. + + Raises: + ValueError if `data` is not a TextClassifier Dataset. + """ + # This override is needed because TextClassifier preprocesses its data + # outside of the `gen_tf_dataset()` method. The preprocess call also + # requires a TextClassifier Dataset instead of a core Dataset. + if not isinstance(data, text_ds.Dataset): + raise ValueError("Need a TextClassifier Dataset.") + + processed_data = self._text_preprocessor.preprocess(data) + dataset = processed_data.gen_tf_dataset(batch_size, is_training=False) + return self._model.evaluate(dataset) + + def export_model( + self, + model_name: str = "model.tflite", + quantization_config: Optional[quantization.QuantizationConfig] = None): + """Converts and saves the model to a TFLite file with metadata included. + + Note that only the TFLite file is needed for deployment. This function also + saves a metadata.json file to the same directory as the TFLite file which + can be used to interpret the metadata content in the TFLite file. + + Args: + model_name: File name to save TFLite model with metadata. The full export + path is {self._hparams.export_dir}/{model_name}. + quantization_config: The configuration for model quantization. + """ + if not tf.io.gfile.exists(self._hparams.export_dir): + tf.io.gfile.makedirs(self._hparams.export_dir) + tflite_file = os.path.join(self._hparams.export_dir, model_name) + metadata_file = os.path.join(self._hparams.export_dir, "metadata.json") + + tflite_model = model_util.convert_to_tflite( + model=self._model, quantization_config=quantization_config) + vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt") + self._save_vocab(vocab_filepath) + + writer = self._get_metadata_writer(tflite_model, vocab_filepath) + tflite_model_with_metadata, metadata_json = writer.populate() + model_util.save_tflite(tflite_model_with_metadata, tflite_file) + with open(metadata_file, "w") as f: + f.write(metadata_json) + + @abc.abstractmethod + def _save_vocab(self, vocab_filepath: str): + """Saves the preprocessor's vocab to `vocab_filepath`.""" + + @abc.abstractmethod + def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str): + """Gets the metadata writer for the text classifier TFLite model.""" + + +class _AverageWordEmbeddingClassifier(TextClassifier): + """APIs to help create and train an Average Word Embedding text classifier.""" + + _DELIM_REGEX_PATTERN = r"[^\w\']+" + + def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec, + model_options: mo.AverageWordEmbeddingClassifierOptions, + hparams: hp.BaseHParams, label_names: Sequence[str]): + super().__init__(model_spec, hparams, label_names) + self._model_options = model_options + self._loss_function = "sparse_categorical_crossentropy" + self._metric_function = "accuracy" + self._text_preprocessor: ( + preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None + + @classmethod + def create_average_word_embedding_classifier( + cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, + options: text_classifier_options.TextClassifierOptions, + label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier": + """Creates, trains, and returns an Average Word Embedding classifier. + + Args: + train_data: Training data. + validation_data: Validation data. + options: Options for creating and training the text classifier. + label_names: Label names used in the data. + + Returns: + An Average Word Embedding classifier. + """ + average_word_embedding_classifier = _AverageWordEmbeddingClassifier( + model_spec=options.supported_model.value(), + model_options=options.model_options, + hparams=options.hparams, + label_names=train_data.label_names) + average_word_embedding_classifier._create_and_train_model( + train_data, validation_data) + return average_word_embedding_classifier + + def _create_and_train_model(self, train_data: text_ds.Dataset, + validation_data: text_ds.Dataset): + """Creates the Average Word Embedding classifier keras model and trains it. + + Args: + train_data: Training data. + validation_data: Validation data. + """ + (processed_train_data, processed_validation_data) = ( + self._load_and_run_preprocessor(train_data, validation_data)) + self._create_model() + self._optimizer = "rmsprop" + self._train_model(processed_train_data, processed_validation_data) + + def _load_and_run_preprocessor( + self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset + ) -> Tuple[text_ds.Dataset, text_ds.Dataset]: + """Runs an AverageWordEmbeddingClassifierPreprocessor on the data. + + Args: + train_data: Training data. + validation_data: Validation data. + + Returns: + Preprocessed training data and preprocessed validation data. + """ + train_texts = [text.numpy()[0] for text, _ in train_data.gen_tf_dataset()] + validation_texts = [ + text.numpy()[0] for text, _ in validation_data.gen_tf_dataset() + ] + self._text_preprocessor = ( + preprocessor.AverageWordEmbeddingClassifierPreprocessor( + seq_len=self._model_options.seq_len, + do_lower_case=self._model_options.do_lower_case, + texts=train_texts + validation_texts, + vocab_size=self._model_options.vocab_size)) + return self._text_preprocessor.preprocess( + train_data), self._text_preprocessor.preprocess(validation_data) + + def _create_model(self): + """Creates an Average Word Embedding model.""" + self._model = tf.keras.Sequential([ + tf.keras.layers.InputLayer( + input_shape=[self._model_options.seq_len], dtype=tf.int32), + tf.keras.layers.Embedding( + len(self._text_preprocessor.get_vocab()), + self._model_options.wordvec_dim, + input_length=self._model_options.seq_len), + tf.keras.layers.GlobalAveragePooling1D(), + tf.keras.layers.Dense( + self._model_options.wordvec_dim, activation=tf.nn.relu), + tf.keras.layers.Dropout(self._model_options.dropout_rate), + tf.keras.layers.Dense(self._num_classes, activation="softmax") + ]) + + def _save_vocab(self, vocab_filepath: str): + with tf.io.gfile.GFile(vocab_filepath, "w") as f: + for token, index in self._text_preprocessor.get_vocab().items(): + f.write(f"{token} {index}\n") + + def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str): + return text_classifier_writer.MetadataWriter.create_for_regex_model( + model_buffer=tflite_model, + regex_tokenizer=metadata_writer.RegexTokenizer( + # TODO: Align with MediaPipe's RegexTokenizer. + delim_regex_pattern=self._DELIM_REGEX_PATTERN, + vocab_file_path=vocab_filepath), + labels=metadata_writer.Labels().add(list(self._label_names))) + + +class _BertClassifier(TextClassifier): + """APIs to help create and train a BERT-based text classifier.""" + + _INITIALIZER_RANGE = 0.02 + + def __init__(self, model_spec: ms.BertClassifierSpec, + model_options: mo.BertClassifierOptions, hparams: hp.BaseHParams, + label_names: Sequence[str]): + super().__init__(model_spec, hparams, label_names) + self._model_options = model_options + self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() + self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy( + "test_accuracy", dtype=tf.float32) + self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None + + @classmethod + def create_bert_classifier( + cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, + options: text_classifier_options.TextClassifierOptions, + label_names: Sequence[str]) -> "_BertClassifier": + """Creates, trains, and returns a BERT-based classifier. + + Args: + train_data: Training data. + validation_data: Validation data. + options: Options for creating and training the text classifier. + label_names: Label names used in the data. + + Returns: + A BERT-based classifier. + """ + bert_classifier = _BertClassifier( + model_spec=options.supported_model.value(), + model_options=options.model_options, + hparams=options.hparams, + label_names=train_data.label_names) + bert_classifier._create_and_train_model(train_data, validation_data) + return bert_classifier + + def _create_and_train_model(self, train_data: text_ds.Dataset, + validation_data: text_ds.Dataset): + """Creates the BERT-based classifier keras model and trains it. + + Args: + train_data: Training data. + validation_data: Validation data. + """ + (processed_train_data, processed_validation_data) = ( + self._load_and_run_preprocessor(train_data, validation_data)) + self._create_model() + self._create_optimizer(processed_train_data) + self._train_model(processed_train_data, processed_validation_data) + + def _load_and_run_preprocessor( + self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset + ) -> Tuple[text_ds.Dataset, text_ds.Dataset]: + """Loads a BertClassifierPreprocessor and runs it on the data. + + Args: + train_data: Training data. + validation_data: Validation data. + + Returns: + Preprocessed training data and preprocessed validation data. + """ + self._text_preprocessor = preprocessor.BertClassifierPreprocessor( + seq_len=self._model_options.seq_len, + do_lower_case=self._model_spec.do_lower_case, + uri=self._model_spec.uri) + return (self._text_preprocessor.preprocess(train_data), + self._text_preprocessor.preprocess(validation_data)) + + def _create_model(self): + """Creates a BERT-based classifier model. + + The model architecture consists of stacking a dense classification layer and + dropout layer on top of the BERT encoder outputs. + """ + encoder_inputs = dict( + input_word_ids=tf.keras.layers.Input( + shape=(self._model_options.seq_len,), dtype=tf.int32), + input_mask=tf.keras.layers.Input( + shape=(self._model_options.seq_len,), dtype=tf.int32), + input_type_ids=tf.keras.layers.Input( + shape=(self._model_options.seq_len,), dtype=tf.int32), + ) + encoder = hub.KerasLayer( + self._model_spec.uri, trainable=self._model_options.do_fine_tuning) + encoder_outputs = encoder(encoder_inputs) + pooled_output = encoder_outputs["pooled_output"] + + output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)( + pooled_output) + initializer = tf.keras.initializers.TruncatedNormal( + stddev=self._INITIALIZER_RANGE) + output = tf.keras.layers.Dense( + self._num_classes, + kernel_initializer=initializer, + name="output", + activation="softmax", + dtype=tf.float32)( + output) + self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output) + + def _create_optimizer(self, train_data: text_ds.Dataset): + """Loads an optimizer with a learning rate schedule. + + The decay steps in the learning rate schedule depend on the + `steps_per_epoch` which may depend on the size of the training data. + + Args: + train_data: Training data. + """ + self._hparams.steps_per_epoch = model_util.get_steps_per_epoch( + steps_per_epoch=self._hparams.steps_per_epoch, + batch_size=self._hparams.batch_size, + train_data=train_data) + total_steps = self._hparams.steps_per_epoch * self._hparams.epochs + warmup_steps = int(total_steps * 0.1) + initial_lr = self._hparams.learning_rate + self._optimizer = optimization.create_optimizer(initial_lr, total_steps, + warmup_steps) + + def _save_vocab(self, vocab_filepath: str): + tf.io.gfile.copy( + self._text_preprocessor.get_vocab_file(), + vocab_filepath, + overwrite=True) + + def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str): + return text_classifier_writer.MetadataWriter.create_for_bert_model( + model_buffer=tflite_model, + tokenizer=metadata_writer.BertTokenizer(vocab_filepath), + labels=metadata_writer.Labels().add(list(self._label_names)), + ids_name=self._model_spec.tflite_input_name["ids"], + mask_name=self._model_spec.tflite_input_name["mask"], + segment_name=self._model_spec.tflite_input_name["segment_ids"]) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py new file mode 100644 index 000000000..de6b85751 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py @@ -0,0 +1,108 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Demo for making a text classifier model by MediaPipe Model Maker.""" + +import os +import tempfile + +# Dependency imports + +from absl import app +from absl import flags +from absl import logging +import tensorflow as tf + +from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds +from mediapipe.model_maker.python.text.text_classifier import model_spec as ms +from mediapipe.model_maker.python.text.text_classifier import text_classifier +from mediapipe.model_maker.python.text.text_classifier import text_classifier_options + +FLAGS = flags.FLAGS + + +def define_flags(): + flags.DEFINE_string('export_dir', None, + 'The directory to save exported files.') + flags.DEFINE_enum('supported_model', 'average_word_embedding', + ['average_word_embedding', 'bert'], + 'The text classifier to run.') + flags.mark_flag_as_required('export_dir') + + +def download_demo_data(): + """Downloads demo data, and returns directory path.""" + data_path = tf.keras.utils.get_file( + fname='SST-2.zip', + origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip', + extract=True) + return os.path.join(os.path.dirname(data_path), 'SST-2') # folder name + + +def run(data_dir, + export_dir=tempfile.mkdtemp(), + supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): + """Runs demo.""" + + # Gets training data and validation data. + csv_params = text_ds.CSVParameters( + text_column='sentence', label_column='label', delimiter='\t') + train_data = text_ds.Dataset.from_csv( + filename=os.path.join(os.path.join(data_dir, 'train.tsv')), + csv_params=csv_params) + validation_data = text_ds.Dataset.from_csv( + filename=os.path.join(os.path.join(data_dir, 'dev.tsv')), + csv_params=csv_params) + + quantization_config = None + if supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER: + hparams = hp.BaseHParams( + epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir) + # Warning: This takes extremely long to run on CPU + elif supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER: + quantization_config = quantization.QuantizationConfig.for_dynamic() + hparams = hp.BaseHParams( + epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir) + + # Fine-tunes the model. + options = text_classifier_options.TextClassifierOptions( + supported_model=supported_model, hparams=hparams) + model = text_classifier.TextClassifier.create(train_data, validation_data, + options) + + # Gets evaluation results. + _, acc = model.evaluate(validation_data) + print('Eval accuracy: %f' % acc) + + model.export_model(quantization_config=quantization_config) + model.export_labels(export_dir=options.hparams.export_dir) + + +def main(_): + logging.set_verbosity(logging.INFO) + data_dir = download_demo_data() + export_dir = os.path.expanduser(FLAGS.export_dir) + + if FLAGS.supported_model == 'average_word_embedding': + supported_model = ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER + elif FLAGS.supported_model == 'bert': + supported_model = ms.SupportedModels.MOBILEBERT_CLASSIFIER + + run(data_dir, export_dir, supported_model) + + +if __name__ == '__main__': + define_flags() + app.run(main) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_options.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_options.py new file mode 100644 index 000000000..a02f17347 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_options.py @@ -0,0 +1,38 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""User-facing customization options to create and train a text classifier.""" + +import dataclasses +from typing import Optional + +from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.text.text_classifier import model_options as mo +from mediapipe.model_maker.python.text.text_classifier import model_spec as ms + + +@dataclasses.dataclass +class TextClassifierOptions: + """User-facing options for creating the text classifier. + + Attributes: + supported_model: A preconfigured model spec. + hparams: Training hyperparameters the user can set to override the ones in + `supported_model`. + model_options: Model options the user can set to override the ones in + `supported_model`. The model options type should be consistent with the + architecture of the `supported_model`. + """ + supported_model: ms.SupportedModels + hparams: Optional[hp.BaseHParams] = None + model_options: Optional[mo.TextClassifierModelOptions] = None diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py new file mode 100644 index 000000000..41dbb464a --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -0,0 +1,138 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import filecmp +import os + +import tensorflow as tf + +from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.text.text_classifier import dataset +from mediapipe.model_maker.python.text.text_classifier import model_options as mo +from mediapipe.model_maker.python.text.text_classifier import model_spec as ms +from mediapipe.model_maker.python.text.text_classifier import text_classifier +from mediapipe.model_maker.python.text.text_classifier import text_classifier_options +from mediapipe.tasks.python.test import test_utils + + +class TextClassifierTest(tf.test.TestCase): + + _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( + test_utils.get_test_data_path('average_word_embedding_metadata.json')) + + def _get_data(self): + labels_and_text = (('pos', 'super good'), (('neg', 'really bad'))) + csv_file = os.path.join(self.get_temp_dir(), 'data.csv') + if os.path.exists(csv_file): + return csv_file + fieldnames = ['text', 'label'] + with open(csv_file, 'w') as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for label, text in labels_and_text: + writer.writerow({'text': text, 'label': label}) + csv_params = dataset.CSVParameters(text_column='text', label_column='label') + all_data = dataset.Dataset.from_csv( + filename=csv_file, csv_params=csv_params) + return all_data.split(0.5) + + def test_create_and_train_average_word_embedding_model(self): + train_data, validation_data = self._get_data() + options = text_classifier_options.TextClassifierOptions( + supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER, + hparams=hp.BaseHParams(epochs=1, batch_size=1, learning_rate=0)) + average_word_embedding_classifier = text_classifier.TextClassifier.create( + train_data, validation_data, options) + + _, accuracy = average_word_embedding_classifier.evaluate(validation_data) + self.assertGreaterEqual(accuracy, 0.0) + + # Test export_model + average_word_embedding_classifier.export_model() + output_metadata_file = os.path.join(options.hparams.export_dir, + 'metadata.json') + output_tflite_file = os.path.join(options.hparams.export_dir, + 'model.tflite') + + self.assertTrue(os.path.exists(output_tflite_file)) + self.assertGreater(os.path.getsize(output_tflite_file), 0) + + self.assertTrue(os.path.exists(output_metadata_file)) + self.assertGreater(os.path.getsize(output_metadata_file), 0) + self.assertTrue( + filecmp.cmp(output_metadata_file, + self._AVERAGE_WORD_EMBEDDING_JSON_FILE)) + + def test_create_and_train_bert(self): + train_data, validation_data = self._get_data() + options = text_classifier_options.TextClassifierOptions( + supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER, + model_options=mo.BertClassifierOptions(do_fine_tuning=False, seq_len=2), + hparams=hp.BaseHParams( + epochs=1, + batch_size=1, + learning_rate=3e-5, + distribution_strategy='off')) + bert_classifier = text_classifier.TextClassifier.create( + train_data, validation_data, options) + + _, accuracy = bert_classifier.evaluate(validation_data) + self.assertGreaterEqual(accuracy, 0.0) + # TODO: Add a unit test that does not run OOM. + + def test_label_mismatch(self): + options = ( + text_classifier_options.TextClassifierOptions( + supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER)) + train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) + train_data = dataset.Dataset(train_tf_dataset, 1, ['foo']) + validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) + validation_data = dataset.Dataset(validation_tf_dataset, 1, ['bar']) + with self.assertRaisesRegex( + ValueError, + 'Training data label names .* not equal to validation data label names' + ): + text_classifier.TextClassifier.create(train_data, validation_data, + options) + + def test_options_mismatch(self): + train_data, validation_data = self._get_data() + + avg_options = ( + text_classifier_options.TextClassifierOptions( + supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER, + model_options=mo.AverageWordEmbeddingClassifierOptions())) + with self.assertRaisesRegex( + ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' + ' SupportedModels.MOBILEBERT_CLASSIFIER'): + text_classifier.TextClassifier.create(train_data, validation_data, + avg_options) + + bert_options = ( + text_classifier_options.TextClassifierOptions( + supported_model=( + ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER), + model_options=mo.BertClassifierOptions())) + with self.assertRaisesRegex( + ValueError, 'Expected MOBILEBERT_CLASSIFIER, got' + ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'): + text_classifier.TextClassifier.create(train_data, validation_data, + bert_options) + + +if __name__ == '__main__': + # Load compressed models from tensorflow_hub + os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED' + tf.test.main()