diff --git a/mediapipe/model_maker/python/vision/object_detector/model.py b/mediapipe/model_maker/python/vision/object_detector/model.py index eac669786..e3eb3a651 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model.py +++ b/mediapipe/model_maker/python/vision/object_detector/model.py @@ -59,7 +59,9 @@ class ObjectDetectorModel(tf.keras.Model): self._num_classes = num_classes self._model = self._build_model() checkpoint_folder = self._model_spec.downloaded_files.get_path() - checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200') + checkpoint_file = os.path.join( + checkpoint_folder, self._model_spec.checkpoint_name + ) self.load_checkpoint(checkpoint_file) self._model.summary() self.loss_trackers = [ @@ -80,7 +82,10 @@ class ObjectDetectorModel(tf.keras.Model): num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3 ), backbone=configs.backbones.Backbone( - type='mobilenet', mobilenet=configs.backbones.MobileNet() + type='mobilenet', + mobilenet=configs.backbones.MobileNet( + model_id=self._model_spec.model_id + ), ), decoder=configs.decoders.Decoder( type='fpn', diff --git a/mediapipe/model_maker/python/vision/object_detector/model_spec.py b/mediapipe/model_maker/python/vision/object_detector/model_spec.py index 2ce838c71..9c89c4ed0 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model_spec.py +++ b/mediapipe/model_maker/python/vision/object_detector/model_spec.py @@ -26,6 +26,12 @@ MOBILENET_V2_FILES = file_util.DownloadedFiles( is_folder=True, ) +MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles( + 'object_detector/mobilenetmultiavg', + 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz', + is_folder=True, +) + @dataclasses.dataclass class ModelSpec(object): @@ -38,13 +44,25 @@ class ModelSpec(object): stddev_rgb = (127.5,) downloaded_files: file_util.DownloadedFiles + checkpoint_name: str input_image_shape: List[int] + model_id: str mobilenet_v2_spec = functools.partial( ModelSpec, downloaded_files=MOBILENET_V2_FILES, + checkpoint_name='ckpt-277200', input_image_shape=[256, 256, 3], + model_id='MobileNetV2', +) + +mobilenet_multi_avg_spec = functools.partial( + ModelSpec, + downloaded_files=MOBILENET_MULTI_AVG_FILES, + checkpoint_name='ckpt-277200', + input_image_shape=[256, 256, 3], + model_id='MobileNetMultiAVG', ) @@ -53,6 +71,7 @@ class SupportedModels(enum.Enum): """Predefined object detector model specs supported by Model Maker.""" MOBILENET_V2 = mobilenet_v2_spec + MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec @classmethod def get(cls, spec: 'SupportedModels') -> 'ModelSpec':