Refactor image classifier Create
API by defining an ImageClassifierOption
which wraps the model_spec, model_options, and hparams. Also migrate the definition of HParams by extending it from BaseHParams.
PiperOrigin-RevId: 486796657
This commit is contained in:
parent
1049ef781d
commit
c3bb4bb5da
|
@ -19,7 +19,6 @@ import tempfile
|
|||
from typing import Optional
|
||||
|
||||
|
||||
# TODO: Integrate this class into ImageClassifier and other tasks.
|
||||
@dataclasses.dataclass
|
||||
class BaseHParams:
|
||||
"""Hyperparameters used for training models.
|
||||
|
|
|
@ -28,6 +28,8 @@ py_library(
|
|||
":dataset",
|
||||
":hyperparameters",
|
||||
":image_classifier",
|
||||
":image_classifier_options",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
],
|
||||
)
|
||||
|
@ -58,6 +60,24 @@ py_test(
|
|||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_options",
|
||||
srcs = ["model_options.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_classifier_options",
|
||||
srcs = ["image_classifier_options.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -74,6 +94,8 @@ py_library(
|
|||
srcs = ["image_classifier.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":image_classifier_options",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":train_image_classifier_lib",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
|
@ -99,6 +121,7 @@ py_library(
|
|||
|
||||
py_test(
|
||||
name = "image_classifier_test",
|
||||
size = "large",
|
||||
srcs = ["image_classifier_test.py"],
|
||||
shard_count = 2,
|
||||
tags = ["requires-net:external"],
|
||||
|
|
|
@ -16,10 +16,14 @@
|
|||
from mediapipe.model_maker.python.vision.image_classifier import dataset
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.image_classifier import image_classifier
|
||||
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_options
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec
|
||||
|
||||
ImageClassifier = image_classifier.ImageClassifier
|
||||
HParams = hyperparameters.HParams
|
||||
Dataset = dataset.Dataset
|
||||
ModelOptions = model_options.ImageClassifierModelOptions
|
||||
ModelSpec = model_spec.ModelSpec
|
||||
SupportedModels = model_spec.SupportedModels
|
||||
ImageClassifierOptions = image_classifier_options.ImageClassifierOptions
|
||||
|
|
|
@ -14,28 +14,20 @@
|
|||
"""Hyperparameters for training image classification models."""
|
||||
|
||||
import dataclasses
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
|
||||
|
||||
# TODO: Expose other hyperparameters, e.g. data augmentation
|
||||
# hyperparameters if requested.
|
||||
@dataclasses.dataclass
|
||||
class HParams:
|
||||
class HParams(hp.BaseHParams):
|
||||
"""The hyperparameters for training image classifiers.
|
||||
|
||||
The hyperparameters include:
|
||||
# Parameters about training data.
|
||||
Attributes:
|
||||
learning_rate: Learning rate to use for gradient descent training.
|
||||
batch_size: Batch size for training.
|
||||
epochs: Number of training iterations over the dataset.
|
||||
do_fine_tuning: If true, the base module is trained together with the
|
||||
classification layer on top.
|
||||
shuffle: A boolean controlling if shuffle the dataset. Default to false.
|
||||
|
||||
# Parameters about training configuration
|
||||
train_epochs: Training will do this many iterations over the dataset.
|
||||
batch_size: Each training step samples a batch of this many images.
|
||||
learning_rate: The learning rate to use for gradient descent training.
|
||||
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||
layer.
|
||||
l1_regularizer: A regularizer that applies a L1 regularization penalty.
|
||||
l2_regularizer: A regularizer that applies a L2 regularization penalty.
|
||||
label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
|
||||
|
@ -43,32 +35,21 @@ class HParams:
|
|||
do_data_augmentation: A boolean controlling whether the training dataset is
|
||||
augmented by randomly distorting input images, including random cropping,
|
||||
flipping, etc. See utils.image_preprocessing documentation for details.
|
||||
steps_per_epoch: An optional integer indicate the number of training steps
|
||||
per epoch. If not set, the training pipeline calculates the default steps
|
||||
per epoch as the training dataset size devided by batch size.
|
||||
decay_samples: Number of training samples used to calculate the decay steps
|
||||
and create the training optimizer.
|
||||
warmup_steps: Number of warmup steps for a linear increasing warmup schedule
|
||||
on learning rate. Used to set up warmup schedule by model_util.WarmUp.
|
||||
|
||||
# Parameters about the saved checkpoint
|
||||
model_dir: The location of model checkpoint files and exported model files.
|
||||
on learning rate. Used to set up warmup schedule by model_util.WarmUp.s
|
||||
"""
|
||||
# Parameters about training data
|
||||
do_fine_tuning: bool = False
|
||||
shuffle: bool = False
|
||||
# Parameters from BaseHParams class.
|
||||
learning_rate: float = 0.001
|
||||
batch_size: int = 2
|
||||
epochs: int = 10
|
||||
# Parameters about training configuration
|
||||
train_epochs: int = 5
|
||||
batch_size: int = 32
|
||||
learning_rate: float = 0.005
|
||||
dropout_rate: float = 0.2
|
||||
do_fine_tuning: bool = False
|
||||
l1_regularizer: float = 0.0
|
||||
l2_regularizer: float = 0.0001
|
||||
label_smoothing: float = 0.1
|
||||
do_data_augmentation: bool = True
|
||||
steps_per_epoch: Optional[int] = None
|
||||
# TODO: Use lr_decay in hp.baseHParams to infer decay_samples.
|
||||
decay_samples: int = 10000 * 256
|
||||
warmup_epochs: int = 2
|
||||
|
||||
# Parameters about the saved checkpoint
|
||||
model_dir: str = tempfile.mkdtemp()
|
||||
|
|
|
@ -25,6 +25,8 @@ from mediapipe.model_maker.python.core.utils import model_util
|
|||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.vision.core import image_preprocessing
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
|
||||
|
@ -35,17 +37,20 @@ class ImageClassifier(classifier.Classifier):
|
|||
"""ImageClassifier for building image classification model."""
|
||||
|
||||
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
|
||||
hparams: hp.HParams):
|
||||
hparams: hp.HParams,
|
||||
model_options: model_opt.ImageClassifierModelOptions):
|
||||
"""Initializes ImageClassifier class.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
label_names: A list of label names for the classes.
|
||||
hparams: The hyperparameters for training image classifier.
|
||||
model_options: Model options for creating image classifier.
|
||||
"""
|
||||
super().__init__(
|
||||
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
||||
self._hparams = hparams
|
||||
self._model_options = model_options
|
||||
self._preprocess = image_preprocessing.Preprocessor(
|
||||
input_shape=self._model_spec.input_image_shape,
|
||||
num_classes=self._num_classes,
|
||||
|
@ -57,30 +62,34 @@ class ImageClassifier(classifier.Classifier):
|
|||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_spec: ms.SupportedModels,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
hparams: Optional[hp.HParams] = None,
|
||||
options: image_classifier_options.ImageClassifierOptions,
|
||||
) -> 'ImageClassifier':
|
||||
"""Creates and trains an image classifier.
|
||||
|
||||
Loads data and trains the model based on data for image classification.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
hparams: Hyperparameters for training image classifier.
|
||||
options: configuration to create image classifier.
|
||||
|
||||
Returns:
|
||||
An instance based on ImageClassifier.
|
||||
"""
|
||||
if hparams is None:
|
||||
hparams = hp.HParams()
|
||||
if options.hparams is None:
|
||||
options.hparams = hp.HParams()
|
||||
|
||||
spec = ms.SupportedModels.get(model_spec)
|
||||
if options.model_options is None:
|
||||
options.model_options = model_opt.ImageClassifierModelOptions()
|
||||
|
||||
spec = ms.SupportedModels.get(options.supported_model)
|
||||
image_classifier = cls(
|
||||
model_spec=spec, label_names=train_data.label_names, hparams=hparams)
|
||||
model_spec=spec,
|
||||
label_names=train_data.label_names,
|
||||
hparams=options.hparams,
|
||||
model_options=options.model_options)
|
||||
|
||||
image_classifier._create_model()
|
||||
|
||||
|
@ -90,6 +99,7 @@ class ImageClassifier(classifier.Classifier):
|
|||
|
||||
return image_classifier
|
||||
|
||||
# TODO: Migrate to the shared training library of Model Maker.
|
||||
def _train(self, train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset):
|
||||
"""Trains the model with input train_data.
|
||||
|
@ -142,7 +152,7 @@ class ImageClassifier(classifier.Classifier):
|
|||
|
||||
self._model = tf.keras.Sequential([
|
||||
tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
|
||||
tf.keras.layers.Dropout(rate=self._hparams.dropout_rate),
|
||||
tf.keras.layers.Dropout(rate=self._model_options.dropout_rate),
|
||||
tf.keras.layers.Dense(
|
||||
units=self._num_classes,
|
||||
activation='softmax',
|
||||
|
@ -167,10 +177,10 @@ class ImageClassifier(classifier.Classifier):
|
|||
path is {self._hparams.model_dir}/{model_name}.
|
||||
quantization_config: The configuration for model quantization.
|
||||
"""
|
||||
if not tf.io.gfile.exists(self._hparams.model_dir):
|
||||
tf.io.gfile.makedirs(self._hparams.model_dir)
|
||||
tflite_file = os.path.join(self._hparams.model_dir, model_name)
|
||||
metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
|
||||
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
|
||||
|
||||
tflite_model = model_util.convert_to_tflite(
|
||||
model=self._model,
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Options for building image classifier."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImageClassifierOptions:
|
||||
"""Configurable options for building image classifier.
|
||||
|
||||
Attributes:
|
||||
supported_model: A model from the SupportedModels enum.
|
||||
model_options: A set of options for configuring the selected model.
|
||||
hparams: A set of hyperparameters used to train the image classifier.
|
||||
"""
|
||||
supported_model: model_spec.SupportedModels
|
||||
model_options: Optional[model_opt.ImageClassifierModelOptions] = None
|
||||
hparams: Optional[hyperparameters.HParams] = None
|
|
@ -15,6 +15,7 @@
|
|||
import filecmp
|
||||
import os
|
||||
|
||||
from unittest import mock
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
@ -54,54 +55,59 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
super(ImageClassifierTest, self).setUp()
|
||||
all_data = self._gen_cmy_data()
|
||||
# Splits data, 90% data for training, 10% for testing
|
||||
self.train_data, self.test_data = all_data.split(0.9)
|
||||
self._train_data, self._test_data = all_data.split(0.9)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='mobilenet_v2',
|
||||
model_spec=image_classifier.SupportedModels.MOBILENET_V2,
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=image_classifier.SupportedModels.MOBILENET_V2,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
epochs=1, batch_size=1, shuffle=True))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite0',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=(
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
epochs=1, batch_size=1, shuffle=True))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite0_change_dropout_rate',
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=(
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
||||
model_options=image_classifier.ModelOptions(dropout_rate=0.1),
|
||||
hparams=image_classifier.HParams(
|
||||
epochs=1, batch_size=1, shuffle=True))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite2',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=(
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE2),
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
epochs=1, batch_size=1, shuffle=True))),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite4',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
|
||||
options=image_classifier.ImageClassifierOptions(
|
||||
supported_model=(
|
||||
image_classifier.SupportedModels.EFFICIENTNET_LITE4),
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
epochs=1, batch_size=1, shuffle=True))),
|
||||
)
|
||||
def test_create_and_train_model(self,
|
||||
model_spec: image_classifier.SupportedModels,
|
||||
hparams: image_classifier.HParams):
|
||||
def test_create_and_train_model(
|
||||
self, options: image_classifier.ImageClassifierOptions):
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
model_spec=model_spec,
|
||||
train_data=self.train_data,
|
||||
hparams=hparams,
|
||||
validation_data=self.test_data)
|
||||
self._test_accuracy(model)
|
||||
|
||||
def test_efficientnetlite0_model_train_and_export(self):
|
||||
hparams = image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||
train_data=self.train_data,
|
||||
hparams=hparams,
|
||||
validation_data=self.test_data)
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=options)
|
||||
self._test_accuracy(model)
|
||||
|
||||
# Test export_model
|
||||
model.export_model()
|
||||
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
|
||||
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
|
||||
output_metadata_file = os.path.join(options.hparams.export_dir,
|
||||
'metadata.json')
|
||||
output_tflite_file = os.path.join(options.hparams.export_dir,
|
||||
'model.tflite')
|
||||
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
|
||||
|
||||
self.assertTrue(os.path.exists(output_tflite_file))
|
||||
|
@ -112,9 +118,30 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
||||
|
||||
def _test_accuracy(self, model, threshold=0.0):
|
||||
_, accuracy = model.evaluate(self.test_data)
|
||||
_, accuracy = model.evaluate(self._test_data)
|
||||
self.assertGreaterEqual(accuracy, threshold)
|
||||
|
||||
@mock.patch.object(
|
||||
image_classifier.hyperparameters,
|
||||
'HParams',
|
||||
autospec=True,
|
||||
return_value=image_classifier.HParams(epochs=1))
|
||||
@mock.patch.object(
|
||||
image_classifier.model_options,
|
||||
'ImageClassifierModelOptions',
|
||||
autospec=True,
|
||||
return_value=image_classifier.ModelOptions())
|
||||
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
|
||||
self, mock_hparams, mock_model_options):
|
||||
options = image_classifier.ImageClassifierOptions(
|
||||
supported_model=(image_classifier.SupportedModels.EFFICIENTNET_LITE0))
|
||||
image_classifier.ImageClassifier.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=options)
|
||||
mock_hparams.assert_called_once()
|
||||
mock_model_options.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load compressed models from tensorflow_hub
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Configurable model options for image classifier models."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImageClassifierModelOptions:
|
||||
"""Configurable options for image classifier model.
|
||||
|
||||
Attributes:
|
||||
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||
layer.
|
||||
"""
|
||||
dropout_rate: float = 0.2
|
|
@ -49,13 +49,14 @@ def _create_optimizer(init_lr: float, decay_steps: int,
|
|||
return optimizer
|
||||
|
||||
|
||||
def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]:
|
||||
def _get_default_callbacks(
|
||||
export_dir: str) -> List[tf.keras.callbacks.Callback]:
|
||||
"""Gets default callbacks."""
|
||||
summary_dir = os.path.join(model_dir, 'summaries')
|
||||
summary_dir = os.path.join(export_dir, 'summaries')
|
||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||
# Save checkpoint every 20 epochs.
|
||||
|
||||
checkpoint_path = os.path.join(model_dir, 'checkpoint')
|
||||
checkpoint_path = os.path.join(export_dir, 'checkpoint')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
checkpoint_path, save_weights_only=True, period=20)
|
||||
return [summary_callback, checkpoint_callback]
|
||||
|
@ -81,7 +82,8 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
|||
learning_rate = hparams.learning_rate * hparams.batch_size / 256
|
||||
|
||||
# Get decay steps.
|
||||
total_training_steps = hparams.steps_per_epoch * hparams.train_epochs
|
||||
# NOMUTANTS--(b/256493858):Plan to test it in the unified training library.
|
||||
total_training_steps = hparams.steps_per_epoch * hparams.epochs
|
||||
default_decay_steps = hparams.decay_samples // hparams.batch_size
|
||||
decay_steps = max(total_training_steps, default_decay_steps)
|
||||
|
||||
|
@ -92,11 +94,11 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
|||
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||
label_smoothing=hparams.label_smoothing)
|
||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||
callbacks = _get_default_callbacks(hparams.model_dir)
|
||||
callbacks = _get_default_callbacks(export_dir=hparams.export_dir)
|
||||
|
||||
# Train the model.
|
||||
return model.fit(
|
||||
x=train_ds,
|
||||
epochs=hparams.train_epochs,
|
||||
epochs=hparams.epochs,
|
||||
validation_data=validation_ds,
|
||||
callbacks=callbacks)
|
||||
|
|
Loading…
Reference in New Issue
Block a user