Internal change

PiperOrigin-RevId: 526235882
This commit is contained in:
MediaPipe Team 2023-04-22 00:26:56 -07:00 committed by Copybara-Service
parent 58dcbc9833
commit a6c1bb6324
3 changed files with 403 additions and 4 deletions

View File

@ -67,11 +67,18 @@ py_library(
name = "loss_functions",
srcs = ["loss_functions.py"],
srcs_version = "PY3",
deps = [
":file_util",
":model_util",
],
)
py_test(
name = "loss_functions_test",
srcs = ["loss_functions_test.py"],
tags = [
"requires-net:external",
],
deps = [":loss_functions"],
)

View File

@ -13,10 +13,21 @@
# limitations under the License.
"""Loss function utility library."""
from typing import Optional, Sequence
import abc
from typing import Mapping, Sequence
import dataclasses
from typing import Optional
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.core.utils import model_util
from official.modeling import tf_utils
_VGG_IMAGENET_PERCEPTUAL_MODEL_URL = 'https://storage.googleapis.com/mediapipe-assets/vgg_feature_extractor.tar.gz'
class FocalLoss(tf.keras.losses.Loss):
"""Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
@ -45,7 +56,6 @@ class FocalLoss(tf.keras.losses.Loss):
```python
model.compile(optimizer='sgd', loss=FocalLoss(gamma))
```
"""
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
@ -103,3 +113,206 @@ class FocalLoss(tf.keras.losses.Loss):
# By default, this function uses "sum_over_batch_size" reduction for the
# loss per batch.
return tf.reduce_sum(losses) / batch_size
@dataclasses.dataclass
class PerceptualLossWeight:
"""The weight for each perceptual loss.
Attributes:
l1: weight for L1 loss.
content: weight for content loss.
style: weight for style loss.
"""
l1: float = 1.0
content: float = 1.0
style: float = 1.0
class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
"""Base class for perceptual loss model."""
def __init__(
self,
feature_weight: Optional[Sequence[float]] = None,
loss_weight: Optional[PerceptualLossWeight] = None,
):
"""Instantiates perceptual loss.
Args:
feature_weight: The weight coeffcients of multiple model extracted
features used for calculating the perceptual loss.
loss_weight: The weight coefficients between `style_loss` and
`content_loss`.
"""
super().__init__()
self._loss_op = lambda x, y: tf.math.reduce_mean(tf.abs(x - y))
self._loss_style = tf.constant(0.0)
self._loss_content = tf.constant(0.0)
self._feature_weight = feature_weight
self._loss_weight = loss_weight
def __call__(
self,
img1: tf.Tensor,
img2: tf.Tensor,
) -> Mapping[str, tf.Tensor]:
"""Computes perceptual loss between two images.
Args:
img1: First batch of images. The pixel values should be normalized to [-1,
1].
img2: Second batch of images. The pixel values should be normalized to
[-1, 1].
Returns:
A mapping between loss name and loss tensors.
"""
x_features = self._compute_features(img1)
y_features = self._compute_features(img2)
if self._loss_weight is None:
self._loss_weight = PerceptualLossWeight()
# If the _feature_weight is not initialized, then initialize it as a list of
# all the element equals to 1.0.
if self._feature_weight is None:
self._feature_weight = [1.0] * len(x_features)
# If the length of _feature_weight smallert than the length of the feature,
# raise a ValueError. Otherwise, only use the first len(x_features) weight
# for computing the loss.
if len(self._feature_weight) < len(x_features):
raise ValueError(
f'Input feature weight length {len(self._feature_weight)} is smaller'
f' than feature length {len(x_features)}'
)
if self._loss_weight.style > 0.0:
self._loss_style = tf_utils.safe_mean(
self._loss_weight.style
* self._get_style_loss(x_feats=x_features, y_feats=y_features)
)
if self._loss_weight.content > 0.0:
self._loss_content = tf_utils.safe_mean(
self._loss_weight.content
* self._get_content_loss(x_feats=x_features, y_feats=y_features)
)
return {'style_loss': self._loss_style, 'content_loss': self._loss_content}
@abc.abstractmethod
def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
"""Computes features from the given image tensor.
Args:
img: Image tensor.
Returns:
A list of multi-scale feature maps.
"""
def _get_content_loss(
self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor]
) -> tf.Tensor:
"""Gets weighted multi-scale content loss.
Args:
x_feats: Reconstructed face image.
y_feats: Target face image.
Returns:
A scalar tensor for the content loss.
"""
content_losses = []
for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats):
content_loss = self._loss_op(x_feat, y_feat) * coef
content_losses.append(content_loss)
return tf.math.reduce_sum(content_losses)
def _get_style_loss(
self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor]
) -> tf.Tensor:
"""Gets weighted multi-scale style loss.
Args:
x_feats: Reconstructed face image.
y_feats: Target face image.
Returns:
A scalar tensor for the style loss.
"""
style_losses = []
i = 0
for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats):
x_feat_g = _compute_gram_matrix(x_feat)
y_feat_g = _compute_gram_matrix(y_feat)
style_loss = self._loss_op(x_feat_g, y_feat_g) * coef
style_losses.append(style_loss)
i = i + 1
return tf.math.reduce_sum(style_loss)
class VGGPerceptualLoss(PerceptualLoss):
"""Perceptual loss based on VGG19 pretrained on the ImageNet dataset.
Reference:
- [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](
https://arxiv.org/abs/1603.08155) (ECCV 2016)
Perceptual loss measures high-level perceptual and semantic differences
between images.
"""
def __init__(
self,
loss_weight: Optional[PerceptualLossWeight] = None,
):
"""Initializes image quality loss essentials.
Args:
loss_weight: Loss weight coefficients.
"""
super().__init__(
feature_weight=np.array([0.1, 0.1, 1.0, 1.0, 1.0]),
loss_weight=loss_weight,
)
rgb_mean = tf.constant([0.485, 0.456, 0.406])
rgb_std = tf.constant([0.229, 0.224, 0.225])
self._rgb_mean = tf.reshape(rgb_mean, (1, 1, 1, 3))
self._rgb_std = tf.reshape(rgb_std, (1, 1, 1, 3))
model_path = file_util.DownloadedFiles(
'vgg_feature_extractor',
_VGG_IMAGENET_PERCEPTUAL_MODEL_URL,
is_folder=True,
)
self._vgg19 = model_util.load_keras_model(model_path.get_path())
def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
"""Computes VGG19 features."""
img = (img + 1) / 2.0
norm_img = (img - self._rgb_mean) / self._rgb_std
# no grad, as it only serves as a frozen feature extractor.
return self._vgg19(norm_img)
def _compute_gram_matrix(feature: tf.Tensor) -> tf.Tensor:
"""Computes gram matrix for the feature map.
Args:
feature: [B, H, W, C] feature map.
Returns:
[B, C, C] gram matrix.
"""
h, w, c = feature.shape[1:].as_list()
feat_reshaped = tf.reshape(feature, shape=(-1, h * w, c))
feat_gram = tf.matmul(
tf.transpose(feat_reshaped, perm=[0, 2, 1]), feat_reshaped
)
return feat_gram / (c * h * w)

View File

@ -13,7 +13,9 @@
# limitations under the License.
import math
from typing import Optional
import tempfile
from typing import Dict, Optional, Sequence
from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
@ -21,7 +23,7 @@ import tensorflow as tf
from mediapipe.model_maker.python.core.utils import loss_functions
class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(testcase_name='no_sample_weight', sample_weight=None),
@ -99,5 +101,182 @@ class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
self.assertNear(loss, expected_loss, 1e-4)
class MockPerceptualLoss(loss_functions.PerceptualLoss):
"""A mock class with implementation of abstract methods for testing."""
def __init__(
self,
use_mock_loss_op: bool = False,
feature_weight: Optional[Sequence[float]] = None,
loss_weight: Optional[loss_functions.PerceptualLossWeight] = None,
):
super().__init__(feature_weight=feature_weight, loss_weight=loss_weight)
if use_mock_loss_op:
self._loss_op = lambda x, y: tf.math.reduce_mean(x - y)
def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
return [tf.random.normal(shape=(1, 8, 8, 3))] * 5
class PerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._img1 = tf.fill(dims=(8, 8), value=0.2)
self._img2 = tf.fill(dims=(8, 8), value=0.8)
def test_invalid_feature_weight_raise_value_error(self):
with self.assertRaisesRegex(
ValueError,
'Input feature weight length 2 is smaller than feature length 5',
):
MockPerceptualLoss(feature_weight=[1.0, 2.0])(
img1=self._img1, img2=self._img2
)
@parameterized.named_parameters(
dict(
testcase_name='default_loss_weight_and_loss_op',
use_mock_loss_op=False,
feature_weight=None,
loss_weight=None,
loss_values={
'style_loss': 0.032839,
'content_loss': 5.639870,
},
),
dict(
testcase_name='style_loss_weight_is_0_default_loss_op',
use_mock_loss_op=False,
feature_weight=None,
loss_weight=loss_functions.PerceptualLossWeight(style=0),
loss_values={
'style_loss': 0,
'content_loss': 5.639870,
},
),
dict(
testcase_name='content_loss_weight_is_0_default_loss_op',
use_mock_loss_op=False,
feature_weight=None,
loss_weight=loss_functions.PerceptualLossWeight(content=0),
loss_values={
'style_loss': 0.032839,
'content_loss': 0,
},
),
dict(
testcase_name='customized_loss_weight_default_loss_op',
use_mock_loss_op=False,
feature_weight=None,
loss_weight=loss_functions.PerceptualLossWeight(
style=1.0, content=2.0
),
loss_values={'style_loss': 0.032839, 'content_loss': 11.279739},
),
dict(
testcase_name=(
'customized_feature_weight_and_loss_weight_default_loss_op'
),
use_mock_loss_op=False,
feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0],
loss_weight=loss_functions.PerceptualLossWeight(
style=1.0, content=2.0
),
loss_values={'style_loss': 0.164193, 'content_loss': 33.839218},
),
dict(
testcase_name='no_loss_change_if_extra_feature_weight_provided',
use_mock_loss_op=False,
feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
loss_weight=loss_functions.PerceptualLossWeight(
style=1.0, content=2.0
),
loss_values={
'style_loss': 0.164193,
'content_loss': 33.839218,
},
),
dict(
testcase_name='customized_loss_weight_custom_loss_op',
use_mock_loss_op=True,
feature_weight=None,
loss_weight=loss_functions.PerceptualLossWeight(
style=1.0, content=2.0
),
loss_values={'style_loss': 0.000395, 'content_loss': -1.533469},
),
)
def test_weighted_perceptul_loss(
self,
use_mock_loss_op: bool,
feature_weight: Sequence[float],
loss_weight: loss_functions.PerceptualLossWeight,
loss_values: Dict[str, float],
):
perceptual_loss = MockPerceptualLoss(
use_mock_loss_op=use_mock_loss_op,
feature_weight=feature_weight,
loss_weight=loss_weight,
)
loss = perceptual_loss(img1=self._img1, img2=self._img2)
self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss'])
self.assertNear(loss['style_loss'], loss_values['style_loss'], 1e-4)
self.assertNear(loss['content_loss'], loss_values['content_loss'], 1e-4)
class VGGPerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1)
self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9)
@parameterized.named_parameters(
dict(
testcase_name='default_loss_weight',
loss_weight=None,
loss_values={
'style_loss': 5.8363257e-06,
'content_loss': 1.7016045,
},
),
dict(
testcase_name='customized_loss_weight',
loss_weight=loss_functions.PerceptualLossWeight(
style=10.0, content=20.0
),
loss_values={
'style_loss': 5.8363257e-05,
'content_loss': 34.03208,
},
),
)
def test_vgg_perceptual_loss(self, loss_weight, loss_values):
vgg_loss = loss_functions.VGGPerceptualLoss(loss_weight=loss_weight)
loss = vgg_loss(img1=self._img1, img2=self._img2)
self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss'])
self.assertNear(
loss['style_loss'],
loss_values['style_loss'],
loss_values['style_loss'] / 1e5,
)
self.assertNear(
loss['content_loss'],
loss_values['content_loss'],
loss_values['content_loss'] / 1e5,
)
if __name__ == '__main__':
tf.test.main()