Open source model_maker/python/core/tasks and model_maker/python/vision/image_classifier
PiperOrigin-RevId: 481182271
This commit is contained in:
parent
6f3e8381ed
commit
0428550d75
64
mediapipe/model_maker/python/core/tasks/BUILD
Normal file
64
mediapipe/model_maker/python/core/tasks/BUILD
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "custom_model",
|
||||||
|
srcs = ["custom_model.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "custom_model_test",
|
||||||
|
srcs = ["custom_model_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":custom_model",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "classifier",
|
||||||
|
srcs = ["classifier.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":custom_model",
|
||||||
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "classifier_test",
|
||||||
|
srcs = ["classifier_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":classifier",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||||
|
],
|
||||||
|
)
|
13
mediapipe/model_maker/python/core/tasks/__init__.py
Normal file
13
mediapipe/model_maker/python/core/tasks/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
77
mediapipe/model_maker/python/core/tasks/classifier.py
Normal file
77
mediapipe/model_maker/python/core/tasks/classifier.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Custom classifier."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import dataset
|
||||||
|
from mediapipe.model_maker.python.core.tasks import custom_model
|
||||||
|
|
||||||
|
|
||||||
|
class Classifier(custom_model.CustomModel):
|
||||||
|
"""An abstract base class that represents a TensorFlow classifier."""
|
||||||
|
|
||||||
|
def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool,
|
||||||
|
full_train: bool):
|
||||||
|
"""Initilizes a classifier with its specifications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_spec: Specification for the model.
|
||||||
|
index_to_label: A list that map from index to label class name.
|
||||||
|
shuffle: Whether the dataset should be shuffled.
|
||||||
|
full_train: If true, train the model end-to-end including the backbone
|
||||||
|
and the classification layers on top. Otherwise, only train the top
|
||||||
|
classification layers.
|
||||||
|
"""
|
||||||
|
super(Classifier, self).__init__(model_spec, shuffle)
|
||||||
|
self._index_to_label = index_to_label
|
||||||
|
self._full_train = full_train
|
||||||
|
self._num_classes = len(index_to_label)
|
||||||
|
|
||||||
|
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||||
|
"""Evaluates the classifier with the provided evaluation dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Evaluation dataset
|
||||||
|
batch_size: Number of samples per evaluation step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loss value and accuracy.
|
||||||
|
"""
|
||||||
|
ds = data.gen_tf_dataset(
|
||||||
|
batch_size, is_training=False, preprocess=self._preprocess)
|
||||||
|
return self._model.evaluate(ds)
|
||||||
|
|
||||||
|
def export_labels(self, export_dir: str, label_filename: str = 'labels.txt'):
|
||||||
|
"""Exports classification labels into a label file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: The directory to save exported files.
|
||||||
|
label_filename: File name to save labels model. The full export path is
|
||||||
|
{export_dir}/{label_filename}.
|
||||||
|
"""
|
||||||
|
if not tf.io.gfile.exists(export_dir):
|
||||||
|
tf.io.gfile.makedirs(export_dir)
|
||||||
|
|
||||||
|
label_filepath = os.path.join(export_dir, label_filename)
|
||||||
|
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
|
||||||
|
with tf.io.gfile.GFile(label_filepath, 'w') as f:
|
||||||
|
f.write('\n'.join(self._index_to_label))
|
58
mediapipe/model_maker/python/core/tasks/classifier_test.py
Normal file
58
mediapipe/model_maker/python/core/tasks/classifier_test.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.tasks import classifier
|
||||||
|
from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
|
||||||
|
|
||||||
|
class MockClassifier(classifier.Classifier):
|
||||||
|
"""A mock class with implementation of abstract methods for testing."""
|
||||||
|
|
||||||
|
def train(self, train_data, validation_data=None, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def evaluate(self, data, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(ClassifierTest, self).setUp()
|
||||||
|
index_to_label = ['cat', 'dog']
|
||||||
|
self.model = MockClassifier(
|
||||||
|
model_spec=None,
|
||||||
|
index_to_label=index_to_label,
|
||||||
|
shuffle=False,
|
||||||
|
full_train=False)
|
||||||
|
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||||
|
|
||||||
|
def _check_nonempty_file(self, filepath):
|
||||||
|
self.assertTrue(os.path.isfile(filepath))
|
||||||
|
self.assertGreater(os.path.getsize(filepath), 0)
|
||||||
|
|
||||||
|
def test_export_labels(self):
|
||||||
|
export_path = os.path.join(self.get_temp_dir(), 'export/')
|
||||||
|
self.model.export_labels(export_dir=export_path)
|
||||||
|
self._check_nonempty_file(os.path.join(export_path, 'labels.txt'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
85
mediapipe/model_maker/python/core/tasks/custom_model.py
Normal file
85
mediapipe/model_maker/python/core/tasks/custom_model.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Interface to define a custom model."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import os
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import dataset
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
|
||||||
|
|
||||||
|
class CustomModel(abc.ABC):
|
||||||
|
"""The abstract base class that represents a custom TensorFlow model."""
|
||||||
|
|
||||||
|
def __init__(self, model_spec: Any, shuffle: bool):
|
||||||
|
"""Initializes a custom model with model specs and other parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_spec: Specification for the model.
|
||||||
|
shuffle: Whether the training data need be shuffled.
|
||||||
|
"""
|
||||||
|
self._model_spec = model_spec
|
||||||
|
self._shuffle = shuffle
|
||||||
|
self._preprocess = None
|
||||||
|
self._model = None
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def evaluate(self, data: dataset.Dataset, **kwargs):
|
||||||
|
"""Evaluates the model with the provided data."""
|
||||||
|
return
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
"""Prints a summary of the model."""
|
||||||
|
self._model.summary()
|
||||||
|
|
||||||
|
def export_tflite(
|
||||||
|
self,
|
||||||
|
export_dir: str,
|
||||||
|
tflite_filename: str = 'model.tflite',
|
||||||
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
|
preprocess: Optional[Callable[..., bool]] = None):
|
||||||
|
"""Converts the model to requested formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: The directory to save exported files.
|
||||||
|
tflite_filename: File name to save tflite model. The full export path is
|
||||||
|
{export_dir}/{tflite_filename}.
|
||||||
|
quantization_config: The configuration for model quantization.
|
||||||
|
preprocess: A callable to preprocess the representative dataset for
|
||||||
|
quantization. The callable takes three arguments in order: feature,
|
||||||
|
label, and is_training.
|
||||||
|
"""
|
||||||
|
if not tf.io.gfile.exists(export_dir):
|
||||||
|
tf.io.gfile.makedirs(export_dir)
|
||||||
|
|
||||||
|
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
||||||
|
# TODO: Populate metadata to the exported TFLite model.
|
||||||
|
model_util.export_tflite(
|
||||||
|
self._model,
|
||||||
|
tflite_filepath,
|
||||||
|
quantization_config,
|
||||||
|
preprocess=preprocess)
|
||||||
|
tf.compat.v1.logging.info(
|
||||||
|
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
56
mediapipe/model_maker/python/core/tasks/custom_model_test.py
Normal file
56
mediapipe/model_maker/python/core/tasks/custom_model_test.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.tasks import custom_model
|
||||||
|
from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
|
||||||
|
|
||||||
|
class MockCustomModel(custom_model.CustomModel):
|
||||||
|
"""A mock class with implementation of abstract methods for testing."""
|
||||||
|
|
||||||
|
def train(self, train_data, validation_data=None, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def evaluate(self, data, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CustomModelTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(CustomModelTest, self).setUp()
|
||||||
|
self.model = MockCustomModel(model_spec=None, shuffle=False)
|
||||||
|
self.model._model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||||
|
|
||||||
|
def _check_nonempty_file(self, filepath):
|
||||||
|
self.assertTrue(os.path.isfile(filepath))
|
||||||
|
self.assertGreater(os.path.getsize(filepath), 0)
|
||||||
|
|
||||||
|
def test_export_tflite(self):
|
||||||
|
export_path = os.path.join(self.get_temp_dir(), 'export/')
|
||||||
|
self.model.export_tflite(export_dir=export_path)
|
||||||
|
self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
19
mediapipe/model_maker/python/vision/BUILD
Normal file
19
mediapipe/model_maker/python/vision/BUILD
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"])
|
13
mediapipe/model_maker/python/vision/__init__.py
Normal file
13
mediapipe/model_maker/python/vision/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
111
mediapipe/model_maker/python/vision/image_classifier/BUILD
Normal file
111
mediapipe/model_maker/python/vision/image_classifier/BUILD
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python library rule.
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "image_classifier_import",
|
||||||
|
srcs = ["__init__.py"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":hyperparameters",
|
||||||
|
":image_classifier",
|
||||||
|
":model_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_spec",
|
||||||
|
srcs = ["model_spec.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "model_spec_test",
|
||||||
|
srcs = ["model_spec_test.py"],
|
||||||
|
deps = [":model_spec"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dataset",
|
||||||
|
srcs = ["dataset.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dataset_test",
|
||||||
|
srcs = ["dataset_test.py"],
|
||||||
|
deps = [":dataset"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "hyperparameters",
|
||||||
|
srcs = ["hyperparameters.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "train_image_classifier_lib",
|
||||||
|
srcs = ["train_image_classifier_lib.py"],
|
||||||
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "image_classifier",
|
||||||
|
srcs = ["image_classifier.py"],
|
||||||
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
|
":model_spec",
|
||||||
|
":train_image_classifier_lib",
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:image_preprocessing",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "image_classifier_test_lib",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["image_classifier_test.py"],
|
||||||
|
deps = [":image_classifier_import"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "image_classifier_test",
|
||||||
|
srcs = ["image_classifier_test.py"],
|
||||||
|
shard_count = 2,
|
||||||
|
tags = ["requires-net:external"],
|
||||||
|
deps = [
|
||||||
|
":image_classifier_test_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "image_classifier_demo",
|
||||||
|
srcs = ["image_classifier_demo.py"],
|
||||||
|
deps = [
|
||||||
|
":image_classifier_import",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe Model Maker Python Public API For Image Classifier."""
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import dataset
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import image_classifier
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import model_spec
|
||||||
|
|
||||||
|
ImageClassifier = image_classifier.ImageClassifier
|
||||||
|
HParams = hyperparameters.HParams
|
||||||
|
Dataset = dataset.Dataset
|
||||||
|
ModelSpec = model_spec.ModelSpec
|
||||||
|
SupportedModels = model_spec.SupportedModels
|
139
mediapipe/model_maker/python/vision/image_classifier/dataset.py
Normal file
139
mediapipe/model_maker/python/vision/image_classifier/dataset.py
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image classifier dataset library."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _load_image(path: str) -> tf.Tensor:
|
||||||
|
"""Loads image."""
|
||||||
|
image_raw = tf.io.read_file(path)
|
||||||
|
image_tensor = tf.cond(
|
||||||
|
tf.image.is_jpeg(image_raw),
|
||||||
|
lambda: tf.image.decode_jpeg(image_raw, channels=3),
|
||||||
|
lambda: tf.image.decode_png(image_raw, channels=3))
|
||||||
|
return image_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _create_data(
|
||||||
|
name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo,
|
||||||
|
label_names: List[str]
|
||||||
|
) -> Optional[classification_dataset.ClassificationDataset]:
|
||||||
|
"""Creates a Dataset object from tfds data."""
|
||||||
|
if name not in data:
|
||||||
|
return None
|
||||||
|
data = data[name]
|
||||||
|
data = data.map(lambda a: (a['image'], a['label']))
|
||||||
|
size = info.splits[name].num_examples
|
||||||
|
return Dataset(data, size, label_names)
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
|
"""Dataset library for image classifier."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_folder(
|
||||||
|
cls,
|
||||||
|
dirname: str,
|
||||||
|
shuffle: bool = True) -> classification_dataset.ClassificationDataset:
|
||||||
|
"""Loads images and labels from the given directory.
|
||||||
|
|
||||||
|
Assume the image data of the same label are in the same subdirectory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dirname: Name of the directory containing the data files.
|
||||||
|
shuffle: boolean, if shuffle, random shuffle data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing images and labels and other related info.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the input data directory is empty.
|
||||||
|
"""
|
||||||
|
data_root = os.path.abspath(dirname)
|
||||||
|
|
||||||
|
# Assumes the image data of the same label are in the same subdirectory,
|
||||||
|
# gets image path and label names.
|
||||||
|
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
||||||
|
all_image_size = len(all_image_paths)
|
||||||
|
if all_image_size == 0:
|
||||||
|
raise ValueError('Image size is zero')
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
# Random shuffle data.
|
||||||
|
random.shuffle(all_image_paths)
|
||||||
|
|
||||||
|
label_names = sorted(
|
||||||
|
name for name in os.listdir(data_root)
|
||||||
|
if os.path.isdir(os.path.join(data_root, name)))
|
||||||
|
all_label_size = len(label_names)
|
||||||
|
label_to_index = dict(
|
||||||
|
(name, index) for index, name in enumerate(label_names))
|
||||||
|
all_image_labels = [
|
||||||
|
label_to_index[os.path.basename(os.path.dirname(path))]
|
||||||
|
for path in all_image_paths
|
||||||
|
]
|
||||||
|
|
||||||
|
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||||
|
|
||||||
|
autotune = tf.data.AUTOTUNE
|
||||||
|
image_ds = path_ds.map(_load_image, num_parallel_calls=autotune)
|
||||||
|
|
||||||
|
# Loads label.
|
||||||
|
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
|
tf.cast(all_image_labels, tf.int64))
|
||||||
|
|
||||||
|
# Creates a dataset if (image, label) pairs.
|
||||||
|
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||||
|
|
||||||
|
tf.compat.v1.logging.info(
|
||||||
|
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
||||||
|
all_label_size, ', '.join(label_names))
|
||||||
|
return Dataset(image_label_ds, all_image_size, label_names)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_tf_dataset(
|
||||||
|
cls, name: str
|
||||||
|
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
|
||||||
|
Optional[classification_dataset.ClassificationDataset],
|
||||||
|
Optional[classification_dataset.ClassificationDataset]]:
|
||||||
|
"""Loads data from tensorflow_datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
|
||||||
|
documentation of tfds.load for more details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of Datasets for the train/validation/test.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the input tf dataset does not have train/validation/test
|
||||||
|
labels.
|
||||||
|
"""
|
||||||
|
data, info = tfds.load(name, with_info=True)
|
||||||
|
if 'label' not in info.features:
|
||||||
|
raise ValueError('info.features need to contain \'label\' key.')
|
||||||
|
label_names = info.features['label'].names
|
||||||
|
|
||||||
|
train_data = _create_data('train', data, info, label_names)
|
||||||
|
validation_data = _create_data('validation', data, info, label_names)
|
||||||
|
test_data = _create_data('test', data, info, label_names)
|
||||||
|
return train_data, validation_data, test_data
|
|
@ -0,0 +1,108 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _fill_image(rgb, image_size):
|
||||||
|
r, g, b = rgb
|
||||||
|
return np.broadcast_to(
|
||||||
|
np.array([[[r, g, b]]], dtype=np.uint8),
|
||||||
|
shape=(image_size, image_size, 3))
|
||||||
|
|
||||||
|
|
||||||
|
def _write_filled_jpeg_file(path, rgb, image_size):
|
||||||
|
tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size),
|
||||||
|
'channels_last', 'jpeg')
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.image_path = os.path.join(self.get_temp_dir(), 'random_image_dir')
|
||||||
|
if os.path.exists(self.image_path):
|
||||||
|
return
|
||||||
|
os.mkdir(self.image_path)
|
||||||
|
for class_name in ('daisy', 'tulips'):
|
||||||
|
class_subdir = os.path.join(self.image_path, class_name)
|
||||||
|
os.mkdir(class_subdir)
|
||||||
|
_write_filled_jpeg_file(
|
||||||
|
os.path.join(class_subdir, '0.jpeg'),
|
||||||
|
[random.uniform(0, 255) for _ in range(3)], 224)
|
||||||
|
|
||||||
|
def test_split(self):
|
||||||
|
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||||
|
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
|
||||||
|
train_data, test_data = data.split(0.5)
|
||||||
|
|
||||||
|
self.assertLen(train_data, 2)
|
||||||
|
for i, elem in enumerate(train_data._dataset):
|
||||||
|
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
|
||||||
|
self.assertEqual(train_data.num_classes, 2)
|
||||||
|
self.assertEqual(train_data.index_to_label, ['pos', 'neg'])
|
||||||
|
|
||||||
|
self.assertLen(test_data, 2)
|
||||||
|
for i, elem in enumerate(test_data._dataset):
|
||||||
|
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
|
||||||
|
self.assertEqual(test_data.num_classes, 2)
|
||||||
|
self.assertEqual(test_data.index_to_label, ['pos', 'neg'])
|
||||||
|
|
||||||
|
def test_from_folder(self):
|
||||||
|
data = dataset.Dataset.from_folder(self.image_path)
|
||||||
|
|
||||||
|
self.assertLen(data, 2)
|
||||||
|
self.assertEqual(data.num_classes, 2)
|
||||||
|
self.assertEqual(data.index_to_label, ['daisy', 'tulips'])
|
||||||
|
for image, label in data.gen_tf_dataset():
|
||||||
|
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
|
||||||
|
if label.numpy() == 0:
|
||||||
|
raw_image_tensor = dataset._load_image(
|
||||||
|
os.path.join(self.image_path, 'daisy', '0.jpeg'))
|
||||||
|
else:
|
||||||
|
raw_image_tensor = dataset._load_image(
|
||||||
|
os.path.join(self.image_path, 'tulips', '0.jpeg'))
|
||||||
|
self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all())
|
||||||
|
|
||||||
|
def test_from_tfds(self):
|
||||||
|
# TODO: Remove this once tfds download error is fixed.
|
||||||
|
self.skipTest('Temporarily skip the unittest due to tfds download error.')
|
||||||
|
train_data, validation_data, test_data = (
|
||||||
|
dataset.Dataset.from_tfds('beans'))
|
||||||
|
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
|
||||||
|
self.assertLen(train_data, 1034)
|
||||||
|
self.assertEqual(train_data.num_classes, 3)
|
||||||
|
self.assertEqual(train_data.index_to_label,
|
||||||
|
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||||
|
|
||||||
|
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
|
||||||
|
self.assertLen(validation_data, 133)
|
||||||
|
self.assertEqual(validation_data.num_classes, 3)
|
||||||
|
self.assertEqual(validation_data.index_to_label,
|
||||||
|
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||||
|
|
||||||
|
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
|
||||||
|
self.assertLen(test_data, 128)
|
||||||
|
self.assertEqual(test_data.num_classes, 3)
|
||||||
|
self.assertEqual(test_data.index_to_label,
|
||||||
|
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Hyperparameters for training image classification models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import tempfile
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Expose other hyperparameters, e.g. data augmentation
|
||||||
|
# hyperparameters if requested.
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class HParams:
|
||||||
|
"""The hyperparameters for training image classifiers.
|
||||||
|
|
||||||
|
The hyperparameters include:
|
||||||
|
# Parameters about training data.
|
||||||
|
do_fine_tuning: If true, the base module is trained together with the
|
||||||
|
classification layer on top.
|
||||||
|
shuffle: A boolean controlling if shuffle the dataset. Default to false.
|
||||||
|
|
||||||
|
# Parameters about training configuration
|
||||||
|
train_epochs: Training will do this many iterations over the dataset.
|
||||||
|
batch_size: Each training step samples a batch of this many images.
|
||||||
|
learning_rate: The learning rate to use for gradient descent training.
|
||||||
|
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||||
|
layer.
|
||||||
|
l1_regularizer: A regularizer that applies a L1 regularization penalty.
|
||||||
|
l2_regularizer: A regularizer that applies a L2 regularization penalty.
|
||||||
|
label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for
|
||||||
|
more details.
|
||||||
|
do_data_augmentation: A boolean controlling whether the training dataset is
|
||||||
|
augmented by randomly distorting input images, including random cropping,
|
||||||
|
flipping, etc. See utils.image_preprocessing documentation for details.
|
||||||
|
steps_per_epoch: An optional integer indicate the number of training steps
|
||||||
|
per epoch. If not set, the training pipeline calculates the default steps
|
||||||
|
per epoch as the training dataset size devided by batch size.
|
||||||
|
decay_samples: Number of training samples used to calculate the decay steps
|
||||||
|
and create the training optimizer.
|
||||||
|
warmup_steps: Number of warmup steps for a linear increasing warmup schedule
|
||||||
|
on learning rate. Used to set up warmup schedule by model_util.WarmUp.
|
||||||
|
|
||||||
|
# Parameters about the saved checkpoint
|
||||||
|
model_dir: The location of model checkpoint files and exported model files.
|
||||||
|
"""
|
||||||
|
# Parameters about training data
|
||||||
|
do_fine_tuning: bool = False
|
||||||
|
shuffle: bool = False
|
||||||
|
# Parameters about training configuration
|
||||||
|
train_epochs: int = 5
|
||||||
|
batch_size: int = 32
|
||||||
|
learning_rate: float = 0.005
|
||||||
|
dropout_rate: float = 0.2
|
||||||
|
l1_regularizer: float = 0.0
|
||||||
|
l2_regularizer: float = 0.0001
|
||||||
|
label_smoothing: float = 0.1
|
||||||
|
do_data_augmentation: bool = True
|
||||||
|
steps_per_epoch: Optional[int] = None
|
||||||
|
decay_samples: int = 10000 * 256
|
||||||
|
warmup_epochs: int = 2
|
||||||
|
|
||||||
|
# Parameters about the saved checkpoint
|
||||||
|
model_dir: str = tempfile.mkdtemp()
|
|
@ -0,0 +1,172 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""APIs to train image classifier model."""
|
||||||
|
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
|
||||||
|
from mediapipe.model_maker.python.core.tasks import classifier
|
||||||
|
from mediapipe.model_maker.python.core.utils import image_preprocessing
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
||||||
|
|
||||||
|
|
||||||
|
class ImageClassifier(classifier.Classifier):
|
||||||
|
"""ImageClassifier for building image classification model."""
|
||||||
|
|
||||||
|
def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any],
|
||||||
|
hparams: hp.HParams):
|
||||||
|
"""Initializes ImageClassifier class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_spec: Specification for the model.
|
||||||
|
index_to_label: A list that maps from index to label class name.
|
||||||
|
hparams: The hyperparameters for training image classifier.
|
||||||
|
"""
|
||||||
|
super(ImageClassifier, self).__init__(
|
||||||
|
model_spec=model_spec,
|
||||||
|
index_to_label=index_to_label,
|
||||||
|
shuffle=hparams.shuffle,
|
||||||
|
full_train=hparams.do_fine_tuning)
|
||||||
|
self._hparams = hparams
|
||||||
|
self._preprocess = image_preprocessing.Preprocessor(
|
||||||
|
input_shape=self._model_spec.input_image_shape,
|
||||||
|
num_classes=self._num_classes,
|
||||||
|
mean_rgb=self._model_spec.mean_rgb,
|
||||||
|
stddev_rgb=self._model_spec.stddev_rgb,
|
||||||
|
use_augmentation=hparams.do_data_augmentation)
|
||||||
|
self._history = None # Training history returned from `keras_model.fit`.
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
model_spec: ms.SupportedModels,
|
||||||
|
train_data: classification_ds.ClassificationDataset,
|
||||||
|
validation_data: classification_ds.ClassificationDataset,
|
||||||
|
hparams: Optional[hp.HParams] = None,
|
||||||
|
) -> 'ImageClassifier':
|
||||||
|
"""Creates and trains an image classifier.
|
||||||
|
|
||||||
|
Loads data and trains the model based on data for image classification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_spec: Specification for the model.
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
hparams: Hyperparameters for training image classifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance based on ImageClassifier.
|
||||||
|
"""
|
||||||
|
if hparams is None:
|
||||||
|
hparams = hp.HParams()
|
||||||
|
|
||||||
|
spec = ms.SupportedModels.get(model_spec)
|
||||||
|
image_classifier = cls(
|
||||||
|
model_spec=spec,
|
||||||
|
index_to_label=train_data.index_to_label,
|
||||||
|
hparams=hparams)
|
||||||
|
|
||||||
|
image_classifier._create_model()
|
||||||
|
|
||||||
|
tf.compat.v1.logging.info('Training the models...')
|
||||||
|
image_classifier._train(
|
||||||
|
train_data=train_data, validation_data=validation_data)
|
||||||
|
|
||||||
|
return image_classifier
|
||||||
|
|
||||||
|
def _train(self, train_data: classification_ds.ClassificationDataset,
|
||||||
|
validation_data: classification_ds.ClassificationDataset):
|
||||||
|
"""Trains the model with input train_data.
|
||||||
|
|
||||||
|
The training results are recorded by a self._history object returned by
|
||||||
|
tf.keras.Model.fit().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tf.compat.v1.logging.info('Training the models...')
|
||||||
|
hparams = self._hparams
|
||||||
|
if len(train_data) < hparams.batch_size:
|
||||||
|
raise ValueError('The size of the train_data (%d) couldn\'t be smaller '
|
||||||
|
'than batch_size (%d). To solve this problem, set '
|
||||||
|
'the batch_size smaller or increase the size of the '
|
||||||
|
'train_data.' % (len(train_data), hparams.batch_size))
|
||||||
|
|
||||||
|
train_dataset = train_data.gen_tf_dataset(
|
||||||
|
batch_size=hparams.batch_size,
|
||||||
|
is_training=True,
|
||||||
|
shuffle=self._shuffle,
|
||||||
|
preprocess=self._preprocess)
|
||||||
|
hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||||
|
steps_per_epoch=hparams.steps_per_epoch,
|
||||||
|
batch_size=hparams.batch_size,
|
||||||
|
train_data=train_data)
|
||||||
|
train_dataset = train_dataset.take(count=hparams.steps_per_epoch)
|
||||||
|
|
||||||
|
validation_dataset = validation_data.gen_tf_dataset(
|
||||||
|
batch_size=hparams.batch_size,
|
||||||
|
is_training=False,
|
||||||
|
preprocess=self._preprocess)
|
||||||
|
|
||||||
|
# Train the model.
|
||||||
|
self._history = train_image_classifier_lib.train_model(
|
||||||
|
model=self._model,
|
||||||
|
hparams=hparams,
|
||||||
|
train_ds=train_dataset,
|
||||||
|
validation_ds=validation_dataset)
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
"""Creates the classifier model from TFHub pretrained models."""
|
||||||
|
module_layer = hub.KerasLayer(
|
||||||
|
handle=self._model_spec.uri, trainable=self._hparams.do_fine_tuning)
|
||||||
|
|
||||||
|
image_size = self._model_spec.input_image_shape
|
||||||
|
|
||||||
|
self._model = tf.keras.Sequential([
|
||||||
|
tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer,
|
||||||
|
tf.keras.layers.Dropout(rate=self._hparams.dropout_rate),
|
||||||
|
tf.keras.layers.Dense(
|
||||||
|
units=self._num_classes,
|
||||||
|
activation='softmax',
|
||||||
|
kernel_regularizer=tf.keras.regularizers.l1_l2(
|
||||||
|
l1=self._hparams.l1_regularizer,
|
||||||
|
l2=self._hparams.l2_regularizer))
|
||||||
|
])
|
||||||
|
print(self._model.summary())
|
||||||
|
|
||||||
|
def export_model(
|
||||||
|
self,
|
||||||
|
model_name: str = 'model.tflite',
|
||||||
|
quantization_config: Optional[quantization.QuantizationConfig] = None):
|
||||||
|
"""Converts the model to the requested formats and exports to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: File name to save tflite model. The full export path is
|
||||||
|
{export_dir}/{tflite_filename}.
|
||||||
|
quantization_config: The configuration for model quantization.
|
||||||
|
"""
|
||||||
|
super().export_tflite(
|
||||||
|
self._hparams.model_dir,
|
||||||
|
model_name,
|
||||||
|
quantization_config,
|
||||||
|
preprocess=self._preprocess)
|
|
@ -0,0 +1,106 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Demo for making an image classifier model by MediaPipe Model Maker."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
from mediapipe.model_maker.python.vision import image_classifier
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
def define_flags() -> None:
|
||||||
|
"""Define flags for the image classifier model maker demo."""
|
||||||
|
flags.DEFINE_string('export_dir', None,
|
||||||
|
'The directory to save exported files.')
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'input_data_dir', None,
|
||||||
|
"""The directory with input training data. If the training data is not
|
||||||
|
specified, the pipeline will download a default training dataset.""")
|
||||||
|
flags.DEFINE_enum_class('spec',
|
||||||
|
image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||||
|
image_classifier.SupportedModels,
|
||||||
|
'The image classifier to run.')
|
||||||
|
flags.DEFINE_enum('quantization', None, ['dynamic', 'int8', 'float16'],
|
||||||
|
'The quantization method to use when exporting the model.')
|
||||||
|
flags.mark_flag_as_required('export_dir')
|
||||||
|
|
||||||
|
|
||||||
|
def download_demo_data() -> str:
|
||||||
|
"""Downloads demo data, and returns directory path."""
|
||||||
|
data_dir = tf.keras.utils.get_file(
|
||||||
|
fname='flower_photos.tgz',
|
||||||
|
origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
|
||||||
|
extract=True)
|
||||||
|
return os.path.join(os.path.dirname(data_dir), 'flower_photos') # folder name
|
||||||
|
|
||||||
|
|
||||||
|
def run(data_dir: str, export_dir: str,
|
||||||
|
model_spec: image_classifier.SupportedModels,
|
||||||
|
quantization_option: str) -> None:
|
||||||
|
"""Runs demo."""
|
||||||
|
data = image_classifier.Dataset.from_folder(data_dir)
|
||||||
|
train_data, rest_data = data.split(0.8)
|
||||||
|
validation_data, test_data = rest_data.split(0.5)
|
||||||
|
|
||||||
|
model = image_classifier.ImageClassifier.create(
|
||||||
|
model_spec=model_spec,
|
||||||
|
train_data=train_data,
|
||||||
|
validation_data=validation_data,
|
||||||
|
hparams=image_classifier.HParams(model_dir=export_dir))
|
||||||
|
|
||||||
|
_, acc = model.evaluate(test_data)
|
||||||
|
print('Test accuracy: %f' % acc)
|
||||||
|
|
||||||
|
if quantization_option is None:
|
||||||
|
quantization_config = None
|
||||||
|
elif quantization_option == 'dynamic':
|
||||||
|
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
||||||
|
elif quantization_option == 'int8':
|
||||||
|
quantization_config = quantization.QuantizationConfig.for_int8(train_data)
|
||||||
|
elif quantization_option == 'float16':
|
||||||
|
quantization_config = quantization.QuantizationConfig.for_float16()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Quantization: {quantization} is not recognized')
|
||||||
|
|
||||||
|
model.export_model(quantization_config=quantization_config)
|
||||||
|
model.export_labels(export_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def main(_) -> None:
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
|
||||||
|
if FLAGS.input_data_dir is None:
|
||||||
|
data_dir = download_demo_data()
|
||||||
|
else:
|
||||||
|
data_dir = FLAGS.input_data_dir
|
||||||
|
|
||||||
|
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||||
|
run(data_dir=data_dir,
|
||||||
|
export_dir=export_dir,
|
||||||
|
model_spec=FLAGS.spec,
|
||||||
|
quantization_option=FLAGS.quantization)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
define_flags()
|
||||||
|
app.run(main)
|
|
@ -0,0 +1,122 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision import image_classifier
|
||||||
|
|
||||||
|
|
||||||
|
def _fill_image(rgb, image_size):
|
||||||
|
r, g, b = rgb
|
||||||
|
return np.broadcast_to(
|
||||||
|
np.array([[[r, g, b]]], dtype=np.uint8),
|
||||||
|
shape=(image_size, image_size, 3))
|
||||||
|
|
||||||
|
|
||||||
|
class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
IMAGE_SIZE = 24
|
||||||
|
IMAGES_PER_CLASS = 2
|
||||||
|
CMY_NAMES_AND_RGB_VALUES = (('cyan', (0, 255, 255)),
|
||||||
|
('magenta', (255, 0, 255)), ('yellow', (255, 255,
|
||||||
|
0)))
|
||||||
|
|
||||||
|
def _gen(self):
|
||||||
|
for i, (_, rgb) in enumerate(self.CMY_NAMES_AND_RGB_VALUES):
|
||||||
|
for _ in range(self.IMAGES_PER_CLASS):
|
||||||
|
yield (_fill_image(rgb, self.IMAGE_SIZE), i)
|
||||||
|
|
||||||
|
def _gen_cmy_data(self):
|
||||||
|
ds = tf.data.Dataset.from_generator(
|
||||||
|
self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
|
||||||
|
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
|
||||||
|
data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3,
|
||||||
|
['cyan', 'magenta', 'yellow'])
|
||||||
|
return data
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(ImageClassifierTest, self).setUp()
|
||||||
|
all_data = self._gen_cmy_data()
|
||||||
|
# Splits data, 90% data for training, 10% for testing
|
||||||
|
self.train_data, self.test_data = all_data.split(0.9)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='mobilenet_v2',
|
||||||
|
model_spec=image_classifier.SupportedModels.MOBILENET_V2,
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
train_epochs=1, batch_size=1, shuffle=True)),
|
||||||
|
dict(
|
||||||
|
testcase_name='resnet_50',
|
||||||
|
model_spec=image_classifier.SupportedModels.RESNET_50,
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
train_epochs=1, batch_size=1, shuffle=True)),
|
||||||
|
dict(
|
||||||
|
testcase_name='efficientnet_lite0',
|
||||||
|
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
train_epochs=1, batch_size=1, shuffle=True)),
|
||||||
|
dict(
|
||||||
|
testcase_name='efficientnet_lite1',
|
||||||
|
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1,
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
train_epochs=1, batch_size=1, shuffle=True)),
|
||||||
|
dict(
|
||||||
|
testcase_name='efficientnet_lite2',
|
||||||
|
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
train_epochs=1, batch_size=1, shuffle=True)),
|
||||||
|
dict(
|
||||||
|
testcase_name='efficientnet_lite3',
|
||||||
|
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE3,
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
train_epochs=1, batch_size=1, shuffle=True)),
|
||||||
|
dict(
|
||||||
|
testcase_name='efficientnet_lite4',
|
||||||
|
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
|
||||||
|
hparams=image_classifier.HParams(
|
||||||
|
train_epochs=1, batch_size=1, shuffle=True)),
|
||||||
|
)
|
||||||
|
def test_create_and_train_model(self,
|
||||||
|
model_spec: image_classifier.SupportedModels,
|
||||||
|
hparams: image_classifier.HParams):
|
||||||
|
model = image_classifier.ImageClassifier.create(
|
||||||
|
model_spec=model_spec,
|
||||||
|
train_data=self.train_data,
|
||||||
|
hparams=hparams,
|
||||||
|
validation_data=self.test_data)
|
||||||
|
self._test_accuracy(model)
|
||||||
|
|
||||||
|
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self):
|
||||||
|
hparams = image_classifier.HParams(
|
||||||
|
train_epochs=1, batch_size=1, shuffle=True)
|
||||||
|
model = image_classifier.ImageClassifier.create(
|
||||||
|
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||||
|
train_data=self.train_data,
|
||||||
|
hparams=hparams,
|
||||||
|
validation_data=self.test_data)
|
||||||
|
self._test_accuracy(model)
|
||||||
|
|
||||||
|
def _test_accuracy(self, model, threshold=0.0):
|
||||||
|
_, accuracy = model.evaluate(self.test_data)
|
||||||
|
self.assertGreaterEqual(accuracy, threshold)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Load compressed models from tensorflow_hub
|
||||||
|
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,104 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image classifier model specification."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import functools
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSpec(object):
|
||||||
|
"""Specification of image classifier model."""
|
||||||
|
|
||||||
|
mean_rgb = [0.0]
|
||||||
|
stddev_rgb = [255.0]
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
uri: str,
|
||||||
|
input_image_shape: Optional[List[int]] = None,
|
||||||
|
name: str = ''):
|
||||||
|
"""Initializes a new instance of the `ImageModelSpec` class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri: str, URI to the pretrained model.
|
||||||
|
input_image_shape: list of int, input image shape. Default: [224, 224].
|
||||||
|
name: str, model spec name.
|
||||||
|
"""
|
||||||
|
self.uri = uri
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
if input_image_shape is None:
|
||||||
|
input_image_shape = [224, 224]
|
||||||
|
self.input_image_shape = input_image_shape
|
||||||
|
|
||||||
|
|
||||||
|
mobilenet_v2_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
|
||||||
|
name='mobilenet_v2')
|
||||||
|
|
||||||
|
resnet_50_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
|
||||||
|
name='resnet_50')
|
||||||
|
|
||||||
|
efficientnet_lite0_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
|
||||||
|
name='efficientnet_lite0')
|
||||||
|
|
||||||
|
efficientnet_lite1_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
|
||||||
|
input_image_shape=[240, 240],
|
||||||
|
name='efficientnet_lite1')
|
||||||
|
|
||||||
|
efficientnet_lite2_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
|
||||||
|
input_image_shape=[260, 260],
|
||||||
|
name='efficientnet_lite2')
|
||||||
|
|
||||||
|
efficientnet_lite3_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
|
||||||
|
input_image_shape=[280, 280],
|
||||||
|
name='efficientnet_lite3')
|
||||||
|
|
||||||
|
efficientnet_lite4_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
|
||||||
|
input_image_shape=[300, 300],
|
||||||
|
name='efficientnet_lite4')
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Document the exposed models.
|
||||||
|
@enum.unique
|
||||||
|
class SupportedModels(enum.Enum):
|
||||||
|
"""Image classifier model supported by model maker."""
|
||||||
|
MOBILENET_V2 = mobilenet_v2_spec
|
||||||
|
RESNET_50 = resnet_50_spec
|
||||||
|
EFFICIENTNET_LITE0 = efficientnet_lite0_spec
|
||||||
|
EFFICIENTNET_LITE1 = efficientnet_lite1_spec
|
||||||
|
EFFICIENTNET_LITE2 = efficientnet_lite2_spec
|
||||||
|
EFFICIENTNET_LITE3 = efficientnet_lite3_spec
|
||||||
|
EFFICIENTNET_LITE4 = efficientnet_lite4_spec
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||||
|
"""Gets model spec from the input enum and initializes it."""
|
||||||
|
if spec not in cls:
|
||||||
|
raise TypeError('Unsupported image classifier spec: {}'.format(spec))
|
||||||
|
|
||||||
|
return spec.value()
|
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from typing import Callable, List
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSpecTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='mobilenet_v2_spec_test',
|
||||||
|
model_spec=ms.mobilenet_v2_spec,
|
||||||
|
expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
|
||||||
|
expected_name='mobilenet_v2',
|
||||||
|
expected_input_image_shape=[224, 224]),
|
||||||
|
dict(
|
||||||
|
testcase_name='resnet_50_spec_test',
|
||||||
|
model_spec=ms.resnet_50_spec,
|
||||||
|
expected_uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
|
||||||
|
expected_name='resnet_50',
|
||||||
|
expected_input_image_shape=[224, 224]),
|
||||||
|
dict(
|
||||||
|
testcase_name='efficientnet_lite0_spec_test',
|
||||||
|
model_spec=ms.efficientnet_lite0_spec,
|
||||||
|
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
|
||||||
|
expected_name='efficientnet_lite0',
|
||||||
|
expected_input_image_shape=[224, 224]),
|
||||||
|
dict(
|
||||||
|
testcase_name='efficientnet_lite1_spec_test',
|
||||||
|
model_spec=ms.efficientnet_lite1_spec,
|
||||||
|
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
|
||||||
|
expected_name='efficientnet_lite1',
|
||||||
|
expected_input_image_shape=[240, 240]),
|
||||||
|
dict(
|
||||||
|
testcase_name='efficientnet_lite2_spec_test',
|
||||||
|
model_spec=ms.efficientnet_lite2_spec,
|
||||||
|
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
|
||||||
|
expected_name='efficientnet_lite2',
|
||||||
|
expected_input_image_shape=[260, 260]),
|
||||||
|
dict(
|
||||||
|
testcase_name='efficientnet_lite3_spec_test',
|
||||||
|
model_spec=ms.efficientnet_lite3_spec,
|
||||||
|
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
|
||||||
|
expected_name='efficientnet_lite3',
|
||||||
|
expected_input_image_shape=[280, 280]),
|
||||||
|
dict(
|
||||||
|
testcase_name='efficientnet_lite4_spec_test',
|
||||||
|
model_spec=ms.efficientnet_lite4_spec,
|
||||||
|
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
|
||||||
|
expected_name='efficientnet_lite4',
|
||||||
|
expected_input_image_shape=[300, 300]),
|
||||||
|
)
|
||||||
|
def test_predefiend_spec(self, model_spec: Callable[..., ms.ModelSpec],
|
||||||
|
expected_uri: str, expected_name: str,
|
||||||
|
expected_input_image_shape: List[int]):
|
||||||
|
model_spec_obj = model_spec()
|
||||||
|
self.assertIsInstance(model_spec_obj, ms.ModelSpec)
|
||||||
|
self.assertEqual(model_spec_obj.uri, expected_uri)
|
||||||
|
self.assertEqual(model_spec_obj.name, expected_name)
|
||||||
|
self.assertEqual(model_spec_obj.input_image_shape,
|
||||||
|
expected_input_image_shape)
|
||||||
|
|
||||||
|
def test_create_spec(self):
|
||||||
|
custom_model_spec = ms.ModelSpec(
|
||||||
|
uri='https://custom_model',
|
||||||
|
input_image_shape=[128, 128],
|
||||||
|
name='custom_model')
|
||||||
|
self.assertEqual(custom_model_spec.uri, 'https://custom_model')
|
||||||
|
self.assertEqual(custom_model_spec.name, 'custom_model')
|
||||||
|
self.assertEqual(custom_model_spec.input_image_shape, [128, 128])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Load compressed models from tensorflow_hub
|
||||||
|
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,103 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Library to train model."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||||
|
|
||||||
|
|
||||||
|
def _create_optimizer(init_lr: float, decay_steps: int,
|
||||||
|
warmup_steps: int) -> tf.keras.optimizers.Optimizer:
|
||||||
|
"""Creates an optimizer with learning rate schedule.
|
||||||
|
|
||||||
|
Uses Keras CosineDecay schedule for the learning rate by default.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_lr: Initial learning rate.
|
||||||
|
decay_steps: Number of steps to decay over.
|
||||||
|
warmup_steps: Number of steps to do warmup for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tf.keras.optimizers.Optimizer for model training.
|
||||||
|
"""
|
||||||
|
learning_rate_fn = tf.keras.experimental.CosineDecay(
|
||||||
|
initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0)
|
||||||
|
if warmup_steps:
|
||||||
|
learning_rate_fn = model_util.WarmUp(
|
||||||
|
initial_learning_rate=init_lr,
|
||||||
|
decay_schedule_fn=learning_rate_fn,
|
||||||
|
warmup_steps=warmup_steps)
|
||||||
|
optimizer = tf.keras.optimizers.RMSprop(
|
||||||
|
learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001)
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]:
|
||||||
|
"""Gets default callbacks."""
|
||||||
|
summary_dir = os.path.join(model_dir, 'summaries')
|
||||||
|
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||||
|
# Save checkpoint every 20 epochs.
|
||||||
|
|
||||||
|
checkpoint_path = os.path.join(model_dir, 'checkpoint')
|
||||||
|
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||||
|
checkpoint_path, save_weights_only=True, period=20)
|
||||||
|
return [summary_callback, checkpoint_callback]
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
||||||
|
train_ds: tf.data.Dataset,
|
||||||
|
validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History:
|
||||||
|
"""Trains model with the input data and hyperparameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Input tf.keras.Model.
|
||||||
|
hparams: Hyperparameters for training image classifier.
|
||||||
|
train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit().
|
||||||
|
validation_ds: tf.data.Dataset, validation data to be fed in
|
||||||
|
tf.keras.Model.fit().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The tf.keras.callbacks.History object returned by tf.keras.Model.fit().
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Learning rate is linear to batch size.
|
||||||
|
learning_rate = hparams.learning_rate * hparams.batch_size / 256
|
||||||
|
|
||||||
|
# Get decay steps.
|
||||||
|
total_training_steps = hparams.steps_per_epoch * hparams.train_epochs
|
||||||
|
default_decay_steps = hparams.decay_samples // hparams.batch_size
|
||||||
|
decay_steps = max(total_training_steps, default_decay_steps)
|
||||||
|
|
||||||
|
warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch
|
||||||
|
optimizer = _create_optimizer(
|
||||||
|
init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps)
|
||||||
|
|
||||||
|
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||||
|
label_smoothing=hparams.label_smoothing)
|
||||||
|
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
|
||||||
|
callbacks = _get_default_callbacks(hparams.model_dir)
|
||||||
|
|
||||||
|
# Train the model.
|
||||||
|
return model.fit(
|
||||||
|
x=train_ds,
|
||||||
|
epochs=hparams.train_epochs,
|
||||||
|
steps_per_epoch=hparams.steps_per_epoch,
|
||||||
|
validation_data=validation_ds,
|
||||||
|
callbacks=callbacks)
|
|
@ -2,3 +2,5 @@ absl-py
|
||||||
numpy
|
numpy
|
||||||
opencv-contrib-python
|
opencv-contrib-python
|
||||||
tensorflow
|
tensorflow
|
||||||
|
tensorflow-datasets
|
||||||
|
tensorflow-hub
|
||||||
|
|
Loading…
Reference in New Issue
Block a user