Internal change
PiperOrigin-RevId: 526235882
This commit is contained in:
		
							parent
							
								
									58dcbc9833
								
							
						
					
					
						commit
						a6c1bb6324
					
				| 
						 | 
					@ -67,11 +67,18 @@ py_library(
 | 
				
			||||||
    name = "loss_functions",
 | 
					    name = "loss_functions",
 | 
				
			||||||
    srcs = ["loss_functions.py"],
 | 
					    srcs = ["loss_functions.py"],
 | 
				
			||||||
    srcs_version = "PY3",
 | 
					    srcs_version = "PY3",
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":file_util",
 | 
				
			||||||
 | 
					        ":model_util",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_test(
 | 
					py_test(
 | 
				
			||||||
    name = "loss_functions_test",
 | 
					    name = "loss_functions_test",
 | 
				
			||||||
    srcs = ["loss_functions_test.py"],
 | 
					    srcs = ["loss_functions_test.py"],
 | 
				
			||||||
 | 
					    tags = [
 | 
				
			||||||
 | 
					        "requires-net:external",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
    deps = [":loss_functions"],
 | 
					    deps = [":loss_functions"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,10 +13,21 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
"""Loss function utility library."""
 | 
					"""Loss function utility library."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional, Sequence
 | 
					import abc
 | 
				
			||||||
 | 
					from typing import Mapping, Sequence
 | 
				
			||||||
 | 
					import dataclasses
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core.utils import file_util
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core.utils import model_util
 | 
				
			||||||
 | 
					from official.modeling import tf_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_VGG_IMAGENET_PERCEPTUAL_MODEL_URL = 'https://storage.googleapis.com/mediapipe-assets/vgg_feature_extractor.tar.gz'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FocalLoss(tf.keras.losses.Loss):
 | 
					class FocalLoss(tf.keras.losses.Loss):
 | 
				
			||||||
  """Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
 | 
					  """Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
 | 
				
			||||||
| 
						 | 
					@ -45,7 +56,6 @@ class FocalLoss(tf.keras.losses.Loss):
 | 
				
			||||||
  ```python
 | 
					  ```python
 | 
				
			||||||
  model.compile(optimizer='sgd', loss=FocalLoss(gamma))
 | 
					  model.compile(optimizer='sgd', loss=FocalLoss(gamma))
 | 
				
			||||||
  ```
 | 
					  ```
 | 
				
			||||||
 | 
					 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
 | 
					  def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
 | 
				
			||||||
| 
						 | 
					@ -103,3 +113,206 @@ class FocalLoss(tf.keras.losses.Loss):
 | 
				
			||||||
    # By default, this function uses "sum_over_batch_size" reduction for the
 | 
					    # By default, this function uses "sum_over_batch_size" reduction for the
 | 
				
			||||||
    # loss per batch.
 | 
					    # loss per batch.
 | 
				
			||||||
    return tf.reduce_sum(losses) / batch_size
 | 
					    return tf.reduce_sum(losses) / batch_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class PerceptualLossWeight:
 | 
				
			||||||
 | 
					  """The weight for each perceptual loss.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Attributes:
 | 
				
			||||||
 | 
					    l1: weight for L1 loss.
 | 
				
			||||||
 | 
					    content: weight for content loss.
 | 
				
			||||||
 | 
					    style: weight for style loss.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  l1: float = 1.0
 | 
				
			||||||
 | 
					  content: float = 1.0
 | 
				
			||||||
 | 
					  style: float = 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
 | 
				
			||||||
 | 
					  """Base class for perceptual loss model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      feature_weight: Optional[Sequence[float]] = None,
 | 
				
			||||||
 | 
					      loss_weight: Optional[PerceptualLossWeight] = None,
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    """Instantiates perceptual loss.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      feature_weight: The weight coeffcients of multiple model extracted
 | 
				
			||||||
 | 
					        features used for calculating the perceptual loss.
 | 
				
			||||||
 | 
					      loss_weight: The weight coefficients between `style_loss` and
 | 
				
			||||||
 | 
					        `content_loss`.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    super().__init__()
 | 
				
			||||||
 | 
					    self._loss_op = lambda x, y: tf.math.reduce_mean(tf.abs(x - y))
 | 
				
			||||||
 | 
					    self._loss_style = tf.constant(0.0)
 | 
				
			||||||
 | 
					    self._loss_content = tf.constant(0.0)
 | 
				
			||||||
 | 
					    self._feature_weight = feature_weight
 | 
				
			||||||
 | 
					    self._loss_weight = loss_weight
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __call__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      img1: tf.Tensor,
 | 
				
			||||||
 | 
					      img2: tf.Tensor,
 | 
				
			||||||
 | 
					  ) -> Mapping[str, tf.Tensor]:
 | 
				
			||||||
 | 
					    """Computes perceptual loss between two images.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      img1: First batch of images. The pixel values should be normalized to [-1,
 | 
				
			||||||
 | 
					        1].
 | 
				
			||||||
 | 
					      img2: Second batch of images. The pixel values should be normalized to
 | 
				
			||||||
 | 
					        [-1, 1].
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A mapping between loss name and loss tensors.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    x_features = self._compute_features(img1)
 | 
				
			||||||
 | 
					    y_features = self._compute_features(img2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self._loss_weight is None:
 | 
				
			||||||
 | 
					      self._loss_weight = PerceptualLossWeight()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # If the _feature_weight is not initialized, then initialize it as a list of
 | 
				
			||||||
 | 
					    # all the element equals to 1.0.
 | 
				
			||||||
 | 
					    if self._feature_weight is None:
 | 
				
			||||||
 | 
					      self._feature_weight = [1.0] * len(x_features)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # If the length of _feature_weight smallert than the length of the feature,
 | 
				
			||||||
 | 
					    # raise a ValueError. Otherwise, only use the first len(x_features) weight
 | 
				
			||||||
 | 
					    # for computing the loss.
 | 
				
			||||||
 | 
					    if len(self._feature_weight) < len(x_features):
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          f'Input feature weight length {len(self._feature_weight)} is smaller'
 | 
				
			||||||
 | 
					          f' than feature length {len(x_features)}'
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self._loss_weight.style > 0.0:
 | 
				
			||||||
 | 
					      self._loss_style = tf_utils.safe_mean(
 | 
				
			||||||
 | 
					          self._loss_weight.style
 | 
				
			||||||
 | 
					          * self._get_style_loss(x_feats=x_features, y_feats=y_features)
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    if self._loss_weight.content > 0.0:
 | 
				
			||||||
 | 
					      self._loss_content = tf_utils.safe_mean(
 | 
				
			||||||
 | 
					          self._loss_weight.content
 | 
				
			||||||
 | 
					          * self._get_content_loss(x_feats=x_features, y_feats=y_features)
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return {'style_loss': self._loss_style, 'content_loss': self._loss_content}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @abc.abstractmethod
 | 
				
			||||||
 | 
					  def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
 | 
				
			||||||
 | 
					    """Computes features from the given image tensor.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      img: Image tensor.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A list of multi-scale feature maps.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _get_content_loss(
 | 
				
			||||||
 | 
					      self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor]
 | 
				
			||||||
 | 
					  ) -> tf.Tensor:
 | 
				
			||||||
 | 
					    """Gets weighted multi-scale content loss.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      x_feats: Reconstructed face image.
 | 
				
			||||||
 | 
					      y_feats: Target face image.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A scalar tensor for the content loss.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    content_losses = []
 | 
				
			||||||
 | 
					    for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats):
 | 
				
			||||||
 | 
					      content_loss = self._loss_op(x_feat, y_feat) * coef
 | 
				
			||||||
 | 
					      content_losses.append(content_loss)
 | 
				
			||||||
 | 
					    return tf.math.reduce_sum(content_losses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _get_style_loss(
 | 
				
			||||||
 | 
					      self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor]
 | 
				
			||||||
 | 
					  ) -> tf.Tensor:
 | 
				
			||||||
 | 
					    """Gets weighted multi-scale style loss.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      x_feats: Reconstructed face image.
 | 
				
			||||||
 | 
					      y_feats: Target face image.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A scalar tensor for the style loss.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    style_losses = []
 | 
				
			||||||
 | 
					    i = 0
 | 
				
			||||||
 | 
					    for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats):
 | 
				
			||||||
 | 
					      x_feat_g = _compute_gram_matrix(x_feat)
 | 
				
			||||||
 | 
					      y_feat_g = _compute_gram_matrix(y_feat)
 | 
				
			||||||
 | 
					      style_loss = self._loss_op(x_feat_g, y_feat_g) * coef
 | 
				
			||||||
 | 
					      style_losses.append(style_loss)
 | 
				
			||||||
 | 
					      i = i + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return tf.math.reduce_sum(style_loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VGGPerceptualLoss(PerceptualLoss):
 | 
				
			||||||
 | 
					  """Perceptual loss based on VGG19 pretrained on the ImageNet dataset.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Reference:
 | 
				
			||||||
 | 
					  - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](
 | 
				
			||||||
 | 
					      https://arxiv.org/abs/1603.08155) (ECCV 2016)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Perceptual loss measures high-level perceptual and semantic differences
 | 
				
			||||||
 | 
					  between images.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      loss_weight: Optional[PerceptualLossWeight] = None,
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    """Initializes image quality loss essentials.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      loss_weight: Loss weight coefficients.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    super().__init__(
 | 
				
			||||||
 | 
					        feature_weight=np.array([0.1, 0.1, 1.0, 1.0, 1.0]),
 | 
				
			||||||
 | 
					        loss_weight=loss_weight,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rgb_mean = tf.constant([0.485, 0.456, 0.406])
 | 
				
			||||||
 | 
					    rgb_std = tf.constant([0.229, 0.224, 0.225])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    self._rgb_mean = tf.reshape(rgb_mean, (1, 1, 1, 3))
 | 
				
			||||||
 | 
					    self._rgb_std = tf.reshape(rgb_std, (1, 1, 1, 3))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_path = file_util.DownloadedFiles(
 | 
				
			||||||
 | 
					        'vgg_feature_extractor',
 | 
				
			||||||
 | 
					        _VGG_IMAGENET_PERCEPTUAL_MODEL_URL,
 | 
				
			||||||
 | 
					        is_folder=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self._vgg19 = model_util.load_keras_model(model_path.get_path())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
 | 
				
			||||||
 | 
					    """Computes VGG19 features."""
 | 
				
			||||||
 | 
					    img = (img + 1) / 2.0
 | 
				
			||||||
 | 
					    norm_img = (img - self._rgb_mean) / self._rgb_std
 | 
				
			||||||
 | 
					    # no grad, as it only serves as a frozen feature extractor.
 | 
				
			||||||
 | 
					    return self._vgg19(norm_img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _compute_gram_matrix(feature: tf.Tensor) -> tf.Tensor:
 | 
				
			||||||
 | 
					  """Computes gram matrix for the feature map.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Args:
 | 
				
			||||||
 | 
					    feature: [B, H, W, C] feature map.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Returns:
 | 
				
			||||||
 | 
					    [B, C, C] gram matrix.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  h, w, c = feature.shape[1:].as_list()
 | 
				
			||||||
 | 
					  feat_reshaped = tf.reshape(feature, shape=(-1, h * w, c))
 | 
				
			||||||
 | 
					  feat_gram = tf.matmul(
 | 
				
			||||||
 | 
					      tf.transpose(feat_reshaped, perm=[0, 2, 1]), feat_reshaped
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					  return feat_gram / (c * h * w)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,7 +13,9 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
from typing import Optional
 | 
					import tempfile
 | 
				
			||||||
 | 
					from typing import Dict, Optional, Sequence
 | 
				
			||||||
 | 
					from unittest import mock as unittest_mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from absl.testing import parameterized
 | 
					from absl.testing import parameterized
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
| 
						 | 
					@ -21,7 +23,7 @@ import tensorflow as tf
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import loss_functions
 | 
					from mediapipe.model_maker.python.core.utils import loss_functions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
 | 
					class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @parameterized.named_parameters(
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
      dict(testcase_name='no_sample_weight', sample_weight=None),
 | 
					      dict(testcase_name='no_sample_weight', sample_weight=None),
 | 
				
			||||||
| 
						 | 
					@ -99,5 +101,182 @@ class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
    self.assertNear(loss, expected_loss, 1e-4)
 | 
					    self.assertNear(loss, expected_loss, 1e-4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MockPerceptualLoss(loss_functions.PerceptualLoss):
 | 
				
			||||||
 | 
					  """A mock class with implementation of abstract methods for testing."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      use_mock_loss_op: bool = False,
 | 
				
			||||||
 | 
					      feature_weight: Optional[Sequence[float]] = None,
 | 
				
			||||||
 | 
					      loss_weight: Optional[loss_functions.PerceptualLossWeight] = None,
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    super().__init__(feature_weight=feature_weight, loss_weight=loss_weight)
 | 
				
			||||||
 | 
					    if use_mock_loss_op:
 | 
				
			||||||
 | 
					      self._loss_op = lambda x, y: tf.math.reduce_mean(x - y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
 | 
				
			||||||
 | 
					    return [tf.random.normal(shape=(1, 8, 8, 3))] * 5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super().setUp()
 | 
				
			||||||
 | 
					    self._img1 = tf.fill(dims=(8, 8), value=0.2)
 | 
				
			||||||
 | 
					    self._img2 = tf.fill(dims=(8, 8), value=0.8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_invalid_feature_weight_raise_value_error(self):
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					        ValueError,
 | 
				
			||||||
 | 
					        'Input feature weight length 2 is smaller than feature length 5',
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					      MockPerceptualLoss(feature_weight=[1.0, 2.0])(
 | 
				
			||||||
 | 
					          img1=self._img1, img2=self._img2
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='default_loss_weight_and_loss_op',
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=None,
 | 
				
			||||||
 | 
					          loss_weight=None,
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 0.032839,
 | 
				
			||||||
 | 
					              'content_loss': 5.639870,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='style_loss_weight_is_0_default_loss_op',
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=None,
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(style=0),
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 0,
 | 
				
			||||||
 | 
					              'content_loss': 5.639870,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='content_loss_weight_is_0_default_loss_op',
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=None,
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(content=0),
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 0.032839,
 | 
				
			||||||
 | 
					              'content_loss': 0,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='customized_loss_weight_default_loss_op',
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=None,
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              style=1.0, content=2.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_values={'style_loss': 0.032839, 'content_loss': 11.279739},
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name=(
 | 
				
			||||||
 | 
					              'customized_feature_weight_and_loss_weight_default_loss_op'
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0],
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              style=1.0, content=2.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_values={'style_loss': 0.164193, 'content_loss': 33.839218},
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='no_loss_change_if_extra_feature_weight_provided',
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              style=1.0, content=2.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 0.164193,
 | 
				
			||||||
 | 
					              'content_loss': 33.839218,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='customized_loss_weight_custom_loss_op',
 | 
				
			||||||
 | 
					          use_mock_loss_op=True,
 | 
				
			||||||
 | 
					          feature_weight=None,
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              style=1.0, content=2.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_values={'style_loss': 0.000395, 'content_loss': -1.533469},
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					  def test_weighted_perceptul_loss(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      use_mock_loss_op: bool,
 | 
				
			||||||
 | 
					      feature_weight: Sequence[float],
 | 
				
			||||||
 | 
					      loss_weight: loss_functions.PerceptualLossWeight,
 | 
				
			||||||
 | 
					      loss_values: Dict[str, float],
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    perceptual_loss = MockPerceptualLoss(
 | 
				
			||||||
 | 
					        use_mock_loss_op=use_mock_loss_op,
 | 
				
			||||||
 | 
					        feature_weight=feature_weight,
 | 
				
			||||||
 | 
					        loss_weight=loss_weight,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    loss = perceptual_loss(img1=self._img1, img2=self._img2)
 | 
				
			||||||
 | 
					    self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss'])
 | 
				
			||||||
 | 
					    self.assertNear(loss['style_loss'], loss_values['style_loss'], 1e-4)
 | 
				
			||||||
 | 
					    self.assertNear(loss['content_loss'], loss_values['content_loss'], 1e-4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VGGPerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super().setUp()
 | 
				
			||||||
 | 
					    # Mock tempfile.gettempdir() to be unique for each test to avoid race
 | 
				
			||||||
 | 
					    # condition when downloading model since these tests may run in parallel.
 | 
				
			||||||
 | 
					    mock_gettempdir = unittest_mock.patch.object(
 | 
				
			||||||
 | 
					        tempfile,
 | 
				
			||||||
 | 
					        'gettempdir',
 | 
				
			||||||
 | 
					        return_value=self.create_tempdir(),
 | 
				
			||||||
 | 
					        autospec=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self.mock_gettempdir = mock_gettempdir.start()
 | 
				
			||||||
 | 
					    self.addCleanup(mock_gettempdir.stop)
 | 
				
			||||||
 | 
					    self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1)
 | 
				
			||||||
 | 
					    self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='default_loss_weight',
 | 
				
			||||||
 | 
					          loss_weight=None,
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 5.8363257e-06,
 | 
				
			||||||
 | 
					              'content_loss': 1.7016045,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='customized_loss_weight',
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              style=10.0, content=20.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 5.8363257e-05,
 | 
				
			||||||
 | 
					              'content_loss': 34.03208,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					  def test_vgg_perceptual_loss(self, loss_weight, loss_values):
 | 
				
			||||||
 | 
					    vgg_loss = loss_functions.VGGPerceptualLoss(loss_weight=loss_weight)
 | 
				
			||||||
 | 
					    loss = vgg_loss(img1=self._img1, img2=self._img2)
 | 
				
			||||||
 | 
					    self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss'])
 | 
				
			||||||
 | 
					    self.assertNear(
 | 
				
			||||||
 | 
					        loss['style_loss'],
 | 
				
			||||||
 | 
					        loss_values['style_loss'],
 | 
				
			||||||
 | 
					        loss_values['style_loss'] / 1e5,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self.assertNear(
 | 
				
			||||||
 | 
					        loss['content_loss'],
 | 
				
			||||||
 | 
					        loss_values['content_loss'],
 | 
				
			||||||
 | 
					        loss_values['content_loss'] / 1e5,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
  tf.test.main()
 | 
					  tf.test.main()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user