Add the dataset module for face stylizer in model maker.

PiperOrigin-RevId: 516628350
This commit is contained in:
MediaPipe Team 2023-03-14 14:08:49 -07:00 committed by Copybara-Service
parent ade31b567b
commit 6774794d02
6 changed files with 208 additions and 0 deletions

View File

@ -0,0 +1,48 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(default_visibility = ["//mediapipe:__subpackages__"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/vision/core:image_utils",
],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
data = [
":testdata",
],
deps = [
":dataset",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -0,0 +1,14 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Model Maker Python Public API For Face Stylization."""

View File

@ -0,0 +1,98 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Face stylizer dataset library."""
import logging
import os
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils
# TODO: Change to a unlabeled dataset if it makes sense.
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for face stylizer fine tuning."""
@classmethod
def from_folder(
cls, dirname: str
) -> classification_dataset.ClassificationDataset:
"""Loads images from the given directory.
The style image dataset directory is expected to contain one subdirectory
whose name represents the label of the style. There can be one or multiple
images of the same style in that subdirectory. Supported input image formats
include 'jpg', 'jpeg', 'png'.
Args:
dirname: Name of the directory containing the image files.
Returns:
Dataset containing images and labels and other related info.
Raises:
ValueError: if the input data directory is empty.
"""
data_root = os.path.abspath(dirname)
# Assumes the image data of the same label are in the same subdirectory,
# gets image path and label names.
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
all_image_size = len(all_image_paths)
if all_image_size == 0:
raise ValueError('Invalid input data directory')
if not any(
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
):
raise ValueError('No images found under given directory')
label_names = sorted(
name
for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name))
)
all_label_size = len(label_names)
index_by_label = dict(
(name, index) for index, name in enumerate(label_names)
)
# Get the style label from the subdirectory name.
all_image_labels = [
index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
)
# Load label
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(all_image_labels, tf.int64)
)
# Create a dataset of (image, label) pairs
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
logging.info(
'Load images dataset with size: %d, num_label: %d, labels: %s.',
all_image_size,
all_label_size,
', '.join(label_names),
)
return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names
)

View File

@ -0,0 +1,48 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from mediapipe.model_maker.python.vision.face_stylizer import dataset
from mediapipe.tasks.python.test import test_utils
class DatasetTest(tf.test.TestCase):
def setUp(self):
super().setUp()
# TODO: Replace the stylize image dataset with licensed images.
self._test_data_dirname = 'testdata'
def test_from_folder(self):
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
data = dataset.Dataset.from_folder(dirname=input_data_dir)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
self.assertLen(data, 2)
def test_from_folder_raise_value_error_for_invalid_path(self):
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
dataset.Dataset.from_folder(dirname='invalid')
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
input_data_dir = test_utils.get_test_data_path('face_stylizer')
with self.assertRaisesRegex(
ValueError, 'No images found under given directory'
):
dataset.Dataset.from_folder(dirname=input_data_dir)
if __name__ == '__main__':
tf.test.main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 347 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 336 KiB