Internal change
PiperOrigin-RevId: 538313290
This commit is contained in:
parent
d6f34f6aef
commit
4b0f3cacae
|
@ -101,14 +101,17 @@ class ObjectDetectorModel(tf.keras.Model):
|
||||||
)
|
)
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
def _build_model(self) -> tf.keras.Model:
|
def _build_model(self, omit_l2=False) -> tf.keras.Model:
|
||||||
"""Builds a RetinaNet object detector model."""
|
"""Builds a RetinaNet object detector model."""
|
||||||
input_specs = tf.keras.layers.InputSpec(
|
input_specs = tf.keras.layers.InputSpec(
|
||||||
shape=[None] + self._model_spec.input_image_shape
|
shape=[None] + self._model_spec.input_image_shape
|
||||||
)
|
)
|
||||||
l2_regularizer = tf.keras.regularizers.l2(
|
if omit_l2:
|
||||||
self._model_options.l2_weight_decay / 2.0
|
l2_regularizer = None
|
||||||
)
|
else:
|
||||||
|
l2_regularizer = tf.keras.regularizers.l2(
|
||||||
|
self._model_options.l2_weight_decay / 2.0
|
||||||
|
)
|
||||||
model_config = self._get_model_config()
|
model_config = self._get_model_config()
|
||||||
|
|
||||||
return factory.build_retinanet(input_specs, model_config, l2_regularizer)
|
return factory.build_retinanet(input_specs, model_config, l2_regularizer)
|
||||||
|
@ -167,7 +170,7 @@ class ObjectDetectorModel(tf.keras.Model):
|
||||||
|
|
||||||
def convert_to_qat(self) -> None:
|
def convert_to_qat(self) -> None:
|
||||||
"""Converts the model to a QAT RetinaNet model."""
|
"""Converts the model to a QAT RetinaNet model."""
|
||||||
model = self._build_model()
|
model = self._build_model(omit_l2=True)
|
||||||
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
|
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
|
||||||
model(dummy_input, training=True)
|
model(dummy_input, training=True)
|
||||||
model.set_weights(self._model.get_weights())
|
model.set_weights(self._model.get_weights())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user