Refactor image classifier Create API by defining an ImageClassifierOption which wraps the model_spec, model_options, and hparams. Also migrate the definition of HParams by extending it from BaseHParams.
				
					
				
			PiperOrigin-RevId: 486796657
This commit is contained in:
		
							parent
							
								
									1049ef781d
								
							
						
					
					
						commit
						c3bb4bb5da
					
				| 
						 | 
					@ -19,7 +19,6 @@ import tempfile
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: Integrate this class into ImageClassifier and other tasks.
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					@dataclasses.dataclass
 | 
				
			||||||
class BaseHParams:
 | 
					class BaseHParams:
 | 
				
			||||||
  """Hyperparameters used for training models.
 | 
					  """Hyperparameters used for training models.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -28,6 +28,8 @@ py_library(
 | 
				
			||||||
        ":dataset",
 | 
					        ":dataset",
 | 
				
			||||||
        ":hyperparameters",
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":image_classifier",
 | 
					        ":image_classifier",
 | 
				
			||||||
 | 
					        ":image_classifier_options",
 | 
				
			||||||
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -58,6 +60,24 @@ py_test(
 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
    name = "hyperparameters",
 | 
					    name = "hyperparameters",
 | 
				
			||||||
    srcs = ["hyperparameters.py"],
 | 
					    srcs = ["hyperparameters.py"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "model_options",
 | 
				
			||||||
 | 
					    srcs = ["model_options.py"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "image_classifier_options",
 | 
				
			||||||
 | 
					    srcs = ["image_classifier_options.py"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":hyperparameters",
 | 
				
			||||||
 | 
					        ":model_options",
 | 
				
			||||||
 | 
					        ":model_spec",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
| 
						 | 
					@ -74,6 +94,8 @@ py_library(
 | 
				
			||||||
    srcs = ["image_classifier.py"],
 | 
					    srcs = ["image_classifier.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":hyperparameters",
 | 
					        ":hyperparameters",
 | 
				
			||||||
 | 
					        ":image_classifier_options",
 | 
				
			||||||
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
        ":train_image_classifier_lib",
 | 
					        ":train_image_classifier_lib",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
					        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
				
			||||||
| 
						 | 
					@ -99,6 +121,7 @@ py_library(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_test(
 | 
					py_test(
 | 
				
			||||||
    name = "image_classifier_test",
 | 
					    name = "image_classifier_test",
 | 
				
			||||||
 | 
					    size = "large",
 | 
				
			||||||
    srcs = ["image_classifier_test.py"],
 | 
					    srcs = ["image_classifier_test.py"],
 | 
				
			||||||
    shard_count = 2,
 | 
					    shard_count = 2,
 | 
				
			||||||
    tags = ["requires-net:external"],
 | 
					    tags = ["requires-net:external"],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,10 +16,14 @@
 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import dataset
 | 
					from mediapipe.model_maker.python.vision.image_classifier import dataset
 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
 | 
					from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import image_classifier
 | 
					from mediapipe.model_maker.python.vision.image_classifier import image_classifier
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.image_classifier import model_options
 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec
 | 
					from mediapipe.model_maker.python.vision.image_classifier import model_spec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ImageClassifier = image_classifier.ImageClassifier
 | 
					ImageClassifier = image_classifier.ImageClassifier
 | 
				
			||||||
HParams = hyperparameters.HParams
 | 
					HParams = hyperparameters.HParams
 | 
				
			||||||
Dataset = dataset.Dataset
 | 
					Dataset = dataset.Dataset
 | 
				
			||||||
 | 
					ModelOptions = model_options.ImageClassifierModelOptions
 | 
				
			||||||
ModelSpec = model_spec.ModelSpec
 | 
					ModelSpec = model_spec.ModelSpec
 | 
				
			||||||
SupportedModels = model_spec.SupportedModels
 | 
					SupportedModels = model_spec.SupportedModels
 | 
				
			||||||
 | 
					ImageClassifierOptions = image_classifier_options.ImageClassifierOptions
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,28 +14,20 @@
 | 
				
			||||||
"""Hyperparameters for training image classification models."""
 | 
					"""Hyperparameters for training image classification models."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import dataclasses
 | 
					import dataclasses
 | 
				
			||||||
import tempfile
 | 
					
 | 
				
			||||||
from typing import Optional
 | 
					from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: Expose other hyperparameters, e.g. data augmentation
 | 
					 | 
				
			||||||
# hyperparameters if requested.
 | 
					 | 
				
			||||||
@dataclasses.dataclass
 | 
					@dataclasses.dataclass
 | 
				
			||||||
class HParams:
 | 
					class HParams(hp.BaseHParams):
 | 
				
			||||||
  """The hyperparameters for training image classifiers.
 | 
					  """The hyperparameters for training image classifiers.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  The hyperparameters include:
 | 
					  Attributes:
 | 
				
			||||||
    # Parameters about training data.
 | 
					    learning_rate: Learning rate to use for gradient descent training.
 | 
				
			||||||
 | 
					    batch_size: Batch size for training.
 | 
				
			||||||
 | 
					    epochs: Number of training iterations over the dataset.
 | 
				
			||||||
    do_fine_tuning: If true, the base module is trained together with the
 | 
					    do_fine_tuning: If true, the base module is trained together with the
 | 
				
			||||||
      classification layer on top.
 | 
					      classification layer on top.
 | 
				
			||||||
    shuffle: A boolean controlling if shuffle the dataset. Default to false.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Parameters about training configuration
 | 
					 | 
				
			||||||
    train_epochs: Training will do this many iterations over the dataset.
 | 
					 | 
				
			||||||
    batch_size: Each training step samples a batch of this many images.
 | 
					 | 
				
			||||||
    learning_rate: The learning rate to use for gradient descent training.
 | 
					 | 
				
			||||||
    dropout_rate: The fraction of the input units to drop, used in dropout
 | 
					 | 
				
			||||||
      layer.
 | 
					 | 
				
			||||||
    l1_regularizer: A regularizer that applies a L1 regularization penalty.
 | 
					    l1_regularizer: A regularizer that applies a L1 regularization penalty.
 | 
				
			||||||
    l2_regularizer: A regularizer that applies a L2 regularization penalty.
 | 
					    l2_regularizer: A regularizer that applies a L2 regularization penalty.
 | 
				
			||||||
    label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
 | 
					    label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
 | 
				
			||||||
| 
						 | 
					@ -43,32 +35,21 @@ class HParams:
 | 
				
			||||||
    do_data_augmentation: A boolean controlling whether the training dataset is
 | 
					    do_data_augmentation: A boolean controlling whether the training dataset is
 | 
				
			||||||
      augmented by randomly distorting input images, including random cropping,
 | 
					      augmented by randomly distorting input images, including random cropping,
 | 
				
			||||||
      flipping, etc. See utils.image_preprocessing documentation for details.
 | 
					      flipping, etc. See utils.image_preprocessing documentation for details.
 | 
				
			||||||
    steps_per_epoch: An optional integer indicate the number of training steps
 | 
					 | 
				
			||||||
      per epoch. If not set, the training pipeline calculates the default steps
 | 
					 | 
				
			||||||
      per epoch as the training dataset size devided by batch size.
 | 
					 | 
				
			||||||
    decay_samples: Number of training samples used to calculate the decay steps
 | 
					    decay_samples: Number of training samples used to calculate the decay steps
 | 
				
			||||||
      and create the training optimizer.
 | 
					      and create the training optimizer.
 | 
				
			||||||
    warmup_steps: Number of warmup steps for a linear increasing warmup schedule
 | 
					    warmup_steps: Number of warmup steps for a linear increasing warmup schedule
 | 
				
			||||||
       on learning rate. Used to set up warmup schedule by model_util.WarmUp.
 | 
					       on learning rate. Used to set up warmup schedule by model_util.WarmUp.s
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Parameters about the saved checkpoint
 | 
					 | 
				
			||||||
    model_dir: The location of model checkpoint files and exported model files.
 | 
					 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
  # Parameters about training data
 | 
					  # Parameters from BaseHParams class.
 | 
				
			||||||
  do_fine_tuning: bool = False
 | 
					  learning_rate: float = 0.001
 | 
				
			||||||
  shuffle: bool = False
 | 
					  batch_size: int = 2
 | 
				
			||||||
 | 
					  epochs: int = 10
 | 
				
			||||||
  # Parameters about training configuration
 | 
					  # Parameters about training configuration
 | 
				
			||||||
  train_epochs: int = 5
 | 
					  do_fine_tuning: bool = False
 | 
				
			||||||
  batch_size: int = 32
 | 
					 | 
				
			||||||
  learning_rate: float = 0.005
 | 
					 | 
				
			||||||
  dropout_rate: float = 0.2
 | 
					 | 
				
			||||||
  l1_regularizer: float = 0.0
 | 
					  l1_regularizer: float = 0.0
 | 
				
			||||||
  l2_regularizer: float = 0.0001
 | 
					  l2_regularizer: float = 0.0001
 | 
				
			||||||
  label_smoothing: float = 0.1
 | 
					  label_smoothing: float = 0.1
 | 
				
			||||||
  do_data_augmentation: bool = True
 | 
					  do_data_augmentation: bool = True
 | 
				
			||||||
  steps_per_epoch: Optional[int] = None
 | 
					  # TODO: Use lr_decay in hp.baseHParams to infer decay_samples.
 | 
				
			||||||
  decay_samples: int = 10000 * 256
 | 
					  decay_samples: int = 10000 * 256
 | 
				
			||||||
  warmup_epochs: int = 2
 | 
					  warmup_epochs: int = 2
 | 
				
			||||||
 | 
					 | 
				
			||||||
  # Parameters about the saved checkpoint
 | 
					 | 
				
			||||||
  model_dir: str = tempfile.mkdtemp()
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,6 +25,8 @@ from mediapipe.model_maker.python.core.utils import model_util
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import quantization
 | 
					from mediapipe.model_maker.python.core.utils import quantization
 | 
				
			||||||
from mediapipe.model_maker.python.vision.core import image_preprocessing
 | 
					from mediapipe.model_maker.python.vision.core import image_preprocessing
 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
 | 
					from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
 | 
					from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
 | 
				
			||||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
 | 
					from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
 | 
				
			||||||
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
 | 
					from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
 | 
				
			||||||
| 
						 | 
					@ -35,17 +37,20 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
  """ImageClassifier for building image classification model."""
 | 
					  """ImageClassifier for building image classification model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
 | 
					  def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
 | 
				
			||||||
               hparams: hp.HParams):
 | 
					               hparams: hp.HParams,
 | 
				
			||||||
 | 
					               model_options: model_opt.ImageClassifierModelOptions):
 | 
				
			||||||
    """Initializes ImageClassifier class.
 | 
					    """Initializes ImageClassifier class.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
      model_spec: Specification for the model.
 | 
					      model_spec: Specification for the model.
 | 
				
			||||||
      label_names: A list of label names for the classes.
 | 
					      label_names: A list of label names for the classes.
 | 
				
			||||||
      hparams: The hyperparameters for training image classifier.
 | 
					      hparams: The hyperparameters for training image classifier.
 | 
				
			||||||
 | 
					      model_options: Model options for creating image classifier.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    super().__init__(
 | 
					    super().__init__(
 | 
				
			||||||
        model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
 | 
					        model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
 | 
				
			||||||
    self._hparams = hparams
 | 
					    self._hparams = hparams
 | 
				
			||||||
 | 
					    self._model_options = model_options
 | 
				
			||||||
    self._preprocess = image_preprocessing.Preprocessor(
 | 
					    self._preprocess = image_preprocessing.Preprocessor(
 | 
				
			||||||
        input_shape=self._model_spec.input_image_shape,
 | 
					        input_shape=self._model_spec.input_image_shape,
 | 
				
			||||||
        num_classes=self._num_classes,
 | 
					        num_classes=self._num_classes,
 | 
				
			||||||
| 
						 | 
					@ -57,30 +62,34 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
  @classmethod
 | 
					  @classmethod
 | 
				
			||||||
  def create(
 | 
					  def create(
 | 
				
			||||||
      cls,
 | 
					      cls,
 | 
				
			||||||
      model_spec: ms.SupportedModels,
 | 
					 | 
				
			||||||
      train_data: classification_ds.ClassificationDataset,
 | 
					      train_data: classification_ds.ClassificationDataset,
 | 
				
			||||||
      validation_data: classification_ds.ClassificationDataset,
 | 
					      validation_data: classification_ds.ClassificationDataset,
 | 
				
			||||||
      hparams: Optional[hp.HParams] = None,
 | 
					      options: image_classifier_options.ImageClassifierOptions,
 | 
				
			||||||
  ) -> 'ImageClassifier':
 | 
					  ) -> 'ImageClassifier':
 | 
				
			||||||
    """Creates and trains an image classifier.
 | 
					    """Creates and trains an image classifier.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Loads data and trains the model based on data for image classification.
 | 
					    Loads data and trains the model based on data for image classification.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
      model_spec: Specification for the model.
 | 
					 | 
				
			||||||
      train_data: Training data.
 | 
					      train_data: Training data.
 | 
				
			||||||
      validation_data: Validation data.
 | 
					      validation_data: Validation data.
 | 
				
			||||||
      hparams: Hyperparameters for training image classifier.
 | 
					      options: configuration to create image classifier.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Returns:
 | 
					    Returns:
 | 
				
			||||||
      An instance based on ImageClassifier.
 | 
					      An instance based on ImageClassifier.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if hparams is None:
 | 
					    if options.hparams is None:
 | 
				
			||||||
      hparams = hp.HParams()
 | 
					      options.hparams = hp.HParams()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    spec = ms.SupportedModels.get(model_spec)
 | 
					    if options.model_options is None:
 | 
				
			||||||
 | 
					      options.model_options = model_opt.ImageClassifierModelOptions()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    spec = ms.SupportedModels.get(options.supported_model)
 | 
				
			||||||
    image_classifier = cls(
 | 
					    image_classifier = cls(
 | 
				
			||||||
        model_spec=spec, label_names=train_data.label_names, hparams=hparams)
 | 
					        model_spec=spec,
 | 
				
			||||||
 | 
					        label_names=train_data.label_names,
 | 
				
			||||||
 | 
					        hparams=options.hparams,
 | 
				
			||||||
 | 
					        model_options=options.model_options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    image_classifier._create_model()
 | 
					    image_classifier._create_model()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -90,6 +99,7 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return image_classifier
 | 
					    return image_classifier
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # TODO: Migrate to the shared training library of Model Maker.
 | 
				
			||||||
  def _train(self, train_data: classification_ds.ClassificationDataset,
 | 
					  def _train(self, train_data: classification_ds.ClassificationDataset,
 | 
				
			||||||
             validation_data: classification_ds.ClassificationDataset):
 | 
					             validation_data: classification_ds.ClassificationDataset):
 | 
				
			||||||
    """Trains the model with input train_data.
 | 
					    """Trains the model with input train_data.
 | 
				
			||||||
| 
						 | 
					@ -142,7 +152,7 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    self._model = tf.keras.Sequential([
 | 
					    self._model = tf.keras.Sequential([
 | 
				
			||||||
        tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
 | 
					        tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
 | 
				
			||||||
        tf.keras.layers.Dropout(rate=self._hparams.dropout_rate),
 | 
					        tf.keras.layers.Dropout(rate=self._model_options.dropout_rate),
 | 
				
			||||||
        tf.keras.layers.Dense(
 | 
					        tf.keras.layers.Dense(
 | 
				
			||||||
            units=self._num_classes,
 | 
					            units=self._num_classes,
 | 
				
			||||||
            activation='softmax',
 | 
					            activation='softmax',
 | 
				
			||||||
| 
						 | 
					@ -167,10 +177,10 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
        path is {self._hparams.model_dir}/{model_name}.
 | 
					        path is {self._hparams.model_dir}/{model_name}.
 | 
				
			||||||
      quantization_config: The configuration for model quantization.
 | 
					      quantization_config: The configuration for model quantization.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if not tf.io.gfile.exists(self._hparams.model_dir):
 | 
					    if not tf.io.gfile.exists(self._hparams.export_dir):
 | 
				
			||||||
      tf.io.gfile.makedirs(self._hparams.model_dir)
 | 
					      tf.io.gfile.makedirs(self._hparams.export_dir)
 | 
				
			||||||
    tflite_file = os.path.join(self._hparams.model_dir, model_name)
 | 
					    tflite_file = os.path.join(self._hparams.export_dir, model_name)
 | 
				
			||||||
    metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
 | 
					    metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    tflite_model = model_util.convert_to_tflite(
 | 
					    tflite_model = model_util.convert_to_tflite(
 | 
				
			||||||
        model=self._model,
 | 
					        model=self._model,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,35 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Options for building image classifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import dataclasses
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.image_classifier import model_spec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class ImageClassifierOptions:
 | 
				
			||||||
 | 
					  """Configurable options for building image classifier.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Attributes:
 | 
				
			||||||
 | 
					    supported_model: A model from the SupportedModels enum.
 | 
				
			||||||
 | 
					    model_options: A set of options for configuring the selected model.
 | 
				
			||||||
 | 
					    hparams: A set of hyperparameters used to train the image classifier.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  supported_model: model_spec.SupportedModels
 | 
				
			||||||
 | 
					  model_options: Optional[model_opt.ImageClassifierModelOptions] = None
 | 
				
			||||||
 | 
					  hparams: Optional[hyperparameters.HParams] = None
 | 
				
			||||||
| 
						 | 
					@ -15,6 +15,7 @@
 | 
				
			||||||
import filecmp
 | 
					import filecmp
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from unittest import mock
 | 
				
			||||||
from absl.testing import parameterized
 | 
					from absl.testing import parameterized
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
| 
						 | 
					@ -54,54 +55,59 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
    super(ImageClassifierTest, self).setUp()
 | 
					    super(ImageClassifierTest, self).setUp()
 | 
				
			||||||
    all_data = self._gen_cmy_data()
 | 
					    all_data = self._gen_cmy_data()
 | 
				
			||||||
    # Splits data, 90% data for training, 10% for testing
 | 
					    # Splits data, 90% data for training, 10% for testing
 | 
				
			||||||
    self.train_data, self.test_data = all_data.split(0.9)
 | 
					    self._train_data, self._test_data = all_data.split(0.9)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @parameterized.named_parameters(
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
      dict(
 | 
					      dict(
 | 
				
			||||||
          testcase_name='mobilenet_v2',
 | 
					          testcase_name='mobilenet_v2',
 | 
				
			||||||
          model_spec=image_classifier.SupportedModels.MOBILENET_V2,
 | 
					          options=image_classifier.ImageClassifierOptions(
 | 
				
			||||||
 | 
					              supported_model=image_classifier.SupportedModels.MOBILENET_V2,
 | 
				
			||||||
              hparams=image_classifier.HParams(
 | 
					              hparams=image_classifier.HParams(
 | 
				
			||||||
              train_epochs=1, batch_size=1, shuffle=True)),
 | 
					                  epochs=1, batch_size=1, shuffle=True))),
 | 
				
			||||||
      dict(
 | 
					      dict(
 | 
				
			||||||
          testcase_name='efficientnet_lite0',
 | 
					          testcase_name='efficientnet_lite0',
 | 
				
			||||||
          model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
 | 
					          options=image_classifier.ImageClassifierOptions(
 | 
				
			||||||
 | 
					              supported_model=(
 | 
				
			||||||
 | 
					                  image_classifier.SupportedModels.EFFICIENTNET_LITE0),
 | 
				
			||||||
              hparams=image_classifier.HParams(
 | 
					              hparams=image_classifier.HParams(
 | 
				
			||||||
              train_epochs=1, batch_size=1, shuffle=True)),
 | 
					                  epochs=1, batch_size=1, shuffle=True))),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='efficientnet_lite0_change_dropout_rate',
 | 
				
			||||||
 | 
					          options=image_classifier.ImageClassifierOptions(
 | 
				
			||||||
 | 
					              supported_model=(
 | 
				
			||||||
 | 
					                  image_classifier.SupportedModels.EFFICIENTNET_LITE0),
 | 
				
			||||||
 | 
					              model_options=image_classifier.ModelOptions(dropout_rate=0.1),
 | 
				
			||||||
 | 
					              hparams=image_classifier.HParams(
 | 
				
			||||||
 | 
					                  epochs=1, batch_size=1, shuffle=True))),
 | 
				
			||||||
      dict(
 | 
					      dict(
 | 
				
			||||||
          testcase_name='efficientnet_lite2',
 | 
					          testcase_name='efficientnet_lite2',
 | 
				
			||||||
          model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
 | 
					          options=image_classifier.ImageClassifierOptions(
 | 
				
			||||||
 | 
					              supported_model=(
 | 
				
			||||||
 | 
					                  image_classifier.SupportedModels.EFFICIENTNET_LITE2),
 | 
				
			||||||
              hparams=image_classifier.HParams(
 | 
					              hparams=image_classifier.HParams(
 | 
				
			||||||
              train_epochs=1, batch_size=1, shuffle=True)),
 | 
					                  epochs=1, batch_size=1, shuffle=True))),
 | 
				
			||||||
      dict(
 | 
					      dict(
 | 
				
			||||||
          testcase_name='efficientnet_lite4',
 | 
					          testcase_name='efficientnet_lite4',
 | 
				
			||||||
          model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
 | 
					          options=image_classifier.ImageClassifierOptions(
 | 
				
			||||||
 | 
					              supported_model=(
 | 
				
			||||||
 | 
					                  image_classifier.SupportedModels.EFFICIENTNET_LITE4),
 | 
				
			||||||
              hparams=image_classifier.HParams(
 | 
					              hparams=image_classifier.HParams(
 | 
				
			||||||
              train_epochs=1, batch_size=1, shuffle=True)),
 | 
					                  epochs=1, batch_size=1, shuffle=True))),
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  def test_create_and_train_model(self,
 | 
					  def test_create_and_train_model(
 | 
				
			||||||
                                  model_spec: image_classifier.SupportedModels,
 | 
					      self, options: image_classifier.ImageClassifierOptions):
 | 
				
			||||||
                                  hparams: image_classifier.HParams):
 | 
					 | 
				
			||||||
    model = image_classifier.ImageClassifier.create(
 | 
					    model = image_classifier.ImageClassifier.create(
 | 
				
			||||||
        model_spec=model_spec,
 | 
					        train_data=self._train_data,
 | 
				
			||||||
        train_data=self.train_data,
 | 
					        validation_data=self._test_data,
 | 
				
			||||||
        hparams=hparams,
 | 
					        options=options)
 | 
				
			||||||
        validation_data=self.test_data)
 | 
					 | 
				
			||||||
    self._test_accuracy(model)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  def test_efficientnetlite0_model_train_and_export(self):
 | 
					 | 
				
			||||||
    hparams = image_classifier.HParams(
 | 
					 | 
				
			||||||
        train_epochs=1, batch_size=1, shuffle=True)
 | 
					 | 
				
			||||||
    model = image_classifier.ImageClassifier.create(
 | 
					 | 
				
			||||||
        model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
 | 
					 | 
				
			||||||
        train_data=self.train_data,
 | 
					 | 
				
			||||||
        hparams=hparams,
 | 
					 | 
				
			||||||
        validation_data=self.test_data)
 | 
					 | 
				
			||||||
    self._test_accuracy(model)
 | 
					    self._test_accuracy(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Test export_model
 | 
					    # Test export_model
 | 
				
			||||||
    model.export_model()
 | 
					    model.export_model()
 | 
				
			||||||
    output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
 | 
					    output_metadata_file = os.path.join(options.hparams.export_dir,
 | 
				
			||||||
    output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
 | 
					                                        'metadata.json')
 | 
				
			||||||
 | 
					    output_tflite_file = os.path.join(options.hparams.export_dir,
 | 
				
			||||||
 | 
					                                      'model.tflite')
 | 
				
			||||||
    expected_metadata_file = test_utils.get_test_data_path('metadata.json')
 | 
					    expected_metadata_file = test_utils.get_test_data_path('metadata.json')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    self.assertTrue(os.path.exists(output_tflite_file))
 | 
					    self.assertTrue(os.path.exists(output_tflite_file))
 | 
				
			||||||
| 
						 | 
					@ -112,9 +118,30 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
    self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
 | 
					    self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _test_accuracy(self, model, threshold=0.0):
 | 
					  def _test_accuracy(self, model, threshold=0.0):
 | 
				
			||||||
    _, accuracy = model.evaluate(self.test_data)
 | 
					    _, accuracy = model.evaluate(self._test_data)
 | 
				
			||||||
    self.assertGreaterEqual(accuracy, threshold)
 | 
					    self.assertGreaterEqual(accuracy, threshold)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @mock.patch.object(
 | 
				
			||||||
 | 
					      image_classifier.hyperparameters,
 | 
				
			||||||
 | 
					      'HParams',
 | 
				
			||||||
 | 
					      autospec=True,
 | 
				
			||||||
 | 
					      return_value=image_classifier.HParams(epochs=1))
 | 
				
			||||||
 | 
					  @mock.patch.object(
 | 
				
			||||||
 | 
					      image_classifier.model_options,
 | 
				
			||||||
 | 
					      'ImageClassifierModelOptions',
 | 
				
			||||||
 | 
					      autospec=True,
 | 
				
			||||||
 | 
					      return_value=image_classifier.ModelOptions())
 | 
				
			||||||
 | 
					  def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
 | 
				
			||||||
 | 
					      self, mock_hparams, mock_model_options):
 | 
				
			||||||
 | 
					    options = image_classifier.ImageClassifierOptions(
 | 
				
			||||||
 | 
					        supported_model=(image_classifier.SupportedModels.EFFICIENTNET_LITE0))
 | 
				
			||||||
 | 
					    image_classifier.ImageClassifier.create(
 | 
				
			||||||
 | 
					        train_data=self._train_data,
 | 
				
			||||||
 | 
					        validation_data=self._test_data,
 | 
				
			||||||
 | 
					        options=options)
 | 
				
			||||||
 | 
					    mock_hparams.assert_called_once()
 | 
				
			||||||
 | 
					    mock_model_options.assert_called_once()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
  # Load compressed models from tensorflow_hub
 | 
					  # Load compressed models from tensorflow_hub
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,27 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Configurable model options for image classifier models."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import dataclasses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class ImageClassifierModelOptions:
 | 
				
			||||||
 | 
					  """Configurable options for image classifier model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Attributes:
 | 
				
			||||||
 | 
					    dropout_rate: The fraction of the input units to drop, used in dropout
 | 
				
			||||||
 | 
					      layer.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  dropout_rate: float = 0.2
 | 
				
			||||||
| 
						 | 
					@ -49,13 +49,14 @@ def _create_optimizer(init_lr: float, decay_steps: int,
 | 
				
			||||||
  return optimizer
 | 
					  return optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]:
 | 
					def _get_default_callbacks(
 | 
				
			||||||
 | 
					    export_dir: str) -> List[tf.keras.callbacks.Callback]:
 | 
				
			||||||
  """Gets default callbacks."""
 | 
					  """Gets default callbacks."""
 | 
				
			||||||
  summary_dir = os.path.join(model_dir, 'summaries')
 | 
					  summary_dir = os.path.join(export_dir, 'summaries')
 | 
				
			||||||
  summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
 | 
					  summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
 | 
				
			||||||
  # Save checkpoint every 20 epochs.
 | 
					  # Save checkpoint every 20 epochs.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  checkpoint_path = os.path.join(model_dir, 'checkpoint')
 | 
					  checkpoint_path = os.path.join(export_dir, 'checkpoint')
 | 
				
			||||||
  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
 | 
					  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
 | 
				
			||||||
      checkpoint_path, save_weights_only=True, period=20)
 | 
					      checkpoint_path, save_weights_only=True, period=20)
 | 
				
			||||||
  return [summary_callback, checkpoint_callback]
 | 
					  return [summary_callback, checkpoint_callback]
 | 
				
			||||||
| 
						 | 
					@ -81,7 +82,8 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
 | 
				
			||||||
  learning_rate = hparams.learning_rate * hparams.batch_size / 256
 | 
					  learning_rate = hparams.learning_rate * hparams.batch_size / 256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Get decay steps.
 | 
					  # Get decay steps.
 | 
				
			||||||
  total_training_steps = hparams.steps_per_epoch * hparams.train_epochs
 | 
					  # NOMUTANTS--(b/256493858):Plan to test it in the unified training library.
 | 
				
			||||||
 | 
					  total_training_steps = hparams.steps_per_epoch * hparams.epochs
 | 
				
			||||||
  default_decay_steps = hparams.decay_samples // hparams.batch_size
 | 
					  default_decay_steps = hparams.decay_samples // hparams.batch_size
 | 
				
			||||||
  decay_steps = max(total_training_steps, default_decay_steps)
 | 
					  decay_steps = max(total_training_steps, default_decay_steps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -92,11 +94,11 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
 | 
				
			||||||
  loss = tf.keras.losses.CategoricalCrossentropy(
 | 
					  loss = tf.keras.losses.CategoricalCrossentropy(
 | 
				
			||||||
      label_smoothing=hparams.label_smoothing)
 | 
					      label_smoothing=hparams.label_smoothing)
 | 
				
			||||||
  model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
 | 
					  model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
 | 
				
			||||||
  callbacks = _get_default_callbacks(hparams.model_dir)
 | 
					  callbacks = _get_default_callbacks(export_dir=hparams.export_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Train the model.
 | 
					  # Train the model.
 | 
				
			||||||
  return model.fit(
 | 
					  return model.fit(
 | 
				
			||||||
      x=train_ds,
 | 
					      x=train_ds,
 | 
				
			||||||
      epochs=hparams.train_epochs,
 | 
					      epochs=hparams.epochs,
 | 
				
			||||||
      validation_data=validation_ds,
 | 
					      validation_data=validation_ds,
 | 
				
			||||||
      callbacks=callbacks)
 | 
					      callbacks=callbacks)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user