Adds a BertClassifier
.
PiperOrigin-RevId: 487086744
This commit is contained in:
parent
669d539551
commit
c31aaa94a6
|
@ -45,7 +45,10 @@ py_library(
|
||||||
srcs = ["classifier.py"],
|
srcs = ["classifier.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":custom_model",
|
":custom_model",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
"//mediapipe/model_maker/python/core/data:dataset",
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,24 +13,24 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Custom classifier."""
|
"""Custom classifier."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, List
|
from typing import Any, Callable, Optional, Sequence, Union
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
|
||||||
from mediapipe.model_maker.python.core.data import dataset
|
from mediapipe.model_maker.python.core.data import dataset
|
||||||
from mediapipe.model_maker.python.core.tasks import custom_model
|
from mediapipe.model_maker.python.core.tasks import custom_model
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
|
||||||
|
|
||||||
class Classifier(custom_model.CustomModel):
|
class Classifier(custom_model.CustomModel):
|
||||||
"""An abstract base class that represents a TensorFlow classifier."""
|
"""An abstract base class that represents a TensorFlow classifier."""
|
||||||
|
|
||||||
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool):
|
def __init__(self, model_spec: Any, label_names: Sequence[str],
|
||||||
"""Initilizes a classifier with its specifications.
|
shuffle: bool):
|
||||||
|
"""Initializes a classifier with its specifications.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_spec: Specification for the model.
|
model_spec: Specification for the model.
|
||||||
|
@ -40,6 +40,59 @@ class Classifier(custom_model.CustomModel):
|
||||||
super(Classifier, self).__init__(model_spec, shuffle)
|
super(Classifier, self).__init__(model_spec, shuffle)
|
||||||
self._label_names = label_names
|
self._label_names = label_names
|
||||||
self._num_classes = len(label_names)
|
self._num_classes = len(label_names)
|
||||||
|
self._model: tf.keras.Model = None
|
||||||
|
self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
|
||||||
|
self._loss_function: Union[str, tf.keras.losses.Loss] = None
|
||||||
|
self._metric_function: Union[str, tf.keras.metrics.Metric] = None
|
||||||
|
self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
|
||||||
|
self._hparams: hp.BaseHParams = None
|
||||||
|
self._history: tf.keras.callbacks.History = None
|
||||||
|
|
||||||
|
# TODO: Integrate this into all Model Maker tasks.
|
||||||
|
def _train_model(self,
|
||||||
|
train_data: classification_ds.ClassificationDataset,
|
||||||
|
validation_data: classification_ds.ClassificationDataset,
|
||||||
|
preprocessor: Optional[Callable[..., bool]] = None):
|
||||||
|
"""Trains the classifier model.
|
||||||
|
|
||||||
|
Compiles and fits the tf.keras `_model` and records the `_history`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
preprocessor: An optional data preprocessor that can be used when
|
||||||
|
generating a tf.data.Dataset.
|
||||||
|
"""
|
||||||
|
tf.compat.v1.logging.info('Training the models...')
|
||||||
|
if len(train_data) < self._hparams.batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f'The size of the train_data {len(train_data)} can\'t be smaller than'
|
||||||
|
f' batch_size {self._hparams.batch_size}. To solve this problem, set'
|
||||||
|
' the batch_size smaller or increase the size of the train_data.')
|
||||||
|
|
||||||
|
train_dataset = train_data.gen_tf_dataset(
|
||||||
|
batch_size=self._hparams.batch_size,
|
||||||
|
is_training=True,
|
||||||
|
shuffle=self._shuffle,
|
||||||
|
preprocess=preprocessor)
|
||||||
|
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||||
|
steps_per_epoch=self._hparams.steps_per_epoch,
|
||||||
|
batch_size=self._hparams.batch_size,
|
||||||
|
train_data=train_data)
|
||||||
|
train_dataset = train_dataset.take(count=self._hparams.steps_per_epoch)
|
||||||
|
validation_dataset = validation_data.gen_tf_dataset(
|
||||||
|
batch_size=self._hparams.batch_size,
|
||||||
|
is_training=False,
|
||||||
|
preprocess=preprocessor)
|
||||||
|
self._model.compile(
|
||||||
|
optimizer=self._optimizer,
|
||||||
|
loss=self._loss_function,
|
||||||
|
metrics=[self._metric_function])
|
||||||
|
self._history = self._model.fit(
|
||||||
|
x=train_dataset,
|
||||||
|
epochs=self._hparams.epochs,
|
||||||
|
validation_data=validation_dataset,
|
||||||
|
callbacks=self._callbacks)
|
||||||
|
|
||||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||||
"""Evaluates the classifier with the provided evaluation dataset.
|
"""Evaluates the classifier with the provided evaluation dataset.
|
||||||
|
|
|
@ -193,7 +193,7 @@ class ImageClassifier(classifier.Classifier):
|
||||||
tflite_model,
|
tflite_model,
|
||||||
self._model_spec.mean_rgb,
|
self._model_spec.mean_rgb,
|
||||||
self._model_spec.stddev_rgb,
|
self._model_spec.stddev_rgb,
|
||||||
labels=metadata_writer.Labels().add(self._label_names))
|
labels=metadata_writer.Labels().add(list(self._label_names)))
|
||||||
tflite_model_with_metadata, metadata_json = writer.populate()
|
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||||
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||||
with open(metadata_file, 'w') as f:
|
with open(metadata_file, 'w') as f:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user