From c2ac040a6c3cd502ab1a5c65018bb0d029e0470d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 11 Nov 2022 16:50:38 -0800 Subject: [PATCH] Adds a public import API for `TextClassifier`. PiperOrigin-RevId: 487949023 --- .../python/text/text_classifier/BUILD | 26 ++++---- .../python/text/text_classifier/__init__.py | 18 ++++++ .../text/text_classifier/model_options.py | 10 ++-- .../python/text/text_classifier/model_spec.py | 4 +- .../text/text_classifier/model_spec_test.py | 8 +-- .../text/text_classifier/text_classifier.py | 10 ++-- .../text_classifier/text_classifier_demo.py | 32 +++++----- .../text_classifier/text_classifier_test.py | 60 ++++++++++--------- 8 files changed, 97 insertions(+), 71 deletions(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 357263678..0c35e7966 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -21,6 +21,19 @@ package( licenses(["notice"]) +py_library( + name = "text_classifier_import", + srcs = ["__init__.py"], + deps = [ + ":dataset", + ":model_options", + ":model_spec", + ":text_classifier", + ":text_classifier_options", + "//mediapipe/model_maker/python/core:hyperparameters", + ], +) + py_library( name = "model_options", srcs = ["model_options.py"], @@ -114,12 +127,7 @@ py_test( ], tags = ["requires-net:external"], deps = [ - ":dataset", - ":model_options", - ":model_spec", - ":text_classifier", - ":text_classifier_options", - "//mediapipe/model_maker/python/core:hyperparameters", + ":text_classifier_import", "//mediapipe/tasks/python/test:test_utils", ], ) @@ -128,11 +136,7 @@ py_library( name = "text_classifier_demo_lib", srcs = ["text_classifier_demo.py"], deps = [ - ":dataset", - ":model_spec", - ":text_classifier", - ":text_classifier_options", - "//mediapipe/model_maker/python/core:hyperparameters", + ":text_classifier_import", "//mediapipe/model_maker/python/core/utils:quantization", ], ) diff --git a/mediapipe/model_maker/python/text/text_classifier/__init__.py b/mediapipe/model_maker/python/text/text_classifier/__init__.py index 7ca2f9216..5f34fe866 100644 --- a/mediapipe/model_maker/python/text/text_classifier/__init__.py +++ b/mediapipe/model_maker/python/text/text_classifier/__init__.py @@ -11,3 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""MediaPipe Public Python API for Text Classifier.""" + +from mediapipe.model_maker.python.core import hyperparameters +from mediapipe.model_maker.python.text.text_classifier import dataset +from mediapipe.model_maker.python.text.text_classifier import model_options +from mediapipe.model_maker.python.text.text_classifier import model_spec +from mediapipe.model_maker.python.text.text_classifier import text_classifier +from mediapipe.model_maker.python.text.text_classifier import text_classifier_options + +HParams = hyperparameters.BaseHParams +CSVParams = dataset.CSVParameters +Dataset = dataset.Dataset +AverageWordEmbeddingClassifierModelOptions = ( + model_options.AverageWordEmbeddingClassifierModelOptions) +BertClassifierModelOptions = model_options.BertClassifierModelOptions +SupportedModels = model_spec.SupportedModels +TextClassifier = text_classifier.TextClassifier +TextClassifierOptions = text_classifier_options.TextClassifierOptions diff --git a/mediapipe/model_maker/python/text/text_classifier/model_options.py b/mediapipe/model_maker/python/text/text_classifier/model_options.py index b48e38da1..3dfce316b 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_options.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_options.py @@ -18,12 +18,12 @@ from typing import Union from mediapipe.model_maker.python.text.core import bert_model_options -# BERT text classifier options inherited from BertModelOptions. -BertClassifierOptions = bert_model_options.BertModelOptions +# BERT text classifier model options inherited from BertModelOptions. +BertClassifierModelOptions = bert_model_options.BertModelOptions @dataclasses.dataclass -class AverageWordEmbeddingClassifierOptions: +class AverageWordEmbeddingClassifierModelOptions: """Configurable model options for an Average Word Embedding classifier. Attributes: @@ -41,5 +41,5 @@ class AverageWordEmbeddingClassifierOptions: dropout_rate: float = 0.2 -TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierOptions, - BertClassifierOptions] +TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierModelOptions, + BertClassifierModelOptions] diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index c2694786c..1e215c528 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -38,8 +38,8 @@ class AverageWordEmbeddingClassifierSpec: # `learning_rate` is unused for the average word embedding model hparams: hp.BaseHParams = hp.BaseHParams( epochs=10, batch_size=32, learning_rate=0) - model_options: mo.AverageWordEmbeddingClassifierOptions = ( - mo.AverageWordEmbeddingClassifierOptions()) + model_options: mo.AverageWordEmbeddingClassifierModelOptions = ( + mo.AverageWordEmbeddingClassifierModelOptions()) name: str = 'AverageWordEmbedding' diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py index 118b84fdc..6cd5408bc 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -40,7 +40,7 @@ class ModelSpecTest(tf.test.TestCase): }) self.assertEqual( model_spec_obj.model_options, - classifier_model_options.BertClassifierOptions( + classifier_model_options.BertClassifierModelOptions( seq_len=128, do_fine_tuning=True, dropout_rate=0.1)) self.assertEqual( model_spec_obj.hparams, @@ -57,7 +57,7 @@ class ModelSpecTest(tf.test.TestCase): self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding') self.assertEqual( model_spec_obj.model_options, - classifier_model_options.AverageWordEmbeddingClassifierOptions( + classifier_model_options.AverageWordEmbeddingClassifierModelOptions( seq_len=256, wordvec_dim=16, do_lower_case=True, @@ -77,7 +77,7 @@ class ModelSpecTest(tf.test.TestCase): def test_custom_bert_spec(self): custom_bert_classifier_options = ( - classifier_model_options.BertClassifierOptions( + classifier_model_options.BertClassifierModelOptions( seq_len=512, do_fine_tuning=False, dropout_rate=0.3)) model_spec_obj = ( ms.SupportedModels.MOBILEBERT_CLASSIFIER.value( @@ -97,7 +97,7 @@ class ModelSpecTest(tf.test.TestCase): num_gpus=3, tpu='tpu/address') custom_average_word_embedding_model_options = ( - classifier_model_options.AverageWordEmbeddingClassifierOptions( + classifier_model_options.AverageWordEmbeddingClassifierModelOptions( seq_len=512, wordvec_dim=32, do_lower_case=False, diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 919277b8a..a6ad9ab55 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -51,12 +51,12 @@ def _validate(options: text_classifier_options.TextClassifierOptions): return if (isinstance(options.model_options, - mo.AverageWordEmbeddingClassifierOptions) and + mo.AverageWordEmbeddingClassifierModelOptions) and (options.supported_model != ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)): raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER," f" got {options.supported_model}") - if (isinstance(options.model_options, mo.BertClassifierOptions) and + if (isinstance(options.model_options, mo.BertClassifierModelOptions) and (options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)): raise ValueError( f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}") @@ -194,7 +194,7 @@ class _AverageWordEmbeddingClassifier(TextClassifier): _DELIM_REGEX_PATTERN = r"[^\w\']+" def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec, - model_options: mo.AverageWordEmbeddingClassifierOptions, + model_options: mo.AverageWordEmbeddingClassifierModelOptions, hparams: hp.BaseHParams, label_names: Sequence[str]): super().__init__(model_spec, hparams, label_names) self._model_options = model_options @@ -304,8 +304,8 @@ class _BertClassifier(TextClassifier): _INITIALIZER_RANGE = 0.02 def __init__(self, model_spec: ms.BertClassifierSpec, - model_options: mo.BertClassifierOptions, hparams: hp.BaseHParams, - label_names: Sequence[str]): + model_options: mo.BertClassifierModelOptions, + hparams: hp.BaseHParams, label_names: Sequence[str]): super().__init__(model_spec, hparams, label_names) self._model_options = model_options self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py index de6b85751..08f4c2ad3 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py @@ -23,12 +23,8 @@ from absl import flags from absl import logging import tensorflow as tf -from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core.utils import quantization -from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds -from mediapipe.model_maker.python.text.text_classifier import model_spec as ms -from mediapipe.model_maker.python.text.text_classifier import text_classifier -from mediapipe.model_maker.python.text.text_classifier import text_classifier_options +from mediapipe.model_maker.python.text import text_classifier FLAGS = flags.FLAGS @@ -53,31 +49,34 @@ def download_demo_data(): def run(data_dir, export_dir=tempfile.mkdtemp(), - supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): + supported_model=( + text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)): """Runs demo.""" # Gets training data and validation data. - csv_params = text_ds.CSVParameters( + csv_params = text_classifier.CSVParams( text_column='sentence', label_column='label', delimiter='\t') - train_data = text_ds.Dataset.from_csv( + train_data = text_classifier.Dataset.from_csv( filename=os.path.join(os.path.join(data_dir, 'train.tsv')), csv_params=csv_params) - validation_data = text_ds.Dataset.from_csv( + validation_data = text_classifier.Dataset.from_csv( filename=os.path.join(os.path.join(data_dir, 'dev.tsv')), csv_params=csv_params) quantization_config = None - if supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER: - hparams = hp.BaseHParams( + if (supported_model == + text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): + hparams = text_classifier.HParams( epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir) # Warning: This takes extremely long to run on CPU - elif supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER: + elif ( + supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER): quantization_config = quantization.QuantizationConfig.for_dynamic() - hparams = hp.BaseHParams( + hparams = text_classifier.HParams( epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir) # Fine-tunes the model. - options = text_classifier_options.TextClassifierOptions( + options = text_classifier.TextClassifierOptions( supported_model=supported_model, hparams=hparams) model = text_classifier.TextClassifier.create(train_data, validation_data, options) @@ -96,9 +95,10 @@ def main(_): export_dir = os.path.expanduser(FLAGS.export_dir) if FLAGS.supported_model == 'average_word_embedding': - supported_model = ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER + supported_model = ( + text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER) elif FLAGS.supported_model == 'bert': - supported_model = ms.SupportedModels.MOBILEBERT_CLASSIFIER + supported_model = text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER run(data_dir, export_dir, supported_model) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 41dbb464a..55ffd6a7b 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -18,12 +18,7 @@ import os import tensorflow as tf -from mediapipe.model_maker.python.core import hyperparameters as hp -from mediapipe.model_maker.python.text.text_classifier import dataset -from mediapipe.model_maker.python.text.text_classifier import model_options as mo -from mediapipe.model_maker.python.text.text_classifier import model_spec as ms -from mediapipe.model_maker.python.text.text_classifier import text_classifier -from mediapipe.model_maker.python.text.text_classifier import text_classifier_options +from mediapipe.model_maker.python.text import text_classifier from mediapipe.tasks.python.test import test_utils @@ -43,18 +38,23 @@ class TextClassifierTest(tf.test.TestCase): writer.writeheader() for label, text in labels_and_text: writer.writerow({'text': text, 'label': label}) - csv_params = dataset.CSVParameters(text_column='text', label_column='label') - all_data = dataset.Dataset.from_csv( + csv_params = text_classifier.CSVParams( + text_column='text', label_column='label') + all_data = text_classifier.Dataset.from_csv( filename=csv_file, csv_params=csv_params) return all_data.split(0.5) def test_create_and_train_average_word_embedding_model(self): train_data, validation_data = self._get_data() - options = text_classifier_options.TextClassifierOptions( - supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER, - hparams=hp.BaseHParams(epochs=1, batch_size=1, learning_rate=0)) - average_word_embedding_classifier = text_classifier.TextClassifier.create( - train_data, validation_data, options) + options = ( + text_classifier.TextClassifierOptions( + supported_model=(text_classifier.SupportedModels + .AVERAGE_WORD_EMBEDDING_CLASSIFIER), + hparams=text_classifier.HParams( + epochs=1, batch_size=1, learning_rate=0))) + average_word_embedding_classifier = ( + text_classifier.TextClassifier.create(train_data, validation_data, + options)) _, accuracy = average_word_embedding_classifier.evaluate(validation_data) self.assertGreaterEqual(accuracy, 0.0) @@ -77,10 +77,11 @@ class TextClassifierTest(tf.test.TestCase): def test_create_and_train_bert(self): train_data, validation_data = self._get_data() - options = text_classifier_options.TextClassifierOptions( - supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER, - model_options=mo.BertClassifierOptions(do_fine_tuning=False, seq_len=2), - hparams=hp.BaseHParams( + options = text_classifier.TextClassifierOptions( + supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, + model_options=text_classifier.BertClassifierModelOptions( + do_fine_tuning=False, seq_len=2), + hparams=text_classifier.HParams( epochs=1, batch_size=1, learning_rate=3e-5, @@ -94,12 +95,13 @@ class TextClassifierTest(tf.test.TestCase): def test_label_mismatch(self): options = ( - text_classifier_options.TextClassifierOptions( - supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER)) + text_classifier.TextClassifierOptions( + supported_model=( + text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER))) train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) - train_data = dataset.Dataset(train_tf_dataset, 1, ['foo']) + train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo']) validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) - validation_data = dataset.Dataset(validation_tf_dataset, 1, ['bar']) + validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar']) with self.assertRaisesRegex( ValueError, 'Training data label names .* not equal to validation data label names' @@ -111,9 +113,11 @@ class TextClassifierTest(tf.test.TestCase): train_data, validation_data = self._get_data() avg_options = ( - text_classifier_options.TextClassifierOptions( - supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER, - model_options=mo.AverageWordEmbeddingClassifierOptions())) + text_classifier.TextClassifierOptions( + supported_model=( + text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER), + model_options=( + text_classifier.AverageWordEmbeddingClassifierModelOptions()))) with self.assertRaisesRegex( ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' ' SupportedModels.MOBILEBERT_CLASSIFIER'): @@ -121,10 +125,10 @@ class TextClassifierTest(tf.test.TestCase): avg_options) bert_options = ( - text_classifier_options.TextClassifierOptions( - supported_model=( - ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER), - model_options=mo.BertClassifierOptions())) + text_classifier.TextClassifierOptions( + supported_model=(text_classifier.SupportedModels + .AVERAGE_WORD_EMBEDDING_CLASSIFIER), + model_options=text_classifier.BertClassifierModelOptions())) with self.assertRaisesRegex( ValueError, 'Expected MOBILEBERT_CLASSIFIER, got' ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):