Add MobileNetV2_I320 and MobileNetMultiHWAVG_I384 to support larger input image sizes.
PiperOrigin-RevId: 544393692
This commit is contained in:
		
							parent
							
								
									8278dbc38f
								
							
						
					
					
						commit
						0bb4ee8941
					
				| 
						 | 
					@ -74,8 +74,8 @@ class ObjectDetectorModel(tf.keras.Model):
 | 
				
			||||||
      generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
 | 
					      generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
 | 
				
			||||||
  ) -> configs.retinanet.RetinaNet:
 | 
					  ) -> configs.retinanet.RetinaNet:
 | 
				
			||||||
    model_config = configs.retinanet.RetinaNet(
 | 
					    model_config = configs.retinanet.RetinaNet(
 | 
				
			||||||
        min_level=3,
 | 
					        min_level=self._model_spec.min_level,
 | 
				
			||||||
        max_level=7,
 | 
					        max_level=self._model_spec.max_level,
 | 
				
			||||||
        num_classes=self._num_classes,
 | 
					        num_classes=self._num_classes,
 | 
				
			||||||
        input_size=self._model_spec.input_image_shape,
 | 
					        input_size=self._model_spec.input_image_shape,
 | 
				
			||||||
        anchor=configs.retinanet.Anchor(
 | 
					        anchor=configs.retinanet.Anchor(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,18 +20,30 @@ from typing import List
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import file_util
 | 
					from mediapipe.model_maker.python.core.utils import file_util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
MOBILENET_V2_FILES = file_util.DownloadedFiles(
 | 
					MOBILENET_V2_I256_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
    'object_detector/mobilenetv2',
 | 
					    'object_detector/mobilenetv2_i256',
 | 
				
			||||||
    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz',
 | 
					    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz',
 | 
				
			||||||
    is_folder=True,
 | 
					    is_folder=True,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MOBILENET_V2_I320_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
 | 
					    'object_detector/mobilenetv2_i320',
 | 
				
			||||||
 | 
					    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i320_ckpt.tar.gz',
 | 
				
			||||||
 | 
					    is_folder=True,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
 | 
					MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
    'object_detector/mobilenetmultiavg',
 | 
					    'object_detector/mobilenetmultiavg',
 | 
				
			||||||
    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
 | 
					    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
 | 
				
			||||||
    is_folder=True,
 | 
					    is_folder=True,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MOBILENET_MULTI_AVG_I384_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
 | 
					    'object_detector/mobilenetmultiavg_i384',
 | 
				
			||||||
 | 
					    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv3.5_ssd_i384_ckpt.tar.gz',
 | 
				
			||||||
 | 
					    is_folder=True,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclasses.dataclass
 | 
					@dataclasses.dataclass
 | 
				
			||||||
class ModelSpec(object):
 | 
					class ModelSpec(object):
 | 
				
			||||||
| 
						 | 
					@ -48,30 +60,66 @@ class ModelSpec(object):
 | 
				
			||||||
  input_image_shape: List[int]
 | 
					  input_image_shape: List[int]
 | 
				
			||||||
  model_id: str
 | 
					  model_id: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # Model Config values
 | 
				
			||||||
 | 
					  min_level: int
 | 
				
			||||||
 | 
					  max_level: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
mobilenet_v2_spec = functools.partial(
 | 
					
 | 
				
			||||||
 | 
					mobilenet_v2_i256_spec = functools.partial(
 | 
				
			||||||
    ModelSpec,
 | 
					    ModelSpec,
 | 
				
			||||||
    downloaded_files=MOBILENET_V2_FILES,
 | 
					    downloaded_files=MOBILENET_V2_I256_FILES,
 | 
				
			||||||
    checkpoint_name='ckpt-277200',
 | 
					    checkpoint_name='ckpt-277200',
 | 
				
			||||||
    input_image_shape=[256, 256, 3],
 | 
					    input_image_shape=[256, 256, 3],
 | 
				
			||||||
    model_id='MobileNetV2',
 | 
					    model_id='MobileNetV2',
 | 
				
			||||||
 | 
					    min_level=3,
 | 
				
			||||||
 | 
					    max_level=7,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
mobilenet_multi_avg_spec = functools.partial(
 | 
					mobilenet_v2_i320_spec = functools.partial(
 | 
				
			||||||
 | 
					    ModelSpec,
 | 
				
			||||||
 | 
					    downloaded_files=MOBILENET_V2_I320_FILES,
 | 
				
			||||||
 | 
					    checkpoint_name='ckpt-277200',
 | 
				
			||||||
 | 
					    input_image_shape=[320, 320, 3],
 | 
				
			||||||
 | 
					    model_id='MobileNetV2',
 | 
				
			||||||
 | 
					    min_level=3,
 | 
				
			||||||
 | 
					    max_level=6,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mobilenet_multi_avg_i256_spec = functools.partial(
 | 
				
			||||||
    ModelSpec,
 | 
					    ModelSpec,
 | 
				
			||||||
    downloaded_files=MOBILENET_MULTI_AVG_FILES,
 | 
					    downloaded_files=MOBILENET_MULTI_AVG_FILES,
 | 
				
			||||||
    checkpoint_name='ckpt-277200',
 | 
					    checkpoint_name='ckpt-277200',
 | 
				
			||||||
    input_image_shape=[256, 256, 3],
 | 
					    input_image_shape=[256, 256, 3],
 | 
				
			||||||
    model_id='MobileNetMultiAVG',
 | 
					    model_id='MobileNetMultiAVG',
 | 
				
			||||||
 | 
					    min_level=3,
 | 
				
			||||||
 | 
					    max_level=7,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mobilenet_multi_avg_i384_spec = functools.partial(
 | 
				
			||||||
 | 
					    ModelSpec,
 | 
				
			||||||
 | 
					    downloaded_files=MOBILENET_MULTI_AVG_I384_FILES,
 | 
				
			||||||
 | 
					    checkpoint_name='ckpt-277200',
 | 
				
			||||||
 | 
					    input_image_shape=[384, 384, 3],
 | 
				
			||||||
 | 
					    model_id='MobileNetMultiAVG',
 | 
				
			||||||
 | 
					    min_level=3,
 | 
				
			||||||
 | 
					    max_level=7,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@enum.unique
 | 
					@enum.unique
 | 
				
			||||||
class SupportedModels(enum.Enum):
 | 
					class SupportedModels(enum.Enum):
 | 
				
			||||||
  """Predefined object detector model specs supported by Model Maker."""
 | 
					  """Predefined object detector model specs supported by Model Maker.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  MOBILENET_V2 = mobilenet_v2_spec
 | 
					  Supported models include the following:
 | 
				
			||||||
  MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
 | 
					  - MOBILENET_V2: MobileNetV2 256x256 input
 | 
				
			||||||
 | 
					  - MOBILENET_V2_I320: MobileNetV2 320x320 input
 | 
				
			||||||
 | 
					  - MOBILENET_MULTI_AVG: MobileNet-MultiHW-AVG 256x256 input
 | 
				
			||||||
 | 
					  - MOBILENET_MULTI_AVG_I384: MobileNet-MultiHW-AVG 384x384 input
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  MOBILENET_V2 = mobilenet_v2_i256_spec
 | 
				
			||||||
 | 
					  MOBILENET_V2_I320 = mobilenet_v2_i320_spec
 | 
				
			||||||
 | 
					  MOBILENET_MULTI_AVG = mobilenet_multi_avg_i256_spec
 | 
				
			||||||
 | 
					  MOBILENET_MULTI_AVG_I384 = mobilenet_multi_avg_i384_spec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @classmethod
 | 
					  @classmethod
 | 
				
			||||||
  def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
 | 
					  def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -395,7 +395,7 @@ class ObjectDetector(classifier.Classifier):
 | 
				
			||||||
  ) -> tf.keras.optimizers.Optimizer:
 | 
					  ) -> tf.keras.optimizers.Optimizer:
 | 
				
			||||||
    """Creates an optimizer with learning rate schedule for regular training.
 | 
					    """Creates an optimizer with learning rate schedule for regular training.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Uses Keras PiecewiseConstantDecay schedule by default.
 | 
					    Uses Keras CosineDecay schedule by default.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
      steps_per_epoch: Steps per epoch to calculate the step boundaries from the
 | 
					      steps_per_epoch: Steps per epoch to calculate the step boundaries from the
 | 
				
			||||||
| 
						 | 
					@ -404,6 +404,8 @@ class ObjectDetector(classifier.Classifier):
 | 
				
			||||||
    Returns:
 | 
					    Returns:
 | 
				
			||||||
      A tf.keras.optimizer.Optimizer for model training.
 | 
					      A tf.keras.optimizer.Optimizer for model training.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					    total_steps = steps_per_epoch * self._hparams.epochs
 | 
				
			||||||
 | 
					    warmup_steps = int(total_steps * 0.1)
 | 
				
			||||||
    init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
 | 
					    init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
 | 
				
			||||||
    decay_epochs = (
 | 
					    decay_epochs = (
 | 
				
			||||||
        self._hparams.cosine_decay_epochs
 | 
					        self._hparams.cosine_decay_epochs
 | 
				
			||||||
| 
						 | 
					@ -415,6 +417,11 @@ class ObjectDetector(classifier.Classifier):
 | 
				
			||||||
        steps_per_epoch * decay_epochs,
 | 
					        steps_per_epoch * decay_epochs,
 | 
				
			||||||
        self._hparams.cosine_decay_alpha,
 | 
					        self._hparams.cosine_decay_alpha,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    learning_rate = model_util.WarmUp(
 | 
				
			||||||
 | 
					        initial_learning_rate=init_lr,
 | 
				
			||||||
 | 
					        decay_schedule_fn=learning_rate,
 | 
				
			||||||
 | 
					        warmup_steps=warmup_steps,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    return tf.keras.optimizers.experimental.SGD(
 | 
					    return tf.keras.optimizers.experimental.SGD(
 | 
				
			||||||
        learning_rate=learning_rate, momentum=0.9
 | 
					        learning_rate=learning_rate, momentum=0.9
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,8 +32,8 @@ class Preprocessor(object):
 | 
				
			||||||
    self._mean_norm = model_spec.mean_norm
 | 
					    self._mean_norm = model_spec.mean_norm
 | 
				
			||||||
    self._stddev_norm = model_spec.stddev_norm
 | 
					    self._stddev_norm = model_spec.stddev_norm
 | 
				
			||||||
    self._output_size = model_spec.input_image_shape[:2]
 | 
					    self._output_size = model_spec.input_image_shape[:2]
 | 
				
			||||||
    self._min_level = 3
 | 
					    self._min_level = model_spec.min_level
 | 
				
			||||||
    self._max_level = 7
 | 
					    self._max_level = model_spec.max_level
 | 
				
			||||||
    self._num_scales = 3
 | 
					    self._num_scales = 3
 | 
				
			||||||
    self._aspect_ratios = [0.5, 1, 2]
 | 
					    self._aspect_ratios = [0.5, 1, 2]
 | 
				
			||||||
    self._anchor_size = 3
 | 
					    self._anchor_size = 3
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user