Add the dataset module for face stylizer in model maker.
PiperOrigin-RevId: 516628350
This commit is contained in:
parent
ade31b567b
commit
6774794d02
48
mediapipe/model_maker/python/vision/face_stylizer/BUILD
Normal file
48
mediapipe/model_maker/python/vision/face_stylizer/BUILD
Normal file
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = glob([
|
||||
"testdata/**",
|
||||
]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
data = [
|
||||
":testdata",
|
||||
],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe Model Maker Python Public API For Face Stylization."""
|
98
mediapipe/model_maker/python/vision/face_stylizer/dataset.py
Normal file
98
mediapipe/model_maker/python/vision/face_stylizer/dataset.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Face stylizer dataset library."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
from mediapipe.model_maker.python.vision.core import image_utils
|
||||
|
||||
|
||||
# TODO: Change to a unlabeled dataset if it makes sense.
|
||||
class Dataset(classification_dataset.ClassificationDataset):
|
||||
"""Dataset library for face stylizer fine tuning."""
|
||||
|
||||
@classmethod
|
||||
def from_folder(
|
||||
cls, dirname: str
|
||||
) -> classification_dataset.ClassificationDataset:
|
||||
"""Loads images from the given directory.
|
||||
|
||||
The style image dataset directory is expected to contain one subdirectory
|
||||
whose name represents the label of the style. There can be one or multiple
|
||||
images of the same style in that subdirectory. Supported input image formats
|
||||
include 'jpg', 'jpeg', 'png'.
|
||||
|
||||
Args:
|
||||
dirname: Name of the directory containing the image files.
|
||||
|
||||
Returns:
|
||||
Dataset containing images and labels and other related info.
|
||||
Raises:
|
||||
ValueError: if the input data directory is empty.
|
||||
"""
|
||||
data_root = os.path.abspath(dirname)
|
||||
|
||||
# Assumes the image data of the same label are in the same subdirectory,
|
||||
# gets image path and label names.
|
||||
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
||||
all_image_size = len(all_image_paths)
|
||||
if all_image_size == 0:
|
||||
raise ValueError('Invalid input data directory')
|
||||
if not any(
|
||||
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
|
||||
):
|
||||
raise ValueError('No images found under given directory')
|
||||
|
||||
label_names = sorted(
|
||||
name
|
||||
for name in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, name))
|
||||
)
|
||||
all_label_size = len(label_names)
|
||||
index_by_label = dict(
|
||||
(name, index) for index, name in enumerate(label_names)
|
||||
)
|
||||
# Get the style label from the subdirectory name.
|
||||
all_image_labels = [
|
||||
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||
for path in all_image_paths
|
||||
]
|
||||
|
||||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||
|
||||
image_ds = path_ds.map(
|
||||
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
|
||||
)
|
||||
|
||||
# Load label
|
||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||
tf.cast(all_image_labels, tf.int64)
|
||||
)
|
||||
|
||||
# Create a dataset of (image, label) pairs
|
||||
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||
|
||||
logging.info(
|
||||
'Load images dataset with size: %d, num_label: %d, labels: %s.',
|
||||
all_image_size,
|
||||
all_label_size,
|
||||
', '.join(label_names),
|
||||
)
|
||||
return Dataset(
|
||||
dataset=image_label_ds, size=all_image_size, label_names=label_names
|
||||
)
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.face_stylizer import dataset
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# TODO: Replace the stylize image dataset with licensed images.
|
||||
self._test_data_dirname = 'testdata'
|
||||
|
||||
def test_from_folder(self):
|
||||
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
|
||||
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||
self.assertEqual(data.num_classes, 2)
|
||||
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
||||
self.assertLen(data, 2)
|
||||
|
||||
def test_from_folder_raise_value_error_for_invalid_path(self):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
|
||||
dataset.Dataset.from_folder(dirname='invalid')
|
||||
|
||||
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
|
||||
input_data_dir = test_utils.get_test_data_path('face_stylizer')
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'No images found under given directory'
|
||||
):
|
||||
dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
vendored
Normal file
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 347 KiB |
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
vendored
Normal file
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 336 KiB |
Loading…
Reference in New Issue
Block a user