Refactor common methods into vision/core/image_utils.py and vision/core/test_utils.py
PiperOrigin-RevId: 509968910
This commit is contained in:
parent
3d4ed305bc
commit
bdd1c24990
|
@ -31,3 +31,22 @@ py_test(
|
|||
srcs = ["image_preprocessing_test.py"],
|
||||
deps = [":image_preprocessing"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_utils",
|
||||
srcs = ["image_utils.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_utils_test",
|
||||
srcs = ["image_utils_test.py"],
|
||||
deps = [
|
||||
":image_utils",
|
||||
":test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_utils",
|
||||
srcs = ["test_utils.py"],
|
||||
)
|
||||
|
|
28
mediapipe/model_maker/python/vision/core/image_utils.py
Normal file
28
mediapipe/model_maker/python/vision/core/image_utils.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for Images."""
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def load_image(path: str) -> tf.Tensor:
|
||||
"""Loads a jpeg/png image and returns an image tensor."""
|
||||
image_raw = tf.io.read_file(path)
|
||||
image_tensor = tf.cond(
|
||||
tf.io.is_jpeg(image_raw),
|
||||
lambda: tf.io.decode_jpeg(image_raw, channels=3),
|
||||
lambda: tf.io.decode_png(image_raw, channels=3),
|
||||
)
|
||||
return image_tensor
|
37
mediapipe/model_maker/python/vision/core/image_utils_test.py
Normal file
37
mediapipe/model_maker/python/vision/core/image_utils_test.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.core import image_utils
|
||||
from mediapipe.model_maker.python.vision.core import test_utils
|
||||
|
||||
|
||||
class ImageUtilsTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.jpeg_img = os.path.join(self.get_temp_dir(), 'image.jpeg')
|
||||
if os.path.exists(self.jpeg_img):
|
||||
return
|
||||
test_utils.write_filled_jpeg_file(self.jpeg_img, [0, 125, 255], 224)
|
||||
|
||||
def test_load_image(self):
|
||||
img_tensor = image_utils.load_image(self.jpeg_img)
|
||||
self.assertEqual(img_tensor.shape, (224, 224, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
51
mediapipe/model_maker/python/vision/core/test_utils.py
Normal file
51
mediapipe/model_maker/python/vision/core/test_utils.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test utilities for model maker vision module."""
|
||||
|
||||
from typing import Collection
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def fill_image(rgb: Collection[int], image_size: int):
|
||||
"""Test helper function to create images.
|
||||
|
||||
Args:
|
||||
rgb: A tuple or array of rgb values in [r, g, b] format
|
||||
image_size: Int specifying the edge of the square image
|
||||
|
||||
Returns:
|
||||
Numpy array of shape (image_size, image_size, 3) filled with the rgb color
|
||||
"""
|
||||
r, g, b = rgb
|
||||
return np.broadcast_to(
|
||||
np.array([[[r, g, b]]], dtype=np.uint8), shape=(image_size, image_size, 3)
|
||||
)
|
||||
|
||||
|
||||
def write_filled_jpeg_file(path: str, rgb: Collection[int], image_size: int):
|
||||
"""Writes an image to a file path.
|
||||
|
||||
Args:
|
||||
path: location to write the image
|
||||
rgb: A tuple or array of rgb values in [r, g, b] format
|
||||
image_size: Int specifying the edge of the square image
|
||||
"""
|
||||
tf.keras.preprocessing.image.save_img(
|
||||
path=path,
|
||||
x=fill_image(rgb, image_size),
|
||||
data_format='channels_last',
|
||||
file_format='jpeg',
|
||||
)
|
|
@ -55,13 +55,20 @@ py_test(
|
|||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
deps = [":dataset"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||
"//mediapipe/model_maker/python/vision/core:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
|
|
@ -21,16 +21,7 @@ import tensorflow as tf
|
|||
import tensorflow_datasets as tfds
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
|
||||
|
||||
def _load_image(path: str) -> tf.Tensor:
|
||||
"""Loads a jpeg/png image and returns an image tensor."""
|
||||
image_raw = tf.io.read_file(path)
|
||||
image_tensor = tf.cond(
|
||||
tf.io.is_jpeg(image_raw),
|
||||
lambda: tf.io.decode_jpeg(image_raw, channels=3),
|
||||
lambda: tf.io.decode_png(image_raw, channels=3))
|
||||
return image_tensor
|
||||
from mediapipe.model_maker.python.vision.core import image_utils
|
||||
|
||||
|
||||
def _create_data(
|
||||
|
@ -93,7 +84,9 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
|
||||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||
|
||||
image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
image_ds = path_ds.map(
|
||||
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
|
||||
)
|
||||
|
||||
# Load label
|
||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||
|
|
|
@ -17,21 +17,11 @@ import random
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.core import image_utils
|
||||
from mediapipe.model_maker.python.vision.core import test_utils
|
||||
from mediapipe.model_maker.python.vision.image_classifier import dataset
|
||||
|
||||
|
||||
def _fill_image(rgb, image_size):
|
||||
r, g, b = rgb
|
||||
return np.broadcast_to(
|
||||
np.array([[[r, g, b]]], dtype=np.uint8),
|
||||
shape=(image_size, image_size, 3))
|
||||
|
||||
|
||||
def _write_filled_jpeg_file(path, rgb, image_size):
|
||||
tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size),
|
||||
'channels_last', 'jpeg')
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -43,9 +33,11 @@ class DatasetTest(tf.test.TestCase):
|
|||
for class_name in ('daisy', 'tulips'):
|
||||
class_subdir = os.path.join(self.image_path, class_name)
|
||||
os.mkdir(class_subdir)
|
||||
_write_filled_jpeg_file(
|
||||
test_utils.write_filled_jpeg_file(
|
||||
os.path.join(class_subdir, '0.jpeg'),
|
||||
[random.uniform(0, 255) for _ in range(3)], 224)
|
||||
[random.uniform(0, 255) for _ in range(3)],
|
||||
224,
|
||||
)
|
||||
|
||||
def test_split(self):
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
|
@ -73,11 +65,13 @@ class DatasetTest(tf.test.TestCase):
|
|||
for image, label in data.gen_tf_dataset():
|
||||
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
|
||||
if label.numpy() == 0:
|
||||
raw_image_tensor = dataset._load_image(
|
||||
os.path.join(self.image_path, 'daisy', '0.jpeg'))
|
||||
raw_image_tensor = image_utils.load_image(
|
||||
os.path.join(self.image_path, 'daisy', '0.jpeg')
|
||||
)
|
||||
else:
|
||||
raw_image_tensor = dataset._load_image(
|
||||
os.path.join(self.image_path, 'tulips', '0.jpeg'))
|
||||
raw_image_tensor = image_utils.load_image(
|
||||
os.path.join(self.image_path, 'tulips', '0.jpeg')
|
||||
)
|
||||
self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all())
|
||||
|
||||
def test_from_tfds(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user