Refactor common methods into vision/core/image_utils.py and vision/core/test_utils.py
PiperOrigin-RevId: 509968910
This commit is contained in:
		
							parent
							
								
									3d4ed305bc
								
							
						
					
					
						commit
						bdd1c24990
					
				|  | @ -31,3 +31,22 @@ py_test( | |||
|     srcs = ["image_preprocessing_test.py"], | ||||
|     deps = [":image_preprocessing"], | ||||
| ) | ||||
| 
 | ||||
| py_library( | ||||
|     name = "image_utils", | ||||
|     srcs = ["image_utils.py"], | ||||
| ) | ||||
| 
 | ||||
| py_test( | ||||
|     name = "image_utils_test", | ||||
|     srcs = ["image_utils_test.py"], | ||||
|     deps = [ | ||||
|         ":image_utils", | ||||
|         ":test_utils", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_library( | ||||
|     name = "test_utils", | ||||
|     srcs = ["test_utils.py"], | ||||
| ) | ||||
|  |  | |||
							
								
								
									
										28
									
								
								mediapipe/model_maker/python/vision/core/image_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								mediapipe/model_maker/python/vision/core/image_utils.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,28 @@ | |||
| # Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Utilities for Images.""" | ||||
| 
 | ||||
| import tensorflow as tf | ||||
| 
 | ||||
| 
 | ||||
| def load_image(path: str) -> tf.Tensor: | ||||
|   """Loads a jpeg/png image and returns an image tensor.""" | ||||
|   image_raw = tf.io.read_file(path) | ||||
|   image_tensor = tf.cond( | ||||
|       tf.io.is_jpeg(image_raw), | ||||
|       lambda: tf.io.decode_jpeg(image_raw, channels=3), | ||||
|       lambda: tf.io.decode_png(image_raw, channels=3), | ||||
|   ) | ||||
|   return image_tensor | ||||
							
								
								
									
										37
									
								
								mediapipe/model_maker/python/vision/core/image_utils_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								mediapipe/model_maker/python/vision/core/image_utils_test.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,37 @@ | |||
| # Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the 'License'); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an 'AS IS' BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import os | ||||
| import tensorflow as tf | ||||
| 
 | ||||
| from mediapipe.model_maker.python.vision.core import image_utils | ||||
| from mediapipe.model_maker.python.vision.core import test_utils | ||||
| 
 | ||||
| 
 | ||||
| class ImageUtilsTest(tf.test.TestCase): | ||||
| 
 | ||||
|   def setUp(self): | ||||
|     super().setUp() | ||||
|     self.jpeg_img = os.path.join(self.get_temp_dir(), 'image.jpeg') | ||||
|     if os.path.exists(self.jpeg_img): | ||||
|       return | ||||
|     test_utils.write_filled_jpeg_file(self.jpeg_img, [0, 125, 255], 224) | ||||
| 
 | ||||
|   def test_load_image(self): | ||||
|     img_tensor = image_utils.load_image(self.jpeg_img) | ||||
|     self.assertEqual(img_tensor.shape, (224, 224, 3)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|   tf.test.main() | ||||
							
								
								
									
										51
									
								
								mediapipe/model_maker/python/vision/core/test_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								mediapipe/model_maker/python/vision/core/test_utils.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,51 @@ | |||
| # Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """Test utilities for model maker vision module.""" | ||||
| 
 | ||||
| from typing import Collection | ||||
| 
 | ||||
| import numpy as np | ||||
| import tensorflow as tf | ||||
| 
 | ||||
| 
 | ||||
| def fill_image(rgb: Collection[int], image_size: int): | ||||
|   """Test helper function to create images. | ||||
| 
 | ||||
|   Args: | ||||
|     rgb: A tuple or array of rgb values in [r, g, b] format | ||||
|     image_size: Int specifying the edge of the square image | ||||
| 
 | ||||
|   Returns: | ||||
|     Numpy array of shape (image_size, image_size, 3) filled with the rgb color | ||||
|   """ | ||||
|   r, g, b = rgb | ||||
|   return np.broadcast_to( | ||||
|       np.array([[[r, g, b]]], dtype=np.uint8), shape=(image_size, image_size, 3) | ||||
|   ) | ||||
| 
 | ||||
| 
 | ||||
| def write_filled_jpeg_file(path: str, rgb: Collection[int], image_size: int): | ||||
|   """Writes an image to a file path. | ||||
| 
 | ||||
|   Args: | ||||
|     path: location to write the image | ||||
|     rgb: A tuple or array of rgb values in [r, g, b] format | ||||
|     image_size: Int specifying the edge of the square image | ||||
|   """ | ||||
|   tf.keras.preprocessing.image.save_img( | ||||
|       path=path, | ||||
|       x=fill_image(rgb, image_size), | ||||
|       data_format='channels_last', | ||||
|       file_format='jpeg', | ||||
|   ) | ||||
|  | @ -55,13 +55,20 @@ py_test( | |||
| py_library( | ||||
|     name = "dataset", | ||||
|     srcs = ["dataset.py"], | ||||
|     deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"], | ||||
|     deps = [ | ||||
|         "//mediapipe/model_maker/python/core/data:classification_dataset", | ||||
|         "//mediapipe/model_maker/python/vision/core:image_utils", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_test( | ||||
|     name = "dataset_test", | ||||
|     srcs = ["dataset_test.py"], | ||||
|     deps = [":dataset"], | ||||
|     deps = [ | ||||
|         ":dataset", | ||||
|         "//mediapipe/model_maker/python/vision/core:image_utils", | ||||
|         "//mediapipe/model_maker/python/vision/core:test_utils", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_library( | ||||
|  |  | |||
|  | @ -21,16 +21,7 @@ import tensorflow as tf | |||
| import tensorflow_datasets as tfds | ||||
| 
 | ||||
| from mediapipe.model_maker.python.core.data import classification_dataset | ||||
| 
 | ||||
| 
 | ||||
| def _load_image(path: str) -> tf.Tensor: | ||||
|   """Loads a jpeg/png image and returns an image tensor.""" | ||||
|   image_raw = tf.io.read_file(path) | ||||
|   image_tensor = tf.cond( | ||||
|       tf.io.is_jpeg(image_raw), | ||||
|       lambda: tf.io.decode_jpeg(image_raw, channels=3), | ||||
|       lambda: tf.io.decode_png(image_raw, channels=3)) | ||||
|   return image_tensor | ||||
| from mediapipe.model_maker.python.vision.core import image_utils | ||||
| 
 | ||||
| 
 | ||||
| def _create_data( | ||||
|  | @ -93,7 +84,9 @@ class Dataset(classification_dataset.ClassificationDataset): | |||
| 
 | ||||
|     path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) | ||||
| 
 | ||||
|     image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE) | ||||
|     image_ds = path_ds.map( | ||||
|         image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE | ||||
|     ) | ||||
| 
 | ||||
|     # Load label | ||||
|     label_ds = tf.data.Dataset.from_tensor_slices( | ||||
|  |  | |||
|  | @ -17,21 +17,11 @@ import random | |||
| import numpy as np | ||||
| import tensorflow as tf | ||||
| 
 | ||||
| from mediapipe.model_maker.python.vision.core import image_utils | ||||
| from mediapipe.model_maker.python.vision.core import test_utils | ||||
| from mediapipe.model_maker.python.vision.image_classifier import dataset | ||||
| 
 | ||||
| 
 | ||||
| def _fill_image(rgb, image_size): | ||||
|   r, g, b = rgb | ||||
|   return np.broadcast_to( | ||||
|       np.array([[[r, g, b]]], dtype=np.uint8), | ||||
|       shape=(image_size, image_size, 3)) | ||||
| 
 | ||||
| 
 | ||||
| def _write_filled_jpeg_file(path, rgb, image_size): | ||||
|   tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size), | ||||
|                                         'channels_last', 'jpeg') | ||||
| 
 | ||||
| 
 | ||||
| class DatasetTest(tf.test.TestCase): | ||||
| 
 | ||||
|   def setUp(self): | ||||
|  | @ -43,9 +33,11 @@ class DatasetTest(tf.test.TestCase): | |||
|     for class_name in ('daisy', 'tulips'): | ||||
|       class_subdir = os.path.join(self.image_path, class_name) | ||||
|       os.mkdir(class_subdir) | ||||
|       _write_filled_jpeg_file( | ||||
|       test_utils.write_filled_jpeg_file( | ||||
|           os.path.join(class_subdir, '0.jpeg'), | ||||
|           [random.uniform(0, 255) for _ in range(3)], 224) | ||||
|           [random.uniform(0, 255) for _ in range(3)], | ||||
|           224, | ||||
|       ) | ||||
| 
 | ||||
|   def test_split(self): | ||||
|     ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) | ||||
|  | @ -73,11 +65,13 @@ class DatasetTest(tf.test.TestCase): | |||
|     for image, label in data.gen_tf_dataset(): | ||||
|       self.assertTrue(label.numpy() == 1 or label.numpy() == 0) | ||||
|       if label.numpy() == 0: | ||||
|         raw_image_tensor = dataset._load_image( | ||||
|             os.path.join(self.image_path, 'daisy', '0.jpeg')) | ||||
|         raw_image_tensor = image_utils.load_image( | ||||
|             os.path.join(self.image_path, 'daisy', '0.jpeg') | ||||
|         ) | ||||
|       else: | ||||
|         raw_image_tensor = dataset._load_image( | ||||
|             os.path.join(self.image_path, 'tulips', '0.jpeg')) | ||||
|         raw_image_tensor = image_utils.load_image( | ||||
|             os.path.join(self.image_path, 'tulips', '0.jpeg') | ||||
|         ) | ||||
|       self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all()) | ||||
| 
 | ||||
|   def test_from_tfds(self): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user