Refactor the loss functions to initialize the VGG loss function in the init function to avoid duplicated initialization.
PiperOrigin-RevId: 527424556
This commit is contained in:
parent
baed44ab10
commit
b05fd21709
|
@ -144,6 +144,7 @@ class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
|
||||||
"""Initializes ImagePerceptualQualityLoss."""
|
"""Initializes ImagePerceptualQualityLoss."""
|
||||||
self._loss_weight = loss_weight
|
self._loss_weight = loss_weight
|
||||||
self._losses = {}
|
self._losses = {}
|
||||||
|
self._vgg_loss = VGGPerceptualLoss(self._loss_weight)
|
||||||
self._reduction = reduction
|
self._reduction = reduction
|
||||||
|
|
||||||
def _l1_loss(
|
def _l1_loss(
|
||||||
|
@ -164,7 +165,7 @@ class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
|
||||||
self._loss_weight = PerceptualLossWeight()
|
self._loss_weight = PerceptualLossWeight()
|
||||||
|
|
||||||
if self._loss_weight.content > 0 or self._loss_weight.style > 0:
|
if self._loss_weight.content > 0 or self._loss_weight.style > 0:
|
||||||
vgg_loss = VGGPerceptualLoss(self._loss_weight)(img1, img2)
|
vgg_loss = self._vgg_loss(img1, img2)
|
||||||
vgg_loss_value = tf.math.add_n(vgg_loss.values())
|
vgg_loss_value = tf.math.add_n(vgg_loss.values())
|
||||||
loss_value.append(vgg_loss_value)
|
loss_value.append(vgg_loss_value)
|
||||||
if self._loss_weight.l1 > 0:
|
if self._loss_weight.l1 > 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user