Add MobileNetV2_I320 and MobileNetMultiHWAVG_I384 to support larger input image sizes.
PiperOrigin-RevId: 544393692
This commit is contained in:
parent
8278dbc38f
commit
0bb4ee8941
|
@ -74,8 +74,8 @@ class ObjectDetectorModel(tf.keras.Model):
|
||||||
generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
|
generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
|
||||||
) -> configs.retinanet.RetinaNet:
|
) -> configs.retinanet.RetinaNet:
|
||||||
model_config = configs.retinanet.RetinaNet(
|
model_config = configs.retinanet.RetinaNet(
|
||||||
min_level=3,
|
min_level=self._model_spec.min_level,
|
||||||
max_level=7,
|
max_level=self._model_spec.max_level,
|
||||||
num_classes=self._num_classes,
|
num_classes=self._num_classes,
|
||||||
input_size=self._model_spec.input_image_shape,
|
input_size=self._model_spec.input_image_shape,
|
||||||
anchor=configs.retinanet.Anchor(
|
anchor=configs.retinanet.Anchor(
|
||||||
|
|
|
@ -20,18 +20,30 @@ from typing import List
|
||||||
from mediapipe.model_maker.python.core.utils import file_util
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
|
|
||||||
|
|
||||||
MOBILENET_V2_FILES = file_util.DownloadedFiles(
|
MOBILENET_V2_I256_FILES = file_util.DownloadedFiles(
|
||||||
'object_detector/mobilenetv2',
|
'object_detector/mobilenetv2_i256',
|
||||||
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz',
|
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz',
|
||||||
is_folder=True,
|
is_folder=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MOBILENET_V2_I320_FILES = file_util.DownloadedFiles(
|
||||||
|
'object_detector/mobilenetv2_i320',
|
||||||
|
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i320_ckpt.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
|
||||||
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
|
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
|
||||||
'object_detector/mobilenetmultiavg',
|
'object_detector/mobilenetmultiavg',
|
||||||
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
|
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
|
||||||
is_folder=True,
|
is_folder=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MOBILENET_MULTI_AVG_I384_FILES = file_util.DownloadedFiles(
|
||||||
|
'object_detector/mobilenetmultiavg_i384',
|
||||||
|
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv3.5_ssd_i384_ckpt.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ModelSpec(object):
|
class ModelSpec(object):
|
||||||
|
@ -48,30 +60,66 @@ class ModelSpec(object):
|
||||||
input_image_shape: List[int]
|
input_image_shape: List[int]
|
||||||
model_id: str
|
model_id: str
|
||||||
|
|
||||||
|
# Model Config values
|
||||||
|
min_level: int
|
||||||
|
max_level: int
|
||||||
|
|
||||||
mobilenet_v2_spec = functools.partial(
|
|
||||||
|
mobilenet_v2_i256_spec = functools.partial(
|
||||||
ModelSpec,
|
ModelSpec,
|
||||||
downloaded_files=MOBILENET_V2_FILES,
|
downloaded_files=MOBILENET_V2_I256_FILES,
|
||||||
checkpoint_name='ckpt-277200',
|
checkpoint_name='ckpt-277200',
|
||||||
input_image_shape=[256, 256, 3],
|
input_image_shape=[256, 256, 3],
|
||||||
model_id='MobileNetV2',
|
model_id='MobileNetV2',
|
||||||
|
min_level=3,
|
||||||
|
max_level=7,
|
||||||
)
|
)
|
||||||
|
|
||||||
mobilenet_multi_avg_spec = functools.partial(
|
mobilenet_v2_i320_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
downloaded_files=MOBILENET_V2_I320_FILES,
|
||||||
|
checkpoint_name='ckpt-277200',
|
||||||
|
input_image_shape=[320, 320, 3],
|
||||||
|
model_id='MobileNetV2',
|
||||||
|
min_level=3,
|
||||||
|
max_level=6,
|
||||||
|
)
|
||||||
|
|
||||||
|
mobilenet_multi_avg_i256_spec = functools.partial(
|
||||||
ModelSpec,
|
ModelSpec,
|
||||||
downloaded_files=MOBILENET_MULTI_AVG_FILES,
|
downloaded_files=MOBILENET_MULTI_AVG_FILES,
|
||||||
checkpoint_name='ckpt-277200',
|
checkpoint_name='ckpt-277200',
|
||||||
input_image_shape=[256, 256, 3],
|
input_image_shape=[256, 256, 3],
|
||||||
model_id='MobileNetMultiAVG',
|
model_id='MobileNetMultiAVG',
|
||||||
|
min_level=3,
|
||||||
|
max_level=7,
|
||||||
|
)
|
||||||
|
|
||||||
|
mobilenet_multi_avg_i384_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
downloaded_files=MOBILENET_MULTI_AVG_I384_FILES,
|
||||||
|
checkpoint_name='ckpt-277200',
|
||||||
|
input_image_shape=[384, 384, 3],
|
||||||
|
model_id='MobileNetMultiAVG',
|
||||||
|
min_level=3,
|
||||||
|
max_level=7,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@enum.unique
|
@enum.unique
|
||||||
class SupportedModels(enum.Enum):
|
class SupportedModels(enum.Enum):
|
||||||
"""Predefined object detector model specs supported by Model Maker."""
|
"""Predefined object detector model specs supported by Model Maker.
|
||||||
|
|
||||||
MOBILENET_V2 = mobilenet_v2_spec
|
Supported models include the following:
|
||||||
MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
|
- MOBILENET_V2: MobileNetV2 256x256 input
|
||||||
|
- MOBILENET_V2_I320: MobileNetV2 320x320 input
|
||||||
|
- MOBILENET_MULTI_AVG: MobileNet-MultiHW-AVG 256x256 input
|
||||||
|
- MOBILENET_MULTI_AVG_I384: MobileNet-MultiHW-AVG 384x384 input
|
||||||
|
"""
|
||||||
|
MOBILENET_V2 = mobilenet_v2_i256_spec
|
||||||
|
MOBILENET_V2_I320 = mobilenet_v2_i320_spec
|
||||||
|
MOBILENET_MULTI_AVG = mobilenet_multi_avg_i256_spec
|
||||||
|
MOBILENET_MULTI_AVG_I384 = mobilenet_multi_avg_i384_spec
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||||
|
|
|
@ -395,7 +395,7 @@ class ObjectDetector(classifier.Classifier):
|
||||||
) -> tf.keras.optimizers.Optimizer:
|
) -> tf.keras.optimizers.Optimizer:
|
||||||
"""Creates an optimizer with learning rate schedule for regular training.
|
"""Creates an optimizer with learning rate schedule for regular training.
|
||||||
|
|
||||||
Uses Keras PiecewiseConstantDecay schedule by default.
|
Uses Keras CosineDecay schedule by default.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
steps_per_epoch: Steps per epoch to calculate the step boundaries from the
|
steps_per_epoch: Steps per epoch to calculate the step boundaries from the
|
||||||
|
@ -404,6 +404,8 @@ class ObjectDetector(classifier.Classifier):
|
||||||
Returns:
|
Returns:
|
||||||
A tf.keras.optimizer.Optimizer for model training.
|
A tf.keras.optimizer.Optimizer for model training.
|
||||||
"""
|
"""
|
||||||
|
total_steps = steps_per_epoch * self._hparams.epochs
|
||||||
|
warmup_steps = int(total_steps * 0.1)
|
||||||
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
|
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
|
||||||
decay_epochs = (
|
decay_epochs = (
|
||||||
self._hparams.cosine_decay_epochs
|
self._hparams.cosine_decay_epochs
|
||||||
|
@ -415,6 +417,11 @@ class ObjectDetector(classifier.Classifier):
|
||||||
steps_per_epoch * decay_epochs,
|
steps_per_epoch * decay_epochs,
|
||||||
self._hparams.cosine_decay_alpha,
|
self._hparams.cosine_decay_alpha,
|
||||||
)
|
)
|
||||||
|
learning_rate = model_util.WarmUp(
|
||||||
|
initial_learning_rate=init_lr,
|
||||||
|
decay_schedule_fn=learning_rate,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
)
|
||||||
return tf.keras.optimizers.experimental.SGD(
|
return tf.keras.optimizers.experimental.SGD(
|
||||||
learning_rate=learning_rate, momentum=0.9
|
learning_rate=learning_rate, momentum=0.9
|
||||||
)
|
)
|
||||||
|
|
|
@ -32,8 +32,8 @@ class Preprocessor(object):
|
||||||
self._mean_norm = model_spec.mean_norm
|
self._mean_norm = model_spec.mean_norm
|
||||||
self._stddev_norm = model_spec.stddev_norm
|
self._stddev_norm = model_spec.stddev_norm
|
||||||
self._output_size = model_spec.input_image_shape[:2]
|
self._output_size = model_spec.input_image_shape[:2]
|
||||||
self._min_level = 3
|
self._min_level = model_spec.min_level
|
||||||
self._max_level = 7
|
self._max_level = model_spec.max_level
|
||||||
self._num_scales = 3
|
self._num_scales = 3
|
||||||
self._aspect_ratios = [0.5, 1, 2]
|
self._aspect_ratios = [0.5, 1, 2]
|
||||||
self._anchor_size = 3
|
self._anchor_size = 3
|
||||||
|
|
Loading…
Reference in New Issue
Block a user