diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 43c3d42f9..907706b3a 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -67,11 +67,18 @@ py_library( name = "loss_functions", srcs = ["loss_functions.py"], srcs_version = "PY3", + deps = [ + ":file_util", + ":model_util", + ], ) py_test( name = "loss_functions_test", srcs = ["loss_functions_test.py"], + tags = [ + "requires-net:external", + ], deps = [":loss_functions"], ) diff --git a/mediapipe/model_maker/python/core/utils/loss_functions.py b/mediapipe/model_maker/python/core/utils/loss_functions.py index 5b0aa32bf..e05cf6f59 100644 --- a/mediapipe/model_maker/python/core/utils/loss_functions.py +++ b/mediapipe/model_maker/python/core/utils/loss_functions.py @@ -13,10 +13,21 @@ # limitations under the License. """Loss function utility library.""" -from typing import Optional, Sequence +import abc +from typing import Mapping, Sequence +import dataclasses +from typing import Optional +import numpy as np import tensorflow as tf +from mediapipe.model_maker.python.core.utils import file_util +from mediapipe.model_maker.python.core.utils import model_util +from official.modeling import tf_utils + + +_VGG_IMAGENET_PERCEPTUAL_MODEL_URL = 'https://storage.googleapis.com/mediapipe-assets/vgg_feature_extractor.tar.gz' + class FocalLoss(tf.keras.losses.Loss): """Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf). @@ -45,7 +56,6 @@ class FocalLoss(tf.keras.losses.Loss): ```python model.compile(optimizer='sgd', loss=FocalLoss(gamma)) ``` - """ def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None): @@ -103,3 +113,206 @@ class FocalLoss(tf.keras.losses.Loss): # By default, this function uses "sum_over_batch_size" reduction for the # loss per batch. return tf.reduce_sum(losses) / batch_size + + +@dataclasses.dataclass +class PerceptualLossWeight: + """The weight for each perceptual loss. + + Attributes: + l1: weight for L1 loss. + content: weight for content loss. + style: weight for style loss. + """ + + l1: float = 1.0 + content: float = 1.0 + style: float = 1.0 + + +class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta): + """Base class for perceptual loss model.""" + + def __init__( + self, + feature_weight: Optional[Sequence[float]] = None, + loss_weight: Optional[PerceptualLossWeight] = None, + ): + """Instantiates perceptual loss. + + Args: + feature_weight: The weight coeffcients of multiple model extracted + features used for calculating the perceptual loss. + loss_weight: The weight coefficients between `style_loss` and + `content_loss`. + """ + super().__init__() + self._loss_op = lambda x, y: tf.math.reduce_mean(tf.abs(x - y)) + self._loss_style = tf.constant(0.0) + self._loss_content = tf.constant(0.0) + self._feature_weight = feature_weight + self._loss_weight = loss_weight + + def __call__( + self, + img1: tf.Tensor, + img2: tf.Tensor, + ) -> Mapping[str, tf.Tensor]: + """Computes perceptual loss between two images. + + Args: + img1: First batch of images. The pixel values should be normalized to [-1, + 1]. + img2: Second batch of images. The pixel values should be normalized to + [-1, 1]. + + Returns: + A mapping between loss name and loss tensors. + """ + x_features = self._compute_features(img1) + y_features = self._compute_features(img2) + + if self._loss_weight is None: + self._loss_weight = PerceptualLossWeight() + + # If the _feature_weight is not initialized, then initialize it as a list of + # all the element equals to 1.0. + if self._feature_weight is None: + self._feature_weight = [1.0] * len(x_features) + + # If the length of _feature_weight smallert than the length of the feature, + # raise a ValueError. Otherwise, only use the first len(x_features) weight + # for computing the loss. + if len(self._feature_weight) < len(x_features): + raise ValueError( + f'Input feature weight length {len(self._feature_weight)} is smaller' + f' than feature length {len(x_features)}' + ) + + if self._loss_weight.style > 0.0: + self._loss_style = tf_utils.safe_mean( + self._loss_weight.style + * self._get_style_loss(x_feats=x_features, y_feats=y_features) + ) + if self._loss_weight.content > 0.0: + self._loss_content = tf_utils.safe_mean( + self._loss_weight.content + * self._get_content_loss(x_feats=x_features, y_feats=y_features) + ) + + return {'style_loss': self._loss_style, 'content_loss': self._loss_content} + + @abc.abstractmethod + def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]: + """Computes features from the given image tensor. + + Args: + img: Image tensor. + + Returns: + A list of multi-scale feature maps. + """ + + def _get_content_loss( + self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor] + ) -> tf.Tensor: + """Gets weighted multi-scale content loss. + + Args: + x_feats: Reconstructed face image. + y_feats: Target face image. + + Returns: + A scalar tensor for the content loss. + """ + content_losses = [] + for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats): + content_loss = self._loss_op(x_feat, y_feat) * coef + content_losses.append(content_loss) + return tf.math.reduce_sum(content_losses) + + def _get_style_loss( + self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor] + ) -> tf.Tensor: + """Gets weighted multi-scale style loss. + + Args: + x_feats: Reconstructed face image. + y_feats: Target face image. + + Returns: + A scalar tensor for the style loss. + """ + style_losses = [] + i = 0 + for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats): + x_feat_g = _compute_gram_matrix(x_feat) + y_feat_g = _compute_gram_matrix(y_feat) + style_loss = self._loss_op(x_feat_g, y_feat_g) * coef + style_losses.append(style_loss) + i = i + 1 + + return tf.math.reduce_sum(style_loss) + + +class VGGPerceptualLoss(PerceptualLoss): + """Perceptual loss based on VGG19 pretrained on the ImageNet dataset. + + Reference: + - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution]( + https://arxiv.org/abs/1603.08155) (ECCV 2016) + + Perceptual loss measures high-level perceptual and semantic differences + between images. + """ + + def __init__( + self, + loss_weight: Optional[PerceptualLossWeight] = None, + ): + """Initializes image quality loss essentials. + + Args: + loss_weight: Loss weight coefficients. + """ + super().__init__( + feature_weight=np.array([0.1, 0.1, 1.0, 1.0, 1.0]), + loss_weight=loss_weight, + ) + + rgb_mean = tf.constant([0.485, 0.456, 0.406]) + rgb_std = tf.constant([0.229, 0.224, 0.225]) + + self._rgb_mean = tf.reshape(rgb_mean, (1, 1, 1, 3)) + self._rgb_std = tf.reshape(rgb_std, (1, 1, 1, 3)) + + model_path = file_util.DownloadedFiles( + 'vgg_feature_extractor', + _VGG_IMAGENET_PERCEPTUAL_MODEL_URL, + is_folder=True, + ) + self._vgg19 = model_util.load_keras_model(model_path.get_path()) + + def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]: + """Computes VGG19 features.""" + img = (img + 1) / 2.0 + norm_img = (img - self._rgb_mean) / self._rgb_std + # no grad, as it only serves as a frozen feature extractor. + return self._vgg19(norm_img) + + +def _compute_gram_matrix(feature: tf.Tensor) -> tf.Tensor: + """Computes gram matrix for the feature map. + + Args: + feature: [B, H, W, C] feature map. + + Returns: + [B, C, C] gram matrix. + """ + h, w, c = feature.shape[1:].as_list() + feat_reshaped = tf.reshape(feature, shape=(-1, h * w, c)) + feat_gram = tf.matmul( + tf.transpose(feat_reshaped, perm=[0, 2, 1]), feat_reshaped + ) + return feat_gram / (c * h * w) diff --git a/mediapipe/model_maker/python/core/utils/loss_functions_test.py b/mediapipe/model_maker/python/core/utils/loss_functions_test.py index 716c329ef..a3d9a8aa7 100644 --- a/mediapipe/model_maker/python/core/utils/loss_functions_test.py +++ b/mediapipe/model_maker/python/core/utils/loss_functions_test.py @@ -13,7 +13,9 @@ # limitations under the License. import math -from typing import Optional +import tempfile +from typing import Dict, Optional, Sequence +from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf @@ -21,7 +23,7 @@ import tensorflow as tf from mediapipe.model_maker.python.core.utils import loss_functions -class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase): +class FocalLossTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( dict(testcase_name='no_sample_weight', sample_weight=None), @@ -99,5 +101,182 @@ class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase): self.assertNear(loss, expected_loss, 1e-4) +class MockPerceptualLoss(loss_functions.PerceptualLoss): + """A mock class with implementation of abstract methods for testing.""" + + def __init__( + self, + use_mock_loss_op: bool = False, + feature_weight: Optional[Sequence[float]] = None, + loss_weight: Optional[loss_functions.PerceptualLossWeight] = None, + ): + super().__init__(feature_weight=feature_weight, loss_weight=loss_weight) + if use_mock_loss_op: + self._loss_op = lambda x, y: tf.math.reduce_mean(x - y) + + def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]: + return [tf.random.normal(shape=(1, 8, 8, 3))] * 5 + + +class PerceptualLossTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self._img1 = tf.fill(dims=(8, 8), value=0.2) + self._img2 = tf.fill(dims=(8, 8), value=0.8) + + def test_invalid_feature_weight_raise_value_error(self): + with self.assertRaisesRegex( + ValueError, + 'Input feature weight length 2 is smaller than feature length 5', + ): + MockPerceptualLoss(feature_weight=[1.0, 2.0])( + img1=self._img1, img2=self._img2 + ) + + @parameterized.named_parameters( + dict( + testcase_name='default_loss_weight_and_loss_op', + use_mock_loss_op=False, + feature_weight=None, + loss_weight=None, + loss_values={ + 'style_loss': 0.032839, + 'content_loss': 5.639870, + }, + ), + dict( + testcase_name='style_loss_weight_is_0_default_loss_op', + use_mock_loss_op=False, + feature_weight=None, + loss_weight=loss_functions.PerceptualLossWeight(style=0), + loss_values={ + 'style_loss': 0, + 'content_loss': 5.639870, + }, + ), + dict( + testcase_name='content_loss_weight_is_0_default_loss_op', + use_mock_loss_op=False, + feature_weight=None, + loss_weight=loss_functions.PerceptualLossWeight(content=0), + loss_values={ + 'style_loss': 0.032839, + 'content_loss': 0, + }, + ), + dict( + testcase_name='customized_loss_weight_default_loss_op', + use_mock_loss_op=False, + feature_weight=None, + loss_weight=loss_functions.PerceptualLossWeight( + style=1.0, content=2.0 + ), + loss_values={'style_loss': 0.032839, 'content_loss': 11.279739}, + ), + dict( + testcase_name=( + 'customized_feature_weight_and_loss_weight_default_loss_op' + ), + use_mock_loss_op=False, + feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0], + loss_weight=loss_functions.PerceptualLossWeight( + style=1.0, content=2.0 + ), + loss_values={'style_loss': 0.164193, 'content_loss': 33.839218}, + ), + dict( + testcase_name='no_loss_change_if_extra_feature_weight_provided', + use_mock_loss_op=False, + feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + loss_weight=loss_functions.PerceptualLossWeight( + style=1.0, content=2.0 + ), + loss_values={ + 'style_loss': 0.164193, + 'content_loss': 33.839218, + }, + ), + dict( + testcase_name='customized_loss_weight_custom_loss_op', + use_mock_loss_op=True, + feature_weight=None, + loss_weight=loss_functions.PerceptualLossWeight( + style=1.0, content=2.0 + ), + loss_values={'style_loss': 0.000395, 'content_loss': -1.533469}, + ), + ) + def test_weighted_perceptul_loss( + self, + use_mock_loss_op: bool, + feature_weight: Sequence[float], + loss_weight: loss_functions.PerceptualLossWeight, + loss_values: Dict[str, float], + ): + perceptual_loss = MockPerceptualLoss( + use_mock_loss_op=use_mock_loss_op, + feature_weight=feature_weight, + loss_weight=loss_weight, + ) + loss = perceptual_loss(img1=self._img1, img2=self._img2) + self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss']) + self.assertNear(loss['style_loss'], loss_values['style_loss'], 1e-4) + self.assertNear(loss['content_loss'], loss_values['content_loss'], 1e-4) + + +class VGGPerceptualLossTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + # Mock tempfile.gettempdir() to be unique for each test to avoid race + # condition when downloading model since these tests may run in parallel. + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1) + self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9) + + @parameterized.named_parameters( + dict( + testcase_name='default_loss_weight', + loss_weight=None, + loss_values={ + 'style_loss': 5.8363257e-06, + 'content_loss': 1.7016045, + }, + ), + dict( + testcase_name='customized_loss_weight', + loss_weight=loss_functions.PerceptualLossWeight( + style=10.0, content=20.0 + ), + loss_values={ + 'style_loss': 5.8363257e-05, + 'content_loss': 34.03208, + }, + ), + ) + def test_vgg_perceptual_loss(self, loss_weight, loss_values): + vgg_loss = loss_functions.VGGPerceptualLoss(loss_weight=loss_weight) + loss = vgg_loss(img1=self._img1, img2=self._img2) + self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss']) + self.assertNear( + loss['style_loss'], + loss_values['style_loss'], + loss_values['style_loss'] / 1e5, + ) + self.assertNear( + loss['content_loss'], + loss_values['content_loss'], + loss_values['content_loss'] / 1e5, + ) + + if __name__ == '__main__': tf.test.main()