Object Detector remove nms operation from exported tflite
PiperOrigin-RevId: 529559380
This commit is contained in:
parent
12b0b6fad1
commit
61cfe2ca9b
|
@ -88,6 +88,17 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "detection",
|
||||
srcs = ["detection.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "detection_test",
|
||||
srcs = ["detection_test.py"],
|
||||
deps = [":detection"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
|
@ -116,6 +127,7 @@ py_library(
|
|||
name = "model",
|
||||
srcs = ["model.py"],
|
||||
deps = [
|
||||
":detection",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
],
|
||||
|
@ -163,6 +175,7 @@ py_library(
|
|||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_info",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:object_detector",
|
||||
],
|
||||
|
|
|
@ -32,6 +32,7 @@ ObjectDetectorOptions = object_detector_options.ObjectDetectorOptions
|
|||
# Remove duplicated and non-public API
|
||||
del dataset
|
||||
del dataset_util # pylint: disable=undefined-variable
|
||||
del detection # pylint: disable=undefined-variable
|
||||
del hyperparameters
|
||||
del model # pylint: disable=undefined-variable
|
||||
del model_options
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Custom Detection export module for Object Detection."""
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
from official.vision.serving import detection
|
||||
|
||||
|
||||
class DetectionModule(detection.DetectionModule):
|
||||
"""A serving detection module for exporting the model.
|
||||
|
||||
This module overrides the tensorflow_models DetectionModule by only outputting
|
||||
the pre-nms detection_boxes and detection_scores.
|
||||
"""
|
||||
|
||||
def serve(self, images) -> Mapping[str, Any]:
|
||||
result = super().serve(images)
|
||||
final_outputs = {
|
||||
'detection_boxes': result['detection_boxes'],
|
||||
'detection_scores': result['detection_scores'],
|
||||
}
|
||||
return final_outputs
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import mock
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.object_detector import detection
|
||||
from official.core import config_definitions as cfg
|
||||
from official.vision import configs
|
||||
from official.vision.serving import detection as detection_module
|
||||
|
||||
|
||||
class ObjectDetectorTest(tf.test.TestCase):
|
||||
|
||||
@mock.patch.object(detection_module.DetectionModule, 'serve', autospec=True)
|
||||
def test_detection_module(self, mock_serve):
|
||||
mock_serve.return_value = {
|
||||
'detection_boxes': 1,
|
||||
'detection_scores': 2,
|
||||
'detection_classes': 3,
|
||||
'num_detections': 4,
|
||||
}
|
||||
model_config = configs.retinanet.RetinaNet(
|
||||
min_level=3,
|
||||
max_level=7,
|
||||
num_classes=10,
|
||||
input_size=[256, 256, 3],
|
||||
anchor=configs.retinanet.Anchor(
|
||||
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
|
||||
),
|
||||
backbone=configs.backbones.Backbone(
|
||||
type='mobilenet', mobilenet=configs.backbones.MobileNet()
|
||||
),
|
||||
decoder=configs.decoders.Decoder(
|
||||
type='fpn',
|
||||
fpn=configs.decoders.FPN(
|
||||
num_filters=128, use_separable_conv=True, use_keras_layer=True
|
||||
),
|
||||
),
|
||||
head=configs.retinanet.RetinaNetHead(
|
||||
num_filters=128, use_separable_conv=True
|
||||
),
|
||||
detection_generator=configs.retinanet.DetectionGenerator(),
|
||||
norm_activation=configs.common.NormActivation(activation='relu6'),
|
||||
)
|
||||
task_config = configs.retinanet.RetinaNetTask(model=model_config)
|
||||
params = cfg.ExperimentConfig(
|
||||
task=task_config,
|
||||
)
|
||||
detection_instance = detection.DetectionModule(
|
||||
params=params, batch_size=1, input_image_size=[256, 256]
|
||||
)
|
||||
outputs = detection_instance.serve(0)
|
||||
expected_outputs = {
|
||||
'detection_boxes': 1,
|
||||
'detection_scores': 2,
|
||||
}
|
||||
self.assertAllEqual(outputs, expected_outputs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -18,6 +18,7 @@ from typing import Mapping, Optional, Sequence, Union
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.object_detector import detection
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||
from official.core import config_definitions as cfg
|
||||
|
@ -29,7 +30,6 @@ from official.vision.losses import loss_utils
|
|||
from official.vision.modeling import factory
|
||||
from official.vision.modeling import retinanet_model
|
||||
from official.vision.modeling.layers import detection_generator
|
||||
from official.vision.serving import detection
|
||||
|
||||
|
||||
class ObjectDetectorModel(tf.keras.Model):
|
||||
|
@ -199,6 +199,7 @@ class ObjectDetectorModel(tf.keras.Model):
|
|||
max_detections=10,
|
||||
max_classes_per_detection=1,
|
||||
normalize_anchor_coordinates=True,
|
||||
omit_nms=True,
|
||||
),
|
||||
)
|
||||
tflite_post_processing_config = (
|
||||
|
|
|
@ -28,6 +28,7 @@ from mediapipe.model_maker.python.vision.object_detector import model_options as
|
|||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
||||
from mediapipe.model_maker.python.vision.object_detector import preprocessor
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import object_detector as object_detector_writer
|
||||
from official.vision.evaluation import coco_evaluator
|
||||
|
@ -264,6 +265,27 @@ class ObjectDetector(classifier.Classifier):
|
|||
coco_metrics = coco_eval.result()
|
||||
return losses, coco_metrics
|
||||
|
||||
def _create_fixed_anchor(
|
||||
self, anchor_box: List[float]
|
||||
) -> object_detector_writer.FixedAnchor:
|
||||
"""Helper function to create FixedAnchor objects from an anchor box array.
|
||||
|
||||
Args:
|
||||
anchor_box: List of anchor box coordinates in the format of [x_min, y_min,
|
||||
x_max, y_max].
|
||||
|
||||
Returns:
|
||||
A FixedAnchor object representing the anchor_box.
|
||||
"""
|
||||
image_shape = self._model_spec.input_image_shape[:2]
|
||||
y_center_norm = (anchor_box[0] + anchor_box[2]) / (2 * image_shape[0])
|
||||
x_center_norm = (anchor_box[1] + anchor_box[3]) / (2 * image_shape[1])
|
||||
height_norm = (anchor_box[2] - anchor_box[0]) / image_shape[0]
|
||||
width_norm = (anchor_box[3] - anchor_box[1]) / image_shape[1]
|
||||
return object_detector_writer.FixedAnchor(
|
||||
x_center_norm, y_center_norm, width_norm, height_norm
|
||||
)
|
||||
|
||||
def export_model(
|
||||
self,
|
||||
model_name: str = 'model.tflite',
|
||||
|
@ -328,11 +350,40 @@ class ObjectDetector(classifier.Classifier):
|
|||
converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,)
|
||||
tflite_model = converter.convert()
|
||||
|
||||
writer = object_detector_writer.MetadataWriter.create_for_models_with_nms(
|
||||
# Build anchors
|
||||
raw_anchor_boxes = self._preprocessor.anchor_boxes
|
||||
anchors = []
|
||||
for _, anchor_boxes in raw_anchor_boxes.items():
|
||||
anchor_boxes_reshaped = anchor_boxes.numpy().reshape((-1, 4))
|
||||
for ab in anchor_boxes_reshaped:
|
||||
anchors.append(self._create_fixed_anchor(ab))
|
||||
|
||||
ssd_anchors_options = object_detector_writer.SsdAnchorsOptions(
|
||||
object_detector_writer.FixedAnchorsSchema(anchors)
|
||||
)
|
||||
|
||||
tensor_decoding_options = object_detector_writer.TensorsDecodingOptions(
|
||||
num_classes=self._num_classes,
|
||||
num_boxes=len(anchors),
|
||||
num_coords=4,
|
||||
keypoint_coord_offset=0,
|
||||
num_keypoints=0,
|
||||
num_values_per_keypoint=2,
|
||||
x_scale=1,
|
||||
y_scale=1,
|
||||
w_scale=1,
|
||||
h_scale=1,
|
||||
apply_exponential_on_box_size=True,
|
||||
sigmoid_score=False,
|
||||
)
|
||||
writer = object_detector_writer.MetadataWriter.create_for_models_without_nms(
|
||||
tflite_model,
|
||||
self._model_spec.mean_rgb,
|
||||
self._model_spec.stddev_rgb,
|
||||
labels=metadata_writer.Labels().add(list(self._label_names)),
|
||||
ssd_anchors_options=ssd_anchors_options,
|
||||
tensors_decoding_options=tensor_decoding_options,
|
||||
output_tensors_order=metadata_info.RawDetectionOutputTensorsOrder.LOCATION_SCORE,
|
||||
)
|
||||
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||
|
|
|
@ -44,6 +44,26 @@ class Preprocessor(object):
|
|||
self._aug_scale_max = 2.0
|
||||
self._max_num_instances = 100
|
||||
|
||||
self._padded_size = preprocess_ops.compute_padded_size(
|
||||
self._output_size, 2**self._max_level
|
||||
)
|
||||
|
||||
input_anchor = anchor.build_anchor_generator(
|
||||
min_level=self._min_level,
|
||||
max_level=self._max_level,
|
||||
num_scales=self._num_scales,
|
||||
aspect_ratios=self._aspect_ratios,
|
||||
anchor_size=self._anchor_size,
|
||||
)
|
||||
self._anchor_boxes = input_anchor(image_size=self._output_size)
|
||||
self._anchor_labeler = anchor.AnchorLabeler(
|
||||
self._match_threshold, self._unmatched_threshold
|
||||
)
|
||||
|
||||
@property
|
||||
def anchor_boxes(self):
|
||||
return self._anchor_boxes
|
||||
|
||||
def __call__(
|
||||
self, data: Mapping[str, Any], is_training: bool = True
|
||||
) -> Tuple[tf.Tensor, Mapping[str, Any]]:
|
||||
|
@ -90,13 +110,10 @@ class Preprocessor(object):
|
|||
image, image_info = preprocess_ops.resize_and_crop_image(
|
||||
image,
|
||||
self._output_size,
|
||||
padded_size=preprocess_ops.compute_padded_size(
|
||||
self._output_size, 2**self._max_level
|
||||
),
|
||||
padded_size=self._padded_size,
|
||||
aug_scale_min=(self._aug_scale_min if is_training else 1.0),
|
||||
aug_scale_max=(self._aug_scale_max if is_training else 1.0),
|
||||
)
|
||||
image_height, image_width, _ = image.get_shape().as_list()
|
||||
|
||||
# Resize and crop boxes.
|
||||
image_scale = image_info[2, :]
|
||||
|
@ -110,20 +127,9 @@ class Preprocessor(object):
|
|||
classes = tf.gather(classes, indices)
|
||||
|
||||
# Assign anchors.
|
||||
input_anchor = anchor.build_anchor_generator(
|
||||
min_level=self._min_level,
|
||||
max_level=self._max_level,
|
||||
num_scales=self._num_scales,
|
||||
aspect_ratios=self._aspect_ratios,
|
||||
anchor_size=self._anchor_size,
|
||||
)
|
||||
anchor_boxes = input_anchor(image_size=(image_height, image_width))
|
||||
anchor_labeler = anchor.AnchorLabeler(
|
||||
self._match_threshold, self._unmatched_threshold
|
||||
)
|
||||
(cls_targets, box_targets, _, cls_weights, box_weights) = (
|
||||
anchor_labeler.label_anchors(
|
||||
anchor_boxes, boxes, tf.expand_dims(classes, axis=1)
|
||||
self._anchor_labeler.label_anchors(
|
||||
self.anchor_boxes, boxes, tf.expand_dims(classes, axis=1)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -134,7 +140,7 @@ class Preprocessor(object):
|
|||
labels = {
|
||||
'cls_targets': cls_targets,
|
||||
'box_targets': box_targets,
|
||||
'anchor_boxes': anchor_boxes,
|
||||
'anchor_boxes': self.anchor_boxes,
|
||||
'cls_weights': cls_weights,
|
||||
'box_weights': box_weights,
|
||||
'image_info': image_info,
|
||||
|
|
Loading…
Reference in New Issue
Block a user