Internal change

PiperOrigin-RevId: 526300079
This commit is contained in:
MediaPipe Team 2023-04-22 10:50:10 -07:00 committed by Copybara-Service
parent a6c1bb6324
commit abded49e5b
2 changed files with 92 additions and 1 deletions

View File

@ -16,7 +16,7 @@
import abc
from typing import Mapping, Sequence
import dataclasses
from typing import Optional
from typing import Any, Optional
import numpy as np
import tensorflow as tf
@ -130,6 +130,51 @@ class PerceptualLossWeight:
style: float = 1.0
class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
"""Image perceptual quality loss.
It obtains a weighted loss between the VGGPerceptualLoss and L1 loss.
"""
def __init__(
self,
loss_weight: Optional[PerceptualLossWeight] = None,
reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
):
"""Initializes ImagePerceptualQualityLoss."""
self._loss_weight = loss_weight
self._losses = {}
self._reduction = reduction
def _l1_loss(
self,
reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
) -> Any:
"""Calculates L1 loss."""
return tf.keras.losses.MeanAbsoluteError(reduction)
def __call__(
self,
img1: tf.Tensor,
img2: tf.Tensor,
) -> tf.Tensor:
"""Computes image perceptual quality loss."""
loss_value = []
if self._loss_weight is None:
self._loss_weight = PerceptualLossWeight()
if self._loss_weight.content > 0 or self._loss_weight.style > 0:
vgg_loss = VGGPerceptualLoss(self._loss_weight)(img1, img2)
vgg_loss_value = tf.math.add_n(vgg_loss.values())
loss_value.append(vgg_loss_value)
if self._loss_weight.l1 > 0:
l1_loss = self._l1_loss(reduction=self._reduction)(img1, img2)
l1_loss_value = tf_utils.safe_mean(l1_loss * self._loss_weight.l1)
loss_value.append(l1_loss_value)
total_loss = tf.math.add_n(loss_value)
return total_loss
class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
"""Base class for perceptual loss model."""

View File

@ -278,5 +278,51 @@ class VGGPerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
)
class ImagePerceptualQualityLossTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1)
self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9)
@parameterized.named_parameters(
dict(
testcase_name='default_loss_weight',
loss_weight=None,
loss_value=2.501612,
),
dict(
testcase_name='customized_loss_weight_zero_l1',
loss_weight=loss_functions.PerceptualLossWeight(
l1=0.0, style=10.0, content=20.0
),
loss_value=34.032139,
),
dict(
testcase_name='customized_loss_weight_nonzero_l1',
loss_weight=loss_functions.PerceptualLossWeight(
l1=10.0, style=10.0, content=20.0
),
loss_value=42.032139,
),
)
def test_image_perceptual_quality_loss(self, loss_weight, loss_value):
image_quality_loss = loss_functions.ImagePerceptualQualityLoss(
loss_weight=loss_weight
)
loss = image_quality_loss(img1=self._img1, img2=self._img2)
self.assertNear(loss, loss_value, 1e-4)
if __name__ == '__main__':
tf.test.main()