Object Detector remove nms operation from exported tflite
PiperOrigin-RevId: 529559380
This commit is contained in:
		
							parent
							
								
									12b0b6fad1
								
							
						
					
					
						commit
						61cfe2ca9b
					
				| 
						 | 
					@ -88,6 +88,17 @@ py_test(
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "detection",
 | 
				
			||||||
 | 
					    srcs = ["detection.py"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_test(
 | 
				
			||||||
 | 
					    name = "detection_test",
 | 
				
			||||||
 | 
					    srcs = ["detection_test.py"],
 | 
				
			||||||
 | 
					    deps = [":detection"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
    name = "hyperparameters",
 | 
					    name = "hyperparameters",
 | 
				
			||||||
    srcs = ["hyperparameters.py"],
 | 
					    srcs = ["hyperparameters.py"],
 | 
				
			||||||
| 
						 | 
					@ -116,6 +127,7 @@ py_library(
 | 
				
			||||||
    name = "model",
 | 
					    name = "model",
 | 
				
			||||||
    srcs = ["model.py"],
 | 
					    srcs = ["model.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":detection",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
| 
						 | 
					@ -163,6 +175,7 @@ py_library(
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
					        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
					        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/utils:quantization",
 | 
					        "//mediapipe/model_maker/python/core/utils:quantization",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/metadata/metadata_writers:metadata_info",
 | 
				
			||||||
        "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
 | 
					        "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
 | 
				
			||||||
        "//mediapipe/tasks/python/metadata/metadata_writers:object_detector",
 | 
					        "//mediapipe/tasks/python/metadata/metadata_writers:object_detector",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,6 +32,7 @@ ObjectDetectorOptions = object_detector_options.ObjectDetectorOptions
 | 
				
			||||||
# Remove duplicated and non-public API
 | 
					# Remove duplicated and non-public API
 | 
				
			||||||
del dataset
 | 
					del dataset
 | 
				
			||||||
del dataset_util  # pylint: disable=undefined-variable
 | 
					del dataset_util  # pylint: disable=undefined-variable
 | 
				
			||||||
 | 
					del detection  # pylint: disable=undefined-variable
 | 
				
			||||||
del hyperparameters
 | 
					del hyperparameters
 | 
				
			||||||
del model  # pylint: disable=undefined-variable
 | 
					del model  # pylint: disable=undefined-variable
 | 
				
			||||||
del model_options
 | 
					del model_options
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,34 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Custom Detection export module for Object Detection."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Any, Mapping
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from official.vision.serving import detection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DetectionModule(detection.DetectionModule):
 | 
				
			||||||
 | 
					  """A serving detection module for exporting the model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  This module overrides the tensorflow_models DetectionModule by only outputting
 | 
				
			||||||
 | 
					    the pre-nms detection_boxes and detection_scores.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def serve(self, images) -> Mapping[str, Any]:
 | 
				
			||||||
 | 
					    result = super().serve(images)
 | 
				
			||||||
 | 
					    final_outputs = {
 | 
				
			||||||
 | 
					        'detection_boxes': result['detection_boxes'],
 | 
				
			||||||
 | 
					        'detection_scores': result['detection_scores'],
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return final_outputs
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,73 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the 'License');
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an 'AS IS' BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from unittest import mock
 | 
				
			||||||
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.object_detector import detection
 | 
				
			||||||
 | 
					from official.core import config_definitions as cfg
 | 
				
			||||||
 | 
					from official.vision import configs
 | 
				
			||||||
 | 
					from official.vision.serving import detection as detection_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ObjectDetectorTest(tf.test.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @mock.patch.object(detection_module.DetectionModule, 'serve', autospec=True)
 | 
				
			||||||
 | 
					  def test_detection_module(self, mock_serve):
 | 
				
			||||||
 | 
					    mock_serve.return_value = {
 | 
				
			||||||
 | 
					        'detection_boxes': 1,
 | 
				
			||||||
 | 
					        'detection_scores': 2,
 | 
				
			||||||
 | 
					        'detection_classes': 3,
 | 
				
			||||||
 | 
					        'num_detections': 4,
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    model_config = configs.retinanet.RetinaNet(
 | 
				
			||||||
 | 
					        min_level=3,
 | 
				
			||||||
 | 
					        max_level=7,
 | 
				
			||||||
 | 
					        num_classes=10,
 | 
				
			||||||
 | 
					        input_size=[256, 256, 3],
 | 
				
			||||||
 | 
					        anchor=configs.retinanet.Anchor(
 | 
				
			||||||
 | 
					            num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        backbone=configs.backbones.Backbone(
 | 
				
			||||||
 | 
					            type='mobilenet', mobilenet=configs.backbones.MobileNet()
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        decoder=configs.decoders.Decoder(
 | 
				
			||||||
 | 
					            type='fpn',
 | 
				
			||||||
 | 
					            fpn=configs.decoders.FPN(
 | 
				
			||||||
 | 
					                num_filters=128, use_separable_conv=True, use_keras_layer=True
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        head=configs.retinanet.RetinaNetHead(
 | 
				
			||||||
 | 
					            num_filters=128, use_separable_conv=True
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        detection_generator=configs.retinanet.DetectionGenerator(),
 | 
				
			||||||
 | 
					        norm_activation=configs.common.NormActivation(activation='relu6'),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    task_config = configs.retinanet.RetinaNetTask(model=model_config)
 | 
				
			||||||
 | 
					    params = cfg.ExperimentConfig(
 | 
				
			||||||
 | 
					        task=task_config,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    detection_instance = detection.DetectionModule(
 | 
				
			||||||
 | 
					        params=params, batch_size=1, input_image_size=[256, 256]
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    outputs = detection_instance.serve(0)
 | 
				
			||||||
 | 
					    expected_outputs = {
 | 
				
			||||||
 | 
					        'detection_boxes': 1,
 | 
				
			||||||
 | 
					        'detection_scores': 2,
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    self.assertAllEqual(outputs, expected_outputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					  tf.test.main()
 | 
				
			||||||
| 
						 | 
					@ -18,6 +18,7 @@ from typing import Mapping, Optional, Sequence, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.object_detector import detection
 | 
				
			||||||
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
 | 
					from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
 | 
				
			||||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
 | 
					from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
 | 
				
			||||||
from official.core import config_definitions as cfg
 | 
					from official.core import config_definitions as cfg
 | 
				
			||||||
| 
						 | 
					@ -29,7 +30,6 @@ from official.vision.losses import loss_utils
 | 
				
			||||||
from official.vision.modeling import factory
 | 
					from official.vision.modeling import factory
 | 
				
			||||||
from official.vision.modeling import retinanet_model
 | 
					from official.vision.modeling import retinanet_model
 | 
				
			||||||
from official.vision.modeling.layers import detection_generator
 | 
					from official.vision.modeling.layers import detection_generator
 | 
				
			||||||
from official.vision.serving import detection
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ObjectDetectorModel(tf.keras.Model):
 | 
					class ObjectDetectorModel(tf.keras.Model):
 | 
				
			||||||
| 
						 | 
					@ -199,6 +199,7 @@ class ObjectDetectorModel(tf.keras.Model):
 | 
				
			||||||
            max_detections=10,
 | 
					            max_detections=10,
 | 
				
			||||||
            max_classes_per_detection=1,
 | 
					            max_classes_per_detection=1,
 | 
				
			||||||
            normalize_anchor_coordinates=True,
 | 
					            normalize_anchor_coordinates=True,
 | 
				
			||||||
 | 
					            omit_nms=True,
 | 
				
			||||||
        ),
 | 
					        ),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    tflite_post_processing_config = (
 | 
					    tflite_post_processing_config = (
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -28,6 +28,7 @@ from mediapipe.model_maker.python.vision.object_detector import model_options as
 | 
				
			||||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
 | 
					from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
 | 
				
			||||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
 | 
					from mediapipe.model_maker.python.vision.object_detector import object_detector_options
 | 
				
			||||||
from mediapipe.model_maker.python.vision.object_detector import preprocessor
 | 
					from mediapipe.model_maker.python.vision.object_detector import preprocessor
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
 | 
				
			||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
 | 
					from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
 | 
				
			||||||
from mediapipe.tasks.python.metadata.metadata_writers import object_detector as object_detector_writer
 | 
					from mediapipe.tasks.python.metadata.metadata_writers import object_detector as object_detector_writer
 | 
				
			||||||
from official.vision.evaluation import coco_evaluator
 | 
					from official.vision.evaluation import coco_evaluator
 | 
				
			||||||
| 
						 | 
					@ -264,6 +265,27 @@ class ObjectDetector(classifier.Classifier):
 | 
				
			||||||
    coco_metrics = coco_eval.result()
 | 
					    coco_metrics = coco_eval.result()
 | 
				
			||||||
    return losses, coco_metrics
 | 
					    return losses, coco_metrics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_fixed_anchor(
 | 
				
			||||||
 | 
					      self, anchor_box: List[float]
 | 
				
			||||||
 | 
					  ) -> object_detector_writer.FixedAnchor:
 | 
				
			||||||
 | 
					    """Helper function to create FixedAnchor objects from an anchor box array.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      anchor_box: List of anchor box coordinates in the format of [x_min, y_min,
 | 
				
			||||||
 | 
					        x_max, y_max].
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A FixedAnchor object representing the anchor_box.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    image_shape = self._model_spec.input_image_shape[:2]
 | 
				
			||||||
 | 
					    y_center_norm = (anchor_box[0] + anchor_box[2]) / (2 * image_shape[0])
 | 
				
			||||||
 | 
					    x_center_norm = (anchor_box[1] + anchor_box[3]) / (2 * image_shape[1])
 | 
				
			||||||
 | 
					    height_norm = (anchor_box[2] - anchor_box[0]) / image_shape[0]
 | 
				
			||||||
 | 
					    width_norm = (anchor_box[3] - anchor_box[1]) / image_shape[1]
 | 
				
			||||||
 | 
					    return object_detector_writer.FixedAnchor(
 | 
				
			||||||
 | 
					        x_center_norm, y_center_norm, width_norm, height_norm
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def export_model(
 | 
					  def export_model(
 | 
				
			||||||
      self,
 | 
					      self,
 | 
				
			||||||
      model_name: str = 'model.tflite',
 | 
					      model_name: str = 'model.tflite',
 | 
				
			||||||
| 
						 | 
					@ -328,11 +350,40 @@ class ObjectDetector(classifier.Classifier):
 | 
				
			||||||
      converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,)
 | 
					      converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,)
 | 
				
			||||||
      tflite_model = converter.convert()
 | 
					      tflite_model = converter.convert()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    writer = object_detector_writer.MetadataWriter.create_for_models_with_nms(
 | 
					    # Build anchors
 | 
				
			||||||
 | 
					    raw_anchor_boxes = self._preprocessor.anchor_boxes
 | 
				
			||||||
 | 
					    anchors = []
 | 
				
			||||||
 | 
					    for _, anchor_boxes in raw_anchor_boxes.items():
 | 
				
			||||||
 | 
					      anchor_boxes_reshaped = anchor_boxes.numpy().reshape((-1, 4))
 | 
				
			||||||
 | 
					      for ab in anchor_boxes_reshaped:
 | 
				
			||||||
 | 
					        anchors.append(self._create_fixed_anchor(ab))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ssd_anchors_options = object_detector_writer.SsdAnchorsOptions(
 | 
				
			||||||
 | 
					        object_detector_writer.FixedAnchorsSchema(anchors)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tensor_decoding_options = object_detector_writer.TensorsDecodingOptions(
 | 
				
			||||||
 | 
					        num_classes=self._num_classes,
 | 
				
			||||||
 | 
					        num_boxes=len(anchors),
 | 
				
			||||||
 | 
					        num_coords=4,
 | 
				
			||||||
 | 
					        keypoint_coord_offset=0,
 | 
				
			||||||
 | 
					        num_keypoints=0,
 | 
				
			||||||
 | 
					        num_values_per_keypoint=2,
 | 
				
			||||||
 | 
					        x_scale=1,
 | 
				
			||||||
 | 
					        y_scale=1,
 | 
				
			||||||
 | 
					        w_scale=1,
 | 
				
			||||||
 | 
					        h_scale=1,
 | 
				
			||||||
 | 
					        apply_exponential_on_box_size=True,
 | 
				
			||||||
 | 
					        sigmoid_score=False,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    writer = object_detector_writer.MetadataWriter.create_for_models_without_nms(
 | 
				
			||||||
        tflite_model,
 | 
					        tflite_model,
 | 
				
			||||||
        self._model_spec.mean_rgb,
 | 
					        self._model_spec.mean_rgb,
 | 
				
			||||||
        self._model_spec.stddev_rgb,
 | 
					        self._model_spec.stddev_rgb,
 | 
				
			||||||
        labels=metadata_writer.Labels().add(list(self._label_names)),
 | 
					        labels=metadata_writer.Labels().add(list(self._label_names)),
 | 
				
			||||||
 | 
					        ssd_anchors_options=ssd_anchors_options,
 | 
				
			||||||
 | 
					        tensors_decoding_options=tensor_decoding_options,
 | 
				
			||||||
 | 
					        output_tensors_order=metadata_info.RawDetectionOutputTensorsOrder.LOCATION_SCORE,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    tflite_model_with_metadata, metadata_json = writer.populate()
 | 
					    tflite_model_with_metadata, metadata_json = writer.populate()
 | 
				
			||||||
    model_util.save_tflite(tflite_model_with_metadata, tflite_file)
 | 
					    model_util.save_tflite(tflite_model_with_metadata, tflite_file)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -44,6 +44,26 @@ class Preprocessor(object):
 | 
				
			||||||
    self._aug_scale_max = 2.0
 | 
					    self._aug_scale_max = 2.0
 | 
				
			||||||
    self._max_num_instances = 100
 | 
					    self._max_num_instances = 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    self._padded_size = preprocess_ops.compute_padded_size(
 | 
				
			||||||
 | 
					        self._output_size, 2**self._max_level
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    input_anchor = anchor.build_anchor_generator(
 | 
				
			||||||
 | 
					        min_level=self._min_level,
 | 
				
			||||||
 | 
					        max_level=self._max_level,
 | 
				
			||||||
 | 
					        num_scales=self._num_scales,
 | 
				
			||||||
 | 
					        aspect_ratios=self._aspect_ratios,
 | 
				
			||||||
 | 
					        anchor_size=self._anchor_size,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self._anchor_boxes = input_anchor(image_size=self._output_size)
 | 
				
			||||||
 | 
					    self._anchor_labeler = anchor.AnchorLabeler(
 | 
				
			||||||
 | 
					        self._match_threshold, self._unmatched_threshold
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @property
 | 
				
			||||||
 | 
					  def anchor_boxes(self):
 | 
				
			||||||
 | 
					    return self._anchor_boxes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __call__(
 | 
					  def __call__(
 | 
				
			||||||
      self, data: Mapping[str, Any], is_training: bool = True
 | 
					      self, data: Mapping[str, Any], is_training: bool = True
 | 
				
			||||||
  ) -> Tuple[tf.Tensor, Mapping[str, Any]]:
 | 
					  ) -> Tuple[tf.Tensor, Mapping[str, Any]]:
 | 
				
			||||||
| 
						 | 
					@ -90,13 +110,10 @@ class Preprocessor(object):
 | 
				
			||||||
    image, image_info = preprocess_ops.resize_and_crop_image(
 | 
					    image, image_info = preprocess_ops.resize_and_crop_image(
 | 
				
			||||||
        image,
 | 
					        image,
 | 
				
			||||||
        self._output_size,
 | 
					        self._output_size,
 | 
				
			||||||
        padded_size=preprocess_ops.compute_padded_size(
 | 
					        padded_size=self._padded_size,
 | 
				
			||||||
            self._output_size, 2**self._max_level
 | 
					 | 
				
			||||||
        ),
 | 
					 | 
				
			||||||
        aug_scale_min=(self._aug_scale_min if is_training else 1.0),
 | 
					        aug_scale_min=(self._aug_scale_min if is_training else 1.0),
 | 
				
			||||||
        aug_scale_max=(self._aug_scale_max if is_training else 1.0),
 | 
					        aug_scale_max=(self._aug_scale_max if is_training else 1.0),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    image_height, image_width, _ = image.get_shape().as_list()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Resize and crop boxes.
 | 
					    # Resize and crop boxes.
 | 
				
			||||||
    image_scale = image_info[2, :]
 | 
					    image_scale = image_info[2, :]
 | 
				
			||||||
| 
						 | 
					@ -110,20 +127,9 @@ class Preprocessor(object):
 | 
				
			||||||
    classes = tf.gather(classes, indices)
 | 
					    classes = tf.gather(classes, indices)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Assign anchors.
 | 
					    # Assign anchors.
 | 
				
			||||||
    input_anchor = anchor.build_anchor_generator(
 | 
					 | 
				
			||||||
        min_level=self._min_level,
 | 
					 | 
				
			||||||
        max_level=self._max_level,
 | 
					 | 
				
			||||||
        num_scales=self._num_scales,
 | 
					 | 
				
			||||||
        aspect_ratios=self._aspect_ratios,
 | 
					 | 
				
			||||||
        anchor_size=self._anchor_size,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    anchor_boxes = input_anchor(image_size=(image_height, image_width))
 | 
					 | 
				
			||||||
    anchor_labeler = anchor.AnchorLabeler(
 | 
					 | 
				
			||||||
        self._match_threshold, self._unmatched_threshold
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    (cls_targets, box_targets, _, cls_weights, box_weights) = (
 | 
					    (cls_targets, box_targets, _, cls_weights, box_weights) = (
 | 
				
			||||||
        anchor_labeler.label_anchors(
 | 
					        self._anchor_labeler.label_anchors(
 | 
				
			||||||
            anchor_boxes, boxes, tf.expand_dims(classes, axis=1)
 | 
					            self.anchor_boxes, boxes, tf.expand_dims(classes, axis=1)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -134,7 +140,7 @@ class Preprocessor(object):
 | 
				
			||||||
    labels = {
 | 
					    labels = {
 | 
				
			||||||
        'cls_targets': cls_targets,
 | 
					        'cls_targets': cls_targets,
 | 
				
			||||||
        'box_targets': box_targets,
 | 
					        'box_targets': box_targets,
 | 
				
			||||||
        'anchor_boxes': anchor_boxes,
 | 
					        'anchor_boxes': self.anchor_boxes,
 | 
				
			||||||
        'cls_weights': cls_weights,
 | 
					        'cls_weights': cls_weights,
 | 
				
			||||||
        'box_weights': box_weights,
 | 
					        'box_weights': box_weights,
 | 
				
			||||||
        'image_info': image_info,
 | 
					        'image_info': image_info,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user