Support MultiHW AVG Architecture for object detector
PiperOrigin-RevId: 529221127
This commit is contained in:
parent
8c324fbd77
commit
b350f72394
|
@ -59,7 +59,9 @@ class ObjectDetectorModel(tf.keras.Model):
|
|||
self._num_classes = num_classes
|
||||
self._model = self._build_model()
|
||||
checkpoint_folder = self._model_spec.downloaded_files.get_path()
|
||||
checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200')
|
||||
checkpoint_file = os.path.join(
|
||||
checkpoint_folder, self._model_spec.checkpoint_name
|
||||
)
|
||||
self.load_checkpoint(checkpoint_file)
|
||||
self._model.summary()
|
||||
self.loss_trackers = [
|
||||
|
@ -80,7 +82,10 @@ class ObjectDetectorModel(tf.keras.Model):
|
|||
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
|
||||
),
|
||||
backbone=configs.backbones.Backbone(
|
||||
type='mobilenet', mobilenet=configs.backbones.MobileNet()
|
||||
type='mobilenet',
|
||||
mobilenet=configs.backbones.MobileNet(
|
||||
model_id=self._model_spec.model_id
|
||||
),
|
||||
),
|
||||
decoder=configs.decoders.Decoder(
|
||||
type='fpn',
|
||||
|
|
|
@ -26,6 +26,12 @@ MOBILENET_V2_FILES = file_util.DownloadedFiles(
|
|||
is_folder=True,
|
||||
)
|
||||
|
||||
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
|
||||
'object_detector/mobilenetmultiavg',
|
||||
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
|
||||
is_folder=True,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelSpec(object):
|
||||
|
@ -38,13 +44,25 @@ class ModelSpec(object):
|
|||
stddev_rgb = (127.5,)
|
||||
|
||||
downloaded_files: file_util.DownloadedFiles
|
||||
checkpoint_name: str
|
||||
input_image_shape: List[int]
|
||||
model_id: str
|
||||
|
||||
|
||||
mobilenet_v2_spec = functools.partial(
|
||||
ModelSpec,
|
||||
downloaded_files=MOBILENET_V2_FILES,
|
||||
checkpoint_name='ckpt-277200',
|
||||
input_image_shape=[256, 256, 3],
|
||||
model_id='MobileNetV2',
|
||||
)
|
||||
|
||||
mobilenet_multi_avg_spec = functools.partial(
|
||||
ModelSpec,
|
||||
downloaded_files=MOBILENET_MULTI_AVG_FILES,
|
||||
checkpoint_name='ckpt-277200',
|
||||
input_image_shape=[256, 256, 3],
|
||||
model_id='MobileNetMultiAVG',
|
||||
)
|
||||
|
||||
|
||||
|
@ -53,6 +71,7 @@ class SupportedModels(enum.Enum):
|
|||
"""Predefined object detector model specs supported by Model Maker."""
|
||||
|
||||
MOBILENET_V2 = mobilenet_v2_spec
|
||||
MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
|
||||
|
||||
@classmethod
|
||||
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||
|
|
Loading…
Reference in New Issue
Block a user