Add MobileNetV2_I320 and MobileNetMultiHWAVG_I384 to support larger input image sizes.

PiperOrigin-RevId: 544393692
This commit is contained in:
MediaPipe Team 2023-06-29 10:22:18 -07:00 committed by Copybara-Service
parent 8278dbc38f
commit 0bb4ee8941
4 changed files with 68 additions and 13 deletions

View File

@ -74,8 +74,8 @@ class ObjectDetectorModel(tf.keras.Model):
generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(), generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
) -> configs.retinanet.RetinaNet: ) -> configs.retinanet.RetinaNet:
model_config = configs.retinanet.RetinaNet( model_config = configs.retinanet.RetinaNet(
min_level=3, min_level=self._model_spec.min_level,
max_level=7, max_level=self._model_spec.max_level,
num_classes=self._num_classes, num_classes=self._num_classes,
input_size=self._model_spec.input_image_shape, input_size=self._model_spec.input_image_shape,
anchor=configs.retinanet.Anchor( anchor=configs.retinanet.Anchor(

View File

@ -20,18 +20,30 @@ from typing import List
from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.core.utils import file_util
MOBILENET_V2_FILES = file_util.DownloadedFiles( MOBILENET_V2_I256_FILES = file_util.DownloadedFiles(
'object_detector/mobilenetv2', 'object_detector/mobilenetv2_i256',
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz', 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz',
is_folder=True, is_folder=True,
) )
MOBILENET_V2_I320_FILES = file_util.DownloadedFiles(
'object_detector/mobilenetv2_i320',
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i320_ckpt.tar.gz',
is_folder=True,
)
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles( MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
'object_detector/mobilenetmultiavg', 'object_detector/mobilenetmultiavg',
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz', 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
is_folder=True, is_folder=True,
) )
MOBILENET_MULTI_AVG_I384_FILES = file_util.DownloadedFiles(
'object_detector/mobilenetmultiavg_i384',
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv3.5_ssd_i384_ckpt.tar.gz',
is_folder=True,
)
@dataclasses.dataclass @dataclasses.dataclass
class ModelSpec(object): class ModelSpec(object):
@ -48,30 +60,66 @@ class ModelSpec(object):
input_image_shape: List[int] input_image_shape: List[int]
model_id: str model_id: str
# Model Config values
min_level: int
max_level: int
mobilenet_v2_spec = functools.partial(
mobilenet_v2_i256_spec = functools.partial(
ModelSpec, ModelSpec,
downloaded_files=MOBILENET_V2_FILES, downloaded_files=MOBILENET_V2_I256_FILES,
checkpoint_name='ckpt-277200', checkpoint_name='ckpt-277200',
input_image_shape=[256, 256, 3], input_image_shape=[256, 256, 3],
model_id='MobileNetV2', model_id='MobileNetV2',
min_level=3,
max_level=7,
) )
mobilenet_multi_avg_spec = functools.partial( mobilenet_v2_i320_spec = functools.partial(
ModelSpec,
downloaded_files=MOBILENET_V2_I320_FILES,
checkpoint_name='ckpt-277200',
input_image_shape=[320, 320, 3],
model_id='MobileNetV2',
min_level=3,
max_level=6,
)
mobilenet_multi_avg_i256_spec = functools.partial(
ModelSpec, ModelSpec,
downloaded_files=MOBILENET_MULTI_AVG_FILES, downloaded_files=MOBILENET_MULTI_AVG_FILES,
checkpoint_name='ckpt-277200', checkpoint_name='ckpt-277200',
input_image_shape=[256, 256, 3], input_image_shape=[256, 256, 3],
model_id='MobileNetMultiAVG', model_id='MobileNetMultiAVG',
min_level=3,
max_level=7,
)
mobilenet_multi_avg_i384_spec = functools.partial(
ModelSpec,
downloaded_files=MOBILENET_MULTI_AVG_I384_FILES,
checkpoint_name='ckpt-277200',
input_image_shape=[384, 384, 3],
model_id='MobileNetMultiAVG',
min_level=3,
max_level=7,
) )
@enum.unique @enum.unique
class SupportedModels(enum.Enum): class SupportedModels(enum.Enum):
"""Predefined object detector model specs supported by Model Maker.""" """Predefined object detector model specs supported by Model Maker.
MOBILENET_V2 = mobilenet_v2_spec Supported models include the following:
MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec - MOBILENET_V2: MobileNetV2 256x256 input
- MOBILENET_V2_I320: MobileNetV2 320x320 input
- MOBILENET_MULTI_AVG: MobileNet-MultiHW-AVG 256x256 input
- MOBILENET_MULTI_AVG_I384: MobileNet-MultiHW-AVG 384x384 input
"""
MOBILENET_V2 = mobilenet_v2_i256_spec
MOBILENET_V2_I320 = mobilenet_v2_i320_spec
MOBILENET_MULTI_AVG = mobilenet_multi_avg_i256_spec
MOBILENET_MULTI_AVG_I384 = mobilenet_multi_avg_i384_spec
@classmethod @classmethod
def get(cls, spec: 'SupportedModels') -> 'ModelSpec': def get(cls, spec: 'SupportedModels') -> 'ModelSpec':

View File

@ -395,7 +395,7 @@ class ObjectDetector(classifier.Classifier):
) -> tf.keras.optimizers.Optimizer: ) -> tf.keras.optimizers.Optimizer:
"""Creates an optimizer with learning rate schedule for regular training. """Creates an optimizer with learning rate schedule for regular training.
Uses Keras PiecewiseConstantDecay schedule by default. Uses Keras CosineDecay schedule by default.
Args: Args:
steps_per_epoch: Steps per epoch to calculate the step boundaries from the steps_per_epoch: Steps per epoch to calculate the step boundaries from the
@ -404,6 +404,8 @@ class ObjectDetector(classifier.Classifier):
Returns: Returns:
A tf.keras.optimizer.Optimizer for model training. A tf.keras.optimizer.Optimizer for model training.
""" """
total_steps = steps_per_epoch * self._hparams.epochs
warmup_steps = int(total_steps * 0.1)
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
decay_epochs = ( decay_epochs = (
self._hparams.cosine_decay_epochs self._hparams.cosine_decay_epochs
@ -415,6 +417,11 @@ class ObjectDetector(classifier.Classifier):
steps_per_epoch * decay_epochs, steps_per_epoch * decay_epochs,
self._hparams.cosine_decay_alpha, self._hparams.cosine_decay_alpha,
) )
learning_rate = model_util.WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate,
warmup_steps=warmup_steps,
)
return tf.keras.optimizers.experimental.SGD( return tf.keras.optimizers.experimental.SGD(
learning_rate=learning_rate, momentum=0.9 learning_rate=learning_rate, momentum=0.9
) )

View File

@ -32,8 +32,8 @@ class Preprocessor(object):
self._mean_norm = model_spec.mean_norm self._mean_norm = model_spec.mean_norm
self._stddev_norm = model_spec.stddev_norm self._stddev_norm = model_spec.stddev_norm
self._output_size = model_spec.input_image_shape[:2] self._output_size = model_spec.input_image_shape[:2]
self._min_level = 3 self._min_level = model_spec.min_level
self._max_level = 7 self._max_level = model_spec.max_level
self._num_scales = 3 self._num_scales = 3
self._aspect_ratios = [0.5, 1, 2] self._aspect_ratios = [0.5, 1, 2]
self._anchor_size = 3 self._anchor_size = 3