Object Detector remove nms operation from exported tflite

PiperOrigin-RevId: 529559380
This commit is contained in:
MediaPipe Team 2023-05-04 17:33:40 -07:00 committed by Copybara-Service
parent 12b0b6fad1
commit 61cfe2ca9b
7 changed files with 199 additions and 20 deletions

View File

@ -88,6 +88,17 @@ py_test(
],
)
py_library(
name = "detection",
srcs = ["detection.py"],
)
py_test(
name = "detection_test",
srcs = ["detection_test.py"],
deps = [":detection"],
)
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
@ -116,6 +127,7 @@ py_library(
name = "model",
srcs = ["model.py"],
deps = [
":detection",
":model_options",
":model_spec",
],
@ -163,6 +175,7 @@ py_library(
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_info",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:object_detector",
],

View File

@ -32,6 +32,7 @@ ObjectDetectorOptions = object_detector_options.ObjectDetectorOptions
# Remove duplicated and non-public API
del dataset
del dataset_util # pylint: disable=undefined-variable
del detection # pylint: disable=undefined-variable
del hyperparameters
del model # pylint: disable=undefined-variable
del model_options

View File

@ -0,0 +1,34 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom Detection export module for Object Detection."""
from typing import Any, Mapping
from official.vision.serving import detection
class DetectionModule(detection.DetectionModule):
"""A serving detection module for exporting the model.
This module overrides the tensorflow_models DetectionModule by only outputting
the pre-nms detection_boxes and detection_scores.
"""
def serve(self, images) -> Mapping[str, Any]:
result = super().serve(images)
final_outputs = {
'detection_boxes': result['detection_boxes'],
'detection_scores': result['detection_scores'],
}
return final_outputs

View File

@ -0,0 +1,73 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
import tensorflow as tf
from mediapipe.model_maker.python.vision.object_detector import detection
from official.core import config_definitions as cfg
from official.vision import configs
from official.vision.serving import detection as detection_module
class ObjectDetectorTest(tf.test.TestCase):
@mock.patch.object(detection_module.DetectionModule, 'serve', autospec=True)
def test_detection_module(self, mock_serve):
mock_serve.return_value = {
'detection_boxes': 1,
'detection_scores': 2,
'detection_classes': 3,
'num_detections': 4,
}
model_config = configs.retinanet.RetinaNet(
min_level=3,
max_level=7,
num_classes=10,
input_size=[256, 256, 3],
anchor=configs.retinanet.Anchor(
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
),
backbone=configs.backbones.Backbone(
type='mobilenet', mobilenet=configs.backbones.MobileNet()
),
decoder=configs.decoders.Decoder(
type='fpn',
fpn=configs.decoders.FPN(
num_filters=128, use_separable_conv=True, use_keras_layer=True
),
),
head=configs.retinanet.RetinaNetHead(
num_filters=128, use_separable_conv=True
),
detection_generator=configs.retinanet.DetectionGenerator(),
norm_activation=configs.common.NormActivation(activation='relu6'),
)
task_config = configs.retinanet.RetinaNetTask(model=model_config)
params = cfg.ExperimentConfig(
task=task_config,
)
detection_instance = detection.DetectionModule(
params=params, batch_size=1, input_image_size=[256, 256]
)
outputs = detection_instance.serve(0)
expected_outputs = {
'detection_boxes': 1,
'detection_scores': 2,
}
self.assertAllEqual(outputs, expected_outputs)
if __name__ == '__main__':
tf.test.main()

View File

@ -18,6 +18,7 @@ from typing import Mapping, Optional, Sequence, Union
import tensorflow as tf
from mediapipe.model_maker.python.vision.object_detector import detection
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
from official.core import config_definitions as cfg
@ -29,7 +30,6 @@ from official.vision.losses import loss_utils
from official.vision.modeling import factory
from official.vision.modeling import retinanet_model
from official.vision.modeling.layers import detection_generator
from official.vision.serving import detection
class ObjectDetectorModel(tf.keras.Model):
@ -199,6 +199,7 @@ class ObjectDetectorModel(tf.keras.Model):
max_detections=10,
max_classes_per_detection=1,
normalize_anchor_coordinates=True,
omit_nms=True,
),
)
tflite_post_processing_config = (

View File

@ -28,6 +28,7 @@ from mediapipe.model_maker.python.vision.object_detector import model_options as
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
from mediapipe.model_maker.python.vision.object_detector import preprocessor
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import object_detector as object_detector_writer
from official.vision.evaluation import coco_evaluator
@ -264,6 +265,27 @@ class ObjectDetector(classifier.Classifier):
coco_metrics = coco_eval.result()
return losses, coco_metrics
def _create_fixed_anchor(
self, anchor_box: List[float]
) -> object_detector_writer.FixedAnchor:
"""Helper function to create FixedAnchor objects from an anchor box array.
Args:
anchor_box: List of anchor box coordinates in the format of [x_min, y_min,
x_max, y_max].
Returns:
A FixedAnchor object representing the anchor_box.
"""
image_shape = self._model_spec.input_image_shape[:2]
y_center_norm = (anchor_box[0] + anchor_box[2]) / (2 * image_shape[0])
x_center_norm = (anchor_box[1] + anchor_box[3]) / (2 * image_shape[1])
height_norm = (anchor_box[2] - anchor_box[0]) / image_shape[0]
width_norm = (anchor_box[3] - anchor_box[1]) / image_shape[1]
return object_detector_writer.FixedAnchor(
x_center_norm, y_center_norm, width_norm, height_norm
)
def export_model(
self,
model_name: str = 'model.tflite',
@ -328,11 +350,40 @@ class ObjectDetector(classifier.Classifier):
converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,)
tflite_model = converter.convert()
writer = object_detector_writer.MetadataWriter.create_for_models_with_nms(
# Build anchors
raw_anchor_boxes = self._preprocessor.anchor_boxes
anchors = []
for _, anchor_boxes in raw_anchor_boxes.items():
anchor_boxes_reshaped = anchor_boxes.numpy().reshape((-1, 4))
for ab in anchor_boxes_reshaped:
anchors.append(self._create_fixed_anchor(ab))
ssd_anchors_options = object_detector_writer.SsdAnchorsOptions(
object_detector_writer.FixedAnchorsSchema(anchors)
)
tensor_decoding_options = object_detector_writer.TensorsDecodingOptions(
num_classes=self._num_classes,
num_boxes=len(anchors),
num_coords=4,
keypoint_coord_offset=0,
num_keypoints=0,
num_values_per_keypoint=2,
x_scale=1,
y_scale=1,
w_scale=1,
h_scale=1,
apply_exponential_on_box_size=True,
sigmoid_score=False,
)
writer = object_detector_writer.MetadataWriter.create_for_models_without_nms(
tflite_model,
self._model_spec.mean_rgb,
self._model_spec.stddev_rgb,
labels=metadata_writer.Labels().add(list(self._label_names)),
ssd_anchors_options=ssd_anchors_options,
tensors_decoding_options=tensor_decoding_options,
output_tensors_order=metadata_info.RawDetectionOutputTensorsOrder.LOCATION_SCORE,
)
tflite_model_with_metadata, metadata_json = writer.populate()
model_util.save_tflite(tflite_model_with_metadata, tflite_file)

View File

@ -44,6 +44,26 @@ class Preprocessor(object):
self._aug_scale_max = 2.0
self._max_num_instances = 100
self._padded_size = preprocess_ops.compute_padded_size(
self._output_size, 2**self._max_level
)
input_anchor = anchor.build_anchor_generator(
min_level=self._min_level,
max_level=self._max_level,
num_scales=self._num_scales,
aspect_ratios=self._aspect_ratios,
anchor_size=self._anchor_size,
)
self._anchor_boxes = input_anchor(image_size=self._output_size)
self._anchor_labeler = anchor.AnchorLabeler(
self._match_threshold, self._unmatched_threshold
)
@property
def anchor_boxes(self):
return self._anchor_boxes
def __call__(
self, data: Mapping[str, Any], is_training: bool = True
) -> Tuple[tf.Tensor, Mapping[str, Any]]:
@ -90,13 +110,10 @@ class Preprocessor(object):
image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._output_size,
padded_size=preprocess_ops.compute_padded_size(
self._output_size, 2**self._max_level
),
padded_size=self._padded_size,
aug_scale_min=(self._aug_scale_min if is_training else 1.0),
aug_scale_max=(self._aug_scale_max if is_training else 1.0),
)
image_height, image_width, _ = image.get_shape().as_list()
# Resize and crop boxes.
image_scale = image_info[2, :]
@ -110,20 +127,9 @@ class Preprocessor(object):
classes = tf.gather(classes, indices)
# Assign anchors.
input_anchor = anchor.build_anchor_generator(
min_level=self._min_level,
max_level=self._max_level,
num_scales=self._num_scales,
aspect_ratios=self._aspect_ratios,
anchor_size=self._anchor_size,
)
anchor_boxes = input_anchor(image_size=(image_height, image_width))
anchor_labeler = anchor.AnchorLabeler(
self._match_threshold, self._unmatched_threshold
)
(cls_targets, box_targets, _, cls_weights, box_weights) = (
anchor_labeler.label_anchors(
anchor_boxes, boxes, tf.expand_dims(classes, axis=1)
self._anchor_labeler.label_anchors(
self.anchor_boxes, boxes, tf.expand_dims(classes, axis=1)
)
)
@ -134,7 +140,7 @@ class Preprocessor(object):
labels = {
'cls_targets': cls_targets,
'box_targets': box_targets,
'anchor_boxes': anchor_boxes,
'anchor_boxes': self.anchor_boxes,
'cls_weights': cls_weights,
'box_weights': box_weights,
'image_info': image_info,