Refactor image classifier Create API by defining an ImageClassifierOption which wraps the model_spec, model_options, and hparams. Also migrate the definition of HParams by extending it from BaseHParams.

PiperOrigin-RevId: 486796657
This commit is contained in:
MediaPipe Team 2022-11-07 16:37:40 -08:00 committed by Copybara-Service
parent 1049ef781d
commit c3bb4bb5da
9 changed files with 195 additions and 87 deletions

View File

@ -19,7 +19,6 @@ import tempfile
from typing import Optional
# TODO: Integrate this class into ImageClassifier and other tasks.
@dataclasses.dataclass
class BaseHParams:
"""Hyperparameters used for training models.

View File

@ -28,6 +28,8 @@ py_library(
":dataset",
":hyperparameters",
":image_classifier",
":image_classifier_options",
":model_options",
":model_spec",
],
)
@ -58,6 +60,24 @@ py_test(
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
deps = [
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
name = "model_options",
srcs = ["model_options.py"],
)
py_library(
name = "image_classifier_options",
srcs = ["image_classifier_options.py"],
deps = [
":hyperparameters",
":model_options",
":model_spec",
],
)
py_library(
@ -74,6 +94,8 @@ py_library(
srcs = ["image_classifier.py"],
deps = [
":hyperparameters",
":image_classifier_options",
":model_options",
":model_spec",
":train_image_classifier_lib",
"//mediapipe/model_maker/python/core/data:classification_dataset",
@ -99,6 +121,7 @@ py_library(
py_test(
name = "image_classifier_test",
size = "large",
srcs = ["image_classifier_test.py"],
shard_count = 2,
tags = ["requires-net:external"],

View File

@ -16,10 +16,14 @@
from mediapipe.model_maker.python.vision.image_classifier import dataset
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import image_classifier
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
from mediapipe.model_maker.python.vision.image_classifier import model_options
from mediapipe.model_maker.python.vision.image_classifier import model_spec
ImageClassifier = image_classifier.ImageClassifier
HParams = hyperparameters.HParams
Dataset = dataset.Dataset
ModelOptions = model_options.ImageClassifierModelOptions
ModelSpec = model_spec.ModelSpec
SupportedModels = model_spec.SupportedModels
ImageClassifierOptions = image_classifier_options.ImageClassifierOptions

View File

@ -14,28 +14,20 @@
"""Hyperparameters for training image classification models."""
import dataclasses
import tempfile
from typing import Optional
from mediapipe.model_maker.python.core import hyperparameters as hp
# TODO: Expose other hyperparameters, e.g. data augmentation
# hyperparameters if requested.
@dataclasses.dataclass
class HParams:
class HParams(hp.BaseHParams):
"""The hyperparameters for training image classifiers.
The hyperparameters include:
# Parameters about training data.
Attributes:
learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
do_fine_tuning: If true, the base module is trained together with the
classification layer on top.
shuffle: A boolean controlling if shuffle the dataset. Default to false.
# Parameters about training configuration
train_epochs: Training will do this many iterations over the dataset.
batch_size: Each training step samples a batch of this many images.
learning_rate: The learning rate to use for gradient descent training.
dropout_rate: The fraction of the input units to drop, used in dropout
layer.
l1_regularizer: A regularizer that applies a L1 regularization penalty.
l2_regularizer: A regularizer that applies a L2 regularization penalty.
label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
@ -43,32 +35,21 @@ class HParams:
do_data_augmentation: A boolean controlling whether the training dataset is
augmented by randomly distorting input images, including random cropping,
flipping, etc. See utils.image_preprocessing documentation for details.
steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size devided by batch size.
decay_samples: Number of training samples used to calculate the decay steps
and create the training optimizer.
warmup_steps: Number of warmup steps for a linear increasing warmup schedule
on learning rate. Used to set up warmup schedule by model_util.WarmUp.
# Parameters about the saved checkpoint
model_dir: The location of model checkpoint files and exported model files.
on learning rate. Used to set up warmup schedule by model_util.WarmUp.s
"""
# Parameters about training data
do_fine_tuning: bool = False
shuffle: bool = False
# Parameters from BaseHParams class.
learning_rate: float = 0.001
batch_size: int = 2
epochs: int = 10
# Parameters about training configuration
train_epochs: int = 5
batch_size: int = 32
learning_rate: float = 0.005
dropout_rate: float = 0.2
do_fine_tuning: bool = False
l1_regularizer: float = 0.0
l2_regularizer: float = 0.0001
label_smoothing: float = 0.1
do_data_augmentation: bool = True
steps_per_epoch: Optional[int] = None
# TODO: Use lr_decay in hp.baseHParams to infer decay_samples.
decay_samples: int = 10000 * 256
warmup_epochs: int = 2
# Parameters about the saved checkpoint
model_dir: str = tempfile.mkdtemp()

View File

@ -25,6 +25,8 @@ from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
@ -35,17 +37,20 @@ class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
hparams: hp.HParams):
hparams: hp.HParams,
model_options: model_opt.ImageClassifierModelOptions):
"""Initializes ImageClassifier class.
Args:
model_spec: Specification for the model.
label_names: A list of label names for the classes.
hparams: The hyperparameters for training image classifier.
model_options: Model options for creating image classifier.
"""
super().__init__(
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
self._hparams = hparams
self._model_options = model_options
self._preprocess = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape,
num_classes=self._num_classes,
@ -57,30 +62,34 @@ class ImageClassifier(classifier.Classifier):
@classmethod
def create(
cls,
model_spec: ms.SupportedModels,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
hparams: Optional[hp.HParams] = None,
options: image_classifier_options.ImageClassifierOptions,
) -> 'ImageClassifier':
"""Creates and trains an image classifier.
Loads data and trains the model based on data for image classification.
Args:
model_spec: Specification for the model.
train_data: Training data.
validation_data: Validation data.
hparams: Hyperparameters for training image classifier.
options: configuration to create image classifier.
Returns:
An instance based on ImageClassifier.
"""
if hparams is None:
hparams = hp.HParams()
if options.hparams is None:
options.hparams = hp.HParams()
spec = ms.SupportedModels.get(model_spec)
if options.model_options is None:
options.model_options = model_opt.ImageClassifierModelOptions()
spec = ms.SupportedModels.get(options.supported_model)
image_classifier = cls(
model_spec=spec, label_names=train_data.label_names, hparams=hparams)
model_spec=spec,
label_names=train_data.label_names,
hparams=options.hparams,
model_options=options.model_options)
image_classifier._create_model()
@ -90,6 +99,7 @@ class ImageClassifier(classifier.Classifier):
return image_classifier
# TODO: Migrate to the shared training library of Model Maker.
def _train(self, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset):
"""Trains the model with input train_data.
@ -142,7 +152,7 @@ class ImageClassifier(classifier.Classifier):
self._model = tf.keras.Sequential([
tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
tf.keras.layers.Dropout(rate=self._hparams.dropout_rate),
tf.keras.layers.Dropout(rate=self._model_options.dropout_rate),
tf.keras.layers.Dense(
units=self._num_classes,
activation='softmax',
@ -167,10 +177,10 @@ class ImageClassifier(classifier.Classifier):
path is {self._hparams.model_dir}/{model_name}.
quantization_config: The configuration for model quantization.
"""
if not tf.io.gfile.exists(self._hparams.model_dir):
tf.io.gfile.makedirs(self._hparams.model_dir)
tflite_file = os.path.join(self._hparams.model_dir, model_name)
metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.export_dir, model_name)
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
tflite_model = model_util.convert_to_tflite(
model=self._model,

View File

@ -0,0 +1,35 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Options for building image classifier."""
import dataclasses
from typing import Optional
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
from mediapipe.model_maker.python.vision.image_classifier import model_spec
@dataclasses.dataclass
class ImageClassifierOptions:
"""Configurable options for building image classifier.
Attributes:
supported_model: A model from the SupportedModels enum.
model_options: A set of options for configuring the selected model.
hparams: A set of hyperparameters used to train the image classifier.
"""
supported_model: model_spec.SupportedModels
model_options: Optional[model_opt.ImageClassifierModelOptions] = None
hparams: Optional[hyperparameters.HParams] = None

View File

@ -15,6 +15,7 @@
import filecmp
import os
from unittest import mock
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
@ -54,54 +55,59 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
super(ImageClassifierTest, self).setUp()
all_data = self._gen_cmy_data()
# Splits data, 90% data for training, 10% for testing
self.train_data, self.test_data = all_data.split(0.9)
self._train_data, self._test_data = all_data.split(0.9)
@parameterized.named_parameters(
dict(
testcase_name='mobilenet_v2',
model_spec=image_classifier.SupportedModels.MOBILENET_V2,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
options=image_classifier.ImageClassifierOptions(
supported_model=image_classifier.SupportedModels.MOBILENET_V2,
hparams=image_classifier.HParams(
epochs=1, batch_size=1, shuffle=True))),
dict(
testcase_name='efficientnet_lite0',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
options=image_classifier.ImageClassifierOptions(
supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
hparams=image_classifier.HParams(
epochs=1, batch_size=1, shuffle=True))),
dict(
testcase_name='efficientnet_lite0_change_dropout_rate',
options=image_classifier.ImageClassifierOptions(
supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
model_options=image_classifier.ModelOptions(dropout_rate=0.1),
hparams=image_classifier.HParams(
epochs=1, batch_size=1, shuffle=True))),
dict(
testcase_name='efficientnet_lite2',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
options=image_classifier.ImageClassifierOptions(
supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE2),
hparams=image_classifier.HParams(
epochs=1, batch_size=1, shuffle=True))),
dict(
testcase_name='efficientnet_lite4',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
options=image_classifier.ImageClassifierOptions(
supported_model=(
image_classifier.SupportedModels.EFFICIENTNET_LITE4),
hparams=image_classifier.HParams(
epochs=1, batch_size=1, shuffle=True))),
)
def test_create_and_train_model(self,
model_spec: image_classifier.SupportedModels,
hparams: image_classifier.HParams):
def test_create_and_train_model(
self, options: image_classifier.ImageClassifierOptions):
model = image_classifier.ImageClassifier.create(
model_spec=model_spec,
train_data=self.train_data,
hparams=hparams,
validation_data=self.test_data)
self._test_accuracy(model)
def test_efficientnetlite0_model_train_and_export(self):
hparams = image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)
model = image_classifier.ImageClassifier.create(
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
train_data=self.train_data,
hparams=hparams,
validation_data=self.test_data)
train_data=self._train_data,
validation_data=self._test_data,
options=options)
self._test_accuracy(model)
# Test export_model
model.export_model()
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
output_metadata_file = os.path.join(options.hparams.export_dir,
'metadata.json')
output_tflite_file = os.path.join(options.hparams.export_dir,
'model.tflite')
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
self.assertTrue(os.path.exists(output_tflite_file))
@ -112,9 +118,30 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
def _test_accuracy(self, model, threshold=0.0):
_, accuracy = model.evaluate(self.test_data)
_, accuracy = model.evaluate(self._test_data)
self.assertGreaterEqual(accuracy, threshold)
@mock.patch.object(
image_classifier.hyperparameters,
'HParams',
autospec=True,
return_value=image_classifier.HParams(epochs=1))
@mock.patch.object(
image_classifier.model_options,
'ImageClassifierModelOptions',
autospec=True,
return_value=image_classifier.ModelOptions())
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
self, mock_hparams, mock_model_options):
options = image_classifier.ImageClassifierOptions(
supported_model=(image_classifier.SupportedModels.EFFICIENTNET_LITE0))
image_classifier.ImageClassifier.create(
train_data=self._train_data,
validation_data=self._test_data,
options=options)
mock_hparams.assert_called_once()
mock_model_options.assert_called_once()
if __name__ == '__main__':
# Load compressed models from tensorflow_hub

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurable model options for image classifier models."""
import dataclasses
@dataclasses.dataclass
class ImageClassifierModelOptions:
"""Configurable options for image classifier model.
Attributes:
dropout_rate: The fraction of the input units to drop, used in dropout
layer.
"""
dropout_rate: float = 0.2

View File

@ -49,13 +49,14 @@ def _create_optimizer(init_lr: float, decay_steps: int,
return optimizer
def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]:
def _get_default_callbacks(
export_dir: str) -> List[tf.keras.callbacks.Callback]:
"""Gets default callbacks."""
summary_dir = os.path.join(model_dir, 'summaries')
summary_dir = os.path.join(export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 20 epochs.
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint_path = os.path.join(export_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True, period=20)
return [summary_callback, checkpoint_callback]
@ -81,7 +82,8 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
learning_rate = hparams.learning_rate * hparams.batch_size / 256
# Get decay steps.
total_training_steps = hparams.steps_per_epoch * hparams.train_epochs
# NOMUTANTS--(b/256493858):Plan to test it in the unified training library.
total_training_steps = hparams.steps_per_epoch * hparams.epochs
default_decay_steps = hparams.decay_samples // hparams.batch_size
decay_steps = max(total_training_steps, default_decay_steps)
@ -92,11 +94,11 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
loss = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=hparams.label_smoothing)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
callbacks = _get_default_callbacks(hparams.model_dir)
callbacks = _get_default_callbacks(export_dir=hparams.export_dir)
# Train the model.
return model.fit(
x=train_ds,
epochs=hparams.train_epochs,
epochs=hparams.epochs,
validation_data=validation_ds,
callbacks=callbacks)