Refactor the loss functions to initialize the VGG loss function in the init function to avoid duplicated initialization.

PiperOrigin-RevId: 527424556
This commit is contained in:
MediaPipe Team 2023-04-26 17:47:14 -07:00 committed by Copybara-Service
parent baed44ab10
commit b05fd21709

View File

@ -144,6 +144,7 @@ class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
"""Initializes ImagePerceptualQualityLoss.""" """Initializes ImagePerceptualQualityLoss."""
self._loss_weight = loss_weight self._loss_weight = loss_weight
self._losses = {} self._losses = {}
self._vgg_loss = VGGPerceptualLoss(self._loss_weight)
self._reduction = reduction self._reduction = reduction
def _l1_loss( def _l1_loss(
@ -164,7 +165,7 @@ class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
self._loss_weight = PerceptualLossWeight() self._loss_weight = PerceptualLossWeight()
if self._loss_weight.content > 0 or self._loss_weight.style > 0: if self._loss_weight.content > 0 or self._loss_weight.style > 0:
vgg_loss = VGGPerceptualLoss(self._loss_weight)(img1, img2) vgg_loss = self._vgg_loss(img1, img2)
vgg_loss_value = tf.math.add_n(vgg_loss.values()) vgg_loss_value = tf.math.add_n(vgg_loss.values())
loss_value.append(vgg_loss_value) loss_value.append(vgg_loss_value)
if self._loss_weight.l1 > 0: if self._loss_weight.l1 > 0: