Internal change
PiperOrigin-RevId: 526235882
This commit is contained in:
parent
58dcbc9833
commit
a6c1bb6324
|
@ -67,11 +67,18 @@ py_library(
|
||||||
name = "loss_functions",
|
name = "loss_functions",
|
||||||
srcs = ["loss_functions.py"],
|
srcs = ["loss_functions.py"],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":file_util",
|
||||||
|
":model_util",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "loss_functions_test",
|
name = "loss_functions_test",
|
||||||
srcs = ["loss_functions_test.py"],
|
srcs = ["loss_functions_test.py"],
|
||||||
|
tags = [
|
||||||
|
"requires-net:external",
|
||||||
|
],
|
||||||
deps = [":loss_functions"],
|
deps = [":loss_functions"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,10 +13,21 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Loss function utility library."""
|
"""Loss function utility library."""
|
||||||
|
|
||||||
from typing import Optional, Sequence
|
import abc
|
||||||
|
from typing import Mapping, Sequence
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
from official.modeling import tf_utils
|
||||||
|
|
||||||
|
|
||||||
|
_VGG_IMAGENET_PERCEPTUAL_MODEL_URL = 'https://storage.googleapis.com/mediapipe-assets/vgg_feature_extractor.tar.gz'
|
||||||
|
|
||||||
|
|
||||||
class FocalLoss(tf.keras.losses.Loss):
|
class FocalLoss(tf.keras.losses.Loss):
|
||||||
"""Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
|
"""Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
|
||||||
|
@ -45,7 +56,6 @@ class FocalLoss(tf.keras.losses.Loss):
|
||||||
```python
|
```python
|
||||||
model.compile(optimizer='sgd', loss=FocalLoss(gamma))
|
model.compile(optimizer='sgd', loss=FocalLoss(gamma))
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
|
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
|
||||||
|
@ -103,3 +113,206 @@ class FocalLoss(tf.keras.losses.Loss):
|
||||||
# By default, this function uses "sum_over_batch_size" reduction for the
|
# By default, this function uses "sum_over_batch_size" reduction for the
|
||||||
# loss per batch.
|
# loss per batch.
|
||||||
return tf.reduce_sum(losses) / batch_size
|
return tf.reduce_sum(losses) / batch_size
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class PerceptualLossWeight:
|
||||||
|
"""The weight for each perceptual loss.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
l1: weight for L1 loss.
|
||||||
|
content: weight for content loss.
|
||||||
|
style: weight for style loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
l1: float = 1.0
|
||||||
|
content: float = 1.0
|
||||||
|
style: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
|
||||||
|
"""Base class for perceptual loss model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_weight: Optional[Sequence[float]] = None,
|
||||||
|
loss_weight: Optional[PerceptualLossWeight] = None,
|
||||||
|
):
|
||||||
|
"""Instantiates perceptual loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_weight: The weight coeffcients of multiple model extracted
|
||||||
|
features used for calculating the perceptual loss.
|
||||||
|
loss_weight: The weight coefficients between `style_loss` and
|
||||||
|
`content_loss`.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._loss_op = lambda x, y: tf.math.reduce_mean(tf.abs(x - y))
|
||||||
|
self._loss_style = tf.constant(0.0)
|
||||||
|
self._loss_content = tf.constant(0.0)
|
||||||
|
self._feature_weight = feature_weight
|
||||||
|
self._loss_weight = loss_weight
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
img1: tf.Tensor,
|
||||||
|
img2: tf.Tensor,
|
||||||
|
) -> Mapping[str, tf.Tensor]:
|
||||||
|
"""Computes perceptual loss between two images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img1: First batch of images. The pixel values should be normalized to [-1,
|
||||||
|
1].
|
||||||
|
img2: Second batch of images. The pixel values should be normalized to
|
||||||
|
[-1, 1].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A mapping between loss name and loss tensors.
|
||||||
|
"""
|
||||||
|
x_features = self._compute_features(img1)
|
||||||
|
y_features = self._compute_features(img2)
|
||||||
|
|
||||||
|
if self._loss_weight is None:
|
||||||
|
self._loss_weight = PerceptualLossWeight()
|
||||||
|
|
||||||
|
# If the _feature_weight is not initialized, then initialize it as a list of
|
||||||
|
# all the element equals to 1.0.
|
||||||
|
if self._feature_weight is None:
|
||||||
|
self._feature_weight = [1.0] * len(x_features)
|
||||||
|
|
||||||
|
# If the length of _feature_weight smallert than the length of the feature,
|
||||||
|
# raise a ValueError. Otherwise, only use the first len(x_features) weight
|
||||||
|
# for computing the loss.
|
||||||
|
if len(self._feature_weight) < len(x_features):
|
||||||
|
raise ValueError(
|
||||||
|
f'Input feature weight length {len(self._feature_weight)} is smaller'
|
||||||
|
f' than feature length {len(x_features)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._loss_weight.style > 0.0:
|
||||||
|
self._loss_style = tf_utils.safe_mean(
|
||||||
|
self._loss_weight.style
|
||||||
|
* self._get_style_loss(x_feats=x_features, y_feats=y_features)
|
||||||
|
)
|
||||||
|
if self._loss_weight.content > 0.0:
|
||||||
|
self._loss_content = tf_utils.safe_mean(
|
||||||
|
self._loss_weight.content
|
||||||
|
* self._get_content_loss(x_feats=x_features, y_feats=y_features)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {'style_loss': self._loss_style, 'content_loss': self._loss_content}
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
|
||||||
|
"""Computes features from the given image tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: Image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of multi-scale feature maps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_content_loss(
|
||||||
|
self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor]
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Gets weighted multi-scale content loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_feats: Reconstructed face image.
|
||||||
|
y_feats: Target face image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar tensor for the content loss.
|
||||||
|
"""
|
||||||
|
content_losses = []
|
||||||
|
for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats):
|
||||||
|
content_loss = self._loss_op(x_feat, y_feat) * coef
|
||||||
|
content_losses.append(content_loss)
|
||||||
|
return tf.math.reduce_sum(content_losses)
|
||||||
|
|
||||||
|
def _get_style_loss(
|
||||||
|
self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor]
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Gets weighted multi-scale style loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_feats: Reconstructed face image.
|
||||||
|
y_feats: Target face image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar tensor for the style loss.
|
||||||
|
"""
|
||||||
|
style_losses = []
|
||||||
|
i = 0
|
||||||
|
for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats):
|
||||||
|
x_feat_g = _compute_gram_matrix(x_feat)
|
||||||
|
y_feat_g = _compute_gram_matrix(y_feat)
|
||||||
|
style_loss = self._loss_op(x_feat_g, y_feat_g) * coef
|
||||||
|
style_losses.append(style_loss)
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
return tf.math.reduce_sum(style_loss)
|
||||||
|
|
||||||
|
|
||||||
|
class VGGPerceptualLoss(PerceptualLoss):
|
||||||
|
"""Perceptual loss based on VGG19 pretrained on the ImageNet dataset.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](
|
||||||
|
https://arxiv.org/abs/1603.08155) (ECCV 2016)
|
||||||
|
|
||||||
|
Perceptual loss measures high-level perceptual and semantic differences
|
||||||
|
between images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
loss_weight: Optional[PerceptualLossWeight] = None,
|
||||||
|
):
|
||||||
|
"""Initializes image quality loss essentials.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_weight: Loss weight coefficients.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
feature_weight=np.array([0.1, 0.1, 1.0, 1.0, 1.0]),
|
||||||
|
loss_weight=loss_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
rgb_mean = tf.constant([0.485, 0.456, 0.406])
|
||||||
|
rgb_std = tf.constant([0.229, 0.224, 0.225])
|
||||||
|
|
||||||
|
self._rgb_mean = tf.reshape(rgb_mean, (1, 1, 1, 3))
|
||||||
|
self._rgb_std = tf.reshape(rgb_std, (1, 1, 1, 3))
|
||||||
|
|
||||||
|
model_path = file_util.DownloadedFiles(
|
||||||
|
'vgg_feature_extractor',
|
||||||
|
_VGG_IMAGENET_PERCEPTUAL_MODEL_URL,
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
self._vgg19 = model_util.load_keras_model(model_path.get_path())
|
||||||
|
|
||||||
|
def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
|
||||||
|
"""Computes VGG19 features."""
|
||||||
|
img = (img + 1) / 2.0
|
||||||
|
norm_img = (img - self._rgb_mean) / self._rgb_std
|
||||||
|
# no grad, as it only serves as a frozen feature extractor.
|
||||||
|
return self._vgg19(norm_img)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_gram_matrix(feature: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""Computes gram matrix for the feature map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature: [B, H, W, C] feature map.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[B, C, C] gram matrix.
|
||||||
|
"""
|
||||||
|
h, w, c = feature.shape[1:].as_list()
|
||||||
|
feat_reshaped = tf.reshape(feature, shape=(-1, h * w, c))
|
||||||
|
feat_gram = tf.matmul(
|
||||||
|
tf.transpose(feat_reshaped, perm=[0, 2, 1]), feat_reshaped
|
||||||
|
)
|
||||||
|
return feat_gram / (c * h * w)
|
||||||
|
|
|
@ -13,7 +13,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
import tempfile
|
||||||
|
from typing import Dict, Optional, Sequence
|
||||||
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -21,7 +23,7 @@ import tensorflow as tf
|
||||||
from mediapipe.model_maker.python.core.utils import loss_functions
|
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||||
|
|
||||||
|
|
||||||
class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
|
class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
dict(testcase_name='no_sample_weight', sample_weight=None),
|
dict(testcase_name='no_sample_weight', sample_weight=None),
|
||||||
|
@ -99,5 +101,182 @@ class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertNear(loss, expected_loss, 1e-4)
|
self.assertNear(loss, expected_loss, 1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
class MockPerceptualLoss(loss_functions.PerceptualLoss):
|
||||||
|
"""A mock class with implementation of abstract methods for testing."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
use_mock_loss_op: bool = False,
|
||||||
|
feature_weight: Optional[Sequence[float]] = None,
|
||||||
|
loss_weight: Optional[loss_functions.PerceptualLossWeight] = None,
|
||||||
|
):
|
||||||
|
super().__init__(feature_weight=feature_weight, loss_weight=loss_weight)
|
||||||
|
if use_mock_loss_op:
|
||||||
|
self._loss_op = lambda x, y: tf.math.reduce_mean(x - y)
|
||||||
|
|
||||||
|
def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
|
||||||
|
return [tf.random.normal(shape=(1, 8, 8, 3))] * 5
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self._img1 = tf.fill(dims=(8, 8), value=0.2)
|
||||||
|
self._img2 = tf.fill(dims=(8, 8), value=0.8)
|
||||||
|
|
||||||
|
def test_invalid_feature_weight_raise_value_error(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
'Input feature weight length 2 is smaller than feature length 5',
|
||||||
|
):
|
||||||
|
MockPerceptualLoss(feature_weight=[1.0, 2.0])(
|
||||||
|
img1=self._img1, img2=self._img2
|
||||||
|
)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='default_loss_weight_and_loss_op',
|
||||||
|
use_mock_loss_op=False,
|
||||||
|
feature_weight=None,
|
||||||
|
loss_weight=None,
|
||||||
|
loss_values={
|
||||||
|
'style_loss': 0.032839,
|
||||||
|
'content_loss': 5.639870,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='style_loss_weight_is_0_default_loss_op',
|
||||||
|
use_mock_loss_op=False,
|
||||||
|
feature_weight=None,
|
||||||
|
loss_weight=loss_functions.PerceptualLossWeight(style=0),
|
||||||
|
loss_values={
|
||||||
|
'style_loss': 0,
|
||||||
|
'content_loss': 5.639870,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='content_loss_weight_is_0_default_loss_op',
|
||||||
|
use_mock_loss_op=False,
|
||||||
|
feature_weight=None,
|
||||||
|
loss_weight=loss_functions.PerceptualLossWeight(content=0),
|
||||||
|
loss_values={
|
||||||
|
'style_loss': 0.032839,
|
||||||
|
'content_loss': 0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='customized_loss_weight_default_loss_op',
|
||||||
|
use_mock_loss_op=False,
|
||||||
|
feature_weight=None,
|
||||||
|
loss_weight=loss_functions.PerceptualLossWeight(
|
||||||
|
style=1.0, content=2.0
|
||||||
|
),
|
||||||
|
loss_values={'style_loss': 0.032839, 'content_loss': 11.279739},
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name=(
|
||||||
|
'customized_feature_weight_and_loss_weight_default_loss_op'
|
||||||
|
),
|
||||||
|
use_mock_loss_op=False,
|
||||||
|
feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||||
|
loss_weight=loss_functions.PerceptualLossWeight(
|
||||||
|
style=1.0, content=2.0
|
||||||
|
),
|
||||||
|
loss_values={'style_loss': 0.164193, 'content_loss': 33.839218},
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='no_loss_change_if_extra_feature_weight_provided',
|
||||||
|
use_mock_loss_op=False,
|
||||||
|
feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
|
||||||
|
loss_weight=loss_functions.PerceptualLossWeight(
|
||||||
|
style=1.0, content=2.0
|
||||||
|
),
|
||||||
|
loss_values={
|
||||||
|
'style_loss': 0.164193,
|
||||||
|
'content_loss': 33.839218,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='customized_loss_weight_custom_loss_op',
|
||||||
|
use_mock_loss_op=True,
|
||||||
|
feature_weight=None,
|
||||||
|
loss_weight=loss_functions.PerceptualLossWeight(
|
||||||
|
style=1.0, content=2.0
|
||||||
|
),
|
||||||
|
loss_values={'style_loss': 0.000395, 'content_loss': -1.533469},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_weighted_perceptul_loss(
|
||||||
|
self,
|
||||||
|
use_mock_loss_op: bool,
|
||||||
|
feature_weight: Sequence[float],
|
||||||
|
loss_weight: loss_functions.PerceptualLossWeight,
|
||||||
|
loss_values: Dict[str, float],
|
||||||
|
):
|
||||||
|
perceptual_loss = MockPerceptualLoss(
|
||||||
|
use_mock_loss_op=use_mock_loss_op,
|
||||||
|
feature_weight=feature_weight,
|
||||||
|
loss_weight=loss_weight,
|
||||||
|
)
|
||||||
|
loss = perceptual_loss(img1=self._img1, img2=self._img2)
|
||||||
|
self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss'])
|
||||||
|
self.assertNear(loss['style_loss'], loss_values['style_loss'], 1e-4)
|
||||||
|
self.assertNear(loss['content_loss'], loss_values['content_loss'], 1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
class VGGPerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
# condition when downloading model since these tests may run in parallel.
|
||||||
|
mock_gettempdir = unittest_mock.patch.object(
|
||||||
|
tempfile,
|
||||||
|
'gettempdir',
|
||||||
|
return_value=self.create_tempdir(),
|
||||||
|
autospec=True,
|
||||||
|
)
|
||||||
|
self.mock_gettempdir = mock_gettempdir.start()
|
||||||
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1)
|
||||||
|
self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='default_loss_weight',
|
||||||
|
loss_weight=None,
|
||||||
|
loss_values={
|
||||||
|
'style_loss': 5.8363257e-06,
|
||||||
|
'content_loss': 1.7016045,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='customized_loss_weight',
|
||||||
|
loss_weight=loss_functions.PerceptualLossWeight(
|
||||||
|
style=10.0, content=20.0
|
||||||
|
),
|
||||||
|
loss_values={
|
||||||
|
'style_loss': 5.8363257e-05,
|
||||||
|
'content_loss': 34.03208,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_vgg_perceptual_loss(self, loss_weight, loss_values):
|
||||||
|
vgg_loss = loss_functions.VGGPerceptualLoss(loss_weight=loss_weight)
|
||||||
|
loss = vgg_loss(img1=self._img1, img2=self._img2)
|
||||||
|
self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss'])
|
||||||
|
self.assertNear(
|
||||||
|
loss['style_loss'],
|
||||||
|
loss_values['style_loss'],
|
||||||
|
loss_values['style_loss'] / 1e5,
|
||||||
|
)
|
||||||
|
self.assertNear(
|
||||||
|
loss['content_loss'],
|
||||||
|
loss_values['content_loss'],
|
||||||
|
loss_values['content_loss'] / 1e5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user