Adds a BertClassifier.

PiperOrigin-RevId: 487086744
This commit is contained in:
MediaPipe Team 2022-11-08 16:47:48 -08:00 committed by Copybara-Service
parent 669d539551
commit c31aaa94a6
3 changed files with 64 additions and 8 deletions

View File

@ -45,7 +45,10 @@ py_library(
srcs = ["classifier.py"],
deps = [
":custom_model",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/utils:model_util",
],
)

View File

@ -13,24 +13,24 @@
# limitations under the License.
"""Custom classifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from typing import Any, List
from typing import Any, Callable, Optional, Sequence, Union
import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.tasks import custom_model
from mediapipe.model_maker.python.core.utils import model_util
class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool):
"""Initilizes a classifier with its specifications.
def __init__(self, model_spec: Any, label_names: Sequence[str],
shuffle: bool):
"""Initializes a classifier with its specifications.
Args:
model_spec: Specification for the model.
@ -40,6 +40,59 @@ class Classifier(custom_model.CustomModel):
super(Classifier, self).__init__(model_spec, shuffle)
self._label_names = label_names
self._num_classes = len(label_names)
self._model: tf.keras.Model = None
self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
self._loss_function: Union[str, tf.keras.losses.Loss] = None
self._metric_function: Union[str, tf.keras.metrics.Metric] = None
self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None
# TODO: Integrate this into all Model Maker tasks.
def _train_model(self,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None):
"""Trains the classifier model.
Compiles and fits the tf.keras `_model` and records the `_history`.
Args:
train_data: Training data.
validation_data: Validation data.
preprocessor: An optional data preprocessor that can be used when
generating a tf.data.Dataset.
"""
tf.compat.v1.logging.info('Training the models...')
if len(train_data) < self._hparams.batch_size:
raise ValueError(
f'The size of the train_data {len(train_data)} can\'t be smaller than'
f' batch_size {self._hparams.batch_size}. To solve this problem, set'
' the batch_size smaller or increase the size of the train_data.')
train_dataset = train_data.gen_tf_dataset(
batch_size=self._hparams.batch_size,
is_training=True,
shuffle=self._shuffle,
preprocess=preprocessor)
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=self._hparams.steps_per_epoch,
batch_size=self._hparams.batch_size,
train_data=train_data)
train_dataset = train_dataset.take(count=self._hparams.steps_per_epoch)
validation_dataset = validation_data.gen_tf_dataset(
batch_size=self._hparams.batch_size,
is_training=False,
preprocess=preprocessor)
self._model.compile(
optimizer=self._optimizer,
loss=self._loss_function,
metrics=[self._metric_function])
self._history = self._model.fit(
x=train_dataset,
epochs=self._hparams.epochs,
validation_data=validation_dataset,
callbacks=self._callbacks)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset.

View File

@ -193,7 +193,7 @@ class ImageClassifier(classifier.Classifier):
tflite_model,
self._model_spec.mean_rgb,
self._model_spec.stddev_rgb,
labels=metadata_writer.Labels().add(self._label_names))
labels=metadata_writer.Labels().add(list(self._label_names)))
tflite_model_with_metadata, metadata_json = writer.populate()
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with open(metadata_file, 'w') as f: