Internal change

PiperOrigin-RevId: 538313290
This commit is contained in:
MediaPipe Team 2023-06-06 15:51:00 -07:00 committed by Copybara-Service
parent d6f34f6aef
commit 4b0f3cacae

View File

@ -101,14 +101,17 @@ class ObjectDetectorModel(tf.keras.Model):
) )
return model_config return model_config
def _build_model(self) -> tf.keras.Model: def _build_model(self, omit_l2=False) -> tf.keras.Model:
"""Builds a RetinaNet object detector model.""" """Builds a RetinaNet object detector model."""
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None] + self._model_spec.input_image_shape shape=[None] + self._model_spec.input_image_shape
) )
l2_regularizer = tf.keras.regularizers.l2( if omit_l2:
self._model_options.l2_weight_decay / 2.0 l2_regularizer = None
) else:
l2_regularizer = tf.keras.regularizers.l2(
self._model_options.l2_weight_decay / 2.0
)
model_config = self._get_model_config() model_config = self._get_model_config()
return factory.build_retinanet(input_specs, model_config, l2_regularizer) return factory.build_retinanet(input_specs, model_config, l2_regularizer)
@ -167,7 +170,7 @@ class ObjectDetectorModel(tf.keras.Model):
def convert_to_qat(self) -> None: def convert_to_qat(self) -> None:
"""Converts the model to a QAT RetinaNet model.""" """Converts the model to a QAT RetinaNet model."""
model = self._build_model() model = self._build_model(omit_l2=True)
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape) dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
model(dummy_input, training=True) model(dummy_input, training=True)
model.set_weights(self._model.get_weights()) model.set_weights(self._model.get_weights())