Internal change

PiperOrigin-RevId: 538313290
This commit is contained in:
MediaPipe Team 2023-06-06 15:51:00 -07:00 committed by Copybara-Service
parent d6f34f6aef
commit 4b0f3cacae

View File

@ -101,14 +101,17 @@ class ObjectDetectorModel(tf.keras.Model):
)
return model_config
def _build_model(self) -> tf.keras.Model:
def _build_model(self, omit_l2=False) -> tf.keras.Model:
"""Builds a RetinaNet object detector model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self._model_spec.input_image_shape
)
l2_regularizer = tf.keras.regularizers.l2(
self._model_options.l2_weight_decay / 2.0
)
if omit_l2:
l2_regularizer = None
else:
l2_regularizer = tf.keras.regularizers.l2(
self._model_options.l2_weight_decay / 2.0
)
model_config = self._get_model_config()
return factory.build_retinanet(input_specs, model_config, l2_regularizer)
@ -167,7 +170,7 @@ class ObjectDetectorModel(tf.keras.Model):
def convert_to_qat(self) -> None:
"""Converts the model to a QAT RetinaNet model."""
model = self._build_model()
model = self._build_model(omit_l2=True)
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
model(dummy_input, training=True)
model.set_weights(self._model.get_weights())