diff --git a/mediapipe/model_maker/python/core/tasks/BUILD b/mediapipe/model_maker/python/core/tasks/BUILD new file mode 100644 index 000000000..b3588f0be --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/BUILD @@ -0,0 +1,64 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +licenses(["notice"]) + +py_library( + name = "custom_model", + srcs = ["custom_model.py"], + srcs_version = "PY3", + deps = [ + "//mediapipe/model_maker/python/core/data:dataset", + "//mediapipe/model_maker/python/core/utils:model_util", + "//mediapipe/model_maker/python/core/utils:quantization", + ], +) + +py_test( + name = "custom_model_test", + srcs = ["custom_model_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":custom_model", + "//mediapipe/model_maker/python/core/utils:test_util", + ], +) + +py_library( + name = "classifier", + srcs = ["classifier.py"], + srcs_version = "PY3", + deps = [ + ":custom_model", + "//mediapipe/model_maker/python/core/data:dataset", + ], +) + +py_test( + name = "classifier_test", + srcs = ["classifier_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":classifier", + "//mediapipe/model_maker/python/core/utils:test_util", + ], +) diff --git a/mediapipe/model_maker/python/core/tasks/__init__.py b/mediapipe/model_maker/python/core/tasks/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py new file mode 100644 index 000000000..6b366f6dc --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -0,0 +1,77 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Custom classifier.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from typing import Any, List + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset +from mediapipe.model_maker.python.core.tasks import custom_model + + +class Classifier(custom_model.CustomModel): + """An abstract base class that represents a TensorFlow classifier.""" + + def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool, + full_train: bool): + """Initilizes a classifier with its specifications. + + Args: + model_spec: Specification for the model. + index_to_label: A list that map from index to label class name. + shuffle: Whether the dataset should be shuffled. + full_train: If true, train the model end-to-end including the backbone + and the classification layers on top. Otherwise, only train the top + classification layers. + """ + super(Classifier, self).__init__(model_spec, shuffle) + self._index_to_label = index_to_label + self._full_train = full_train + self._num_classes = len(index_to_label) + + def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: + """Evaluates the classifier with the provided evaluation dataset. + + Args: + data: Evaluation dataset + batch_size: Number of samples per evaluation step. + + Returns: + The loss value and accuracy. + """ + ds = data.gen_tf_dataset( + batch_size, is_training=False, preprocess=self._preprocess) + return self._model.evaluate(ds) + + def export_labels(self, export_dir: str, label_filename: str = 'labels.txt'): + """Exports classification labels into a label file. + + Args: + export_dir: The directory to save exported files. + label_filename: File name to save labels model. The full export path is + {export_dir}/{label_filename}. + """ + if not tf.io.gfile.exists(export_dir): + tf.io.gfile.makedirs(export_dir) + + label_filepath = os.path.join(export_dir, label_filename) + tf.compat.v1.logging.info('Saving labels in %s', label_filepath) + with tf.io.gfile.GFile(label_filepath, 'w') as f: + f.write('\n'.join(self._index_to_label)) diff --git a/mediapipe/model_maker/python/core/tasks/classifier_test.py b/mediapipe/model_maker/python/core/tasks/classifier_test.py new file mode 100644 index 000000000..fbf231d8b --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/classifier_test.py @@ -0,0 +1,58 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +# Dependency imports + +import tensorflow as tf + +from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import test_util + + +class MockClassifier(classifier.Classifier): + """A mock class with implementation of abstract methods for testing.""" + + def train(self, train_data, validation_data=None, **kwargs): + pass + + def evaluate(self, data, **kwargs): + pass + + +class ClassifierTest(tf.test.TestCase): + + def setUp(self): + super(ClassifierTest, self).setUp() + index_to_label = ['cat', 'dog'] + self.model = MockClassifier( + model_spec=None, + index_to_label=index_to_label, + shuffle=False, + full_train=False) + self.model.model = test_util.build_model(input_shape=[4], num_classes=2) + + def _check_nonempty_file(self, filepath): + self.assertTrue(os.path.isfile(filepath)) + self.assertGreater(os.path.getsize(filepath), 0) + + def test_export_labels(self): + export_path = os.path.join(self.get_temp_dir(), 'export/') + self.model.export_labels(export_dir=export_path) + self._check_nonempty_file(os.path.join(export_path, 'labels.txt')) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/tasks/custom_model.py b/mediapipe/model_maker/python/core/tasks/custom_model.py new file mode 100644 index 000000000..2cea4e0a1 --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/custom_model.py @@ -0,0 +1,85 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Interface to define a custom model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import os +from typing import Any, Callable, Optional + +# Dependency imports + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.core.utils import quantization + + +class CustomModel(abc.ABC): + """The abstract base class that represents a custom TensorFlow model.""" + + def __init__(self, model_spec: Any, shuffle: bool): + """Initializes a custom model with model specs and other parameters. + + Args: + model_spec: Specification for the model. + shuffle: Whether the training data need be shuffled. + """ + self._model_spec = model_spec + self._shuffle = shuffle + self._preprocess = None + self._model = None + + @abc.abstractmethod + def evaluate(self, data: dataset.Dataset, **kwargs): + """Evaluates the model with the provided data.""" + return + + def summary(self): + """Prints a summary of the model.""" + self._model.summary() + + def export_tflite( + self, + export_dir: str, + tflite_filename: str = 'model.tflite', + quantization_config: Optional[quantization.QuantizationConfig] = None, + preprocess: Optional[Callable[..., bool]] = None): + """Converts the model to requested formats. + + Args: + export_dir: The directory to save exported files. + tflite_filename: File name to save tflite model. The full export path is + {export_dir}/{tflite_filename}. + quantization_config: The configuration for model quantization. + preprocess: A callable to preprocess the representative dataset for + quantization. The callable takes three arguments in order: feature, + label, and is_training. + """ + if not tf.io.gfile.exists(export_dir): + tf.io.gfile.makedirs(export_dir) + + tflite_filepath = os.path.join(export_dir, tflite_filename) + # TODO: Populate metadata to the exported TFLite model. + model_util.export_tflite( + self._model, + tflite_filepath, + quantization_config, + preprocess=preprocess) + tf.compat.v1.logging.info( + 'TensorFlow Lite model exported successfully: %s' % tflite_filepath) diff --git a/mediapipe/model_maker/python/core/tasks/custom_model_test.py b/mediapipe/model_maker/python/core/tasks/custom_model_test.py new file mode 100644 index 000000000..e693e1275 --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/custom_model_test.py @@ -0,0 +1,56 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +# Dependency imports + +import tensorflow as tf + +from mediapipe.model_maker.python.core.tasks import custom_model +from mediapipe.model_maker.python.core.utils import test_util + + +class MockCustomModel(custom_model.CustomModel): + """A mock class with implementation of abstract methods for testing.""" + + def train(self, train_data, validation_data=None, **kwargs): + pass + + def evaluate(self, data, **kwargs): + pass + + +class CustomModelTest(tf.test.TestCase): + + def setUp(self): + super(CustomModelTest, self).setUp() + self.model = MockCustomModel(model_spec=None, shuffle=False) + self.model._model = test_util.build_model(input_shape=[4], num_classes=2) + + def _check_nonempty_file(self, filepath): + self.assertTrue(os.path.isfile(filepath)) + self.assertGreater(os.path.getsize(filepath), 0) + + def test_export_tflite(self): + export_path = os.path.join(self.get_temp_dir(), 'export/') + self.model.export_tflite(export_dir=export_path) + self._check_nonempty_file(os.path.join(export_path, 'model.tflite')) + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/BUILD b/mediapipe/model_maker/python/vision/BUILD new file mode 100644 index 000000000..10aef8c33 --- /dev/null +++ b/mediapipe/model_maker/python/vision/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +licenses(["notice"]) diff --git a/mediapipe/model_maker/python/vision/__init__.py b/mediapipe/model_maker/python/vision/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/vision/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD new file mode 100644 index 000000000..a9386d56e --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -0,0 +1,111 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python library rule. +# Placeholder for internal Python strict library and test compatibility macro. + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +py_library( + name = "image_classifier_import", + srcs = ["__init__.py"], + deps = [ + ":dataset", + ":hyperparameters", + ":image_classifier", + ":model_spec", + ], +) + +py_library( + name = "model_spec", + srcs = ["model_spec.py"], +) + +py_test( + name = "model_spec_test", + srcs = ["model_spec_test.py"], + deps = [":model_spec"], +) + +py_library( + name = "dataset", + srcs = ["dataset.py"], + deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"], +) + +py_test( + name = "dataset_test", + srcs = ["dataset_test.py"], + deps = [":dataset"], +) + +py_library( + name = "hyperparameters", + srcs = ["hyperparameters.py"], +) + +py_library( + name = "train_image_classifier_lib", + srcs = ["train_image_classifier_lib.py"], + deps = [ + ":hyperparameters", + "//mediapipe/model_maker/python/core/utils:model_util", + ], +) + +py_library( + name = "image_classifier", + srcs = ["image_classifier.py"], + deps = [ + ":hyperparameters", + ":model_spec", + ":train_image_classifier_lib", + "//mediapipe/model_maker/python/core/data:classification_dataset", + "//mediapipe/model_maker/python/core/tasks:classifier", + "//mediapipe/model_maker/python/core/utils:image_preprocessing", + "//mediapipe/model_maker/python/core/utils:model_util", + "//mediapipe/model_maker/python/core/utils:quantization", + ], +) + +py_library( + name = "image_classifier_test_lib", + testonly = 1, + srcs = ["image_classifier_test.py"], + deps = [":image_classifier_import"], +) + +py_test( + name = "image_classifier_test", + srcs = ["image_classifier_test.py"], + shard_count = 2, + tags = ["requires-net:external"], + deps = [ + ":image_classifier_test_lib", + ], +) + +py_binary( + name = "image_classifier_demo", + srcs = ["image_classifier_demo.py"], + deps = [ + ":image_classifier_import", + "//mediapipe/model_maker/python/core/utils:quantization", + ], +) diff --git a/mediapipe/model_maker/python/vision/image_classifier/__init__.py b/mediapipe/model_maker/python/vision/image_classifier/__init__.py new file mode 100644 index 000000000..3ba6b0764 --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe Model Maker Python Public API For Image Classifier.""" + +from mediapipe.model_maker.python.vision.image_classifier import dataset +from mediapipe.model_maker.python.vision.image_classifier import hyperparameters +from mediapipe.model_maker.python.vision.image_classifier import image_classifier +from mediapipe.model_maker.python.vision.image_classifier import model_spec + +ImageClassifier = image_classifier.ImageClassifier +HParams = hyperparameters.HParams +Dataset = dataset.Dataset +ModelSpec = model_spec.ModelSpec +SupportedModels = model_spec.SupportedModels diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset.py b/mediapipe/model_maker/python/vision/image_classifier/dataset.py new file mode 100644 index 000000000..e57bae3dd --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset.py @@ -0,0 +1,139 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image classifier dataset library.""" + +import os +import random + +from typing import List, Optional, Tuple +import tensorflow as tf +import tensorflow_datasets as tfds + +from mediapipe.model_maker.python.core.data import classification_dataset + + +def _load_image(path: str) -> tf.Tensor: + """Loads image.""" + image_raw = tf.io.read_file(path) + image_tensor = tf.cond( + tf.image.is_jpeg(image_raw), + lambda: tf.image.decode_jpeg(image_raw, channels=3), + lambda: tf.image.decode_png(image_raw, channels=3)) + return image_tensor + + +def _create_data( + name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo, + label_names: List[str] +) -> Optional[classification_dataset.ClassificationDataset]: + """Creates a Dataset object from tfds data.""" + if name not in data: + return None + data = data[name] + data = data.map(lambda a: (a['image'], a['label'])) + size = info.splits[name].num_examples + return Dataset(data, size, label_names) + + +class Dataset(classification_dataset.ClassificationDataset): + """Dataset library for image classifier.""" + + @classmethod + def from_folder( + cls, + dirname: str, + shuffle: bool = True) -> classification_dataset.ClassificationDataset: + """Loads images and labels from the given directory. + + Assume the image data of the same label are in the same subdirectory. + + Args: + dirname: Name of the directory containing the data files. + shuffle: boolean, if shuffle, random shuffle data. + + Returns: + Dataset containing images and labels and other related info. + + Raises: + ValueError: if the input data directory is empty. + """ + data_root = os.path.abspath(dirname) + + # Assumes the image data of the same label are in the same subdirectory, + # gets image path and label names. + all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*')) + all_image_size = len(all_image_paths) + if all_image_size == 0: + raise ValueError('Image size is zero') + + if shuffle: + # Random shuffle data. + random.shuffle(all_image_paths) + + label_names = sorted( + name for name in os.listdir(data_root) + if os.path.isdir(os.path.join(data_root, name))) + all_label_size = len(label_names) + label_to_index = dict( + (name, index) for index, name in enumerate(label_names)) + all_image_labels = [ + label_to_index[os.path.basename(os.path.dirname(path))] + for path in all_image_paths + ] + + path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) + + autotune = tf.data.AUTOTUNE + image_ds = path_ds.map(_load_image, num_parallel_calls=autotune) + + # Loads label. + label_ds = tf.data.Dataset.from_tensor_slices( + tf.cast(all_image_labels, tf.int64)) + + # Creates a dataset if (image, label) pairs. + image_label_ds = tf.data.Dataset.zip((image_ds, label_ds)) + + tf.compat.v1.logging.info( + 'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, + all_label_size, ', '.join(label_names)) + return Dataset(image_label_ds, all_image_size, label_names) + + @classmethod + def load_tf_dataset( + cls, name: str + ) -> Tuple[Optional[classification_dataset.ClassificationDataset], + Optional[classification_dataset.ClassificationDataset], + Optional[classification_dataset.ClassificationDataset]]: + """Loads data from tensorflow_datasets. + + Args: + name: the registered name of the tfds.core.DatasetBuilder. Refer to the + documentation of tfds.load for more details. + + Returns: + A tuple of Datasets for the train/validation/test. + + Raises: + ValueError: if the input tf dataset does not have train/validation/test + labels. + """ + data, info = tfds.load(name, with_info=True) + if 'label' not in info.features: + raise ValueError('info.features need to contain \'label\' key.') + label_names = info.features['label'].names + + train_data = _create_data('train', data, info, label_names) + validation_data = _create_data('validation', data, info, label_names) + test_data = _create_data('test', data, info, label_names) + return train_data, validation_data, test_data diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py new file mode 100644 index 000000000..3a5d198b4 --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py @@ -0,0 +1,108 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.vision.image_classifier import dataset + + +def _fill_image(rgb, image_size): + r, g, b = rgb + return np.broadcast_to( + np.array([[[r, g, b]]], dtype=np.uint8), + shape=(image_size, image_size, 3)) + + +def _write_filled_jpeg_file(path, rgb, image_size): + tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size), + 'channels_last', 'jpeg') + + +class DatasetTest(tf.test.TestCase): + + def setUp(self): + super().setUp() + self.image_path = os.path.join(self.get_temp_dir(), 'random_image_dir') + if os.path.exists(self.image_path): + return + os.mkdir(self.image_path) + for class_name in ('daisy', 'tulips'): + class_subdir = os.path.join(self.image_path, class_name) + os.mkdir(class_subdir) + _write_filled_jpeg_file( + os.path.join(class_subdir, '0.jpeg'), + [random.uniform(0, 255) for _ in range(3)], 224) + + def test_split(self): + ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) + data = dataset.Dataset(ds, 4, ['pos', 'neg']) + train_data, test_data = data.split(0.5) + + self.assertLen(train_data, 2) + for i, elem in enumerate(train_data._dataset): + self.assertTrue((elem.numpy() == np.array([i, 1])).all()) + self.assertEqual(train_data.num_classes, 2) + self.assertEqual(train_data.index_to_label, ['pos', 'neg']) + + self.assertLen(test_data, 2) + for i, elem in enumerate(test_data._dataset): + self.assertTrue((elem.numpy() == np.array([i, 0])).all()) + self.assertEqual(test_data.num_classes, 2) + self.assertEqual(test_data.index_to_label, ['pos', 'neg']) + + def test_from_folder(self): + data = dataset.Dataset.from_folder(self.image_path) + + self.assertLen(data, 2) + self.assertEqual(data.num_classes, 2) + self.assertEqual(data.index_to_label, ['daisy', 'tulips']) + for image, label in data.gen_tf_dataset(): + self.assertTrue(label.numpy() == 1 or label.numpy() == 0) + if label.numpy() == 0: + raw_image_tensor = dataset._load_image( + os.path.join(self.image_path, 'daisy', '0.jpeg')) + else: + raw_image_tensor = dataset._load_image( + os.path.join(self.image_path, 'tulips', '0.jpeg')) + self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all()) + + def test_from_tfds(self): + # TODO: Remove this once tfds download error is fixed. + self.skipTest('Temporarily skip the unittest due to tfds download error.') + train_data, validation_data, test_data = ( + dataset.Dataset.from_tfds('beans')) + self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset) + self.assertLen(train_data, 1034) + self.assertEqual(train_data.num_classes, 3) + self.assertEqual(train_data.index_to_label, + ['angular_leaf_spot', 'bean_rust', 'healthy']) + + self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset) + self.assertLen(validation_data, 133) + self.assertEqual(validation_data.num_classes, 3) + self.assertEqual(validation_data.index_to_label, + ['angular_leaf_spot', 'bean_rust', 'healthy']) + + self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset) + self.assertLen(test_data, 128) + self.assertEqual(test_data.num_classes, 3) + self.assertEqual(test_data.index_to_label, + ['angular_leaf_spot', 'bean_rust', 'healthy']) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/image_classifier/hyperparameters.py b/mediapipe/model_maker/python/vision/image_classifier/hyperparameters.py new file mode 100644 index 000000000..6df18579a --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/hyperparameters.py @@ -0,0 +1,74 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hyperparameters for training image classification models.""" + +import dataclasses +import tempfile +from typing import Optional + + +# TODO: Expose other hyperparameters, e.g. data augmentation +# hyperparameters if requested. +@dataclasses.dataclass +class HParams: + """The hyperparameters for training image classifiers. + + The hyperparameters include: + # Parameters about training data. + do_fine_tuning: If true, the base module is trained together with the + classification layer on top. + shuffle: A boolean controlling if shuffle the dataset. Default to false. + + # Parameters about training configuration + train_epochs: Training will do this many iterations over the dataset. + batch_size: Each training step samples a batch of this many images. + learning_rate: The learning rate to use for gradient descent training. + dropout_rate: The fraction of the input units to drop, used in dropout + layer. + l1_regularizer: A regularizer that applies a L1 regularization penalty. + l2_regularizer: A regularizer that applies a L2 regularization penalty. + label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for + more details. + do_data_augmentation: A boolean controlling whether the training dataset is + augmented by randomly distorting input images, including random cropping, + flipping, etc. See utils.image_preprocessing documentation for details. + steps_per_epoch: An optional integer indicate the number of training steps + per epoch. If not set, the training pipeline calculates the default steps + per epoch as the training dataset size devided by batch size. + decay_samples: Number of training samples used to calculate the decay steps + and create the training optimizer. + warmup_steps: Number of warmup steps for a linear increasing warmup schedule + on learning rate. Used to set up warmup schedule by model_util.WarmUp. + + # Parameters about the saved checkpoint + model_dir: The location of model checkpoint files and exported model files. + """ + # Parameters about training data + do_fine_tuning: bool = False + shuffle: bool = False + # Parameters about training configuration + train_epochs: int = 5 + batch_size: int = 32 + learning_rate: float = 0.005 + dropout_rate: float = 0.2 + l1_regularizer: float = 0.0 + l2_regularizer: float = 0.0001 + label_smoothing: float = 0.1 + do_data_augmentation: bool = True + steps_per_epoch: Optional[int] = None + decay_samples: int = 10000 * 256 + warmup_epochs: int = 2 + + # Parameters about the saved checkpoint + model_dir: str = tempfile.mkdtemp() diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py new file mode 100644 index 000000000..7a99f9ae0 --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -0,0 +1,172 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""APIs to train image classifier model.""" + +from typing import Any, List, Optional + +import tensorflow as tf +import tensorflow_hub as hub + +from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds +from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import image_preprocessing +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp +from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms +from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib + + +class ImageClassifier(classifier.Classifier): + """ImageClassifier for building image classification model.""" + + def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any], + hparams: hp.HParams): + """Initializes ImageClassifier class. + + Args: + model_spec: Specification for the model. + index_to_label: A list that maps from index to label class name. + hparams: The hyperparameters for training image classifier. + """ + super(ImageClassifier, self).__init__( + model_spec=model_spec, + index_to_label=index_to_label, + shuffle=hparams.shuffle, + full_train=hparams.do_fine_tuning) + self._hparams = hparams + self._preprocess = image_preprocessing.Preprocessor( + input_shape=self._model_spec.input_image_shape, + num_classes=self._num_classes, + mean_rgb=self._model_spec.mean_rgb, + stddev_rgb=self._model_spec.stddev_rgb, + use_augmentation=hparams.do_data_augmentation) + self._history = None # Training history returned from `keras_model.fit`. + + @classmethod + def create( + cls, + model_spec: ms.SupportedModels, + train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset, + hparams: Optional[hp.HParams] = None, + ) -> 'ImageClassifier': + """Creates and trains an image classifier. + + Loads data and trains the model based on data for image classification. + + Args: + model_spec: Specification for the model. + train_data: Training data. + validation_data: Validation data. + hparams: Hyperparameters for training image classifier. + + Returns: + An instance based on ImageClassifier. + """ + if hparams is None: + hparams = hp.HParams() + + spec = ms.SupportedModels.get(model_spec) + image_classifier = cls( + model_spec=spec, + index_to_label=train_data.index_to_label, + hparams=hparams) + + image_classifier._create_model() + + tf.compat.v1.logging.info('Training the models...') + image_classifier._train( + train_data=train_data, validation_data=validation_data) + + return image_classifier + + def _train(self, train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset): + """Trains the model with input train_data. + + The training results are recorded by a self._history object returned by + tf.keras.Model.fit(). + + Args: + train_data: Training data. + validation_data: Validation data. + """ + + tf.compat.v1.logging.info('Training the models...') + hparams = self._hparams + if len(train_data) < hparams.batch_size: + raise ValueError('The size of the train_data (%d) couldn\'t be smaller ' + 'than batch_size (%d). To solve this problem, set ' + 'the batch_size smaller or increase the size of the ' + 'train_data.' % (len(train_data), hparams.batch_size)) + + train_dataset = train_data.gen_tf_dataset( + batch_size=hparams.batch_size, + is_training=True, + shuffle=self._shuffle, + preprocess=self._preprocess) + hparams.steps_per_epoch = model_util.get_steps_per_epoch( + steps_per_epoch=hparams.steps_per_epoch, + batch_size=hparams.batch_size, + train_data=train_data) + train_dataset = train_dataset.take(count=hparams.steps_per_epoch) + + validation_dataset = validation_data.gen_tf_dataset( + batch_size=hparams.batch_size, + is_training=False, + preprocess=self._preprocess) + + # Train the model. + self._history = train_image_classifier_lib.train_model( + model=self._model, + hparams=hparams, + train_ds=train_dataset, + validation_ds=validation_dataset) + + def _create_model(self): + """Creates the classifier model from TFHub pretrained models.""" + module_layer = hub.KerasLayer( + handle=self._model_spec.uri, trainable=self._hparams.do_fine_tuning) + + image_size = self._model_spec.input_image_shape + + self._model = tf.keras.Sequential([ + tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer, + tf.keras.layers.Dropout(rate=self._hparams.dropout_rate), + tf.keras.layers.Dense( + units=self._num_classes, + activation='softmax', + kernel_regularizer=tf.keras.regularizers.l1_l2( + l1=self._hparams.l1_regularizer, + l2=self._hparams.l2_regularizer)) + ]) + print(self._model.summary()) + + def export_model( + self, + model_name: str = 'model.tflite', + quantization_config: Optional[quantization.QuantizationConfig] = None): + """Converts the model to the requested formats and exports to a file. + + Args: + model_name: File name to save tflite model. The full export path is + {export_dir}/{tflite_filename}. + quantization_config: The configuration for model quantization. + """ + super().export_tflite( + self._hparams.model_dir, + model_name, + quantization_config, + preprocess=self._preprocess) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py new file mode 100644 index 000000000..5832ea53a --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py @@ -0,0 +1,106 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Demo for making an image classifier model by MediaPipe Model Maker.""" + +import os + +# Dependency imports + +from absl import app +from absl import flags +from absl import logging +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.vision import image_classifier + +FLAGS = flags.FLAGS + + +def define_flags() -> None: + """Define flags for the image classifier model maker demo.""" + flags.DEFINE_string('export_dir', None, + 'The directory to save exported files.') + flags.DEFINE_string( + 'input_data_dir', None, + """The directory with input training data. If the training data is not + specified, the pipeline will download a default training dataset.""") + flags.DEFINE_enum_class('spec', + image_classifier.SupportedModels.EFFICIENTNET_LITE0, + image_classifier.SupportedModels, + 'The image classifier to run.') + flags.DEFINE_enum('quantization', None, ['dynamic', 'int8', 'float16'], + 'The quantization method to use when exporting the model.') + flags.mark_flag_as_required('export_dir') + + +def download_demo_data() -> str: + """Downloads demo data, and returns directory path.""" + data_dir = tf.keras.utils.get_file( + fname='flower_photos.tgz', + origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', + extract=True) + return os.path.join(os.path.dirname(data_dir), 'flower_photos') # folder name + + +def run(data_dir: str, export_dir: str, + model_spec: image_classifier.SupportedModels, + quantization_option: str) -> None: + """Runs demo.""" + data = image_classifier.Dataset.from_folder(data_dir) + train_data, rest_data = data.split(0.8) + validation_data, test_data = rest_data.split(0.5) + + model = image_classifier.ImageClassifier.create( + model_spec=model_spec, + train_data=train_data, + validation_data=validation_data, + hparams=image_classifier.HParams(model_dir=export_dir)) + + _, acc = model.evaluate(test_data) + print('Test accuracy: %f' % acc) + + if quantization_option is None: + quantization_config = None + elif quantization_option == 'dynamic': + quantization_config = quantization.QuantizationConfig.for_dynamic() + elif quantization_option == 'int8': + quantization_config = quantization.QuantizationConfig.for_int8(train_data) + elif quantization_option == 'float16': + quantization_config = quantization.QuantizationConfig.for_float16() + else: + raise ValueError(f'Quantization: {quantization} is not recognized') + + model.export_model(quantization_config=quantization_config) + model.export_labels(export_dir) + + +def main(_) -> None: + logging.set_verbosity(logging.INFO) + + if FLAGS.input_data_dir is None: + data_dir = download_demo_data() + else: + data_dir = FLAGS.input_data_dir + + export_dir = os.path.expanduser(FLAGS.export_dir) + run(data_dir=data_dir, + export_dir=export_dir, + model_spec=FLAGS.spec, + quantization_option=FLAGS.quantization) + + +if __name__ == '__main__': + define_flags() + app.run(main) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py new file mode 100644 index 000000000..a7faab5b6 --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -0,0 +1,122 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.vision import image_classifier + + +def _fill_image(rgb, image_size): + r, g, b = rgb + return np.broadcast_to( + np.array([[[r, g, b]]], dtype=np.uint8), + shape=(image_size, image_size, 3)) + + +class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): + IMAGE_SIZE = 24 + IMAGES_PER_CLASS = 2 + CMY_NAMES_AND_RGB_VALUES = (('cyan', (0, 255, 255)), + ('magenta', (255, 0, 255)), ('yellow', (255, 255, + 0))) + + def _gen(self): + for i, (_, rgb) in enumerate(self.CMY_NAMES_AND_RGB_VALUES): + for _ in range(self.IMAGES_PER_CLASS): + yield (_fill_image(rgb, self.IMAGE_SIZE), i) + + def _gen_cmy_data(self): + ds = tf.data.Dataset.from_generator( + self._gen, (tf.uint8, tf.int64), (tf.TensorShape( + [self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([]))) + data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3, + ['cyan', 'magenta', 'yellow']) + return data + + def setUp(self): + super(ImageClassifierTest, self).setUp() + all_data = self._gen_cmy_data() + # Splits data, 90% data for training, 10% for testing + self.train_data, self.test_data = all_data.split(0.9) + + @parameterized.named_parameters( + dict( + testcase_name='mobilenet_v2', + model_spec=image_classifier.SupportedModels.MOBILENET_V2, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + dict( + testcase_name='resnet_50', + model_spec=image_classifier.SupportedModels.RESNET_50, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + dict( + testcase_name='efficientnet_lite0', + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + dict( + testcase_name='efficientnet_lite1', + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + dict( + testcase_name='efficientnet_lite2', + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + dict( + testcase_name='efficientnet_lite3', + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE3, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + dict( + testcase_name='efficientnet_lite4', + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + ) + def test_create_and_train_model(self, + model_spec: image_classifier.SupportedModels, + hparams: image_classifier.HParams): + model = image_classifier.ImageClassifier.create( + model_spec=model_spec, + train_data=self.train_data, + hparams=hparams, + validation_data=self.test_data) + self._test_accuracy(model) + + def test_efficientnetlite0_model_with_model_maker_retraining_lib(self): + hparams = image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True) + model = image_classifier.ImageClassifier.create( + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0, + train_data=self.train_data, + hparams=hparams, + validation_data=self.test_data) + self._test_accuracy(model) + + def _test_accuracy(self, model, threshold=0.0): + _, accuracy = model.evaluate(self.test_data) + self.assertGreaterEqual(accuracy, threshold) + + +if __name__ == '__main__': + # Load compressed models from tensorflow_hub + os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED' + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_spec.py b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py new file mode 100644 index 000000000..4e9565274 --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py @@ -0,0 +1,104 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image classifier model specification.""" + +import enum +import functools +from typing import List, Optional + + +class ModelSpec(object): + """Specification of image classifier model.""" + + mean_rgb = [0.0] + stddev_rgb = [255.0] + + def __init__(self, + uri: str, + input_image_shape: Optional[List[int]] = None, + name: str = ''): + """Initializes a new instance of the `ImageModelSpec` class. + + Args: + uri: str, URI to the pretrained model. + input_image_shape: list of int, input image shape. Default: [224, 224]. + name: str, model spec name. + """ + self.uri = uri + self.name = name + + if input_image_shape is None: + input_image_shape = [224, 224] + self.input_image_shape = input_image_shape + + +mobilenet_v2_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4', + name='mobilenet_v2') + +resnet_50_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4', + name='resnet_50') + +efficientnet_lite0_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2', + name='efficientnet_lite0') + +efficientnet_lite1_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2', + input_image_shape=[240, 240], + name='efficientnet_lite1') + +efficientnet_lite2_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2', + input_image_shape=[260, 260], + name='efficientnet_lite2') + +efficientnet_lite3_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2', + input_image_shape=[280, 280], + name='efficientnet_lite3') + +efficientnet_lite4_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2', + input_image_shape=[300, 300], + name='efficientnet_lite4') + + +# TODO: Document the exposed models. +@enum.unique +class SupportedModels(enum.Enum): + """Image classifier model supported by model maker.""" + MOBILENET_V2 = mobilenet_v2_spec + RESNET_50 = resnet_50_spec + EFFICIENTNET_LITE0 = efficientnet_lite0_spec + EFFICIENTNET_LITE1 = efficientnet_lite1_spec + EFFICIENTNET_LITE2 = efficientnet_lite2_spec + EFFICIENTNET_LITE3 = efficientnet_lite3_spec + EFFICIENTNET_LITE4 = efficientnet_lite4_spec + + @classmethod + def get(cls, spec: 'SupportedModels') -> 'ModelSpec': + """Gets model spec from the input enum and initializes it.""" + if spec not in cls: + raise TypeError('Unsupported image classifier spec: {}'.format(spec)) + + return spec.value() diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py b/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py new file mode 100644 index 000000000..bacab016e --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py @@ -0,0 +1,93 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from typing import Callable, List +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms + + +class ModelSpecTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='mobilenet_v2_spec_test', + model_spec=ms.mobilenet_v2_spec, + expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4', + expected_name='mobilenet_v2', + expected_input_image_shape=[224, 224]), + dict( + testcase_name='resnet_50_spec_test', + model_spec=ms.resnet_50_spec, + expected_uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4', + expected_name='resnet_50', + expected_input_image_shape=[224, 224]), + dict( + testcase_name='efficientnet_lite0_spec_test', + model_spec=ms.efficientnet_lite0_spec, + expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2', + expected_name='efficientnet_lite0', + expected_input_image_shape=[224, 224]), + dict( + testcase_name='efficientnet_lite1_spec_test', + model_spec=ms.efficientnet_lite1_spec, + expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2', + expected_name='efficientnet_lite1', + expected_input_image_shape=[240, 240]), + dict( + testcase_name='efficientnet_lite2_spec_test', + model_spec=ms.efficientnet_lite2_spec, + expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2', + expected_name='efficientnet_lite2', + expected_input_image_shape=[260, 260]), + dict( + testcase_name='efficientnet_lite3_spec_test', + model_spec=ms.efficientnet_lite3_spec, + expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2', + expected_name='efficientnet_lite3', + expected_input_image_shape=[280, 280]), + dict( + testcase_name='efficientnet_lite4_spec_test', + model_spec=ms.efficientnet_lite4_spec, + expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2', + expected_name='efficientnet_lite4', + expected_input_image_shape=[300, 300]), + ) + def test_predefiend_spec(self, model_spec: Callable[..., ms.ModelSpec], + expected_uri: str, expected_name: str, + expected_input_image_shape: List[int]): + model_spec_obj = model_spec() + self.assertIsInstance(model_spec_obj, ms.ModelSpec) + self.assertEqual(model_spec_obj.uri, expected_uri) + self.assertEqual(model_spec_obj.name, expected_name) + self.assertEqual(model_spec_obj.input_image_shape, + expected_input_image_shape) + + def test_create_spec(self): + custom_model_spec = ms.ModelSpec( + uri='https://custom_model', + input_image_shape=[128, 128], + name='custom_model') + self.assertEqual(custom_model_spec.uri, 'https://custom_model') + self.assertEqual(custom_model_spec.name, 'custom_model') + self.assertEqual(custom_model_spec.input_image_shape, [128, 128]) + + +if __name__ == '__main__': + # Load compressed models from tensorflow_hub + os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED' + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py new file mode 100644 index 000000000..704d71a5a --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py @@ -0,0 +1,103 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Library to train model.""" + +import os +from typing import List + +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp + + +def _create_optimizer(init_lr: float, decay_steps: int, + warmup_steps: int) -> tf.keras.optimizers.Optimizer: + """Creates an optimizer with learning rate schedule. + + Uses Keras CosineDecay schedule for the learning rate by default. + + Args: + init_lr: Initial learning rate. + decay_steps: Number of steps to decay over. + warmup_steps: Number of steps to do warmup for. + + Returns: + A tf.keras.optimizers.Optimizer for model training. + """ + learning_rate_fn = tf.keras.experimental.CosineDecay( + initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0) + if warmup_steps: + learning_rate_fn = model_util.WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=learning_rate_fn, + warmup_steps=warmup_steps) + optimizer = tf.keras.optimizers.RMSprop( + learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001) + + return optimizer + + +def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]: + """Gets default callbacks.""" + summary_dir = os.path.join(model_dir, 'summaries') + summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) + # Save checkpoint every 20 epochs. + + checkpoint_path = os.path.join(model_dir, 'checkpoint') + checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + checkpoint_path, save_weights_only=True, period=20) + return [summary_callback, checkpoint_callback] + + +def train_model(model: tf.keras.Model, hparams: hp.HParams, + train_ds: tf.data.Dataset, + validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History: + """Trains model with the input data and hyperparameters. + + Args: + model: Input tf.keras.Model. + hparams: Hyperparameters for training image classifier. + train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit(). + validation_ds: tf.data.Dataset, validation data to be fed in + tf.keras.Model.fit(). + + Returns: + The tf.keras.callbacks.History object returned by tf.keras.Model.fit(). + """ + + # Learning rate is linear to batch size. + learning_rate = hparams.learning_rate * hparams.batch_size / 256 + + # Get decay steps. + total_training_steps = hparams.steps_per_epoch * hparams.train_epochs + default_decay_steps = hparams.decay_samples // hparams.batch_size + decay_steps = max(total_training_steps, default_decay_steps) + + warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch + optimizer = _create_optimizer( + init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps) + + loss = tf.keras.losses.CategoricalCrossentropy( + label_smoothing=hparams.label_smoothing) + model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) + callbacks = _get_default_callbacks(hparams.model_dir) + + # Train the model. + return model.fit( + x=train_ds, + epochs=hparams.train_epochs, + steps_per_epoch=hparams.steps_per_epoch, + validation_data=validation_ds, + callbacks=callbacks) diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 5e3832b09..389ee484a 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -2,3 +2,5 @@ absl-py numpy opencv-contrib-python tensorflow +tensorflow-datasets +tensorflow-hub