Internal change
PiperOrigin-RevId: 538313290
This commit is contained in:
parent
d6f34f6aef
commit
4b0f3cacae
|
@ -101,14 +101,17 @@ class ObjectDetectorModel(tf.keras.Model):
|
|||
)
|
||||
return model_config
|
||||
|
||||
def _build_model(self) -> tf.keras.Model:
|
||||
def _build_model(self, omit_l2=False) -> tf.keras.Model:
|
||||
"""Builds a RetinaNet object detector model."""
|
||||
input_specs = tf.keras.layers.InputSpec(
|
||||
shape=[None] + self._model_spec.input_image_shape
|
||||
)
|
||||
l2_regularizer = tf.keras.regularizers.l2(
|
||||
self._model_options.l2_weight_decay / 2.0
|
||||
)
|
||||
if omit_l2:
|
||||
l2_regularizer = None
|
||||
else:
|
||||
l2_regularizer = tf.keras.regularizers.l2(
|
||||
self._model_options.l2_weight_decay / 2.0
|
||||
)
|
||||
model_config = self._get_model_config()
|
||||
|
||||
return factory.build_retinanet(input_specs, model_config, l2_regularizer)
|
||||
|
@ -167,7 +170,7 @@ class ObjectDetectorModel(tf.keras.Model):
|
|||
|
||||
def convert_to_qat(self) -> None:
|
||||
"""Converts the model to a QAT RetinaNet model."""
|
||||
model = self._build_model()
|
||||
model = self._build_model(omit_l2=True)
|
||||
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
|
||||
model(dummy_input, training=True)
|
||||
model.set_weights(self._model.get_weights())
|
||||
|
|
Loading…
Reference in New Issue
Block a user