Adds a BertClassifier.
				
					
				
			PiperOrigin-RevId: 487086744
This commit is contained in:
		
							parent
							
								
									669d539551
								
							
						
					
					
						commit
						c31aaa94a6
					
				| 
						 | 
					@ -45,7 +45,10 @@ py_library(
 | 
				
			||||||
    srcs = ["classifier.py"],
 | 
					    srcs = ["classifier.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":custom_model",
 | 
					        ":custom_model",
 | 
				
			||||||
 | 
					        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
				
			||||||
 | 
					        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/data:dataset",
 | 
					        "//mediapipe/model_maker/python/core/data:dataset",
 | 
				
			||||||
 | 
					        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,24 +13,24 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
"""Custom classifier."""
 | 
					"""Custom classifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from __future__ import absolute_import
 | 
					 | 
				
			||||||
from __future__ import division
 | 
					 | 
				
			||||||
from __future__ import print_function
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from typing import Any, List
 | 
					from typing import Any, Callable, Optional, Sequence, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
 | 
				
			||||||
from mediapipe.model_maker.python.core.data import dataset
 | 
					from mediapipe.model_maker.python.core.data import dataset
 | 
				
			||||||
from mediapipe.model_maker.python.core.tasks import custom_model
 | 
					from mediapipe.model_maker.python.core.tasks import custom_model
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core.utils import model_util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Classifier(custom_model.CustomModel):
 | 
					class Classifier(custom_model.CustomModel):
 | 
				
			||||||
  """An abstract base class that represents a TensorFlow classifier."""
 | 
					  """An abstract base class that represents a TensorFlow classifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool):
 | 
					  def __init__(self, model_spec: Any, label_names: Sequence[str],
 | 
				
			||||||
    """Initilizes a classifier with its specifications.
 | 
					               shuffle: bool):
 | 
				
			||||||
 | 
					    """Initializes a classifier with its specifications.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
        model_spec: Specification for the model.
 | 
					        model_spec: Specification for the model.
 | 
				
			||||||
| 
						 | 
					@ -40,6 +40,59 @@ class Classifier(custom_model.CustomModel):
 | 
				
			||||||
    super(Classifier, self).__init__(model_spec, shuffle)
 | 
					    super(Classifier, self).__init__(model_spec, shuffle)
 | 
				
			||||||
    self._label_names = label_names
 | 
					    self._label_names = label_names
 | 
				
			||||||
    self._num_classes = len(label_names)
 | 
					    self._num_classes = len(label_names)
 | 
				
			||||||
 | 
					    self._model: tf.keras.Model = None
 | 
				
			||||||
 | 
					    self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
 | 
				
			||||||
 | 
					    self._loss_function: Union[str, tf.keras.losses.Loss] = None
 | 
				
			||||||
 | 
					    self._metric_function: Union[str, tf.keras.metrics.Metric] = None
 | 
				
			||||||
 | 
					    self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
 | 
				
			||||||
 | 
					    self._hparams: hp.BaseHParams = None
 | 
				
			||||||
 | 
					    self._history: tf.keras.callbacks.History = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # TODO: Integrate this into all Model Maker tasks.
 | 
				
			||||||
 | 
					  def _train_model(self,
 | 
				
			||||||
 | 
					                   train_data: classification_ds.ClassificationDataset,
 | 
				
			||||||
 | 
					                   validation_data: classification_ds.ClassificationDataset,
 | 
				
			||||||
 | 
					                   preprocessor: Optional[Callable[..., bool]] = None):
 | 
				
			||||||
 | 
					    """Trains the classifier model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Compiles and fits the tf.keras `_model` and records the `_history`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      train_data: Training data.
 | 
				
			||||||
 | 
					      validation_data: Validation data.
 | 
				
			||||||
 | 
					      preprocessor: An optional data preprocessor that can be used when
 | 
				
			||||||
 | 
					        generating a tf.data.Dataset.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    tf.compat.v1.logging.info('Training the models...')
 | 
				
			||||||
 | 
					    if len(train_data) < self._hparams.batch_size:
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          f'The size of the train_data {len(train_data)} can\'t be smaller than'
 | 
				
			||||||
 | 
					          f' batch_size {self._hparams.batch_size}. To solve this problem, set'
 | 
				
			||||||
 | 
					          ' the batch_size smaller or increase the size of the train_data.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_dataset = train_data.gen_tf_dataset(
 | 
				
			||||||
 | 
					        batch_size=self._hparams.batch_size,
 | 
				
			||||||
 | 
					        is_training=True,
 | 
				
			||||||
 | 
					        shuffle=self._shuffle,
 | 
				
			||||||
 | 
					        preprocess=preprocessor)
 | 
				
			||||||
 | 
					    self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
 | 
				
			||||||
 | 
					        steps_per_epoch=self._hparams.steps_per_epoch,
 | 
				
			||||||
 | 
					        batch_size=self._hparams.batch_size,
 | 
				
			||||||
 | 
					        train_data=train_data)
 | 
				
			||||||
 | 
					    train_dataset = train_dataset.take(count=self._hparams.steps_per_epoch)
 | 
				
			||||||
 | 
					    validation_dataset = validation_data.gen_tf_dataset(
 | 
				
			||||||
 | 
					        batch_size=self._hparams.batch_size,
 | 
				
			||||||
 | 
					        is_training=False,
 | 
				
			||||||
 | 
					        preprocess=preprocessor)
 | 
				
			||||||
 | 
					    self._model.compile(
 | 
				
			||||||
 | 
					        optimizer=self._optimizer,
 | 
				
			||||||
 | 
					        loss=self._loss_function,
 | 
				
			||||||
 | 
					        metrics=[self._metric_function])
 | 
				
			||||||
 | 
					    self._history = self._model.fit(
 | 
				
			||||||
 | 
					        x=train_dataset,
 | 
				
			||||||
 | 
					        epochs=self._hparams.epochs,
 | 
				
			||||||
 | 
					        validation_data=validation_dataset,
 | 
				
			||||||
 | 
					        callbacks=self._callbacks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
 | 
					  def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
 | 
				
			||||||
    """Evaluates the classifier with the provided evaluation dataset.
 | 
					    """Evaluates the classifier with the provided evaluation dataset.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -193,7 +193,7 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
        tflite_model,
 | 
					        tflite_model,
 | 
				
			||||||
        self._model_spec.mean_rgb,
 | 
					        self._model_spec.mean_rgb,
 | 
				
			||||||
        self._model_spec.stddev_rgb,
 | 
					        self._model_spec.stddev_rgb,
 | 
				
			||||||
        labels=metadata_writer.Labels().add(self._label_names))
 | 
					        labels=metadata_writer.Labels().add(list(self._label_names)))
 | 
				
			||||||
    tflite_model_with_metadata, metadata_json = writer.populate()
 | 
					    tflite_model_with_metadata, metadata_json = writer.populate()
 | 
				
			||||||
    model_util.save_tflite(tflite_model_with_metadata, tflite_file)
 | 
					    model_util.save_tflite(tflite_model_with_metadata, tflite_file)
 | 
				
			||||||
    with open(metadata_file, 'w') as f:
 | 
					    with open(metadata_file, 'w') as f:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user