Add MobileNetV2_I320 and MobileNetMultiHWAVG_I384 to support larger input image sizes.

PiperOrigin-RevId: 544393692
This commit is contained in:
MediaPipe Team 2023-06-29 10:22:18 -07:00 committed by Copybara-Service
parent 8278dbc38f
commit 0bb4ee8941
4 changed files with 68 additions and 13 deletions

View File

@ -74,8 +74,8 @@ class ObjectDetectorModel(tf.keras.Model):
generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
) -> configs.retinanet.RetinaNet:
model_config = configs.retinanet.RetinaNet(
min_level=3,
max_level=7,
min_level=self._model_spec.min_level,
max_level=self._model_spec.max_level,
num_classes=self._num_classes,
input_size=self._model_spec.input_image_shape,
anchor=configs.retinanet.Anchor(

View File

@ -20,18 +20,30 @@ from typing import List
from mediapipe.model_maker.python.core.utils import file_util
MOBILENET_V2_FILES = file_util.DownloadedFiles(
'object_detector/mobilenetv2',
MOBILENET_V2_I256_FILES = file_util.DownloadedFiles(
'object_detector/mobilenetv2_i256',
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz',
is_folder=True,
)
MOBILENET_V2_I320_FILES = file_util.DownloadedFiles(
'object_detector/mobilenetv2_i320',
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i320_ckpt.tar.gz',
is_folder=True,
)
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
'object_detector/mobilenetmultiavg',
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
is_folder=True,
)
MOBILENET_MULTI_AVG_I384_FILES = file_util.DownloadedFiles(
'object_detector/mobilenetmultiavg_i384',
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv3.5_ssd_i384_ckpt.tar.gz',
is_folder=True,
)
@dataclasses.dataclass
class ModelSpec(object):
@ -48,30 +60,66 @@ class ModelSpec(object):
input_image_shape: List[int]
model_id: str
# Model Config values
min_level: int
max_level: int
mobilenet_v2_spec = functools.partial(
mobilenet_v2_i256_spec = functools.partial(
ModelSpec,
downloaded_files=MOBILENET_V2_FILES,
downloaded_files=MOBILENET_V2_I256_FILES,
checkpoint_name='ckpt-277200',
input_image_shape=[256, 256, 3],
model_id='MobileNetV2',
min_level=3,
max_level=7,
)
mobilenet_multi_avg_spec = functools.partial(
mobilenet_v2_i320_spec = functools.partial(
ModelSpec,
downloaded_files=MOBILENET_V2_I320_FILES,
checkpoint_name='ckpt-277200',
input_image_shape=[320, 320, 3],
model_id='MobileNetV2',
min_level=3,
max_level=6,
)
mobilenet_multi_avg_i256_spec = functools.partial(
ModelSpec,
downloaded_files=MOBILENET_MULTI_AVG_FILES,
checkpoint_name='ckpt-277200',
input_image_shape=[256, 256, 3],
model_id='MobileNetMultiAVG',
min_level=3,
max_level=7,
)
mobilenet_multi_avg_i384_spec = functools.partial(
ModelSpec,
downloaded_files=MOBILENET_MULTI_AVG_I384_FILES,
checkpoint_name='ckpt-277200',
input_image_shape=[384, 384, 3],
model_id='MobileNetMultiAVG',
min_level=3,
max_level=7,
)
@enum.unique
class SupportedModels(enum.Enum):
"""Predefined object detector model specs supported by Model Maker."""
"""Predefined object detector model specs supported by Model Maker.
MOBILENET_V2 = mobilenet_v2_spec
MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
Supported models include the following:
- MOBILENET_V2: MobileNetV2 256x256 input
- MOBILENET_V2_I320: MobileNetV2 320x320 input
- MOBILENET_MULTI_AVG: MobileNet-MultiHW-AVG 256x256 input
- MOBILENET_MULTI_AVG_I384: MobileNet-MultiHW-AVG 384x384 input
"""
MOBILENET_V2 = mobilenet_v2_i256_spec
MOBILENET_V2_I320 = mobilenet_v2_i320_spec
MOBILENET_MULTI_AVG = mobilenet_multi_avg_i256_spec
MOBILENET_MULTI_AVG_I384 = mobilenet_multi_avg_i384_spec
@classmethod
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':

View File

@ -395,7 +395,7 @@ class ObjectDetector(classifier.Classifier):
) -> tf.keras.optimizers.Optimizer:
"""Creates an optimizer with learning rate schedule for regular training.
Uses Keras PiecewiseConstantDecay schedule by default.
Uses Keras CosineDecay schedule by default.
Args:
steps_per_epoch: Steps per epoch to calculate the step boundaries from the
@ -404,6 +404,8 @@ class ObjectDetector(classifier.Classifier):
Returns:
A tf.keras.optimizer.Optimizer for model training.
"""
total_steps = steps_per_epoch * self._hparams.epochs
warmup_steps = int(total_steps * 0.1)
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
decay_epochs = (
self._hparams.cosine_decay_epochs
@ -415,6 +417,11 @@ class ObjectDetector(classifier.Classifier):
steps_per_epoch * decay_epochs,
self._hparams.cosine_decay_alpha,
)
learning_rate = model_util.WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate,
warmup_steps=warmup_steps,
)
return tf.keras.optimizers.experimental.SGD(
learning_rate=learning_rate, momentum=0.9
)

View File

@ -32,8 +32,8 @@ class Preprocessor(object):
self._mean_norm = model_spec.mean_norm
self._stddev_norm = model_spec.stddev_norm
self._output_size = model_spec.input_image_shape[:2]
self._min_level = 3
self._max_level = 7
self._min_level = model_spec.min_level
self._max_level = model_spec.max_level
self._num_scales = 3
self._aspect_ratios = [0.5, 1, 2]
self._anchor_size = 3