Open sourcing model_maker/core/util and model_maker/core/data
PiperOrigin-RevId: 478835650
This commit is contained in:
parent
38baaa00b1
commit
fb3b0e788e
|
@ -53,7 +53,7 @@ RUN pip3 install wheel
|
||||||
RUN pip3 install future
|
RUN pip3 install future
|
||||||
RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1
|
RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1
|
||||||
RUN pip3 install six==1.14.0
|
RUN pip3 install six==1.14.0
|
||||||
RUN pip3 install tensorflow==2.2.0
|
RUN pip3 install tensorflow
|
||||||
RUN pip3 install tf_slim
|
RUN pip3 install tf_slim
|
||||||
|
|
||||||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||||
|
|
22
mediapipe/model_maker/BUILD
Normal file
22
mediapipe/model_maker/BUILD
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
package_group(
|
||||||
|
name = "internal",
|
||||||
|
packages = [
|
||||||
|
"//mediapipe/model_maker/...",
|
||||||
|
],
|
||||||
|
)
|
13
mediapipe/model_maker/__init__.py
Normal file
13
mediapipe/model_maker/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
22
mediapipe/model_maker/python/BUILD
Normal file
22
mediapipe/model_maker/python/BUILD
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
package_group(
|
||||||
|
name = "internal",
|
||||||
|
packages = [
|
||||||
|
"//mediapipe/model_maker/...",
|
||||||
|
],
|
||||||
|
)
|
13
mediapipe/model_maker/python/__init__.py
Normal file
13
mediapipe/model_maker/python/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
19
mediapipe/model_maker/python/core/BUILD
Normal file
19
mediapipe/model_maker/python/core/BUILD
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"])
|
13
mediapipe/model_maker/python/core/__init__.py
Normal file
13
mediapipe/model_maker/python/core/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
68
mediapipe/model_maker/python/core/data/BUILD
Normal file
68
mediapipe/model_maker/python/core/data/BUILD
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "data_util",
|
||||||
|
srcs = ["data_util.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "data_util_test",
|
||||||
|
srcs = ["data_util_test.py"],
|
||||||
|
data = ["//mediapipe/model_maker/python/core/data/testdata"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":data_util"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dataset",
|
||||||
|
srcs = ["dataset.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dataset_test",
|
||||||
|
srcs = ["dataset_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "classification_dataset",
|
||||||
|
srcs = ["classification_dataset.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":dataset"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "classification_dataset_test",
|
||||||
|
srcs = ["classification_dataset_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":classification_dataset"],
|
||||||
|
)
|
13
mediapipe/model_maker/python/core/data/__init__.py
Normal file
13
mediapipe/model_maker/python/core/data/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Common classification dataset library."""
|
||||||
|
|
||||||
|
from typing import Any, Tuple
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationDataset(ds.Dataset):
|
||||||
|
"""DataLoader for classification models."""
|
||||||
|
|
||||||
|
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any):
|
||||||
|
super().__init__(dataset, size)
|
||||||
|
self.index_to_label = index_to_label
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self: ds._DatasetT) -> int:
|
||||||
|
return len(self.index_to_label)
|
||||||
|
|
||||||
|
def split(self: ds._DatasetT,
|
||||||
|
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
|
||||||
|
"""Splits dataset into two sub-datasets with the given fraction.
|
||||||
|
|
||||||
|
Primarily used for splitting the data set into training and testing sets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fraction: float, demonstrates the fraction of the first returned
|
||||||
|
subdataset in the original data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The splitted two sub datasets.
|
||||||
|
"""
|
||||||
|
return self._split(fraction, self.index_to_label)
|
|
@ -0,0 +1,68 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationDataLoaderTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_split(self):
|
||||||
|
|
||||||
|
class MagicClassificationDataLoader(
|
||||||
|
classification_dataset.ClassificationDataset):
|
||||||
|
|
||||||
|
def __init__(self, dataset, size, index_to_label, value):
|
||||||
|
super(MagicClassificationDataLoader,
|
||||||
|
self).__init__(dataset, size, index_to_label)
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def split(self, fraction):
|
||||||
|
return self._split(fraction, self.index_to_label, self.value)
|
||||||
|
|
||||||
|
# Some dummy inputs.
|
||||||
|
magic_value = 42
|
||||||
|
num_classes = 2
|
||||||
|
index_to_label = (False, True)
|
||||||
|
|
||||||
|
# Create data loader from sample data.
|
||||||
|
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||||
|
data = MagicClassificationDataLoader(ds, len(ds), index_to_label,
|
||||||
|
magic_value)
|
||||||
|
|
||||||
|
# Train/Test data split.
|
||||||
|
fraction = .25
|
||||||
|
train_data, test_data = data.split(fraction)
|
||||||
|
|
||||||
|
# `split` should return instances of child DataLoader.
|
||||||
|
self.assertIsInstance(train_data, MagicClassificationDataLoader)
|
||||||
|
self.assertIsInstance(test_data, MagicClassificationDataLoader)
|
||||||
|
|
||||||
|
# Make sure number of entries are right.
|
||||||
|
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
|
||||||
|
self.assertLen(train_data, fraction * len(ds))
|
||||||
|
self.assertLen(test_data, len(ds) - len(train_data))
|
||||||
|
|
||||||
|
# Make sure attributes propagated correctly.
|
||||||
|
self.assertEqual(train_data.num_classes, num_classes)
|
||||||
|
self.assertEqual(test_data.index_to_label, index_to_label)
|
||||||
|
self.assertEqual(train_data.value, magic_value)
|
||||||
|
self.assertEqual(test_data.value, magic_value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
35
mediapipe/model_maker/python/core/data/data_util.py
Normal file
35
mediapipe/model_maker/python/core/data/data_util.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Data utility library."""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(path: str) -> np.ndarray:
|
||||||
|
"""Loads an image as an RGB numpy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: input image file absolute path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An RGB image in numpy.ndarray.
|
||||||
|
"""
|
||||||
|
tf.compat.v1.logging.info('Loading RGB image %s', path)
|
||||||
|
# TODO Replace the OpenCV image load and conversion library by
|
||||||
|
# MediaPipe image utility library once it is ready.
|
||||||
|
image = cv2.imread(path)
|
||||||
|
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
44
mediapipe/model_maker/python/core/data/data_util_test.py
Normal file
44
mediapipe/model_maker/python/core/data/data_util_test.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
from absl import flags
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import data_util
|
||||||
|
|
||||||
|
_WORKSPACE = "mediapipe"
|
||||||
|
_TEST_DATA_DIR = os.path.join(
|
||||||
|
_WORKSPACE, 'mediapipe/model_maker/python/core/data/testdata')
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
class DataUtilTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_load_rgb_image(self):
|
||||||
|
image_path = os.path.join(FLAGS.test_srcdir, _TEST_DATA_DIR, 'test.jpg')
|
||||||
|
image_data = data_util.load_image(image_path)
|
||||||
|
self.assertEqual(image_data.shape, (5184, 3456, 3))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
164
mediapipe/model_maker/python/core/data/dataset.py
Normal file
164
mediapipe/model_maker/python/core/data/dataset.py
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Common dataset for model training and evaluation."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from typing import Callable, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
_DatasetT = TypeVar('_DatasetT', bound='Dataset')
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(object):
|
||||||
|
"""A generic dataset class for loading model training and evaluation dataset.
|
||||||
|
|
||||||
|
For each ML task, such as image classification, text classification etc., a
|
||||||
|
subclass can be derived from this class to provide task-specific data loading
|
||||||
|
utilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tf_dataset: tf.data.Dataset, size: Optional[int] = None):
|
||||||
|
"""Initializes Dataset class.
|
||||||
|
|
||||||
|
To build dataset from raw data, consider using the task specific utilities,
|
||||||
|
e.g. from_folder().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tf_dataset: A tf.data.Dataset object that contains a potentially large set
|
||||||
|
of elements, where each element is a pair of (input_data, target). The
|
||||||
|
`input_data` means the raw input data, like an image, a text etc., while
|
||||||
|
the `target` means the ground truth of the raw input data, e.g. the
|
||||||
|
classification label of the image etc.
|
||||||
|
size: The size of the dataset. tf.data.Dataset donesn't support a function
|
||||||
|
to get the length directly since it's lazy-loaded and may be infinite.
|
||||||
|
"""
|
||||||
|
self._dataset = tf_dataset
|
||||||
|
self._size = size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> Optional[int]:
|
||||||
|
"""Returns the size of the dataset.
|
||||||
|
|
||||||
|
Note that this function may return None becuase the exact size of the
|
||||||
|
dataset isn't a necessary parameter to create an instance of this class,
|
||||||
|
and tf.data.Dataset donesn't support a function to get the length directly
|
||||||
|
since it's lazy-loaded and may be infinite.
|
||||||
|
In most cases, however, when an instance of this class is created by helper
|
||||||
|
functions like 'from_folder', the size of the dataset will be preprocessed,
|
||||||
|
and this function can return an int representing the size of the dataset.
|
||||||
|
"""
|
||||||
|
return self._size
|
||||||
|
|
||||||
|
def gen_tf_dataset(self,
|
||||||
|
batch_size: int = 1,
|
||||||
|
is_training: bool = False,
|
||||||
|
shuffle: bool = False,
|
||||||
|
preprocess: Optional[Callable[..., bool]] = None,
|
||||||
|
drop_remainder: bool = False) -> tf.data.Dataset:
|
||||||
|
"""Generates a batched tf.data.Dataset for training/evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: An integer, the returned dataset will be batched by this size.
|
||||||
|
is_training: A boolean, when True, the returned dataset will be optionally
|
||||||
|
shuffled and repeated as an endless dataset.
|
||||||
|
shuffle: A boolean, when True, the returned dataset will be shuffled to
|
||||||
|
create randomness during model training.
|
||||||
|
preprocess: A function taking three arguments in order, feature, label and
|
||||||
|
boolean is_training.
|
||||||
|
drop_remainder: boolean, whether the finaly batch drops remainder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A TF dataset ready to be consumed by Keras model.
|
||||||
|
"""
|
||||||
|
dataset = self._dataset
|
||||||
|
|
||||||
|
if preprocess:
|
||||||
|
preprocess = functools.partial(preprocess, is_training=is_training)
|
||||||
|
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
if shuffle:
|
||||||
|
# Shuffle size should be bigger than the batch_size. Otherwise it's only
|
||||||
|
# shuffling within the batch, which equals to not having shuffle.
|
||||||
|
buffer_size = 3 * batch_size
|
||||||
|
# But since we are doing shuffle before repeat, it doesn't make sense to
|
||||||
|
# shuffle more than total available entries.
|
||||||
|
# TODO: Investigate if shuffling before / after repeat
|
||||||
|
# dataset can get a better performance?
|
||||||
|
# Shuffle after repeat will give a more randomized dataset and mix the
|
||||||
|
# epoch boundary: https://www.tensorflow.org/guide/data
|
||||||
|
if self._size:
|
||||||
|
buffer_size = min(self._size, buffer_size)
|
||||||
|
dataset = dataset.shuffle(buffer_size=buffer_size)
|
||||||
|
|
||||||
|
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
|
||||||
|
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||||
|
# TODO: Consider converting dataset to distributed dataset
|
||||||
|
# here.
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Returns the number of element of the dataset."""
|
||||||
|
if self._size is not None:
|
||||||
|
return self._size
|
||||||
|
else:
|
||||||
|
return len(self._dataset)
|
||||||
|
|
||||||
|
def split(self: _DatasetT, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||||
|
"""Splits dataset into two sub-datasets with the given fraction.
|
||||||
|
|
||||||
|
Primarily used for splitting the data set into training and testing sets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fraction: A float value defines the fraction of the first returned
|
||||||
|
subdataset in the original data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The splitted two sub datasets.
|
||||||
|
"""
|
||||||
|
return self._split(fraction)
|
||||||
|
|
||||||
|
def _split(self: _DatasetT, fraction: float,
|
||||||
|
*args) -> Tuple[_DatasetT, _DatasetT]:
|
||||||
|
"""Implementation for `split` method and returns sub-class instances.
|
||||||
|
|
||||||
|
Child DataLoader classes, if requires additional constructor arguments,
|
||||||
|
should implement their own `split` method by calling `_split` with all
|
||||||
|
arguments to the constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fraction: A float value defines the fraction of the first returned
|
||||||
|
subdataset in the original data.
|
||||||
|
*args: additional arguments passed to the sub-class constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The splitted two sub datasets.
|
||||||
|
"""
|
||||||
|
assert (fraction > 0 and fraction < 1)
|
||||||
|
|
||||||
|
dataset = self._dataset
|
||||||
|
|
||||||
|
train_size = int(self._size * fraction)
|
||||||
|
trainset = self.__class__(dataset.take(train_size), train_size, *args)
|
||||||
|
|
||||||
|
test_size = self._size - train_size
|
||||||
|
testset = self.__class__(dataset.skip(train_size), test_size, *args)
|
||||||
|
|
||||||
|
return trainset, testset
|
78
mediapipe/model_maker/python/core/data/dataset_test.py
Normal file
78
mediapipe/model_maker/python/core/data/dataset_test.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
|
from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_split(self):
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||||
|
[1, 0]])
|
||||||
|
data = ds.Dataset(dataset, 4)
|
||||||
|
train_data, test_data = data.split(0.5)
|
||||||
|
|
||||||
|
self.assertLen(train_data, 2)
|
||||||
|
self.assertIsInstance(train_data, ds.Dataset)
|
||||||
|
self.assertIsInstance(test_data, ds.Dataset)
|
||||||
|
for i, elem in enumerate(train_data.gen_tf_dataset()):
|
||||||
|
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
|
||||||
|
|
||||||
|
self.assertLen(test_data, 2)
|
||||||
|
for i, elem in enumerate(test_data.gen_tf_dataset()):
|
||||||
|
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
size = 4
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||||
|
[1, 0]])
|
||||||
|
data = ds.Dataset(dataset, size)
|
||||||
|
self.assertLen(data, size)
|
||||||
|
|
||||||
|
def test_gen_tf_dataset(self):
|
||||||
|
input_dim = 8
|
||||||
|
data = test_util.create_dataset(
|
||||||
|
data_size=2, input_shape=[input_dim], num_classes=2)
|
||||||
|
|
||||||
|
dataset = data.gen_tf_dataset()
|
||||||
|
self.assertLen(dataset, 2)
|
||||||
|
for (feature, label) in dataset:
|
||||||
|
self.assertTrue((tf.shape(feature).numpy() == np.array([1, 8])).all())
|
||||||
|
self.assertTrue((tf.shape(label).numpy() == np.array([1])).all())
|
||||||
|
|
||||||
|
dataset2 = data.gen_tf_dataset(batch_size=2)
|
||||||
|
self.assertLen(dataset2, 1)
|
||||||
|
for (feature, label) in dataset2:
|
||||||
|
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
|
||||||
|
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
|
||||||
|
|
||||||
|
dataset3 = data.gen_tf_dataset(batch_size=2, is_training=True, shuffle=True)
|
||||||
|
self.assertEqual(dataset3.cardinality(), 1)
|
||||||
|
for (feature, label) in dataset3.take(10):
|
||||||
|
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
|
||||||
|
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
30
mediapipe/model_maker/python/core/data/testdata/BUILD
vendored
Normal file
30
mediapipe/model_maker/python/core/data/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load(
|
||||||
|
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||||
|
"mediapipe_files",
|
||||||
|
)
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe/model_maker/python/core/data:__subpackages__"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_files(srcs = ["test.jpg"])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = ["test.jpg"],
|
||||||
|
)
|
100
mediapipe/model_maker/python/core/utils/BUILD
Normal file
100
mediapipe/model_maker/python/core/utils/BUILD
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "test_util",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["test_util.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":model_util",
|
||||||
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "image_preprocessing",
|
||||||
|
srcs = ["image_preprocessing.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "image_preprocessing_test",
|
||||||
|
srcs = ["image_preprocessing_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":image_preprocessing"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_util",
|
||||||
|
srcs = ["model_util.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":quantization",
|
||||||
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "model_util_test",
|
||||||
|
srcs = ["model_util_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":model_util",
|
||||||
|
":quantization",
|
||||||
|
":test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "loss_functions",
|
||||||
|
srcs = ["loss_functions.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "loss_functions_test",
|
||||||
|
srcs = ["loss_functions_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":loss_functions"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "quantization",
|
||||||
|
srcs = ["quantization.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = ["//mediapipe/model_maker/python/core/data:dataset"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "quantization_test",
|
||||||
|
srcs = ["quantization_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":quantization",
|
||||||
|
":test_util",
|
||||||
|
],
|
||||||
|
)
|
13
mediapipe/model_maker/python/core/utils/__init__.py
Normal file
13
mediapipe/model_maker/python/core/utils/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
228
mediapipe/model_maker/python/core/utils/image_preprocessing.py
Normal file
228
mediapipe/model_maker/python/core/utils/image_preprocessing.py
Normal file
|
@ -0,0 +1,228 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""ImageNet preprocessing."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
IMAGE_SIZE = 224
|
||||||
|
CROP_PADDING = 32
|
||||||
|
|
||||||
|
|
||||||
|
class Preprocessor(object):
|
||||||
|
"""Preprocessor for image classification."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_shape,
|
||||||
|
num_classes,
|
||||||
|
mean_rgb,
|
||||||
|
stddev_rgb,
|
||||||
|
use_augmentation=False):
|
||||||
|
self.input_shape = input_shape
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.mean_rgb = mean_rgb
|
||||||
|
self.stddev_rgb = stddev_rgb
|
||||||
|
self.use_augmentation = use_augmentation
|
||||||
|
|
||||||
|
def __call__(self, image, label, is_training=True):
|
||||||
|
if self.use_augmentation:
|
||||||
|
return self._preprocess_with_augmentation(image, label, is_training)
|
||||||
|
return self._preprocess_without_augmentation(image, label)
|
||||||
|
|
||||||
|
def _preprocess_with_augmentation(self, image, label, is_training):
|
||||||
|
"""Image preprocessing method with data augmentation."""
|
||||||
|
image_size = self.input_shape[0]
|
||||||
|
if is_training:
|
||||||
|
image = preprocess_for_train(image, image_size)
|
||||||
|
else:
|
||||||
|
image = preprocess_for_eval(image, image_size)
|
||||||
|
|
||||||
|
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||||
|
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||||
|
|
||||||
|
label = tf.one_hot(label, depth=self.num_classes)
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
# TODO: Changes to preprocess to support batch input.
|
||||||
|
def _preprocess_without_augmentation(self, image, label):
|
||||||
|
"""Image preprocessing method without data augmentation."""
|
||||||
|
image = tf.cast(image, tf.float32)
|
||||||
|
|
||||||
|
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||||
|
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||||
|
|
||||||
|
image = tf.compat.v1.image.resize(image, self.input_shape)
|
||||||
|
label = tf.one_hot(label, depth=self.num_classes)
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
|
||||||
|
def _distorted_bounding_box_crop(image,
|
||||||
|
bbox,
|
||||||
|
min_object_covered=0.1,
|
||||||
|
aspect_ratio_range=(0.75, 1.33),
|
||||||
|
area_range=(0.05, 1.0),
|
||||||
|
max_attempts=100):
|
||||||
|
"""Generates cropped_image using one of the bboxes randomly distorted.
|
||||||
|
|
||||||
|
See `tf.image.sample_distorted_bounding_box` for more documentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
|
||||||
|
shape [height, width, channels].
|
||||||
|
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where
|
||||||
|
each coordinate is [0, 1) and the coordinates are arranged as `[ymin,
|
||||||
|
xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image.
|
||||||
|
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area
|
||||||
|
of the image must contain at least this fraction of any bounding box
|
||||||
|
supplied.
|
||||||
|
aspect_ratio_range: An optional list of `float`s. The cropped area of the
|
||||||
|
image must have an aspect ratio = width / height within this range.
|
||||||
|
area_range: An optional list of `float`s. The cropped area of the image must
|
||||||
|
contain a fraction of the supplied image within in this range.
|
||||||
|
max_attempts: An optional `int`. Number of attempts at generating a cropped
|
||||||
|
region of the image of the specified constraints. After `max_attempts`
|
||||||
|
failures, return the entire image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A cropped image `Tensor`
|
||||||
|
"""
|
||||||
|
with tf.name_scope('distorted_bounding_box_crop'):
|
||||||
|
shape = tf.shape(image)
|
||||||
|
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
||||||
|
shape,
|
||||||
|
bounding_boxes=bbox,
|
||||||
|
min_object_covered=min_object_covered,
|
||||||
|
aspect_ratio_range=aspect_ratio_range,
|
||||||
|
area_range=area_range,
|
||||||
|
max_attempts=max_attempts,
|
||||||
|
use_image_if_no_bounding_boxes=True)
|
||||||
|
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
||||||
|
|
||||||
|
# Crop the image to the specified bounding box.
|
||||||
|
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
||||||
|
target_height, target_width, _ = tf.unstack(bbox_size)
|
||||||
|
image = tf.image.crop_to_bounding_box(image, offset_y, offset_x,
|
||||||
|
target_height, target_width)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def _at_least_x_are_equal(a, b, x):
|
||||||
|
"""At least `x` of `a` and `b` `Tensors` are equal."""
|
||||||
|
match = tf.equal(a, b)
|
||||||
|
match = tf.cast(match, tf.int32)
|
||||||
|
return tf.greater_equal(tf.reduce_sum(match), x)
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_image(image, image_size, method=None):
|
||||||
|
if method is not None:
|
||||||
|
tf.compat.v1.logging.info('Use customized resize method {}'.format(method))
|
||||||
|
return tf.compat.v1.image.resize([image], [image_size, image_size],
|
||||||
|
method)[0]
|
||||||
|
tf.compat.v1.logging.info('Use default resize_bicubic.')
|
||||||
|
return tf.compat.v1.image.resize_bicubic([image], [image_size, image_size])[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_and_random_crop(original_image, image_size, resize_method=None):
|
||||||
|
"""Makes a random crop of image_size."""
|
||||||
|
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
|
||||||
|
image = _distorted_bounding_box_crop(
|
||||||
|
original_image,
|
||||||
|
bbox,
|
||||||
|
min_object_covered=0.1,
|
||||||
|
aspect_ratio_range=(3. / 4, 4. / 3.),
|
||||||
|
area_range=(0.08, 1.0),
|
||||||
|
max_attempts=10)
|
||||||
|
original_shape = tf.shape(original_image)
|
||||||
|
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
|
||||||
|
|
||||||
|
image = tf.cond(bad,
|
||||||
|
lambda: _decode_and_center_crop(original_image, image_size),
|
||||||
|
lambda: _resize_image(image, image_size, resize_method))
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_and_center_crop(image, image_size, resize_method=None):
|
||||||
|
"""Crops to center of image with padding then scales image_size."""
|
||||||
|
shape = tf.shape(image)
|
||||||
|
image_height = shape[0]
|
||||||
|
image_width = shape[1]
|
||||||
|
|
||||||
|
padded_center_crop_size = tf.cast(
|
||||||
|
((image_size / (image_size + CROP_PADDING)) *
|
||||||
|
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
|
||||||
|
|
||||||
|
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
|
||||||
|
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
|
||||||
|
image = tf.image.crop_to_bounding_box(image, offset_height, offset_width,
|
||||||
|
padded_center_crop_size,
|
||||||
|
padded_center_crop_size)
|
||||||
|
image = _resize_image(image, image_size, resize_method)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def _flip(image):
|
||||||
|
"""Random horizontal image flip."""
|
||||||
|
image = tf.image.random_flip_left_right(image)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_for_train(
|
||||||
|
image: tf.Tensor,
|
||||||
|
image_size: int = IMAGE_SIZE,
|
||||||
|
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
|
||||||
|
"""Preprocesses the given image for evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
|
||||||
|
shape [height, width, channels].
|
||||||
|
image_size: image size.
|
||||||
|
resize_method: resize method. If none, use bicubic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A preprocessed image `Tensor`.
|
||||||
|
"""
|
||||||
|
image = _decode_and_random_crop(image, image_size, resize_method)
|
||||||
|
image = _flip(image)
|
||||||
|
image = tf.reshape(image, [image_size, image_size, 3])
|
||||||
|
|
||||||
|
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_for_eval(
|
||||||
|
image: tf.Tensor,
|
||||||
|
image_size: int = IMAGE_SIZE,
|
||||||
|
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
|
||||||
|
"""Preprocesses the given image for evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
|
||||||
|
shape [height, width, channels].
|
||||||
|
image_size: image size.
|
||||||
|
resize_method: if None, use bicubic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A preprocessed image `Tensor`.
|
||||||
|
"""
|
||||||
|
image = _decode_and_center_crop(image, image_size, resize_method)
|
||||||
|
image = tf.reshape(image, [image_size, image_size, 3])
|
||||||
|
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
|
||||||
|
return image
|
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import image_preprocessing
|
||||||
|
|
||||||
|
|
||||||
|
def _get_preprocessed_image(preprocessor, is_training=False):
|
||||||
|
image_placeholder = tf.compat.v1.placeholder(tf.uint8, [24, 24, 3])
|
||||||
|
label_placeholder = tf.compat.v1.placeholder(tf.int32, [1])
|
||||||
|
image_tensor, _ = preprocessor(image_placeholder, label_placeholder,
|
||||||
|
is_training)
|
||||||
|
|
||||||
|
with tf.compat.v1.Session() as sess:
|
||||||
|
input_image = np.arange(24 * 24 * 3, dtype=np.uint8).reshape([24, 24, 3])
|
||||||
|
image = sess.run(
|
||||||
|
image_tensor,
|
||||||
|
feed_dict={
|
||||||
|
image_placeholder: input_image,
|
||||||
|
label_placeholder: [0]
|
||||||
|
})
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessorTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_preprocess_without_augmentation(self):
|
||||||
|
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
|
||||||
|
num_classes=2,
|
||||||
|
mean_rgb=[0.0],
|
||||||
|
stddev_rgb=[255.0],
|
||||||
|
use_augmentation=False)
|
||||||
|
actual_image = np.array([[[0., 0.00392157, 0.00784314],
|
||||||
|
[0.14117648, 0.14509805, 0.14901961]],
|
||||||
|
[[0.37647063, 0.3803922, 0.38431376],
|
||||||
|
[0.5176471, 0.52156866, 0.5254902]]])
|
||||||
|
|
||||||
|
image = _get_preprocessed_image(preprocessor)
|
||||||
|
self.assertTrue(np.allclose(image, actual_image, atol=1e-05))
|
||||||
|
|
||||||
|
def test_preprocess_with_augmentation(self):
|
||||||
|
image_preprocessing.CROP_PADDING = 1
|
||||||
|
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
|
||||||
|
num_classes=2,
|
||||||
|
mean_rgb=[0.0],
|
||||||
|
stddev_rgb=[255.0],
|
||||||
|
use_augmentation=True)
|
||||||
|
# Tests validation image.
|
||||||
|
actual_eval_image = np.array([[[0.17254902, 0.1764706, 0.18039216],
|
||||||
|
[0.26666668, 0.27058825, 0.27450982]],
|
||||||
|
[[0.42352945, 0.427451, 0.43137258],
|
||||||
|
[0.5176471, 0.52156866, 0.5254902]]])
|
||||||
|
|
||||||
|
image = _get_preprocessed_image(preprocessor, is_training=False)
|
||||||
|
self.assertTrue(np.allclose(image, actual_eval_image, atol=1e-05))
|
||||||
|
|
||||||
|
# Tests training image.
|
||||||
|
image1 = _get_preprocessed_image(preprocessor, is_training=True)
|
||||||
|
image2 = _get_preprocessed_image(preprocessor, is_training=True)
|
||||||
|
self.assertFalse(np.allclose(image1, image2, atol=1e-05))
|
||||||
|
self.assertEqual(image1.shape, (2, 2, 3))
|
||||||
|
self.assertEqual(image2.shape, (2, 2, 3))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.compat.v1.disable_eager_execution()
|
||||||
|
tf.test.main()
|
105
mediapipe/model_maker/python/core/utils/loss_functions.py
Normal file
105
mediapipe/model_maker/python/core/utils/loss_functions.py
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Loss function utility library."""
|
||||||
|
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class FocalLoss(tf.keras.losses.Loss):
|
||||||
|
"""Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
|
||||||
|
|
||||||
|
This class computes the focal loss between labels and prediction. Focal loss
|
||||||
|
is a weighted loss function that modulates the standard cross-entropy loss
|
||||||
|
based on how well the neural network performs on a specific example of a
|
||||||
|
class. The labels should be provided in a `one_hot` vector representation.
|
||||||
|
There should be `#classes` floating point values per prediction.
|
||||||
|
The loss is reduced across all samples using 'sum_over_batch_size' reduction
|
||||||
|
(see https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction).
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
>>> y_true = [[0, 1, 0], [0, 0, 1]]
|
||||||
|
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
|
||||||
|
>>> gamma = 2
|
||||||
|
>>> focal_loss = FocalLoss(gamma)
|
||||||
|
>>> focal_loss(y_true, y_pred).numpy()
|
||||||
|
0.9326
|
||||||
|
|
||||||
|
>>> # Calling with 'sample_weight'.
|
||||||
|
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
|
||||||
|
0.6528
|
||||||
|
|
||||||
|
Usage with the `compile()` API:
|
||||||
|
```python
|
||||||
|
model.compile(optimizer='sgd', loss=FocalLoss(gamma))
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gamma: Focal loss gamma, as described in class docs.
|
||||||
|
class_weight: A weight to apply to the loss, one for each class. The
|
||||||
|
weight is applied for each input where the ground truth label matches.
|
||||||
|
"""
|
||||||
|
super(tf.keras.losses.Loss, self).__init__()
|
||||||
|
# Used for clipping min/max values of probability values in y_pred to avoid
|
||||||
|
# NaNs and Infs in computation.
|
||||||
|
self._epsilon = 1e-7
|
||||||
|
# This is a tunable "focusing parameter"; should be >= 0.
|
||||||
|
# When gamma = 0, the loss returned is the standard categorical
|
||||||
|
# cross-entropy loss.
|
||||||
|
self._gamma = gamma
|
||||||
|
self._class_weight = class_weight
|
||||||
|
# tf.keras.losses.Loss class implementation requires a Reduction specified
|
||||||
|
# in self.reduction. To use this reduction, we should use tensorflow's
|
||||||
|
# compute_weighted_loss function however it is only compatible with v1 of
|
||||||
|
# Tensorflow: https://www.tensorflow.org/api_docs/python/tf/compat/v1/losses/compute_weighted_loss?hl=en. pylint: disable=line-too-long
|
||||||
|
# So even though it is specified here, we don't use self.reduction in the
|
||||||
|
# loss function call.
|
||||||
|
self.reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
|
||||||
|
|
||||||
|
def __call__(self,
|
||||||
|
y_true: tf.Tensor,
|
||||||
|
y_pred: tf.Tensor,
|
||||||
|
sample_weight: Optional[tf.Tensor] = None) -> tf.Tensor:
|
||||||
|
if self._class_weight:
|
||||||
|
class_weight = tf.convert_to_tensor(self._class_weight, dtype=tf.float32)
|
||||||
|
label = tf.argmax(y_true, axis=1)
|
||||||
|
loss_weight = tf.gather(class_weight, label)
|
||||||
|
else:
|
||||||
|
loss_weight = tf.ones(tf.shape(y_true)[0])
|
||||||
|
y_true = tf.cast(y_true, y_pred.dtype)
|
||||||
|
y_pred = tf.clip_by_value(y_pred, self._epsilon, 1 - self._epsilon)
|
||||||
|
batch_size = tf.cast(tf.shape(y_pred)[0], y_pred.dtype)
|
||||||
|
if sample_weight is None:
|
||||||
|
sample_weight = tf.constant(1.0)
|
||||||
|
weight_shape = sample_weight.shape
|
||||||
|
weight_rank = weight_shape.ndims
|
||||||
|
y_pred_rank = y_pred.shape.ndims
|
||||||
|
if y_pred_rank - weight_rank == 1:
|
||||||
|
sample_weight = tf.expand_dims(sample_weight, [-1])
|
||||||
|
elif weight_rank != 0:
|
||||||
|
raise ValueError(f'Unexpected sample_weights, should be either a scalar'
|
||||||
|
f'or a vector of batch_size:{batch_size.numpy()}')
|
||||||
|
ce = -tf.math.log(y_pred)
|
||||||
|
modulating_factor = tf.math.pow(1 - y_pred, self._gamma)
|
||||||
|
losses = y_true * modulating_factor * ce * sample_weight
|
||||||
|
losses = losses * loss_weight[:, tf.newaxis]
|
||||||
|
# By default, this function uses "sum_over_batch_size" reduction for the
|
||||||
|
# loss per batch.
|
||||||
|
return tf.reduce_sum(losses) / batch_size
|
103
mediapipe/model_maker/python/core/utils/loss_functions_test.py
Normal file
103
mediapipe/model_maker/python/core/utils/loss_functions_test.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||||
|
|
||||||
|
|
||||||
|
class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(testcase_name='no_sample_weight', sample_weight=None),
|
||||||
|
dict(
|
||||||
|
testcase_name='with_sample_weight',
|
||||||
|
sample_weight=tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])))
|
||||||
|
def test_focal_loss_gamma_0_is_cross_entropy(
|
||||||
|
self, sample_weight: Optional[tf.Tensor]):
|
||||||
|
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
|
||||||
|
0]])
|
||||||
|
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
|
||||||
|
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
|
||||||
|
|
||||||
|
tf_cce = tf.keras.losses.CategoricalCrossentropy(
|
||||||
|
from_logits=False,
|
||||||
|
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
|
||||||
|
focal_loss = loss_functions.FocalLoss(gamma=0)
|
||||||
|
self.assertAllClose(
|
||||||
|
tf_cce(y_true, y_pred, sample_weight=sample_weight),
|
||||||
|
focal_loss(y_true, y_pred, sample_weight=sample_weight), 1e-4)
|
||||||
|
|
||||||
|
def test_focal_loss_with_sample_weight(self):
|
||||||
|
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
|
||||||
|
0]])
|
||||||
|
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
|
||||||
|
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
|
||||||
|
|
||||||
|
focal_loss = loss_functions.FocalLoss(gamma=0)
|
||||||
|
|
||||||
|
sample_weight = tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])
|
||||||
|
|
||||||
|
self.assertGreater(
|
||||||
|
focal_loss(y_true=y_true, y_pred=y_pred),
|
||||||
|
focal_loss(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(testcase_name='gt_0.1', y_pred=tf.constant([0.1, 0.9])),
|
||||||
|
dict(testcase_name='gt_0.3', y_pred=tf.constant([0.3, 0.7])),
|
||||||
|
dict(testcase_name='gt_0.5', y_pred=tf.constant([0.5, 0.5])),
|
||||||
|
dict(testcase_name='gt_0.7', y_pred=tf.constant([0.7, 0.3])),
|
||||||
|
dict(testcase_name='gt_0.9', y_pred=tf.constant([0.9, 0.1])),
|
||||||
|
)
|
||||||
|
def test_focal_loss_decreases_with_increasing_gamma(self, y_pred: tf.Tensor):
|
||||||
|
y_true = tf.constant([[1, 0]])
|
||||||
|
|
||||||
|
focal_loss_gamma_0 = loss_functions.FocalLoss(gamma=0)
|
||||||
|
loss_gamma_0 = focal_loss_gamma_0(y_true, y_pred)
|
||||||
|
focal_loss_gamma_0p5 = loss_functions.FocalLoss(gamma=0.5)
|
||||||
|
loss_gamma_0p5 = focal_loss_gamma_0p5(y_true, y_pred)
|
||||||
|
focal_loss_gamma_1 = loss_functions.FocalLoss(gamma=1)
|
||||||
|
loss_gamma_1 = focal_loss_gamma_1(y_true, y_pred)
|
||||||
|
focal_loss_gamma_2 = loss_functions.FocalLoss(gamma=2)
|
||||||
|
loss_gamma_2 = focal_loss_gamma_2(y_true, y_pred)
|
||||||
|
focal_loss_gamma_5 = loss_functions.FocalLoss(gamma=5)
|
||||||
|
loss_gamma_5 = focal_loss_gamma_5(y_true, y_pred)
|
||||||
|
|
||||||
|
self.assertGreater(loss_gamma_0, loss_gamma_0p5)
|
||||||
|
self.assertGreater(loss_gamma_0p5, loss_gamma_1)
|
||||||
|
self.assertGreater(loss_gamma_1, loss_gamma_2)
|
||||||
|
self.assertGreater(loss_gamma_2, loss_gamma_5)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(testcase_name='index_0', true_class=0),
|
||||||
|
dict(testcase_name='index_1', true_class=1),
|
||||||
|
dict(testcase_name='index_2', true_class=2),
|
||||||
|
)
|
||||||
|
def test_focal_loss_class_weight_is_applied(self, true_class: int):
|
||||||
|
class_weight = [1.0, 3.0, 10.0]
|
||||||
|
y_pred = tf.constant([[1.0, 1.0, 1.0]]) / 3.0
|
||||||
|
y_true = tf.one_hot(true_class, depth=3)[tf.newaxis, :]
|
||||||
|
expected_loss = -math.log(1.0 / 3.0) * class_weight[true_class]
|
||||||
|
|
||||||
|
loss_fn = loss_functions.FocalLoss(gamma=0, class_weight=class_weight)
|
||||||
|
loss = loss_fn(y_true, y_pred)
|
||||||
|
self.assertNear(loss, expected_loss, 1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
241
mediapipe/model_maker/python/core/utils/model_util.py
Normal file
241
mediapipe/model_maker/python/core/utils/model_util.py
Normal file
|
@ -0,0 +1,241 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Utilities for keras models."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import dataset
|
||||||
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
|
||||||
|
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
|
||||||
|
ESTIMITED_STEPS_PER_EPOCH = 1000
|
||||||
|
|
||||||
|
|
||||||
|
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
train_data: Optional[dataset.Dataset] = None) -> int:
|
||||||
|
"""Gets the estimated training steps per epoch.
|
||||||
|
|
||||||
|
1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
|
||||||
|
2. Else if we can get the length of training data successfully, returns
|
||||||
|
`train_data_length // batch_size`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
steps_per_epoch: int, training steps per epoch.
|
||||||
|
batch_size: int, batch size.
|
||||||
|
train_data: training data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated training steps per epoch.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if both steps_per_epoch and train_data are not set.
|
||||||
|
"""
|
||||||
|
if steps_per_epoch is not None:
|
||||||
|
# steps_per_epoch is set by users manually.
|
||||||
|
return steps_per_epoch
|
||||||
|
else:
|
||||||
|
if train_data is None:
|
||||||
|
raise ValueError('Input train_data cannot be None.')
|
||||||
|
# Gets the steps by the length of the training data.
|
||||||
|
return len(train_data) // batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def export_tflite(
|
||||||
|
model: tf.keras.Model,
|
||||||
|
tflite_filepath: str,
|
||||||
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
|
supported_ops: Tuple[tf.lite.OpsSet,
|
||||||
|
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,)):
|
||||||
|
"""Converts the model to tflite format and saves it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: model to be converted to tflite.
|
||||||
|
tflite_filepath: File path to save tflite model.
|
||||||
|
quantization_config: Configuration for post-training quantization.
|
||||||
|
supported_ops: A list of supported ops in the converted TFLite file.
|
||||||
|
"""
|
||||||
|
if tflite_filepath is None:
|
||||||
|
raise ValueError(
|
||||||
|
"TFLite filepath couldn't be None when exporting to tflite.")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
save_path = os.path.join(temp_dir, 'saved_model')
|
||||||
|
model.save(save_path, include_optimizer=False, save_format='tf')
|
||||||
|
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
|
||||||
|
|
||||||
|
if quantization_config:
|
||||||
|
converter = quantization_config.set_converter_with_quantization(converter)
|
||||||
|
|
||||||
|
converter.target_spec.supported_ops = supported_ops
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
|
||||||
|
f.write(tflite_model)
|
||||||
|
|
||||||
|
|
||||||
|
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||||
|
"""Applies a warmup schedule on a given learning rate decay schedule."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
initial_learning_rate: float,
|
||||||
|
decay_schedule_fn: Callable[[Any], Any],
|
||||||
|
warmup_steps: int,
|
||||||
|
name: Optional[str] = None):
|
||||||
|
"""Initializes a new instance of the `WarmUp` class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_learning_rate: learning rate after the warmup.
|
||||||
|
decay_schedule_fn: A function maps step to learning rate. Will be applied
|
||||||
|
for values of step larger than 'warmup_steps'.
|
||||||
|
warmup_steps: Number of steps to do warmup for.
|
||||||
|
name: TF namescope under which to perform the learning rate calculation.
|
||||||
|
"""
|
||||||
|
super(WarmUp, self).__init__()
|
||||||
|
self.initial_learning_rate = initial_learning_rate
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.decay_schedule_fn = decay_schedule_fn
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor:
|
||||||
|
with tf.name_scope(self.name or 'WarmUp') as name:
|
||||||
|
# Implements linear warmup. i.e., if global_step < warmup_steps, the
|
||||||
|
# learning rate will be `global_step/num_warmup_steps * init_lr`.
|
||||||
|
global_step_float = tf.cast(step, tf.float32)
|
||||||
|
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
|
||||||
|
warmup_percent_done = global_step_float / warmup_steps_float
|
||||||
|
warmup_learning_rate = self.initial_learning_rate * warmup_percent_done
|
||||||
|
return tf.cond(
|
||||||
|
global_step_float < warmup_steps_float,
|
||||||
|
lambda: warmup_learning_rate,
|
||||||
|
lambda: self.decay_schedule_fn(step),
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
def get_config(self) -> Dict[Text, Any]:
|
||||||
|
return {
|
||||||
|
'initial_learning_rate': self.initial_learning_rate,
|
||||||
|
'decay_schedule_fn': self.decay_schedule_fn,
|
||||||
|
'warmup_steps': self.warmup_steps,
|
||||||
|
'name': self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LiteRunner(object):
|
||||||
|
"""A runner to do inference with the TFLite model."""
|
||||||
|
|
||||||
|
def __init__(self, tflite_filepath: str):
|
||||||
|
"""Initializes Lite runner with tflite model file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tflite_filepath: File path to the TFLite model.
|
||||||
|
"""
|
||||||
|
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
|
||||||
|
tflite_model = f.read()
|
||||||
|
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||||
|
self.interpreter.allocate_tensors()
|
||||||
|
self.input_details = self.interpreter.get_input_details()
|
||||||
|
self.output_details = self.interpreter.get_output_details()
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]]
|
||||||
|
) -> Union[List[tf.Tensor], tf.Tensor]:
|
||||||
|
"""Runs inference with the TFLite model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensors: List / Dict of the input tensors of the TFLite model. The
|
||||||
|
order should be the same as the keras model if it's a list. It also
|
||||||
|
accepts tensor directly if the model has only 1 input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of the output tensors for multi-output models, otherwise just
|
||||||
|
the output tensor. The order should be the same as the keras model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(input_tensors, list) and not isinstance(
|
||||||
|
input_tensors, dict):
|
||||||
|
input_tensors = [input_tensors]
|
||||||
|
|
||||||
|
interpreter = self.interpreter
|
||||||
|
|
||||||
|
# Reshape inputs
|
||||||
|
for i, input_detail in enumerate(self.input_details):
|
||||||
|
input_tensor = _get_input_tensor(
|
||||||
|
input_tensors=input_tensors,
|
||||||
|
input_details=self.input_details,
|
||||||
|
index=i)
|
||||||
|
interpreter.resize_tensor_input(
|
||||||
|
input_index=input_detail['index'], tensor_size=input_tensor.shape)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
# Feed input to the interpreter
|
||||||
|
for i, input_detail in enumerate(self.input_details):
|
||||||
|
input_tensor = _get_input_tensor(
|
||||||
|
input_tensors=input_tensors,
|
||||||
|
input_details=self.input_details,
|
||||||
|
index=i)
|
||||||
|
if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
|
||||||
|
# Quantize the input
|
||||||
|
scale, zero_point = input_detail['quantization']
|
||||||
|
input_tensor = input_tensor / scale + zero_point
|
||||||
|
input_tensor = np.array(input_tensor, dtype=input_detail['dtype'])
|
||||||
|
interpreter.set_tensor(input_detail['index'], input_tensor)
|
||||||
|
|
||||||
|
interpreter.invoke()
|
||||||
|
|
||||||
|
output_tensors = []
|
||||||
|
for output_detail in self.output_details:
|
||||||
|
output_tensor = interpreter.get_tensor(output_detail['index'])
|
||||||
|
if output_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
|
||||||
|
# Dequantize the output
|
||||||
|
scale, zero_point = output_detail['quantization']
|
||||||
|
output_tensor = output_tensor.astype(np.float32)
|
||||||
|
output_tensor = (output_tensor - zero_point) * scale
|
||||||
|
output_tensors.append(output_tensor)
|
||||||
|
|
||||||
|
if len(output_tensors) == 1:
|
||||||
|
return output_tensors[0]
|
||||||
|
return output_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner':
|
||||||
|
"""Returns a `LiteRunner` from file path to TFLite model."""
|
||||||
|
lite_runner = LiteRunner(tflite_filepath)
|
||||||
|
return lite_runner
|
||||||
|
|
||||||
|
|
||||||
|
def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str,
|
||||||
|
tf.Tensor]],
|
||||||
|
input_details: Dict[str, Any], index: int) -> tf.Tensor:
|
||||||
|
"""Returns input tensor in `input_tensors` that maps `input_detail[i]`."""
|
||||||
|
if isinstance(input_tensors, dict):
|
||||||
|
# Gets the mapped input tensor.
|
||||||
|
input_detail = input_details
|
||||||
|
for input_tensor_name, input_tensor in input_tensors.items():
|
||||||
|
if input_tensor_name in input_detail['name']:
|
||||||
|
return input_tensor
|
||||||
|
raise ValueError('Input tensors don\'t contains a tensor that mapped the '
|
||||||
|
'input detail %s' % str(input_detail))
|
||||||
|
else:
|
||||||
|
return input_tensors[index]
|
137
mediapipe/model_maker/python/core/utils/model_util_test.py
Normal file
137
mediapipe/model_maker/python/core/utils/model_util_test.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
|
||||||
|
|
||||||
|
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='input_only_steps_per_epoch',
|
||||||
|
steps_per_epoch=1000,
|
||||||
|
batch_size=None,
|
||||||
|
train_data=None,
|
||||||
|
expected_steps_per_epoch=1000),
|
||||||
|
dict(
|
||||||
|
testcase_name='input_steps_per_epoch_and_batch_size',
|
||||||
|
steps_per_epoch=1000,
|
||||||
|
batch_size=32,
|
||||||
|
train_data=None,
|
||||||
|
expected_steps_per_epoch=1000),
|
||||||
|
dict(
|
||||||
|
testcase_name='input_steps_per_epoch_batch_size_and_train_data',
|
||||||
|
steps_per_epoch=1000,
|
||||||
|
batch_size=32,
|
||||||
|
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||||
|
[1, 0]]),
|
||||||
|
expected_steps_per_epoch=1000),
|
||||||
|
dict(
|
||||||
|
testcase_name='input_batch_size_and_train_data',
|
||||||
|
steps_per_epoch=None,
|
||||||
|
batch_size=2,
|
||||||
|
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||||
|
[1, 0]]),
|
||||||
|
expected_steps_per_epoch=2))
|
||||||
|
def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data,
|
||||||
|
expected_steps_per_epoch):
|
||||||
|
estimated_steps_per_epoch = model_util.get_steps_per_epoch(
|
||||||
|
steps_per_epoch=steps_per_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
train_data=train_data)
|
||||||
|
self.assertEqual(estimated_steps_per_epoch, expected_steps_per_epoch)
|
||||||
|
|
||||||
|
def test_get_steps_per_epoch_raise_value_error(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model_util.get_steps_per_epoch(
|
||||||
|
steps_per_epoch=None, batch_size=16, train_data=None)
|
||||||
|
|
||||||
|
def test_warmup(self):
|
||||||
|
init_lr = 0.1
|
||||||
|
warmup_steps = 1000
|
||||||
|
num_decay_steps = 100
|
||||||
|
learning_rate_fn = tf.keras.experimental.CosineDecay(
|
||||||
|
initial_learning_rate=init_lr, decay_steps=num_decay_steps)
|
||||||
|
warmup_object = model_util.WarmUp(
|
||||||
|
initial_learning_rate=init_lr,
|
||||||
|
decay_schedule_fn=learning_rate_fn,
|
||||||
|
warmup_steps=1000,
|
||||||
|
name='test')
|
||||||
|
self.assertEqual(
|
||||||
|
warmup_object.get_config(), {
|
||||||
|
'initial_learning_rate': init_lr,
|
||||||
|
'decay_schedule_fn': learning_rate_fn,
|
||||||
|
'warmup_steps': warmup_steps,
|
||||||
|
'name': 'test'
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_export_tflite(self):
|
||||||
|
input_dim = 4
|
||||||
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
|
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||||
|
model_util.export_tflite(model, tflite_file)
|
||||||
|
self._test_tflite(model, tflite_file, input_dim)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='dynamic_quantize',
|
||||||
|
config=quantization.QuantizationConfig.for_dynamic(),
|
||||||
|
model_size=1288),
|
||||||
|
dict(
|
||||||
|
testcase_name='int8_quantize',
|
||||||
|
config=quantization.QuantizationConfig.for_int8(
|
||||||
|
representative_data=test_util.create_dataset(
|
||||||
|
data_size=10, input_shape=[16], num_classes=3)),
|
||||||
|
model_size=1832),
|
||||||
|
dict(
|
||||||
|
testcase_name='float16_quantize',
|
||||||
|
config=quantization.QuantizationConfig.for_float16(),
|
||||||
|
model_size=1468))
|
||||||
|
def test_export_tflite_quantized(self, config, model_size):
|
||||||
|
input_dim = 16
|
||||||
|
num_classes = 2
|
||||||
|
max_input_value = 5
|
||||||
|
model = test_util.build_model([input_dim], num_classes)
|
||||||
|
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
|
||||||
|
|
||||||
|
model_util.export_tflite(model, tflite_file, config)
|
||||||
|
self._test_tflite(
|
||||||
|
model, tflite_file, input_dim, max_input_value, atol=1e-00)
|
||||||
|
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
||||||
|
|
||||||
|
def _test_tflite(self,
|
||||||
|
keras_model: tf.keras.Model,
|
||||||
|
tflite_model_file: str,
|
||||||
|
input_dim: int,
|
||||||
|
max_input_value: int = 1000,
|
||||||
|
atol: float = 1e-04):
|
||||||
|
np.random.seed(0)
|
||||||
|
random_input = np.random.uniform(
|
||||||
|
low=0, high=max_input_value, size=(1, input_dim)).astype(np.float32)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
test_util.is_same_output(
|
||||||
|
tflite_model_file, keras_model, random_input, atol=atol))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
213
mediapipe/model_maker/python/core/utils/quantization.py
Normal file
213
mediapipe/model_maker/python/core/utils/quantization.py
Normal file
|
@ -0,0 +1,213 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Libraries for post-training quantization."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
|
|
||||||
|
DEFAULT_QUANTIZATION_STEPS = 500
|
||||||
|
|
||||||
|
|
||||||
|
def _get_representative_dataset_generator(dataset: tf.data.Dataset,
|
||||||
|
num_steps: int) -> Callable[[], Any]:
|
||||||
|
"""Gets a representative dataset generator for post-training quantization.
|
||||||
|
|
||||||
|
The generator is to provide a small dataset to calibrate or estimate the
|
||||||
|
range, i.e, (min, max) of all floating-point arrays in the model for
|
||||||
|
quantization. Usually, this is a small subset of a few hundred samples
|
||||||
|
randomly chosen, in no particular order, from the training or evaluation
|
||||||
|
dataset. See tf.lite.RepresentativeDataset for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Input dataset for extracting representative sub dataset.
|
||||||
|
num_steps: The number of quantization steps which also reflects the size of
|
||||||
|
the representative dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A representative dataset generator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def representative_dataset_gen():
|
||||||
|
"""Generates representative dataset for quantization."""
|
||||||
|
for data, _ in dataset.take(num_steps):
|
||||||
|
yield [data]
|
||||||
|
|
||||||
|
return representative_dataset_gen
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationConfig(object):
|
||||||
|
"""Configuration for post-training quantization.
|
||||||
|
|
||||||
|
Refer to
|
||||||
|
https://www.tensorflow.org/lite/performance/post_training_quantization
|
||||||
|
for different post-training quantization options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizations: Optional[Union[tf.lite.Optimize,
|
||||||
|
List[tf.lite.Optimize]]] = None,
|
||||||
|
representative_data: Optional[ds.Dataset] = None,
|
||||||
|
quantization_steps: Optional[int] = None,
|
||||||
|
inference_input_type: Optional[tf.dtypes.DType] = None,
|
||||||
|
inference_output_type: Optional[tf.dtypes.DType] = None,
|
||||||
|
supported_ops: Optional[Union[tf.lite.OpsSet,
|
||||||
|
List[tf.lite.OpsSet]]] = None,
|
||||||
|
supported_types: Optional[Union[tf.dtypes.DType,
|
||||||
|
List[tf.dtypes.DType]]] = None,
|
||||||
|
experimental_new_quantizer: bool = False,
|
||||||
|
):
|
||||||
|
"""Constructs QuantizationConfig.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizations: A list of optimizations to apply when converting the model.
|
||||||
|
If not set, use `[Optimize.DEFAULT]` by default.
|
||||||
|
representative_data: A representative ds.Dataset for post-training
|
||||||
|
quantization.
|
||||||
|
quantization_steps: Number of post-training quantization calibration steps
|
||||||
|
to run (default to DEFAULT_QUANTIZATION_STEPS).
|
||||||
|
inference_input_type: Target data type of real-number input arrays. Allows
|
||||||
|
for a different type for input arrays. Defaults to None. If set, must be
|
||||||
|
be `{tf.float32, tf.uint8, tf.int8}`.
|
||||||
|
inference_output_type: Target data type of real-number output arrays.
|
||||||
|
Allows for a different type for output arrays. Defaults to None. If set,
|
||||||
|
must be `{tf.float32, tf.uint8, tf.int8}`.
|
||||||
|
supported_ops: Set of OpsSet options supported by the device. Used to Set
|
||||||
|
converter.target_spec.supported_ops.
|
||||||
|
supported_types: List of types for constant values on the target device.
|
||||||
|
Supported values are types exported by lite.constants. Frequently, an
|
||||||
|
optimization choice is driven by the most compact (i.e. smallest) type
|
||||||
|
in this list (default [constants.FLOAT]).
|
||||||
|
experimental_new_quantizer: Whether to enable experimental new quantizer.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if inference_input_type or inference_output_type are set but
|
||||||
|
not in {tf.float32, tf.uint8, tf.int8}.
|
||||||
|
"""
|
||||||
|
if inference_input_type is not None and inference_input_type not in {
|
||||||
|
tf.float32, tf.uint8, tf.int8
|
||||||
|
}:
|
||||||
|
raise ValueError('Unsupported inference_input_type %s' %
|
||||||
|
inference_input_type)
|
||||||
|
if inference_output_type is not None and inference_output_type not in {
|
||||||
|
tf.float32, tf.uint8, tf.int8
|
||||||
|
}:
|
||||||
|
raise ValueError('Unsupported inference_output_type %s' %
|
||||||
|
inference_output_type)
|
||||||
|
|
||||||
|
if optimizations is None:
|
||||||
|
optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
|
if not isinstance(optimizations, list):
|
||||||
|
optimizations = [optimizations]
|
||||||
|
self.optimizations = optimizations
|
||||||
|
|
||||||
|
self.representative_data = representative_data
|
||||||
|
if self.representative_data is not None and quantization_steps is None:
|
||||||
|
quantization_steps = DEFAULT_QUANTIZATION_STEPS
|
||||||
|
self.quantization_steps = quantization_steps
|
||||||
|
|
||||||
|
self.inference_input_type = inference_input_type
|
||||||
|
self.inference_output_type = inference_output_type
|
||||||
|
|
||||||
|
if supported_ops is not None and not isinstance(supported_ops, list):
|
||||||
|
supported_ops = [supported_ops]
|
||||||
|
self.supported_ops = supported_ops
|
||||||
|
|
||||||
|
if supported_types is not None and not isinstance(supported_types, list):
|
||||||
|
supported_types = [supported_types]
|
||||||
|
self.supported_types = supported_types
|
||||||
|
|
||||||
|
self.experimental_new_quantizer = experimental_new_quantizer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_dynamic(cls) -> 'QuantizationConfig':
|
||||||
|
"""Creates configuration for dynamic range quantization."""
|
||||||
|
return QuantizationConfig()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_int8(
|
||||||
|
cls,
|
||||||
|
representative_data: ds.Dataset,
|
||||||
|
quantization_steps: int = DEFAULT_QUANTIZATION_STEPS,
|
||||||
|
inference_input_type: tf.dtypes.DType = tf.uint8,
|
||||||
|
inference_output_type: tf.dtypes.DType = tf.uint8,
|
||||||
|
supported_ops: tf.lite.OpsSet = tf.lite.OpsSet.TFLITE_BUILTINS_INT8
|
||||||
|
) -> 'QuantizationConfig':
|
||||||
|
"""Creates configuration for full integer quantization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
representative_data: Representative data used for post-training
|
||||||
|
quantization.
|
||||||
|
quantization_steps: Number of post-training quantization calibration steps
|
||||||
|
to run.
|
||||||
|
inference_input_type: Target data type of real-number input arrays.
|
||||||
|
inference_output_type: Target data type of real-number output arrays.
|
||||||
|
supported_ops: Set of `tf.lite.OpsSet` options, where each option
|
||||||
|
represents a set of operators supported by the target device.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QuantizationConfig.
|
||||||
|
"""
|
||||||
|
return QuantizationConfig(
|
||||||
|
representative_data=representative_data,
|
||||||
|
quantization_steps=quantization_steps,
|
||||||
|
inference_input_type=inference_input_type,
|
||||||
|
inference_output_type=inference_output_type,
|
||||||
|
supported_ops=supported_ops)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_float16(cls) -> 'QuantizationConfig':
|
||||||
|
"""Creates configuration for float16 quantization."""
|
||||||
|
return QuantizationConfig(supported_types=[tf.float16])
|
||||||
|
|
||||||
|
def set_converter_with_quantization(self, converter: tf.lite.TFLiteConverter,
|
||||||
|
**kwargs: Any) -> tf.lite.TFLiteConverter:
|
||||||
|
"""Sets input TFLite converter with quantization configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
converter: input tf.lite.TFLiteConverter.
|
||||||
|
**kwargs: arguments used by ds.Dataset.gen_tf_dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.lite.TFLiteConverter with quantization configurations.
|
||||||
|
"""
|
||||||
|
converter.optimizations = self.optimizations
|
||||||
|
|
||||||
|
if self.representative_data is not None:
|
||||||
|
tf_ds = self.representative_data.gen_tf_dataset(
|
||||||
|
batch_size=1, is_training=False, **kwargs)
|
||||||
|
converter.representative_dataset = tf.lite.RepresentativeDataset(
|
||||||
|
_get_representative_dataset_generator(tf_ds, self.quantization_steps))
|
||||||
|
|
||||||
|
if self.inference_input_type:
|
||||||
|
converter.inference_input_type = self.inference_input_type
|
||||||
|
if self.inference_output_type:
|
||||||
|
converter.inference_output_type = self.inference_output_type
|
||||||
|
if self.supported_ops:
|
||||||
|
converter.target_spec.supported_ops = self.supported_ops
|
||||||
|
if self.supported_types:
|
||||||
|
converter.target_spec.supported_types = self.supported_types
|
||||||
|
|
||||||
|
if self.experimental_new_quantizer is not None:
|
||||||
|
converter.experimental_new_quantizer = self.experimental_new_quantizer
|
||||||
|
return converter
|
108
mediapipe/model_maker/python/core/utils/quantization_test.py
Normal file
108
mediapipe/model_maker/python/core/utils/quantization_test.py
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_create_dynamic_quantization_config(self):
|
||||||
|
config = quantization.QuantizationConfig.for_dynamic()
|
||||||
|
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||||
|
self.assertIsNone(config.representative_data)
|
||||||
|
self.assertIsNone(config.inference_input_type)
|
||||||
|
self.assertIsNone(config.inference_output_type)
|
||||||
|
self.assertIsNone(config.supported_ops)
|
||||||
|
self.assertIsNone(config.supported_types)
|
||||||
|
self.assertFalse(config.experimental_new_quantizer)
|
||||||
|
|
||||||
|
def test_create_int8_quantization_config(self):
|
||||||
|
representative_data = test_util.create_dataset(
|
||||||
|
data_size=10, input_shape=[4], num_classes=3)
|
||||||
|
config = quantization.QuantizationConfig.for_int8(
|
||||||
|
representative_data=representative_data)
|
||||||
|
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||||
|
self.assertEqual(config.inference_input_type, tf.uint8)
|
||||||
|
self.assertEqual(config.inference_output_type, tf.uint8)
|
||||||
|
self.assertEqual(config.supported_ops,
|
||||||
|
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
|
||||||
|
self.assertFalse(config.experimental_new_quantizer)
|
||||||
|
|
||||||
|
def test_set_converter_with_quantization_from_int8_config(self):
|
||||||
|
representative_data = test_util.create_dataset(
|
||||||
|
data_size=10, input_shape=[4], num_classes=3)
|
||||||
|
config = quantization.QuantizationConfig.for_int8(
|
||||||
|
representative_data=representative_data)
|
||||||
|
model = test_util.build_model(input_shape=[4], num_classes=3)
|
||||||
|
saved_model_dir = self.get_temp_dir()
|
||||||
|
model.save(saved_model_dir)
|
||||||
|
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
||||||
|
converter = config.set_converter_with_quantization(converter=converter)
|
||||||
|
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||||
|
self.assertEqual(config.inference_input_type, tf.uint8)
|
||||||
|
self.assertEqual(config.inference_output_type, tf.uint8)
|
||||||
|
self.assertEqual(config.supported_ops,
|
||||||
|
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||||
|
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.uint8)
|
||||||
|
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.uint8)
|
||||||
|
|
||||||
|
def test_create_float16_quantization_config(self):
|
||||||
|
config = quantization.QuantizationConfig.for_float16()
|
||||||
|
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||||
|
self.assertIsNone(config.representative_data)
|
||||||
|
self.assertIsNone(config.inference_input_type)
|
||||||
|
self.assertIsNone(config.inference_output_type)
|
||||||
|
self.assertIsNone(config.supported_ops)
|
||||||
|
self.assertEqual(config.supported_types, [tf.float16])
|
||||||
|
self.assertFalse(config.experimental_new_quantizer)
|
||||||
|
|
||||||
|
def test_set_converter_with_quantization_from_float16_config(self):
|
||||||
|
config = quantization.QuantizationConfig.for_float16()
|
||||||
|
model = test_util.build_model(input_shape=[4], num_classes=3)
|
||||||
|
saved_model_dir = self.get_temp_dir()
|
||||||
|
model.save(saved_model_dir)
|
||||||
|
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
||||||
|
converter = config.set_converter_with_quantization(converter=converter)
|
||||||
|
self.assertEqual(config.supported_types, [tf.float16])
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||||
|
# The input and output are expected to be set to float32 by default.
|
||||||
|
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.float32)
|
||||||
|
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.float32)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='invalid_inference_input_type',
|
||||||
|
inference_input_type=tf.uint8,
|
||||||
|
inference_output_type=tf.int64),
|
||||||
|
dict(
|
||||||
|
testcase_name='invalid_inference_output_type',
|
||||||
|
inference_input_type=tf.int64,
|
||||||
|
inference_output_type=tf.float32))
|
||||||
|
def test_create_quantization_config_failure(self, inference_input_type,
|
||||||
|
inference_output_type):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = quantization.QuantizationConfig(
|
||||||
|
inference_input_type=inference_input_type,
|
||||||
|
inference_output_type=inference_output_type)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
76
mediapipe/model_maker/python/core/utils/test_util.py
Normal file
76
mediapipe/model_maker/python/core/utils/test_util.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Test utilities for model maker."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(data_size: int,
|
||||||
|
input_shape: List[int],
|
||||||
|
num_classes: int,
|
||||||
|
max_input_value: int = 1000) -> ds.Dataset:
|
||||||
|
"""Creates and returns a simple `Dataset` object for test."""
|
||||||
|
features = tf.random.uniform(
|
||||||
|
shape=[data_size] + input_shape,
|
||||||
|
minval=0,
|
||||||
|
maxval=max_input_value,
|
||||||
|
dtype=tf.float32)
|
||||||
|
|
||||||
|
labels = tf.random.uniform(
|
||||||
|
shape=[data_size], minval=0, maxval=num_classes, dtype=tf.int32)
|
||||||
|
|
||||||
|
tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels))
|
||||||
|
dataset = ds.Dataset(tf_dataset, data_size)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
|
||||||
|
"""Builds a simple Keras model for test."""
|
||||||
|
inputs = tf.keras.layers.Input(shape=input_shape)
|
||||||
|
if len(input_shape) == 3: # Image inputs.
|
||||||
|
outputs = tf.keras.layers.GlobalAveragePooling2D()(inputs)
|
||||||
|
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(outputs)
|
||||||
|
elif len(input_shape) == 1: # Text inputs.
|
||||||
|
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(inputs)
|
||||||
|
else:
|
||||||
|
raise ValueError("Model inputs should be 2D tensor or 4D tensor.")
|
||||||
|
|
||||||
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def is_same_output(tflite_file: str,
|
||||||
|
keras_model: tf.keras.Model,
|
||||||
|
input_tensors: Union[List[tf.Tensor], tf.Tensor],
|
||||||
|
atol: float = 1e-04) -> bool:
|
||||||
|
"""Returns if the output of TFLite model and keras model are identical."""
|
||||||
|
# Gets output from lite model.
|
||||||
|
lite_runner = model_util.get_lite_runner(tflite_file)
|
||||||
|
lite_output = lite_runner.run(input_tensors)
|
||||||
|
|
||||||
|
# Gets output from keras model.
|
||||||
|
keras_output = keras_model.predict_on_batch(input_tensors)
|
||||||
|
|
||||||
|
return np.allclose(lite_output, keras_output, atol=atol)
|
4
mediapipe/model_maker/requirements.txt
Normal file
4
mediapipe/model_maker/requirements.txt
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
absl-py
|
||||||
|
numpy
|
||||||
|
opencv-contrib-python
|
||||||
|
tensorflow
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -550,6 +550,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/ssd_mobilenet_v1.tflite?generation=1661875947436302"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/ssd_mobilenet_v1.tflite?generation=1661875947436302"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_test_jpg",
|
||||||
|
sha256 = "798a12a466933842528d8438f553320eebe5137f02650f12dd68706a2f94fb4f",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/test.jpg?generation=1664672140191116"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_test_model_add_op_tflite",
|
name = "com_google_mediapipe_test_model_add_op_tflite",
|
||||||
sha256 = "298300ca8a9193b80ada1dca39d36f20bffeebde09e85385049b3bfe7be2272f",
|
sha256 = "298300ca8a9193b80ada1dca39d36f20bffeebde09e85385049b3bfe7be2272f",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user