mediapipe/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py
2022-10-14 10:47:34 -07:00

109 lines
4.0 KiB
Python

# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.vision.image_classifier import dataset
def _fill_image(rgb, image_size):
r, g, b = rgb
return np.broadcast_to(
np.array([[[r, g, b]]], dtype=np.uint8),
shape=(image_size, image_size, 3))
def _write_filled_jpeg_file(path, rgb, image_size):
tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size),
'channels_last', 'jpeg')
class DatasetTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.image_path = os.path.join(self.get_temp_dir(), 'random_image_dir')
if os.path.exists(self.image_path):
return
os.mkdir(self.image_path)
for class_name in ('daisy', 'tulips'):
class_subdir = os.path.join(self.image_path, class_name)
os.mkdir(class_subdir)
_write_filled_jpeg_file(
os.path.join(class_subdir, '0.jpeg'),
[random.uniform(0, 255) for _ in range(3)], 224)
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 2)
for i, elem in enumerate(train_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.index_to_label, ['pos', 'neg'])
self.assertLen(test_data, 2)
for i, elem in enumerate(test_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.index_to_label, ['pos', 'neg'])
def test_from_folder(self):
data = dataset.Dataset.from_folder(self.image_path)
self.assertLen(data, 2)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.index_to_label, ['daisy', 'tulips'])
for image, label in data.gen_tf_dataset():
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
if label.numpy() == 0:
raw_image_tensor = dataset._load_image(
os.path.join(self.image_path, 'daisy', '0.jpeg'))
else:
raw_image_tensor = dataset._load_image(
os.path.join(self.image_path, 'tulips', '0.jpeg'))
self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all())
def test_from_tfds(self):
# TODO: Remove this once tfds download error is fixed.
self.skipTest('Temporarily skip the unittest due to tfds download error.')
train_data, validation_data, test_data = (
dataset.Dataset.from_tfds('beans'))
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(train_data, 1034)
self.assertEqual(train_data.num_classes, 3)
self.assertEqual(train_data.index_to_label,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(validation_data, 133)
self.assertEqual(validation_data.num_classes, 3)
self.assertEqual(validation_data.index_to_label,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(test_data, 128)
self.assertEqual(test_data.num_classes, 3)
self.assertEqual(test_data.index_to_label,
['angular_leaf_spot', 'bean_rust', 'healthy'])
if __name__ == '__main__':
tf.test.main()