Check if the image contains valid face that can be aligned for stylization. If not, throw an exception for invalid input image. This is applied to both input stylized face and raw face.

PiperOrigin-RevId: 561439600
This commit is contained in:
MediaPipe Team 2023-08-30 13:51:38 -07:00 committed by Copybara-Service
parent c92570f844
commit 612162d765
3 changed files with 14 additions and 2 deletions

View File

@ -40,6 +40,12 @@ def _preprocess_face_dataset(
tf.compat.v1.logging.info('Preprocess image %s', path) tf.compat.v1.logging.info('Preprocess image %s', path)
image = image_module.Image.create_from_file(path) image = image_module.Image.create_from_file(path)
aligned_image = aligner.align(image) aligned_image = aligner.align(image)
if aligned_image is None:
raise ValueError(
'ERROR: Invalid image. No face is detected and aligned. Please make'
' sure the image has a single face that is facing straightforward and'
' not significantly rotated.'
)
aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view()) aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view())
preprocessed_images.append(aligned_image_tensor) preprocessed_images.append(aligned_image_tensor)

View File

@ -27,8 +27,8 @@ class DatasetTest(tf.test.TestCase):
def test_from_image(self): def test_from_image(self):
test_image_file = 'input/style/cartoon/cartoon.jpg' test_image_file = 'input/style/cartoon/cartoon.jpg'
input_data_dir = test_utils.get_test_data_path(test_image_file) input_image_path = test_utils.get_test_data_path(test_image_file)
data = dataset.Dataset.from_image(filename=input_data_dir) data = dataset.Dataset.from_image(filename=input_image_path)
self.assertEqual(data.num_classes, 1) self.assertEqual(data.num_classes, 1)
self.assertEqual(data.label_names, ['cartoon']) self.assertEqual(data.label_names, ['cartoon'])
self.assertLen(data, 1) self.assertLen(data, 1)
@ -37,6 +37,12 @@ class DatasetTest(tf.test.TestCase):
with self.assertRaisesRegex(ValueError, 'Unsupported image formats: .zip'): with self.assertRaisesRegex(ValueError, 'Unsupported image formats: .zip'):
dataset.Dataset.from_image(filename='input/style/cartoon/cartoon.zip') dataset.Dataset.from_image(filename='input/style/cartoon/cartoon.zip')
def test_from_image_raise_value_error_for_invalid_image(self):
with self.assertRaisesRegex(ValueError, 'Invalid image'):
test_image_file = 'input/style/sketch/boy-6030802_1280.jpg'
input_image_path = test_utils.get_test_data_path(test_image_file)
dataset.Dataset.from_image(filename=input_image_path)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 275 KiB