Object Detector remove nms operation from exported tflite
PiperOrigin-RevId: 529559380
This commit is contained in:
parent
12b0b6fad1
commit
61cfe2ca9b
|
@ -88,6 +88,17 @@ py_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "detection",
|
||||||
|
srcs = ["detection.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "detection_test",
|
||||||
|
srcs = ["detection_test.py"],
|
||||||
|
deps = [":detection"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "hyperparameters",
|
name = "hyperparameters",
|
||||||
srcs = ["hyperparameters.py"],
|
srcs = ["hyperparameters.py"],
|
||||||
|
@ -116,6 +127,7 @@ py_library(
|
||||||
name = "model",
|
name = "model",
|
||||||
srcs = ["model.py"],
|
srcs = ["model.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":detection",
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
],
|
],
|
||||||
|
@ -163,6 +175,7 @@ py_library(
|
||||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_info",
|
||||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
"//mediapipe/tasks/python/metadata/metadata_writers:object_detector",
|
"//mediapipe/tasks/python/metadata/metadata_writers:object_detector",
|
||||||
],
|
],
|
||||||
|
|
|
@ -32,6 +32,7 @@ ObjectDetectorOptions = object_detector_options.ObjectDetectorOptions
|
||||||
# Remove duplicated and non-public API
|
# Remove duplicated and non-public API
|
||||||
del dataset
|
del dataset
|
||||||
del dataset_util # pylint: disable=undefined-variable
|
del dataset_util # pylint: disable=undefined-variable
|
||||||
|
del detection # pylint: disable=undefined-variable
|
||||||
del hyperparameters
|
del hyperparameters
|
||||||
del model # pylint: disable=undefined-variable
|
del model # pylint: disable=undefined-variable
|
||||||
del model_options
|
del model_options
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Custom Detection export module for Object Detection."""
|
||||||
|
|
||||||
|
from typing import Any, Mapping
|
||||||
|
|
||||||
|
from official.vision.serving import detection
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionModule(detection.DetectionModule):
|
||||||
|
"""A serving detection module for exporting the model.
|
||||||
|
|
||||||
|
This module overrides the tensorflow_models DetectionModule by only outputting
|
||||||
|
the pre-nms detection_boxes and detection_scores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def serve(self, images) -> Mapping[str, Any]:
|
||||||
|
result = super().serve(images)
|
||||||
|
final_outputs = {
|
||||||
|
'detection_boxes': result['detection_boxes'],
|
||||||
|
'detection_scores': result['detection_scores'],
|
||||||
|
}
|
||||||
|
return final_outputs
|
|
@ -0,0 +1,73 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from unittest import mock
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import detection
|
||||||
|
from official.core import config_definitions as cfg
|
||||||
|
from official.vision import configs
|
||||||
|
from official.vision.serving import detection as detection_module
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetectorTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
@mock.patch.object(detection_module.DetectionModule, 'serve', autospec=True)
|
||||||
|
def test_detection_module(self, mock_serve):
|
||||||
|
mock_serve.return_value = {
|
||||||
|
'detection_boxes': 1,
|
||||||
|
'detection_scores': 2,
|
||||||
|
'detection_classes': 3,
|
||||||
|
'num_detections': 4,
|
||||||
|
}
|
||||||
|
model_config = configs.retinanet.RetinaNet(
|
||||||
|
min_level=3,
|
||||||
|
max_level=7,
|
||||||
|
num_classes=10,
|
||||||
|
input_size=[256, 256, 3],
|
||||||
|
anchor=configs.retinanet.Anchor(
|
||||||
|
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
|
||||||
|
),
|
||||||
|
backbone=configs.backbones.Backbone(
|
||||||
|
type='mobilenet', mobilenet=configs.backbones.MobileNet()
|
||||||
|
),
|
||||||
|
decoder=configs.decoders.Decoder(
|
||||||
|
type='fpn',
|
||||||
|
fpn=configs.decoders.FPN(
|
||||||
|
num_filters=128, use_separable_conv=True, use_keras_layer=True
|
||||||
|
),
|
||||||
|
),
|
||||||
|
head=configs.retinanet.RetinaNetHead(
|
||||||
|
num_filters=128, use_separable_conv=True
|
||||||
|
),
|
||||||
|
detection_generator=configs.retinanet.DetectionGenerator(),
|
||||||
|
norm_activation=configs.common.NormActivation(activation='relu6'),
|
||||||
|
)
|
||||||
|
task_config = configs.retinanet.RetinaNetTask(model=model_config)
|
||||||
|
params = cfg.ExperimentConfig(
|
||||||
|
task=task_config,
|
||||||
|
)
|
||||||
|
detection_instance = detection.DetectionModule(
|
||||||
|
params=params, batch_size=1, input_image_size=[256, 256]
|
||||||
|
)
|
||||||
|
outputs = detection_instance.serve(0)
|
||||||
|
expected_outputs = {
|
||||||
|
'detection_boxes': 1,
|
||||||
|
'detection_scores': 2,
|
||||||
|
}
|
||||||
|
self.assertAllEqual(outputs, expected_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -18,6 +18,7 @@ from typing import Mapping, Optional, Sequence, Union
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import detection
|
||||||
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
|
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
|
||||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||||
from official.core import config_definitions as cfg
|
from official.core import config_definitions as cfg
|
||||||
|
@ -29,7 +30,6 @@ from official.vision.losses import loss_utils
|
||||||
from official.vision.modeling import factory
|
from official.vision.modeling import factory
|
||||||
from official.vision.modeling import retinanet_model
|
from official.vision.modeling import retinanet_model
|
||||||
from official.vision.modeling.layers import detection_generator
|
from official.vision.modeling.layers import detection_generator
|
||||||
from official.vision.serving import detection
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectDetectorModel(tf.keras.Model):
|
class ObjectDetectorModel(tf.keras.Model):
|
||||||
|
@ -199,6 +199,7 @@ class ObjectDetectorModel(tf.keras.Model):
|
||||||
max_detections=10,
|
max_detections=10,
|
||||||
max_classes_per_detection=1,
|
max_classes_per_detection=1,
|
||||||
normalize_anchor_coordinates=True,
|
normalize_anchor_coordinates=True,
|
||||||
|
omit_nms=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
tflite_post_processing_config = (
|
tflite_post_processing_config = (
|
||||||
|
|
|
@ -28,6 +28,7 @@ from mediapipe.model_maker.python.vision.object_detector import model_options as
|
||||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
||||||
from mediapipe.model_maker.python.vision.object_detector import preprocessor
|
from mediapipe.model_maker.python.vision.object_detector import preprocessor
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import object_detector as object_detector_writer
|
from mediapipe.tasks.python.metadata.metadata_writers import object_detector as object_detector_writer
|
||||||
from official.vision.evaluation import coco_evaluator
|
from official.vision.evaluation import coco_evaluator
|
||||||
|
@ -264,6 +265,27 @@ class ObjectDetector(classifier.Classifier):
|
||||||
coco_metrics = coco_eval.result()
|
coco_metrics = coco_eval.result()
|
||||||
return losses, coco_metrics
|
return losses, coco_metrics
|
||||||
|
|
||||||
|
def _create_fixed_anchor(
|
||||||
|
self, anchor_box: List[float]
|
||||||
|
) -> object_detector_writer.FixedAnchor:
|
||||||
|
"""Helper function to create FixedAnchor objects from an anchor box array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
anchor_box: List of anchor box coordinates in the format of [x_min, y_min,
|
||||||
|
x_max, y_max].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A FixedAnchor object representing the anchor_box.
|
||||||
|
"""
|
||||||
|
image_shape = self._model_spec.input_image_shape[:2]
|
||||||
|
y_center_norm = (anchor_box[0] + anchor_box[2]) / (2 * image_shape[0])
|
||||||
|
x_center_norm = (anchor_box[1] + anchor_box[3]) / (2 * image_shape[1])
|
||||||
|
height_norm = (anchor_box[2] - anchor_box[0]) / image_shape[0]
|
||||||
|
width_norm = (anchor_box[3] - anchor_box[1]) / image_shape[1]
|
||||||
|
return object_detector_writer.FixedAnchor(
|
||||||
|
x_center_norm, y_center_norm, width_norm, height_norm
|
||||||
|
)
|
||||||
|
|
||||||
def export_model(
|
def export_model(
|
||||||
self,
|
self,
|
||||||
model_name: str = 'model.tflite',
|
model_name: str = 'model.tflite',
|
||||||
|
@ -328,11 +350,40 @@ class ObjectDetector(classifier.Classifier):
|
||||||
converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,)
|
converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,)
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
writer = object_detector_writer.MetadataWriter.create_for_models_with_nms(
|
# Build anchors
|
||||||
|
raw_anchor_boxes = self._preprocessor.anchor_boxes
|
||||||
|
anchors = []
|
||||||
|
for _, anchor_boxes in raw_anchor_boxes.items():
|
||||||
|
anchor_boxes_reshaped = anchor_boxes.numpy().reshape((-1, 4))
|
||||||
|
for ab in anchor_boxes_reshaped:
|
||||||
|
anchors.append(self._create_fixed_anchor(ab))
|
||||||
|
|
||||||
|
ssd_anchors_options = object_detector_writer.SsdAnchorsOptions(
|
||||||
|
object_detector_writer.FixedAnchorsSchema(anchors)
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor_decoding_options = object_detector_writer.TensorsDecodingOptions(
|
||||||
|
num_classes=self._num_classes,
|
||||||
|
num_boxes=len(anchors),
|
||||||
|
num_coords=4,
|
||||||
|
keypoint_coord_offset=0,
|
||||||
|
num_keypoints=0,
|
||||||
|
num_values_per_keypoint=2,
|
||||||
|
x_scale=1,
|
||||||
|
y_scale=1,
|
||||||
|
w_scale=1,
|
||||||
|
h_scale=1,
|
||||||
|
apply_exponential_on_box_size=True,
|
||||||
|
sigmoid_score=False,
|
||||||
|
)
|
||||||
|
writer = object_detector_writer.MetadataWriter.create_for_models_without_nms(
|
||||||
tflite_model,
|
tflite_model,
|
||||||
self._model_spec.mean_rgb,
|
self._model_spec.mean_rgb,
|
||||||
self._model_spec.stddev_rgb,
|
self._model_spec.stddev_rgb,
|
||||||
labels=metadata_writer.Labels().add(list(self._label_names)),
|
labels=metadata_writer.Labels().add(list(self._label_names)),
|
||||||
|
ssd_anchors_options=ssd_anchors_options,
|
||||||
|
tensors_decoding_options=tensor_decoding_options,
|
||||||
|
output_tensors_order=metadata_info.RawDetectionOutputTensorsOrder.LOCATION_SCORE,
|
||||||
)
|
)
|
||||||
tflite_model_with_metadata, metadata_json = writer.populate()
|
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||||
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||||
|
|
|
@ -44,6 +44,26 @@ class Preprocessor(object):
|
||||||
self._aug_scale_max = 2.0
|
self._aug_scale_max = 2.0
|
||||||
self._max_num_instances = 100
|
self._max_num_instances = 100
|
||||||
|
|
||||||
|
self._padded_size = preprocess_ops.compute_padded_size(
|
||||||
|
self._output_size, 2**self._max_level
|
||||||
|
)
|
||||||
|
|
||||||
|
input_anchor = anchor.build_anchor_generator(
|
||||||
|
min_level=self._min_level,
|
||||||
|
max_level=self._max_level,
|
||||||
|
num_scales=self._num_scales,
|
||||||
|
aspect_ratios=self._aspect_ratios,
|
||||||
|
anchor_size=self._anchor_size,
|
||||||
|
)
|
||||||
|
self._anchor_boxes = input_anchor(image_size=self._output_size)
|
||||||
|
self._anchor_labeler = anchor.AnchorLabeler(
|
||||||
|
self._match_threshold, self._unmatched_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def anchor_boxes(self):
|
||||||
|
return self._anchor_boxes
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, data: Mapping[str, Any], is_training: bool = True
|
self, data: Mapping[str, Any], is_training: bool = True
|
||||||
) -> Tuple[tf.Tensor, Mapping[str, Any]]:
|
) -> Tuple[tf.Tensor, Mapping[str, Any]]:
|
||||||
|
@ -90,13 +110,10 @@ class Preprocessor(object):
|
||||||
image, image_info = preprocess_ops.resize_and_crop_image(
|
image, image_info = preprocess_ops.resize_and_crop_image(
|
||||||
image,
|
image,
|
||||||
self._output_size,
|
self._output_size,
|
||||||
padded_size=preprocess_ops.compute_padded_size(
|
padded_size=self._padded_size,
|
||||||
self._output_size, 2**self._max_level
|
|
||||||
),
|
|
||||||
aug_scale_min=(self._aug_scale_min if is_training else 1.0),
|
aug_scale_min=(self._aug_scale_min if is_training else 1.0),
|
||||||
aug_scale_max=(self._aug_scale_max if is_training else 1.0),
|
aug_scale_max=(self._aug_scale_max if is_training else 1.0),
|
||||||
)
|
)
|
||||||
image_height, image_width, _ = image.get_shape().as_list()
|
|
||||||
|
|
||||||
# Resize and crop boxes.
|
# Resize and crop boxes.
|
||||||
image_scale = image_info[2, :]
|
image_scale = image_info[2, :]
|
||||||
|
@ -110,20 +127,9 @@ class Preprocessor(object):
|
||||||
classes = tf.gather(classes, indices)
|
classes = tf.gather(classes, indices)
|
||||||
|
|
||||||
# Assign anchors.
|
# Assign anchors.
|
||||||
input_anchor = anchor.build_anchor_generator(
|
|
||||||
min_level=self._min_level,
|
|
||||||
max_level=self._max_level,
|
|
||||||
num_scales=self._num_scales,
|
|
||||||
aspect_ratios=self._aspect_ratios,
|
|
||||||
anchor_size=self._anchor_size,
|
|
||||||
)
|
|
||||||
anchor_boxes = input_anchor(image_size=(image_height, image_width))
|
|
||||||
anchor_labeler = anchor.AnchorLabeler(
|
|
||||||
self._match_threshold, self._unmatched_threshold
|
|
||||||
)
|
|
||||||
(cls_targets, box_targets, _, cls_weights, box_weights) = (
|
(cls_targets, box_targets, _, cls_weights, box_weights) = (
|
||||||
anchor_labeler.label_anchors(
|
self._anchor_labeler.label_anchors(
|
||||||
anchor_boxes, boxes, tf.expand_dims(classes, axis=1)
|
self.anchor_boxes, boxes, tf.expand_dims(classes, axis=1)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -134,7 +140,7 @@ class Preprocessor(object):
|
||||||
labels = {
|
labels = {
|
||||||
'cls_targets': cls_targets,
|
'cls_targets': cls_targets,
|
||||||
'box_targets': box_targets,
|
'box_targets': box_targets,
|
||||||
'anchor_boxes': anchor_boxes,
|
'anchor_boxes': self.anchor_boxes,
|
||||||
'cls_weights': cls_weights,
|
'cls_weights': cls_weights,
|
||||||
'box_weights': box_weights,
|
'box_weights': box_weights,
|
||||||
'image_info': image_info,
|
'image_info': image_info,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user