diff --git a/mediapipe/model_maker/python/vision/object_detector/model.py b/mediapipe/model_maker/python/vision/object_detector/model.py index 70e63d5b5..b1b4951fd 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model.py +++ b/mediapipe/model_maker/python/vision/object_detector/model.py @@ -101,14 +101,17 @@ class ObjectDetectorModel(tf.keras.Model): ) return model_config - def _build_model(self) -> tf.keras.Model: + def _build_model(self, omit_l2=False) -> tf.keras.Model: """Builds a RetinaNet object detector model.""" input_specs = tf.keras.layers.InputSpec( shape=[None] + self._model_spec.input_image_shape ) - l2_regularizer = tf.keras.regularizers.l2( - self._model_options.l2_weight_decay / 2.0 - ) + if omit_l2: + l2_regularizer = None + else: + l2_regularizer = tf.keras.regularizers.l2( + self._model_options.l2_weight_decay / 2.0 + ) model_config = self._get_model_config() return factory.build_retinanet(input_specs, model_config, l2_regularizer) @@ -167,7 +170,7 @@ class ObjectDetectorModel(tf.keras.Model): def convert_to_qat(self) -> None: """Converts the model to a QAT RetinaNet model.""" - model = self._build_model() + model = self._build_model(omit_l2=True) dummy_input = tf.zeros([1] + self._model_spec.input_image_shape) model(dummy_input, training=True) model.set_weights(self._model.get_weights())