Internal change
PiperOrigin-RevId: 526300079
This commit is contained in:
parent
a6c1bb6324
commit
abded49e5b
|
@ -16,7 +16,7 @@
|
||||||
import abc
|
import abc
|
||||||
from typing import Mapping, Sequence
|
from typing import Mapping, Sequence
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -130,6 +130,51 @@ class PerceptualLossWeight:
|
||||||
style: float = 1.0
|
style: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
|
||||||
|
"""Image perceptual quality loss.
|
||||||
|
|
||||||
|
It obtains a weighted loss between the VGGPerceptualLoss and L1 loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
loss_weight: Optional[PerceptualLossWeight] = None,
|
||||||
|
reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
|
||||||
|
):
|
||||||
|
"""Initializes ImagePerceptualQualityLoss."""
|
||||||
|
self._loss_weight = loss_weight
|
||||||
|
self._losses = {}
|
||||||
|
self._reduction = reduction
|
||||||
|
|
||||||
|
def _l1_loss(
|
||||||
|
self,
|
||||||
|
reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
|
||||||
|
) -> Any:
|
||||||
|
"""Calculates L1 loss."""
|
||||||
|
return tf.keras.losses.MeanAbsoluteError(reduction)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
img1: tf.Tensor,
|
||||||
|
img2: tf.Tensor,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""Computes image perceptual quality loss."""
|
||||||
|
loss_value = []
|
||||||
|
if self._loss_weight is None:
|
||||||
|
self._loss_weight = PerceptualLossWeight()
|
||||||
|
|
||||||
|
if self._loss_weight.content > 0 or self._loss_weight.style > 0:
|
||||||
|
vgg_loss = VGGPerceptualLoss(self._loss_weight)(img1, img2)
|
||||||
|
vgg_loss_value = tf.math.add_n(vgg_loss.values())
|
||||||
|
loss_value.append(vgg_loss_value)
|
||||||
|
if self._loss_weight.l1 > 0:
|
||||||
|
l1_loss = self._l1_loss(reduction=self._reduction)(img1, img2)
|
||||||
|
l1_loss_value = tf_utils.safe_mean(l1_loss * self._loss_weight.l1)
|
||||||
|
loss_value.append(l1_loss_value)
|
||||||
|
total_loss = tf.math.add_n(loss_value)
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
|
||||||
class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
|
class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
|
||||||
"""Base class for perceptual loss model."""
|
"""Base class for perceptual loss model."""
|
||||||
|
|
||||||
|
|
|
@ -278,5 +278,51 @@ class VGGPerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePerceptualQualityLossTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
# condition when downloading model since these tests may run in parallel.
|
||||||
|
mock_gettempdir = unittest_mock.patch.object(
|
||||||
|
tempfile,
|
||||||
|
'gettempdir',
|
||||||
|
return_value=self.create_tempdir(),
|
||||||
|
autospec=True,
|
||||||
|
)
|
||||||
|
self.mock_gettempdir = mock_gettempdir.start()
|
||||||
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1)
|
||||||
|
self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='default_loss_weight',
|
||||||
|
loss_weight=None,
|
||||||
|
loss_value=2.501612,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='customized_loss_weight_zero_l1',
|
||||||
|
loss_weight=loss_functions.PerceptualLossWeight(
|
||||||
|
l1=0.0, style=10.0, content=20.0
|
||||||
|
),
|
||||||
|
loss_value=34.032139,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='customized_loss_weight_nonzero_l1',
|
||||||
|
loss_weight=loss_functions.PerceptualLossWeight(
|
||||||
|
l1=10.0, style=10.0, content=20.0
|
||||||
|
),
|
||||||
|
loss_value=42.032139,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_image_perceptual_quality_loss(self, loss_weight, loss_value):
|
||||||
|
image_quality_loss = loss_functions.ImagePerceptualQualityLoss(
|
||||||
|
loss_weight=loss_weight
|
||||||
|
)
|
||||||
|
loss = image_quality_loss(img1=self._img1, img2=self._img2)
|
||||||
|
self.assertNear(loss, loss_value, 1e-4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user