From c31aaa94a6ba862b5148922ebff8ec0f6f127316 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 8 Nov 2022 16:47:48 -0800 Subject: [PATCH] Adds a `BertClassifier`. PiperOrigin-RevId: 487086744 --- mediapipe/model_maker/python/core/tasks/BUILD | 3 + .../python/core/tasks/classifier.py | 67 +++++++++++++++++-- .../image_classifier/image_classifier.py | 2 +- 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/mediapipe/model_maker/python/core/tasks/BUILD b/mediapipe/model_maker/python/core/tasks/BUILD index 124de621a..8c5448556 100644 --- a/mediapipe/model_maker/python/core/tasks/BUILD +++ b/mediapipe/model_maker/python/core/tasks/BUILD @@ -45,7 +45,10 @@ py_library( srcs = ["classifier.py"], deps = [ ":custom_model", + "//mediapipe/model_maker/python/core:hyperparameters", + "//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/data:dataset", + "//mediapipe/model_maker/python/core/utils:model_util", ], ) diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 5d0fbd066..200726864 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -13,24 +13,24 @@ # limitations under the License. """Custom classifier.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import os -from typing import Any, List +from typing import Any, Callable, Optional, Sequence, Union import tensorflow as tf +from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds from mediapipe.model_maker.python.core.data import dataset from mediapipe.model_maker.python.core.tasks import custom_model +from mediapipe.model_maker.python.core.utils import model_util class Classifier(custom_model.CustomModel): """An abstract base class that represents a TensorFlow classifier.""" - def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool): - """Initilizes a classifier with its specifications. + def __init__(self, model_spec: Any, label_names: Sequence[str], + shuffle: bool): + """Initializes a classifier with its specifications. Args: model_spec: Specification for the model. @@ -40,6 +40,59 @@ class Classifier(custom_model.CustomModel): super(Classifier, self).__init__(model_spec, shuffle) self._label_names = label_names self._num_classes = len(label_names) + self._model: tf.keras.Model = None + self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None + self._loss_function: Union[str, tf.keras.losses.Loss] = None + self._metric_function: Union[str, tf.keras.metrics.Metric] = None + self._callbacks: Sequence[tf.keras.callbacks.Callback] = None + self._hparams: hp.BaseHParams = None + self._history: tf.keras.callbacks.History = None + + # TODO: Integrate this into all Model Maker tasks. + def _train_model(self, + train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset, + preprocessor: Optional[Callable[..., bool]] = None): + """Trains the classifier model. + + Compiles and fits the tf.keras `_model` and records the `_history`. + + Args: + train_data: Training data. + validation_data: Validation data. + preprocessor: An optional data preprocessor that can be used when + generating a tf.data.Dataset. + """ + tf.compat.v1.logging.info('Training the models...') + if len(train_data) < self._hparams.batch_size: + raise ValueError( + f'The size of the train_data {len(train_data)} can\'t be smaller than' + f' batch_size {self._hparams.batch_size}. To solve this problem, set' + ' the batch_size smaller or increase the size of the train_data.') + + train_dataset = train_data.gen_tf_dataset( + batch_size=self._hparams.batch_size, + is_training=True, + shuffle=self._shuffle, + preprocess=preprocessor) + self._hparams.steps_per_epoch = model_util.get_steps_per_epoch( + steps_per_epoch=self._hparams.steps_per_epoch, + batch_size=self._hparams.batch_size, + train_data=train_data) + train_dataset = train_dataset.take(count=self._hparams.steps_per_epoch) + validation_dataset = validation_data.gen_tf_dataset( + batch_size=self._hparams.batch_size, + is_training=False, + preprocess=preprocessor) + self._model.compile( + optimizer=self._optimizer, + loss=self._loss_function, + metrics=[self._metric_function]) + self._history = self._model.fit( + x=train_dataset, + epochs=self._hparams.epochs, + validation_data=validation_dataset, + callbacks=self._callbacks) def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: """Evaluates the classifier with the provided evaluation dataset. diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index f6edbeab4..1ff6132b4 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -193,7 +193,7 @@ class ImageClassifier(classifier.Classifier): tflite_model, self._model_spec.mean_rgb, self._model_spec.stddev_rgb, - labels=metadata_writer.Labels().add(self._label_names)) + labels=metadata_writer.Labels().add(list(self._label_names))) tflite_model_with_metadata, metadata_json = writer.populate() model_util.save_tflite(tflite_model_with_metadata, tflite_file) with open(metadata_file, 'w') as f: