Add the model configuration and training hyperparameters for BlazeFaceStylizer.
PiperOrigin-RevId: 520767282
This commit is contained in:
		
							parent
							
								
									a4923ca7aa
								
							
						
					
					
						commit
						1eeb89e95f
					
				| 
						 | 
				
			
			@ -26,6 +26,40 @@ filegroup(
 | 
			
		|||
    ]),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "hyperparameters",
 | 
			
		||||
    srcs = ["hyperparameters.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "model_options",
 | 
			
		||||
    srcs = ["model_options.py"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "model_spec",
 | 
			
		||||
    srcs = ["model_spec.py"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "model_spec_test",
 | 
			
		||||
    srcs = ["model_spec_test.py"],
 | 
			
		||||
    deps = [":model_spec"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "face_stylizer_options",
 | 
			
		||||
    srcs = ["face_stylizer_options.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "dataset",
 | 
			
		||||
    srcs = ["dataset.py"],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,36 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Options for building face stylizer."""
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters
 | 
			
		||||
from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt
 | 
			
		||||
from mediapipe.model_maker.python.vision.face_stylizer import model_spec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class FaceStylizerOptions:
 | 
			
		||||
  """Configurable options for building face stylizer.
 | 
			
		||||
 | 
			
		||||
  Attributes:
 | 
			
		||||
    model: A model from the SupportedModels enum.
 | 
			
		||||
    hparams: A set of hyperparameters used to train the face stylizer.
 | 
			
		||||
    model_options: A set of options for configuring the model.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  model: model_spec.SupportedModels
 | 
			
		||||
  model_options: Optional[model_opt.FaceStylizerModelOptions] = None
 | 
			
		||||
  hparams: Optional[hyperparameters.HParams] = None
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,39 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Hyperparameters for training on-device face stylization models."""
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class HParams(hp.BaseHParams):
 | 
			
		||||
  """The hyperparameters for training face stylizers.
 | 
			
		||||
 | 
			
		||||
  Attributes:
 | 
			
		||||
    learning_rate: Learning rate to use for gradient descent training.
 | 
			
		||||
    batch_size: Batch size for training.
 | 
			
		||||
    epochs: Number of training iterations.
 | 
			
		||||
    beta_1: beta_1 used in tf.keras.optimizers.Adam.
 | 
			
		||||
    beta_2: beta_2 used in tf.keras.optimizers.Adam.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  # Parameters from BaseHParams class.
 | 
			
		||||
  learning_rate: float = 5e-5
 | 
			
		||||
  batch_size: int = 4
 | 
			
		||||
  epochs: int = 100
 | 
			
		||||
  # Parameters for face stylizer.
 | 
			
		||||
  beta_1 = 0.0
 | 
			
		||||
  beta_2 = 0.99
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,37 @@
 | 
			
		|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Configurable model options for face stylizer models."""
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: Add more detailed instructions about hyperparameter tuning.
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class FaceStylizerModelOptions:
 | 
			
		||||
  """Configurable model options for face stylizer models.
 | 
			
		||||
 | 
			
		||||
  Attributes:
 | 
			
		||||
    swap_layers: The layers of feature to be interpolated between encoding
 | 
			
		||||
      features and StyleGAN input features.
 | 
			
		||||
    alpha: Weighting coefficient for swapping layer interpolation.
 | 
			
		||||
    adv_loss_weight: Weighting coeffcieint of adversarial loss versus perceptual
 | 
			
		||||
      loss.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  swap_layers: List[int] = dataclasses.field(
 | 
			
		||||
      default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11]
 | 
			
		||||
  )
 | 
			
		||||
  alpha: float = 1.0
 | 
			
		||||
  adv_loss_weight: float = 1.0
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,63 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Face stylizer model specification."""
 | 
			
		||||
 | 
			
		||||
import enum
 | 
			
		||||
import functools
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelSpec(object):
 | 
			
		||||
  """Specification of face stylizer model."""
 | 
			
		||||
 | 
			
		||||
  mean_rgb = [127.5]
 | 
			
		||||
  stddev_rgb = [127.5]
 | 
			
		||||
 | 
			
		||||
  def __init__(
 | 
			
		||||
      self, style_block_num: int, input_image_shape: List[int], name: str = ''
 | 
			
		||||
  ):
 | 
			
		||||
    """Initializes a new instance of the `ModelSpec` class for face stylizer.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      style_block_num: int, number of style block in the decoder.
 | 
			
		||||
      input_image_shape: list of int, input image shape.
 | 
			
		||||
      name: str, model spec name.
 | 
			
		||||
    """
 | 
			
		||||
    self.style_block_num = style_block_num
 | 
			
		||||
    self.input_image_shape = input_image_shape
 | 
			
		||||
    self.name = name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
blaze_face_stylizer_256_spec = functools.partial(
 | 
			
		||||
    ModelSpec,
 | 
			
		||||
    style_block_num=12,
 | 
			
		||||
    input_image_shape=[256, 256],
 | 
			
		||||
    name='blaze_face_stylizer_256',
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: Document the exposed models.
 | 
			
		||||
@enum.unique
 | 
			
		||||
class SupportedModels(enum.Enum):
 | 
			
		||||
  """Face stylizer model supported by MediaPipe model maker."""
 | 
			
		||||
 | 
			
		||||
  BLAZE_FACE_STYLIZER_256 = blaze_face_stylizer_256_spec
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
  def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
 | 
			
		||||
    """Gets model spec from the input enum and initializes it."""
 | 
			
		||||
    if spec not in cls:
 | 
			
		||||
      raise TypeError('Unsupported face stylizer spec: {}'.format(spec))
 | 
			
		||||
 | 
			
		||||
    return spec.value()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,44 @@
 | 
			
		|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the 'License');
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an 'AS IS' BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelSpecTest(tf.test.TestCase):
 | 
			
		||||
 | 
			
		||||
  def test_predefine_spec(self):
 | 
			
		||||
    blaze_face_stylizer_256_spec = ms.blaze_face_stylizer_256_spec()
 | 
			
		||||
    self.assertIsInstance(blaze_face_stylizer_256_spec, ms.ModelSpec)
 | 
			
		||||
    self.assertEqual(blaze_face_stylizer_256_spec.style_block_num, 12)
 | 
			
		||||
    self.assertAllEqual(
 | 
			
		||||
        blaze_face_stylizer_256_spec.input_image_shape, [256, 256]
 | 
			
		||||
    )
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        blaze_face_stylizer_256_spec.name, 'blaze_face_stylizer_256'
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  def test_predefine_spec_enum(self):
 | 
			
		||||
    blaze_face_stylizer_256 = ms.SupportedModels.BLAZE_FACE_STYLIZER_256
 | 
			
		||||
    spec = ms.SupportedModels.get(blaze_face_stylizer_256)
 | 
			
		||||
    self.assertIsInstance(spec, ms.ModelSpec)
 | 
			
		||||
    self.assertEqual(spec.style_block_num, 12)
 | 
			
		||||
    self.assertAllEqual(spec.input_image_shape, [256, 256])
 | 
			
		||||
    self.assertEqual(spec.name, 'blaze_face_stylizer_256')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  tf.test.main()
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user