Refactor image classifier Create
API by defining an ImageClassifierOption
which wraps the model_spec, model_options, and hparams. Also migrate the definition of HParams by extending it from BaseHParams.
PiperOrigin-RevId: 486796657
This commit is contained in:
parent
1049ef781d
commit
c3bb4bb5da
|
@ -19,7 +19,6 @@ import tempfile
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
# TODO: Integrate this class into ImageClassifier and other tasks.
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class BaseHParams:
|
class BaseHParams:
|
||||||
"""Hyperparameters used for training models.
|
"""Hyperparameters used for training models.
|
||||||
|
|
|
@ -28,6 +28,8 @@ py_library(
|
||||||
":dataset",
|
":dataset",
|
||||||
":hyperparameters",
|
":hyperparameters",
|
||||||
":image_classifier",
|
":image_classifier",
|
||||||
|
":image_classifier_options",
|
||||||
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -58,6 +60,24 @@ py_test(
|
||||||
py_library(
|
py_library(
|
||||||
name = "hyperparameters",
|
name = "hyperparameters",
|
||||||
srcs = ["hyperparameters.py"],
|
srcs = ["hyperparameters.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_options",
|
||||||
|
srcs = ["model_options.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "image_classifier_options",
|
||||||
|
srcs = ["image_classifier_options.py"],
|
||||||
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -74,6 +94,8 @@ py_library(
|
||||||
srcs = ["image_classifier.py"],
|
srcs = ["image_classifier.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":hyperparameters",
|
":hyperparameters",
|
||||||
|
":image_classifier_options",
|
||||||
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
":train_image_classifier_lib",
|
":train_image_classifier_lib",
|
||||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
@ -99,6 +121,7 @@ py_library(
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "image_classifier_test",
|
name = "image_classifier_test",
|
||||||
|
size = "large",
|
||||||
srcs = ["image_classifier_test.py"],
|
srcs = ["image_classifier_test.py"],
|
||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
tags = ["requires-net:external"],
|
tags = ["requires-net:external"],
|
||||||
|
|
|
@ -16,10 +16,14 @@
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import dataset
|
from mediapipe.model_maker.python.vision.image_classifier import dataset
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
|
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import image_classifier
|
from mediapipe.model_maker.python.vision.image_classifier import image_classifier
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import model_options
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec
|
from mediapipe.model_maker.python.vision.image_classifier import model_spec
|
||||||
|
|
||||||
ImageClassifier = image_classifier.ImageClassifier
|
ImageClassifier = image_classifier.ImageClassifier
|
||||||
HParams = hyperparameters.HParams
|
HParams = hyperparameters.HParams
|
||||||
Dataset = dataset.Dataset
|
Dataset = dataset.Dataset
|
||||||
|
ModelOptions = model_options.ImageClassifierModelOptions
|
||||||
ModelSpec = model_spec.ModelSpec
|
ModelSpec = model_spec.ModelSpec
|
||||||
SupportedModels = model_spec.SupportedModels
|
SupportedModels = model_spec.SupportedModels
|
||||||
|
ImageClassifierOptions = image_classifier_options.ImageClassifierOptions
|
||||||
|
|
|
@ -14,28 +14,20 @@
|
||||||
"""Hyperparameters for training image classification models."""
|
"""Hyperparameters for training image classification models."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import tempfile
|
|
||||||
from typing import Optional
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
|
||||||
|
|
||||||
# TODO: Expose other hyperparameters, e.g. data augmentation
|
|
||||||
# hyperparameters if requested.
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class HParams:
|
class HParams(hp.BaseHParams):
|
||||||
"""The hyperparameters for training image classifiers.
|
"""The hyperparameters for training image classifiers.
|
||||||
|
|
||||||
The hyperparameters include:
|
Attributes:
|
||||||
# Parameters about training data.
|
learning_rate: Learning rate to use for gradient descent training.
|
||||||
|
batch_size: Batch size for training.
|
||||||
|
epochs: Number of training iterations over the dataset.
|
||||||
do_fine_tuning: If true, the base module is trained together with the
|
do_fine_tuning: If true, the base module is trained together with the
|
||||||
classification layer on top.
|
classification layer on top.
|
||||||
shuffle: A boolean controlling if shuffle the dataset. Default to false.
|
|
||||||
|
|
||||||
# Parameters about training configuration
|
|
||||||
train_epochs: Training will do this many iterations over the dataset.
|
|
||||||
batch_size: Each training step samples a batch of this many images.
|
|
||||||
learning_rate: The learning rate to use for gradient descent training.
|
|
||||||
dropout_rate: The fraction of the input units to drop, used in dropout
|
|
||||||
layer.
|
|
||||||
l1_regularizer: A regularizer that applies a L1 regularization penalty.
|
l1_regularizer: A regularizer that applies a L1 regularization penalty.
|
||||||
l2_regularizer: A regularizer that applies a L2 regularization penalty.
|
l2_regularizer: A regularizer that applies a L2 regularization penalty.
|
||||||
label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
|
label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
|
||||||
|
@ -43,32 +35,21 @@ class HParams:
|
||||||
do_data_augmentation: A boolean controlling whether the training dataset is
|
do_data_augmentation: A boolean controlling whether the training dataset is
|
||||||
augmented by randomly distorting input images, including random cropping,
|
augmented by randomly distorting input images, including random cropping,
|
||||||
flipping, etc. See utils.image_preprocessing documentation for details.
|
flipping, etc. See utils.image_preprocessing documentation for details.
|
||||||
steps_per_epoch: An optional integer indicate the number of training steps
|
|
||||||
per epoch. If not set, the training pipeline calculates the default steps
|
|
||||||
per epoch as the training dataset size devided by batch size.
|
|
||||||
decay_samples: Number of training samples used to calculate the decay steps
|
decay_samples: Number of training samples used to calculate the decay steps
|
||||||
and create the training optimizer.
|
and create the training optimizer.
|
||||||
warmup_steps: Number of warmup steps for a linear increasing warmup schedule
|
warmup_steps: Number of warmup steps for a linear increasing warmup schedule
|
||||||
on learning rate. Used to set up warmup schedule by model_util.WarmUp.
|
on learning rate. Used to set up warmup schedule by model_util.WarmUp.s
|
||||||
|
|
||||||
# Parameters about the saved checkpoint
|
|
||||||
model_dir: The location of model checkpoint files and exported model files.
|
|
||||||
"""
|
"""
|
||||||
# Parameters about training data
|
# Parameters from BaseHParams class.
|
||||||
do_fine_tuning: bool = False
|
learning_rate: float = 0.001
|
||||||
shuffle: bool = False
|
batch_size: int = 2
|
||||||
|
epochs: int = 10
|
||||||
# Parameters about training configuration
|
# Parameters about training configuration
|
||||||
train_epochs: int = 5
|
do_fine_tuning: bool = False
|
||||||
batch_size: int = 32
|
|
||||||
learning_rate: float = 0.005
|
|
||||||
dropout_rate: float = 0.2
|
|
||||||
l1_regularizer: float = 0.0
|
l1_regularizer: float = 0.0
|
||||||
l2_regularizer: float = 0.0001
|
l2_regularizer: float = 0.0001
|
||||||
label_smoothing: float = 0.1
|
label_smoothing: float = 0.1
|
||||||
do_data_augmentation: bool = True
|
do_data_augmentation: bool = True
|
||||||
steps_per_epoch: Optional[int] = None
|
# TODO: Use lr_decay in hp.baseHParams to infer decay_samples.
|
||||||
decay_samples: int = 10000 * 256
|
decay_samples: int = 10000 * 256
|
||||||
warmup_epochs: int = 2
|
warmup_epochs: int = 2
|
||||||
|
|
||||||
# Parameters about the saved checkpoint
|
|
||||||
model_dir: str = tempfile.mkdtemp()
|
|
||||||
|
|
|
@ -25,6 +25,8 @@ from mediapipe.model_maker.python.core.utils import model_util
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
from mediapipe.model_maker.python.vision.core import image_preprocessing
|
from mediapipe.model_maker.python.vision.core import image_preprocessing
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import image_classifier_options
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
|
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
|
||||||
|
@ -35,17 +37,20 @@ class ImageClassifier(classifier.Classifier):
|
||||||
"""ImageClassifier for building image classification model."""
|
"""ImageClassifier for building image classification model."""
|
||||||
|
|
||||||
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
|
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
|
||||||
hparams: hp.HParams):
|
hparams: hp.HParams,
|
||||||
|
model_options: model_opt.ImageClassifierModelOptions):
|
||||||
"""Initializes ImageClassifier class.
|
"""Initializes ImageClassifier class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_spec: Specification for the model.
|
model_spec: Specification for the model.
|
||||||
label_names: A list of label names for the classes.
|
label_names: A list of label names for the classes.
|
||||||
hparams: The hyperparameters for training image classifier.
|
hparams: The hyperparameters for training image classifier.
|
||||||
|
model_options: Model options for creating image classifier.
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
||||||
self._hparams = hparams
|
self._hparams = hparams
|
||||||
|
self._model_options = model_options
|
||||||
self._preprocess = image_preprocessing.Preprocessor(
|
self._preprocess = image_preprocessing.Preprocessor(
|
||||||
input_shape=self._model_spec.input_image_shape,
|
input_shape=self._model_spec.input_image_shape,
|
||||||
num_classes=self._num_classes,
|
num_classes=self._num_classes,
|
||||||
|
@ -57,30 +62,34 @@ class ImageClassifier(classifier.Classifier):
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
model_spec: ms.SupportedModels,
|
|
||||||
train_data: classification_ds.ClassificationDataset,
|
train_data: classification_ds.ClassificationDataset,
|
||||||
validation_data: classification_ds.ClassificationDataset,
|
validation_data: classification_ds.ClassificationDataset,
|
||||||
hparams: Optional[hp.HParams] = None,
|
options: image_classifier_options.ImageClassifierOptions,
|
||||||
) -> 'ImageClassifier':
|
) -> 'ImageClassifier':
|
||||||
"""Creates and trains an image classifier.
|
"""Creates and trains an image classifier.
|
||||||
|
|
||||||
Loads data and trains the model based on data for image classification.
|
Loads data and trains the model based on data for image classification.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_spec: Specification for the model.
|
|
||||||
train_data: Training data.
|
train_data: Training data.
|
||||||
validation_data: Validation data.
|
validation_data: Validation data.
|
||||||
hparams: Hyperparameters for training image classifier.
|
options: configuration to create image classifier.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance based on ImageClassifier.
|
An instance based on ImageClassifier.
|
||||||
"""
|
"""
|
||||||
if hparams is None:
|
if options.hparams is None:
|
||||||
hparams = hp.HParams()
|
options.hparams = hp.HParams()
|
||||||
|
|
||||||
spec = ms.SupportedModels.get(model_spec)
|
if options.model_options is None:
|
||||||
|
options.model_options = model_opt.ImageClassifierModelOptions()
|
||||||
|
|
||||||
|
spec = ms.SupportedModels.get(options.supported_model)
|
||||||
image_classifier = cls(
|
image_classifier = cls(
|
||||||
model_spec=spec, label_names=train_data.label_names, hparams=hparams)
|
model_spec=spec,
|
||||||
|
label_names=train_data.label_names,
|
||||||
|
hparams=options.hparams,
|
||||||
|
model_options=options.model_options)
|
||||||
|
|
||||||
image_classifier._create_model()
|
image_classifier._create_model()
|
||||||
|
|
||||||
|
@ -90,6 +99,7 @@ class ImageClassifier(classifier.Classifier):
|
||||||
|
|
||||||
return image_classifier
|
return image_classifier
|
||||||
|
|
||||||
|
# TODO: Migrate to the shared training library of Model Maker.
|
||||||
def _train(self, train_data: classification_ds.ClassificationDataset,
|
def _train(self, train_data: classification_ds.ClassificationDataset,
|
||||||
validation_data: classification_ds.ClassificationDataset):
|
validation_data: classification_ds.ClassificationDataset):
|
||||||
"""Trains the model with input train_data.
|
"""Trains the model with input train_data.
|
||||||
|
@ -142,7 +152,7 @@ class ImageClassifier(classifier.Classifier):
|
||||||
|
|
||||||
self._model = tf.keras.Sequential([
|
self._model = tf.keras.Sequential([
|
||||||
tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
|
tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
|
||||||
tf.keras.layers.Dropout(rate=self._hparams.dropout_rate),
|
tf.keras.layers.Dropout(rate=self._model_options.dropout_rate),
|
||||||
tf.keras.layers.Dense(
|
tf.keras.layers.Dense(
|
||||||
units=self._num_classes,
|
units=self._num_classes,
|
||||||
activation='softmax',
|
activation='softmax',
|
||||||
|
@ -167,10 +177,10 @@ class ImageClassifier(classifier.Classifier):
|
||||||
path is {self._hparams.model_dir}/{model_name}.
|
path is {self._hparams.model_dir}/{model_name}.
|
||||||
quantization_config: The configuration for model quantization.
|
quantization_config: The configuration for model quantization.
|
||||||
"""
|
"""
|
||||||
if not tf.io.gfile.exists(self._hparams.model_dir):
|
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||||
tf.io.gfile.makedirs(self._hparams.model_dir)
|
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||||
tflite_file = os.path.join(self._hparams.model_dir, model_name)
|
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||||
metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
|
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
|
||||||
|
|
||||||
tflite_model = model_util.convert_to_tflite(
|
tflite_model = model_util.convert_to_tflite(
|
||||||
model=self._model,
|
model=self._model,
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Options for building image classifier."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import model_options as model_opt
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import model_spec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ImageClassifierOptions:
|
||||||
|
"""Configurable options for building image classifier.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
supported_model: A model from the SupportedModels enum.
|
||||||
|
model_options: A set of options for configuring the selected model.
|
||||||
|
hparams: A set of hyperparameters used to train the image classifier.
|
||||||
|
"""
|
||||||
|
supported_model: model_spec.SupportedModels
|
||||||
|
model_options: Optional[model_opt.ImageClassifierModelOptions] = None
|
||||||
|
hparams: Optional[hyperparameters.HParams] = None
|
|
@ -15,6 +15,7 @@
|
||||||
import filecmp
|
import filecmp
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from unittest import mock
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -54,54 +55,59 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
super(ImageClassifierTest, self).setUp()
|
super(ImageClassifierTest, self).setUp()
|
||||||
all_data = self._gen_cmy_data()
|
all_data = self._gen_cmy_data()
|
||||||
# Splits data, 90% data for training, 10% for testing
|
# Splits data, 90% data for training, 10% for testing
|
||||||
self.train_data, self.test_data = all_data.split(0.9)
|
self._train_data, self._test_data = all_data.split(0.9)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
dict(
|
dict(
|
||||||
testcase_name='mobilenet_v2',
|
testcase_name='mobilenet_v2',
|
||||||
model_spec=image_classifier.SupportedModels.MOBILENET_V2,
|
options=image_classifier.ImageClassifierOptions(
|
||||||
hparams=image_classifier.HParams(
|
supported_model=image_classifier.SupportedModels.MOBILENET_V2,
|
||||||
train_epochs=1, batch_size=1, shuffle=True)),
|
hparams=image_classifier.HParams(
|
||||||
|
epochs=1, batch_size=1, shuffle=True))),
|
||||||
dict(
|
dict(
|
||||||
testcase_name='efficientnet_lite0',
|
testcase_name='efficientnet_lite0',
|
||||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
options=image_classifier.ImageClassifierOptions(
|
||||||
hparams=image_classifier.HParams(
|
supported_model=(
|
||||||
train_epochs=1, batch_size=1, shuffle=True)),
|
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
epochs=1, batch_size=1, shuffle=True))),
|
||||||
|
dict(
|
||||||
|
testcase_name='efficientnet_lite0_change_dropout_rate',
|
||||||
|
options=image_classifier.ImageClassifierOptions(
|
||||||
|
supported_model=(
|
||||||
|
image_classifier.SupportedModels.EFFICIENTNET_LITE0),
|
||||||
|
model_options=image_classifier.ModelOptions(dropout_rate=0.1),
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
epochs=1, batch_size=1, shuffle=True))),
|
||||||
dict(
|
dict(
|
||||||
testcase_name='efficientnet_lite2',
|
testcase_name='efficientnet_lite2',
|
||||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
|
options=image_classifier.ImageClassifierOptions(
|
||||||
hparams=image_classifier.HParams(
|
supported_model=(
|
||||||
train_epochs=1, batch_size=1, shuffle=True)),
|
image_classifier.SupportedModels.EFFICIENTNET_LITE2),
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
epochs=1, batch_size=1, shuffle=True))),
|
||||||
dict(
|
dict(
|
||||||
testcase_name='efficientnet_lite4',
|
testcase_name='efficientnet_lite4',
|
||||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
|
options=image_classifier.ImageClassifierOptions(
|
||||||
hparams=image_classifier.HParams(
|
supported_model=(
|
||||||
train_epochs=1, batch_size=1, shuffle=True)),
|
image_classifier.SupportedModels.EFFICIENTNET_LITE4),
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
epochs=1, batch_size=1, shuffle=True))),
|
||||||
)
|
)
|
||||||
def test_create_and_train_model(self,
|
def test_create_and_train_model(
|
||||||
model_spec: image_classifier.SupportedModels,
|
self, options: image_classifier.ImageClassifierOptions):
|
||||||
hparams: image_classifier.HParams):
|
|
||||||
model = image_classifier.ImageClassifier.create(
|
model = image_classifier.ImageClassifier.create(
|
||||||
model_spec=model_spec,
|
train_data=self._train_data,
|
||||||
train_data=self.train_data,
|
validation_data=self._test_data,
|
||||||
hparams=hparams,
|
options=options)
|
||||||
validation_data=self.test_data)
|
|
||||||
self._test_accuracy(model)
|
|
||||||
|
|
||||||
def test_efficientnetlite0_model_train_and_export(self):
|
|
||||||
hparams = image_classifier.HParams(
|
|
||||||
train_epochs=1, batch_size=1, shuffle=True)
|
|
||||||
model = image_classifier.ImageClassifier.create(
|
|
||||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
|
||||||
train_data=self.train_data,
|
|
||||||
hparams=hparams,
|
|
||||||
validation_data=self.test_data)
|
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
|
||||||
# Test export_model
|
# Test export_model
|
||||||
model.export_model()
|
model.export_model()
|
||||||
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
|
output_metadata_file = os.path.join(options.hparams.export_dir,
|
||||||
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
|
'metadata.json')
|
||||||
|
output_tflite_file = os.path.join(options.hparams.export_dir,
|
||||||
|
'model.tflite')
|
||||||
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
|
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
|
||||||
|
|
||||||
self.assertTrue(os.path.exists(output_tflite_file))
|
self.assertTrue(os.path.exists(output_tflite_file))
|
||||||
|
@ -112,9 +118,30 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
||||||
|
|
||||||
def _test_accuracy(self, model, threshold=0.0):
|
def _test_accuracy(self, model, threshold=0.0):
|
||||||
_, accuracy = model.evaluate(self.test_data)
|
_, accuracy = model.evaluate(self._test_data)
|
||||||
self.assertGreaterEqual(accuracy, threshold)
|
self.assertGreaterEqual(accuracy, threshold)
|
||||||
|
|
||||||
|
@mock.patch.object(
|
||||||
|
image_classifier.hyperparameters,
|
||||||
|
'HParams',
|
||||||
|
autospec=True,
|
||||||
|
return_value=image_classifier.HParams(epochs=1))
|
||||||
|
@mock.patch.object(
|
||||||
|
image_classifier.model_options,
|
||||||
|
'ImageClassifierModelOptions',
|
||||||
|
autospec=True,
|
||||||
|
return_value=image_classifier.ModelOptions())
|
||||||
|
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
|
||||||
|
self, mock_hparams, mock_model_options):
|
||||||
|
options = image_classifier.ImageClassifierOptions(
|
||||||
|
supported_model=(image_classifier.SupportedModels.EFFICIENTNET_LITE0))
|
||||||
|
image_classifier.ImageClassifier.create(
|
||||||
|
train_data=self._train_data,
|
||||||
|
validation_data=self._test_data,
|
||||||
|
options=options)
|
||||||
|
mock_hparams.assert_called_once()
|
||||||
|
mock_model_options.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Load compressed models from tensorflow_hub
|
# Load compressed models from tensorflow_hub
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Configurable model options for image classifier models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ImageClassifierModelOptions:
|
||||||
|
"""Configurable options for image classifier model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||||
|
layer.
|
||||||
|
"""
|
||||||
|
dropout_rate: float = 0.2
|
|
@ -49,13 +49,14 @@ def _create_optimizer(init_lr: float, decay_steps: int,
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]:
|
def _get_default_callbacks(
|
||||||
|
export_dir: str) -> List[tf.keras.callbacks.Callback]:
|
||||||
"""Gets default callbacks."""
|
"""Gets default callbacks."""
|
||||||
summary_dir = os.path.join(model_dir, 'summaries')
|
summary_dir = os.path.join(export_dir, 'summaries')
|
||||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||||
# Save checkpoint every 20 epochs.
|
# Save checkpoint every 20 epochs.
|
||||||
|
|
||||||
checkpoint_path = os.path.join(model_dir, 'checkpoint')
|
checkpoint_path = os.path.join(export_dir, 'checkpoint')
|
||||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||||
checkpoint_path, save_weights_only=True, period=20)
|
checkpoint_path, save_weights_only=True, period=20)
|
||||||
return [summary_callback, checkpoint_callback]
|
return [summary_callback, checkpoint_callback]
|
||||||
|
@ -81,7 +82,8 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
||||||
learning_rate = hparams.learning_rate * hparams.batch_size / 256
|
learning_rate = hparams.learning_rate * hparams.batch_size / 256
|
||||||
|
|
||||||
# Get decay steps.
|
# Get decay steps.
|
||||||
total_training_steps = hparams.steps_per_epoch * hparams.train_epochs
|
# NOMUTANTS--(b/256493858):Plan to test it in the unified training library.
|
||||||
|
total_training_steps = hparams.steps_per_epoch * hparams.epochs
|
||||||
default_decay_steps = hparams.decay_samples // hparams.batch_size
|
default_decay_steps = hparams.decay_samples // hparams.batch_size
|
||||||
decay_steps = max(total_training_steps, default_decay_steps)
|
decay_steps = max(total_training_steps, default_decay_steps)
|
||||||
|
|
||||||
|
@ -92,11 +94,11 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
||||||
loss = tf.keras.losses.CategoricalCrossentropy(
|
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||||
label_smoothing=hparams.label_smoothing)
|
label_smoothing=hparams.label_smoothing)
|
||||||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||||
callbacks = _get_default_callbacks(hparams.model_dir)
|
callbacks = _get_default_callbacks(export_dir=hparams.export_dir)
|
||||||
|
|
||||||
# Train the model.
|
# Train the model.
|
||||||
return model.fit(
|
return model.fit(
|
||||||
x=train_ds,
|
x=train_ds,
|
||||||
epochs=hparams.train_epochs,
|
epochs=hparams.epochs,
|
||||||
validation_data=validation_ds,
|
validation_data=validation_ds,
|
||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user