Support MultiHW AVG Architecture for object detector
PiperOrigin-RevId: 529221127
This commit is contained in:
		
							parent
							
								
									8c324fbd77
								
							
						
					
					
						commit
						b350f72394
					
				| 
						 | 
				
			
			@ -59,7 +59,9 @@ class ObjectDetectorModel(tf.keras.Model):
 | 
			
		|||
    self._num_classes = num_classes
 | 
			
		||||
    self._model = self._build_model()
 | 
			
		||||
    checkpoint_folder = self._model_spec.downloaded_files.get_path()
 | 
			
		||||
    checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200')
 | 
			
		||||
    checkpoint_file = os.path.join(
 | 
			
		||||
        checkpoint_folder, self._model_spec.checkpoint_name
 | 
			
		||||
    )
 | 
			
		||||
    self.load_checkpoint(checkpoint_file)
 | 
			
		||||
    self._model.summary()
 | 
			
		||||
    self.loss_trackers = [
 | 
			
		||||
| 
						 | 
				
			
			@ -80,7 +82,10 @@ class ObjectDetectorModel(tf.keras.Model):
 | 
			
		|||
            num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
 | 
			
		||||
        ),
 | 
			
		||||
        backbone=configs.backbones.Backbone(
 | 
			
		||||
            type='mobilenet', mobilenet=configs.backbones.MobileNet()
 | 
			
		||||
            type='mobilenet',
 | 
			
		||||
            mobilenet=configs.backbones.MobileNet(
 | 
			
		||||
                model_id=self._model_spec.model_id
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
        decoder=configs.decoders.Decoder(
 | 
			
		||||
            type='fpn',
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,6 +26,12 @@ MOBILENET_V2_FILES = file_util.DownloadedFiles(
 | 
			
		|||
    is_folder=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'object_detector/mobilenetmultiavg',
 | 
			
		||||
    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
 | 
			
		||||
    is_folder=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class ModelSpec(object):
 | 
			
		||||
| 
						 | 
				
			
			@ -38,13 +44,25 @@ class ModelSpec(object):
 | 
			
		|||
  stddev_rgb = (127.5,)
 | 
			
		||||
 | 
			
		||||
  downloaded_files: file_util.DownloadedFiles
 | 
			
		||||
  checkpoint_name: str
 | 
			
		||||
  input_image_shape: List[int]
 | 
			
		||||
  model_id: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
mobilenet_v2_spec = functools.partial(
 | 
			
		||||
    ModelSpec,
 | 
			
		||||
    downloaded_files=MOBILENET_V2_FILES,
 | 
			
		||||
    checkpoint_name='ckpt-277200',
 | 
			
		||||
    input_image_shape=[256, 256, 3],
 | 
			
		||||
    model_id='MobileNetV2',
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mobilenet_multi_avg_spec = functools.partial(
 | 
			
		||||
    ModelSpec,
 | 
			
		||||
    downloaded_files=MOBILENET_MULTI_AVG_FILES,
 | 
			
		||||
    checkpoint_name='ckpt-277200',
 | 
			
		||||
    input_image_shape=[256, 256, 3],
 | 
			
		||||
    model_id='MobileNetMultiAVG',
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -53,6 +71,7 @@ class SupportedModels(enum.Enum):
 | 
			
		|||
  """Predefined object detector model specs supported by Model Maker."""
 | 
			
		||||
 | 
			
		||||
  MOBILENET_V2 = mobilenet_v2_spec
 | 
			
		||||
  MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
  def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user