From bdd1c24990e715269e49e4c9fc58ec7a8c6a8711 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 15 Feb 2023 16:52:50 -0800 Subject: [PATCH] Refactor common methods into vision/core/image_utils.py and vision/core/test_utils.py PiperOrigin-RevId: 509968910 --- .../model_maker/python/vision/core/BUILD | 19 +++++++ .../python/vision/core/image_utils.py | 28 ++++++++++ .../python/vision/core/image_utils_test.py | 37 ++++++++++++++ .../python/vision/core/test_utils.py | 51 +++++++++++++++++++ .../python/vision/image_classifier/BUILD | 11 +++- .../python/vision/image_classifier/dataset.py | 15 ++---- .../vision/image_classifier/dataset_test.py | 30 +++++------ 7 files changed, 160 insertions(+), 31 deletions(-) create mode 100644 mediapipe/model_maker/python/vision/core/image_utils.py create mode 100644 mediapipe/model_maker/python/vision/core/image_utils_test.py create mode 100644 mediapipe/model_maker/python/vision/core/test_utils.py diff --git a/mediapipe/model_maker/python/vision/core/BUILD b/mediapipe/model_maker/python/vision/core/BUILD index 0b15a0276..6dd547ff1 100644 --- a/mediapipe/model_maker/python/vision/core/BUILD +++ b/mediapipe/model_maker/python/vision/core/BUILD @@ -31,3 +31,22 @@ py_test( srcs = ["image_preprocessing_test.py"], deps = [":image_preprocessing"], ) + +py_library( + name = "image_utils", + srcs = ["image_utils.py"], +) + +py_test( + name = "image_utils_test", + srcs = ["image_utils_test.py"], + deps = [ + ":image_utils", + ":test_utils", + ], +) + +py_library( + name = "test_utils", + srcs = ["test_utils.py"], +) diff --git a/mediapipe/model_maker/python/vision/core/image_utils.py b/mediapipe/model_maker/python/vision/core/image_utils.py new file mode 100644 index 000000000..80d0616e5 --- /dev/null +++ b/mediapipe/model_maker/python/vision/core/image_utils.py @@ -0,0 +1,28 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for Images.""" + +import tensorflow as tf + + +def load_image(path: str) -> tf.Tensor: + """Loads a jpeg/png image and returns an image tensor.""" + image_raw = tf.io.read_file(path) + image_tensor = tf.cond( + tf.io.is_jpeg(image_raw), + lambda: tf.io.decode_jpeg(image_raw, channels=3), + lambda: tf.io.decode_png(image_raw, channels=3), + ) + return image_tensor diff --git a/mediapipe/model_maker/python/vision/core/image_utils_test.py b/mediapipe/model_maker/python/vision/core/image_utils_test.py new file mode 100644 index 000000000..84101113c --- /dev/null +++ b/mediapipe/model_maker/python/vision/core/image_utils_test.py @@ -0,0 +1,37 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tensorflow as tf + +from mediapipe.model_maker.python.vision.core import image_utils +from mediapipe.model_maker.python.vision.core import test_utils + + +class ImageUtilsTest(tf.test.TestCase): + + def setUp(self): + super().setUp() + self.jpeg_img = os.path.join(self.get_temp_dir(), 'image.jpeg') + if os.path.exists(self.jpeg_img): + return + test_utils.write_filled_jpeg_file(self.jpeg_img, [0, 125, 255], 224) + + def test_load_image(self): + img_tensor = image_utils.load_image(self.jpeg_img) + self.assertEqual(img_tensor.shape, (224, 224, 3)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/core/test_utils.py b/mediapipe/model_maker/python/vision/core/test_utils.py new file mode 100644 index 000000000..528b2ca7b --- /dev/null +++ b/mediapipe/model_maker/python/vision/core/test_utils.py @@ -0,0 +1,51 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test utilities for model maker vision module.""" + +from typing import Collection + +import numpy as np +import tensorflow as tf + + +def fill_image(rgb: Collection[int], image_size: int): + """Test helper function to create images. + + Args: + rgb: A tuple or array of rgb values in [r, g, b] format + image_size: Int specifying the edge of the square image + + Returns: + Numpy array of shape (image_size, image_size, 3) filled with the rgb color + """ + r, g, b = rgb + return np.broadcast_to( + np.array([[[r, g, b]]], dtype=np.uint8), shape=(image_size, image_size, 3) + ) + + +def write_filled_jpeg_file(path: str, rgb: Collection[int], image_size: int): + """Writes an image to a file path. + + Args: + path: location to write the image + rgb: A tuple or array of rgb values in [r, g, b] format + image_size: Int specifying the edge of the square image + """ + tf.keras.preprocessing.image.save_img( + path=path, + x=fill_image(rgb, image_size), + data_format='channels_last', + file_format='jpeg', + ) diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index bd916a92b..f88616690 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -55,13 +55,20 @@ py_test( py_library( name = "dataset", srcs = ["dataset.py"], - deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"], + deps = [ + "//mediapipe/model_maker/python/core/data:classification_dataset", + "//mediapipe/model_maker/python/vision/core:image_utils", + ], ) py_test( name = "dataset_test", srcs = ["dataset_test.py"], - deps = [":dataset"], + deps = [ + ":dataset", + "//mediapipe/model_maker/python/vision/core:image_utils", + "//mediapipe/model_maker/python/vision/core:test_utils", + ], ) py_library( diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset.py b/mediapipe/model_maker/python/vision/image_classifier/dataset.py index 071fe483e..bf4bbc4b6 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/dataset.py +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset.py @@ -21,16 +21,7 @@ import tensorflow as tf import tensorflow_datasets as tfds from mediapipe.model_maker.python.core.data import classification_dataset - - -def _load_image(path: str) -> tf.Tensor: - """Loads a jpeg/png image and returns an image tensor.""" - image_raw = tf.io.read_file(path) - image_tensor = tf.cond( - tf.io.is_jpeg(image_raw), - lambda: tf.io.decode_jpeg(image_raw, channels=3), - lambda: tf.io.decode_png(image_raw, channels=3)) - return image_tensor +from mediapipe.model_maker.python.vision.core import image_utils def _create_data( @@ -93,7 +84,9 @@ class Dataset(classification_dataset.ClassificationDataset): path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) - image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE) + image_ds = path_ds.map( + image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE + ) # Load label label_ds = tf.data.Dataset.from_tensor_slices( diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py index 0eed547eb..1f290b327 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py @@ -17,21 +17,11 @@ import random import numpy as np import tensorflow as tf +from mediapipe.model_maker.python.vision.core import image_utils +from mediapipe.model_maker.python.vision.core import test_utils from mediapipe.model_maker.python.vision.image_classifier import dataset -def _fill_image(rgb, image_size): - r, g, b = rgb - return np.broadcast_to( - np.array([[[r, g, b]]], dtype=np.uint8), - shape=(image_size, image_size, 3)) - - -def _write_filled_jpeg_file(path, rgb, image_size): - tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size), - 'channels_last', 'jpeg') - - class DatasetTest(tf.test.TestCase): def setUp(self): @@ -43,9 +33,11 @@ class DatasetTest(tf.test.TestCase): for class_name in ('daisy', 'tulips'): class_subdir = os.path.join(self.image_path, class_name) os.mkdir(class_subdir) - _write_filled_jpeg_file( + test_utils.write_filled_jpeg_file( os.path.join(class_subdir, '0.jpeg'), - [random.uniform(0, 255) for _ in range(3)], 224) + [random.uniform(0, 255) for _ in range(3)], + 224, + ) def test_split(self): ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) @@ -73,11 +65,13 @@ class DatasetTest(tf.test.TestCase): for image, label in data.gen_tf_dataset(): self.assertTrue(label.numpy() == 1 or label.numpy() == 0) if label.numpy() == 0: - raw_image_tensor = dataset._load_image( - os.path.join(self.image_path, 'daisy', '0.jpeg')) + raw_image_tensor = image_utils.load_image( + os.path.join(self.image_path, 'daisy', '0.jpeg') + ) else: - raw_image_tensor = dataset._load_image( - os.path.join(self.image_path, 'tulips', '0.jpeg')) + raw_image_tensor = image_utils.load_image( + os.path.join(self.image_path, 'tulips', '0.jpeg') + ) self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all()) def test_from_tfds(self):