Add MobileNetV2_I320 and MobileNetMultiHWAVG_I384 to support larger input image sizes.
PiperOrigin-RevId: 544393692
This commit is contained in:
		
							parent
							
								
									8278dbc38f
								
							
						
					
					
						commit
						0bb4ee8941
					
				| 
						 | 
				
			
			@ -74,8 +74,8 @@ class ObjectDetectorModel(tf.keras.Model):
 | 
			
		|||
      generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
 | 
			
		||||
  ) -> configs.retinanet.RetinaNet:
 | 
			
		||||
    model_config = configs.retinanet.RetinaNet(
 | 
			
		||||
        min_level=3,
 | 
			
		||||
        max_level=7,
 | 
			
		||||
        min_level=self._model_spec.min_level,
 | 
			
		||||
        max_level=self._model_spec.max_level,
 | 
			
		||||
        num_classes=self._num_classes,
 | 
			
		||||
        input_size=self._model_spec.input_image_shape,
 | 
			
		||||
        anchor=configs.retinanet.Anchor(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,18 +20,30 @@ from typing import List
 | 
			
		|||
from mediapipe.model_maker.python.core.utils import file_util
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
MOBILENET_V2_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'object_detector/mobilenetv2',
 | 
			
		||||
MOBILENET_V2_I256_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'object_detector/mobilenetv2_i256',
 | 
			
		||||
    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz',
 | 
			
		||||
    is_folder=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
MOBILENET_V2_I320_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'object_detector/mobilenetv2_i320',
 | 
			
		||||
    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i320_ckpt.tar.gz',
 | 
			
		||||
    is_folder=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'object_detector/mobilenetmultiavg',
 | 
			
		||||
    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
 | 
			
		||||
    is_folder=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
MOBILENET_MULTI_AVG_I384_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'object_detector/mobilenetmultiavg_i384',
 | 
			
		||||
    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv3.5_ssd_i384_ckpt.tar.gz',
 | 
			
		||||
    is_folder=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class ModelSpec(object):
 | 
			
		||||
| 
						 | 
				
			
			@ -48,30 +60,66 @@ class ModelSpec(object):
 | 
			
		|||
  input_image_shape: List[int]
 | 
			
		||||
  model_id: str
 | 
			
		||||
 | 
			
		||||
  # Model Config values
 | 
			
		||||
  min_level: int
 | 
			
		||||
  max_level: int
 | 
			
		||||
 | 
			
		||||
mobilenet_v2_spec = functools.partial(
 | 
			
		||||
 | 
			
		||||
mobilenet_v2_i256_spec = functools.partial(
 | 
			
		||||
    ModelSpec,
 | 
			
		||||
    downloaded_files=MOBILENET_V2_FILES,
 | 
			
		||||
    downloaded_files=MOBILENET_V2_I256_FILES,
 | 
			
		||||
    checkpoint_name='ckpt-277200',
 | 
			
		||||
    input_image_shape=[256, 256, 3],
 | 
			
		||||
    model_id='MobileNetV2',
 | 
			
		||||
    min_level=3,
 | 
			
		||||
    max_level=7,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mobilenet_multi_avg_spec = functools.partial(
 | 
			
		||||
mobilenet_v2_i320_spec = functools.partial(
 | 
			
		||||
    ModelSpec,
 | 
			
		||||
    downloaded_files=MOBILENET_V2_I320_FILES,
 | 
			
		||||
    checkpoint_name='ckpt-277200',
 | 
			
		||||
    input_image_shape=[320, 320, 3],
 | 
			
		||||
    model_id='MobileNetV2',
 | 
			
		||||
    min_level=3,
 | 
			
		||||
    max_level=6,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mobilenet_multi_avg_i256_spec = functools.partial(
 | 
			
		||||
    ModelSpec,
 | 
			
		||||
    downloaded_files=MOBILENET_MULTI_AVG_FILES,
 | 
			
		||||
    checkpoint_name='ckpt-277200',
 | 
			
		||||
    input_image_shape=[256, 256, 3],
 | 
			
		||||
    model_id='MobileNetMultiAVG',
 | 
			
		||||
    min_level=3,
 | 
			
		||||
    max_level=7,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mobilenet_multi_avg_i384_spec = functools.partial(
 | 
			
		||||
    ModelSpec,
 | 
			
		||||
    downloaded_files=MOBILENET_MULTI_AVG_I384_FILES,
 | 
			
		||||
    checkpoint_name='ckpt-277200',
 | 
			
		||||
    input_image_shape=[384, 384, 3],
 | 
			
		||||
    model_id='MobileNetMultiAVG',
 | 
			
		||||
    min_level=3,
 | 
			
		||||
    max_level=7,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@enum.unique
 | 
			
		||||
class SupportedModels(enum.Enum):
 | 
			
		||||
  """Predefined object detector model specs supported by Model Maker."""
 | 
			
		||||
  """Predefined object detector model specs supported by Model Maker.
 | 
			
		||||
 | 
			
		||||
  MOBILENET_V2 = mobilenet_v2_spec
 | 
			
		||||
  MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
 | 
			
		||||
  Supported models include the following:
 | 
			
		||||
  - MOBILENET_V2: MobileNetV2 256x256 input
 | 
			
		||||
  - MOBILENET_V2_I320: MobileNetV2 320x320 input
 | 
			
		||||
  - MOBILENET_MULTI_AVG: MobileNet-MultiHW-AVG 256x256 input
 | 
			
		||||
  - MOBILENET_MULTI_AVG_I384: MobileNet-MultiHW-AVG 384x384 input
 | 
			
		||||
  """
 | 
			
		||||
  MOBILENET_V2 = mobilenet_v2_i256_spec
 | 
			
		||||
  MOBILENET_V2_I320 = mobilenet_v2_i320_spec
 | 
			
		||||
  MOBILENET_MULTI_AVG = mobilenet_multi_avg_i256_spec
 | 
			
		||||
  MOBILENET_MULTI_AVG_I384 = mobilenet_multi_avg_i384_spec
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
  def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -395,7 +395,7 @@ class ObjectDetector(classifier.Classifier):
 | 
			
		|||
  ) -> tf.keras.optimizers.Optimizer:
 | 
			
		||||
    """Creates an optimizer with learning rate schedule for regular training.
 | 
			
		||||
 | 
			
		||||
    Uses Keras PiecewiseConstantDecay schedule by default.
 | 
			
		||||
    Uses Keras CosineDecay schedule by default.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      steps_per_epoch: Steps per epoch to calculate the step boundaries from the
 | 
			
		||||
| 
						 | 
				
			
			@ -404,6 +404,8 @@ class ObjectDetector(classifier.Classifier):
 | 
			
		|||
    Returns:
 | 
			
		||||
      A tf.keras.optimizer.Optimizer for model training.
 | 
			
		||||
    """
 | 
			
		||||
    total_steps = steps_per_epoch * self._hparams.epochs
 | 
			
		||||
    warmup_steps = int(total_steps * 0.1)
 | 
			
		||||
    init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
 | 
			
		||||
    decay_epochs = (
 | 
			
		||||
        self._hparams.cosine_decay_epochs
 | 
			
		||||
| 
						 | 
				
			
			@ -415,6 +417,11 @@ class ObjectDetector(classifier.Classifier):
 | 
			
		|||
        steps_per_epoch * decay_epochs,
 | 
			
		||||
        self._hparams.cosine_decay_alpha,
 | 
			
		||||
    )
 | 
			
		||||
    learning_rate = model_util.WarmUp(
 | 
			
		||||
        initial_learning_rate=init_lr,
 | 
			
		||||
        decay_schedule_fn=learning_rate,
 | 
			
		||||
        warmup_steps=warmup_steps,
 | 
			
		||||
    )
 | 
			
		||||
    return tf.keras.optimizers.experimental.SGD(
 | 
			
		||||
        learning_rate=learning_rate, momentum=0.9
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,8 +32,8 @@ class Preprocessor(object):
 | 
			
		|||
    self._mean_norm = model_spec.mean_norm
 | 
			
		||||
    self._stddev_norm = model_spec.stddev_norm
 | 
			
		||||
    self._output_size = model_spec.input_image_shape[:2]
 | 
			
		||||
    self._min_level = 3
 | 
			
		||||
    self._max_level = 7
 | 
			
		||||
    self._min_level = model_spec.min_level
 | 
			
		||||
    self._max_level = model_spec.max_level
 | 
			
		||||
    self._num_scales = 3
 | 
			
		||||
    self._aspect_ratios = [0.5, 1, 2]
 | 
			
		||||
    self._anchor_size = 3
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user