Refactor common methods into vision/core/image_utils.py and vision/core/test_utils.py
PiperOrigin-RevId: 509968910
This commit is contained in:
parent
3d4ed305bc
commit
bdd1c24990
|
@ -31,3 +31,22 @@ py_test(
|
||||||
srcs = ["image_preprocessing_test.py"],
|
srcs = ["image_preprocessing_test.py"],
|
||||||
deps = [":image_preprocessing"],
|
deps = [":image_preprocessing"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "image_utils",
|
||||||
|
srcs = ["image_utils.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "image_utils_test",
|
||||||
|
srcs = ["image_utils_test.py"],
|
||||||
|
deps = [
|
||||||
|
":image_utils",
|
||||||
|
":test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "test_utils",
|
||||||
|
srcs = ["test_utils.py"],
|
||||||
|
)
|
||||||
|
|
28
mediapipe/model_maker/python/vision/core/image_utils.py
Normal file
28
mediapipe/model_maker/python/vision/core/image_utils.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utilities for Images."""
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(path: str) -> tf.Tensor:
|
||||||
|
"""Loads a jpeg/png image and returns an image tensor."""
|
||||||
|
image_raw = tf.io.read_file(path)
|
||||||
|
image_tensor = tf.cond(
|
||||||
|
tf.io.is_jpeg(image_raw),
|
||||||
|
lambda: tf.io.decode_jpeg(image_raw, channels=3),
|
||||||
|
lambda: tf.io.decode_png(image_raw, channels=3),
|
||||||
|
)
|
||||||
|
return image_tensor
|
37
mediapipe/model_maker/python/vision/core/image_utils_test.py
Normal file
37
mediapipe/model_maker/python/vision/core/image_utils_test.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.core import image_utils
|
||||||
|
from mediapipe.model_maker.python.vision.core import test_utils
|
||||||
|
|
||||||
|
|
||||||
|
class ImageUtilsTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.jpeg_img = os.path.join(self.get_temp_dir(), 'image.jpeg')
|
||||||
|
if os.path.exists(self.jpeg_img):
|
||||||
|
return
|
||||||
|
test_utils.write_filled_jpeg_file(self.jpeg_img, [0, 125, 255], 224)
|
||||||
|
|
||||||
|
def test_load_image(self):
|
||||||
|
img_tensor = image_utils.load_image(self.jpeg_img)
|
||||||
|
self.assertEqual(img_tensor.shape, (224, 224, 3))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
51
mediapipe/model_maker/python/vision/core/test_utils.py
Normal file
51
mediapipe/model_maker/python/vision/core/test_utils.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Test utilities for model maker vision module."""
|
||||||
|
|
||||||
|
from typing import Collection
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def fill_image(rgb: Collection[int], image_size: int):
|
||||||
|
"""Test helper function to create images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rgb: A tuple or array of rgb values in [r, g, b] format
|
||||||
|
image_size: Int specifying the edge of the square image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Numpy array of shape (image_size, image_size, 3) filled with the rgb color
|
||||||
|
"""
|
||||||
|
r, g, b = rgb
|
||||||
|
return np.broadcast_to(
|
||||||
|
np.array([[[r, g, b]]], dtype=np.uint8), shape=(image_size, image_size, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def write_filled_jpeg_file(path: str, rgb: Collection[int], image_size: int):
|
||||||
|
"""Writes an image to a file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: location to write the image
|
||||||
|
rgb: A tuple or array of rgb values in [r, g, b] format
|
||||||
|
image_size: Int specifying the edge of the square image
|
||||||
|
"""
|
||||||
|
tf.keras.preprocessing.image.save_img(
|
||||||
|
path=path,
|
||||||
|
x=fill_image(rgb, image_size),
|
||||||
|
data_format='channels_last',
|
||||||
|
file_format='jpeg',
|
||||||
|
)
|
|
@ -55,13 +55,20 @@ py_test(
|
||||||
py_library(
|
py_library(
|
||||||
name = "dataset",
|
name = "dataset",
|
||||||
srcs = ["dataset.py"],
|
srcs = ["dataset.py"],
|
||||||
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "dataset_test",
|
name = "dataset_test",
|
||||||
srcs = ["dataset_test.py"],
|
srcs = ["dataset_test.py"],
|
||||||
deps = [":dataset"],
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||||
|
"//mediapipe/model_maker/python/vision/core:test_utils",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
|
|
@ -21,16 +21,7 @@ import tensorflow as tf
|
||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
from mediapipe.model_maker.python.vision.core import image_utils
|
||||||
|
|
||||||
def _load_image(path: str) -> tf.Tensor:
|
|
||||||
"""Loads a jpeg/png image and returns an image tensor."""
|
|
||||||
image_raw = tf.io.read_file(path)
|
|
||||||
image_tensor = tf.cond(
|
|
||||||
tf.io.is_jpeg(image_raw),
|
|
||||||
lambda: tf.io.decode_jpeg(image_raw, channels=3),
|
|
||||||
lambda: tf.io.decode_png(image_raw, channels=3))
|
|
||||||
return image_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def _create_data(
|
def _create_data(
|
||||||
|
@ -93,7 +84,9 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
|
|
||||||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||||
|
|
||||||
image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE)
|
image_ds = path_ds.map(
|
||||||
|
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
|
||||||
|
)
|
||||||
|
|
||||||
# Load label
|
# Load label
|
||||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
|
|
|
@ -17,21 +17,11 @@ import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.core import image_utils
|
||||||
|
from mediapipe.model_maker.python.vision.core import test_utils
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import dataset
|
from mediapipe.model_maker.python.vision.image_classifier import dataset
|
||||||
|
|
||||||
|
|
||||||
def _fill_image(rgb, image_size):
|
|
||||||
r, g, b = rgb
|
|
||||||
return np.broadcast_to(
|
|
||||||
np.array([[[r, g, b]]], dtype=np.uint8),
|
|
||||||
shape=(image_size, image_size, 3))
|
|
||||||
|
|
||||||
|
|
||||||
def _write_filled_jpeg_file(path, rgb, image_size):
|
|
||||||
tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size),
|
|
||||||
'channels_last', 'jpeg')
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetTest(tf.test.TestCase):
|
class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -43,9 +33,11 @@ class DatasetTest(tf.test.TestCase):
|
||||||
for class_name in ('daisy', 'tulips'):
|
for class_name in ('daisy', 'tulips'):
|
||||||
class_subdir = os.path.join(self.image_path, class_name)
|
class_subdir = os.path.join(self.image_path, class_name)
|
||||||
os.mkdir(class_subdir)
|
os.mkdir(class_subdir)
|
||||||
_write_filled_jpeg_file(
|
test_utils.write_filled_jpeg_file(
|
||||||
os.path.join(class_subdir, '0.jpeg'),
|
os.path.join(class_subdir, '0.jpeg'),
|
||||||
[random.uniform(0, 255) for _ in range(3)], 224)
|
[random.uniform(0, 255) for _ in range(3)],
|
||||||
|
224,
|
||||||
|
)
|
||||||
|
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||||
|
@ -73,11 +65,13 @@ class DatasetTest(tf.test.TestCase):
|
||||||
for image, label in data.gen_tf_dataset():
|
for image, label in data.gen_tf_dataset():
|
||||||
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
|
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
|
||||||
if label.numpy() == 0:
|
if label.numpy() == 0:
|
||||||
raw_image_tensor = dataset._load_image(
|
raw_image_tensor = image_utils.load_image(
|
||||||
os.path.join(self.image_path, 'daisy', '0.jpeg'))
|
os.path.join(self.image_path, 'daisy', '0.jpeg')
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raw_image_tensor = dataset._load_image(
|
raw_image_tensor = image_utils.load_image(
|
||||||
os.path.join(self.image_path, 'tulips', '0.jpeg'))
|
os.path.join(self.image_path, 'tulips', '0.jpeg')
|
||||||
|
)
|
||||||
self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all())
|
self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all())
|
||||||
|
|
||||||
def test_from_tfds(self):
|
def test_from_tfds(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user