Refactor common methods into vision/core/image_utils.py and vision/core/test_utils.py

PiperOrigin-RevId: 509968910
This commit is contained in:
MediaPipe Team 2023-02-15 16:52:50 -08:00 committed by Copybara-Service
parent 3d4ed305bc
commit bdd1c24990
7 changed files with 160 additions and 31 deletions

View File

@ -31,3 +31,22 @@ py_test(
srcs = ["image_preprocessing_test.py"], srcs = ["image_preprocessing_test.py"],
deps = [":image_preprocessing"], deps = [":image_preprocessing"],
) )
py_library(
name = "image_utils",
srcs = ["image_utils.py"],
)
py_test(
name = "image_utils_test",
srcs = ["image_utils_test.py"],
deps = [
":image_utils",
":test_utils",
],
)
py_library(
name = "test_utils",
srcs = ["test_utils.py"],
)

View File

@ -0,0 +1,28 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for Images."""
import tensorflow as tf
def load_image(path: str) -> tf.Tensor:
"""Loads a jpeg/png image and returns an image tensor."""
image_raw = tf.io.read_file(path)
image_tensor = tf.cond(
tf.io.is_jpeg(image_raw),
lambda: tf.io.decode_jpeg(image_raw, channels=3),
lambda: tf.io.decode_png(image_raw, channels=3),
)
return image_tensor

View File

@ -0,0 +1,37 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tensorflow as tf
from mediapipe.model_maker.python.vision.core import image_utils
from mediapipe.model_maker.python.vision.core import test_utils
class ImageUtilsTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.jpeg_img = os.path.join(self.get_temp_dir(), 'image.jpeg')
if os.path.exists(self.jpeg_img):
return
test_utils.write_filled_jpeg_file(self.jpeg_img, [0, 125, 255], 224)
def test_load_image(self):
img_tensor = image_utils.load_image(self.jpeg_img)
self.assertEqual(img_tensor.shape, (224, 224, 3))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,51 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test utilities for model maker vision module."""
from typing import Collection
import numpy as np
import tensorflow as tf
def fill_image(rgb: Collection[int], image_size: int):
"""Test helper function to create images.
Args:
rgb: A tuple or array of rgb values in [r, g, b] format
image_size: Int specifying the edge of the square image
Returns:
Numpy array of shape (image_size, image_size, 3) filled with the rgb color
"""
r, g, b = rgb
return np.broadcast_to(
np.array([[[r, g, b]]], dtype=np.uint8), shape=(image_size, image_size, 3)
)
def write_filled_jpeg_file(path: str, rgb: Collection[int], image_size: int):
"""Writes an image to a file path.
Args:
path: location to write the image
rgb: A tuple or array of rgb values in [r, g, b] format
image_size: Int specifying the edge of the square image
"""
tf.keras.preprocessing.image.save_img(
path=path,
x=fill_image(rgb, image_size),
data_format='channels_last',
file_format='jpeg',
)

View File

@ -55,13 +55,20 @@ py_test(
py_library( py_library(
name = "dataset", name = "dataset",
srcs = ["dataset.py"], srcs = ["dataset.py"],
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"], deps = [
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/vision/core:image_utils",
],
) )
py_test( py_test(
name = "dataset_test", name = "dataset_test",
srcs = ["dataset_test.py"], srcs = ["dataset_test.py"],
deps = [":dataset"], deps = [
":dataset",
"//mediapipe/model_maker/python/vision/core:image_utils",
"//mediapipe/model_maker/python/vision/core:test_utils",
],
) )
py_library( py_library(

View File

@ -21,16 +21,7 @@ import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils
def _load_image(path: str) -> tf.Tensor:
"""Loads a jpeg/png image and returns an image tensor."""
image_raw = tf.io.read_file(path)
image_tensor = tf.cond(
tf.io.is_jpeg(image_raw),
lambda: tf.io.decode_jpeg(image_raw, channels=3),
lambda: tf.io.decode_png(image_raw, channels=3))
return image_tensor
def _create_data( def _create_data(
@ -93,7 +84,9 @@ class Dataset(classification_dataset.ClassificationDataset):
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE) image_ds = path_ds.map(
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
)
# Load label # Load label
label_ds = tf.data.Dataset.from_tensor_slices( label_ds = tf.data.Dataset.from_tensor_slices(

View File

@ -17,21 +17,11 @@ import random
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision.core import image_utils
from mediapipe.model_maker.python.vision.core import test_utils
from mediapipe.model_maker.python.vision.image_classifier import dataset from mediapipe.model_maker.python.vision.image_classifier import dataset
def _fill_image(rgb, image_size):
r, g, b = rgb
return np.broadcast_to(
np.array([[[r, g, b]]], dtype=np.uint8),
shape=(image_size, image_size, 3))
def _write_filled_jpeg_file(path, rgb, image_size):
tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size),
'channels_last', 'jpeg')
class DatasetTest(tf.test.TestCase): class DatasetTest(tf.test.TestCase):
def setUp(self): def setUp(self):
@ -43,9 +33,11 @@ class DatasetTest(tf.test.TestCase):
for class_name in ('daisy', 'tulips'): for class_name in ('daisy', 'tulips'):
class_subdir = os.path.join(self.image_path, class_name) class_subdir = os.path.join(self.image_path, class_name)
os.mkdir(class_subdir) os.mkdir(class_subdir)
_write_filled_jpeg_file( test_utils.write_filled_jpeg_file(
os.path.join(class_subdir, '0.jpeg'), os.path.join(class_subdir, '0.jpeg'),
[random.uniform(0, 255) for _ in range(3)], 224) [random.uniform(0, 255) for _ in range(3)],
224,
)
def test_split(self): def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
@ -73,11 +65,13 @@ class DatasetTest(tf.test.TestCase):
for image, label in data.gen_tf_dataset(): for image, label in data.gen_tf_dataset():
self.assertTrue(label.numpy() == 1 or label.numpy() == 0) self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
if label.numpy() == 0: if label.numpy() == 0:
raw_image_tensor = dataset._load_image( raw_image_tensor = image_utils.load_image(
os.path.join(self.image_path, 'daisy', '0.jpeg')) os.path.join(self.image_path, 'daisy', '0.jpeg')
)
else: else:
raw_image_tensor = dataset._load_image( raw_image_tensor = image_utils.load_image(
os.path.join(self.image_path, 'tulips', '0.jpeg')) os.path.join(self.image_path, 'tulips', '0.jpeg')
)
self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all()) self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all())
def test_from_tfds(self): def test_from_tfds(self):