Open Source Object Detector
PiperOrigin-RevId: 519201221
| 
						 | 
				
			
			@ -17,6 +17,7 @@ from mediapipe.model_maker.python.core.utils import quantization
 | 
			
		|||
from mediapipe.model_maker.python.vision import image_classifier
 | 
			
		||||
from mediapipe.model_maker.python.vision import gesture_recognizer
 | 
			
		||||
from mediapipe.model_maker.python.text import text_classifier
 | 
			
		||||
from mediapipe.model_maker.python.vision import object_detector
 | 
			
		||||
 | 
			
		||||
# Remove duplicated and non-public API
 | 
			
		||||
del python
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										195
									
								
								mediapipe/model_maker/python/vision/object_detector/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						| 
						 | 
				
			
			@ -0,0 +1,195 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
# Placeholder for internal Python strict library and test compatibility macro.
 | 
			
		||||
# Placeholder for internal Python strict test compatibility macro.
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
package(
 | 
			
		||||
    default_visibility = ["//mediapipe:__subpackages__"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "object_detector_import",
 | 
			
		||||
    srcs = ["__init__.py"],
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        ":object_detector",
 | 
			
		||||
        ":object_detector_options",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_binary(
 | 
			
		||||
    name = "object_detector_demo",
 | 
			
		||||
    srcs = ["object_detector_demo.py"],
 | 
			
		||||
    data = [":testdata"],
 | 
			
		||||
    python_version = "PY3",
 | 
			
		||||
    tags = ["requires-net:external"],
 | 
			
		||||
    deps = [":object_detector_import"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
filegroup(
 | 
			
		||||
    name = "testdata",
 | 
			
		||||
    srcs = glob([
 | 
			
		||||
        "testdata/**",
 | 
			
		||||
    ]),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "dataset",
 | 
			
		||||
    srcs = ["dataset.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset_util",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "dataset_test",
 | 
			
		||||
    srcs = ["dataset_test.py"],
 | 
			
		||||
    data = [":testdata"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        "//mediapipe/model_maker/python/vision/core:image_utils",
 | 
			
		||||
        "//mediapipe/model_maker/python/vision/core:test_utils",
 | 
			
		||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "dataset_util",
 | 
			
		||||
    srcs = ["dataset_util.py"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "dataset_util_test",
 | 
			
		||||
    srcs = ["dataset_util_test.py"],
 | 
			
		||||
    data = [":testdata"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset_util",
 | 
			
		||||
        "//mediapipe/model_maker/python/vision/core:test_utils",
 | 
			
		||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "hyperparameters",
 | 
			
		||||
    srcs = ["hyperparameters.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "preprocessor",
 | 
			
		||||
    srcs = ["preprocessor.py"],
 | 
			
		||||
    deps = [":model_spec"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "preprocessor_test",
 | 
			
		||||
    srcs = ["preprocessor_test.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        ":preprocessor",
 | 
			
		||||
        "//mediapipe/model_maker/python/vision/core:test_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "model",
 | 
			
		||||
    srcs = ["model.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "model_test",
 | 
			
		||||
    size = "large",
 | 
			
		||||
    srcs = ["model_test.py"],
 | 
			
		||||
    data = [":testdata"],
 | 
			
		||||
    shard_count = 4,
 | 
			
		||||
    tags = ["requires-net:external"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        ":model",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        ":preprocessor",
 | 
			
		||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "model_options",
 | 
			
		||||
    srcs = ["model_options.py"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "model_spec",
 | 
			
		||||
    srcs = ["model_spec.py"],
 | 
			
		||||
    deps = ["//mediapipe/model_maker/python/core/utils:file_util"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "object_detector",
 | 
			
		||||
    srcs = ["object_detector.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        ":object_detector_options",
 | 
			
		||||
        ":preprocessor",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:quantization",
 | 
			
		||||
        "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
 | 
			
		||||
        "//mediapipe/tasks/python/metadata/metadata_writers:object_detector",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "object_detector_test",
 | 
			
		||||
    size = "enormous",
 | 
			
		||||
    srcs = ["object_detector_test.py"],
 | 
			
		||||
    data = [":testdata"],
 | 
			
		||||
    tags = ["requires-net:external"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        ":object_detector",
 | 
			
		||||
        ":object_detector_options",
 | 
			
		||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "object_detector_options",
 | 
			
		||||
    srcs = ["object_detector_options.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,30 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""MediaPipe Model Maker Python Public API For Object Detector."""
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import dataset
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_options
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_spec
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import object_detector
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
 | 
			
		||||
 | 
			
		||||
ObjectDetector = object_detector.ObjectDetector
 | 
			
		||||
ModelOptions = model_options.ObjectDetectorModelOptions
 | 
			
		||||
ModelSpec = model_spec.ModelSpec
 | 
			
		||||
SupportedModels = model_spec.SupportedModels
 | 
			
		||||
HParams = hyperparameters.HParams
 | 
			
		||||
QATHParams = hyperparameters.QATHParams
 | 
			
		||||
Dataset = dataset.Dataset
 | 
			
		||||
ObjectDetectorOptions = object_detector_options.ObjectDetectorOptions
 | 
			
		||||
							
								
								
									
										179
									
								
								mediapipe/model_maker/python/vision/object_detector/dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| 
						 | 
				
			
			@ -0,0 +1,179 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Object detector dataset library."""
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import yaml
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.data import classification_dataset
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import dataset_util
 | 
			
		||||
from official.vision.dataloaders import tf_example_decoder
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Dataset(classification_dataset.ClassificationDataset):
 | 
			
		||||
  """Dataset library for object detector."""
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
  def from_coco_folder(
 | 
			
		||||
      cls,
 | 
			
		||||
      data_dir: str,
 | 
			
		||||
      max_num_images: Optional[int] = None,
 | 
			
		||||
      cache_dir: Optional[str] = None,
 | 
			
		||||
  ) -> 'Dataset':
 | 
			
		||||
    """Loads images and labels from the given directory in COCO format.
 | 
			
		||||
 | 
			
		||||
    - https://cocodataset.org/#home
 | 
			
		||||
 | 
			
		||||
    Folder structure should be:
 | 
			
		||||
      <data_dir>/
 | 
			
		||||
        images/
 | 
			
		||||
          <file0>.jpg
 | 
			
		||||
          ...
 | 
			
		||||
        labels.json
 | 
			
		||||
 | 
			
		||||
    The `labels.json` annotations file should should have the following format:
 | 
			
		||||
    {
 | 
			
		||||
        "categories": [{"id": 0, "name": "background"}, ...],
 | 
			
		||||
        "images": [{"id": 0, "file_name": "<file0>.jpg"}, ...],
 | 
			
		||||
        "annotations": [{
 | 
			
		||||
           "id": 0,
 | 
			
		||||
           "image_id": 0,
 | 
			
		||||
           "category_id": 2,
 | 
			
		||||
           "bbox": [x-top left, y-top left, width, height],
 | 
			
		||||
           }, ...]
 | 
			
		||||
    }
 | 
			
		||||
    Note that category id 0 is reserved for the "background" class. It is
 | 
			
		||||
    optional to include, but if included it must be set to "background".
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      data_dir: Name of the directory containing the data files.
 | 
			
		||||
      max_num_images: Max number of images to process.
 | 
			
		||||
      cache_dir: The cache directory to save TFRecord and metadata files. The
 | 
			
		||||
        TFRecord files are a standardized format for training object detection
 | 
			
		||||
        while the metadata file is used to store information like dataset size
 | 
			
		||||
        and label mapping of id to label name. If the cache_dir is not set, a
 | 
			
		||||
        temporary folder will be created and will not be removed automatically
 | 
			
		||||
        after training which means it can be reused later.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      Dataset containing images and labels and other related info.
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: If the input data directory is empty.
 | 
			
		||||
      ValueError: If the label_name for id 0 is set to something other than
 | 
			
		||||
        the 'background' class.
 | 
			
		||||
    """
 | 
			
		||||
    cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir)
 | 
			
		||||
    if not dataset_util.is_cached(cache_files):
 | 
			
		||||
      label_map = dataset_util.get_label_map_coco(data_dir)
 | 
			
		||||
      cache_writer = dataset_util.COCOCacheFilesWriter(
 | 
			
		||||
          label_map=label_map, max_num_images=max_num_images
 | 
			
		||||
      )
 | 
			
		||||
      cache_writer.write_files(cache_files, data_dir)
 | 
			
		||||
    return cls.from_cache(cache_files.cache_prefix)
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
  def from_pascal_voc_folder(
 | 
			
		||||
      cls,
 | 
			
		||||
      data_dir: str,
 | 
			
		||||
      max_num_images: Optional[int] = None,
 | 
			
		||||
      cache_dir: Optional[str] = None,
 | 
			
		||||
  ) -> 'Dataset':
 | 
			
		||||
    """Loads images and labels from the given directory in PASCAL VOC format.
 | 
			
		||||
 | 
			
		||||
    - http://host.robots.ox.ac.uk/pascal/VOC.
 | 
			
		||||
 | 
			
		||||
    Folder structure should be:
 | 
			
		||||
      <data_dir>/
 | 
			
		||||
        images/
 | 
			
		||||
          <file0>.jpg
 | 
			
		||||
          ...
 | 
			
		||||
        Annotations/
 | 
			
		||||
          <file0>.xml
 | 
			
		||||
          ...
 | 
			
		||||
    Each <file0>.xml annotation file should have the following format:
 | 
			
		||||
      <annotation>
 | 
			
		||||
        <filename>file0.jpg<filename>
 | 
			
		||||
        <object>
 | 
			
		||||
          <name>kangaroo</name>
 | 
			
		||||
          <bndbox>
 | 
			
		||||
            <xmin>233</xmin>
 | 
			
		||||
            <ymin>89</ymin>
 | 
			
		||||
            <xmax>386</xmax>
 | 
			
		||||
            <ymax>262</ymax>
 | 
			
		||||
        </object>
 | 
			
		||||
        <object>...</object>
 | 
			
		||||
      </annotation>
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      data_dir: Name of the directory containing the data files.
 | 
			
		||||
      max_num_images: Max number of images to process.
 | 
			
		||||
      cache_dir: The cache directory to save TFRecord and metadata files. The
 | 
			
		||||
        TFRecord files are a standardized format for training object detection
 | 
			
		||||
        while the metadata file is used to store information like dataset size
 | 
			
		||||
        and label mapping of id to label name. If the cache_dir is not set, a
 | 
			
		||||
        temporary folder will be created and will not be removed automatically
 | 
			
		||||
        after training which means it can be reused later.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      Dataset containing images and labels and other related info.
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: if the input data directory is empty.
 | 
			
		||||
    """
 | 
			
		||||
    cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir)
 | 
			
		||||
    if not dataset_util.is_cached(cache_files):
 | 
			
		||||
      label_map = dataset_util.get_label_map_pascal_voc(data_dir)
 | 
			
		||||
      cache_writer = dataset_util.PascalVocCacheFilesWriter(
 | 
			
		||||
          label_map=label_map, max_num_images=max_num_images
 | 
			
		||||
      )
 | 
			
		||||
      cache_writer.write_files(cache_files, data_dir)
 | 
			
		||||
 | 
			
		||||
    return cls.from_cache(cache_files.cache_prefix)
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
  def from_cache(cls, cache_prefix: str) -> 'Dataset':
 | 
			
		||||
    """Loads the TFRecord data from cache.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      cache_prefix: The cache prefix including the cache directory and the cache
 | 
			
		||||
        prefix filename, e.g: '/tmp/cache/train'.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      ObjectDetectorDataset object.
 | 
			
		||||
    """
 | 
			
		||||
    # Get TFRecord Files
 | 
			
		||||
    tfrecord_file_patten = cache_prefix + '*.tfrecord'
 | 
			
		||||
    matched_files = tf.io.gfile.glob(tfrecord_file_patten)
 | 
			
		||||
    if not matched_files:
 | 
			
		||||
      raise ValueError('TFRecord files are empty.')
 | 
			
		||||
 | 
			
		||||
    # Load meta_data.
 | 
			
		||||
    meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX
 | 
			
		||||
    if not tf.io.gfile.exists(meta_data_file):
 | 
			
		||||
      raise ValueError("Metadata file %s doesn't exist." % meta_data_file)
 | 
			
		||||
    with tf.io.gfile.GFile(meta_data_file, 'r') as f:
 | 
			
		||||
      meta_data = yaml.load(f, Loader=yaml.FullLoader)
 | 
			
		||||
 | 
			
		||||
    dataset = tf.data.TFRecordDataset(matched_files)
 | 
			
		||||
    decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False)
 | 
			
		||||
    dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE)
 | 
			
		||||
 | 
			
		||||
    label_map = meta_data['label_map']
 | 
			
		||||
    label_names = [label_map[k] for k in sorted(label_map.keys())]
 | 
			
		||||
 | 
			
		||||
    return Dataset(
 | 
			
		||||
        dataset=dataset, size=meta_data['size'], label_names=label_names
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,119 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.core import image_utils
 | 
			
		||||
from mediapipe.model_maker.python.vision.core import test_utils
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import dataset
 | 
			
		||||
from mediapipe.tasks.python.test import test_utils as tasks_test_utils
 | 
			
		||||
 | 
			
		||||
IMAGE_SIZE = 224
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DatasetTest(tf.test.TestCase):
 | 
			
		||||
 | 
			
		||||
  def _get_rand_bbox(self):
 | 
			
		||||
    x1, x2 = random.uniform(0, IMAGE_SIZE), random.uniform(0, IMAGE_SIZE)
 | 
			
		||||
    y1, y2 = random.uniform(0, IMAGE_SIZE), random.uniform(0, IMAGE_SIZE)
 | 
			
		||||
    return [min(x1, x2), min(y1, y2), abs(x1 - x2), abs(y1 - y2)]
 | 
			
		||||
 | 
			
		||||
  def setUp(self):
 | 
			
		||||
    super().setUp()
 | 
			
		||||
    self.coco_dataset_path = os.path.join(self.get_temp_dir(), 'coco_dataset')
 | 
			
		||||
    if os.path.exists(self.coco_dataset_path):
 | 
			
		||||
      return
 | 
			
		||||
    os.mkdir(self.coco_dataset_path)
 | 
			
		||||
    categories = [{'id': 1, 'name': 'daisy'}, {'id': 2, 'name': 'tulips'}]
 | 
			
		||||
    images = [
 | 
			
		||||
        {'id': 1, 'file_name': 'img1.jpeg'},
 | 
			
		||||
        {'id': 2, 'file_name': 'img2.jpeg'},
 | 
			
		||||
    ]
 | 
			
		||||
    annotations = [
 | 
			
		||||
        {'image_id': 1, 'category_id': 1, 'bbox': self._get_rand_bbox()},
 | 
			
		||||
        {'image_id': 2, 'category_id': 1, 'bbox': self._get_rand_bbox()},
 | 
			
		||||
        {'image_id': 2, 'category_id': 2, 'bbox': self._get_rand_bbox()},
 | 
			
		||||
    ]
 | 
			
		||||
    labels_dict = {
 | 
			
		||||
        'categories': categories,
 | 
			
		||||
        'images': images,
 | 
			
		||||
        'annotations': annotations,
 | 
			
		||||
    }
 | 
			
		||||
    labels_json = json.dumps(labels_dict)
 | 
			
		||||
    with open(os.path.join(self.coco_dataset_path, 'labels.json'), 'w') as f:
 | 
			
		||||
      f.write(labels_json)
 | 
			
		||||
    images_dir = os.path.join(self.coco_dataset_path, 'images')
 | 
			
		||||
    os.mkdir(images_dir)
 | 
			
		||||
    for item in images:
 | 
			
		||||
      test_utils.write_filled_jpeg_file(
 | 
			
		||||
          os.path.join(images_dir, item['file_name']),
 | 
			
		||||
          [random.uniform(0, 255) for _ in range(3)],
 | 
			
		||||
          IMAGE_SIZE,
 | 
			
		||||
      )
 | 
			
		||||
 | 
			
		||||
  def test_from_coco_folder(self):
 | 
			
		||||
    data = dataset.Dataset.from_coco_folder(
 | 
			
		||||
        self.coco_dataset_path, cache_dir=self.get_temp_dir()
 | 
			
		||||
    )
 | 
			
		||||
    self.assertLen(data, 2)
 | 
			
		||||
    self.assertEqual(data.num_classes, 3)
 | 
			
		||||
    self.assertEqual(data.label_names, ['background', 'daisy', 'tulips'])
 | 
			
		||||
    for example in data.gen_tf_dataset():
 | 
			
		||||
      boxes = example['groundtruth_boxes']
 | 
			
		||||
      classes = example['groundtruth_classes']
 | 
			
		||||
      self.assertNotEmpty(boxes)
 | 
			
		||||
      self.assertAllLessEqual(boxes, 1)
 | 
			
		||||
      self.assertAllGreaterEqual(boxes, 0)
 | 
			
		||||
      self.assertNotEmpty(classes)
 | 
			
		||||
      self.assertTrue(
 | 
			
		||||
          (classes.numpy() == [1]).all() or (classes.numpy() == [1, 2]).all()
 | 
			
		||||
      )
 | 
			
		||||
      if (classes.numpy() == [1, 1]).all():
 | 
			
		||||
        raw_image_tensor = image_utils.load_image(
 | 
			
		||||
            os.path.join(self.coco_dataset_path, 'images', 'img1.jpeg')
 | 
			
		||||
        )
 | 
			
		||||
      else:
 | 
			
		||||
        raw_image_tensor = image_utils.load_image(
 | 
			
		||||
            os.path.join(self.coco_dataset_path, 'images', 'img2.jpeg')
 | 
			
		||||
        )
 | 
			
		||||
      self.assertTrue(
 | 
			
		||||
          (example['image'].numpy() == raw_image_tensor.numpy()).all()
 | 
			
		||||
      )
 | 
			
		||||
 | 
			
		||||
  def test_from_pascal_voc_folder(self):
 | 
			
		||||
    pascal_voc_folder = tasks_test_utils.get_test_data_path('pascal_voc_data')
 | 
			
		||||
    data = dataset.Dataset.from_pascal_voc_folder(
 | 
			
		||||
        pascal_voc_folder, cache_dir=self.get_temp_dir()
 | 
			
		||||
    )
 | 
			
		||||
    self.assertLen(data, 4)
 | 
			
		||||
    self.assertEqual(data.num_classes, 3)
 | 
			
		||||
    self.assertEqual(data.label_names, ['background', 'android', 'pig_android'])
 | 
			
		||||
    for example in data.gen_tf_dataset():
 | 
			
		||||
      boxes = example['groundtruth_boxes']
 | 
			
		||||
      classes = example['groundtruth_classes']
 | 
			
		||||
      self.assertNotEmpty(boxes)
 | 
			
		||||
      self.assertAllLessEqual(boxes, 1)
 | 
			
		||||
      self.assertAllGreaterEqual(boxes, 0)
 | 
			
		||||
      self.assertNotEmpty(classes)
 | 
			
		||||
      image = example['image']
 | 
			
		||||
      self.assertNotEmpty(image)
 | 
			
		||||
      self.assertAllGreaterEqual(image, 0)
 | 
			
		||||
      self.assertAllLessEqual(image, 255)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  tf.test.main()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,484 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Utilities for Object Detector Dataset Library."""
 | 
			
		||||
 | 
			
		||||
import abc
 | 
			
		||||
import collections
 | 
			
		||||
import dataclasses
 | 
			
		||||
import hashlib
 | 
			
		||||
import json
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
import tempfile
 | 
			
		||||
from typing import Any, Dict, List, Mapping, Optional, Sequence
 | 
			
		||||
import xml.etree.ElementTree as ET
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import yaml
 | 
			
		||||
 | 
			
		||||
from official.vision.data import tfrecord_lib
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Suffix of the meta data file name.
 | 
			
		||||
META_DATA_FILE_SUFFIX = '_meta_data.yaml'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _xml_get(node: ET.Element, name: str) -> ET.Element:
 | 
			
		||||
  """Gets a named child from an XML Element node.
 | 
			
		||||
 | 
			
		||||
  This method is used to retrieve an XML element that is expected to exist as a
 | 
			
		||||
  subelement of the `node` passed into this argument. If the subelement is not
 | 
			
		||||
  found, then an error is thrown.
 | 
			
		||||
 | 
			
		||||
  Raises:
 | 
			
		||||
    ValueError: If the subelement is not found.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    node: XML Element Tree node.
 | 
			
		||||
    name: Name of the child node to get
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    A child node of the parameter node with the matching name.
 | 
			
		||||
  """
 | 
			
		||||
  result = node.find(name)
 | 
			
		||||
  if result is None:
 | 
			
		||||
    raise ValueError(f'Unexpected xml format: {name} not found in {node}')
 | 
			
		||||
  return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_cache_dir_or_create(cache_dir: Optional[str]) -> str:
 | 
			
		||||
  """Gets the cache directory or creates it if not exists."""
 | 
			
		||||
  if cache_dir is None:
 | 
			
		||||
    cache_dir = tempfile.mkdtemp()
 | 
			
		||||
  if not tf.io.gfile.exists(cache_dir):
 | 
			
		||||
    tf.io.gfile.makedirs(cache_dir)
 | 
			
		||||
  return cache_dir
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_dir_basename(data_dir: str) -> str:
 | 
			
		||||
  """Gets the base name of the directory."""
 | 
			
		||||
  return os.path.basename(os.path.abspath(data_dir))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass(frozen=True)
 | 
			
		||||
class CacheFiles:
 | 
			
		||||
  """Cache files for object detection."""
 | 
			
		||||
 | 
			
		||||
  cache_prefix: str
 | 
			
		||||
  tfrecord_files: Sequence[str]
 | 
			
		||||
  meta_data_file: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_cache_files(
 | 
			
		||||
    cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10
 | 
			
		||||
) -> CacheFiles:
 | 
			
		||||
  """Creates an object of CacheFiles class.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    cache_dir: The cache directory to save TFRecord and metadata file. When
 | 
			
		||||
      cache_dir is None, a temporary folder will be created and will not be
 | 
			
		||||
      removed automatically after training which makes it can be used later.
 | 
			
		||||
     cache_prefix_filename: The cache prefix filename.
 | 
			
		||||
     num_shards: Number of shards for output file.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    An object of CacheFiles class.
 | 
			
		||||
  """
 | 
			
		||||
  cache_dir = _get_cache_dir_or_create(cache_dir)
 | 
			
		||||
  # The cache prefix including the cache directory and the cache prefix
 | 
			
		||||
  # filename, e.g: '/tmp/cache/train'.
 | 
			
		||||
  cache_prefix = os.path.join(cache_dir, cache_prefix_filename)
 | 
			
		||||
  tf.compat.v1.logging.info(
 | 
			
		||||
      'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s'
 | 
			
		||||
      % (cache_dir, cache_prefix_filename, cache_prefix)
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
  # Cached files including the TFRecord files and the meta data file.
 | 
			
		||||
  tfrecord_files = [
 | 
			
		||||
      cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards)
 | 
			
		||||
      for i in range(num_shards)
 | 
			
		||||
  ]
 | 
			
		||||
  meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX
 | 
			
		||||
  return CacheFiles(
 | 
			
		||||
      cache_prefix=cache_prefix,
 | 
			
		||||
      tfrecord_files=tuple(tfrecord_files),
 | 
			
		||||
      meta_data_file=meta_data_file,
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
 | 
			
		||||
  """Creates an object of CacheFiles class using a COCO formatted dataset.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    data_dir: Folder path of the coco dataset
 | 
			
		||||
    cache_dir: Folder path of the cache location. When cache_dir is None, a
 | 
			
		||||
      temporary folder will be created and will not be removed automatically
 | 
			
		||||
      after training which makes it can be used later.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    An object of CacheFiles class.
 | 
			
		||||
  """
 | 
			
		||||
  hasher = hashlib.md5()
 | 
			
		||||
  # Update with dataset folder name
 | 
			
		||||
  hasher.update(_get_dir_basename(data_dir).encode('utf-8'))
 | 
			
		||||
  # Update with image filenames
 | 
			
		||||
  for image_file in sorted(os.listdir(os.path.join(data_dir, 'images'))):
 | 
			
		||||
    hasher.update(os.path.basename(image_file).encode('utf-8'))
 | 
			
		||||
  # Update with labels.json file content
 | 
			
		||||
  label_file = os.path.join(data_dir, 'labels.json')
 | 
			
		||||
  with open(label_file, 'r') as f:
 | 
			
		||||
    label_json = json.load(f)
 | 
			
		||||
    hasher.update(str(label_json).encode('utf-8'))
 | 
			
		||||
    num_examples = len(label_json['images'])
 | 
			
		||||
  # Num_shards automatically set to 100 images per shard, up to 10 shards total.
 | 
			
		||||
  # See https://www.tensorflow.org/tutorials/load_data/tfrecord for more info
 | 
			
		||||
  # on sharding.
 | 
			
		||||
  num_shards = min(math.ceil(num_examples / 100), 10)
 | 
			
		||||
  # Update with num shards
 | 
			
		||||
  hasher.update(str(num_shards).encode('utf-8'))
 | 
			
		||||
  cache_prefix_filename = hasher.hexdigest()
 | 
			
		||||
 | 
			
		||||
  return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
 | 
			
		||||
  """Gets an object of CacheFiles using a PASCAL VOC formatted dataset.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    data_dir: Folder path of the pascal voc dataset.
 | 
			
		||||
    cache_dir: Folder path of the cache location. When cache_dir is None, a
 | 
			
		||||
      temporary folder will be created and will not be removed automatically
 | 
			
		||||
      after training which makes it can be used later.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    An object of CacheFiles class.
 | 
			
		||||
  """
 | 
			
		||||
  hasher = hashlib.md5()
 | 
			
		||||
  hasher.update(_get_dir_basename(data_dir).encode('utf-8'))
 | 
			
		||||
  annotation_files = tf.io.gfile.glob(
 | 
			
		||||
      os.path.join(data_dir, 'Annotations') + r'/*.xml'
 | 
			
		||||
  )
 | 
			
		||||
  annotation_filenames = [
 | 
			
		||||
      os.path.basename(ann_file) for ann_file in annotation_files
 | 
			
		||||
  ]
 | 
			
		||||
  hasher.update(' '.join(annotation_filenames).encode('utf-8'))
 | 
			
		||||
  num_examples = len(annotation_filenames)
 | 
			
		||||
  num_shards = min(math.ceil(num_examples / 100), 10)
 | 
			
		||||
  hasher.update(str(num_shards).encode('utf-8'))
 | 
			
		||||
  cache_prefix_filename = hasher.hexdigest()
 | 
			
		||||
 | 
			
		||||
  return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_cached(cache_files: CacheFiles) -> bool:
 | 
			
		||||
  """Checks whether cache files are already cached."""
 | 
			
		||||
  all_cached_files = list(cache_files.tfrecord_files) + [
 | 
			
		||||
      cache_files.meta_data_file
 | 
			
		||||
  ]
 | 
			
		||||
  return all(tf.io.gfile.exists(path) for path in all_cached_files)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CacheFilesWriter(abc.ABC):
 | 
			
		||||
  """CacheFilesWriter class to write the cached files."""
 | 
			
		||||
 | 
			
		||||
  def __init__(
 | 
			
		||||
      self, label_map: Dict[int, str], max_num_images: Optional[int] = None
 | 
			
		||||
  ) -> None:
 | 
			
		||||
    """Initializes CacheFilesWriter for object detector.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      label_map: Dict, map label integer ids to string label names such as {1:
 | 
			
		||||
        'person', 2: 'notperson'}. 0 is the reserved key for `background` and
 | 
			
		||||
        doesn't need to be included in `label_map`. Label names can't be
 | 
			
		||||
        duplicated.
 | 
			
		||||
      max_num_images: Max number of images to process. If None, process all the
 | 
			
		||||
        images.
 | 
			
		||||
    """
 | 
			
		||||
    self.label_map = label_map
 | 
			
		||||
    self.max_num_images = max_num_images
 | 
			
		||||
 | 
			
		||||
  def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None:
 | 
			
		||||
    """Writes TFRecord and meta_data files.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      cache_files: CacheFiles object including a list of TFRecord files and the
 | 
			
		||||
        meta data yaml file to save the meta_data including data size and
 | 
			
		||||
        label_map.
 | 
			
		||||
      *args: Non-keyword of parameters used in the `_get_example` method.
 | 
			
		||||
      **kwargs: Keyword parameters used in the `_get_example` method.
 | 
			
		||||
    """
 | 
			
		||||
    writers = [
 | 
			
		||||
        tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    # Writes tf.Example into TFRecord files.
 | 
			
		||||
    size = 0
 | 
			
		||||
    for idx, tf_example in enumerate(self._get_example(*args, **kwargs)):
 | 
			
		||||
      if self.max_num_images and idx >= self.max_num_images:
 | 
			
		||||
        break
 | 
			
		||||
      if idx % 100 == 0:
 | 
			
		||||
        tf.compat.v1.logging.info('On image %d' % idx)
 | 
			
		||||
      writers[idx % len(writers)].write(tf_example.SerializeToString())
 | 
			
		||||
      size = idx + 1
 | 
			
		||||
 | 
			
		||||
    for writer in writers:
 | 
			
		||||
      writer.close()
 | 
			
		||||
 | 
			
		||||
    # Writes meta_data into meta_data_file.
 | 
			
		||||
    meta_data = {'size': size, 'label_map': self.label_map}
 | 
			
		||||
    with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f:
 | 
			
		||||
      yaml.dump(meta_data, f)
 | 
			
		||||
 | 
			
		||||
  @abc.abstractmethod
 | 
			
		||||
  def _get_example(self, *args, **kwargs):
 | 
			
		||||
    raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_label_map_coco(data_dir: str):
 | 
			
		||||
  """Gets the label map from a COCO formatted dataset directory.
 | 
			
		||||
 | 
			
		||||
  Note that id 0 is reserved for the background class. If id=0 is set, it needs
 | 
			
		||||
  to be set to "background". It is optional to include id=0 if it is unused, and
 | 
			
		||||
  it will be automatically added by this method.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    data_dir: Path of the dataset directory
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    label_map dictionary of the format {<id>:<label_name>}
 | 
			
		||||
 | 
			
		||||
  Raises:
 | 
			
		||||
    ValueError: If the label_name for id 0 is set to something other than
 | 
			
		||||
    the "background" class.
 | 
			
		||||
  """
 | 
			
		||||
  data_dir = os.path.abspath(data_dir)
 | 
			
		||||
  # Process labels.json file
 | 
			
		||||
  label_file = os.path.join(data_dir, 'labels.json')
 | 
			
		||||
  with open(label_file, 'r') as f:
 | 
			
		||||
    data = json.load(f)
 | 
			
		||||
 | 
			
		||||
  # Categories
 | 
			
		||||
  label_map = {}
 | 
			
		||||
  for category in data['categories']:
 | 
			
		||||
    label_map[int(category['id'])] = category['name']
 | 
			
		||||
 | 
			
		||||
  if 0 in label_map and label_map[0] != 'background':
 | 
			
		||||
    raise ValueError(
 | 
			
		||||
        (
 | 
			
		||||
            'Label index 0 is reserved for the background class, but '
 | 
			
		||||
            f'it was found to be {label_map[0]}'
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
  if 0 not in label_map:
 | 
			
		||||
    label_map[0] = 'background'
 | 
			
		||||
 | 
			
		||||
  return label_map
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_label_map_pascal_voc(data_dir: str):
 | 
			
		||||
  """Gets the label map from a PASCAL VOC formatted dataset directory.
 | 
			
		||||
 | 
			
		||||
  The id to label_name mapping is determined by sorting all label_names and
 | 
			
		||||
  numbering them starting from 1. Id=0 is set as the 'background' class.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    data_dir: Path of the dataset directory
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    label_map dictionary of the format {<id>:<label_name>}
 | 
			
		||||
  """
 | 
			
		||||
  data_dir = os.path.abspath(data_dir)
 | 
			
		||||
  all_label_names = set()
 | 
			
		||||
  annotations_dir = os.path.join(data_dir, 'Annotations')
 | 
			
		||||
  all_annotation_files = tf.io.gfile.glob(annotations_dir + r'/*.xml')
 | 
			
		||||
  for ann_file in all_annotation_files:
 | 
			
		||||
    tree = ET.parse(ann_file)
 | 
			
		||||
    root = tree.getroot()
 | 
			
		||||
    for child in root.iter('object'):
 | 
			
		||||
      label_name = _xml_get(child, 'name').text
 | 
			
		||||
      all_label_names.add(label_name)
 | 
			
		||||
  label_map = {0: 'background'}
 | 
			
		||||
  for ind, label_name in enumerate(sorted(all_label_names)):
 | 
			
		||||
    label_map[ind + 1] = label_name
 | 
			
		||||
  return label_map
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _bbox_data_to_feature_dict(data):
 | 
			
		||||
  """Converts a dictionary of bbox annotations to a feature dictionary.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    data: Dict with keys 'xmin', 'xmax', 'ymin', 'ymax', 'category_id'
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    Feature dictionary
 | 
			
		||||
  """
 | 
			
		||||
  bbox_feature_dict = {
 | 
			
		||||
      'image/object/bbox/xmin': tfrecord_lib.convert_to_feature(data['xmin']),
 | 
			
		||||
      'image/object/bbox/xmax': tfrecord_lib.convert_to_feature(data['xmax']),
 | 
			
		||||
      'image/object/bbox/ymin': tfrecord_lib.convert_to_feature(data['ymin']),
 | 
			
		||||
      'image/object/bbox/ymax': tfrecord_lib.convert_to_feature(data['ymax']),
 | 
			
		||||
      'image/object/class/label': tfrecord_lib.convert_to_feature(
 | 
			
		||||
          data['category_id']
 | 
			
		||||
      ),
 | 
			
		||||
  }
 | 
			
		||||
  return bbox_feature_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _coco_annotations_to_lists(
 | 
			
		||||
    bbox_annotations: List[Mapping[str, Any]],
 | 
			
		||||
    image_height: int,
 | 
			
		||||
    image_width: int,
 | 
			
		||||
):
 | 
			
		||||
  """Converts COCO annotations to feature lists.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    bbox_annotations: List of dicts with keys ['bbox', 'category_id']
 | 
			
		||||
    image_height: Height of image
 | 
			
		||||
    image_width: Width of iamge
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    (data, num_annotations_skipped) tuple where data contains the keys:
 | 
			
		||||
    ['xmin', 'xmax', 'ymin', 'ymax', 'is_crowd', 'category_id', 'area'] and
 | 
			
		||||
    num_annotations_skipped is the number of skipped annotations because of the
 | 
			
		||||
    bbox having 0 area.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  data = collections.defaultdict(list)
 | 
			
		||||
 | 
			
		||||
  num_annotations_skipped = 0
 | 
			
		||||
 | 
			
		||||
  for object_annotations in bbox_annotations:
 | 
			
		||||
    (x, y, width, height) = tuple(object_annotations['bbox'])
 | 
			
		||||
 | 
			
		||||
    if width <= 0 or height <= 0:
 | 
			
		||||
      num_annotations_skipped += 1
 | 
			
		||||
      continue
 | 
			
		||||
    if x + width > image_width or y + height > image_height:
 | 
			
		||||
      num_annotations_skipped += 1
 | 
			
		||||
      continue
 | 
			
		||||
    data['xmin'].append(float(x) / image_width)
 | 
			
		||||
    data['xmax'].append(float(x + width) / image_width)
 | 
			
		||||
    data['ymin'].append(float(y) / image_height)
 | 
			
		||||
    data['ymax'].append(float(y + height) / image_height)
 | 
			
		||||
    category_id = int(object_annotations['category_id'])
 | 
			
		||||
    data['category_id'].append(category_id)
 | 
			
		||||
 | 
			
		||||
  return data, num_annotations_skipped
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class COCOCacheFilesWriter(CacheFilesWriter):
 | 
			
		||||
  """CacheFilesWriter class to write the cached files for COCO data."""
 | 
			
		||||
 | 
			
		||||
  def _get_example(self, data_dir: str) -> tf.train.Example:
 | 
			
		||||
    """Iterates over all examples in the COCO formatted dataset directory.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      data_dir: Path of the dataset directory
 | 
			
		||||
 | 
			
		||||
    Yields:
 | 
			
		||||
      tf.train.Example
 | 
			
		||||
    """
 | 
			
		||||
    data_dir = os.path.abspath(data_dir)
 | 
			
		||||
    # Process labels.json file
 | 
			
		||||
    label_file = os.path.join(data_dir, 'labels.json')
 | 
			
		||||
    with open(label_file, 'r') as f:
 | 
			
		||||
      data = json.load(f)
 | 
			
		||||
 | 
			
		||||
    # Load all Annotations
 | 
			
		||||
    img_to_annotations = collections.defaultdict(list)
 | 
			
		||||
    for annotation in data['annotations']:
 | 
			
		||||
      image_id = annotation['image_id']
 | 
			
		||||
      img_to_annotations[image_id].append(annotation)
 | 
			
		||||
 | 
			
		||||
    # For each Image:
 | 
			
		||||
    for image in data['images']:
 | 
			
		||||
      img_id = image['id']
 | 
			
		||||
      file_name = image['file_name']
 | 
			
		||||
      full_path = os.path.join(data_dir, 'images', file_name)
 | 
			
		||||
      with tf.io.gfile.GFile(full_path, 'rb') as fid:
 | 
			
		||||
        encoded_jpg = fid.read()
 | 
			
		||||
      image = tf.io.decode_jpeg(encoded_jpg, channels=3)
 | 
			
		||||
      height, width, _ = image.shape
 | 
			
		||||
      feature_dict = tfrecord_lib.image_info_to_feature_dict(
 | 
			
		||||
          height, width, file_name, img_id, encoded_jpg, 'jpg'
 | 
			
		||||
      )
 | 
			
		||||
      data, _ = _coco_annotations_to_lists(
 | 
			
		||||
          img_to_annotations[img_id], height, width
 | 
			
		||||
      )
 | 
			
		||||
      if not data['xmin']:
 | 
			
		||||
        # Skip examples which have no annotations
 | 
			
		||||
        continue
 | 
			
		||||
      bbox_feature_dict = _bbox_data_to_feature_dict(data)
 | 
			
		||||
      feature_dict.update(bbox_feature_dict)
 | 
			
		||||
      example = tf.train.Example(
 | 
			
		||||
          features=tf.train.Features(feature=feature_dict)
 | 
			
		||||
      )
 | 
			
		||||
      yield example
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PascalVocCacheFilesWriter(CacheFilesWriter):
 | 
			
		||||
  """CacheFilesWriter class to write the cached files for PASCAL VOC data."""
 | 
			
		||||
 | 
			
		||||
  def _get_example(self, data_dir: str) -> tf.train.Example:
 | 
			
		||||
    """Iterates over all examples in the PASCAL VOC formatted dataset directory.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      data_dir: Path of the dataset directory
 | 
			
		||||
 | 
			
		||||
    Yields:
 | 
			
		||||
      tf.train.Example
 | 
			
		||||
    """
 | 
			
		||||
    label_name_to_id = {name: i for (i, name) in self.label_map.items()}
 | 
			
		||||
    annotations_dir = os.path.join(data_dir, 'Annotations')
 | 
			
		||||
    images_dir = os.path.join(data_dir, 'images')
 | 
			
		||||
    all_annotation_paths = tf.io.gfile.glob(annotations_dir + r'/*.xml')
 | 
			
		||||
 | 
			
		||||
    data = collections.defaultdict(list)
 | 
			
		||||
    for ind, ann_file in enumerate(all_annotation_paths):
 | 
			
		||||
      tree = ET.parse(ann_file)
 | 
			
		||||
      root = tree.getroot()
 | 
			
		||||
      img_filename = _xml_get(root, 'filename').text
 | 
			
		||||
      img_file = os.path.join(images_dir, img_filename)
 | 
			
		||||
      with tf.io.gfile.GFile(img_file, 'rb') as fid:
 | 
			
		||||
        encoded_jpg = fid.read()
 | 
			
		||||
      image = tf.io.decode_jpeg(encoded_jpg, channels=3)
 | 
			
		||||
      height, width, _ = image.shape
 | 
			
		||||
      for child in root.iter('object'):
 | 
			
		||||
        category_name = _xml_get(child, 'name').text
 | 
			
		||||
        category_id = label_name_to_id[category_name]
 | 
			
		||||
        bndbox = _xml_get(child, 'bndbox')
 | 
			
		||||
        xmin = float(_xml_get(bndbox, 'xmin').text)
 | 
			
		||||
        xmax = float(_xml_get(bndbox, 'xmax').text)
 | 
			
		||||
        ymin = float(_xml_get(bndbox, 'ymin').text)
 | 
			
		||||
        ymax = float(_xml_get(bndbox, 'ymax').text)
 | 
			
		||||
        if xmax <= xmin or ymax <= ymin or xmax > width or ymax > height:
 | 
			
		||||
          # Skip annotations that have no area or are larger than the image
 | 
			
		||||
          continue
 | 
			
		||||
        data['xmin'].append(xmin / width)
 | 
			
		||||
        data['ymin'].append(ymin / height)
 | 
			
		||||
        data['xmax'].append(xmax / width)
 | 
			
		||||
        data['ymax'].append(ymax / height)
 | 
			
		||||
        data['category_id'].append(category_id)
 | 
			
		||||
      if not data['xmin']:
 | 
			
		||||
        # Skip examples which have no valid annotations
 | 
			
		||||
        continue
 | 
			
		||||
      feature_dict = tfrecord_lib.image_info_to_feature_dict(
 | 
			
		||||
          height, width, img_filename, ind, encoded_jpg, 'jpg'
 | 
			
		||||
      )
 | 
			
		||||
      bbox_feature_dict = _bbox_data_to_feature_dict(data)
 | 
			
		||||
      feature_dict.update(bbox_feature_dict)
 | 
			
		||||
      example = tf.train.Example(
 | 
			
		||||
          features=tf.train.Features(feature=feature_dict)
 | 
			
		||||
      )
 | 
			
		||||
      yield example
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,236 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import hashlib
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import shutil
 | 
			
		||||
from unittest import mock as unittest_mock
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import yaml
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.core import test_utils
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import dataset_util
 | 
			
		||||
from mediapipe.tasks.python.test import test_utils as tasks_test_utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DatasetUtilTest(tf.test.TestCase):
 | 
			
		||||
 | 
			
		||||
  def _assert_cache_files_equal(self, cf1, cf2):
 | 
			
		||||
    self.assertEqual(cf1.cache_prefix, cf2.cache_prefix)
 | 
			
		||||
    self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files)
 | 
			
		||||
    self.assertEqual(cf1.meta_data_file, cf2.meta_data_file)
 | 
			
		||||
 | 
			
		||||
  def _assert_cache_files_not_equal(self, cf1, cf2):
 | 
			
		||||
    self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix)
 | 
			
		||||
    self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files)
 | 
			
		||||
    self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file)
 | 
			
		||||
 | 
			
		||||
  def _get_cache_files_and_assert_neq_fn(self, cache_files_fn):
 | 
			
		||||
    def get_cache_files_and_assert_neq(cf, data_dir, cache_dir):
 | 
			
		||||
      new_cf = cache_files_fn(data_dir, cache_dir)
 | 
			
		||||
      self._assert_cache_files_not_equal(cf, new_cf)
 | 
			
		||||
      return new_cf
 | 
			
		||||
 | 
			
		||||
    return get_cache_files_and_assert_neq
 | 
			
		||||
 | 
			
		||||
  @unittest_mock.patch.object(hashlib, 'md5', autospec=True)
 | 
			
		||||
  def test_get_cache_files_coco(self, mock_md5):
 | 
			
		||||
    mock_md5.return_value.hexdigest.return_value = 'train'
 | 
			
		||||
    cache_files = dataset_util.get_cache_files_coco(
 | 
			
		||||
        tasks_test_utils.get_test_data_path('coco_data'), cache_dir='/tmp/'
 | 
			
		||||
    )
 | 
			
		||||
    self.assertEqual(cache_files.cache_prefix, '/tmp/train')
 | 
			
		||||
    self.assertLen(cache_files.tfrecord_files, 1)
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
 | 
			
		||||
    )
 | 
			
		||||
    self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
 | 
			
		||||
 | 
			
		||||
  def test_matching_get_cache_files_coco(self):
 | 
			
		||||
    cache_dir = self.create_tempdir()
 | 
			
		||||
    coco_folder = tasks_test_utils.get_test_data_path('coco_data')
 | 
			
		||||
    coco_folder_tmp = os.path.join(self.create_tempdir(), 'coco_data')
 | 
			
		||||
    shutil.copytree(coco_folder, coco_folder_tmp)
 | 
			
		||||
    cache_files1 = dataset_util.get_cache_files_coco(coco_folder, cache_dir)
 | 
			
		||||
    cache_files2 = dataset_util.get_cache_files_coco(coco_folder, cache_dir)
 | 
			
		||||
    self._assert_cache_files_equal(cache_files1, cache_files2)
 | 
			
		||||
    cache_files3 = dataset_util.get_cache_files_coco(coco_folder_tmp, cache_dir)
 | 
			
		||||
    self._assert_cache_files_equal(cache_files1, cache_files3)
 | 
			
		||||
 | 
			
		||||
  def test_not_matching_get_cache_files_coco(self):
 | 
			
		||||
    cache_dir = self.create_tempdir()
 | 
			
		||||
    temp_dir = self.create_tempdir()
 | 
			
		||||
    coco_folder = os.path.join(temp_dir, 'coco_data')
 | 
			
		||||
    shutil.copytree(
 | 
			
		||||
        tasks_test_utils.get_test_data_path('coco_data'), coco_folder
 | 
			
		||||
    )
 | 
			
		||||
    prev_cache_file = dataset_util.get_cache_files_coco(coco_folder, cache_dir)
 | 
			
		||||
    os.chmod(coco_folder, 0o700)
 | 
			
		||||
    os.chmod(os.path.join(coco_folder, 'images'), 0o700)
 | 
			
		||||
    os.chmod(os.path.join(coco_folder, 'labels.json'), 0o700)
 | 
			
		||||
    get_cache_files_and_assert_neq = self._get_cache_files_and_assert_neq_fn(
 | 
			
		||||
        dataset_util.get_cache_files_coco
 | 
			
		||||
    )
 | 
			
		||||
    # Test adding image
 | 
			
		||||
    test_utils.write_filled_jpeg_file(
 | 
			
		||||
        os.path.join(coco_folder, 'images', 'test.jpg'), [0, 0, 0], 50
 | 
			
		||||
    )
 | 
			
		||||
    prev_cache_file = get_cache_files_and_assert_neq(
 | 
			
		||||
        prev_cache_file, coco_folder, cache_dir
 | 
			
		||||
    )
 | 
			
		||||
    # Test modifying labels.json
 | 
			
		||||
    with open(os.path.join(coco_folder, 'labels.json'), 'w') as f:
 | 
			
		||||
      json.dump({'images': [{'id': 1, 'file_name': '000000000078.jpg'}]}, f)
 | 
			
		||||
    prev_cache_file = get_cache_files_and_assert_neq(
 | 
			
		||||
        prev_cache_file, coco_folder, cache_dir
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Test rename folder
 | 
			
		||||
    new_coco_folder = os.path.join(temp_dir, 'coco_data_renamed')
 | 
			
		||||
    shutil.move(coco_folder, new_coco_folder)
 | 
			
		||||
    coco_folder = new_coco_folder
 | 
			
		||||
    prev_cache_file = get_cache_files_and_assert_neq(
 | 
			
		||||
        prev_cache_file, new_coco_folder, cache_dir
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  @unittest_mock.patch.object(hashlib, 'md5', autospec=True)
 | 
			
		||||
  def test_get_cache_files_pascal_voc(self, mock_md5):
 | 
			
		||||
    mock_md5.return_value.hexdigest.return_value = 'train'
 | 
			
		||||
    cache_files = dataset_util.get_cache_files_pascal_voc(
 | 
			
		||||
        tasks_test_utils.get_test_data_path('pascal_voc_data'),
 | 
			
		||||
        cache_dir='/tmp/',
 | 
			
		||||
    )
 | 
			
		||||
    self.assertEqual(cache_files.cache_prefix, '/tmp/train')
 | 
			
		||||
    self.assertLen(cache_files.tfrecord_files, 1)
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
 | 
			
		||||
    )
 | 
			
		||||
    self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
 | 
			
		||||
 | 
			
		||||
  def test_matching_get_cache_files_pascal_voc(self):
 | 
			
		||||
    cache_dir = self.create_tempdir()
 | 
			
		||||
    pascal_folder = tasks_test_utils.get_test_data_path('pascal_voc_data')
 | 
			
		||||
    pascal_folder_temp = os.path.join(self.create_tempdir(), 'pascal_voc_data')
 | 
			
		||||
    shutil.copytree(pascal_folder, pascal_folder_temp)
 | 
			
		||||
    cache_files1 = dataset_util.get_cache_files_pascal_voc(
 | 
			
		||||
        pascal_folder, cache_dir
 | 
			
		||||
    )
 | 
			
		||||
    cache_files2 = dataset_util.get_cache_files_pascal_voc(
 | 
			
		||||
        pascal_folder, cache_dir
 | 
			
		||||
    )
 | 
			
		||||
    self._assert_cache_files_equal(cache_files1, cache_files2)
 | 
			
		||||
    cache_files3 = dataset_util.get_cache_files_pascal_voc(
 | 
			
		||||
        pascal_folder_temp, cache_dir
 | 
			
		||||
    )
 | 
			
		||||
    self._assert_cache_files_equal(cache_files1, cache_files3)
 | 
			
		||||
 | 
			
		||||
  def test_not_matching_get_cache_files_pascal_voc(self):
 | 
			
		||||
    cache_dir = self.create_tempdir()
 | 
			
		||||
    temp_dir = self.create_tempdir()
 | 
			
		||||
    pascal_folder = os.path.join(temp_dir, 'pascal_voc_data')
 | 
			
		||||
    shutil.copytree(
 | 
			
		||||
        tasks_test_utils.get_test_data_path('pascal_voc_data'), pascal_folder
 | 
			
		||||
    )
 | 
			
		||||
    prev_cache_files = dataset_util.get_cache_files_pascal_voc(
 | 
			
		||||
        pascal_folder, cache_dir
 | 
			
		||||
    )
 | 
			
		||||
    os.chmod(pascal_folder, 0o700)
 | 
			
		||||
    os.chmod(os.path.join(pascal_folder, 'images'), 0o700)
 | 
			
		||||
    os.chmod(os.path.join(pascal_folder, 'Annotations'), 0o700)
 | 
			
		||||
    get_cache_files_and_assert_neq = self._get_cache_files_and_assert_neq_fn(
 | 
			
		||||
        dataset_util.get_cache_files_pascal_voc
 | 
			
		||||
    )
 | 
			
		||||
    # Test adding xml file
 | 
			
		||||
    with open(os.path.join(pascal_folder, 'Annotations', 'test.xml'), 'w') as f:
 | 
			
		||||
      f.write('test')
 | 
			
		||||
    prev_cache_files = get_cache_files_and_assert_neq(
 | 
			
		||||
        prev_cache_files, pascal_folder, cache_dir
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Test rename folder
 | 
			
		||||
    new_pascal_folder = os.path.join(temp_dir, 'pascal_voc_data_renamed')
 | 
			
		||||
    shutil.move(pascal_folder, new_pascal_folder)
 | 
			
		||||
    pascal_folder = new_pascal_folder
 | 
			
		||||
    prev_cache_files = get_cache_files_and_assert_neq(
 | 
			
		||||
        prev_cache_files, new_pascal_folder, cache_dir
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  def test_is_cached(self):
 | 
			
		||||
    tempdir = self.create_tempdir()
 | 
			
		||||
    cache_files = dataset_util.get_cache_files_coco(
 | 
			
		||||
        tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir
 | 
			
		||||
    )
 | 
			
		||||
    self.assertFalse(dataset_util.is_cached(cache_files))
 | 
			
		||||
    with open(cache_files.tfrecord_files[0], 'w') as f:
 | 
			
		||||
      f.write('test')
 | 
			
		||||
    self.assertFalse(dataset_util.is_cached(cache_files))
 | 
			
		||||
    with open(cache_files.meta_data_file, 'w') as f:
 | 
			
		||||
      f.write('test')
 | 
			
		||||
    self.assertTrue(dataset_util.is_cached(cache_files))
 | 
			
		||||
 | 
			
		||||
  def test_get_label_map_coco(self):
 | 
			
		||||
    coco_dir = tasks_test_utils.get_test_data_path('coco_data')
 | 
			
		||||
    label_map = dataset_util.get_label_map_coco(coco_dir)
 | 
			
		||||
    all_keys = sorted(label_map.keys())
 | 
			
		||||
    self.assertEqual(all_keys[0], 0)
 | 
			
		||||
    self.assertEqual(all_keys[-1], 11)
 | 
			
		||||
    self.assertLen(all_keys, 12)
 | 
			
		||||
 | 
			
		||||
  def test_get_label_map_pascal_voc(self):
 | 
			
		||||
    pascal_dir = tasks_test_utils.get_test_data_path('pascal_voc_data')
 | 
			
		||||
    label_map = dataset_util.get_label_map_pascal_voc(pascal_dir)
 | 
			
		||||
    all_keys = sorted(label_map.keys())
 | 
			
		||||
    self.assertEqual(label_map[0], 'background')
 | 
			
		||||
    self.assertEqual(all_keys[0], 0)
 | 
			
		||||
    self.assertEqual(all_keys[-1], 2)
 | 
			
		||||
    self.assertLen(all_keys, 3)
 | 
			
		||||
 | 
			
		||||
  def _validate_cache_files(self, cache_files, expected_size):
 | 
			
		||||
    # Checks the TFRecord file
 | 
			
		||||
    self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0]))
 | 
			
		||||
    self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0)
 | 
			
		||||
 | 
			
		||||
    # Checks the meta_data file
 | 
			
		||||
    self.assertTrue(os.path.isfile(cache_files.meta_data_file))
 | 
			
		||||
    self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0)
 | 
			
		||||
    with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f:
 | 
			
		||||
      meta_data_dict = yaml.load(f, Loader=yaml.FullLoader)
 | 
			
		||||
      # Size is 3 because some examples are skipped for having poor bboxes
 | 
			
		||||
      self.assertEqual(meta_data_dict['size'], expected_size)
 | 
			
		||||
 | 
			
		||||
  def test_coco_cache_files_writer(self):
 | 
			
		||||
    tempdir = self.create_tempdir()
 | 
			
		||||
    coco_dir = tasks_test_utils.get_test_data_path('coco_data')
 | 
			
		||||
    label_map = dataset_util.get_label_map_coco(coco_dir)
 | 
			
		||||
    cache_writer = dataset_util.COCOCacheFilesWriter(label_map)
 | 
			
		||||
    cache_files = dataset_util.get_cache_files_coco(coco_dir, cache_dir=tempdir)
 | 
			
		||||
    cache_writer.write_files(cache_files, coco_dir)
 | 
			
		||||
    self._validate_cache_files(cache_files, 3)
 | 
			
		||||
 | 
			
		||||
  def test_pascal_voc_cache_files_writer(self):
 | 
			
		||||
    tempdir = self.create_tempdir()
 | 
			
		||||
    pascal_dir = tasks_test_utils.get_test_data_path('pascal_voc_data')
 | 
			
		||||
    label_map = dataset_util.get_label_map_pascal_voc(pascal_dir)
 | 
			
		||||
    cache_writer = dataset_util.PascalVocCacheFilesWriter(label_map)
 | 
			
		||||
    cache_files = dataset_util.get_cache_files_pascal_voc(
 | 
			
		||||
        pascal_dir, cache_dir=tempdir
 | 
			
		||||
    )
 | 
			
		||||
    cache_writer.write_files(cache_files, pascal_dir)
 | 
			
		||||
    self._validate_cache_files(cache_files, 4)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  tf.test.main()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,101 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Hyperparameters for training object detection models."""
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class HParams(hp.BaseHParams):
 | 
			
		||||
  """The hyperparameters for training object detectors.
 | 
			
		||||
 | 
			
		||||
  Attributes:
 | 
			
		||||
    learning_rate: Learning rate to use for gradient descent training.
 | 
			
		||||
    batch_size: Batch size for training.
 | 
			
		||||
    epochs: Number of training iterations over the dataset.
 | 
			
		||||
    do_fine_tuning: If true, the base module is trained together with the
 | 
			
		||||
      classification layer on top.
 | 
			
		||||
    learning_rate_boundaries: List of epoch boundaries where
 | 
			
		||||
      learning_rate_boundaries[i] is the epoch where the learning rate will
 | 
			
		||||
      decay to learning_rate * learning_rate_decay_multipliers[i].
 | 
			
		||||
    learning_rate_decay_multipliers: List of learning rate multipliers which
 | 
			
		||||
      calculates the learning rate at the ith boundary as learning_rate *
 | 
			
		||||
      learning_rate_decay_multipliers[i].
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  # Parameters from BaseHParams class.
 | 
			
		||||
  learning_rate: float = 0.003
 | 
			
		||||
  batch_size: int = 32
 | 
			
		||||
  epochs: int = 10
 | 
			
		||||
 | 
			
		||||
  # Parameters for learning rate decay
 | 
			
		||||
  learning_rate_boundaries: List[int] = dataclasses.field(
 | 
			
		||||
      default_factory=lambda: [5, 8]
 | 
			
		||||
  )
 | 
			
		||||
  learning_rate_decay_multipliers: List[float] = dataclasses.field(
 | 
			
		||||
      default_factory=lambda: [0.1, 0.01]
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
  def __post_init__(self):
 | 
			
		||||
    # Validate stepwise learning rate parameters
 | 
			
		||||
    lr_boundary_len = len(self.learning_rate_boundaries)
 | 
			
		||||
    lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers)
 | 
			
		||||
    if lr_boundary_len != lr_decay_multipliers_len:
 | 
			
		||||
      raise ValueError(
 | 
			
		||||
          "Length of learning_rate_boundaries and ",
 | 
			
		||||
          "learning_rate_decay_multipliers do not match: ",
 | 
			
		||||
          f"{lr_boundary_len}!={lr_decay_multipliers_len}",
 | 
			
		||||
      )
 | 
			
		||||
    # Validate learning_rate_boundaries
 | 
			
		||||
    if sorted(self.learning_rate_boundaries) != self.learning_rate_boundaries:
 | 
			
		||||
      raise ValueError(
 | 
			
		||||
          "learning_rate_boundaries is not in ascending order: ",
 | 
			
		||||
          self.learning_rate_boundaries,
 | 
			
		||||
      )
 | 
			
		||||
    if (
 | 
			
		||||
        self.learning_rate_boundaries
 | 
			
		||||
        and self.learning_rate_boundaries[-1] > self.epochs
 | 
			
		||||
    ):
 | 
			
		||||
      raise ValueError(
 | 
			
		||||
          "Values in learning_rate_boundaries cannot be greater ", "than epochs"
 | 
			
		||||
      )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class QATHParams:
 | 
			
		||||
  """The hyperparameters for running quantization aware training (QAT) on object detectors.
 | 
			
		||||
 | 
			
		||||
  For more information on QAT, see:
 | 
			
		||||
    https://www.tensorflow.org/model_optimization/guide/quantization/training
 | 
			
		||||
 | 
			
		||||
  Attributes:
 | 
			
		||||
    learning_rate: Learning rate to use for gradient descent QAT.
 | 
			
		||||
    batch_size: Batch size for QAT.
 | 
			
		||||
    epochs: Number of training iterations over the dataset.
 | 
			
		||||
    decay_steps: Learning rate decay steps for Exponential Decay. See
 | 
			
		||||
      https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay
 | 
			
		||||
        for more information.
 | 
			
		||||
    decay_rate: Learning rate decay rate for Exponential Decay. See
 | 
			
		||||
      https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay
 | 
			
		||||
        for more information.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  learning_rate: float = 0.03
 | 
			
		||||
  batch_size: int = 32
 | 
			
		||||
  epochs: int = 10
 | 
			
		||||
  decay_steps: int = 231
 | 
			
		||||
  decay_rate: float = 0.96
 | 
			
		||||
							
								
								
									
										355
									
								
								mediapipe/model_maker/python/vision/object_detector/model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						| 
						 | 
				
			
			@ -0,0 +1,355 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Custom Model for Object Detection."""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from typing import Mapping, Optional, Sequence, Union
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
 | 
			
		||||
from official.core import config_definitions as cfg
 | 
			
		||||
from official.projects.qat.vision.configs import common as qat_common
 | 
			
		||||
from official.projects.qat.vision.modeling import factory as qat_factory
 | 
			
		||||
from official.vision import configs
 | 
			
		||||
from official.vision.losses import focal_loss
 | 
			
		||||
from official.vision.losses import loss_utils
 | 
			
		||||
from official.vision.modeling import factory
 | 
			
		||||
from official.vision.modeling import retinanet_model
 | 
			
		||||
from official.vision.modeling.layers import detection_generator
 | 
			
		||||
from official.vision.serving import detection
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ObjectDetectorModel(tf.keras.Model):
 | 
			
		||||
  """An object detector model which can be trained using Model Maker's training API.
 | 
			
		||||
 | 
			
		||||
  Attributes:
 | 
			
		||||
    loss_trackers: List of tf.keras.metrics.Mean objects used to track the loss
 | 
			
		||||
      during training.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  def __init__(
 | 
			
		||||
      self,
 | 
			
		||||
      model_spec: ms.ModelSpec,
 | 
			
		||||
      model_options: model_opt.ObjectDetectorModelOptions,
 | 
			
		||||
      num_classes: int,
 | 
			
		||||
  ) -> None:
 | 
			
		||||
    """Initializes an ObjectDetectorModel.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      model_spec: Specification for the model.
 | 
			
		||||
      model_options: Model options for creating the model.
 | 
			
		||||
      num_classes: Number of classes for object detection.
 | 
			
		||||
    """
 | 
			
		||||
    super().__init__()
 | 
			
		||||
    self._model_spec = model_spec
 | 
			
		||||
    self._model_options = model_options
 | 
			
		||||
    self._num_classes = num_classes
 | 
			
		||||
    self._model = self._build_model()
 | 
			
		||||
    checkpoint_folder = self._model_spec.downloaded_files.get_path()
 | 
			
		||||
    checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200')
 | 
			
		||||
    self.load_checkpoint(checkpoint_file)
 | 
			
		||||
    self._model.summary()
 | 
			
		||||
    self.loss_trackers = [
 | 
			
		||||
        tf.keras.metrics.Mean(name=n)
 | 
			
		||||
        for n in ['total_loss', 'cls_loss', 'box_loss', 'model_loss']
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
  def _get_model_config(
 | 
			
		||||
      self,
 | 
			
		||||
      generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
 | 
			
		||||
  ) -> configs.retinanet.RetinaNet:
 | 
			
		||||
    model_config = configs.retinanet.RetinaNet(
 | 
			
		||||
        min_level=3,
 | 
			
		||||
        max_level=7,
 | 
			
		||||
        num_classes=self._num_classes,
 | 
			
		||||
        input_size=self._model_spec.input_image_shape,
 | 
			
		||||
        anchor=configs.retinanet.Anchor(
 | 
			
		||||
            num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
 | 
			
		||||
        ),
 | 
			
		||||
        backbone=configs.backbones.Backbone(
 | 
			
		||||
            type='mobilenet', mobilenet=configs.backbones.MobileNet()
 | 
			
		||||
        ),
 | 
			
		||||
        decoder=configs.decoders.Decoder(
 | 
			
		||||
            type='fpn',
 | 
			
		||||
            fpn=configs.decoders.FPN(
 | 
			
		||||
                num_filters=128, use_separable_conv=True, use_keras_layer=True
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
        head=configs.retinanet.RetinaNetHead(
 | 
			
		||||
            num_filters=128, use_separable_conv=True
 | 
			
		||||
        ),
 | 
			
		||||
        detection_generator=generator_config,
 | 
			
		||||
        norm_activation=configs.common.NormActivation(activation='relu6'),
 | 
			
		||||
    )
 | 
			
		||||
    return model_config
 | 
			
		||||
 | 
			
		||||
  def _build_model(self) -> tf.keras.Model:
 | 
			
		||||
    """Builds a RetinaNet object detector model."""
 | 
			
		||||
    input_specs = tf.keras.layers.InputSpec(
 | 
			
		||||
        shape=[None] + self._model_spec.input_image_shape
 | 
			
		||||
    )
 | 
			
		||||
    l2_regularizer = tf.keras.regularizers.l2(
 | 
			
		||||
        self._model_options.l2_weight_decay / 2.0
 | 
			
		||||
    )
 | 
			
		||||
    model_config = self._get_model_config()
 | 
			
		||||
 | 
			
		||||
    return factory.build_retinanet(input_specs, model_config, l2_regularizer)
 | 
			
		||||
 | 
			
		||||
  def save_checkpoint(self, checkpoint_path: str) -> None:
 | 
			
		||||
    """Saves a model checkpoint to checkpoint_path.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      checkpoint_path: The path to save checkpoint.
 | 
			
		||||
    """
 | 
			
		||||
    ckpt_items = {
 | 
			
		||||
        'backbone': self._model.backbone,
 | 
			
		||||
        'decoder': self._model.decoder,
 | 
			
		||||
        'head': self._model.head,
 | 
			
		||||
    }
 | 
			
		||||
    tf.train.Checkpoint(**ckpt_items).write(checkpoint_path)
 | 
			
		||||
 | 
			
		||||
  def load_checkpoint(
 | 
			
		||||
      self, checkpoint_path: str, include_last_layer: bool = False
 | 
			
		||||
  ) -> None:
 | 
			
		||||
    """Loads a model checkpoint from checkpoint_path.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      checkpoint_path: The path to load a checkpoint from.
 | 
			
		||||
      include_last_layer: Whether or not to load the last classification layer.
 | 
			
		||||
        The size of the last classification layer will differ depending on the
 | 
			
		||||
        number of classes. When loading from the pre-trained checkpoint, this
 | 
			
		||||
        parameter should be False to avoid shape mismatch on the last layer.
 | 
			
		||||
        Defaults to False.
 | 
			
		||||
    """
 | 
			
		||||
    dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
 | 
			
		||||
    self._model(dummy_input, training=True)
 | 
			
		||||
    if include_last_layer:
 | 
			
		||||
      head = self._model.head
 | 
			
		||||
    else:
 | 
			
		||||
      head_classifier = tf.train.Checkpoint(
 | 
			
		||||
          depthwise_kernel=self._model.head._classifier.depthwise_kernel  # pylint:disable=protected-access
 | 
			
		||||
      )
 | 
			
		||||
      head_items = {
 | 
			
		||||
          '_classifier': head_classifier,
 | 
			
		||||
          '_box_norms': self._model.head._box_norms,  # pylint:disable=protected-access
 | 
			
		||||
          '_box_regressor': self._model.head._box_regressor,  # pylint:disable=protected-access
 | 
			
		||||
          '_cls_convs': self._model.head._cls_convs,  # pylint:disable=protected-access
 | 
			
		||||
          '_cls_norms': self._model.head._cls_norms,  # pylint:disable=protected-access
 | 
			
		||||
          '_box_convs': self._model.head._box_convs,  # pylint:disable=protected-access
 | 
			
		||||
      }
 | 
			
		||||
      head = tf.train.Checkpoint(**head_items)
 | 
			
		||||
    ckpt_items = {
 | 
			
		||||
        'backbone': self._model.backbone,
 | 
			
		||||
        'decoder': self._model.decoder,
 | 
			
		||||
        'head': head,
 | 
			
		||||
    }
 | 
			
		||||
    ckpt = tf.train.Checkpoint(**ckpt_items)
 | 
			
		||||
    status = ckpt.read(checkpoint_path)
 | 
			
		||||
    status.expect_partial().assert_existing_objects_matched()
 | 
			
		||||
 | 
			
		||||
  def convert_to_qat(self) -> None:
 | 
			
		||||
    """Converts the model to a QAT RetinaNet model."""
 | 
			
		||||
    model = self._build_model()
 | 
			
		||||
    dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
 | 
			
		||||
    model(dummy_input, training=True)
 | 
			
		||||
    model.set_weights(self._model.get_weights())
 | 
			
		||||
    quantization_config = qat_common.Quantization(
 | 
			
		||||
        quantize_detection_decoder=True, quantize_detection_head=True
 | 
			
		||||
    )
 | 
			
		||||
    model_config = self._get_model_config()
 | 
			
		||||
    qat_model = qat_factory.build_qat_retinanet(
 | 
			
		||||
        model, quantization_config, model_config
 | 
			
		||||
    )
 | 
			
		||||
    self._model = qat_model
 | 
			
		||||
 | 
			
		||||
  def export_saved_model(self, save_path: str):
 | 
			
		||||
    """Exports a saved_model for tflite conversion.
 | 
			
		||||
 | 
			
		||||
    The export process modifies the model in the following two ways:
 | 
			
		||||
      1. Replaces the nms operation in the detection generator with a custom
 | 
			
		||||
        TFLite compatible nms operation.
 | 
			
		||||
      2. Wraps the model with a DetectionModule which handles pre-processing
 | 
			
		||||
        and post-processing when running inference.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      save_path: Path to export the saved model.
 | 
			
		||||
    """
 | 
			
		||||
    generator_config = configs.retinanet.DetectionGenerator(
 | 
			
		||||
        nms_version='tflite',
 | 
			
		||||
        tflite_post_processing=configs.common.TFLitePostProcessingConfig(
 | 
			
		||||
            nms_score_threshold=0,
 | 
			
		||||
            max_detections=10,
 | 
			
		||||
            max_classes_per_detection=1,
 | 
			
		||||
            normalize_anchor_coordinates=True,
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    tflite_post_processing_config = (
 | 
			
		||||
        generator_config.tflite_post_processing.as_dict()
 | 
			
		||||
    )
 | 
			
		||||
    tflite_post_processing_config['input_image_size'] = (
 | 
			
		||||
        self._model_spec.input_image_shape[0],
 | 
			
		||||
        self._model_spec.input_image_shape[1],
 | 
			
		||||
    )
 | 
			
		||||
    detection_generator_obj = detection_generator.MultilevelDetectionGenerator(
 | 
			
		||||
        apply_nms=generator_config.apply_nms,
 | 
			
		||||
        pre_nms_top_k=generator_config.pre_nms_top_k,
 | 
			
		||||
        pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
 | 
			
		||||
        nms_iou_threshold=generator_config.nms_iou_threshold,
 | 
			
		||||
        max_num_detections=generator_config.max_num_detections,
 | 
			
		||||
        nms_version=generator_config.nms_version,
 | 
			
		||||
        use_cpu_nms=generator_config.use_cpu_nms,
 | 
			
		||||
        soft_nms_sigma=generator_config.soft_nms_sigma,
 | 
			
		||||
        tflite_post_processing_config=tflite_post_processing_config,
 | 
			
		||||
        return_decoded=generator_config.return_decoded,
 | 
			
		||||
        use_class_agnostic_nms=generator_config.use_class_agnostic_nms,
 | 
			
		||||
    )
 | 
			
		||||
    model_config = self._get_model_config(generator_config)
 | 
			
		||||
    model = retinanet_model.RetinaNetModel(
 | 
			
		||||
        self._model.backbone,
 | 
			
		||||
        self._model.decoder,
 | 
			
		||||
        self._model.head,
 | 
			
		||||
        detection_generator_obj,
 | 
			
		||||
        min_level=model_config.min_level,
 | 
			
		||||
        max_level=model_config.max_level,
 | 
			
		||||
        num_scales=model_config.anchor.num_scales,
 | 
			
		||||
        aspect_ratios=model_config.anchor.aspect_ratios,
 | 
			
		||||
        anchor_size=model_config.anchor.anchor_size,
 | 
			
		||||
    )
 | 
			
		||||
    task_config = configs.retinanet.RetinaNetTask(model=model_config)
 | 
			
		||||
    params = cfg.ExperimentConfig(
 | 
			
		||||
        task=task_config,
 | 
			
		||||
    )
 | 
			
		||||
    export_module = detection.DetectionModule(
 | 
			
		||||
        params=params,
 | 
			
		||||
        batch_size=1,
 | 
			
		||||
        input_image_size=self._model_spec.input_image_shape[:2],
 | 
			
		||||
        input_type='tflite',
 | 
			
		||||
        num_channels=self._model_spec.input_image_shape[2],
 | 
			
		||||
        model=model,
 | 
			
		||||
    )
 | 
			
		||||
    function_keys = {'tflite': tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY}
 | 
			
		||||
    signatures = export_module.get_inference_signatures(function_keys)
 | 
			
		||||
 | 
			
		||||
    tf.saved_model.save(export_module, save_path, signatures=signatures)
 | 
			
		||||
 | 
			
		||||
  # The remaining method overrides are used to train this object detector model
 | 
			
		||||
  # using model.fit().
 | 
			
		||||
  def call(
 | 
			
		||||
      self,
 | 
			
		||||
      images: Union[tf.Tensor, Sequence[tf.Tensor]],
 | 
			
		||||
      image_shape: Optional[tf.Tensor] = None,
 | 
			
		||||
      anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
 | 
			
		||||
      output_intermediate_features: bool = False,
 | 
			
		||||
      training: bool = None,
 | 
			
		||||
  ) -> Mapping[str, tf.Tensor]:
 | 
			
		||||
    """Overrides call from tf.keras.Model."""
 | 
			
		||||
    return self._model(
 | 
			
		||||
        images,
 | 
			
		||||
        image_shape,
 | 
			
		||||
        anchor_boxes,
 | 
			
		||||
        output_intermediate_features,
 | 
			
		||||
        training,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
 | 
			
		||||
    """Overrides compute_loss from tf.keras.Model."""
 | 
			
		||||
    cls_loss_fn = focal_loss.FocalLoss(
 | 
			
		||||
        alpha=0.25, gamma=1.5, reduction=tf.keras.losses.Reduction.SUM
 | 
			
		||||
    )
 | 
			
		||||
    box_loss_fn = tf.keras.losses.Huber(
 | 
			
		||||
        0.1, reduction=tf.keras.losses.Reduction.SUM
 | 
			
		||||
    )
 | 
			
		||||
    labels = y
 | 
			
		||||
    outputs = y_pred
 | 
			
		||||
    # Sums all positives in a batch for normalization and avoids zero
 | 
			
		||||
    # num_positives_sum, which would lead to inf loss during training
 | 
			
		||||
    cls_sample_weight = labels['cls_weights']
 | 
			
		||||
    box_sample_weight = labels['box_weights']
 | 
			
		||||
    num_positives = tf.reduce_sum(box_sample_weight) + 1.0
 | 
			
		||||
    cls_sample_weight = cls_sample_weight / num_positives
 | 
			
		||||
    box_sample_weight = box_sample_weight / num_positives
 | 
			
		||||
    y_true_cls = loss_utils.multi_level_flatten(
 | 
			
		||||
        labels['cls_targets'], last_dim=None
 | 
			
		||||
    )
 | 
			
		||||
    y_true_cls = tf.one_hot(y_true_cls, self._num_classes)
 | 
			
		||||
    y_pred_cls = loss_utils.multi_level_flatten(
 | 
			
		||||
        outputs['cls_outputs'], last_dim=self._num_classes
 | 
			
		||||
    )
 | 
			
		||||
    y_true_box = loss_utils.multi_level_flatten(
 | 
			
		||||
        labels['box_targets'], last_dim=4
 | 
			
		||||
    )
 | 
			
		||||
    y_pred_box = loss_utils.multi_level_flatten(
 | 
			
		||||
        outputs['box_outputs'], last_dim=4
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    cls_loss = cls_loss_fn(
 | 
			
		||||
        y_true=y_true_cls, y_pred=y_pred_cls, sample_weight=cls_sample_weight
 | 
			
		||||
    )
 | 
			
		||||
    box_loss = box_loss_fn(
 | 
			
		||||
        y_true=y_true_box, y_pred=y_pred_box, sample_weight=box_sample_weight
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    model_loss = cls_loss + 50 * box_loss
 | 
			
		||||
    total_loss = model_loss
 | 
			
		||||
    regularization_losses = self._model.losses
 | 
			
		||||
    if regularization_losses:
 | 
			
		||||
      reg_loss = tf.reduce_sum(regularization_losses)
 | 
			
		||||
      total_loss = model_loss + reg_loss
 | 
			
		||||
    all_losses = {
 | 
			
		||||
        'total_loss': total_loss,
 | 
			
		||||
        'cls_loss': cls_loss,
 | 
			
		||||
        'box_loss': box_loss,
 | 
			
		||||
        'model_loss': model_loss,
 | 
			
		||||
    }
 | 
			
		||||
    for m in self.metrics:
 | 
			
		||||
      m.update_state(all_losses[m.name])
 | 
			
		||||
    return total_loss
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def metrics(self):
 | 
			
		||||
    """Overrides metrics from tf.keras.Model."""
 | 
			
		||||
    return self.loss_trackers
 | 
			
		||||
 | 
			
		||||
  def compute_metrics(self, x, y, y_pred, sample_weight=None):
 | 
			
		||||
    """Overrides compute_metrics from tf.keras.Model."""
 | 
			
		||||
    return self.get_metrics_result()
 | 
			
		||||
 | 
			
		||||
  def train_step(self, data):
 | 
			
		||||
    """Overrides train_step from tf.keras.Model."""
 | 
			
		||||
    tf.keras.backend.set_learning_phase(1)
 | 
			
		||||
    x, y = data
 | 
			
		||||
    # Run forward pass.
 | 
			
		||||
    with tf.GradientTape() as tape:
 | 
			
		||||
      y_pred = self(x, training=True)
 | 
			
		||||
      loss = self.compute_loss(x, y, y_pred)
 | 
			
		||||
    self._validate_target_and_loss(y, loss)
 | 
			
		||||
    # Run backwards pass.
 | 
			
		||||
    self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
 | 
			
		||||
    return self.compute_metrics(x, y, y_pred)
 | 
			
		||||
 | 
			
		||||
  def test_step(self, data):
 | 
			
		||||
    """Overrides test_step from tf.keras.Model."""
 | 
			
		||||
    tf.keras.backend.set_learning_phase(0)
 | 
			
		||||
    x, y = data
 | 
			
		||||
    y_pred = self(
 | 
			
		||||
        x,
 | 
			
		||||
        anchor_boxes=y['anchor_boxes'],
 | 
			
		||||
        image_shape=y['image_info'][:, 1, :],
 | 
			
		||||
        training=False,
 | 
			
		||||
    )
 | 
			
		||||
    # Updates stateful loss metrics.
 | 
			
		||||
    self.compute_loss(x, y, y_pred)
 | 
			
		||||
    return self.compute_metrics(x, y, y_pred)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,28 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Configurable model options for object detector models."""
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class ObjectDetectorModelOptions:
 | 
			
		||||
  """Configurable options for object detector model.
 | 
			
		||||
 | 
			
		||||
  Attributes:
 | 
			
		||||
    l2_weight_decay: L2 regularization penalty used in
 | 
			
		||||
      https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/L2.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  l2_weight_decay: float = 3.0e-05
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,62 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Object detector model specification."""
 | 
			
		||||
import dataclasses
 | 
			
		||||
import enum
 | 
			
		||||
import functools
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import file_util
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
MOBILENET_V2_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'object_detector/mobilenetv2',
 | 
			
		||||
    'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz',
 | 
			
		||||
    is_folder=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class ModelSpec(object):
 | 
			
		||||
  """Specification of object detector model."""
 | 
			
		||||
 | 
			
		||||
  # Mean and Stddev image preprocessing normalization values.
 | 
			
		||||
  mean_norm = (0.5,)
 | 
			
		||||
  stddev_norm = (0.5,)
 | 
			
		||||
  mean_rgb = (127.5,)
 | 
			
		||||
  stddev_rgb = (127.5,)
 | 
			
		||||
 | 
			
		||||
  downloaded_files: file_util.DownloadedFiles
 | 
			
		||||
  input_image_shape: List[int]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
mobilenet_v2_spec = functools.partial(
 | 
			
		||||
    ModelSpec,
 | 
			
		||||
    downloaded_files=MOBILENET_V2_FILES,
 | 
			
		||||
    input_image_shape=[256, 256, 3],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@enum.unique
 | 
			
		||||
class SupportedModels(enum.Enum):
 | 
			
		||||
  """Predefined object detector model specs supported by Model Maker."""
 | 
			
		||||
 | 
			
		||||
  MOBILENET_V2 = mobilenet_v2_spec
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
  def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
 | 
			
		||||
    """Get model spec from the input enum and initializes it."""
 | 
			
		||||
    if spec not in cls:
 | 
			
		||||
      raise TypeError(f'Unsupported object detector spec: {spec}')
 | 
			
		||||
    return spec.value()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,147 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the 'License');
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an 'AS IS' BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import tempfile
 | 
			
		||||
from unittest import mock as unittest_mock
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import dataset as ds
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model as model_lib
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import preprocessor
 | 
			
		||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _dicts_match(dict_1, dict_2):
 | 
			
		||||
  for key in dict_1:
 | 
			
		||||
    if key not in dict_2 or np.any(dict_1[key] != dict_2[key]):
 | 
			
		||||
      return False
 | 
			
		||||
  return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _outputs_match(output1, output2):
 | 
			
		||||
  return _dicts_match(
 | 
			
		||||
      output1['cls_outputs'], output2['cls_outputs']
 | 
			
		||||
  ) and _dicts_match(output1['box_outputs'], output2['box_outputs'])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ObjectDetectorModelTest(tf.test.TestCase):
 | 
			
		||||
 | 
			
		||||
  def setUp(self):
 | 
			
		||||
    super().setUp()
 | 
			
		||||
    dataset_folder = task_test_utils.get_test_data_path('coco_data')
 | 
			
		||||
    cache_dir = self.create_tempdir()
 | 
			
		||||
    self.data = ds.Dataset.from_coco_folder(dataset_folder, cache_dir=cache_dir)
 | 
			
		||||
    self.model_spec = ms.SupportedModels.MOBILENET_V2.value()
 | 
			
		||||
    self.preprocessor = preprocessor.Preprocessor(self.model_spec)
 | 
			
		||||
    self.fake_inputs = np.random.uniform(
 | 
			
		||||
        low=0, high=1, size=(1, 256, 256, 3)
 | 
			
		||||
    ).astype(np.float32)
 | 
			
		||||
    # Mock tempfile.gettempdir() to be unique for each test to avoid race
 | 
			
		||||
    # condition when downloading model since these tests may run in parallel.
 | 
			
		||||
    mock_gettempdir = unittest_mock.patch.object(
 | 
			
		||||
        tempfile,
 | 
			
		||||
        'gettempdir',
 | 
			
		||||
        return_value=self.create_tempdir(),
 | 
			
		||||
        autospec=True,
 | 
			
		||||
    )
 | 
			
		||||
    self.mock_gettempdir = mock_gettempdir.start()
 | 
			
		||||
    self.addCleanup(mock_gettempdir.stop)
 | 
			
		||||
 | 
			
		||||
  def _create_model(self):
 | 
			
		||||
    model_options = model_opt.ObjectDetectorModelOptions()
 | 
			
		||||
    model = model_lib.ObjectDetectorModel(
 | 
			
		||||
        self.model_spec, model_options, self.data.num_classes
 | 
			
		||||
    )
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
  def _train_model(self, model):
 | 
			
		||||
    """Helper to run a simple training run on the model."""
 | 
			
		||||
    dataset = self.data.gen_tf_dataset(
 | 
			
		||||
        batch_size=2,
 | 
			
		||||
        is_training=True,
 | 
			
		||||
        shuffle=False,
 | 
			
		||||
        preprocess=self.preprocessor,
 | 
			
		||||
    )
 | 
			
		||||
    optimizer = tf.keras.optimizers.experimental.SGD(
 | 
			
		||||
        learning_rate=0.03, momentum=0.9
 | 
			
		||||
    )
 | 
			
		||||
    model.compile(optimizer=optimizer)
 | 
			
		||||
    model.fit(
 | 
			
		||||
        x=dataset, epochs=2, steps_per_epoch=None, validation_data=dataset
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  def test_model(self):
 | 
			
		||||
    model = self._create_model()
 | 
			
		||||
    outputs_before = model(self.fake_inputs, training=True)
 | 
			
		||||
    self._train_model(model)
 | 
			
		||||
    outputs_after = model(self.fake_inputs, training=True)
 | 
			
		||||
    self.assertFalse(_outputs_match(outputs_before, outputs_after))
 | 
			
		||||
 | 
			
		||||
  def test_model_convert_to_qat(self):
 | 
			
		||||
    model_options = model_opt.ObjectDetectorModelOptions()
 | 
			
		||||
    model = model_lib.ObjectDetectorModel(
 | 
			
		||||
        self.model_spec, model_options, self.data.num_classes
 | 
			
		||||
    )
 | 
			
		||||
    outputs_before = model(self.fake_inputs, training=True)
 | 
			
		||||
    model.convert_to_qat()
 | 
			
		||||
    outputs_after = model(self.fake_inputs, training=True)
 | 
			
		||||
    self.assertFalse(_outputs_match(outputs_before, outputs_after))
 | 
			
		||||
    outputs_before = outputs_after
 | 
			
		||||
    self._train_model(model)
 | 
			
		||||
    outputs_after = model(self.fake_inputs, training=True)
 | 
			
		||||
    self.assertFalse(_outputs_match(outputs_before, outputs_after))
 | 
			
		||||
 | 
			
		||||
  def test_model_save_and_load_checkpoint(self):
 | 
			
		||||
    model = self._create_model()
 | 
			
		||||
    checkpoint_path = os.path.join(self.create_tempdir(), 'ckpt')
 | 
			
		||||
    model.save_checkpoint(checkpoint_path)
 | 
			
		||||
    data_checkpoint_file = checkpoint_path + '.data-00000-of-00001'
 | 
			
		||||
    index_checkpoint_file = checkpoint_path + '.index'
 | 
			
		||||
    self.assertTrue(os.path.exists(data_checkpoint_file))
 | 
			
		||||
    self.assertTrue(os.path.exists(index_checkpoint_file))
 | 
			
		||||
    self.assertGreater(os.path.getsize(data_checkpoint_file), 0)
 | 
			
		||||
    self.assertGreater(os.path.getsize(index_checkpoint_file), 0)
 | 
			
		||||
    outputs_before = model(self.fake_inputs, training=True)
 | 
			
		||||
 | 
			
		||||
    # Check model output is different after training
 | 
			
		||||
    self._train_model(model)
 | 
			
		||||
    outputs_after = model(self.fake_inputs, training=True)
 | 
			
		||||
    self.assertFalse(_outputs_match(outputs_before, outputs_after))
 | 
			
		||||
 | 
			
		||||
    # Check model output is the same after loading previous checkpoint
 | 
			
		||||
    model.load_checkpoint(checkpoint_path, include_last_layer=True)
 | 
			
		||||
    outputs_after = model(self.fake_inputs, training=True)
 | 
			
		||||
    self.assertTrue(_outputs_match(outputs_before, outputs_after))
 | 
			
		||||
 | 
			
		||||
  def test_export_saved_model(self):
 | 
			
		||||
    export_dir = self.create_tempdir()
 | 
			
		||||
    export_path = os.path.join(export_dir, 'saved_model')
 | 
			
		||||
    model = self._create_model()
 | 
			
		||||
    model.export_saved_model(export_path)
 | 
			
		||||
    self.assertTrue(os.path.exists(export_path))
 | 
			
		||||
    self.assertGreater(os.path.getsize(export_path), 0)
 | 
			
		||||
    model.convert_to_qat()
 | 
			
		||||
    export_path = os.path.join(export_dir, 'saved_model_qat')
 | 
			
		||||
    model.export_saved_model(export_path)
 | 
			
		||||
    self.assertTrue(os.path.exists(export_path))
 | 
			
		||||
    self.assertGreater(os.path.getsize(export_path), 0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  tf.test.main()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,353 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""APIs to train object detector model."""
 | 
			
		||||
import os
 | 
			
		||||
import tempfile
 | 
			
		||||
from typing import Dict, List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.tasks import classifier
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import model_util
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import quantization
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import dataset as ds
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model as model_lib
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import preprocessor
 | 
			
		||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
 | 
			
		||||
from mediapipe.tasks.python.metadata.metadata_writers import object_detector as object_detector_writer
 | 
			
		||||
from official.vision.evaluation import coco_evaluator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ObjectDetector(classifier.Classifier):
 | 
			
		||||
  """ObjectDetector for building object detection model."""
 | 
			
		||||
 | 
			
		||||
  def __init__(
 | 
			
		||||
      self,
 | 
			
		||||
      model_spec: ms.ModelSpec,
 | 
			
		||||
      label_names: List[str],
 | 
			
		||||
      hparams: hp.HParams,
 | 
			
		||||
      model_options: model_opt.ObjectDetectorModelOptions,
 | 
			
		||||
  ) -> None:
 | 
			
		||||
    """Initializes ObjectDetector class.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      model_spec: Specifications for the model.
 | 
			
		||||
      label_names: A list of label names for the classes.
 | 
			
		||||
      hparams: The hyperparameters for training object detector.
 | 
			
		||||
      model_options: Options for creating the object detector model.
 | 
			
		||||
    """
 | 
			
		||||
    super().__init__(
 | 
			
		||||
        model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle
 | 
			
		||||
    )
 | 
			
		||||
    self._preprocessor = preprocessor.Preprocessor(model_spec)
 | 
			
		||||
    self._hparams = hparams
 | 
			
		||||
    self._model_options = model_options
 | 
			
		||||
    self._optimizer = self._create_optimizer()
 | 
			
		||||
    self._is_qat = False
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
  def create(
 | 
			
		||||
      cls,
 | 
			
		||||
      train_data: ds.Dataset,
 | 
			
		||||
      validation_data: ds.Dataset,
 | 
			
		||||
      options: object_detector_options.ObjectDetectorOptions,
 | 
			
		||||
  ) -> 'ObjectDetector':
 | 
			
		||||
    """Creates and trains an ObjectDetector.
 | 
			
		||||
 | 
			
		||||
    Loads data and trains the model based on data for object detection.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      train_data: Training data.
 | 
			
		||||
      validation_data: Validation data.
 | 
			
		||||
      options: Configurations for creating and training object detector.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      An instance of ObjectDetector.
 | 
			
		||||
    """
 | 
			
		||||
    if options.hparams is None:
 | 
			
		||||
      options.hparams = hp.HParams()
 | 
			
		||||
 | 
			
		||||
    if options.model_options is None:
 | 
			
		||||
      options.model_options = model_opt.ObjectDetectorModelOptions()
 | 
			
		||||
 | 
			
		||||
    spec = ms.SupportedModels.get(options.supported_model)
 | 
			
		||||
    object_detector = cls(
 | 
			
		||||
        model_spec=spec,
 | 
			
		||||
        label_names=train_data.label_names,
 | 
			
		||||
        hparams=options.hparams,
 | 
			
		||||
        model_options=options.model_options,
 | 
			
		||||
    )
 | 
			
		||||
    object_detector._create_and_train_model(train_data, validation_data)
 | 
			
		||||
    return object_detector
 | 
			
		||||
 | 
			
		||||
  def _create_and_train_model(
 | 
			
		||||
      self, train_data: ds.Dataset, validation_data: ds.Dataset
 | 
			
		||||
  ):
 | 
			
		||||
    """Creates and trains the model.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      train_data: Training data.
 | 
			
		||||
      validation_data: Validation data.
 | 
			
		||||
    """
 | 
			
		||||
    self._create_model()
 | 
			
		||||
    self._train_model(
 | 
			
		||||
        train_data, validation_data, preprocessor=self._preprocessor
 | 
			
		||||
    )
 | 
			
		||||
    self._save_float_ckpt()
 | 
			
		||||
 | 
			
		||||
  def _create_model(self) -> None:
 | 
			
		||||
    """Creates the object detector model."""
 | 
			
		||||
    self._model = model_lib.ObjectDetectorModel(
 | 
			
		||||
        model_spec=self._model_spec,
 | 
			
		||||
        model_options=self._model_options,
 | 
			
		||||
        num_classes=self._num_classes,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  def _save_float_ckpt(self) -> None:
 | 
			
		||||
    """Saves a checkpoint of the trained float model.
 | 
			
		||||
 | 
			
		||||
    The default save path is {hparams.export_dir}/float_ckpt. Note that
 | 
			
		||||
      `float_cpt` represents a file prefix, not directory. The resulting files
 | 
			
		||||
      saved to {hparams.export_dir} will be:
 | 
			
		||||
        - float_ckpt.data-00000-of-00001
 | 
			
		||||
        - float_ckpt.index
 | 
			
		||||
    """
 | 
			
		||||
    save_path = os.path.join(self._hparams.export_dir, 'float_ckpt')
 | 
			
		||||
    if not os.path.exists(self._hparams.export_dir):
 | 
			
		||||
      os.makedirs(self._hparams.export_dir)
 | 
			
		||||
    self._model.save_checkpoint(save_path)
 | 
			
		||||
 | 
			
		||||
  def restore_float_ckpt(self) -> None:
 | 
			
		||||
    """Loads a float checkpoint of the model from {hparams.export_dir}/float_ckpt.
 | 
			
		||||
 | 
			
		||||
    The float checkpoint at {hparams.export_dir}/float_ckpt is automatically
 | 
			
		||||
    saved after training an ObjectDetector using the `create` method. This
 | 
			
		||||
    method is used to restore the trained float checkpoint state of the model in
 | 
			
		||||
    order to run `quantization_aware_training` multiple times. Example usage:
 | 
			
		||||
 | 
			
		||||
    # Train a model
 | 
			
		||||
    model = object_detector.create(...)
 | 
			
		||||
    # Run QAT
 | 
			
		||||
    model.quantization_aware_training(...)
 | 
			
		||||
    model.evaluate(...)
 | 
			
		||||
    # Restore the float checkpoint to run QAT again
 | 
			
		||||
    model.restore_float_ckpt()
 | 
			
		||||
    # Run QAT with different parameters
 | 
			
		||||
    model.quantization_aware_training(...)
 | 
			
		||||
    model.evaluate(...)
 | 
			
		||||
    """
 | 
			
		||||
    self._create_model()
 | 
			
		||||
    self._model.load_checkpoint(
 | 
			
		||||
        os.path.join(self._hparams.export_dir, 'float_ckpt'),
 | 
			
		||||
        include_last_layer=True,
 | 
			
		||||
    )
 | 
			
		||||
    self._model.compile()
 | 
			
		||||
    self._is_qat = False
 | 
			
		||||
 | 
			
		||||
  # TODO: Refactor this method to utilize shared training function
 | 
			
		||||
  def quantization_aware_training(
 | 
			
		||||
      self,
 | 
			
		||||
      train_data: ds.Dataset,
 | 
			
		||||
      validation_data: ds.Dataset,
 | 
			
		||||
      qat_hparams: hp.QATHParams,
 | 
			
		||||
  ) -> None:
 | 
			
		||||
    """Runs quantization aware training(QAT) on the model.
 | 
			
		||||
 | 
			
		||||
    The QAT step happens after training a regular float model from the `create`
 | 
			
		||||
    method. This additional step will fine-tune the model with a lower precision
 | 
			
		||||
    in order mimic the behavior of a quantized model. The resulting quantized
 | 
			
		||||
    model generally has better performance than a model which is quantized
 | 
			
		||||
    without running QAT. See the following link for more information:
 | 
			
		||||
    - https://www.tensorflow.org/model_optimization/guide/quantization/training
 | 
			
		||||
 | 
			
		||||
    Just like training the float model using the `create` method, the QAT step
 | 
			
		||||
    also requires some manual tuning of hyperparameters. In order to run QAT
 | 
			
		||||
    more than once for purposes such as hyperparameter tuning, use the
 | 
			
		||||
    `restore_float_ckpt` method to restore the model state to the trained float
 | 
			
		||||
    checkpoint without having to rerun the `create` method.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      train_data: Training dataset.
 | 
			
		||||
      validation_data: Validaiton dataset.
 | 
			
		||||
      qat_hparams: Configuration for QAT.
 | 
			
		||||
    """
 | 
			
		||||
    self._model.convert_to_qat()
 | 
			
		||||
    learning_rate_fn = tf.keras.optimizers.schedules.ExponentialDecay(
 | 
			
		||||
        qat_hparams.learning_rate * qat_hparams.batch_size / 256,
 | 
			
		||||
        decay_steps=qat_hparams.decay_steps,
 | 
			
		||||
        decay_rate=qat_hparams.decay_rate,
 | 
			
		||||
        staircase=True,
 | 
			
		||||
    )
 | 
			
		||||
    optimizer = tf.keras.optimizers.experimental.SGD(
 | 
			
		||||
        learning_rate=learning_rate_fn, momentum=0.9
 | 
			
		||||
    )
 | 
			
		||||
    if len(train_data) < qat_hparams.batch_size:
 | 
			
		||||
      raise ValueError(
 | 
			
		||||
          f"The size of the train_data {len(train_data)} can't be smaller than"
 | 
			
		||||
          f' batch_size {qat_hparams.batch_size}. To solve this problem, set'
 | 
			
		||||
          ' the batch_size smaller or increase the size of the train_data.'
 | 
			
		||||
      )
 | 
			
		||||
 | 
			
		||||
    train_dataset = train_data.gen_tf_dataset(
 | 
			
		||||
        batch_size=qat_hparams.batch_size,
 | 
			
		||||
        is_training=True,
 | 
			
		||||
        shuffle=self._shuffle,
 | 
			
		||||
        preprocess=self._preprocessor,
 | 
			
		||||
    )
 | 
			
		||||
    steps_per_epoch = model_util.get_steps_per_epoch(
 | 
			
		||||
        steps_per_epoch=None,
 | 
			
		||||
        batch_size=qat_hparams.batch_size,
 | 
			
		||||
        train_data=train_data,
 | 
			
		||||
    )
 | 
			
		||||
    train_dataset = train_dataset.take(count=steps_per_epoch)
 | 
			
		||||
    validation_dataset = validation_data.gen_tf_dataset(
 | 
			
		||||
        batch_size=qat_hparams.batch_size,
 | 
			
		||||
        is_training=False,
 | 
			
		||||
        preprocess=self._preprocessor,
 | 
			
		||||
    )
 | 
			
		||||
    self._model.compile(optimizer=optimizer)
 | 
			
		||||
    self._model.fit(
 | 
			
		||||
        x=train_dataset,
 | 
			
		||||
        epochs=qat_hparams.epochs,
 | 
			
		||||
        steps_per_epoch=None,
 | 
			
		||||
        validation_data=validation_dataset,
 | 
			
		||||
    )
 | 
			
		||||
    self._is_qat = True
 | 
			
		||||
 | 
			
		||||
  def evaluate(
 | 
			
		||||
      self, dataset: ds.Dataset, batch_size: int = 1
 | 
			
		||||
  ) -> Tuple[List[float], Dict[str, float]]:
 | 
			
		||||
    """Overrides Classifier.evaluate to calculate COCO metrics."""
 | 
			
		||||
    dataset = dataset.gen_tf_dataset(
 | 
			
		||||
        batch_size, is_training=False, preprocess=self._preprocessor
 | 
			
		||||
    )
 | 
			
		||||
    losses = self._model.evaluate(dataset)
 | 
			
		||||
    coco_eval = coco_evaluator.COCOEvaluator(
 | 
			
		||||
        annotation_file=None,
 | 
			
		||||
        include_mask=False,
 | 
			
		||||
        per_category_metrics=True,
 | 
			
		||||
        max_num_eval_detections=100,
 | 
			
		||||
    )
 | 
			
		||||
    for batch in dataset:
 | 
			
		||||
      x, y = batch
 | 
			
		||||
      y_pred = self._model(
 | 
			
		||||
          x,
 | 
			
		||||
          anchor_boxes=y['anchor_boxes'],
 | 
			
		||||
          image_shape=y['image_info'][:, 1, :],
 | 
			
		||||
          training=False,
 | 
			
		||||
      )
 | 
			
		||||
      groundtruths = y['groundtruths']
 | 
			
		||||
      y_pred['image_info'] = groundtruths['image_info']
 | 
			
		||||
      y_pred['source_id'] = groundtruths['source_id']
 | 
			
		||||
      coco_eval.update_state(groundtruths, y_pred)
 | 
			
		||||
    coco_metrics = coco_eval.result()
 | 
			
		||||
    return losses, coco_metrics
 | 
			
		||||
 | 
			
		||||
  def export_model(
 | 
			
		||||
      self,
 | 
			
		||||
      model_name: str = 'model.tflite',
 | 
			
		||||
      quantization_config: Optional[quantization.QuantizationConfig] = None,
 | 
			
		||||
  ):
 | 
			
		||||
    """Converts and saves the model to a TFLite file with metadata included.
 | 
			
		||||
 | 
			
		||||
    The model export format is automatically set based on whether or not
 | 
			
		||||
    `quantization_aware_training`(QAT) was run. The model exports to float32 by
 | 
			
		||||
    default and will export to an int8 quantized model if QAT was run. To export
 | 
			
		||||
    a float32 model after running QAT, run `restore_float_ckpt` before this
 | 
			
		||||
    method. For custom post-training quantization without QAT, use the
 | 
			
		||||
    quantization_config parameter.
 | 
			
		||||
 | 
			
		||||
    Note that only the TFLite file is needed for deployment. This function also
 | 
			
		||||
    saves a metadata.json file to the same directory as the TFLite file which
 | 
			
		||||
    can be used to interpret the metadata content in the TFLite file.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      model_name: File name to save TFLite model with metadata. The full export
 | 
			
		||||
        path is {self._hparams.export_dir}/{model_name}.
 | 
			
		||||
      quantization_config: The configuration for model quantization. Note that
 | 
			
		||||
        int8 quantization aware training is automatically applied when possible.
 | 
			
		||||
        This parameter is used to specify other post-training quantization
 | 
			
		||||
        options such as fp16 and int8 without QAT.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: If a custom quantization_config is specified when the model
 | 
			
		||||
        has quantization aware training enabled.
 | 
			
		||||
    """
 | 
			
		||||
    if quantization_config:
 | 
			
		||||
      if self._is_qat:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            'Exporting a qat model with a custom quantization_config is not '
 | 
			
		||||
            'supported.'
 | 
			
		||||
        )
 | 
			
		||||
      else:
 | 
			
		||||
        print(
 | 
			
		||||
            'Exporting with custom post-training-quantization: ',
 | 
			
		||||
            quantization_config,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
      if self._is_qat:
 | 
			
		||||
        print('Exporting a qat int8 model')
 | 
			
		||||
        quantization_config = quantization.QuantizationConfig(
 | 
			
		||||
            inference_input_type=tf.uint8, inference_output_type=tf.uint8
 | 
			
		||||
        )
 | 
			
		||||
      else:
 | 
			
		||||
        print('Exporting a floating point model')
 | 
			
		||||
 | 
			
		||||
    tflite_file = os.path.join(self._hparams.export_dir, model_name)
 | 
			
		||||
    metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
 | 
			
		||||
    with tempfile.TemporaryDirectory() as temp_dir:
 | 
			
		||||
      save_path = os.path.join(temp_dir, 'saved_model')
 | 
			
		||||
      self._model.export_saved_model(save_path)
 | 
			
		||||
      converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
 | 
			
		||||
      if quantization_config:
 | 
			
		||||
        converter = quantization_config.set_converter_with_quantization(
 | 
			
		||||
            converter, preprocess=self._preprocessor
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
      converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,)
 | 
			
		||||
      tflite_model = converter.convert()
 | 
			
		||||
 | 
			
		||||
    writer = object_detector_writer.MetadataWriter.create(
 | 
			
		||||
        tflite_model,
 | 
			
		||||
        self._model_spec.mean_rgb,
 | 
			
		||||
        self._model_spec.stddev_rgb,
 | 
			
		||||
        labels=metadata_writer.Labels().add(list(self._label_names)),
 | 
			
		||||
    )
 | 
			
		||||
    tflite_model_with_metadata, metadata_json = writer.populate()
 | 
			
		||||
    model_util.save_tflite(tflite_model_with_metadata, tflite_file)
 | 
			
		||||
    with open(metadata_file, 'w') as f:
 | 
			
		||||
      f.write(metadata_json)
 | 
			
		||||
 | 
			
		||||
  def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
 | 
			
		||||
    """Creates an optimizer with learning rate schedule for regular training.
 | 
			
		||||
 | 
			
		||||
    Uses Keras PiecewiseConstantDecay schedule by default.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      A tf.keras.optimizer.Optimizer for model training.
 | 
			
		||||
    """
 | 
			
		||||
    init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
 | 
			
		||||
    lr_values = [init_lr] + [
 | 
			
		||||
        init_lr * m for m in self._hparams.learning_rate_decay_multipliers
 | 
			
		||||
    ]
 | 
			
		||||
    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
 | 
			
		||||
        self._hparams.learning_rate_boundaries, lr_values
 | 
			
		||||
    )
 | 
			
		||||
    return tf.keras.optimizers.experimental.SGD(
 | 
			
		||||
        learning_rate=learning_rate_fn, momentum=0.9
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,84 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Demo for making an object detector model by MediaPipe Model Maker."""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
# Dependency imports
 | 
			
		||||
 | 
			
		||||
from absl import app
 | 
			
		||||
from absl import flags
 | 
			
		||||
from absl import logging
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision import object_detector
 | 
			
		||||
 | 
			
		||||
FLAGS = flags.FLAGS
 | 
			
		||||
 | 
			
		||||
TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/object_detector/testdata/coco_data'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def define_flags() -> None:
 | 
			
		||||
  """Define flags for the object detection model maker demo."""
 | 
			
		||||
  flags.DEFINE_string(
 | 
			
		||||
      'export_dir', None, 'The directory to save exported files.'
 | 
			
		||||
  )
 | 
			
		||||
  flags.DEFINE_string(
 | 
			
		||||
      'input_data_dir',
 | 
			
		||||
      None,
 | 
			
		||||
      """The directory with input training data. If the training data is not
 | 
			
		||||
      specified, the pipeline will use the test dataset.""",
 | 
			
		||||
  )
 | 
			
		||||
  flags.DEFINE_bool('qat', True, 'Whether or not to do QAT.')
 | 
			
		||||
  flags.mark_flag_as_required('export_dir')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run(data_dir: str, export_dir: str, qat: bool):
 | 
			
		||||
  """Runs demo."""
 | 
			
		||||
  data = object_detector.Dataset.from_coco_folder(data_dir)
 | 
			
		||||
  train_data, rest_data = data.split(0.6)
 | 
			
		||||
  validation_data, test_data = rest_data.split(0.5)
 | 
			
		||||
 | 
			
		||||
  hparams = object_detector.HParams(batch_size=1, export_dir=export_dir)
 | 
			
		||||
  options = object_detector.ObjectDetectorOptions(
 | 
			
		||||
      supported_model=object_detector.SupportedModels.MOBILENET_V2,
 | 
			
		||||
      hparams=hparams,
 | 
			
		||||
  )
 | 
			
		||||
  model = object_detector.ObjectDetector.create(
 | 
			
		||||
      train_data=train_data, validation_data=validation_data, options=options
 | 
			
		||||
  )
 | 
			
		||||
  loss, coco_metrics = model.evaluate(test_data, batch_size=1)
 | 
			
		||||
  print(f'Evaluation loss:{loss}, coco_metrics:{coco_metrics}')
 | 
			
		||||
  if qat:
 | 
			
		||||
    qat_hparams = object_detector.QATHParams(batch_size=1)
 | 
			
		||||
    model.quantization_aware_training(train_data, validation_data, qat_hparams)
 | 
			
		||||
    qat_loss, qat_coco_metrics = model.evaluate(test_data, batch_size=1)
 | 
			
		||||
    print(f'QAT Evaluation loss:{qat_loss}, coco_metrics:{qat_coco_metrics}')
 | 
			
		||||
 | 
			
		||||
  model.export_model()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(_) -> None:
 | 
			
		||||
  logging.set_verbosity(logging.INFO)
 | 
			
		||||
 | 
			
		||||
  if FLAGS.input_data_dir is None:
 | 
			
		||||
    data_dir = os.path.join(FLAGS.test_srcdir, TEST_DATA_DIR)
 | 
			
		||||
  else:
 | 
			
		||||
    data_dir = FLAGS.input_data_dir
 | 
			
		||||
 | 
			
		||||
  export_dir = os.path.expanduser(FLAGS.export_dir)
 | 
			
		||||
  run(data_dir=data_dir, export_dir=export_dir, qat=FLAGS.qat)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  define_flags()
 | 
			
		||||
  app.run(main)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,36 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Options for building object detector."""
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_spec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class ObjectDetectorOptions:
 | 
			
		||||
  """Configurable options for building object detector.
 | 
			
		||||
 | 
			
		||||
  Attributes:
 | 
			
		||||
    supported_model: A model from the SupportedModels enum.
 | 
			
		||||
    model_options: A set of options for configuring the selected model.
 | 
			
		||||
    hparams: A set of hyperparameters used to train the object detector.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  supported_model: model_spec.SupportedModels
 | 
			
		||||
  model_options: Optional[model_opt.ObjectDetectorModelOptions] = None
 | 
			
		||||
  hparams: Optional[hyperparameters.HParams] = None
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,121 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the 'License');
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an 'AS IS' BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import tempfile
 | 
			
		||||
from unittest import mock as unittest_mock
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import dataset
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import object_detector
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
 | 
			
		||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  def setUp(self):
 | 
			
		||||
    super().setUp()
 | 
			
		||||
    dataset_folder = task_test_utils.get_test_data_path('coco_data')
 | 
			
		||||
    cache_dir = self.create_tempdir()
 | 
			
		||||
    self.data = dataset.Dataset.from_coco_folder(
 | 
			
		||||
        dataset_folder, cache_dir=cache_dir
 | 
			
		||||
    )
 | 
			
		||||
    # Mock tempfile.gettempdir() to be unique for each test to avoid race
 | 
			
		||||
    # condition when downloading model since these tests may run in parallel.
 | 
			
		||||
    mock_gettempdir = unittest_mock.patch.object(
 | 
			
		||||
        tempfile,
 | 
			
		||||
        'gettempdir',
 | 
			
		||||
        return_value=self.create_tempdir(),
 | 
			
		||||
        autospec=True,
 | 
			
		||||
    )
 | 
			
		||||
    self.mock_gettempdir = mock_gettempdir.start()
 | 
			
		||||
    self.addCleanup(mock_gettempdir.stop)
 | 
			
		||||
 | 
			
		||||
  def test_object_detector(self):
 | 
			
		||||
    hparams = hyperparameters.HParams(
 | 
			
		||||
        epochs=10,
 | 
			
		||||
        batch_size=2,
 | 
			
		||||
        learning_rate=0.9,
 | 
			
		||||
        shuffle=False,
 | 
			
		||||
        export_dir=self.create_tempdir(),
 | 
			
		||||
    )
 | 
			
		||||
    options = object_detector_options.ObjectDetectorOptions(
 | 
			
		||||
        supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
 | 
			
		||||
    )
 | 
			
		||||
    # Test `create``
 | 
			
		||||
    model = object_detector.ObjectDetector.create(
 | 
			
		||||
        train_data=self.data, validation_data=self.data, options=options
 | 
			
		||||
    )
 | 
			
		||||
    losses, coco_metrics = model.evaluate(self.data)
 | 
			
		||||
    self._assert_ap_greater(coco_metrics)
 | 
			
		||||
    self.assertFalse(model._is_qat)
 | 
			
		||||
    # Test float export_model
 | 
			
		||||
    model.export_model()
 | 
			
		||||
    output_metadata_file = os.path.join(
 | 
			
		||||
        options.hparams.export_dir, 'metadata.json'
 | 
			
		||||
    )
 | 
			
		||||
    output_tflite_file = os.path.join(
 | 
			
		||||
        options.hparams.export_dir, 'model.tflite'
 | 
			
		||||
    )
 | 
			
		||||
    print('ASDF float', os.path.getsize(output_tflite_file))
 | 
			
		||||
    self.assertTrue(os.path.exists(output_tflite_file))
 | 
			
		||||
    self.assertGreater(os.path.getsize(output_tflite_file), 0)
 | 
			
		||||
    self.assertTrue(os.path.exists(output_metadata_file))
 | 
			
		||||
    self.assertGreater(os.path.getsize(output_metadata_file), 0)
 | 
			
		||||
 | 
			
		||||
    # Test `quantization_aware_training`
 | 
			
		||||
    qat_hparams = hyperparameters.QATHParams(
 | 
			
		||||
        learning_rate=0.9,
 | 
			
		||||
        batch_size=2,
 | 
			
		||||
        epochs=5,
 | 
			
		||||
        decay_steps=6,
 | 
			
		||||
        decay_rate=0.96,
 | 
			
		||||
    )
 | 
			
		||||
    model.quantization_aware_training(self.data, self.data, qat_hparams)
 | 
			
		||||
    qat_losses, qat_coco_metrics = model.evaluate(self.data)
 | 
			
		||||
    self._assert_ap_greater(qat_coco_metrics)
 | 
			
		||||
    self.assertNotAllEqual(losses, qat_losses)
 | 
			
		||||
    self.assertTrue(model._is_qat)
 | 
			
		||||
    model.export_model('model_qat.tflite')
 | 
			
		||||
    output_metadata_file = os.path.join(
 | 
			
		||||
        options.hparams.export_dir, 'metadata.json'
 | 
			
		||||
    )
 | 
			
		||||
    output_tflite_file = os.path.join(
 | 
			
		||||
        options.hparams.export_dir, 'model_qat.tflite'
 | 
			
		||||
    )
 | 
			
		||||
    print('ASDF qat', os.path.getsize(output_tflite_file))
 | 
			
		||||
    self.assertTrue(os.path.exists(output_tflite_file))
 | 
			
		||||
    self.assertGreater(os.path.getsize(output_tflite_file), 0)
 | 
			
		||||
    self.assertLess(os.path.getsize(output_tflite_file), 3500000)
 | 
			
		||||
    self.assertTrue(os.path.exists(output_metadata_file))
 | 
			
		||||
    self.assertGreater(os.path.getsize(output_metadata_file), 0)
 | 
			
		||||
 | 
			
		||||
    # Load float ckpt test
 | 
			
		||||
    model.restore_float_ckpt()
 | 
			
		||||
    losses_2, _ = model.evaluate(self.data)
 | 
			
		||||
    self.assertAllEqual(losses, losses_2)
 | 
			
		||||
    self.assertNotAllEqual(qat_losses, losses_2)
 | 
			
		||||
    self.assertFalse(model._is_qat)
 | 
			
		||||
 | 
			
		||||
  def _assert_ap_greater(self, coco_metrics, threshold=0.0):
 | 
			
		||||
    self.assertGreaterEqual(coco_metrics['AP'], threshold)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  tf.test.main()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,163 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Preprocessor for object detector."""
 | 
			
		||||
from typing import Any, Mapping, Tuple
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
 | 
			
		||||
from official.vision.dataloaders import utils
 | 
			
		||||
from official.vision.ops import anchor
 | 
			
		||||
from official.vision.ops import box_ops
 | 
			
		||||
from official.vision.ops import preprocess_ops
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO Combine preprocessing logic with image_preprocessor.
 | 
			
		||||
class Preprocessor(object):
 | 
			
		||||
  """Preprocessor for object detector."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, model_spec: ms.ModelSpec):
 | 
			
		||||
    """Initialize a Preprocessor."""
 | 
			
		||||
    self._mean_norm = model_spec.mean_norm
 | 
			
		||||
    self._stddev_norm = model_spec.stddev_norm
 | 
			
		||||
    self._output_size = model_spec.input_image_shape[:2]
 | 
			
		||||
    self._min_level = 3
 | 
			
		||||
    self._max_level = 7
 | 
			
		||||
    self._num_scales = 3
 | 
			
		||||
    self._aspect_ratios = [0.5, 1, 2]
 | 
			
		||||
    self._anchor_size = 3
 | 
			
		||||
    self._dtype = tf.float32
 | 
			
		||||
    self._match_threshold = 0.5
 | 
			
		||||
    self._unmatched_threshold = 0.5
 | 
			
		||||
    self._aug_scale_min = 0.5
 | 
			
		||||
    self._aug_scale_max = 2.0
 | 
			
		||||
    self._max_num_instances = 100
 | 
			
		||||
 | 
			
		||||
  def __call__(
 | 
			
		||||
      self, data: Mapping[str, Any], is_training: bool = True
 | 
			
		||||
  ) -> Tuple[tf.Tensor, Mapping[str, Any]]:
 | 
			
		||||
    """Run the preprocessor on an example.
 | 
			
		||||
 | 
			
		||||
    The data dict should contain the following keys always:
 | 
			
		||||
      - image
 | 
			
		||||
      - groundtruth_classes
 | 
			
		||||
      - groundtruth_boxes
 | 
			
		||||
      - groundtruth_is_crowd
 | 
			
		||||
    Additional keys needed when is_training is set to True:
 | 
			
		||||
      - groundtruth_area
 | 
			
		||||
      - source_id
 | 
			
		||||
      - height
 | 
			
		||||
      - width
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      data: A dict of object detector inputs.
 | 
			
		||||
      is_training: Whether or not the data is used for training.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      A tuple of (image, labels) where image is a Tensor and labels is a dict.
 | 
			
		||||
    """
 | 
			
		||||
    classes = data['groundtruth_classes']
 | 
			
		||||
    boxes = data['groundtruth_boxes']
 | 
			
		||||
 | 
			
		||||
    # Get original image.
 | 
			
		||||
    image = data['image']
 | 
			
		||||
    image_shape = tf.shape(input=image)[0:2]
 | 
			
		||||
 | 
			
		||||
    # Normalize image with mean and std pixel values.
 | 
			
		||||
    image = preprocess_ops.normalize_image(
 | 
			
		||||
        image, self._mean_norm, self._stddev_norm
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Flip image randomly during training.
 | 
			
		||||
    if is_training:
 | 
			
		||||
      image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
 | 
			
		||||
 | 
			
		||||
    # Convert boxes from normalized coordinates to pixel coordinates.
 | 
			
		||||
    boxes = box_ops.denormalize_boxes(boxes, image_shape)
 | 
			
		||||
 | 
			
		||||
    # Resize and crop image.
 | 
			
		||||
    image, image_info = preprocess_ops.resize_and_crop_image(
 | 
			
		||||
        image,
 | 
			
		||||
        self._output_size,
 | 
			
		||||
        padded_size=preprocess_ops.compute_padded_size(
 | 
			
		||||
            self._output_size, 2**self._max_level
 | 
			
		||||
        ),
 | 
			
		||||
        aug_scale_min=(self._aug_scale_min if is_training else 1.0),
 | 
			
		||||
        aug_scale_max=(self._aug_scale_max if is_training else 1.0),
 | 
			
		||||
    )
 | 
			
		||||
    image_height, image_width, _ = image.get_shape().as_list()
 | 
			
		||||
 | 
			
		||||
    # Resize and crop boxes.
 | 
			
		||||
    image_scale = image_info[2, :]
 | 
			
		||||
    offset = image_info[3, :]
 | 
			
		||||
    boxes = preprocess_ops.resize_and_crop_boxes(
 | 
			
		||||
        boxes, image_scale, image_info[1, :], offset
 | 
			
		||||
    )
 | 
			
		||||
    # Filter out ground-truth boxes that are all zeros.
 | 
			
		||||
    indices = box_ops.get_non_empty_box_indices(boxes)
 | 
			
		||||
    boxes = tf.gather(boxes, indices)
 | 
			
		||||
    classes = tf.gather(classes, indices)
 | 
			
		||||
 | 
			
		||||
    # Assign anchors.
 | 
			
		||||
    input_anchor = anchor.build_anchor_generator(
 | 
			
		||||
        min_level=self._min_level,
 | 
			
		||||
        max_level=self._max_level,
 | 
			
		||||
        num_scales=self._num_scales,
 | 
			
		||||
        aspect_ratios=self._aspect_ratios,
 | 
			
		||||
        anchor_size=self._anchor_size,
 | 
			
		||||
    )
 | 
			
		||||
    anchor_boxes = input_anchor(image_size=(image_height, image_width))
 | 
			
		||||
    anchor_labeler = anchor.AnchorLabeler(
 | 
			
		||||
        self._match_threshold, self._unmatched_threshold
 | 
			
		||||
    )
 | 
			
		||||
    (cls_targets, box_targets, _, cls_weights, box_weights) = (
 | 
			
		||||
        anchor_labeler.label_anchors(
 | 
			
		||||
            anchor_boxes, boxes, tf.expand_dims(classes, axis=1)
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Cast input image to desired data type.
 | 
			
		||||
    image = tf.cast(image, dtype=self._dtype)
 | 
			
		||||
 | 
			
		||||
    # Pack labels for model_fn outputs.
 | 
			
		||||
    labels = {
 | 
			
		||||
        'cls_targets': cls_targets,
 | 
			
		||||
        'box_targets': box_targets,
 | 
			
		||||
        'anchor_boxes': anchor_boxes,
 | 
			
		||||
        'cls_weights': cls_weights,
 | 
			
		||||
        'box_weights': box_weights,
 | 
			
		||||
        'image_info': image_info,
 | 
			
		||||
    }
 | 
			
		||||
    if not is_training:
 | 
			
		||||
      groundtruths = {
 | 
			
		||||
          'source_id': data['source_id'],
 | 
			
		||||
          'height': data['height'],
 | 
			
		||||
          'width': data['width'],
 | 
			
		||||
          'num_detections': tf.shape(data['groundtruth_classes']),
 | 
			
		||||
          'image_info': image_info,
 | 
			
		||||
          'boxes': box_ops.denormalize_boxes(
 | 
			
		||||
              data['groundtruth_boxes'], image_shape
 | 
			
		||||
          ),
 | 
			
		||||
          'classes': data['groundtruth_classes'],
 | 
			
		||||
          'areas': data['groundtruth_area'],
 | 
			
		||||
          'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
 | 
			
		||||
      }
 | 
			
		||||
      groundtruths['source_id'] = utils.process_source_id(
 | 
			
		||||
          groundtruths['source_id']
 | 
			
		||||
      )
 | 
			
		||||
      groundtruths = utils.pad_groundtruths_to_fixed_size(
 | 
			
		||||
          groundtruths, self._max_num_instances
 | 
			
		||||
      )
 | 
			
		||||
      labels.update({'groundtruths': groundtruths})
 | 
			
		||||
    return image, labels
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,158 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.core import test_utils
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
 | 
			
		||||
from mediapipe.model_maker.python.vision.object_detector import preprocessor as preprocessor_lib
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
 | 
			
		||||
  MAX_IMAGE_SIZE = 360
 | 
			
		||||
  OUTPUT_SIZE = 256
 | 
			
		||||
  NUM_CLASSES = 10
 | 
			
		||||
  NUM_EXAMPLES = 3
 | 
			
		||||
  MIN_LEVEL = 3
 | 
			
		||||
  MAX_LEVEL = 7
 | 
			
		||||
  NUM_SCALES = 3
 | 
			
		||||
  ASPECT_RATIOS = [0.5, 1, 2]
 | 
			
		||||
  MAX_NUM_INSTANCES = 100
 | 
			
		||||
 | 
			
		||||
  def _get_rand_example(self):
 | 
			
		||||
    num_annotations = random.randint(1, 3)
 | 
			
		||||
    bboxes, classes, is_crowds = [], [], []
 | 
			
		||||
    image_size = random.randint(10, self.MAX_IMAGE_SIZE + 1)
 | 
			
		||||
    rgb = [random.uniform(0, 255) for _ in range(3)]
 | 
			
		||||
    image = test_utils.fill_image(rgb, image_size)
 | 
			
		||||
    for _ in range(num_annotations):
 | 
			
		||||
      x1, x2 = random.uniform(0, image_size), random.uniform(0, image_size)
 | 
			
		||||
      y1, y2 = random.uniform(0, image_size), random.uniform(0, image_size)
 | 
			
		||||
      bbox = [min(x1, x2), min(y1, y2), abs(x1 - x2), abs(y1 - y2)]
 | 
			
		||||
      bboxes.append(bbox)
 | 
			
		||||
      classes.append(random.randint(0, self.NUM_CLASSES - 1))
 | 
			
		||||
      is_crowds.append(0)
 | 
			
		||||
    return {
 | 
			
		||||
        'image': tf.cast(image, dtype=tf.float32),
 | 
			
		||||
        'groundtruth_boxes': tf.cast(bboxes, dtype=tf.float32),
 | 
			
		||||
        'groundtruth_classes': tf.cast(classes, dtype=tf.int64),
 | 
			
		||||
        'groundtruth_is_crowd': tf.cast(is_crowds, dtype=tf.bool),
 | 
			
		||||
        'groundtruth_area': tf.cast(is_crowds, dtype=tf.float32),
 | 
			
		||||
        'source_id': tf.cast(1, dtype=tf.int64),
 | 
			
		||||
        'height': tf.cast(image_size, dtype=tf.int64),
 | 
			
		||||
        'width': tf.cast(image_size, dtype=tf.int64),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  def setUp(self):
 | 
			
		||||
    super().setUp()
 | 
			
		||||
    dataset = [self._get_rand_example() for _ in range(self.NUM_EXAMPLES)]
 | 
			
		||||
 | 
			
		||||
    def my_generator(data):
 | 
			
		||||
      for item in data:
 | 
			
		||||
        yield item
 | 
			
		||||
 | 
			
		||||
    self.dataset = tf.data.Dataset.from_generator(
 | 
			
		||||
        lambda: my_generator(dataset),
 | 
			
		||||
        output_types={
 | 
			
		||||
            'image': tf.float32,
 | 
			
		||||
            'groundtruth_classes': tf.int64,
 | 
			
		||||
            'groundtruth_boxes': tf.float32,
 | 
			
		||||
            'groundtruth_is_crowd': tf.bool,
 | 
			
		||||
            'groundtruth_area': tf.float32,
 | 
			
		||||
            'source_id': tf.int64,
 | 
			
		||||
            'height': tf.int64,
 | 
			
		||||
            'width': tf.int64,
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(
 | 
			
		||||
      dict(
 | 
			
		||||
          testcase_name='training',
 | 
			
		||||
          is_training=True,
 | 
			
		||||
      ),
 | 
			
		||||
      dict(
 | 
			
		||||
          testcase_name='evaluation',
 | 
			
		||||
          is_training=False,
 | 
			
		||||
      ),
 | 
			
		||||
  )
 | 
			
		||||
  def test_preprocessor(self, is_training):
 | 
			
		||||
    model_spec = ms.SupportedModels.MOBILENET_V2.value()
 | 
			
		||||
    labels_keys = [
 | 
			
		||||
        'cls_targets',
 | 
			
		||||
        'box_targets',
 | 
			
		||||
        'anchor_boxes',
 | 
			
		||||
        'cls_weights',
 | 
			
		||||
        'box_weights',
 | 
			
		||||
        'image_info',
 | 
			
		||||
    ]
 | 
			
		||||
    if not is_training:
 | 
			
		||||
      labels_keys.append('groundtruths')
 | 
			
		||||
    preprocessor = preprocessor_lib.Preprocessor(model_spec)
 | 
			
		||||
    for example in self.dataset:
 | 
			
		||||
      result = preprocessor(example, is_training=is_training)
 | 
			
		||||
      image, labels = result
 | 
			
		||||
      self.assertAllEqual(image.shape, (256, 256, 3))
 | 
			
		||||
      self.assertCountEqual(labels_keys, labels.keys())
 | 
			
		||||
      np_labels = tf.nest.map_structure(lambda x: x.numpy(), labels)
 | 
			
		||||
      # Checks shapes of `image_info` and `anchor_boxes`.
 | 
			
		||||
      self.assertEqual(np_labels['image_info'].shape, (4, 2))
 | 
			
		||||
      n_anchors = 0
 | 
			
		||||
      for level in range(self.MIN_LEVEL, self.MAX_LEVEL + 1):
 | 
			
		||||
        stride = 2**level
 | 
			
		||||
        output_size_l = [self.OUTPUT_SIZE / stride, self.OUTPUT_SIZE / stride]
 | 
			
		||||
        anchors_per_location = self.NUM_SCALES * len(self.ASPECT_RATIOS)
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            list(np_labels['anchor_boxes'][str(level)].shape),
 | 
			
		||||
            [output_size_l[0], output_size_l[1], 4 * anchors_per_location],
 | 
			
		||||
        )
 | 
			
		||||
        n_anchors += output_size_l[0] * output_size_l[1] * anchors_per_location
 | 
			
		||||
      # Checks shapes of training objectives.
 | 
			
		||||
      self.assertEqual(np_labels['cls_weights'].shape, (int(n_anchors),))
 | 
			
		||||
      for level in range(self.MIN_LEVEL, self.MAX_LEVEL + 1):
 | 
			
		||||
        stride = 2**level
 | 
			
		||||
        output_size_l = [self.OUTPUT_SIZE / stride, self.OUTPUT_SIZE / stride]
 | 
			
		||||
        anchors_per_location = self.NUM_SCALES * len(self.ASPECT_RATIOS)
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            list(np_labels['cls_targets'][str(level)].shape),
 | 
			
		||||
            [output_size_l[0], output_size_l[1], anchors_per_location],
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            list(np_labels['box_targets'][str(level)].shape),
 | 
			
		||||
            [output_size_l[0], output_size_l[1], 4 * anchors_per_location],
 | 
			
		||||
        )
 | 
			
		||||
      # Checks shape of groundtruths for eval.
 | 
			
		||||
      if not is_training:
 | 
			
		||||
        self.assertEqual(np_labels['groundtruths']['source_id'].shape, ())
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            np_labels['groundtruths']['classes'].shape,
 | 
			
		||||
            (self.MAX_NUM_INSTANCES,),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            np_labels['groundtruths']['boxes'].shape,
 | 
			
		||||
            (self.MAX_NUM_INSTANCES, 4),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            np_labels['groundtruths']['areas'].shape, (self.MAX_NUM_INSTANCES,)
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            np_labels['groundtruths']['is_crowds'].shape,
 | 
			
		||||
            (self.MAX_NUM_INSTANCES,),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  tf.test.main()
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000072.jpg
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 81 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000078.jpg
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 72 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000315.jpg
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 53 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000431.jpg
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 42 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000446.jpg
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 36 KiB  | 
							
								
								
									
										1
									
								
								mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/labels.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						| 
						 | 
				
			
			@ -0,0 +1,34 @@
 | 
			
		|||
<?xml version="1.0" encoding="utf-8"?>
 | 
			
		||||
<annotation>
 | 
			
		||||
<folder>images</folder>
 | 
			
		||||
<filename>37ca2a3d-IMG_0520.jpg</filename>
 | 
			
		||||
<source>
 | 
			
		||||
<database>MyDatabase</database>
 | 
			
		||||
<annotation>COCO2017</annotation>
 | 
			
		||||
<image>flickr</image>
 | 
			
		||||
<flickrid>NULL</flickrid>
 | 
			
		||||
<annotator>1</annotator>
 | 
			
		||||
</source>
 | 
			
		||||
<owner>
 | 
			
		||||
<flickrid>NULL</flickrid>
 | 
			
		||||
<name>Label Studio</name>
 | 
			
		||||
</owner>
 | 
			
		||||
<size>
 | 
			
		||||
<width>800</width>
 | 
			
		||||
<height>600</height>
 | 
			
		||||
<depth>3</depth>
 | 
			
		||||
</size>
 | 
			
		||||
<segmented>0</segmented>
 | 
			
		||||
<object>
 | 
			
		||||
<name>pig_android</name>
 | 
			
		||||
<pose>Unspecified</pose>
 | 
			
		||||
<truncated>0</truncated>
 | 
			
		||||
<difficult>0</difficult>
 | 
			
		||||
<bndbox>
 | 
			
		||||
<xmin>242</xmin>
 | 
			
		||||
<ymin>17</ymin>
 | 
			
		||||
<xmax>556</xmax>
 | 
			
		||||
<ymax>476</ymax>
 | 
			
		||||
</bndbox>
 | 
			
		||||
</object>
 | 
			
		||||
</annotation>
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,34 @@
 | 
			
		|||
<?xml version="1.0" encoding="utf-8"?>
 | 
			
		||||
<annotation>
 | 
			
		||||
<folder>images</folder>
 | 
			
		||||
<filename>3d3382d3-IMG_0514.jpg</filename>
 | 
			
		||||
<source>
 | 
			
		||||
<database>MyDatabase</database>
 | 
			
		||||
<annotation>COCO2017</annotation>
 | 
			
		||||
<image>flickr</image>
 | 
			
		||||
<flickrid>NULL</flickrid>
 | 
			
		||||
<annotator>1</annotator>
 | 
			
		||||
</source>
 | 
			
		||||
<owner>
 | 
			
		||||
<flickrid>NULL</flickrid>
 | 
			
		||||
<name>Label Studio</name>
 | 
			
		||||
</owner>
 | 
			
		||||
<size>
 | 
			
		||||
<width>800</width>
 | 
			
		||||
<height>600</height>
 | 
			
		||||
<depth>3</depth>
 | 
			
		||||
</size>
 | 
			
		||||
<segmented>0</segmented>
 | 
			
		||||
<object>
 | 
			
		||||
<name>android</name>
 | 
			
		||||
<pose>Unspecified</pose>
 | 
			
		||||
<truncated>0</truncated>
 | 
			
		||||
<difficult>0</difficult>
 | 
			
		||||
<bndbox>
 | 
			
		||||
<xmin>306</xmin>
 | 
			
		||||
<ymin>130</ymin>
 | 
			
		||||
<xmax>550</xmax>
 | 
			
		||||
<ymax>471</ymax>
 | 
			
		||||
</bndbox>
 | 
			
		||||
</object>
 | 
			
		||||
</annotation>
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,46 @@
 | 
			
		|||
<?xml version="1.0" encoding="utf-8"?>
 | 
			
		||||
<annotation>
 | 
			
		||||
<folder>images</folder>
 | 
			
		||||
<filename>d1c65813-IMG_0546.jpg</filename>
 | 
			
		||||
<source>
 | 
			
		||||
<database>MyDatabase</database>
 | 
			
		||||
<annotation>COCO2017</annotation>
 | 
			
		||||
<image>flickr</image>
 | 
			
		||||
<flickrid>NULL</flickrid>
 | 
			
		||||
<annotator>1</annotator>
 | 
			
		||||
</source>
 | 
			
		||||
<owner>
 | 
			
		||||
<flickrid>NULL</flickrid>
 | 
			
		||||
<name>Label Studio</name>
 | 
			
		||||
</owner>
 | 
			
		||||
<size>
 | 
			
		||||
<width>800</width>
 | 
			
		||||
<height>600</height>
 | 
			
		||||
<depth>3</depth>
 | 
			
		||||
</size>
 | 
			
		||||
<segmented>0</segmented>
 | 
			
		||||
<object>
 | 
			
		||||
<name>android</name>
 | 
			
		||||
<pose>Unspecified</pose>
 | 
			
		||||
<truncated>0</truncated>
 | 
			
		||||
<difficult>0</difficult>
 | 
			
		||||
<bndbox>
 | 
			
		||||
<xmin>93</xmin>
 | 
			
		||||
<ymin>101</ymin>
 | 
			
		||||
<xmax>358</xmax>
 | 
			
		||||
<ymax>378</ymax>
 | 
			
		||||
</bndbox>
 | 
			
		||||
</object>
 | 
			
		||||
<object>
 | 
			
		||||
<name>pig_android</name>
 | 
			
		||||
<pose>Unspecified</pose>
 | 
			
		||||
<truncated>0</truncated>
 | 
			
		||||
<difficult>0</difficult>
 | 
			
		||||
<bndbox>
 | 
			
		||||
<xmin>438</xmin>
 | 
			
		||||
<ymin>28</ymin>
 | 
			
		||||
<xmax>654</xmax>
 | 
			
		||||
<ymax>296</ymax>
 | 
			
		||||
</bndbox>
 | 
			
		||||
</object>
 | 
			
		||||
</annotation>
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,46 @@
 | 
			
		|||
<?xml version="1.0" encoding="utf-8"?>
 | 
			
		||||
<annotation>
 | 
			
		||||
  <folder>images</folder>
 | 
			
		||||
  <filename>d86b20e0-IMG_0509.jpg</filename>
 | 
			
		||||
  <source>
 | 
			
		||||
    <database>MyDatabase</database>
 | 
			
		||||
    <annotation>COCO2017</annotation>
 | 
			
		||||
    <image>flickr</image>
 | 
			
		||||
    <flickrid>NULL</flickrid>
 | 
			
		||||
    <annotator>1</annotator>
 | 
			
		||||
  </source>
 | 
			
		||||
  <owner>
 | 
			
		||||
    <flickrid>NULL</flickrid>
 | 
			
		||||
    <name>Label Studio</name>
 | 
			
		||||
  </owner>
 | 
			
		||||
  <size>
 | 
			
		||||
    <width>800</width>
 | 
			
		||||
    <height>600</height>
 | 
			
		||||
    <depth>3</depth>
 | 
			
		||||
  </size>
 | 
			
		||||
  <segmented>0</segmented>
 | 
			
		||||
  <object>
 | 
			
		||||
    <name>android</name>
 | 
			
		||||
    <pose>Unspecified</pose>
 | 
			
		||||
    <truncated>0</truncated>
 | 
			
		||||
    <difficult>0</difficult>
 | 
			
		||||
    <bndbox>
 | 
			
		||||
    <xmin>7</xmin>
 | 
			
		||||
    <ymin>122</ymin>
 | 
			
		||||
    <xmax>296</xmax>
 | 
			
		||||
    <ymax>402</ymax>
 | 
			
		||||
    </bndbox>
 | 
			
		||||
  </object>
 | 
			
		||||
  <object>
 | 
			
		||||
    <name>pig_android</name>
 | 
			
		||||
    <pose>Unspecified</pose>
 | 
			
		||||
    <truncated>0</truncated>
 | 
			
		||||
    <difficult>0</difficult>
 | 
			
		||||
    <bndbox>
 | 
			
		||||
    <xmin>523</xmin>
 | 
			
		||||
    <ymin>69</ymin>
 | 
			
		||||
    <xmax>723</xmax>
 | 
			
		||||
    <ymax>329</ymax>
 | 
			
		||||
    </bndbox>
 | 
			
		||||
  </object>
 | 
			
		||||
</annotation>
 | 
			
		||||
| 
		 After Width: | Height: | Size: 54 KiB  | 
| 
		 After Width: | Height: | Size: 73 KiB  | 
| 
		 After Width: | Height: | Size: 110 KiB  | 
| 
		 After Width: | Height: | Size: 120 KiB  |