Open source model_maker/python/core/tasks and model_maker/python/vision/image_classifier

PiperOrigin-RevId: 481182271
This commit is contained in:
MediaPipe Team 2022-10-14 10:45:23 -07:00 committed by Copybara-Service
parent 6f3e8381ed
commit 0428550d75
20 changed files with 1544 additions and 0 deletions

View File

@ -0,0 +1,64 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])
py_library(
name = "custom_model",
srcs = ["custom_model.py"],
srcs_version = "PY3",
deps = [
"//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
],
)
py_test(
name = "custom_model_test",
srcs = ["custom_model_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":custom_model",
"//mediapipe/model_maker/python/core/utils:test_util",
],
)
py_library(
name = "classifier",
srcs = ["classifier.py"],
srcs_version = "PY3",
deps = [
":custom_model",
"//mediapipe/model_maker/python/core/data:dataset",
],
)
py_test(
name = "classifier_test",
srcs = ["classifier_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":classifier",
"//mediapipe/model_maker/python/core/utils:test_util",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,77 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom classifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from typing import Any, List
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.tasks import custom_model
class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool,
full_train: bool):
"""Initilizes a classifier with its specifications.
Args:
model_spec: Specification for the model.
index_to_label: A list that map from index to label class name.
shuffle: Whether the dataset should be shuffled.
full_train: If true, train the model end-to-end including the backbone
and the classification layers on top. Otherwise, only train the top
classification layers.
"""
super(Classifier, self).__init__(model_spec, shuffle)
self._index_to_label = index_to_label
self._full_train = full_train
self._num_classes = len(index_to_label)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset.
Args:
data: Evaluation dataset
batch_size: Number of samples per evaluation step.
Returns:
The loss value and accuracy.
"""
ds = data.gen_tf_dataset(
batch_size, is_training=False, preprocess=self._preprocess)
return self._model.evaluate(ds)
def export_labels(self, export_dir: str, label_filename: str = 'labels.txt'):
"""Exports classification labels into a label file.
Args:
export_dir: The directory to save exported files.
label_filename: File name to save labels model. The full export path is
{export_dir}/{label_filename}.
"""
if not tf.io.gfile.exists(export_dir):
tf.io.gfile.makedirs(export_dir)
label_filepath = os.path.join(export_dir, label_filename)
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
with tf.io.gfile.GFile(label_filepath, 'w') as f:
f.write('\n'.join(self._index_to_label))

View File

@ -0,0 +1,58 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import test_util
class MockClassifier(classifier.Classifier):
"""A mock class with implementation of abstract methods for testing."""
def train(self, train_data, validation_data=None, **kwargs):
pass
def evaluate(self, data, **kwargs):
pass
class ClassifierTest(tf.test.TestCase):
def setUp(self):
super(ClassifierTest, self).setUp()
index_to_label = ['cat', 'dog']
self.model = MockClassifier(
model_spec=None,
index_to_label=index_to_label,
shuffle=False,
full_train=False)
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
def _check_nonempty_file(self, filepath):
self.assertTrue(os.path.isfile(filepath))
self.assertGreater(os.path.getsize(filepath), 0)
def test_export_labels(self):
export_path = os.path.join(self.get_temp_dir(), 'export/')
self.model.export_labels(export_dir=export_path)
self._check_nonempty_file(os.path.join(export_path, 'labels.txt'))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,85 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface to define a custom model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import os
from typing import Any, Callable, Optional
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
class CustomModel(abc.ABC):
"""The abstract base class that represents a custom TensorFlow model."""
def __init__(self, model_spec: Any, shuffle: bool):
"""Initializes a custom model with model specs and other parameters.
Args:
model_spec: Specification for the model.
shuffle: Whether the training data need be shuffled.
"""
self._model_spec = model_spec
self._shuffle = shuffle
self._preprocess = None
self._model = None
@abc.abstractmethod
def evaluate(self, data: dataset.Dataset, **kwargs):
"""Evaluates the model with the provided data."""
return
def summary(self):
"""Prints a summary of the model."""
self._model.summary()
def export_tflite(
self,
export_dir: str,
tflite_filename: str = 'model.tflite',
quantization_config: Optional[quantization.QuantizationConfig] = None,
preprocess: Optional[Callable[..., bool]] = None):
"""Converts the model to requested formats.
Args:
export_dir: The directory to save exported files.
tflite_filename: File name to save tflite model. The full export path is
{export_dir}/{tflite_filename}.
quantization_config: The configuration for model quantization.
preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature,
label, and is_training.
"""
if not tf.io.gfile.exists(export_dir):
tf.io.gfile.makedirs(export_dir)
tflite_filepath = os.path.join(export_dir, tflite_filename)
# TODO: Populate metadata to the exported TFLite model.
model_util.export_tflite(
self._model,
tflite_filepath,
quantization_config,
preprocess=preprocess)
tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)

View File

@ -0,0 +1,56 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.tasks import custom_model
from mediapipe.model_maker.python.core.utils import test_util
class MockCustomModel(custom_model.CustomModel):
"""A mock class with implementation of abstract methods for testing."""
def train(self, train_data, validation_data=None, **kwargs):
pass
def evaluate(self, data, **kwargs):
pass
class CustomModelTest(tf.test.TestCase):
def setUp(self):
super(CustomModelTest, self).setUp()
self.model = MockCustomModel(model_spec=None, shuffle=False)
self.model._model = test_util.build_model(input_shape=[4], num_classes=2)
def _check_nonempty_file(self, filepath):
self.assertTrue(os.path.isfile(filepath))
self.assertGreater(os.path.getsize(filepath), 0)
def test_export_tflite(self):
export_path = os.path.join(self.get_temp_dir(), 'export/')
self.model.export_tflite(export_dir=export_path)
self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,19 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,111 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python library rule.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "image_classifier_import",
srcs = ["__init__.py"],
deps = [
":dataset",
":hyperparameters",
":image_classifier",
":model_spec",
],
)
py_library(
name = "model_spec",
srcs = ["model_spec.py"],
)
py_test(
name = "model_spec_test",
srcs = ["model_spec_test.py"],
deps = [":model_spec"],
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
deps = [":dataset"],
)
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
)
py_library(
name = "train_image_classifier_lib",
srcs = ["train_image_classifier_lib.py"],
deps = [
":hyperparameters",
"//mediapipe/model_maker/python/core/utils:model_util",
],
)
py_library(
name = "image_classifier",
srcs = ["image_classifier.py"],
deps = [
":hyperparameters",
":model_spec",
":train_image_classifier_lib",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:image_preprocessing",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
],
)
py_library(
name = "image_classifier_test_lib",
testonly = 1,
srcs = ["image_classifier_test.py"],
deps = [":image_classifier_import"],
)
py_test(
name = "image_classifier_test",
srcs = ["image_classifier_test.py"],
shard_count = 2,
tags = ["requires-net:external"],
deps = [
":image_classifier_test_lib",
],
)
py_binary(
name = "image_classifier_demo",
srcs = ["image_classifier_demo.py"],
deps = [
":image_classifier_import",
"//mediapipe/model_maker/python/core/utils:quantization",
],
)

View File

@ -0,0 +1,25 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Model Maker Python Public API For Image Classifier."""
from mediapipe.model_maker.python.vision.image_classifier import dataset
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import image_classifier
from mediapipe.model_maker.python.vision.image_classifier import model_spec
ImageClassifier = image_classifier.ImageClassifier
HParams = hyperparameters.HParams
Dataset = dataset.Dataset
ModelSpec = model_spec.ModelSpec
SupportedModels = model_spec.SupportedModels

View File

@ -0,0 +1,139 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image classifier dataset library."""
import os
import random
from typing import List, Optional, Tuple
import tensorflow as tf
import tensorflow_datasets as tfds
from mediapipe.model_maker.python.core.data import classification_dataset
def _load_image(path: str) -> tf.Tensor:
"""Loads image."""
image_raw = tf.io.read_file(path)
image_tensor = tf.cond(
tf.image.is_jpeg(image_raw),
lambda: tf.image.decode_jpeg(image_raw, channels=3),
lambda: tf.image.decode_png(image_raw, channels=3))
return image_tensor
def _create_data(
name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo,
label_names: List[str]
) -> Optional[classification_dataset.ClassificationDataset]:
"""Creates a Dataset object from tfds data."""
if name not in data:
return None
data = data[name]
data = data.map(lambda a: (a['image'], a['label']))
size = info.splits[name].num_examples
return Dataset(data, size, label_names)
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for image classifier."""
@classmethod
def from_folder(
cls,
dirname: str,
shuffle: bool = True) -> classification_dataset.ClassificationDataset:
"""Loads images and labels from the given directory.
Assume the image data of the same label are in the same subdirectory.
Args:
dirname: Name of the directory containing the data files.
shuffle: boolean, if shuffle, random shuffle data.
Returns:
Dataset containing images and labels and other related info.
Raises:
ValueError: if the input data directory is empty.
"""
data_root = os.path.abspath(dirname)
# Assumes the image data of the same label are in the same subdirectory,
# gets image path and label names.
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
all_image_size = len(all_image_paths)
if all_image_size == 0:
raise ValueError('Image size is zero')
if shuffle:
# Random shuffle data.
random.shuffle(all_image_paths)
label_names = sorted(
name for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name)))
all_label_size = len(label_names)
label_to_index = dict(
(name, index) for index, name in enumerate(label_names))
all_image_labels = [
label_to_index[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
autotune = tf.data.AUTOTUNE
image_ds = path_ds.map(_load_image, num_parallel_calls=autotune)
# Loads label.
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(all_image_labels, tf.int64))
# Creates a dataset if (image, label) pairs.
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
tf.compat.v1.logging.info(
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names))
return Dataset(image_label_ds, all_image_size, label_names)
@classmethod
def load_tf_dataset(
cls, name: str
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset]]:
"""Loads data from tensorflow_datasets.
Args:
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
documentation of tfds.load for more details.
Returns:
A tuple of Datasets for the train/validation/test.
Raises:
ValueError: if the input tf dataset does not have train/validation/test
labels.
"""
data, info = tfds.load(name, with_info=True)
if 'label' not in info.features:
raise ValueError('info.features need to contain \'label\' key.')
label_names = info.features['label'].names
train_data = _create_data('train', data, info, label_names)
validation_data = _create_data('validation', data, info, label_names)
test_data = _create_data('test', data, info, label_names)
return train_data, validation_data, test_data

View File

@ -0,0 +1,108 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.vision.image_classifier import dataset
def _fill_image(rgb, image_size):
r, g, b = rgb
return np.broadcast_to(
np.array([[[r, g, b]]], dtype=np.uint8),
shape=(image_size, image_size, 3))
def _write_filled_jpeg_file(path, rgb, image_size):
tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size),
'channels_last', 'jpeg')
class DatasetTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.image_path = os.path.join(self.get_temp_dir(), 'random_image_dir')
if os.path.exists(self.image_path):
return
os.mkdir(self.image_path)
for class_name in ('daisy', 'tulips'):
class_subdir = os.path.join(self.image_path, class_name)
os.mkdir(class_subdir)
_write_filled_jpeg_file(
os.path.join(class_subdir, '0.jpeg'),
[random.uniform(0, 255) for _ in range(3)], 224)
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 2)
for i, elem in enumerate(train_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.index_to_label, ['pos', 'neg'])
self.assertLen(test_data, 2)
for i, elem in enumerate(test_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.index_to_label, ['pos', 'neg'])
def test_from_folder(self):
data = dataset.Dataset.from_folder(self.image_path)
self.assertLen(data, 2)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.index_to_label, ['daisy', 'tulips'])
for image, label in data.gen_tf_dataset():
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
if label.numpy() == 0:
raw_image_tensor = dataset._load_image(
os.path.join(self.image_path, 'daisy', '0.jpeg'))
else:
raw_image_tensor = dataset._load_image(
os.path.join(self.image_path, 'tulips', '0.jpeg'))
self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all())
def test_from_tfds(self):
# TODO: Remove this once tfds download error is fixed.
self.skipTest('Temporarily skip the unittest due to tfds download error.')
train_data, validation_data, test_data = (
dataset.Dataset.from_tfds('beans'))
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(train_data, 1034)
self.assertEqual(train_data.num_classes, 3)
self.assertEqual(train_data.index_to_label,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(validation_data, 133)
self.assertEqual(validation_data.num_classes, 3)
self.assertEqual(validation_data.index_to_label,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(test_data, 128)
self.assertEqual(test_data.num_classes, 3)
self.assertEqual(test_data.index_to_label,
['angular_leaf_spot', 'bean_rust', 'healthy'])
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,74 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training image classification models."""
import dataclasses
import tempfile
from typing import Optional
# TODO: Expose other hyperparameters, e.g. data augmentation
# hyperparameters if requested.
@dataclasses.dataclass
class HParams:
"""The hyperparameters for training image classifiers.
The hyperparameters include:
# Parameters about training data.
do_fine_tuning: If true, the base module is trained together with the
classification layer on top.
shuffle: A boolean controlling if shuffle the dataset. Default to false.
# Parameters about training configuration
train_epochs: Training will do this many iterations over the dataset.
batch_size: Each training step samples a batch of this many images.
learning_rate: The learning rate to use for gradient descent training.
dropout_rate: The fraction of the input units to drop, used in dropout
layer.
l1_regularizer: A regularizer that applies a L1 regularization penalty.
l2_regularizer: A regularizer that applies a L2 regularization penalty.
label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
more details.
do_data_augmentation: A boolean controlling whether the training dataset is
augmented by randomly distorting input images, including random cropping,
flipping, etc. See utils.image_preprocessing documentation for details.
steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size devided by batch size.
decay_samples: Number of training samples used to calculate the decay steps
and create the training optimizer.
warmup_steps: Number of warmup steps for a linear increasing warmup schedule
on learning rate. Used to set up warmup schedule by model_util.WarmUp.
# Parameters about the saved checkpoint
model_dir: The location of model checkpoint files and exported model files.
"""
# Parameters about training data
do_fine_tuning: bool = False
shuffle: bool = False
# Parameters about training configuration
train_epochs: int = 5
batch_size: int = 32
learning_rate: float = 0.005
dropout_rate: float = 0.2
l1_regularizer: float = 0.0
l2_regularizer: float = 0.0001
label_smoothing: float = 0.1
do_data_augmentation: bool = True
steps_per_epoch: Optional[int] = None
decay_samples: int = 10000 * 256
warmup_epochs: int = 2
# Parameters about the saved checkpoint
model_dir: str = tempfile.mkdtemp()

View File

@ -0,0 +1,172 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""APIs to train image classifier model."""
from typing import Any, List, Optional
import tensorflow as tf
import tensorflow_hub as hub
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import image_preprocessing
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any],
hparams: hp.HParams):
"""Initializes ImageClassifier class.
Args:
model_spec: Specification for the model.
index_to_label: A list that maps from index to label class name.
hparams: The hyperparameters for training image classifier.
"""
super(ImageClassifier, self).__init__(
model_spec=model_spec,
index_to_label=index_to_label,
shuffle=hparams.shuffle,
full_train=hparams.do_fine_tuning)
self._hparams = hparams
self._preprocess = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape,
num_classes=self._num_classes,
mean_rgb=self._model_spec.mean_rgb,
stddev_rgb=self._model_spec.stddev_rgb,
use_augmentation=hparams.do_data_augmentation)
self._history = None # Training history returned from `keras_model.fit`.
@classmethod
def create(
cls,
model_spec: ms.SupportedModels,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
hparams: Optional[hp.HParams] = None,
) -> 'ImageClassifier':
"""Creates and trains an image classifier.
Loads data and trains the model based on data for image classification.
Args:
model_spec: Specification for the model.
train_data: Training data.
validation_data: Validation data.
hparams: Hyperparameters for training image classifier.
Returns:
An instance based on ImageClassifier.
"""
if hparams is None:
hparams = hp.HParams()
spec = ms.SupportedModels.get(model_spec)
image_classifier = cls(
model_spec=spec,
index_to_label=train_data.index_to_label,
hparams=hparams)
image_classifier._create_model()
tf.compat.v1.logging.info('Training the models...')
image_classifier._train(
train_data=train_data, validation_data=validation_data)
return image_classifier
def _train(self, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset):
"""Trains the model with input train_data.
The training results are recorded by a self._history object returned by
tf.keras.Model.fit().
Args:
train_data: Training data.
validation_data: Validation data.
"""
tf.compat.v1.logging.info('Training the models...')
hparams = self._hparams
if len(train_data) < hparams.batch_size:
raise ValueError('The size of the train_data (%d) couldn\'t be smaller '
'than batch_size (%d). To solve this problem, set '
'the batch_size smaller or increase the size of the '
'train_data.' % (len(train_data), hparams.batch_size))
train_dataset = train_data.gen_tf_dataset(
batch_size=hparams.batch_size,
is_training=True,
shuffle=self._shuffle,
preprocess=self._preprocess)
hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=hparams.steps_per_epoch,
batch_size=hparams.batch_size,
train_data=train_data)
train_dataset = train_dataset.take(count=hparams.steps_per_epoch)
validation_dataset = validation_data.gen_tf_dataset(
batch_size=hparams.batch_size,
is_training=False,
preprocess=self._preprocess)
# Train the model.
self._history = train_image_classifier_lib.train_model(
model=self._model,
hparams=hparams,
train_ds=train_dataset,
validation_ds=validation_dataset)
def _create_model(self):
"""Creates the classifier model from TFHub pretrained models."""
module_layer = hub.KerasLayer(
handle=self._model_spec.uri, trainable=self._hparams.do_fine_tuning)
image_size = self._model_spec.input_image_shape
self._model = tf.keras.Sequential([
tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
tf.keras.layers.Dropout(rate=self._hparams.dropout_rate),
tf.keras.layers.Dense(
units=self._num_classes,
activation='softmax',
kernel_regularizer=tf.keras.regularizers.l1_l2(
l1=self._hparams.l1_regularizer,
l2=self._hparams.l2_regularizer))
])
print(self._model.summary())
def export_model(
self,
model_name: str = 'model.tflite',
quantization_config: Optional[quantization.QuantizationConfig] = None):
"""Converts the model to the requested formats and exports to a file.
Args:
model_name: File name to save tflite model. The full export path is
{export_dir}/{tflite_filename}.
quantization_config: The configuration for model quantization.
"""
super().export_tflite(
self._hparams.model_dir,
model_name,
quantization_config,
preprocess=self._preprocess)

View File

@ -0,0 +1,106 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Demo for making an image classifier model by MediaPipe Model Maker."""
import os
# Dependency imports
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision import image_classifier
FLAGS = flags.FLAGS
def define_flags() -> None:
"""Define flags for the image classifier model maker demo."""
flags.DEFINE_string('export_dir', None,
'The directory to save exported files.')
flags.DEFINE_string(
'input_data_dir', None,
"""The directory with input training data. If the training data is not
specified, the pipeline will download a default training dataset.""")
flags.DEFINE_enum_class('spec',
image_classifier.SupportedModels.EFFICIENTNET_LITE0,
image_classifier.SupportedModels,
'The image classifier to run.')
flags.DEFINE_enum('quantization', None, ['dynamic', 'int8', 'float16'],
'The quantization method to use when exporting the model.')
flags.mark_flag_as_required('export_dir')
def download_demo_data() -> str:
"""Downloads demo data, and returns directory path."""
data_dir = tf.keras.utils.get_file(
fname='flower_photos.tgz',
origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
extract=True)
return os.path.join(os.path.dirname(data_dir), 'flower_photos') # folder name
def run(data_dir: str, export_dir: str,
model_spec: image_classifier.SupportedModels,
quantization_option: str) -> None:
"""Runs demo."""
data = image_classifier.Dataset.from_folder(data_dir)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
model = image_classifier.ImageClassifier.create(
model_spec=model_spec,
train_data=train_data,
validation_data=validation_data,
hparams=image_classifier.HParams(model_dir=export_dir))
_, acc = model.evaluate(test_data)
print('Test accuracy: %f' % acc)
if quantization_option is None:
quantization_config = None
elif quantization_option == 'dynamic':
quantization_config = quantization.QuantizationConfig.for_dynamic()
elif quantization_option == 'int8':
quantization_config = quantization.QuantizationConfig.for_int8(train_data)
elif quantization_option == 'float16':
quantization_config = quantization.QuantizationConfig.for_float16()
else:
raise ValueError(f'Quantization: {quantization} is not recognized')
model.export_model(quantization_config=quantization_config)
model.export_labels(export_dir)
def main(_) -> None:
logging.set_verbosity(logging.INFO)
if FLAGS.input_data_dir is None:
data_dir = download_demo_data()
else:
data_dir = FLAGS.input_data_dir
export_dir = os.path.expanduser(FLAGS.export_dir)
run(data_dir=data_dir,
export_dir=export_dir,
model_spec=FLAGS.spec,
quantization_option=FLAGS.quantization)
if __name__ == '__main__':
define_flags()
app.run(main)

View File

@ -0,0 +1,122 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.vision import image_classifier
def _fill_image(rgb, image_size):
r, g, b = rgb
return np.broadcast_to(
np.array([[[r, g, b]]], dtype=np.uint8),
shape=(image_size, image_size, 3))
class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
IMAGE_SIZE = 24
IMAGES_PER_CLASS = 2
CMY_NAMES_AND_RGB_VALUES = (('cyan', (0, 255, 255)),
('magenta', (255, 0, 255)), ('yellow', (255, 255,
0)))
def _gen(self):
for i, (_, rgb) in enumerate(self.CMY_NAMES_AND_RGB_VALUES):
for _ in range(self.IMAGES_PER_CLASS):
yield (_fill_image(rgb, self.IMAGE_SIZE), i)
def _gen_cmy_data(self):
ds = tf.data.Dataset.from_generator(
self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3,
['cyan', 'magenta', 'yellow'])
return data
def setUp(self):
super(ImageClassifierTest, self).setUp()
all_data = self._gen_cmy_data()
# Splits data, 90% data for training, 10% for testing
self.train_data, self.test_data = all_data.split(0.9)
@parameterized.named_parameters(
dict(
testcase_name='mobilenet_v2',
model_spec=image_classifier.SupportedModels.MOBILENET_V2,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='resnet_50',
model_spec=image_classifier.SupportedModels.RESNET_50,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite0',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite1',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite2',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite3',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE3,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite4',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
)
def test_create_and_train_model(self,
model_spec: image_classifier.SupportedModels,
hparams: image_classifier.HParams):
model = image_classifier.ImageClassifier.create(
model_spec=model_spec,
train_data=self.train_data,
hparams=hparams,
validation_data=self.test_data)
self._test_accuracy(model)
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self):
hparams = image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)
model = image_classifier.ImageClassifier.create(
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
train_data=self.train_data,
hparams=hparams,
validation_data=self.test_data)
self._test_accuracy(model)
def _test_accuracy(self, model, threshold=0.0):
_, accuracy = model.evaluate(self.test_data)
self.assertGreaterEqual(accuracy, threshold)
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()

View File

@ -0,0 +1,104 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image classifier model specification."""
import enum
import functools
from typing import List, Optional
class ModelSpec(object):
"""Specification of image classifier model."""
mean_rgb = [0.0]
stddev_rgb = [255.0]
def __init__(self,
uri: str,
input_image_shape: Optional[List[int]] = None,
name: str = ''):
"""Initializes a new instance of the `ImageModelSpec` class.
Args:
uri: str, URI to the pretrained model.
input_image_shape: list of int, input image shape. Default: [224, 224].
name: str, model spec name.
"""
self.uri = uri
self.name = name
if input_image_shape is None:
input_image_shape = [224, 224]
self.input_image_shape = input_image_shape
mobilenet_v2_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
name='mobilenet_v2')
resnet_50_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
name='resnet_50')
efficientnet_lite0_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
name='efficientnet_lite0')
efficientnet_lite1_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
input_image_shape=[240, 240],
name='efficientnet_lite1')
efficientnet_lite2_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
input_image_shape=[260, 260],
name='efficientnet_lite2')
efficientnet_lite3_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
input_image_shape=[280, 280],
name='efficientnet_lite3')
efficientnet_lite4_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
input_image_shape=[300, 300],
name='efficientnet_lite4')
# TODO: Document the exposed models.
@enum.unique
class SupportedModels(enum.Enum):
"""Image classifier model supported by model maker."""
MOBILENET_V2 = mobilenet_v2_spec
RESNET_50 = resnet_50_spec
EFFICIENTNET_LITE0 = efficientnet_lite0_spec
EFFICIENTNET_LITE1 = efficientnet_lite1_spec
EFFICIENTNET_LITE2 = efficientnet_lite2_spec
EFFICIENTNET_LITE3 = efficientnet_lite3_spec
EFFICIENTNET_LITE4 = efficientnet_lite4_spec
@classmethod
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
"""Gets model spec from the input enum and initializes it."""
if spec not in cls:
raise TypeError('Unsupported image classifier spec: {}'.format(spec))
return spec.value()

View File

@ -0,0 +1,93 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable, List
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
class ModelSpecTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='mobilenet_v2_spec_test',
model_spec=ms.mobilenet_v2_spec,
expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
expected_name='mobilenet_v2',
expected_input_image_shape=[224, 224]),
dict(
testcase_name='resnet_50_spec_test',
model_spec=ms.resnet_50_spec,
expected_uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
expected_name='resnet_50',
expected_input_image_shape=[224, 224]),
dict(
testcase_name='efficientnet_lite0_spec_test',
model_spec=ms.efficientnet_lite0_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
expected_name='efficientnet_lite0',
expected_input_image_shape=[224, 224]),
dict(
testcase_name='efficientnet_lite1_spec_test',
model_spec=ms.efficientnet_lite1_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
expected_name='efficientnet_lite1',
expected_input_image_shape=[240, 240]),
dict(
testcase_name='efficientnet_lite2_spec_test',
model_spec=ms.efficientnet_lite2_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
expected_name='efficientnet_lite2',
expected_input_image_shape=[260, 260]),
dict(
testcase_name='efficientnet_lite3_spec_test',
model_spec=ms.efficientnet_lite3_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
expected_name='efficientnet_lite3',
expected_input_image_shape=[280, 280]),
dict(
testcase_name='efficientnet_lite4_spec_test',
model_spec=ms.efficientnet_lite4_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
expected_name='efficientnet_lite4',
expected_input_image_shape=[300, 300]),
)
def test_predefiend_spec(self, model_spec: Callable[..., ms.ModelSpec],
expected_uri: str, expected_name: str,
expected_input_image_shape: List[int]):
model_spec_obj = model_spec()
self.assertIsInstance(model_spec_obj, ms.ModelSpec)
self.assertEqual(model_spec_obj.uri, expected_uri)
self.assertEqual(model_spec_obj.name, expected_name)
self.assertEqual(model_spec_obj.input_image_shape,
expected_input_image_shape)
def test_create_spec(self):
custom_model_spec = ms.ModelSpec(
uri='https://custom_model',
input_image_shape=[128, 128],
name='custom_model')
self.assertEqual(custom_model_spec.uri, 'https://custom_model')
self.assertEqual(custom_model_spec.name, 'custom_model')
self.assertEqual(custom_model_spec.input_image_shape, [128, 128])
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()

View File

@ -0,0 +1,103 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to train model."""
import os
from typing import List
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
def _create_optimizer(init_lr: float, decay_steps: int,
warmup_steps: int) -> tf.keras.optimizers.Optimizer:
"""Creates an optimizer with learning rate schedule.
Uses Keras CosineDecay schedule for the learning rate by default.
Args:
init_lr: Initial learning rate.
decay_steps: Number of steps to decay over.
warmup_steps: Number of steps to do warmup for.
Returns:
A tf.keras.optimizers.Optimizer for model training.
"""
learning_rate_fn = tf.keras.experimental.CosineDecay(
initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
if warmup_steps:
learning_rate_fn = model_util.WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=warmup_steps)
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
return optimizer
def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]:
"""Gets default callbacks."""
summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
# Save checkpoint every 20 epochs.
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True, period=20)
return [summary_callback, checkpoint_callback]
def train_model(model: tf.keras.Model, hparams: hp.HParams,
train_ds: tf.data.Dataset,
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
"""Trains model with the input data and hyperparameters.
Args:
model: Input tf.keras.Model.
hparams: Hyperparameters for training image classifier.
train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit().
validation_ds: tf.data.Dataset, validation data to be fed in
tf.keras.Model.fit().
Returns:
The tf.keras.callbacks.History object returned by tf.keras.Model.fit().
"""
# Learning rate is linear to batch size.
learning_rate = hparams.learning_rate * hparams.batch_size / 256
# Get decay steps.
total_training_steps = hparams.steps_per_epoch * hparams.train_epochs
default_decay_steps = hparams.decay_samples // hparams.batch_size
decay_steps = max(total_training_steps, default_decay_steps)
warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch
optimizer = _create_optimizer(
init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps)
loss = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=hparams.label_smoothing)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
callbacks = _get_default_callbacks(hparams.model_dir)
# Train the model.
return model.fit(
x=train_ds,
epochs=hparams.train_epochs,
steps_per_epoch=hparams.steps_per_epoch,
validation_data=validation_ds,
callbacks=callbacks)

View File

@ -2,3 +2,5 @@ absl-py
numpy numpy
opencv-contrib-python opencv-contrib-python
tensorflow tensorflow
tensorflow-datasets
tensorflow-hub