Refactor the loss functions to initialize the VGG loss function in the init function to avoid duplicated initialization.
PiperOrigin-RevId: 527424556
This commit is contained in:
parent
baed44ab10
commit
b05fd21709
|
@ -144,6 +144,7 @@ class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
|
|||
"""Initializes ImagePerceptualQualityLoss."""
|
||||
self._loss_weight = loss_weight
|
||||
self._losses = {}
|
||||
self._vgg_loss = VGGPerceptualLoss(self._loss_weight)
|
||||
self._reduction = reduction
|
||||
|
||||
def _l1_loss(
|
||||
|
@ -164,7 +165,7 @@ class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
|
|||
self._loss_weight = PerceptualLossWeight()
|
||||
|
||||
if self._loss_weight.content > 0 or self._loss_weight.style > 0:
|
||||
vgg_loss = VGGPerceptualLoss(self._loss_weight)(img1, img2)
|
||||
vgg_loss = self._vgg_loss(img1, img2)
|
||||
vgg_loss_value = tf.math.add_n(vgg_loss.values())
|
||||
loss_value.append(vgg_loss_value)
|
||||
if self._loss_weight.l1 > 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user