Adds a BertClassifier
.
PiperOrigin-RevId: 487086744
This commit is contained in:
parent
669d539551
commit
c31aaa94a6
|
@ -45,7 +45,10 @@ py_library(
|
|||
srcs = ["classifier.py"],
|
||||
deps = [
|
||||
":custom_model",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -13,24 +13,24 @@
|
|||
# limitations under the License.
|
||||
"""Custom classifier."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from typing import Any, List
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
from mediapipe.model_maker.python.core.tasks import custom_model
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
|
||||
|
||||
class Classifier(custom_model.CustomModel):
|
||||
"""An abstract base class that represents a TensorFlow classifier."""
|
||||
|
||||
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool):
|
||||
"""Initilizes a classifier with its specifications.
|
||||
def __init__(self, model_spec: Any, label_names: Sequence[str],
|
||||
shuffle: bool):
|
||||
"""Initializes a classifier with its specifications.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
|
@ -40,6 +40,59 @@ class Classifier(custom_model.CustomModel):
|
|||
super(Classifier, self).__init__(model_spec, shuffle)
|
||||
self._label_names = label_names
|
||||
self._num_classes = len(label_names)
|
||||
self._model: tf.keras.Model = None
|
||||
self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
|
||||
self._loss_function: Union[str, tf.keras.losses.Loss] = None
|
||||
self._metric_function: Union[str, tf.keras.metrics.Metric] = None
|
||||
self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
|
||||
self._hparams: hp.BaseHParams = None
|
||||
self._history: tf.keras.callbacks.History = None
|
||||
|
||||
# TODO: Integrate this into all Model Maker tasks.
|
||||
def _train_model(self,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
preprocessor: Optional[Callable[..., bool]] = None):
|
||||
"""Trains the classifier model.
|
||||
|
||||
Compiles and fits the tf.keras `_model` and records the `_history`.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
preprocessor: An optional data preprocessor that can be used when
|
||||
generating a tf.data.Dataset.
|
||||
"""
|
||||
tf.compat.v1.logging.info('Training the models...')
|
||||
if len(train_data) < self._hparams.batch_size:
|
||||
raise ValueError(
|
||||
f'The size of the train_data {len(train_data)} can\'t be smaller than'
|
||||
f' batch_size {self._hparams.batch_size}. To solve this problem, set'
|
||||
' the batch_size smaller or increase the size of the train_data.')
|
||||
|
||||
train_dataset = train_data.gen_tf_dataset(
|
||||
batch_size=self._hparams.batch_size,
|
||||
is_training=True,
|
||||
shuffle=self._shuffle,
|
||||
preprocess=preprocessor)
|
||||
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=self._hparams.steps_per_epoch,
|
||||
batch_size=self._hparams.batch_size,
|
||||
train_data=train_data)
|
||||
train_dataset = train_dataset.take(count=self._hparams.steps_per_epoch)
|
||||
validation_dataset = validation_data.gen_tf_dataset(
|
||||
batch_size=self._hparams.batch_size,
|
||||
is_training=False,
|
||||
preprocess=preprocessor)
|
||||
self._model.compile(
|
||||
optimizer=self._optimizer,
|
||||
loss=self._loss_function,
|
||||
metrics=[self._metric_function])
|
||||
self._history = self._model.fit(
|
||||
x=train_dataset,
|
||||
epochs=self._hparams.epochs,
|
||||
validation_data=validation_dataset,
|
||||
callbacks=self._callbacks)
|
||||
|
||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||
"""Evaluates the classifier with the provided evaluation dataset.
|
||||
|
|
|
@ -193,7 +193,7 @@ class ImageClassifier(classifier.Classifier):
|
|||
tflite_model,
|
||||
self._model_spec.mean_rgb,
|
||||
self._model_spec.stddev_rgb,
|
||||
labels=metadata_writer.Labels().add(self._label_names))
|
||||
labels=metadata_writer.Labels().add(list(self._label_names)))
|
||||
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||
with open(metadata_file, 'w') as f:
|
||||
|
|
Loading…
Reference in New Issue
Block a user