Add the dataset module for face stylizer in model maker.
PiperOrigin-RevId: 516628350
This commit is contained in:
parent
ade31b567b
commit
6774794d02
48
mediapipe/model_maker/python/vision/face_stylizer/BUILD
Normal file
48
mediapipe/model_maker/python/vision/face_stylizer/BUILD
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = glob([
|
||||||
|
"testdata/**",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dataset",
|
||||||
|
srcs = ["dataset.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dataset_test",
|
||||||
|
srcs = ["dataset_test.py"],
|
||||||
|
data = [
|
||||||
|
":testdata",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,14 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe Model Maker Python Public API For Face Stylization."""
|
98
mediapipe/model_maker/python/vision/face_stylizer/dataset.py
Normal file
98
mediapipe/model_maker/python/vision/face_stylizer/dataset.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Face stylizer dataset library."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
from mediapipe.model_maker.python.vision.core import image_utils
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Change to a unlabeled dataset if it makes sense.
|
||||||
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
|
"""Dataset library for face stylizer fine tuning."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_folder(
|
||||||
|
cls, dirname: str
|
||||||
|
) -> classification_dataset.ClassificationDataset:
|
||||||
|
"""Loads images from the given directory.
|
||||||
|
|
||||||
|
The style image dataset directory is expected to contain one subdirectory
|
||||||
|
whose name represents the label of the style. There can be one or multiple
|
||||||
|
images of the same style in that subdirectory. Supported input image formats
|
||||||
|
include 'jpg', 'jpeg', 'png'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dirname: Name of the directory containing the image files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing images and labels and other related info.
|
||||||
|
Raises:
|
||||||
|
ValueError: if the input data directory is empty.
|
||||||
|
"""
|
||||||
|
data_root = os.path.abspath(dirname)
|
||||||
|
|
||||||
|
# Assumes the image data of the same label are in the same subdirectory,
|
||||||
|
# gets image path and label names.
|
||||||
|
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
||||||
|
all_image_size = len(all_image_paths)
|
||||||
|
if all_image_size == 0:
|
||||||
|
raise ValueError('Invalid input data directory')
|
||||||
|
if not any(
|
||||||
|
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
|
||||||
|
):
|
||||||
|
raise ValueError('No images found under given directory')
|
||||||
|
|
||||||
|
label_names = sorted(
|
||||||
|
name
|
||||||
|
for name in os.listdir(data_root)
|
||||||
|
if os.path.isdir(os.path.join(data_root, name))
|
||||||
|
)
|
||||||
|
all_label_size = len(label_names)
|
||||||
|
index_by_label = dict(
|
||||||
|
(name, index) for index, name in enumerate(label_names)
|
||||||
|
)
|
||||||
|
# Get the style label from the subdirectory name.
|
||||||
|
all_image_labels = [
|
||||||
|
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||||
|
for path in all_image_paths
|
||||||
|
]
|
||||||
|
|
||||||
|
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||||
|
|
||||||
|
image_ds = path_ds.map(
|
||||||
|
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load label
|
||||||
|
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
|
tf.cast(all_image_labels, tf.int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a dataset of (image, label) pairs
|
||||||
|
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
'Load images dataset with size: %d, num_label: %d, labels: %s.',
|
||||||
|
all_image_size,
|
||||||
|
all_label_size,
|
||||||
|
', '.join(label_names),
|
||||||
|
)
|
||||||
|
return Dataset(
|
||||||
|
dataset=image_label_ds, size=all_image_size, label_names=label_names
|
||||||
|
)
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import dataset
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# TODO: Replace the stylize image dataset with licensed images.
|
||||||
|
self._test_data_dirname = 'testdata'
|
||||||
|
|
||||||
|
def test_from_folder(self):
|
||||||
|
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
|
||||||
|
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||||
|
self.assertEqual(data.num_classes, 2)
|
||||||
|
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
||||||
|
self.assertLen(data, 2)
|
||||||
|
|
||||||
|
def test_from_folder_raise_value_error_for_invalid_path(self):
|
||||||
|
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
|
||||||
|
dataset.Dataset.from_folder(dirname='invalid')
|
||||||
|
|
||||||
|
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
|
||||||
|
input_data_dir = test_utils.get_test_data_path('face_stylizer')
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'No images found under given directory'
|
||||||
|
):
|
||||||
|
dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
vendored
Normal file
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 347 KiB |
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
vendored
Normal file
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 336 KiB |
Loading…
Reference in New Issue
Block a user