Internal change
PiperOrigin-RevId: 526300079
This commit is contained in:
parent
a6c1bb6324
commit
abded49e5b
|
@ -16,7 +16,7 @@
|
|||
import abc
|
||||
from typing import Mapping, Sequence
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
@ -130,6 +130,51 @@ class PerceptualLossWeight:
|
|||
style: float = 1.0
|
||||
|
||||
|
||||
class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
|
||||
"""Image perceptual quality loss.
|
||||
|
||||
It obtains a weighted loss between the VGGPerceptualLoss and L1 loss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loss_weight: Optional[PerceptualLossWeight] = None,
|
||||
reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
|
||||
):
|
||||
"""Initializes ImagePerceptualQualityLoss."""
|
||||
self._loss_weight = loss_weight
|
||||
self._losses = {}
|
||||
self._reduction = reduction
|
||||
|
||||
def _l1_loss(
|
||||
self,
|
||||
reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
|
||||
) -> Any:
|
||||
"""Calculates L1 loss."""
|
||||
return tf.keras.losses.MeanAbsoluteError(reduction)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
img1: tf.Tensor,
|
||||
img2: tf.Tensor,
|
||||
) -> tf.Tensor:
|
||||
"""Computes image perceptual quality loss."""
|
||||
loss_value = []
|
||||
if self._loss_weight is None:
|
||||
self._loss_weight = PerceptualLossWeight()
|
||||
|
||||
if self._loss_weight.content > 0 or self._loss_weight.style > 0:
|
||||
vgg_loss = VGGPerceptualLoss(self._loss_weight)(img1, img2)
|
||||
vgg_loss_value = tf.math.add_n(vgg_loss.values())
|
||||
loss_value.append(vgg_loss_value)
|
||||
if self._loss_weight.l1 > 0:
|
||||
l1_loss = self._l1_loss(reduction=self._reduction)(img1, img2)
|
||||
l1_loss_value = tf_utils.safe_mean(l1_loss * self._loss_weight.l1)
|
||||
loss_value.append(l1_loss_value)
|
||||
total_loss = tf.math.add_n(loss_value)
|
||||
return total_loss
|
||||
|
||||
|
||||
class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
|
||||
"""Base class for perceptual loss model."""
|
||||
|
||||
|
|
|
@ -278,5 +278,51 @@ class VGGPerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
|
||||
|
||||
class ImagePerceptualQualityLossTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||
# condition when downloading model since these tests may run in parallel.
|
||||
mock_gettempdir = unittest_mock.patch.object(
|
||||
tempfile,
|
||||
'gettempdir',
|
||||
return_value=self.create_tempdir(),
|
||||
autospec=True,
|
||||
)
|
||||
self.mock_gettempdir = mock_gettempdir.start()
|
||||
self.addCleanup(mock_gettempdir.stop)
|
||||
self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1)
|
||||
self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='default_loss_weight',
|
||||
loss_weight=None,
|
||||
loss_value=2.501612,
|
||||
),
|
||||
dict(
|
||||
testcase_name='customized_loss_weight_zero_l1',
|
||||
loss_weight=loss_functions.PerceptualLossWeight(
|
||||
l1=0.0, style=10.0, content=20.0
|
||||
),
|
||||
loss_value=34.032139,
|
||||
),
|
||||
dict(
|
||||
testcase_name='customized_loss_weight_nonzero_l1',
|
||||
loss_weight=loss_functions.PerceptualLossWeight(
|
||||
l1=10.0, style=10.0, content=20.0
|
||||
),
|
||||
loss_value=42.032139,
|
||||
),
|
||||
)
|
||||
def test_image_perceptual_quality_loss(self, loss_weight, loss_value):
|
||||
image_quality_loss = loss_functions.ImagePerceptualQualityLoss(
|
||||
loss_weight=loss_weight
|
||||
)
|
||||
loss = image_quality_loss(img1=self._img1, img2=self._img2)
|
||||
self.assertNear(loss, loss_value, 1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in New Issue
Block a user