Refactor image classifier Create API by defining an ImageClassifierOption which wraps the model_spec, model_options, and hparams. Also migrate the definition of HParams by extending it from BaseHParams.

PiperOrigin-RevId: 486796657
This commit is contained in:
MediaPipe Team 2022-11-07 16:37:40 -08:00 committed by Copybara-Service
parent 1049ef781d
commit c3bb4bb5da
9 changed files with 195 additions and 87 deletions

View File

@ -19,7 +19,6 @@ import tempfile
from typing import Optional from typing import Optional
# TODO: Integrate this class into ImageClassifier and other tasks.
@dataclasses.dataclass @dataclasses.dataclass
class BaseHParams: class BaseHParams:
"""Hyperparameters used for training models. """Hyperparameters used for training models.

View File

@ -28,6 +28,8 @@ py_library(
":dataset", ":dataset",
":hyperparameters", ":hyperparameters",
":image_classifier", ":image_classifier",
":image_classifier_options",
":model_options",
":model_spec", ":model_spec",
], ],
) )
@ -58,6 +60,24 @@ py_test(
py_library( py_library(
name = "hyperparameters", name = "hyperparameters",
srcs = ["hyperparameters.py"], srcs = ["hyperparameters.py"],
deps = [
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
name = "model_options",
srcs = ["model_options.py"],
)
py_library(
name = "image_classifier_options",
srcs = ["image_classifier_options.py"],
deps = [
":hyperparameters",
":model_options",
":model_spec",
],
) )
py_library( py_library(
@ -74,6 +94,8 @@ py_library(
srcs = ["image_classifier.py"], srcs = ["image_classifier.py"],
deps = [ deps = [
":hyperparameters", ":hyperparameters",
":image_classifier_options",
":model_options",
":model_spec", ":model_spec",
":train_image_classifier_lib", ":train_image_classifier_lib",
"//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/data:classification_dataset",
@ -99,6 +121,7 @@ py_library(
py_test( py_test(
name = "image_classifier_test", name = "image_classifier_test",
size = "large",
srcs = ["image_classifier_test.py"], srcs = ["image_classifier_test.py"],
shard_count = 2, shard_count = 2,
tags = ["requires-net:external"], tags = ["requires-net:external"],

View File

@ -16,10 +16,14 @@
from mediapipe.model_maker.python.vision.image_classifier import dataset from mediapipe.model_maker.python.vision.image_classifier import dataset
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import image_classifier from mediapipe.model_maker.python.vision.image_classifier import image_classifier
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
from mediapipe.model_maker.python.vision.image_classifier import model_options
from mediapipe.model_maker.python.vision.image_classifier import model_spec from mediapipe.model_maker.python.vision.image_classifier import model_spec
ImageClassifier = image_classifier.ImageClassifier ImageClassifier = image_classifier.ImageClassifier
HParams = hyperparameters.HParams HParams = hyperparameters.HParams
Dataset = dataset.Dataset Dataset = dataset.Dataset
ModelOptions = model_options.ImageClassifierModelOptions
ModelSpec = model_spec.ModelSpec ModelSpec = model_spec.ModelSpec
SupportedModels = model_spec.SupportedModels SupportedModels = model_spec.SupportedModels
ImageClassifierOptions = image_classifier_options.ImageClassifierOptions

View File

@ -14,28 +14,20 @@
"""Hyperparameters for training image classification models.""" """Hyperparameters for training image classification models."""
import dataclasses import dataclasses
import tempfile
from typing import Optional from mediapipe.model_maker.python.core import hyperparameters as hp
# TODO: Expose other hyperparameters, e.g. data augmentation
# hyperparameters if requested.
@dataclasses.dataclass @dataclasses.dataclass
class HParams: class HParams(hp.BaseHParams):
"""The hyperparameters for training image classifiers. """The hyperparameters for training image classifiers.
The hyperparameters include: Attributes:
# Parameters about training data. learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
do_fine_tuning: If true, the base module is trained together with the do_fine_tuning: If true, the base module is trained together with the
classification layer on top. classification layer on top.
shuffle: A boolean controlling if shuffle the dataset. Default to false.
# Parameters about training configuration
train_epochs: Training will do this many iterations over the dataset.
batch_size: Each training step samples a batch of this many images.
learning_rate: The learning rate to use for gradient descent training.
dropout_rate: The fraction of the input units to drop, used in dropout
layer.
l1_regularizer: A regularizer that applies a L1 regularization penalty. l1_regularizer: A regularizer that applies a L1 regularization penalty.
l2_regularizer: A regularizer that applies a L2 regularization penalty. l2_regularizer: A regularizer that applies a L2 regularization penalty.
label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
@ -43,32 +35,21 @@ class HParams:
do_data_augmentation: A boolean controlling whether the training dataset is do_data_augmentation: A boolean controlling whether the training dataset is
augmented by randomly distorting input images, including random cropping, augmented by randomly distorting input images, including random cropping,
flipping, etc. See utils.image_preprocessing documentation for details. flipping, etc. See utils.image_preprocessing documentation for details.
steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size devided by batch size.
decay_samples: Number of training samples used to calculate the decay steps decay_samples: Number of training samples used to calculate the decay steps
and create the training optimizer. and create the training optimizer.
warmup_steps: Number of warmup steps for a linear increasing warmup schedule warmup_steps: Number of warmup steps for a linear increasing warmup schedule
on learning rate. Used to set up warmup schedule by model_util.WarmUp. on learning rate. Used to set up warmup schedule by model_util.WarmUp.s
# Parameters about the saved checkpoint
model_dir: The location of model checkpoint files and exported model files.
""" """
# Parameters about training data # Parameters from BaseHParams class.
do_fine_tuning: bool = False learning_rate: float = 0.001
shuffle: bool = False batch_size: int = 2
epochs: int = 10
# Parameters about training configuration # Parameters about training configuration
train_epochs: int = 5 do_fine_tuning: bool = False
batch_size: int = 32
learning_rate: float = 0.005
dropout_rate: float = 0.2
l1_regularizer: float = 0.0 l1_regularizer: float = 0.0
l2_regularizer: float = 0.0001 l2_regularizer: float = 0.0001
label_smoothing: float = 0.1 label_smoothing: float = 0.1
do_data_augmentation: bool = True do_data_augmentation: bool = True
steps_per_epoch: Optional[int] = None # TODO: Use lr_decay in hp.baseHParams to infer decay_samples.
decay_samples: int = 10000 * 256 decay_samples: int = 10000 * 256
warmup_epochs: int = 2 warmup_epochs: int = 2
# Parameters about the saved checkpoint
model_dir: str = tempfile.mkdtemp()

View File

@ -25,6 +25,8 @@ from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision.core import image_preprocessing from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
@ -35,17 +37,20 @@ class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model.""" """ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str], def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
hparams: hp.HParams): hparams: hp.HParams,
model_options: model_opt.ImageClassifierModelOptions):
"""Initializes ImageClassifier class. """Initializes ImageClassifier class.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
label_names: A list of label names for the classes. label_names: A list of label names for the classes.
hparams: The hyperparameters for training image classifier. hparams: The hyperparameters for training image classifier.
model_options: Model options for creating image classifier.
""" """
super().__init__( super().__init__(
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle) model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
self._hparams = hparams self._hparams = hparams
self._model_options = model_options
self._preprocess = image_preprocessing.Preprocessor( self._preprocess = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape, input_shape=self._model_spec.input_image_shape,
num_classes=self._num_classes, num_classes=self._num_classes,
@ -57,30 +62,34 @@ class ImageClassifier(classifier.Classifier):
@classmethod @classmethod
def create( def create(
cls, cls,
model_spec: ms.SupportedModels,
train_data: classification_ds.ClassificationDataset, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset, validation_data: classification_ds.ClassificationDataset,
hparams: Optional[hp.HParams] = None, options: image_classifier_options.ImageClassifierOptions,
) -> 'ImageClassifier': ) -> 'ImageClassifier':
"""Creates and trains an image classifier. """Creates and trains an image classifier.
Loads data and trains the model based on data for image classification. Loads data and trains the model based on data for image classification.
Args: Args:
model_spec: Specification for the model.
train_data: Training data. train_data: Training data.
validation_data: Validation data. validation_data: Validation data.
hparams: Hyperparameters for training image classifier. options: configuration to create image classifier.
Returns: Returns:
An instance based on ImageClassifier. An instance based on ImageClassifier.
""" """
if hparams is None: if options.hparams is None:
hparams = hp.HParams() options.hparams = hp.HParams()
spec = ms.SupportedModels.get(model_spec) if options.model_options is None:
options.model_options = model_opt.ImageClassifierModelOptions()
spec = ms.SupportedModels.get(options.supported_model)
image_classifier = cls( image_classifier = cls(
model_spec=spec, label_names=train_data.label_names, hparams=hparams) model_spec=spec,
label_names=train_data.label_names,
hparams=options.hparams,
model_options=options.model_options)
image_classifier._create_model() image_classifier._create_model()
@ -90,6 +99,7 @@ class ImageClassifier(classifier.Classifier):
return image_classifier return image_classifier
# TODO: Migrate to the shared training library of Model Maker.
def _train(self, train_data: classification_ds.ClassificationDataset, def _train(self, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset): validation_data: classification_ds.ClassificationDataset):
"""Trains the model with input train_data. """Trains the model with input train_data.
@ -142,7 +152,7 @@ class ImageClassifier(classifier.Classifier):
self._model = tf.keras.Sequential([ self._model = tf.keras.Sequential([
tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer, tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
tf.keras.layers.Dropout(rate=self._hparams.dropout_rate), tf.keras.layers.Dropout(rate=self._model_options.dropout_rate),
tf.keras.layers.Dense( tf.keras.layers.Dense(
units=self._num_classes, units=self._num_classes,
activation='softmax', activation='softmax',
@ -167,10 +177,10 @@ class ImageClassifier(classifier.Classifier):
path is {self._hparams.model_dir}/{model_name}. path is {self._hparams.model_dir}/{model_name}.
quantization_config: The configuration for model quantization. quantization_config: The configuration for model quantization.
""" """
if not tf.io.gfile.exists(self._hparams.model_dir): if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.model_dir) tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.model_dir, model_name) tflite_file = os.path.join(self._hparams.export_dir, model_name)
metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json') metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
tflite_model = model_util.convert_to_tflite( tflite_model = model_util.convert_to_tflite(
model=self._model, model=self._model,

View File

@ -0,0 +1,35 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Options for building image classifier."""
import dataclasses
from typing import Optional
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
from mediapipe.model_maker.python.vision.image_classifier import model_spec
@dataclasses.dataclass
class ImageClassifierOptions:
"""Configurable options for building image classifier.
Attributes:
supported_model: A model from the SupportedModels enum.
model_options: A set of options for configuring the selected model.
hparams: A set of hyperparameters used to train the image classifier.
"""
supported_model: model_spec.SupportedModels
model_options: Optional[model_opt.ImageClassifierModelOptions] = None
hparams: Optional[hyperparameters.HParams] = None

View File

@ -15,6 +15,7 @@
import filecmp import filecmp
import os import os
from unittest import mock
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -54,54 +55,59 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
super(ImageClassifierTest, self).setUp() super(ImageClassifierTest, self).setUp()
all_data = self._gen_cmy_data() all_data = self._gen_cmy_data()
# Splits data, 90% data for training, 10% for testing # Splits data, 90% data for training, 10% for testing
self.train_data, self.test_data = all_data.split(0.9) self._train_data, self._test_data = all_data.split(0.9)
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
testcase_name='mobilenet_v2', testcase_name='mobilenet_v2',
model_spec=image_classifier.SupportedModels.MOBILENET_V2, options=image_classifier.ImageClassifierOptions(
supported_model=image_classifier.SupportedModels.MOBILENET_V2,
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)), epochs=1, batch_size=1, shuffle=True))),
dict( dict(
testcase_name='efficientnet_lite0', testcase_name='efficientnet_lite0',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0, options=image_classifier.ImageClassifierOptions(
supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)), epochs=1, batch_size=1, shuffle=True))),
dict(
testcase_name='efficientnet_lite0_change_dropout_rate',
options=image_classifier.ImageClassifierOptions(
supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
model_options=image_classifier.ModelOptions(dropout_rate=0.1),
hparams=image_classifier.HParams(
epochs=1, batch_size=1, shuffle=True))),
dict( dict(
testcase_name='efficientnet_lite2', testcase_name='efficientnet_lite2',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2, options=image_classifier.ImageClassifierOptions(
supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE2),
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)), epochs=1, batch_size=1, shuffle=True))),
dict( dict(
testcase_name='efficientnet_lite4', testcase_name='efficientnet_lite4',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4, options=image_classifier.ImageClassifierOptions(
supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE4),
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)), epochs=1, batch_size=1, shuffle=True))),
) )
def test_create_and_train_model(self, def test_create_and_train_model(
model_spec: image_classifier.SupportedModels, self, options: image_classifier.ImageClassifierOptions):
hparams: image_classifier.HParams):
model = image_classifier.ImageClassifier.create( model = image_classifier.ImageClassifier.create(
model_spec=model_spec, train_data=self._train_data,
train_data=self.train_data, validation_data=self._test_data,
hparams=hparams, options=options)
validation_data=self.test_data)
self._test_accuracy(model)
def test_efficientnetlite0_model_train_and_export(self):
hparams = image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)
model = image_classifier.ImageClassifier.create(
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
train_data=self.train_data,
hparams=hparams,
validation_data=self.test_data)
self._test_accuracy(model) self._test_accuracy(model)
# Test export_model # Test export_model
model.export_model() model.export_model()
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json') output_metadata_file = os.path.join(options.hparams.export_dir,
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite') 'metadata.json')
output_tflite_file = os.path.join(options.hparams.export_dir,
'model.tflite')
expected_metadata_file = test_utils.get_test_data_path('metadata.json') expected_metadata_file = test_utils.get_test_data_path('metadata.json')
self.assertTrue(os.path.exists(output_tflite_file)) self.assertTrue(os.path.exists(output_tflite_file))
@ -112,9 +118,30 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
def _test_accuracy(self, model, threshold=0.0): def _test_accuracy(self, model, threshold=0.0):
_, accuracy = model.evaluate(self.test_data) _, accuracy = model.evaluate(self._test_data)
self.assertGreaterEqual(accuracy, threshold) self.assertGreaterEqual(accuracy, threshold)
@mock.patch.object(
image_classifier.hyperparameters,
'HParams',
autospec=True,
return_value=image_classifier.HParams(epochs=1))
@mock.patch.object(
image_classifier.model_options,
'ImageClassifierModelOptions',
autospec=True,
return_value=image_classifier.ModelOptions())
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
self, mock_hparams, mock_model_options):
options = image_classifier.ImageClassifierOptions(
supported_model=(image_classifier.SupportedModels.EFFICIENTNET_LITE0))
image_classifier.ImageClassifier.create(
train_data=self._train_data,
validation_data=self._test_data,
options=options)
mock_hparams.assert_called_once()
mock_model_options.assert_called_once()
if __name__ == '__main__': if __name__ == '__main__':
# Load compressed models from tensorflow_hub # Load compressed models from tensorflow_hub

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurable model options for image classifier models."""
import dataclasses
@dataclasses.dataclass
class ImageClassifierModelOptions:
"""Configurable options for image classifier model.
Attributes:
dropout_rate: The fraction of the input units to drop, used in dropout
layer.
"""
dropout_rate: float = 0.2

View File

@ -49,13 +49,14 @@ def _create_optimizer(init_lr: float, decay_steps: int,
return optimizer return optimizer
def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]: def _get_default_callbacks(
export_dir: str) -> List[tf.keras.callbacks.Callback]:
"""Gets default callbacks.""" """Gets default callbacks."""
summary_dir = os.path.join(model_dir, 'summaries') summary_dir = os.path.join(export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 20 epochs. # Save checkpoint every 20 epochs.
checkpoint_path = os.path.join(model_dir, 'checkpoint') checkpoint_path = os.path.join(export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True, period=20) checkpoint_path, save_weights_only=True, period=20)
return [summary_callback, checkpoint_callback] return [summary_callback, checkpoint_callback]
@ -81,7 +82,8 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
learning_rate = hparams.learning_rate * hparams.batch_size / 256 learning_rate = hparams.learning_rate * hparams.batch_size / 256
# Get decay steps. # Get decay steps.
total_training_steps = hparams.steps_per_epoch * hparams.train_epochs # NOMUTANTS--(b/256493858):Plan to test it in the unified training library.
total_training_steps = hparams.steps_per_epoch * hparams.epochs
default_decay_steps = hparams.decay_samples // hparams.batch_size default_decay_steps = hparams.decay_samples // hparams.batch_size
decay_steps = max(total_training_steps, default_decay_steps) decay_steps = max(total_training_steps, default_decay_steps)
@ -92,11 +94,11 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
loss = tf.keras.losses.CategoricalCrossentropy( loss = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=hparams.label_smoothing) label_smoothing=hparams.label_smoothing)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
callbacks = _get_default_callbacks(hparams.model_dir) callbacks = _get_default_callbacks(export_dir=hparams.export_dir)
# Train the model. # Train the model.
return model.fit( return model.fit(
x=train_ds, x=train_ds,
epochs=hparams.train_epochs, epochs=hparams.epochs,
validation_data=validation_ds, validation_data=validation_ds,
callbacks=callbacks) callbacks=callbacks)