Add MobileNetV2_I320 and MobileNetMultiHWAVG_I384 to support larger input image sizes.
PiperOrigin-RevId: 544393692
This commit is contained in:
parent
8278dbc38f
commit
0bb4ee8941
|
@ -74,8 +74,8 @@ class ObjectDetectorModel(tf.keras.Model):
|
|||
generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
|
||||
) -> configs.retinanet.RetinaNet:
|
||||
model_config = configs.retinanet.RetinaNet(
|
||||
min_level=3,
|
||||
max_level=7,
|
||||
min_level=self._model_spec.min_level,
|
||||
max_level=self._model_spec.max_level,
|
||||
num_classes=self._num_classes,
|
||||
input_size=self._model_spec.input_image_shape,
|
||||
anchor=configs.retinanet.Anchor(
|
||||
|
|
|
@ -20,18 +20,30 @@ from typing import List
|
|||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
|
||||
|
||||
MOBILENET_V2_FILES = file_util.DownloadedFiles(
|
||||
'object_detector/mobilenetv2',
|
||||
MOBILENET_V2_I256_FILES = file_util.DownloadedFiles(
|
||||
'object_detector/mobilenetv2_i256',
|
||||
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz',
|
||||
is_folder=True,
|
||||
)
|
||||
|
||||
MOBILENET_V2_I320_FILES = file_util.DownloadedFiles(
|
||||
'object_detector/mobilenetv2_i320',
|
||||
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i320_ckpt.tar.gz',
|
||||
is_folder=True,
|
||||
)
|
||||
|
||||
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
|
||||
'object_detector/mobilenetmultiavg',
|
||||
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
|
||||
is_folder=True,
|
||||
)
|
||||
|
||||
MOBILENET_MULTI_AVG_I384_FILES = file_util.DownloadedFiles(
|
||||
'object_detector/mobilenetmultiavg_i384',
|
||||
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv3.5_ssd_i384_ckpt.tar.gz',
|
||||
is_folder=True,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelSpec(object):
|
||||
|
@ -48,30 +60,66 @@ class ModelSpec(object):
|
|||
input_image_shape: List[int]
|
||||
model_id: str
|
||||
|
||||
# Model Config values
|
||||
min_level: int
|
||||
max_level: int
|
||||
|
||||
mobilenet_v2_spec = functools.partial(
|
||||
|
||||
mobilenet_v2_i256_spec = functools.partial(
|
||||
ModelSpec,
|
||||
downloaded_files=MOBILENET_V2_FILES,
|
||||
downloaded_files=MOBILENET_V2_I256_FILES,
|
||||
checkpoint_name='ckpt-277200',
|
||||
input_image_shape=[256, 256, 3],
|
||||
model_id='MobileNetV2',
|
||||
min_level=3,
|
||||
max_level=7,
|
||||
)
|
||||
|
||||
mobilenet_multi_avg_spec = functools.partial(
|
||||
mobilenet_v2_i320_spec = functools.partial(
|
||||
ModelSpec,
|
||||
downloaded_files=MOBILENET_V2_I320_FILES,
|
||||
checkpoint_name='ckpt-277200',
|
||||
input_image_shape=[320, 320, 3],
|
||||
model_id='MobileNetV2',
|
||||
min_level=3,
|
||||
max_level=6,
|
||||
)
|
||||
|
||||
mobilenet_multi_avg_i256_spec = functools.partial(
|
||||
ModelSpec,
|
||||
downloaded_files=MOBILENET_MULTI_AVG_FILES,
|
||||
checkpoint_name='ckpt-277200',
|
||||
input_image_shape=[256, 256, 3],
|
||||
model_id='MobileNetMultiAVG',
|
||||
min_level=3,
|
||||
max_level=7,
|
||||
)
|
||||
|
||||
mobilenet_multi_avg_i384_spec = functools.partial(
|
||||
ModelSpec,
|
||||
downloaded_files=MOBILENET_MULTI_AVG_I384_FILES,
|
||||
checkpoint_name='ckpt-277200',
|
||||
input_image_shape=[384, 384, 3],
|
||||
model_id='MobileNetMultiAVG',
|
||||
min_level=3,
|
||||
max_level=7,
|
||||
)
|
||||
|
||||
|
||||
@enum.unique
|
||||
class SupportedModels(enum.Enum):
|
||||
"""Predefined object detector model specs supported by Model Maker."""
|
||||
"""Predefined object detector model specs supported by Model Maker.
|
||||
|
||||
MOBILENET_V2 = mobilenet_v2_spec
|
||||
MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
|
||||
Supported models include the following:
|
||||
- MOBILENET_V2: MobileNetV2 256x256 input
|
||||
- MOBILENET_V2_I320: MobileNetV2 320x320 input
|
||||
- MOBILENET_MULTI_AVG: MobileNet-MultiHW-AVG 256x256 input
|
||||
- MOBILENET_MULTI_AVG_I384: MobileNet-MultiHW-AVG 384x384 input
|
||||
"""
|
||||
MOBILENET_V2 = mobilenet_v2_i256_spec
|
||||
MOBILENET_V2_I320 = mobilenet_v2_i320_spec
|
||||
MOBILENET_MULTI_AVG = mobilenet_multi_avg_i256_spec
|
||||
MOBILENET_MULTI_AVG_I384 = mobilenet_multi_avg_i384_spec
|
||||
|
||||
@classmethod
|
||||
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||
|
|
|
@ -395,7 +395,7 @@ class ObjectDetector(classifier.Classifier):
|
|||
) -> tf.keras.optimizers.Optimizer:
|
||||
"""Creates an optimizer with learning rate schedule for regular training.
|
||||
|
||||
Uses Keras PiecewiseConstantDecay schedule by default.
|
||||
Uses Keras CosineDecay schedule by default.
|
||||
|
||||
Args:
|
||||
steps_per_epoch: Steps per epoch to calculate the step boundaries from the
|
||||
|
@ -404,6 +404,8 @@ class ObjectDetector(classifier.Classifier):
|
|||
Returns:
|
||||
A tf.keras.optimizer.Optimizer for model training.
|
||||
"""
|
||||
total_steps = steps_per_epoch * self._hparams.epochs
|
||||
warmup_steps = int(total_steps * 0.1)
|
||||
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
|
||||
decay_epochs = (
|
||||
self._hparams.cosine_decay_epochs
|
||||
|
@ -415,6 +417,11 @@ class ObjectDetector(classifier.Classifier):
|
|||
steps_per_epoch * decay_epochs,
|
||||
self._hparams.cosine_decay_alpha,
|
||||
)
|
||||
learning_rate = model_util.WarmUp(
|
||||
initial_learning_rate=init_lr,
|
||||
decay_schedule_fn=learning_rate,
|
||||
warmup_steps=warmup_steps,
|
||||
)
|
||||
return tf.keras.optimizers.experimental.SGD(
|
||||
learning_rate=learning_rate, momentum=0.9
|
||||
)
|
||||
|
|
|
@ -32,8 +32,8 @@ class Preprocessor(object):
|
|||
self._mean_norm = model_spec.mean_norm
|
||||
self._stddev_norm = model_spec.stddev_norm
|
||||
self._output_size = model_spec.input_image_shape[:2]
|
||||
self._min_level = 3
|
||||
self._max_level = 7
|
||||
self._min_level = model_spec.min_level
|
||||
self._max_level = model_spec.max_level
|
||||
self._num_scales = 3
|
||||
self._aspect_ratios = [0.5, 1, 2]
|
||||
self._anchor_size = 3
|
||||
|
|
Loading…
Reference in New Issue
Block a user