Support MultiHW AVG Architecture for object detector

PiperOrigin-RevId: 529221127
This commit is contained in:
MediaPipe Team 2023-05-03 16:10:51 -07:00 committed by Copybara-Service
parent 8c324fbd77
commit b350f72394
2 changed files with 26 additions and 2 deletions

View File

@ -59,7 +59,9 @@ class ObjectDetectorModel(tf.keras.Model):
self._num_classes = num_classes self._num_classes = num_classes
self._model = self._build_model() self._model = self._build_model()
checkpoint_folder = self._model_spec.downloaded_files.get_path() checkpoint_folder = self._model_spec.downloaded_files.get_path()
checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200') checkpoint_file = os.path.join(
checkpoint_folder, self._model_spec.checkpoint_name
)
self.load_checkpoint(checkpoint_file) self.load_checkpoint(checkpoint_file)
self._model.summary() self._model.summary()
self.loss_trackers = [ self.loss_trackers = [
@ -80,7 +82,10 @@ class ObjectDetectorModel(tf.keras.Model):
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3 num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
), ),
backbone=configs.backbones.Backbone( backbone=configs.backbones.Backbone(
type='mobilenet', mobilenet=configs.backbones.MobileNet() type='mobilenet',
mobilenet=configs.backbones.MobileNet(
model_id=self._model_spec.model_id
),
), ),
decoder=configs.decoders.Decoder( decoder=configs.decoders.Decoder(
type='fpn', type='fpn',

View File

@ -26,6 +26,12 @@ MOBILENET_V2_FILES = file_util.DownloadedFiles(
is_folder=True, is_folder=True,
) )
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
'object_detector/mobilenetmultiavg',
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
is_folder=True,
)
@dataclasses.dataclass @dataclasses.dataclass
class ModelSpec(object): class ModelSpec(object):
@ -38,13 +44,25 @@ class ModelSpec(object):
stddev_rgb = (127.5,) stddev_rgb = (127.5,)
downloaded_files: file_util.DownloadedFiles downloaded_files: file_util.DownloadedFiles
checkpoint_name: str
input_image_shape: List[int] input_image_shape: List[int]
model_id: str
mobilenet_v2_spec = functools.partial( mobilenet_v2_spec = functools.partial(
ModelSpec, ModelSpec,
downloaded_files=MOBILENET_V2_FILES, downloaded_files=MOBILENET_V2_FILES,
checkpoint_name='ckpt-277200',
input_image_shape=[256, 256, 3], input_image_shape=[256, 256, 3],
model_id='MobileNetV2',
)
mobilenet_multi_avg_spec = functools.partial(
ModelSpec,
downloaded_files=MOBILENET_MULTI_AVG_FILES,
checkpoint_name='ckpt-277200',
input_image_shape=[256, 256, 3],
model_id='MobileNetMultiAVG',
) )
@ -53,6 +71,7 @@ class SupportedModels(enum.Enum):
"""Predefined object detector model specs supported by Model Maker.""" """Predefined object detector model specs supported by Model Maker."""
MOBILENET_V2 = mobilenet_v2_spec MOBILENET_V2 = mobilenet_v2_spec
MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
@classmethod @classmethod
def get(cls, spec: 'SupportedModels') -> 'ModelSpec': def get(cls, spec: 'SupportedModels') -> 'ModelSpec':