Internal change

PiperOrigin-RevId: 526300079
This commit is contained in:
MediaPipe Team 2023-04-22 10:50:10 -07:00 committed by Copybara-Service
parent a6c1bb6324
commit abded49e5b
2 changed files with 92 additions and 1 deletions

View File

@ -16,7 +16,7 @@
import abc import abc
from typing import Mapping, Sequence from typing import Mapping, Sequence
import dataclasses import dataclasses
from typing import Optional from typing import Any, Optional
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -130,6 +130,51 @@ class PerceptualLossWeight:
style: float = 1.0 style: float = 1.0
class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
"""Image perceptual quality loss.
It obtains a weighted loss between the VGGPerceptualLoss and L1 loss.
"""
def __init__(
self,
loss_weight: Optional[PerceptualLossWeight] = None,
reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
):
"""Initializes ImagePerceptualQualityLoss."""
self._loss_weight = loss_weight
self._losses = {}
self._reduction = reduction
def _l1_loss(
self,
reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
) -> Any:
"""Calculates L1 loss."""
return tf.keras.losses.MeanAbsoluteError(reduction)
def __call__(
self,
img1: tf.Tensor,
img2: tf.Tensor,
) -> tf.Tensor:
"""Computes image perceptual quality loss."""
loss_value = []
if self._loss_weight is None:
self._loss_weight = PerceptualLossWeight()
if self._loss_weight.content > 0 or self._loss_weight.style > 0:
vgg_loss = VGGPerceptualLoss(self._loss_weight)(img1, img2)
vgg_loss_value = tf.math.add_n(vgg_loss.values())
loss_value.append(vgg_loss_value)
if self._loss_weight.l1 > 0:
l1_loss = self._l1_loss(reduction=self._reduction)(img1, img2)
l1_loss_value = tf_utils.safe_mean(l1_loss * self._loss_weight.l1)
loss_value.append(l1_loss_value)
total_loss = tf.math.add_n(loss_value)
return total_loss
class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta): class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
"""Base class for perceptual loss model.""" """Base class for perceptual loss model."""

View File

@ -278,5 +278,51 @@ class VGGPerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
) )
class ImagePerceptualQualityLossTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1)
self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9)
@parameterized.named_parameters(
dict(
testcase_name='default_loss_weight',
loss_weight=None,
loss_value=2.501612,
),
dict(
testcase_name='customized_loss_weight_zero_l1',
loss_weight=loss_functions.PerceptualLossWeight(
l1=0.0, style=10.0, content=20.0
),
loss_value=34.032139,
),
dict(
testcase_name='customized_loss_weight_nonzero_l1',
loss_weight=loss_functions.PerceptualLossWeight(
l1=10.0, style=10.0, content=20.0
),
loss_value=42.032139,
),
)
def test_image_perceptual_quality_loss(self, loss_weight, loss_value):
image_quality_loss = loss_functions.ImagePerceptualQualityLoss(
loss_weight=loss_weight
)
loss = image_quality_loss(img1=self._img1, img2=self._img2)
self.assertNear(loss, loss_value, 1e-4)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()