Support MultiHW AVG Architecture for object detector
PiperOrigin-RevId: 529221127
This commit is contained in:
parent
8c324fbd77
commit
b350f72394
|
@ -59,7 +59,9 @@ class ObjectDetectorModel(tf.keras.Model):
|
||||||
self._num_classes = num_classes
|
self._num_classes = num_classes
|
||||||
self._model = self._build_model()
|
self._model = self._build_model()
|
||||||
checkpoint_folder = self._model_spec.downloaded_files.get_path()
|
checkpoint_folder = self._model_spec.downloaded_files.get_path()
|
||||||
checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200')
|
checkpoint_file = os.path.join(
|
||||||
|
checkpoint_folder, self._model_spec.checkpoint_name
|
||||||
|
)
|
||||||
self.load_checkpoint(checkpoint_file)
|
self.load_checkpoint(checkpoint_file)
|
||||||
self._model.summary()
|
self._model.summary()
|
||||||
self.loss_trackers = [
|
self.loss_trackers = [
|
||||||
|
@ -80,7 +82,10 @@ class ObjectDetectorModel(tf.keras.Model):
|
||||||
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
|
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
|
||||||
),
|
),
|
||||||
backbone=configs.backbones.Backbone(
|
backbone=configs.backbones.Backbone(
|
||||||
type='mobilenet', mobilenet=configs.backbones.MobileNet()
|
type='mobilenet',
|
||||||
|
mobilenet=configs.backbones.MobileNet(
|
||||||
|
model_id=self._model_spec.model_id
|
||||||
|
),
|
||||||
),
|
),
|
||||||
decoder=configs.decoders.Decoder(
|
decoder=configs.decoders.Decoder(
|
||||||
type='fpn',
|
type='fpn',
|
||||||
|
|
|
@ -26,6 +26,12 @@ MOBILENET_V2_FILES = file_util.DownloadedFiles(
|
||||||
is_folder=True,
|
is_folder=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
|
||||||
|
'object_detector/mobilenetmultiavg',
|
||||||
|
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ModelSpec(object):
|
class ModelSpec(object):
|
||||||
|
@ -38,13 +44,25 @@ class ModelSpec(object):
|
||||||
stddev_rgb = (127.5,)
|
stddev_rgb = (127.5,)
|
||||||
|
|
||||||
downloaded_files: file_util.DownloadedFiles
|
downloaded_files: file_util.DownloadedFiles
|
||||||
|
checkpoint_name: str
|
||||||
input_image_shape: List[int]
|
input_image_shape: List[int]
|
||||||
|
model_id: str
|
||||||
|
|
||||||
|
|
||||||
mobilenet_v2_spec = functools.partial(
|
mobilenet_v2_spec = functools.partial(
|
||||||
ModelSpec,
|
ModelSpec,
|
||||||
downloaded_files=MOBILENET_V2_FILES,
|
downloaded_files=MOBILENET_V2_FILES,
|
||||||
|
checkpoint_name='ckpt-277200',
|
||||||
input_image_shape=[256, 256, 3],
|
input_image_shape=[256, 256, 3],
|
||||||
|
model_id='MobileNetV2',
|
||||||
|
)
|
||||||
|
|
||||||
|
mobilenet_multi_avg_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
downloaded_files=MOBILENET_MULTI_AVG_FILES,
|
||||||
|
checkpoint_name='ckpt-277200',
|
||||||
|
input_image_shape=[256, 256, 3],
|
||||||
|
model_id='MobileNetMultiAVG',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,6 +71,7 @@ class SupportedModels(enum.Enum):
|
||||||
"""Predefined object detector model specs supported by Model Maker."""
|
"""Predefined object detector model specs supported by Model Maker."""
|
||||||
|
|
||||||
MOBILENET_V2 = mobilenet_v2_spec
|
MOBILENET_V2 = mobilenet_v2_spec
|
||||||
|
MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user