Open Source Object Detector
PiperOrigin-RevId: 519201221
|
@ -17,6 +17,7 @@ from mediapipe.model_maker.python.core.utils import quantization
|
||||||
from mediapipe.model_maker.python.vision import image_classifier
|
from mediapipe.model_maker.python.vision import image_classifier
|
||||||
from mediapipe.model_maker.python.vision import gesture_recognizer
|
from mediapipe.model_maker.python.vision import gesture_recognizer
|
||||||
from mediapipe.model_maker.python.text import text_classifier
|
from mediapipe.model_maker.python.text import text_classifier
|
||||||
|
from mediapipe.model_maker.python.vision import object_detector
|
||||||
|
|
||||||
# Remove duplicated and non-public API
|
# Remove duplicated and non-public API
|
||||||
del python
|
del python
|
||||||
|
|
195
mediapipe/model_maker/python/vision/object_detector/BUILD
Normal file
|
@ -0,0 +1,195 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "object_detector_import",
|
||||||
|
srcs = ["__init__.py"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":hyperparameters",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
":object_detector",
|
||||||
|
":object_detector_options",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "object_detector_demo",
|
||||||
|
srcs = ["object_detector_demo.py"],
|
||||||
|
data = [":testdata"],
|
||||||
|
python_version = "PY3",
|
||||||
|
tags = ["requires-net:external"],
|
||||||
|
deps = [":object_detector_import"],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = glob([
|
||||||
|
"testdata/**",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dataset",
|
||||||
|
srcs = ["dataset.py"],
|
||||||
|
deps = [
|
||||||
|
":dataset_util",
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dataset_test",
|
||||||
|
srcs = ["dataset_test.py"],
|
||||||
|
data = [":testdata"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||||
|
"//mediapipe/model_maker/python/vision/core:test_utils",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dataset_util",
|
||||||
|
srcs = ["dataset_util.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dataset_util_test",
|
||||||
|
srcs = ["dataset_util_test.py"],
|
||||||
|
data = [":testdata"],
|
||||||
|
deps = [
|
||||||
|
":dataset_util",
|
||||||
|
"//mediapipe/model_maker/python/vision/core:test_utils",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "hyperparameters",
|
||||||
|
srcs = ["hyperparameters.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "preprocessor",
|
||||||
|
srcs = ["preprocessor.py"],
|
||||||
|
deps = [":model_spec"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "preprocessor_test",
|
||||||
|
srcs = ["preprocessor_test.py"],
|
||||||
|
deps = [
|
||||||
|
":model_spec",
|
||||||
|
":preprocessor",
|
||||||
|
"//mediapipe/model_maker/python/vision/core:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model",
|
||||||
|
srcs = ["model.py"],
|
||||||
|
deps = [
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "model_test",
|
||||||
|
size = "large",
|
||||||
|
srcs = ["model_test.py"],
|
||||||
|
data = [":testdata"],
|
||||||
|
shard_count = 4,
|
||||||
|
tags = ["requires-net:external"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":model",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
":preprocessor",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_options",
|
||||||
|
srcs = ["model_options.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_spec",
|
||||||
|
srcs = ["model_spec.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/core/utils:file_util"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "object_detector",
|
||||||
|
srcs = ["object_detector.py"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":hyperparameters",
|
||||||
|
":model",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
":object_detector_options",
|
||||||
|
":preprocessor",
|
||||||
|
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:object_detector",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "object_detector_test",
|
||||||
|
size = "enormous",
|
||||||
|
srcs = ["object_detector_test.py"],
|
||||||
|
data = [":testdata"],
|
||||||
|
tags = ["requires-net:external"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":hyperparameters",
|
||||||
|
":model_spec",
|
||||||
|
":object_detector",
|
||||||
|
":object_detector_options",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "object_detector_options",
|
||||||
|
srcs = ["object_detector_options.py"],
|
||||||
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe Model Maker Python Public API For Object Detector."""
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import dataset
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_options
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_spec
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import object_detector
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
||||||
|
|
||||||
|
ObjectDetector = object_detector.ObjectDetector
|
||||||
|
ModelOptions = model_options.ObjectDetectorModelOptions
|
||||||
|
ModelSpec = model_spec.ModelSpec
|
||||||
|
SupportedModels = model_spec.SupportedModels
|
||||||
|
HParams = hyperparameters.HParams
|
||||||
|
QATHParams = hyperparameters.QATHParams
|
||||||
|
Dataset = dataset.Dataset
|
||||||
|
ObjectDetectorOptions = object_detector_options.ObjectDetectorOptions
|
179
mediapipe/model_maker/python/vision/object_detector/dataset.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Object detector dataset library."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
||||||
|
from official.vision.dataloaders import tf_example_decoder
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
|
"""Dataset library for object detector."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_coco_folder(
|
||||||
|
cls,
|
||||||
|
data_dir: str,
|
||||||
|
max_num_images: Optional[int] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
) -> 'Dataset':
|
||||||
|
"""Loads images and labels from the given directory in COCO format.
|
||||||
|
|
||||||
|
- https://cocodataset.org/#home
|
||||||
|
|
||||||
|
Folder structure should be:
|
||||||
|
<data_dir>/
|
||||||
|
images/
|
||||||
|
<file0>.jpg
|
||||||
|
...
|
||||||
|
labels.json
|
||||||
|
|
||||||
|
The `labels.json` annotations file should should have the following format:
|
||||||
|
{
|
||||||
|
"categories": [{"id": 0, "name": "background"}, ...],
|
||||||
|
"images": [{"id": 0, "file_name": "<file0>.jpg"}, ...],
|
||||||
|
"annotations": [{
|
||||||
|
"id": 0,
|
||||||
|
"image_id": 0,
|
||||||
|
"category_id": 2,
|
||||||
|
"bbox": [x-top left, y-top left, width, height],
|
||||||
|
}, ...]
|
||||||
|
}
|
||||||
|
Note that category id 0 is reserved for the "background" class. It is
|
||||||
|
optional to include, but if included it must be set to "background".
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Name of the directory containing the data files.
|
||||||
|
max_num_images: Max number of images to process.
|
||||||
|
cache_dir: The cache directory to save TFRecord and metadata files. The
|
||||||
|
TFRecord files are a standardized format for training object detection
|
||||||
|
while the metadata file is used to store information like dataset size
|
||||||
|
and label mapping of id to label name. If the cache_dir is not set, a
|
||||||
|
temporary folder will be created and will not be removed automatically
|
||||||
|
after training which means it can be reused later.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing images and labels and other related info.
|
||||||
|
Raises:
|
||||||
|
ValueError: If the input data directory is empty.
|
||||||
|
ValueError: If the label_name for id 0 is set to something other than
|
||||||
|
the 'background' class.
|
||||||
|
"""
|
||||||
|
cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir)
|
||||||
|
if not dataset_util.is_cached(cache_files):
|
||||||
|
label_map = dataset_util.get_label_map_coco(data_dir)
|
||||||
|
cache_writer = dataset_util.COCOCacheFilesWriter(
|
||||||
|
label_map=label_map, max_num_images=max_num_images
|
||||||
|
)
|
||||||
|
cache_writer.write_files(cache_files, data_dir)
|
||||||
|
return cls.from_cache(cache_files.cache_prefix)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pascal_voc_folder(
|
||||||
|
cls,
|
||||||
|
data_dir: str,
|
||||||
|
max_num_images: Optional[int] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
) -> 'Dataset':
|
||||||
|
"""Loads images and labels from the given directory in PASCAL VOC format.
|
||||||
|
|
||||||
|
- http://host.robots.ox.ac.uk/pascal/VOC.
|
||||||
|
|
||||||
|
Folder structure should be:
|
||||||
|
<data_dir>/
|
||||||
|
images/
|
||||||
|
<file0>.jpg
|
||||||
|
...
|
||||||
|
Annotations/
|
||||||
|
<file0>.xml
|
||||||
|
...
|
||||||
|
Each <file0>.xml annotation file should have the following format:
|
||||||
|
<annotation>
|
||||||
|
<filename>file0.jpg<filename>
|
||||||
|
<object>
|
||||||
|
<name>kangaroo</name>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>233</xmin>
|
||||||
|
<ymin>89</ymin>
|
||||||
|
<xmax>386</xmax>
|
||||||
|
<ymax>262</ymax>
|
||||||
|
</object>
|
||||||
|
<object>...</object>
|
||||||
|
</annotation>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Name of the directory containing the data files.
|
||||||
|
max_num_images: Max number of images to process.
|
||||||
|
cache_dir: The cache directory to save TFRecord and metadata files. The
|
||||||
|
TFRecord files are a standardized format for training object detection
|
||||||
|
while the metadata file is used to store information like dataset size
|
||||||
|
and label mapping of id to label name. If the cache_dir is not set, a
|
||||||
|
temporary folder will be created and will not be removed automatically
|
||||||
|
after training which means it can be reused later.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing images and labels and other related info.
|
||||||
|
Raises:
|
||||||
|
ValueError: if the input data directory is empty.
|
||||||
|
"""
|
||||||
|
cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir)
|
||||||
|
if not dataset_util.is_cached(cache_files):
|
||||||
|
label_map = dataset_util.get_label_map_pascal_voc(data_dir)
|
||||||
|
cache_writer = dataset_util.PascalVocCacheFilesWriter(
|
||||||
|
label_map=label_map, max_num_images=max_num_images
|
||||||
|
)
|
||||||
|
cache_writer.write_files(cache_files, data_dir)
|
||||||
|
|
||||||
|
return cls.from_cache(cache_files.cache_prefix)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cache(cls, cache_prefix: str) -> 'Dataset':
|
||||||
|
"""Loads the TFRecord data from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_prefix: The cache prefix including the cache directory and the cache
|
||||||
|
prefix filename, e.g: '/tmp/cache/train'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ObjectDetectorDataset object.
|
||||||
|
"""
|
||||||
|
# Get TFRecord Files
|
||||||
|
tfrecord_file_patten = cache_prefix + '*.tfrecord'
|
||||||
|
matched_files = tf.io.gfile.glob(tfrecord_file_patten)
|
||||||
|
if not matched_files:
|
||||||
|
raise ValueError('TFRecord files are empty.')
|
||||||
|
|
||||||
|
# Load meta_data.
|
||||||
|
meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX
|
||||||
|
if not tf.io.gfile.exists(meta_data_file):
|
||||||
|
raise ValueError("Metadata file %s doesn't exist." % meta_data_file)
|
||||||
|
with tf.io.gfile.GFile(meta_data_file, 'r') as f:
|
||||||
|
meta_data = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
dataset = tf.data.TFRecordDataset(matched_files)
|
||||||
|
decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False)
|
||||||
|
dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
|
|
||||||
|
label_map = meta_data['label_map']
|
||||||
|
label_names = [label_map[k] for k in sorted(label_map.keys())]
|
||||||
|
|
||||||
|
return Dataset(
|
||||||
|
dataset=dataset, size=meta_data['size'], label_names=label_names
|
||||||
|
)
|
|
@ -0,0 +1,119 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.core import image_utils
|
||||||
|
from mediapipe.model_maker.python.vision.core import test_utils
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import dataset
|
||||||
|
from mediapipe.tasks.python.test import test_utils as tasks_test_utils
|
||||||
|
|
||||||
|
IMAGE_SIZE = 224
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def _get_rand_bbox(self):
|
||||||
|
x1, x2 = random.uniform(0, IMAGE_SIZE), random.uniform(0, IMAGE_SIZE)
|
||||||
|
y1, y2 = random.uniform(0, IMAGE_SIZE), random.uniform(0, IMAGE_SIZE)
|
||||||
|
return [min(x1, x2), min(y1, y2), abs(x1 - x2), abs(y1 - y2)]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.coco_dataset_path = os.path.join(self.get_temp_dir(), 'coco_dataset')
|
||||||
|
if os.path.exists(self.coco_dataset_path):
|
||||||
|
return
|
||||||
|
os.mkdir(self.coco_dataset_path)
|
||||||
|
categories = [{'id': 1, 'name': 'daisy'}, {'id': 2, 'name': 'tulips'}]
|
||||||
|
images = [
|
||||||
|
{'id': 1, 'file_name': 'img1.jpeg'},
|
||||||
|
{'id': 2, 'file_name': 'img2.jpeg'},
|
||||||
|
]
|
||||||
|
annotations = [
|
||||||
|
{'image_id': 1, 'category_id': 1, 'bbox': self._get_rand_bbox()},
|
||||||
|
{'image_id': 2, 'category_id': 1, 'bbox': self._get_rand_bbox()},
|
||||||
|
{'image_id': 2, 'category_id': 2, 'bbox': self._get_rand_bbox()},
|
||||||
|
]
|
||||||
|
labels_dict = {
|
||||||
|
'categories': categories,
|
||||||
|
'images': images,
|
||||||
|
'annotations': annotations,
|
||||||
|
}
|
||||||
|
labels_json = json.dumps(labels_dict)
|
||||||
|
with open(os.path.join(self.coco_dataset_path, 'labels.json'), 'w') as f:
|
||||||
|
f.write(labels_json)
|
||||||
|
images_dir = os.path.join(self.coco_dataset_path, 'images')
|
||||||
|
os.mkdir(images_dir)
|
||||||
|
for item in images:
|
||||||
|
test_utils.write_filled_jpeg_file(
|
||||||
|
os.path.join(images_dir, item['file_name']),
|
||||||
|
[random.uniform(0, 255) for _ in range(3)],
|
||||||
|
IMAGE_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_from_coco_folder(self):
|
||||||
|
data = dataset.Dataset.from_coco_folder(
|
||||||
|
self.coco_dataset_path, cache_dir=self.get_temp_dir()
|
||||||
|
)
|
||||||
|
self.assertLen(data, 2)
|
||||||
|
self.assertEqual(data.num_classes, 3)
|
||||||
|
self.assertEqual(data.label_names, ['background', 'daisy', 'tulips'])
|
||||||
|
for example in data.gen_tf_dataset():
|
||||||
|
boxes = example['groundtruth_boxes']
|
||||||
|
classes = example['groundtruth_classes']
|
||||||
|
self.assertNotEmpty(boxes)
|
||||||
|
self.assertAllLessEqual(boxes, 1)
|
||||||
|
self.assertAllGreaterEqual(boxes, 0)
|
||||||
|
self.assertNotEmpty(classes)
|
||||||
|
self.assertTrue(
|
||||||
|
(classes.numpy() == [1]).all() or (classes.numpy() == [1, 2]).all()
|
||||||
|
)
|
||||||
|
if (classes.numpy() == [1, 1]).all():
|
||||||
|
raw_image_tensor = image_utils.load_image(
|
||||||
|
os.path.join(self.coco_dataset_path, 'images', 'img1.jpeg')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_image_tensor = image_utils.load_image(
|
||||||
|
os.path.join(self.coco_dataset_path, 'images', 'img2.jpeg')
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
(example['image'].numpy() == raw_image_tensor.numpy()).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_from_pascal_voc_folder(self):
|
||||||
|
pascal_voc_folder = tasks_test_utils.get_test_data_path('pascal_voc_data')
|
||||||
|
data = dataset.Dataset.from_pascal_voc_folder(
|
||||||
|
pascal_voc_folder, cache_dir=self.get_temp_dir()
|
||||||
|
)
|
||||||
|
self.assertLen(data, 4)
|
||||||
|
self.assertEqual(data.num_classes, 3)
|
||||||
|
self.assertEqual(data.label_names, ['background', 'android', 'pig_android'])
|
||||||
|
for example in data.gen_tf_dataset():
|
||||||
|
boxes = example['groundtruth_boxes']
|
||||||
|
classes = example['groundtruth_classes']
|
||||||
|
self.assertNotEmpty(boxes)
|
||||||
|
self.assertAllLessEqual(boxes, 1)
|
||||||
|
self.assertAllGreaterEqual(boxes, 0)
|
||||||
|
self.assertNotEmpty(classes)
|
||||||
|
image = example['image']
|
||||||
|
self.assertNotEmpty(image)
|
||||||
|
self.assertAllGreaterEqual(image, 0)
|
||||||
|
self.assertAllLessEqual(image, 255)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,484 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Utilities for Object Detector Dataset Library."""
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import collections
|
||||||
|
import dataclasses
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from official.vision.data import tfrecord_lib
|
||||||
|
|
||||||
|
|
||||||
|
# Suffix of the meta data file name.
|
||||||
|
META_DATA_FILE_SUFFIX = '_meta_data.yaml'
|
||||||
|
|
||||||
|
|
||||||
|
def _xml_get(node: ET.Element, name: str) -> ET.Element:
|
||||||
|
"""Gets a named child from an XML Element node.
|
||||||
|
|
||||||
|
This method is used to retrieve an XML element that is expected to exist as a
|
||||||
|
subelement of the `node` passed into this argument. If the subelement is not
|
||||||
|
found, then an error is thrown.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the subelement is not found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: XML Element Tree node.
|
||||||
|
name: Name of the child node to get
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A child node of the parameter node with the matching name.
|
||||||
|
"""
|
||||||
|
result = node.find(name)
|
||||||
|
if result is None:
|
||||||
|
raise ValueError(f'Unexpected xml format: {name} not found in {node}')
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cache_dir_or_create(cache_dir: Optional[str]) -> str:
|
||||||
|
"""Gets the cache directory or creates it if not exists."""
|
||||||
|
if cache_dir is None:
|
||||||
|
cache_dir = tempfile.mkdtemp()
|
||||||
|
if not tf.io.gfile.exists(cache_dir):
|
||||||
|
tf.io.gfile.makedirs(cache_dir)
|
||||||
|
return cache_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dir_basename(data_dir: str) -> str:
|
||||||
|
"""Gets the base name of the directory."""
|
||||||
|
return os.path.basename(os.path.abspath(data_dir))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class CacheFiles:
|
||||||
|
"""Cache files for object detection."""
|
||||||
|
|
||||||
|
cache_prefix: str
|
||||||
|
tfrecord_files: Sequence[str]
|
||||||
|
meta_data_file: str
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cache_files(
|
||||||
|
cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10
|
||||||
|
) -> CacheFiles:
|
||||||
|
"""Creates an object of CacheFiles class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_dir: The cache directory to save TFRecord and metadata file. When
|
||||||
|
cache_dir is None, a temporary folder will be created and will not be
|
||||||
|
removed automatically after training which makes it can be used later.
|
||||||
|
cache_prefix_filename: The cache prefix filename.
|
||||||
|
num_shards: Number of shards for output file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An object of CacheFiles class.
|
||||||
|
"""
|
||||||
|
cache_dir = _get_cache_dir_or_create(cache_dir)
|
||||||
|
# The cache prefix including the cache directory and the cache prefix
|
||||||
|
# filename, e.g: '/tmp/cache/train'.
|
||||||
|
cache_prefix = os.path.join(cache_dir, cache_prefix_filename)
|
||||||
|
tf.compat.v1.logging.info(
|
||||||
|
'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s'
|
||||||
|
% (cache_dir, cache_prefix_filename, cache_prefix)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cached files including the TFRecord files and the meta data file.
|
||||||
|
tfrecord_files = [
|
||||||
|
cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards)
|
||||||
|
for i in range(num_shards)
|
||||||
|
]
|
||||||
|
meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX
|
||||||
|
return CacheFiles(
|
||||||
|
cache_prefix=cache_prefix,
|
||||||
|
tfrecord_files=tuple(tfrecord_files),
|
||||||
|
meta_data_file=meta_data_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
|
||||||
|
"""Creates an object of CacheFiles class using a COCO formatted dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Folder path of the coco dataset
|
||||||
|
cache_dir: Folder path of the cache location. When cache_dir is None, a
|
||||||
|
temporary folder will be created and will not be removed automatically
|
||||||
|
after training which makes it can be used later.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An object of CacheFiles class.
|
||||||
|
"""
|
||||||
|
hasher = hashlib.md5()
|
||||||
|
# Update with dataset folder name
|
||||||
|
hasher.update(_get_dir_basename(data_dir).encode('utf-8'))
|
||||||
|
# Update with image filenames
|
||||||
|
for image_file in sorted(os.listdir(os.path.join(data_dir, 'images'))):
|
||||||
|
hasher.update(os.path.basename(image_file).encode('utf-8'))
|
||||||
|
# Update with labels.json file content
|
||||||
|
label_file = os.path.join(data_dir, 'labels.json')
|
||||||
|
with open(label_file, 'r') as f:
|
||||||
|
label_json = json.load(f)
|
||||||
|
hasher.update(str(label_json).encode('utf-8'))
|
||||||
|
num_examples = len(label_json['images'])
|
||||||
|
# Num_shards automatically set to 100 images per shard, up to 10 shards total.
|
||||||
|
# See https://www.tensorflow.org/tutorials/load_data/tfrecord for more info
|
||||||
|
# on sharding.
|
||||||
|
num_shards = min(math.ceil(num_examples / 100), 10)
|
||||||
|
# Update with num shards
|
||||||
|
hasher.update(str(num_shards).encode('utf-8'))
|
||||||
|
cache_prefix_filename = hasher.hexdigest()
|
||||||
|
|
||||||
|
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
|
||||||
|
"""Gets an object of CacheFiles using a PASCAL VOC formatted dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Folder path of the pascal voc dataset.
|
||||||
|
cache_dir: Folder path of the cache location. When cache_dir is None, a
|
||||||
|
temporary folder will be created and will not be removed automatically
|
||||||
|
after training which makes it can be used later.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An object of CacheFiles class.
|
||||||
|
"""
|
||||||
|
hasher = hashlib.md5()
|
||||||
|
hasher.update(_get_dir_basename(data_dir).encode('utf-8'))
|
||||||
|
annotation_files = tf.io.gfile.glob(
|
||||||
|
os.path.join(data_dir, 'Annotations') + r'/*.xml'
|
||||||
|
)
|
||||||
|
annotation_filenames = [
|
||||||
|
os.path.basename(ann_file) for ann_file in annotation_files
|
||||||
|
]
|
||||||
|
hasher.update(' '.join(annotation_filenames).encode('utf-8'))
|
||||||
|
num_examples = len(annotation_filenames)
|
||||||
|
num_shards = min(math.ceil(num_examples / 100), 10)
|
||||||
|
hasher.update(str(num_shards).encode('utf-8'))
|
||||||
|
cache_prefix_filename = hasher.hexdigest()
|
||||||
|
|
||||||
|
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cached(cache_files: CacheFiles) -> bool:
|
||||||
|
"""Checks whether cache files are already cached."""
|
||||||
|
all_cached_files = list(cache_files.tfrecord_files) + [
|
||||||
|
cache_files.meta_data_file
|
||||||
|
]
|
||||||
|
return all(tf.io.gfile.exists(path) for path in all_cached_files)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheFilesWriter(abc.ABC):
|
||||||
|
"""CacheFilesWriter class to write the cached files."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, label_map: Dict[int, str], max_num_images: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
|
"""Initializes CacheFilesWriter for object detector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label_map: Dict, map label integer ids to string label names such as {1:
|
||||||
|
'person', 2: 'notperson'}. 0 is the reserved key for `background` and
|
||||||
|
doesn't need to be included in `label_map`. Label names can't be
|
||||||
|
duplicated.
|
||||||
|
max_num_images: Max number of images to process. If None, process all the
|
||||||
|
images.
|
||||||
|
"""
|
||||||
|
self.label_map = label_map
|
||||||
|
self.max_num_images = max_num_images
|
||||||
|
|
||||||
|
def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None:
|
||||||
|
"""Writes TFRecord and meta_data files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_files: CacheFiles object including a list of TFRecord files and the
|
||||||
|
meta data yaml file to save the meta_data including data size and
|
||||||
|
label_map.
|
||||||
|
*args: Non-keyword of parameters used in the `_get_example` method.
|
||||||
|
**kwargs: Keyword parameters used in the `_get_example` method.
|
||||||
|
"""
|
||||||
|
writers = [
|
||||||
|
tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files
|
||||||
|
]
|
||||||
|
|
||||||
|
# Writes tf.Example into TFRecord files.
|
||||||
|
size = 0
|
||||||
|
for idx, tf_example in enumerate(self._get_example(*args, **kwargs)):
|
||||||
|
if self.max_num_images and idx >= self.max_num_images:
|
||||||
|
break
|
||||||
|
if idx % 100 == 0:
|
||||||
|
tf.compat.v1.logging.info('On image %d' % idx)
|
||||||
|
writers[idx % len(writers)].write(tf_example.SerializeToString())
|
||||||
|
size = idx + 1
|
||||||
|
|
||||||
|
for writer in writers:
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
# Writes meta_data into meta_data_file.
|
||||||
|
meta_data = {'size': size, 'label_map': self.label_map}
|
||||||
|
with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f:
|
||||||
|
yaml.dump(meta_data, f)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _get_example(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def get_label_map_coco(data_dir: str):
|
||||||
|
"""Gets the label map from a COCO formatted dataset directory.
|
||||||
|
|
||||||
|
Note that id 0 is reserved for the background class. If id=0 is set, it needs
|
||||||
|
to be set to "background". It is optional to include id=0 if it is unused, and
|
||||||
|
it will be automatically added by this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Path of the dataset directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
label_map dictionary of the format {<id>:<label_name>}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the label_name for id 0 is set to something other than
|
||||||
|
the "background" class.
|
||||||
|
"""
|
||||||
|
data_dir = os.path.abspath(data_dir)
|
||||||
|
# Process labels.json file
|
||||||
|
label_file = os.path.join(data_dir, 'labels.json')
|
||||||
|
with open(label_file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Categories
|
||||||
|
label_map = {}
|
||||||
|
for category in data['categories']:
|
||||||
|
label_map[int(category['id'])] = category['name']
|
||||||
|
|
||||||
|
if 0 in label_map and label_map[0] != 'background':
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
'Label index 0 is reserved for the background class, but '
|
||||||
|
f'it was found to be {label_map[0]}'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if 0 not in label_map:
|
||||||
|
label_map[0] = 'background'
|
||||||
|
|
||||||
|
return label_map
|
||||||
|
|
||||||
|
|
||||||
|
def get_label_map_pascal_voc(data_dir: str):
|
||||||
|
"""Gets the label map from a PASCAL VOC formatted dataset directory.
|
||||||
|
|
||||||
|
The id to label_name mapping is determined by sorting all label_names and
|
||||||
|
numbering them starting from 1. Id=0 is set as the 'background' class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Path of the dataset directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
label_map dictionary of the format {<id>:<label_name>}
|
||||||
|
"""
|
||||||
|
data_dir = os.path.abspath(data_dir)
|
||||||
|
all_label_names = set()
|
||||||
|
annotations_dir = os.path.join(data_dir, 'Annotations')
|
||||||
|
all_annotation_files = tf.io.gfile.glob(annotations_dir + r'/*.xml')
|
||||||
|
for ann_file in all_annotation_files:
|
||||||
|
tree = ET.parse(ann_file)
|
||||||
|
root = tree.getroot()
|
||||||
|
for child in root.iter('object'):
|
||||||
|
label_name = _xml_get(child, 'name').text
|
||||||
|
all_label_names.add(label_name)
|
||||||
|
label_map = {0: 'background'}
|
||||||
|
for ind, label_name in enumerate(sorted(all_label_names)):
|
||||||
|
label_map[ind + 1] = label_name
|
||||||
|
return label_map
|
||||||
|
|
||||||
|
|
||||||
|
def _bbox_data_to_feature_dict(data):
|
||||||
|
"""Converts a dictionary of bbox annotations to a feature dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dict with keys 'xmin', 'xmax', 'ymin', 'ymax', 'category_id'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Feature dictionary
|
||||||
|
"""
|
||||||
|
bbox_feature_dict = {
|
||||||
|
'image/object/bbox/xmin': tfrecord_lib.convert_to_feature(data['xmin']),
|
||||||
|
'image/object/bbox/xmax': tfrecord_lib.convert_to_feature(data['xmax']),
|
||||||
|
'image/object/bbox/ymin': tfrecord_lib.convert_to_feature(data['ymin']),
|
||||||
|
'image/object/bbox/ymax': tfrecord_lib.convert_to_feature(data['ymax']),
|
||||||
|
'image/object/class/label': tfrecord_lib.convert_to_feature(
|
||||||
|
data['category_id']
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return bbox_feature_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _coco_annotations_to_lists(
|
||||||
|
bbox_annotations: List[Mapping[str, Any]],
|
||||||
|
image_height: int,
|
||||||
|
image_width: int,
|
||||||
|
):
|
||||||
|
"""Converts COCO annotations to feature lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bbox_annotations: List of dicts with keys ['bbox', 'category_id']
|
||||||
|
image_height: Height of image
|
||||||
|
image_width: Width of iamge
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(data, num_annotations_skipped) tuple where data contains the keys:
|
||||||
|
['xmin', 'xmax', 'ymin', 'ymax', 'is_crowd', 'category_id', 'area'] and
|
||||||
|
num_annotations_skipped is the number of skipped annotations because of the
|
||||||
|
bbox having 0 area.
|
||||||
|
"""
|
||||||
|
|
||||||
|
data = collections.defaultdict(list)
|
||||||
|
|
||||||
|
num_annotations_skipped = 0
|
||||||
|
|
||||||
|
for object_annotations in bbox_annotations:
|
||||||
|
(x, y, width, height) = tuple(object_annotations['bbox'])
|
||||||
|
|
||||||
|
if width <= 0 or height <= 0:
|
||||||
|
num_annotations_skipped += 1
|
||||||
|
continue
|
||||||
|
if x + width > image_width or y + height > image_height:
|
||||||
|
num_annotations_skipped += 1
|
||||||
|
continue
|
||||||
|
data['xmin'].append(float(x) / image_width)
|
||||||
|
data['xmax'].append(float(x + width) / image_width)
|
||||||
|
data['ymin'].append(float(y) / image_height)
|
||||||
|
data['ymax'].append(float(y + height) / image_height)
|
||||||
|
category_id = int(object_annotations['category_id'])
|
||||||
|
data['category_id'].append(category_id)
|
||||||
|
|
||||||
|
return data, num_annotations_skipped
|
||||||
|
|
||||||
|
|
||||||
|
class COCOCacheFilesWriter(CacheFilesWriter):
|
||||||
|
"""CacheFilesWriter class to write the cached files for COCO data."""
|
||||||
|
|
||||||
|
def _get_example(self, data_dir: str) -> tf.train.Example:
|
||||||
|
"""Iterates over all examples in the COCO formatted dataset directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Path of the dataset directory
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
tf.train.Example
|
||||||
|
"""
|
||||||
|
data_dir = os.path.abspath(data_dir)
|
||||||
|
# Process labels.json file
|
||||||
|
label_file = os.path.join(data_dir, 'labels.json')
|
||||||
|
with open(label_file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Load all Annotations
|
||||||
|
img_to_annotations = collections.defaultdict(list)
|
||||||
|
for annotation in data['annotations']:
|
||||||
|
image_id = annotation['image_id']
|
||||||
|
img_to_annotations[image_id].append(annotation)
|
||||||
|
|
||||||
|
# For each Image:
|
||||||
|
for image in data['images']:
|
||||||
|
img_id = image['id']
|
||||||
|
file_name = image['file_name']
|
||||||
|
full_path = os.path.join(data_dir, 'images', file_name)
|
||||||
|
with tf.io.gfile.GFile(full_path, 'rb') as fid:
|
||||||
|
encoded_jpg = fid.read()
|
||||||
|
image = tf.io.decode_jpeg(encoded_jpg, channels=3)
|
||||||
|
height, width, _ = image.shape
|
||||||
|
feature_dict = tfrecord_lib.image_info_to_feature_dict(
|
||||||
|
height, width, file_name, img_id, encoded_jpg, 'jpg'
|
||||||
|
)
|
||||||
|
data, _ = _coco_annotations_to_lists(
|
||||||
|
img_to_annotations[img_id], height, width
|
||||||
|
)
|
||||||
|
if not data['xmin']:
|
||||||
|
# Skip examples which have no annotations
|
||||||
|
continue
|
||||||
|
bbox_feature_dict = _bbox_data_to_feature_dict(data)
|
||||||
|
feature_dict.update(bbox_feature_dict)
|
||||||
|
example = tf.train.Example(
|
||||||
|
features=tf.train.Features(feature=feature_dict)
|
||||||
|
)
|
||||||
|
yield example
|
||||||
|
|
||||||
|
|
||||||
|
class PascalVocCacheFilesWriter(CacheFilesWriter):
|
||||||
|
"""CacheFilesWriter class to write the cached files for PASCAL VOC data."""
|
||||||
|
|
||||||
|
def _get_example(self, data_dir: str) -> tf.train.Example:
|
||||||
|
"""Iterates over all examples in the PASCAL VOC formatted dataset directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Path of the dataset directory
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
tf.train.Example
|
||||||
|
"""
|
||||||
|
label_name_to_id = {name: i for (i, name) in self.label_map.items()}
|
||||||
|
annotations_dir = os.path.join(data_dir, 'Annotations')
|
||||||
|
images_dir = os.path.join(data_dir, 'images')
|
||||||
|
all_annotation_paths = tf.io.gfile.glob(annotations_dir + r'/*.xml')
|
||||||
|
|
||||||
|
data = collections.defaultdict(list)
|
||||||
|
for ind, ann_file in enumerate(all_annotation_paths):
|
||||||
|
tree = ET.parse(ann_file)
|
||||||
|
root = tree.getroot()
|
||||||
|
img_filename = _xml_get(root, 'filename').text
|
||||||
|
img_file = os.path.join(images_dir, img_filename)
|
||||||
|
with tf.io.gfile.GFile(img_file, 'rb') as fid:
|
||||||
|
encoded_jpg = fid.read()
|
||||||
|
image = tf.io.decode_jpeg(encoded_jpg, channels=3)
|
||||||
|
height, width, _ = image.shape
|
||||||
|
for child in root.iter('object'):
|
||||||
|
category_name = _xml_get(child, 'name').text
|
||||||
|
category_id = label_name_to_id[category_name]
|
||||||
|
bndbox = _xml_get(child, 'bndbox')
|
||||||
|
xmin = float(_xml_get(bndbox, 'xmin').text)
|
||||||
|
xmax = float(_xml_get(bndbox, 'xmax').text)
|
||||||
|
ymin = float(_xml_get(bndbox, 'ymin').text)
|
||||||
|
ymax = float(_xml_get(bndbox, 'ymax').text)
|
||||||
|
if xmax <= xmin or ymax <= ymin or xmax > width or ymax > height:
|
||||||
|
# Skip annotations that have no area or are larger than the image
|
||||||
|
continue
|
||||||
|
data['xmin'].append(xmin / width)
|
||||||
|
data['ymin'].append(ymin / height)
|
||||||
|
data['xmax'].append(xmax / width)
|
||||||
|
data['ymax'].append(ymax / height)
|
||||||
|
data['category_id'].append(category_id)
|
||||||
|
if not data['xmin']:
|
||||||
|
# Skip examples which have no valid annotations
|
||||||
|
continue
|
||||||
|
feature_dict = tfrecord_lib.image_info_to_feature_dict(
|
||||||
|
height, width, img_filename, ind, encoded_jpg, 'jpg'
|
||||||
|
)
|
||||||
|
bbox_feature_dict = _bbox_data_to_feature_dict(data)
|
||||||
|
feature_dict.update(bbox_feature_dict)
|
||||||
|
example = tf.train.Example(
|
||||||
|
features=tf.train.Features(feature=feature_dict)
|
||||||
|
)
|
||||||
|
yield example
|
|
@ -0,0 +1,236 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.core import test_utils
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
||||||
|
from mediapipe.tasks.python.test import test_utils as tasks_test_utils
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetUtilTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def _assert_cache_files_equal(self, cf1, cf2):
|
||||||
|
self.assertEqual(cf1.cache_prefix, cf2.cache_prefix)
|
||||||
|
self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files)
|
||||||
|
self.assertEqual(cf1.meta_data_file, cf2.meta_data_file)
|
||||||
|
|
||||||
|
def _assert_cache_files_not_equal(self, cf1, cf2):
|
||||||
|
self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix)
|
||||||
|
self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files)
|
||||||
|
self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file)
|
||||||
|
|
||||||
|
def _get_cache_files_and_assert_neq_fn(self, cache_files_fn):
|
||||||
|
def get_cache_files_and_assert_neq(cf, data_dir, cache_dir):
|
||||||
|
new_cf = cache_files_fn(data_dir, cache_dir)
|
||||||
|
self._assert_cache_files_not_equal(cf, new_cf)
|
||||||
|
return new_cf
|
||||||
|
|
||||||
|
return get_cache_files_and_assert_neq
|
||||||
|
|
||||||
|
@unittest_mock.patch.object(hashlib, 'md5', autospec=True)
|
||||||
|
def test_get_cache_files_coco(self, mock_md5):
|
||||||
|
mock_md5.return_value.hexdigest.return_value = 'train'
|
||||||
|
cache_files = dataset_util.get_cache_files_coco(
|
||||||
|
tasks_test_utils.get_test_data_path('coco_data'), cache_dir='/tmp/'
|
||||||
|
)
|
||||||
|
self.assertEqual(cache_files.cache_prefix, '/tmp/train')
|
||||||
|
self.assertLen(cache_files.tfrecord_files, 1)
|
||||||
|
self.assertEqual(
|
||||||
|
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
||||||
|
)
|
||||||
|
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
|
||||||
|
|
||||||
|
def test_matching_get_cache_files_coco(self):
|
||||||
|
cache_dir = self.create_tempdir()
|
||||||
|
coco_folder = tasks_test_utils.get_test_data_path('coco_data')
|
||||||
|
coco_folder_tmp = os.path.join(self.create_tempdir(), 'coco_data')
|
||||||
|
shutil.copytree(coco_folder, coco_folder_tmp)
|
||||||
|
cache_files1 = dataset_util.get_cache_files_coco(coco_folder, cache_dir)
|
||||||
|
cache_files2 = dataset_util.get_cache_files_coco(coco_folder, cache_dir)
|
||||||
|
self._assert_cache_files_equal(cache_files1, cache_files2)
|
||||||
|
cache_files3 = dataset_util.get_cache_files_coco(coco_folder_tmp, cache_dir)
|
||||||
|
self._assert_cache_files_equal(cache_files1, cache_files3)
|
||||||
|
|
||||||
|
def test_not_matching_get_cache_files_coco(self):
|
||||||
|
cache_dir = self.create_tempdir()
|
||||||
|
temp_dir = self.create_tempdir()
|
||||||
|
coco_folder = os.path.join(temp_dir, 'coco_data')
|
||||||
|
shutil.copytree(
|
||||||
|
tasks_test_utils.get_test_data_path('coco_data'), coco_folder
|
||||||
|
)
|
||||||
|
prev_cache_file = dataset_util.get_cache_files_coco(coco_folder, cache_dir)
|
||||||
|
os.chmod(coco_folder, 0o700)
|
||||||
|
os.chmod(os.path.join(coco_folder, 'images'), 0o700)
|
||||||
|
os.chmod(os.path.join(coco_folder, 'labels.json'), 0o700)
|
||||||
|
get_cache_files_and_assert_neq = self._get_cache_files_and_assert_neq_fn(
|
||||||
|
dataset_util.get_cache_files_coco
|
||||||
|
)
|
||||||
|
# Test adding image
|
||||||
|
test_utils.write_filled_jpeg_file(
|
||||||
|
os.path.join(coco_folder, 'images', 'test.jpg'), [0, 0, 0], 50
|
||||||
|
)
|
||||||
|
prev_cache_file = get_cache_files_and_assert_neq(
|
||||||
|
prev_cache_file, coco_folder, cache_dir
|
||||||
|
)
|
||||||
|
# Test modifying labels.json
|
||||||
|
with open(os.path.join(coco_folder, 'labels.json'), 'w') as f:
|
||||||
|
json.dump({'images': [{'id': 1, 'file_name': '000000000078.jpg'}]}, f)
|
||||||
|
prev_cache_file = get_cache_files_and_assert_neq(
|
||||||
|
prev_cache_file, coco_folder, cache_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test rename folder
|
||||||
|
new_coco_folder = os.path.join(temp_dir, 'coco_data_renamed')
|
||||||
|
shutil.move(coco_folder, new_coco_folder)
|
||||||
|
coco_folder = new_coco_folder
|
||||||
|
prev_cache_file = get_cache_files_and_assert_neq(
|
||||||
|
prev_cache_file, new_coco_folder, cache_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest_mock.patch.object(hashlib, 'md5', autospec=True)
|
||||||
|
def test_get_cache_files_pascal_voc(self, mock_md5):
|
||||||
|
mock_md5.return_value.hexdigest.return_value = 'train'
|
||||||
|
cache_files = dataset_util.get_cache_files_pascal_voc(
|
||||||
|
tasks_test_utils.get_test_data_path('pascal_voc_data'),
|
||||||
|
cache_dir='/tmp/',
|
||||||
|
)
|
||||||
|
self.assertEqual(cache_files.cache_prefix, '/tmp/train')
|
||||||
|
self.assertLen(cache_files.tfrecord_files, 1)
|
||||||
|
self.assertEqual(
|
||||||
|
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
||||||
|
)
|
||||||
|
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
|
||||||
|
|
||||||
|
def test_matching_get_cache_files_pascal_voc(self):
|
||||||
|
cache_dir = self.create_tempdir()
|
||||||
|
pascal_folder = tasks_test_utils.get_test_data_path('pascal_voc_data')
|
||||||
|
pascal_folder_temp = os.path.join(self.create_tempdir(), 'pascal_voc_data')
|
||||||
|
shutil.copytree(pascal_folder, pascal_folder_temp)
|
||||||
|
cache_files1 = dataset_util.get_cache_files_pascal_voc(
|
||||||
|
pascal_folder, cache_dir
|
||||||
|
)
|
||||||
|
cache_files2 = dataset_util.get_cache_files_pascal_voc(
|
||||||
|
pascal_folder, cache_dir
|
||||||
|
)
|
||||||
|
self._assert_cache_files_equal(cache_files1, cache_files2)
|
||||||
|
cache_files3 = dataset_util.get_cache_files_pascal_voc(
|
||||||
|
pascal_folder_temp, cache_dir
|
||||||
|
)
|
||||||
|
self._assert_cache_files_equal(cache_files1, cache_files3)
|
||||||
|
|
||||||
|
def test_not_matching_get_cache_files_pascal_voc(self):
|
||||||
|
cache_dir = self.create_tempdir()
|
||||||
|
temp_dir = self.create_tempdir()
|
||||||
|
pascal_folder = os.path.join(temp_dir, 'pascal_voc_data')
|
||||||
|
shutil.copytree(
|
||||||
|
tasks_test_utils.get_test_data_path('pascal_voc_data'), pascal_folder
|
||||||
|
)
|
||||||
|
prev_cache_files = dataset_util.get_cache_files_pascal_voc(
|
||||||
|
pascal_folder, cache_dir
|
||||||
|
)
|
||||||
|
os.chmod(pascal_folder, 0o700)
|
||||||
|
os.chmod(os.path.join(pascal_folder, 'images'), 0o700)
|
||||||
|
os.chmod(os.path.join(pascal_folder, 'Annotations'), 0o700)
|
||||||
|
get_cache_files_and_assert_neq = self._get_cache_files_and_assert_neq_fn(
|
||||||
|
dataset_util.get_cache_files_pascal_voc
|
||||||
|
)
|
||||||
|
# Test adding xml file
|
||||||
|
with open(os.path.join(pascal_folder, 'Annotations', 'test.xml'), 'w') as f:
|
||||||
|
f.write('test')
|
||||||
|
prev_cache_files = get_cache_files_and_assert_neq(
|
||||||
|
prev_cache_files, pascal_folder, cache_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test rename folder
|
||||||
|
new_pascal_folder = os.path.join(temp_dir, 'pascal_voc_data_renamed')
|
||||||
|
shutil.move(pascal_folder, new_pascal_folder)
|
||||||
|
pascal_folder = new_pascal_folder
|
||||||
|
prev_cache_files = get_cache_files_and_assert_neq(
|
||||||
|
prev_cache_files, new_pascal_folder, cache_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_is_cached(self):
|
||||||
|
tempdir = self.create_tempdir()
|
||||||
|
cache_files = dataset_util.get_cache_files_coco(
|
||||||
|
tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir
|
||||||
|
)
|
||||||
|
self.assertFalse(dataset_util.is_cached(cache_files))
|
||||||
|
with open(cache_files.tfrecord_files[0], 'w') as f:
|
||||||
|
f.write('test')
|
||||||
|
self.assertFalse(dataset_util.is_cached(cache_files))
|
||||||
|
with open(cache_files.meta_data_file, 'w') as f:
|
||||||
|
f.write('test')
|
||||||
|
self.assertTrue(dataset_util.is_cached(cache_files))
|
||||||
|
|
||||||
|
def test_get_label_map_coco(self):
|
||||||
|
coco_dir = tasks_test_utils.get_test_data_path('coco_data')
|
||||||
|
label_map = dataset_util.get_label_map_coco(coco_dir)
|
||||||
|
all_keys = sorted(label_map.keys())
|
||||||
|
self.assertEqual(all_keys[0], 0)
|
||||||
|
self.assertEqual(all_keys[-1], 11)
|
||||||
|
self.assertLen(all_keys, 12)
|
||||||
|
|
||||||
|
def test_get_label_map_pascal_voc(self):
|
||||||
|
pascal_dir = tasks_test_utils.get_test_data_path('pascal_voc_data')
|
||||||
|
label_map = dataset_util.get_label_map_pascal_voc(pascal_dir)
|
||||||
|
all_keys = sorted(label_map.keys())
|
||||||
|
self.assertEqual(label_map[0], 'background')
|
||||||
|
self.assertEqual(all_keys[0], 0)
|
||||||
|
self.assertEqual(all_keys[-1], 2)
|
||||||
|
self.assertLen(all_keys, 3)
|
||||||
|
|
||||||
|
def _validate_cache_files(self, cache_files, expected_size):
|
||||||
|
# Checks the TFRecord file
|
||||||
|
self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0]))
|
||||||
|
self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0)
|
||||||
|
|
||||||
|
# Checks the meta_data file
|
||||||
|
self.assertTrue(os.path.isfile(cache_files.meta_data_file))
|
||||||
|
self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0)
|
||||||
|
with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f:
|
||||||
|
meta_data_dict = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
# Size is 3 because some examples are skipped for having poor bboxes
|
||||||
|
self.assertEqual(meta_data_dict['size'], expected_size)
|
||||||
|
|
||||||
|
def test_coco_cache_files_writer(self):
|
||||||
|
tempdir = self.create_tempdir()
|
||||||
|
coco_dir = tasks_test_utils.get_test_data_path('coco_data')
|
||||||
|
label_map = dataset_util.get_label_map_coco(coco_dir)
|
||||||
|
cache_writer = dataset_util.COCOCacheFilesWriter(label_map)
|
||||||
|
cache_files = dataset_util.get_cache_files_coco(coco_dir, cache_dir=tempdir)
|
||||||
|
cache_writer.write_files(cache_files, coco_dir)
|
||||||
|
self._validate_cache_files(cache_files, 3)
|
||||||
|
|
||||||
|
def test_pascal_voc_cache_files_writer(self):
|
||||||
|
tempdir = self.create_tempdir()
|
||||||
|
pascal_dir = tasks_test_utils.get_test_data_path('pascal_voc_data')
|
||||||
|
label_map = dataset_util.get_label_map_pascal_voc(pascal_dir)
|
||||||
|
cache_writer = dataset_util.PascalVocCacheFilesWriter(label_map)
|
||||||
|
cache_files = dataset_util.get_cache_files_pascal_voc(
|
||||||
|
pascal_dir, cache_dir=tempdir
|
||||||
|
)
|
||||||
|
cache_writer.write_files(cache_files, pascal_dir)
|
||||||
|
self._validate_cache_files(cache_files, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,101 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Hyperparameters for training object detection models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class HParams(hp.BaseHParams):
|
||||||
|
"""The hyperparameters for training object detectors.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate: Learning rate to use for gradient descent training.
|
||||||
|
batch_size: Batch size for training.
|
||||||
|
epochs: Number of training iterations over the dataset.
|
||||||
|
do_fine_tuning: If true, the base module is trained together with the
|
||||||
|
classification layer on top.
|
||||||
|
learning_rate_boundaries: List of epoch boundaries where
|
||||||
|
learning_rate_boundaries[i] is the epoch where the learning rate will
|
||||||
|
decay to learning_rate * learning_rate_decay_multipliers[i].
|
||||||
|
learning_rate_decay_multipliers: List of learning rate multipliers which
|
||||||
|
calculates the learning rate at the ith boundary as learning_rate *
|
||||||
|
learning_rate_decay_multipliers[i].
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Parameters from BaseHParams class.
|
||||||
|
learning_rate: float = 0.003
|
||||||
|
batch_size: int = 32
|
||||||
|
epochs: int = 10
|
||||||
|
|
||||||
|
# Parameters for learning rate decay
|
||||||
|
learning_rate_boundaries: List[int] = dataclasses.field(
|
||||||
|
default_factory=lambda: [5, 8]
|
||||||
|
)
|
||||||
|
learning_rate_decay_multipliers: List[float] = dataclasses.field(
|
||||||
|
default_factory=lambda: [0.1, 0.01]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Validate stepwise learning rate parameters
|
||||||
|
lr_boundary_len = len(self.learning_rate_boundaries)
|
||||||
|
lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers)
|
||||||
|
if lr_boundary_len != lr_decay_multipliers_len:
|
||||||
|
raise ValueError(
|
||||||
|
"Length of learning_rate_boundaries and ",
|
||||||
|
"learning_rate_decay_multipliers do not match: ",
|
||||||
|
f"{lr_boundary_len}!={lr_decay_multipliers_len}",
|
||||||
|
)
|
||||||
|
# Validate learning_rate_boundaries
|
||||||
|
if sorted(self.learning_rate_boundaries) != self.learning_rate_boundaries:
|
||||||
|
raise ValueError(
|
||||||
|
"learning_rate_boundaries is not in ascending order: ",
|
||||||
|
self.learning_rate_boundaries,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.learning_rate_boundaries
|
||||||
|
and self.learning_rate_boundaries[-1] > self.epochs
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Values in learning_rate_boundaries cannot be greater ", "than epochs"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class QATHParams:
|
||||||
|
"""The hyperparameters for running quantization aware training (QAT) on object detectors.
|
||||||
|
|
||||||
|
For more information on QAT, see:
|
||||||
|
https://www.tensorflow.org/model_optimization/guide/quantization/training
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate: Learning rate to use for gradient descent QAT.
|
||||||
|
batch_size: Batch size for QAT.
|
||||||
|
epochs: Number of training iterations over the dataset.
|
||||||
|
decay_steps: Learning rate decay steps for Exponential Decay. See
|
||||||
|
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay
|
||||||
|
for more information.
|
||||||
|
decay_rate: Learning rate decay rate for Exponential Decay. See
|
||||||
|
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay
|
||||||
|
for more information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
learning_rate: float = 0.03
|
||||||
|
batch_size: int = 32
|
||||||
|
epochs: int = 10
|
||||||
|
decay_steps: int = 231
|
||||||
|
decay_rate: float = 0.96
|
355
mediapipe/model_maker/python/vision/object_detector/model.py
Normal file
|
@ -0,0 +1,355 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Custom Model for Object Detection."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Mapping, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||||
|
from official.core import config_definitions as cfg
|
||||||
|
from official.projects.qat.vision.configs import common as qat_common
|
||||||
|
from official.projects.qat.vision.modeling import factory as qat_factory
|
||||||
|
from official.vision import configs
|
||||||
|
from official.vision.losses import focal_loss
|
||||||
|
from official.vision.losses import loss_utils
|
||||||
|
from official.vision.modeling import factory
|
||||||
|
from official.vision.modeling import retinanet_model
|
||||||
|
from official.vision.modeling.layers import detection_generator
|
||||||
|
from official.vision.serving import detection
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetectorModel(tf.keras.Model):
|
||||||
|
"""An object detector model which can be trained using Model Maker's training API.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
loss_trackers: List of tf.keras.metrics.Mean objects used to track the loss
|
||||||
|
during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_spec: ms.ModelSpec,
|
||||||
|
model_options: model_opt.ObjectDetectorModelOptions,
|
||||||
|
num_classes: int,
|
||||||
|
) -> None:
|
||||||
|
"""Initializes an ObjectDetectorModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_spec: Specification for the model.
|
||||||
|
model_options: Model options for creating the model.
|
||||||
|
num_classes: Number of classes for object detection.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._model_spec = model_spec
|
||||||
|
self._model_options = model_options
|
||||||
|
self._num_classes = num_classes
|
||||||
|
self._model = self._build_model()
|
||||||
|
checkpoint_folder = self._model_spec.downloaded_files.get_path()
|
||||||
|
checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200')
|
||||||
|
self.load_checkpoint(checkpoint_file)
|
||||||
|
self._model.summary()
|
||||||
|
self.loss_trackers = [
|
||||||
|
tf.keras.metrics.Mean(name=n)
|
||||||
|
for n in ['total_loss', 'cls_loss', 'box_loss', 'model_loss']
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_model_config(
|
||||||
|
self,
|
||||||
|
generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
|
||||||
|
) -> configs.retinanet.RetinaNet:
|
||||||
|
model_config = configs.retinanet.RetinaNet(
|
||||||
|
min_level=3,
|
||||||
|
max_level=7,
|
||||||
|
num_classes=self._num_classes,
|
||||||
|
input_size=self._model_spec.input_image_shape,
|
||||||
|
anchor=configs.retinanet.Anchor(
|
||||||
|
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
|
||||||
|
),
|
||||||
|
backbone=configs.backbones.Backbone(
|
||||||
|
type='mobilenet', mobilenet=configs.backbones.MobileNet()
|
||||||
|
),
|
||||||
|
decoder=configs.decoders.Decoder(
|
||||||
|
type='fpn',
|
||||||
|
fpn=configs.decoders.FPN(
|
||||||
|
num_filters=128, use_separable_conv=True, use_keras_layer=True
|
||||||
|
),
|
||||||
|
),
|
||||||
|
head=configs.retinanet.RetinaNetHead(
|
||||||
|
num_filters=128, use_separable_conv=True
|
||||||
|
),
|
||||||
|
detection_generator=generator_config,
|
||||||
|
norm_activation=configs.common.NormActivation(activation='relu6'),
|
||||||
|
)
|
||||||
|
return model_config
|
||||||
|
|
||||||
|
def _build_model(self) -> tf.keras.Model:
|
||||||
|
"""Builds a RetinaNet object detector model."""
|
||||||
|
input_specs = tf.keras.layers.InputSpec(
|
||||||
|
shape=[None] + self._model_spec.input_image_shape
|
||||||
|
)
|
||||||
|
l2_regularizer = tf.keras.regularizers.l2(
|
||||||
|
self._model_options.l2_weight_decay / 2.0
|
||||||
|
)
|
||||||
|
model_config = self._get_model_config()
|
||||||
|
|
||||||
|
return factory.build_retinanet(input_specs, model_config, l2_regularizer)
|
||||||
|
|
||||||
|
def save_checkpoint(self, checkpoint_path: str) -> None:
|
||||||
|
"""Saves a model checkpoint to checkpoint_path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: The path to save checkpoint.
|
||||||
|
"""
|
||||||
|
ckpt_items = {
|
||||||
|
'backbone': self._model.backbone,
|
||||||
|
'decoder': self._model.decoder,
|
||||||
|
'head': self._model.head,
|
||||||
|
}
|
||||||
|
tf.train.Checkpoint(**ckpt_items).write(checkpoint_path)
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
self, checkpoint_path: str, include_last_layer: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Loads a model checkpoint from checkpoint_path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: The path to load a checkpoint from.
|
||||||
|
include_last_layer: Whether or not to load the last classification layer.
|
||||||
|
The size of the last classification layer will differ depending on the
|
||||||
|
number of classes. When loading from the pre-trained checkpoint, this
|
||||||
|
parameter should be False to avoid shape mismatch on the last layer.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
|
||||||
|
self._model(dummy_input, training=True)
|
||||||
|
if include_last_layer:
|
||||||
|
head = self._model.head
|
||||||
|
else:
|
||||||
|
head_classifier = tf.train.Checkpoint(
|
||||||
|
depthwise_kernel=self._model.head._classifier.depthwise_kernel # pylint:disable=protected-access
|
||||||
|
)
|
||||||
|
head_items = {
|
||||||
|
'_classifier': head_classifier,
|
||||||
|
'_box_norms': self._model.head._box_norms, # pylint:disable=protected-access
|
||||||
|
'_box_regressor': self._model.head._box_regressor, # pylint:disable=protected-access
|
||||||
|
'_cls_convs': self._model.head._cls_convs, # pylint:disable=protected-access
|
||||||
|
'_cls_norms': self._model.head._cls_norms, # pylint:disable=protected-access
|
||||||
|
'_box_convs': self._model.head._box_convs, # pylint:disable=protected-access
|
||||||
|
}
|
||||||
|
head = tf.train.Checkpoint(**head_items)
|
||||||
|
ckpt_items = {
|
||||||
|
'backbone': self._model.backbone,
|
||||||
|
'decoder': self._model.decoder,
|
||||||
|
'head': head,
|
||||||
|
}
|
||||||
|
ckpt = tf.train.Checkpoint(**ckpt_items)
|
||||||
|
status = ckpt.read(checkpoint_path)
|
||||||
|
status.expect_partial().assert_existing_objects_matched()
|
||||||
|
|
||||||
|
def convert_to_qat(self) -> None:
|
||||||
|
"""Converts the model to a QAT RetinaNet model."""
|
||||||
|
model = self._build_model()
|
||||||
|
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
|
||||||
|
model(dummy_input, training=True)
|
||||||
|
model.set_weights(self._model.get_weights())
|
||||||
|
quantization_config = qat_common.Quantization(
|
||||||
|
quantize_detection_decoder=True, quantize_detection_head=True
|
||||||
|
)
|
||||||
|
model_config = self._get_model_config()
|
||||||
|
qat_model = qat_factory.build_qat_retinanet(
|
||||||
|
model, quantization_config, model_config
|
||||||
|
)
|
||||||
|
self._model = qat_model
|
||||||
|
|
||||||
|
def export_saved_model(self, save_path: str):
|
||||||
|
"""Exports a saved_model for tflite conversion.
|
||||||
|
|
||||||
|
The export process modifies the model in the following two ways:
|
||||||
|
1. Replaces the nms operation in the detection generator with a custom
|
||||||
|
TFLite compatible nms operation.
|
||||||
|
2. Wraps the model with a DetectionModule which handles pre-processing
|
||||||
|
and post-processing when running inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_path: Path to export the saved model.
|
||||||
|
"""
|
||||||
|
generator_config = configs.retinanet.DetectionGenerator(
|
||||||
|
nms_version='tflite',
|
||||||
|
tflite_post_processing=configs.common.TFLitePostProcessingConfig(
|
||||||
|
nms_score_threshold=0,
|
||||||
|
max_detections=10,
|
||||||
|
max_classes_per_detection=1,
|
||||||
|
normalize_anchor_coordinates=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tflite_post_processing_config = (
|
||||||
|
generator_config.tflite_post_processing.as_dict()
|
||||||
|
)
|
||||||
|
tflite_post_processing_config['input_image_size'] = (
|
||||||
|
self._model_spec.input_image_shape[0],
|
||||||
|
self._model_spec.input_image_shape[1],
|
||||||
|
)
|
||||||
|
detection_generator_obj = detection_generator.MultilevelDetectionGenerator(
|
||||||
|
apply_nms=generator_config.apply_nms,
|
||||||
|
pre_nms_top_k=generator_config.pre_nms_top_k,
|
||||||
|
pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
|
||||||
|
nms_iou_threshold=generator_config.nms_iou_threshold,
|
||||||
|
max_num_detections=generator_config.max_num_detections,
|
||||||
|
nms_version=generator_config.nms_version,
|
||||||
|
use_cpu_nms=generator_config.use_cpu_nms,
|
||||||
|
soft_nms_sigma=generator_config.soft_nms_sigma,
|
||||||
|
tflite_post_processing_config=tflite_post_processing_config,
|
||||||
|
return_decoded=generator_config.return_decoded,
|
||||||
|
use_class_agnostic_nms=generator_config.use_class_agnostic_nms,
|
||||||
|
)
|
||||||
|
model_config = self._get_model_config(generator_config)
|
||||||
|
model = retinanet_model.RetinaNetModel(
|
||||||
|
self._model.backbone,
|
||||||
|
self._model.decoder,
|
||||||
|
self._model.head,
|
||||||
|
detection_generator_obj,
|
||||||
|
min_level=model_config.min_level,
|
||||||
|
max_level=model_config.max_level,
|
||||||
|
num_scales=model_config.anchor.num_scales,
|
||||||
|
aspect_ratios=model_config.anchor.aspect_ratios,
|
||||||
|
anchor_size=model_config.anchor.anchor_size,
|
||||||
|
)
|
||||||
|
task_config = configs.retinanet.RetinaNetTask(model=model_config)
|
||||||
|
params = cfg.ExperimentConfig(
|
||||||
|
task=task_config,
|
||||||
|
)
|
||||||
|
export_module = detection.DetectionModule(
|
||||||
|
params=params,
|
||||||
|
batch_size=1,
|
||||||
|
input_image_size=self._model_spec.input_image_shape[:2],
|
||||||
|
input_type='tflite',
|
||||||
|
num_channels=self._model_spec.input_image_shape[2],
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
function_keys = {'tflite': tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY}
|
||||||
|
signatures = export_module.get_inference_signatures(function_keys)
|
||||||
|
|
||||||
|
tf.saved_model.save(export_module, save_path, signatures=signatures)
|
||||||
|
|
||||||
|
# The remaining method overrides are used to train this object detector model
|
||||||
|
# using model.fit().
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
images: Union[tf.Tensor, Sequence[tf.Tensor]],
|
||||||
|
image_shape: Optional[tf.Tensor] = None,
|
||||||
|
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
|
||||||
|
output_intermediate_features: bool = False,
|
||||||
|
training: bool = None,
|
||||||
|
) -> Mapping[str, tf.Tensor]:
|
||||||
|
"""Overrides call from tf.keras.Model."""
|
||||||
|
return self._model(
|
||||||
|
images,
|
||||||
|
image_shape,
|
||||||
|
anchor_boxes,
|
||||||
|
output_intermediate_features,
|
||||||
|
training,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
|
||||||
|
"""Overrides compute_loss from tf.keras.Model."""
|
||||||
|
cls_loss_fn = focal_loss.FocalLoss(
|
||||||
|
alpha=0.25, gamma=1.5, reduction=tf.keras.losses.Reduction.SUM
|
||||||
|
)
|
||||||
|
box_loss_fn = tf.keras.losses.Huber(
|
||||||
|
0.1, reduction=tf.keras.losses.Reduction.SUM
|
||||||
|
)
|
||||||
|
labels = y
|
||||||
|
outputs = y_pred
|
||||||
|
# Sums all positives in a batch for normalization and avoids zero
|
||||||
|
# num_positives_sum, which would lead to inf loss during training
|
||||||
|
cls_sample_weight = labels['cls_weights']
|
||||||
|
box_sample_weight = labels['box_weights']
|
||||||
|
num_positives = tf.reduce_sum(box_sample_weight) + 1.0
|
||||||
|
cls_sample_weight = cls_sample_weight / num_positives
|
||||||
|
box_sample_weight = box_sample_weight / num_positives
|
||||||
|
y_true_cls = loss_utils.multi_level_flatten(
|
||||||
|
labels['cls_targets'], last_dim=None
|
||||||
|
)
|
||||||
|
y_true_cls = tf.one_hot(y_true_cls, self._num_classes)
|
||||||
|
y_pred_cls = loss_utils.multi_level_flatten(
|
||||||
|
outputs['cls_outputs'], last_dim=self._num_classes
|
||||||
|
)
|
||||||
|
y_true_box = loss_utils.multi_level_flatten(
|
||||||
|
labels['box_targets'], last_dim=4
|
||||||
|
)
|
||||||
|
y_pred_box = loss_utils.multi_level_flatten(
|
||||||
|
outputs['box_outputs'], last_dim=4
|
||||||
|
)
|
||||||
|
|
||||||
|
cls_loss = cls_loss_fn(
|
||||||
|
y_true=y_true_cls, y_pred=y_pred_cls, sample_weight=cls_sample_weight
|
||||||
|
)
|
||||||
|
box_loss = box_loss_fn(
|
||||||
|
y_true=y_true_box, y_pred=y_pred_box, sample_weight=box_sample_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
model_loss = cls_loss + 50 * box_loss
|
||||||
|
total_loss = model_loss
|
||||||
|
regularization_losses = self._model.losses
|
||||||
|
if regularization_losses:
|
||||||
|
reg_loss = tf.reduce_sum(regularization_losses)
|
||||||
|
total_loss = model_loss + reg_loss
|
||||||
|
all_losses = {
|
||||||
|
'total_loss': total_loss,
|
||||||
|
'cls_loss': cls_loss,
|
||||||
|
'box_loss': box_loss,
|
||||||
|
'model_loss': model_loss,
|
||||||
|
}
|
||||||
|
for m in self.metrics:
|
||||||
|
m.update_state(all_losses[m.name])
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metrics(self):
|
||||||
|
"""Overrides metrics from tf.keras.Model."""
|
||||||
|
return self.loss_trackers
|
||||||
|
|
||||||
|
def compute_metrics(self, x, y, y_pred, sample_weight=None):
|
||||||
|
"""Overrides compute_metrics from tf.keras.Model."""
|
||||||
|
return self.get_metrics_result()
|
||||||
|
|
||||||
|
def train_step(self, data):
|
||||||
|
"""Overrides train_step from tf.keras.Model."""
|
||||||
|
tf.keras.backend.set_learning_phase(1)
|
||||||
|
x, y = data
|
||||||
|
# Run forward pass.
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
y_pred = self(x, training=True)
|
||||||
|
loss = self.compute_loss(x, y, y_pred)
|
||||||
|
self._validate_target_and_loss(y, loss)
|
||||||
|
# Run backwards pass.
|
||||||
|
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
||||||
|
return self.compute_metrics(x, y, y_pred)
|
||||||
|
|
||||||
|
def test_step(self, data):
|
||||||
|
"""Overrides test_step from tf.keras.Model."""
|
||||||
|
tf.keras.backend.set_learning_phase(0)
|
||||||
|
x, y = data
|
||||||
|
y_pred = self(
|
||||||
|
x,
|
||||||
|
anchor_boxes=y['anchor_boxes'],
|
||||||
|
image_shape=y['image_info'][:, 1, :],
|
||||||
|
training=False,
|
||||||
|
)
|
||||||
|
# Updates stateful loss metrics.
|
||||||
|
self.compute_loss(x, y, y_pred)
|
||||||
|
return self.compute_metrics(x, y, y_pred)
|
|
@ -0,0 +1,28 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Configurable model options for object detector models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ObjectDetectorModelOptions:
|
||||||
|
"""Configurable options for object detector model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
l2_weight_decay: L2 regularization penalty used in
|
||||||
|
https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/L2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
l2_weight_decay: float = 3.0e-05
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Object detector model specification."""
|
||||||
|
import dataclasses
|
||||||
|
import enum
|
||||||
|
import functools
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
|
|
||||||
|
|
||||||
|
MOBILENET_V2_FILES = file_util.DownloadedFiles(
|
||||||
|
'object_detector/mobilenetv2',
|
||||||
|
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ModelSpec(object):
|
||||||
|
"""Specification of object detector model."""
|
||||||
|
|
||||||
|
# Mean and Stddev image preprocessing normalization values.
|
||||||
|
mean_norm = (0.5,)
|
||||||
|
stddev_norm = (0.5,)
|
||||||
|
mean_rgb = (127.5,)
|
||||||
|
stddev_rgb = (127.5,)
|
||||||
|
|
||||||
|
downloaded_files: file_util.DownloadedFiles
|
||||||
|
input_image_shape: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
mobilenet_v2_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
downloaded_files=MOBILENET_V2_FILES,
|
||||||
|
input_image_shape=[256, 256, 3],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@enum.unique
|
||||||
|
class SupportedModels(enum.Enum):
|
||||||
|
"""Predefined object detector model specs supported by Model Maker."""
|
||||||
|
|
||||||
|
MOBILENET_V2 = mobilenet_v2_spec
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||||
|
"""Get model spec from the input enum and initializes it."""
|
||||||
|
if spec not in cls:
|
||||||
|
raise TypeError(f'Unsupported object detector spec: {spec}')
|
||||||
|
return spec.value()
|
|
@ -0,0 +1,147 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import dataset as ds
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model as model_lib
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import preprocessor
|
||||||
|
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||||
|
|
||||||
|
|
||||||
|
def _dicts_match(dict_1, dict_2):
|
||||||
|
for key in dict_1:
|
||||||
|
if key not in dict_2 or np.any(dict_1[key] != dict_2[key]):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _outputs_match(output1, output2):
|
||||||
|
return _dicts_match(
|
||||||
|
output1['cls_outputs'], output2['cls_outputs']
|
||||||
|
) and _dicts_match(output1['box_outputs'], output2['box_outputs'])
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetectorModelTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
||||||
|
cache_dir = self.create_tempdir()
|
||||||
|
self.data = ds.Dataset.from_coco_folder(dataset_folder, cache_dir=cache_dir)
|
||||||
|
self.model_spec = ms.SupportedModels.MOBILENET_V2.value()
|
||||||
|
self.preprocessor = preprocessor.Preprocessor(self.model_spec)
|
||||||
|
self.fake_inputs = np.random.uniform(
|
||||||
|
low=0, high=1, size=(1, 256, 256, 3)
|
||||||
|
).astype(np.float32)
|
||||||
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
# condition when downloading model since these tests may run in parallel.
|
||||||
|
mock_gettempdir = unittest_mock.patch.object(
|
||||||
|
tempfile,
|
||||||
|
'gettempdir',
|
||||||
|
return_value=self.create_tempdir(),
|
||||||
|
autospec=True,
|
||||||
|
)
|
||||||
|
self.mock_gettempdir = mock_gettempdir.start()
|
||||||
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
model_options = model_opt.ObjectDetectorModelOptions()
|
||||||
|
model = model_lib.ObjectDetectorModel(
|
||||||
|
self.model_spec, model_options, self.data.num_classes
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _train_model(self, model):
|
||||||
|
"""Helper to run a simple training run on the model."""
|
||||||
|
dataset = self.data.gen_tf_dataset(
|
||||||
|
batch_size=2,
|
||||||
|
is_training=True,
|
||||||
|
shuffle=False,
|
||||||
|
preprocess=self.preprocessor,
|
||||||
|
)
|
||||||
|
optimizer = tf.keras.optimizers.experimental.SGD(
|
||||||
|
learning_rate=0.03, momentum=0.9
|
||||||
|
)
|
||||||
|
model.compile(optimizer=optimizer)
|
||||||
|
model.fit(
|
||||||
|
x=dataset, epochs=2, steps_per_epoch=None, validation_data=dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
model = self._create_model()
|
||||||
|
outputs_before = model(self.fake_inputs, training=True)
|
||||||
|
self._train_model(model)
|
||||||
|
outputs_after = model(self.fake_inputs, training=True)
|
||||||
|
self.assertFalse(_outputs_match(outputs_before, outputs_after))
|
||||||
|
|
||||||
|
def test_model_convert_to_qat(self):
|
||||||
|
model_options = model_opt.ObjectDetectorModelOptions()
|
||||||
|
model = model_lib.ObjectDetectorModel(
|
||||||
|
self.model_spec, model_options, self.data.num_classes
|
||||||
|
)
|
||||||
|
outputs_before = model(self.fake_inputs, training=True)
|
||||||
|
model.convert_to_qat()
|
||||||
|
outputs_after = model(self.fake_inputs, training=True)
|
||||||
|
self.assertFalse(_outputs_match(outputs_before, outputs_after))
|
||||||
|
outputs_before = outputs_after
|
||||||
|
self._train_model(model)
|
||||||
|
outputs_after = model(self.fake_inputs, training=True)
|
||||||
|
self.assertFalse(_outputs_match(outputs_before, outputs_after))
|
||||||
|
|
||||||
|
def test_model_save_and_load_checkpoint(self):
|
||||||
|
model = self._create_model()
|
||||||
|
checkpoint_path = os.path.join(self.create_tempdir(), 'ckpt')
|
||||||
|
model.save_checkpoint(checkpoint_path)
|
||||||
|
data_checkpoint_file = checkpoint_path + '.data-00000-of-00001'
|
||||||
|
index_checkpoint_file = checkpoint_path + '.index'
|
||||||
|
self.assertTrue(os.path.exists(data_checkpoint_file))
|
||||||
|
self.assertTrue(os.path.exists(index_checkpoint_file))
|
||||||
|
self.assertGreater(os.path.getsize(data_checkpoint_file), 0)
|
||||||
|
self.assertGreater(os.path.getsize(index_checkpoint_file), 0)
|
||||||
|
outputs_before = model(self.fake_inputs, training=True)
|
||||||
|
|
||||||
|
# Check model output is different after training
|
||||||
|
self._train_model(model)
|
||||||
|
outputs_after = model(self.fake_inputs, training=True)
|
||||||
|
self.assertFalse(_outputs_match(outputs_before, outputs_after))
|
||||||
|
|
||||||
|
# Check model output is the same after loading previous checkpoint
|
||||||
|
model.load_checkpoint(checkpoint_path, include_last_layer=True)
|
||||||
|
outputs_after = model(self.fake_inputs, training=True)
|
||||||
|
self.assertTrue(_outputs_match(outputs_before, outputs_after))
|
||||||
|
|
||||||
|
def test_export_saved_model(self):
|
||||||
|
export_dir = self.create_tempdir()
|
||||||
|
export_path = os.path.join(export_dir, 'saved_model')
|
||||||
|
model = self._create_model()
|
||||||
|
model.export_saved_model(export_path)
|
||||||
|
self.assertTrue(os.path.exists(export_path))
|
||||||
|
self.assertGreater(os.path.getsize(export_path), 0)
|
||||||
|
model.convert_to_qat()
|
||||||
|
export_path = os.path.join(export_dir, 'saved_model_qat')
|
||||||
|
model.export_saved_model(export_path)
|
||||||
|
self.assertTrue(os.path.exists(export_path))
|
||||||
|
self.assertGreater(os.path.getsize(export_path), 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,353 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""APIs to train object detector model."""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.tasks import classifier
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import dataset as ds
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model as model_lib
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import preprocessor
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import object_detector as object_detector_writer
|
||||||
|
from official.vision.evaluation import coco_evaluator
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetector(classifier.Classifier):
|
||||||
|
"""ObjectDetector for building object detection model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_spec: ms.ModelSpec,
|
||||||
|
label_names: List[str],
|
||||||
|
hparams: hp.HParams,
|
||||||
|
model_options: model_opt.ObjectDetectorModelOptions,
|
||||||
|
) -> None:
|
||||||
|
"""Initializes ObjectDetector class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_spec: Specifications for the model.
|
||||||
|
label_names: A list of label names for the classes.
|
||||||
|
hparams: The hyperparameters for training object detector.
|
||||||
|
model_options: Options for creating the object detector model.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle
|
||||||
|
)
|
||||||
|
self._preprocessor = preprocessor.Preprocessor(model_spec)
|
||||||
|
self._hparams = hparams
|
||||||
|
self._model_options = model_options
|
||||||
|
self._optimizer = self._create_optimizer()
|
||||||
|
self._is_qat = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
train_data: ds.Dataset,
|
||||||
|
validation_data: ds.Dataset,
|
||||||
|
options: object_detector_options.ObjectDetectorOptions,
|
||||||
|
) -> 'ObjectDetector':
|
||||||
|
"""Creates and trains an ObjectDetector.
|
||||||
|
|
||||||
|
Loads data and trains the model based on data for object detection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
options: Configurations for creating and training object detector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of ObjectDetector.
|
||||||
|
"""
|
||||||
|
if options.hparams is None:
|
||||||
|
options.hparams = hp.HParams()
|
||||||
|
|
||||||
|
if options.model_options is None:
|
||||||
|
options.model_options = model_opt.ObjectDetectorModelOptions()
|
||||||
|
|
||||||
|
spec = ms.SupportedModels.get(options.supported_model)
|
||||||
|
object_detector = cls(
|
||||||
|
model_spec=spec,
|
||||||
|
label_names=train_data.label_names,
|
||||||
|
hparams=options.hparams,
|
||||||
|
model_options=options.model_options,
|
||||||
|
)
|
||||||
|
object_detector._create_and_train_model(train_data, validation_data)
|
||||||
|
return object_detector
|
||||||
|
|
||||||
|
def _create_and_train_model(
|
||||||
|
self, train_data: ds.Dataset, validation_data: ds.Dataset
|
||||||
|
):
|
||||||
|
"""Creates and trains the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
"""
|
||||||
|
self._create_model()
|
||||||
|
self._train_model(
|
||||||
|
train_data, validation_data, preprocessor=self._preprocessor
|
||||||
|
)
|
||||||
|
self._save_float_ckpt()
|
||||||
|
|
||||||
|
def _create_model(self) -> None:
|
||||||
|
"""Creates the object detector model."""
|
||||||
|
self._model = model_lib.ObjectDetectorModel(
|
||||||
|
model_spec=self._model_spec,
|
||||||
|
model_options=self._model_options,
|
||||||
|
num_classes=self._num_classes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save_float_ckpt(self) -> None:
|
||||||
|
"""Saves a checkpoint of the trained float model.
|
||||||
|
|
||||||
|
The default save path is {hparams.export_dir}/float_ckpt. Note that
|
||||||
|
`float_cpt` represents a file prefix, not directory. The resulting files
|
||||||
|
saved to {hparams.export_dir} will be:
|
||||||
|
- float_ckpt.data-00000-of-00001
|
||||||
|
- float_ckpt.index
|
||||||
|
"""
|
||||||
|
save_path = os.path.join(self._hparams.export_dir, 'float_ckpt')
|
||||||
|
if not os.path.exists(self._hparams.export_dir):
|
||||||
|
os.makedirs(self._hparams.export_dir)
|
||||||
|
self._model.save_checkpoint(save_path)
|
||||||
|
|
||||||
|
def restore_float_ckpt(self) -> None:
|
||||||
|
"""Loads a float checkpoint of the model from {hparams.export_dir}/float_ckpt.
|
||||||
|
|
||||||
|
The float checkpoint at {hparams.export_dir}/float_ckpt is automatically
|
||||||
|
saved after training an ObjectDetector using the `create` method. This
|
||||||
|
method is used to restore the trained float checkpoint state of the model in
|
||||||
|
order to run `quantization_aware_training` multiple times. Example usage:
|
||||||
|
|
||||||
|
# Train a model
|
||||||
|
model = object_detector.create(...)
|
||||||
|
# Run QAT
|
||||||
|
model.quantization_aware_training(...)
|
||||||
|
model.evaluate(...)
|
||||||
|
# Restore the float checkpoint to run QAT again
|
||||||
|
model.restore_float_ckpt()
|
||||||
|
# Run QAT with different parameters
|
||||||
|
model.quantization_aware_training(...)
|
||||||
|
model.evaluate(...)
|
||||||
|
"""
|
||||||
|
self._create_model()
|
||||||
|
self._model.load_checkpoint(
|
||||||
|
os.path.join(self._hparams.export_dir, 'float_ckpt'),
|
||||||
|
include_last_layer=True,
|
||||||
|
)
|
||||||
|
self._model.compile()
|
||||||
|
self._is_qat = False
|
||||||
|
|
||||||
|
# TODO: Refactor this method to utilize shared training function
|
||||||
|
def quantization_aware_training(
|
||||||
|
self,
|
||||||
|
train_data: ds.Dataset,
|
||||||
|
validation_data: ds.Dataset,
|
||||||
|
qat_hparams: hp.QATHParams,
|
||||||
|
) -> None:
|
||||||
|
"""Runs quantization aware training(QAT) on the model.
|
||||||
|
|
||||||
|
The QAT step happens after training a regular float model from the `create`
|
||||||
|
method. This additional step will fine-tune the model with a lower precision
|
||||||
|
in order mimic the behavior of a quantized model. The resulting quantized
|
||||||
|
model generally has better performance than a model which is quantized
|
||||||
|
without running QAT. See the following link for more information:
|
||||||
|
- https://www.tensorflow.org/model_optimization/guide/quantization/training
|
||||||
|
|
||||||
|
Just like training the float model using the `create` method, the QAT step
|
||||||
|
also requires some manual tuning of hyperparameters. In order to run QAT
|
||||||
|
more than once for purposes such as hyperparameter tuning, use the
|
||||||
|
`restore_float_ckpt` method to restore the model state to the trained float
|
||||||
|
checkpoint without having to rerun the `create` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training dataset.
|
||||||
|
validation_data: Validaiton dataset.
|
||||||
|
qat_hparams: Configuration for QAT.
|
||||||
|
"""
|
||||||
|
self._model.convert_to_qat()
|
||||||
|
learning_rate_fn = tf.keras.optimizers.schedules.ExponentialDecay(
|
||||||
|
qat_hparams.learning_rate * qat_hparams.batch_size / 256,
|
||||||
|
decay_steps=qat_hparams.decay_steps,
|
||||||
|
decay_rate=qat_hparams.decay_rate,
|
||||||
|
staircase=True,
|
||||||
|
)
|
||||||
|
optimizer = tf.keras.optimizers.experimental.SGD(
|
||||||
|
learning_rate=learning_rate_fn, momentum=0.9
|
||||||
|
)
|
||||||
|
if len(train_data) < qat_hparams.batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"The size of the train_data {len(train_data)} can't be smaller than"
|
||||||
|
f' batch_size {qat_hparams.batch_size}. To solve this problem, set'
|
||||||
|
' the batch_size smaller or increase the size of the train_data.'
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = train_data.gen_tf_dataset(
|
||||||
|
batch_size=qat_hparams.batch_size,
|
||||||
|
is_training=True,
|
||||||
|
shuffle=self._shuffle,
|
||||||
|
preprocess=self._preprocessor,
|
||||||
|
)
|
||||||
|
steps_per_epoch = model_util.get_steps_per_epoch(
|
||||||
|
steps_per_epoch=None,
|
||||||
|
batch_size=qat_hparams.batch_size,
|
||||||
|
train_data=train_data,
|
||||||
|
)
|
||||||
|
train_dataset = train_dataset.take(count=steps_per_epoch)
|
||||||
|
validation_dataset = validation_data.gen_tf_dataset(
|
||||||
|
batch_size=qat_hparams.batch_size,
|
||||||
|
is_training=False,
|
||||||
|
preprocess=self._preprocessor,
|
||||||
|
)
|
||||||
|
self._model.compile(optimizer=optimizer)
|
||||||
|
self._model.fit(
|
||||||
|
x=train_dataset,
|
||||||
|
epochs=qat_hparams.epochs,
|
||||||
|
steps_per_epoch=None,
|
||||||
|
validation_data=validation_dataset,
|
||||||
|
)
|
||||||
|
self._is_qat = True
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self, dataset: ds.Dataset, batch_size: int = 1
|
||||||
|
) -> Tuple[List[float], Dict[str, float]]:
|
||||||
|
"""Overrides Classifier.evaluate to calculate COCO metrics."""
|
||||||
|
dataset = dataset.gen_tf_dataset(
|
||||||
|
batch_size, is_training=False, preprocess=self._preprocessor
|
||||||
|
)
|
||||||
|
losses = self._model.evaluate(dataset)
|
||||||
|
coco_eval = coco_evaluator.COCOEvaluator(
|
||||||
|
annotation_file=None,
|
||||||
|
include_mask=False,
|
||||||
|
per_category_metrics=True,
|
||||||
|
max_num_eval_detections=100,
|
||||||
|
)
|
||||||
|
for batch in dataset:
|
||||||
|
x, y = batch
|
||||||
|
y_pred = self._model(
|
||||||
|
x,
|
||||||
|
anchor_boxes=y['anchor_boxes'],
|
||||||
|
image_shape=y['image_info'][:, 1, :],
|
||||||
|
training=False,
|
||||||
|
)
|
||||||
|
groundtruths = y['groundtruths']
|
||||||
|
y_pred['image_info'] = groundtruths['image_info']
|
||||||
|
y_pred['source_id'] = groundtruths['source_id']
|
||||||
|
coco_eval.update_state(groundtruths, y_pred)
|
||||||
|
coco_metrics = coco_eval.result()
|
||||||
|
return losses, coco_metrics
|
||||||
|
|
||||||
|
def export_model(
|
||||||
|
self,
|
||||||
|
model_name: str = 'model.tflite',
|
||||||
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
"""Converts and saves the model to a TFLite file with metadata included.
|
||||||
|
|
||||||
|
The model export format is automatically set based on whether or not
|
||||||
|
`quantization_aware_training`(QAT) was run. The model exports to float32 by
|
||||||
|
default and will export to an int8 quantized model if QAT was run. To export
|
||||||
|
a float32 model after running QAT, run `restore_float_ckpt` before this
|
||||||
|
method. For custom post-training quantization without QAT, use the
|
||||||
|
quantization_config parameter.
|
||||||
|
|
||||||
|
Note that only the TFLite file is needed for deployment. This function also
|
||||||
|
saves a metadata.json file to the same directory as the TFLite file which
|
||||||
|
can be used to interpret the metadata content in the TFLite file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: File name to save TFLite model with metadata. The full export
|
||||||
|
path is {self._hparams.export_dir}/{model_name}.
|
||||||
|
quantization_config: The configuration for model quantization. Note that
|
||||||
|
int8 quantization aware training is automatically applied when possible.
|
||||||
|
This parameter is used to specify other post-training quantization
|
||||||
|
options such as fp16 and int8 without QAT.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a custom quantization_config is specified when the model
|
||||||
|
has quantization aware training enabled.
|
||||||
|
"""
|
||||||
|
if quantization_config:
|
||||||
|
if self._is_qat:
|
||||||
|
raise ValueError(
|
||||||
|
'Exporting a qat model with a custom quantization_config is not '
|
||||||
|
'supported.'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
'Exporting with custom post-training-quantization: ',
|
||||||
|
quantization_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self._is_qat:
|
||||||
|
print('Exporting a qat int8 model')
|
||||||
|
quantization_config = quantization.QuantizationConfig(
|
||||||
|
inference_input_type=tf.uint8, inference_output_type=tf.uint8
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print('Exporting a floating point model')
|
||||||
|
|
||||||
|
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||||
|
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
save_path = os.path.join(temp_dir, 'saved_model')
|
||||||
|
self._model.export_saved_model(save_path)
|
||||||
|
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
|
||||||
|
if quantization_config:
|
||||||
|
converter = quantization_config.set_converter_with_quantization(
|
||||||
|
converter, preprocess=self._preprocessor
|
||||||
|
)
|
||||||
|
|
||||||
|
converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,)
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
writer = object_detector_writer.MetadataWriter.create(
|
||||||
|
tflite_model,
|
||||||
|
self._model_spec.mean_rgb,
|
||||||
|
self._model_spec.stddev_rgb,
|
||||||
|
labels=metadata_writer.Labels().add(list(self._label_names)),
|
||||||
|
)
|
||||||
|
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||||
|
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||||
|
with open(metadata_file, 'w') as f:
|
||||||
|
f.write(metadata_json)
|
||||||
|
|
||||||
|
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
|
||||||
|
"""Creates an optimizer with learning rate schedule for regular training.
|
||||||
|
|
||||||
|
Uses Keras PiecewiseConstantDecay schedule by default.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tf.keras.optimizer.Optimizer for model training.
|
||||||
|
"""
|
||||||
|
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
|
||||||
|
lr_values = [init_lr] + [
|
||||||
|
init_lr * m for m in self._hparams.learning_rate_decay_multipliers
|
||||||
|
]
|
||||||
|
learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
|
||||||
|
self._hparams.learning_rate_boundaries, lr_values
|
||||||
|
)
|
||||||
|
return tf.keras.optimizers.experimental.SGD(
|
||||||
|
learning_rate=learning_rate_fn, momentum=0.9
|
||||||
|
)
|
|
@ -0,0 +1,84 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Demo for making an object detector model by MediaPipe Model Maker."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision import object_detector
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/object_detector/testdata/coco_data'
|
||||||
|
|
||||||
|
|
||||||
|
def define_flags() -> None:
|
||||||
|
"""Define flags for the object detection model maker demo."""
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'export_dir', None, 'The directory to save exported files.'
|
||||||
|
)
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'input_data_dir',
|
||||||
|
None,
|
||||||
|
"""The directory with input training data. If the training data is not
|
||||||
|
specified, the pipeline will use the test dataset.""",
|
||||||
|
)
|
||||||
|
flags.DEFINE_bool('qat', True, 'Whether or not to do QAT.')
|
||||||
|
flags.mark_flag_as_required('export_dir')
|
||||||
|
|
||||||
|
|
||||||
|
def run(data_dir: str, export_dir: str, qat: bool):
|
||||||
|
"""Runs demo."""
|
||||||
|
data = object_detector.Dataset.from_coco_folder(data_dir)
|
||||||
|
train_data, rest_data = data.split(0.6)
|
||||||
|
validation_data, test_data = rest_data.split(0.5)
|
||||||
|
|
||||||
|
hparams = object_detector.HParams(batch_size=1, export_dir=export_dir)
|
||||||
|
options = object_detector.ObjectDetectorOptions(
|
||||||
|
supported_model=object_detector.SupportedModels.MOBILENET_V2,
|
||||||
|
hparams=hparams,
|
||||||
|
)
|
||||||
|
model = object_detector.ObjectDetector.create(
|
||||||
|
train_data=train_data, validation_data=validation_data, options=options
|
||||||
|
)
|
||||||
|
loss, coco_metrics = model.evaluate(test_data, batch_size=1)
|
||||||
|
print(f'Evaluation loss:{loss}, coco_metrics:{coco_metrics}')
|
||||||
|
if qat:
|
||||||
|
qat_hparams = object_detector.QATHParams(batch_size=1)
|
||||||
|
model.quantization_aware_training(train_data, validation_data, qat_hparams)
|
||||||
|
qat_loss, qat_coco_metrics = model.evaluate(test_data, batch_size=1)
|
||||||
|
print(f'QAT Evaluation loss:{qat_loss}, coco_metrics:{qat_coco_metrics}')
|
||||||
|
|
||||||
|
model.export_model()
|
||||||
|
|
||||||
|
|
||||||
|
def main(_) -> None:
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
|
||||||
|
if FLAGS.input_data_dir is None:
|
||||||
|
data_dir = os.path.join(FLAGS.test_srcdir, TEST_DATA_DIR)
|
||||||
|
else:
|
||||||
|
data_dir = FLAGS.input_data_dir
|
||||||
|
|
||||||
|
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||||
|
run(data_dir=data_dir, export_dir=export_dir, qat=FLAGS.qat)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
define_flags()
|
||||||
|
app.run(main)
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Options for building object detector."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_spec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ObjectDetectorOptions:
|
||||||
|
"""Configurable options for building object detector.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
supported_model: A model from the SupportedModels enum.
|
||||||
|
model_options: A set of options for configuring the selected model.
|
||||||
|
hparams: A set of hyperparameters used to train the object detector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
supported_model: model_spec.SupportedModels
|
||||||
|
model_options: Optional[model_opt.ObjectDetectorModelOptions] = None
|
||||||
|
hparams: Optional[hyperparameters.HParams] = None
|
|
@ -0,0 +1,121 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import dataset
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import object_detector
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
||||||
|
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
||||||
|
cache_dir = self.create_tempdir()
|
||||||
|
self.data = dataset.Dataset.from_coco_folder(
|
||||||
|
dataset_folder, cache_dir=cache_dir
|
||||||
|
)
|
||||||
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
# condition when downloading model since these tests may run in parallel.
|
||||||
|
mock_gettempdir = unittest_mock.patch.object(
|
||||||
|
tempfile,
|
||||||
|
'gettempdir',
|
||||||
|
return_value=self.create_tempdir(),
|
||||||
|
autospec=True,
|
||||||
|
)
|
||||||
|
self.mock_gettempdir = mock_gettempdir.start()
|
||||||
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
|
||||||
|
def test_object_detector(self):
|
||||||
|
hparams = hyperparameters.HParams(
|
||||||
|
epochs=10,
|
||||||
|
batch_size=2,
|
||||||
|
learning_rate=0.9,
|
||||||
|
shuffle=False,
|
||||||
|
export_dir=self.create_tempdir(),
|
||||||
|
)
|
||||||
|
options = object_detector_options.ObjectDetectorOptions(
|
||||||
|
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
|
||||||
|
)
|
||||||
|
# Test `create``
|
||||||
|
model = object_detector.ObjectDetector.create(
|
||||||
|
train_data=self.data, validation_data=self.data, options=options
|
||||||
|
)
|
||||||
|
losses, coco_metrics = model.evaluate(self.data)
|
||||||
|
self._assert_ap_greater(coco_metrics)
|
||||||
|
self.assertFalse(model._is_qat)
|
||||||
|
# Test float export_model
|
||||||
|
model.export_model()
|
||||||
|
output_metadata_file = os.path.join(
|
||||||
|
options.hparams.export_dir, 'metadata.json'
|
||||||
|
)
|
||||||
|
output_tflite_file = os.path.join(
|
||||||
|
options.hparams.export_dir, 'model.tflite'
|
||||||
|
)
|
||||||
|
print('ASDF float', os.path.getsize(output_tflite_file))
|
||||||
|
self.assertTrue(os.path.exists(output_tflite_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||||
|
self.assertTrue(os.path.exists(output_metadata_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
|
|
||||||
|
# Test `quantization_aware_training`
|
||||||
|
qat_hparams = hyperparameters.QATHParams(
|
||||||
|
learning_rate=0.9,
|
||||||
|
batch_size=2,
|
||||||
|
epochs=5,
|
||||||
|
decay_steps=6,
|
||||||
|
decay_rate=0.96,
|
||||||
|
)
|
||||||
|
model.quantization_aware_training(self.data, self.data, qat_hparams)
|
||||||
|
qat_losses, qat_coco_metrics = model.evaluate(self.data)
|
||||||
|
self._assert_ap_greater(qat_coco_metrics)
|
||||||
|
self.assertNotAllEqual(losses, qat_losses)
|
||||||
|
self.assertTrue(model._is_qat)
|
||||||
|
model.export_model('model_qat.tflite')
|
||||||
|
output_metadata_file = os.path.join(
|
||||||
|
options.hparams.export_dir, 'metadata.json'
|
||||||
|
)
|
||||||
|
output_tflite_file = os.path.join(
|
||||||
|
options.hparams.export_dir, 'model_qat.tflite'
|
||||||
|
)
|
||||||
|
print('ASDF qat', os.path.getsize(output_tflite_file))
|
||||||
|
self.assertTrue(os.path.exists(output_tflite_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||||
|
self.assertLess(os.path.getsize(output_tflite_file), 3500000)
|
||||||
|
self.assertTrue(os.path.exists(output_metadata_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
|
|
||||||
|
# Load float ckpt test
|
||||||
|
model.restore_float_ckpt()
|
||||||
|
losses_2, _ = model.evaluate(self.data)
|
||||||
|
self.assertAllEqual(losses, losses_2)
|
||||||
|
self.assertNotAllEqual(qat_losses, losses_2)
|
||||||
|
self.assertFalse(model._is_qat)
|
||||||
|
|
||||||
|
def _assert_ap_greater(self, coco_metrics, threshold=0.0):
|
||||||
|
self.assertGreaterEqual(coco_metrics['AP'], threshold)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,163 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Preprocessor for object detector."""
|
||||||
|
from typing import Any, Mapping, Tuple
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||||
|
from official.vision.dataloaders import utils
|
||||||
|
from official.vision.ops import anchor
|
||||||
|
from official.vision.ops import box_ops
|
||||||
|
from official.vision.ops import preprocess_ops
|
||||||
|
|
||||||
|
|
||||||
|
# TODO Combine preprocessing logic with image_preprocessor.
|
||||||
|
class Preprocessor(object):
|
||||||
|
"""Preprocessor for object detector."""
|
||||||
|
|
||||||
|
def __init__(self, model_spec: ms.ModelSpec):
|
||||||
|
"""Initialize a Preprocessor."""
|
||||||
|
self._mean_norm = model_spec.mean_norm
|
||||||
|
self._stddev_norm = model_spec.stddev_norm
|
||||||
|
self._output_size = model_spec.input_image_shape[:2]
|
||||||
|
self._min_level = 3
|
||||||
|
self._max_level = 7
|
||||||
|
self._num_scales = 3
|
||||||
|
self._aspect_ratios = [0.5, 1, 2]
|
||||||
|
self._anchor_size = 3
|
||||||
|
self._dtype = tf.float32
|
||||||
|
self._match_threshold = 0.5
|
||||||
|
self._unmatched_threshold = 0.5
|
||||||
|
self._aug_scale_min = 0.5
|
||||||
|
self._aug_scale_max = 2.0
|
||||||
|
self._max_num_instances = 100
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, data: Mapping[str, Any], is_training: bool = True
|
||||||
|
) -> Tuple[tf.Tensor, Mapping[str, Any]]:
|
||||||
|
"""Run the preprocessor on an example.
|
||||||
|
|
||||||
|
The data dict should contain the following keys always:
|
||||||
|
- image
|
||||||
|
- groundtruth_classes
|
||||||
|
- groundtruth_boxes
|
||||||
|
- groundtruth_is_crowd
|
||||||
|
Additional keys needed when is_training is set to True:
|
||||||
|
- groundtruth_area
|
||||||
|
- source_id
|
||||||
|
- height
|
||||||
|
- width
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: A dict of object detector inputs.
|
||||||
|
is_training: Whether or not the data is used for training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (image, labels) where image is a Tensor and labels is a dict.
|
||||||
|
"""
|
||||||
|
classes = data['groundtruth_classes']
|
||||||
|
boxes = data['groundtruth_boxes']
|
||||||
|
|
||||||
|
# Get original image.
|
||||||
|
image = data['image']
|
||||||
|
image_shape = tf.shape(input=image)[0:2]
|
||||||
|
|
||||||
|
# Normalize image with mean and std pixel values.
|
||||||
|
image = preprocess_ops.normalize_image(
|
||||||
|
image, self._mean_norm, self._stddev_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flip image randomly during training.
|
||||||
|
if is_training:
|
||||||
|
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
|
||||||
|
|
||||||
|
# Convert boxes from normalized coordinates to pixel coordinates.
|
||||||
|
boxes = box_ops.denormalize_boxes(boxes, image_shape)
|
||||||
|
|
||||||
|
# Resize and crop image.
|
||||||
|
image, image_info = preprocess_ops.resize_and_crop_image(
|
||||||
|
image,
|
||||||
|
self._output_size,
|
||||||
|
padded_size=preprocess_ops.compute_padded_size(
|
||||||
|
self._output_size, 2**self._max_level
|
||||||
|
),
|
||||||
|
aug_scale_min=(self._aug_scale_min if is_training else 1.0),
|
||||||
|
aug_scale_max=(self._aug_scale_max if is_training else 1.0),
|
||||||
|
)
|
||||||
|
image_height, image_width, _ = image.get_shape().as_list()
|
||||||
|
|
||||||
|
# Resize and crop boxes.
|
||||||
|
image_scale = image_info[2, :]
|
||||||
|
offset = image_info[3, :]
|
||||||
|
boxes = preprocess_ops.resize_and_crop_boxes(
|
||||||
|
boxes, image_scale, image_info[1, :], offset
|
||||||
|
)
|
||||||
|
# Filter out ground-truth boxes that are all zeros.
|
||||||
|
indices = box_ops.get_non_empty_box_indices(boxes)
|
||||||
|
boxes = tf.gather(boxes, indices)
|
||||||
|
classes = tf.gather(classes, indices)
|
||||||
|
|
||||||
|
# Assign anchors.
|
||||||
|
input_anchor = anchor.build_anchor_generator(
|
||||||
|
min_level=self._min_level,
|
||||||
|
max_level=self._max_level,
|
||||||
|
num_scales=self._num_scales,
|
||||||
|
aspect_ratios=self._aspect_ratios,
|
||||||
|
anchor_size=self._anchor_size,
|
||||||
|
)
|
||||||
|
anchor_boxes = input_anchor(image_size=(image_height, image_width))
|
||||||
|
anchor_labeler = anchor.AnchorLabeler(
|
||||||
|
self._match_threshold, self._unmatched_threshold
|
||||||
|
)
|
||||||
|
(cls_targets, box_targets, _, cls_weights, box_weights) = (
|
||||||
|
anchor_labeler.label_anchors(
|
||||||
|
anchor_boxes, boxes, tf.expand_dims(classes, axis=1)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cast input image to desired data type.
|
||||||
|
image = tf.cast(image, dtype=self._dtype)
|
||||||
|
|
||||||
|
# Pack labels for model_fn outputs.
|
||||||
|
labels = {
|
||||||
|
'cls_targets': cls_targets,
|
||||||
|
'box_targets': box_targets,
|
||||||
|
'anchor_boxes': anchor_boxes,
|
||||||
|
'cls_weights': cls_weights,
|
||||||
|
'box_weights': box_weights,
|
||||||
|
'image_info': image_info,
|
||||||
|
}
|
||||||
|
if not is_training:
|
||||||
|
groundtruths = {
|
||||||
|
'source_id': data['source_id'],
|
||||||
|
'height': data['height'],
|
||||||
|
'width': data['width'],
|
||||||
|
'num_detections': tf.shape(data['groundtruth_classes']),
|
||||||
|
'image_info': image_info,
|
||||||
|
'boxes': box_ops.denormalize_boxes(
|
||||||
|
data['groundtruth_boxes'], image_shape
|
||||||
|
),
|
||||||
|
'classes': data['groundtruth_classes'],
|
||||||
|
'areas': data['groundtruth_area'],
|
||||||
|
'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
|
||||||
|
}
|
||||||
|
groundtruths['source_id'] = utils.process_source_id(
|
||||||
|
groundtruths['source_id']
|
||||||
|
)
|
||||||
|
groundtruths = utils.pad_groundtruths_to_fixed_size(
|
||||||
|
groundtruths, self._max_num_instances
|
||||||
|
)
|
||||||
|
labels.update({'groundtruths': groundtruths})
|
||||||
|
return image, labels
|
|
@ -0,0 +1,158 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.core import test_utils
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.vision.object_detector import preprocessor as preprocessor_lib
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
MAX_IMAGE_SIZE = 360
|
||||||
|
OUTPUT_SIZE = 256
|
||||||
|
NUM_CLASSES = 10
|
||||||
|
NUM_EXAMPLES = 3
|
||||||
|
MIN_LEVEL = 3
|
||||||
|
MAX_LEVEL = 7
|
||||||
|
NUM_SCALES = 3
|
||||||
|
ASPECT_RATIOS = [0.5, 1, 2]
|
||||||
|
MAX_NUM_INSTANCES = 100
|
||||||
|
|
||||||
|
def _get_rand_example(self):
|
||||||
|
num_annotations = random.randint(1, 3)
|
||||||
|
bboxes, classes, is_crowds = [], [], []
|
||||||
|
image_size = random.randint(10, self.MAX_IMAGE_SIZE + 1)
|
||||||
|
rgb = [random.uniform(0, 255) for _ in range(3)]
|
||||||
|
image = test_utils.fill_image(rgb, image_size)
|
||||||
|
for _ in range(num_annotations):
|
||||||
|
x1, x2 = random.uniform(0, image_size), random.uniform(0, image_size)
|
||||||
|
y1, y2 = random.uniform(0, image_size), random.uniform(0, image_size)
|
||||||
|
bbox = [min(x1, x2), min(y1, y2), abs(x1 - x2), abs(y1 - y2)]
|
||||||
|
bboxes.append(bbox)
|
||||||
|
classes.append(random.randint(0, self.NUM_CLASSES - 1))
|
||||||
|
is_crowds.append(0)
|
||||||
|
return {
|
||||||
|
'image': tf.cast(image, dtype=tf.float32),
|
||||||
|
'groundtruth_boxes': tf.cast(bboxes, dtype=tf.float32),
|
||||||
|
'groundtruth_classes': tf.cast(classes, dtype=tf.int64),
|
||||||
|
'groundtruth_is_crowd': tf.cast(is_crowds, dtype=tf.bool),
|
||||||
|
'groundtruth_area': tf.cast(is_crowds, dtype=tf.float32),
|
||||||
|
'source_id': tf.cast(1, dtype=tf.int64),
|
||||||
|
'height': tf.cast(image_size, dtype=tf.int64),
|
||||||
|
'width': tf.cast(image_size, dtype=tf.int64),
|
||||||
|
}
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
dataset = [self._get_rand_example() for _ in range(self.NUM_EXAMPLES)]
|
||||||
|
|
||||||
|
def my_generator(data):
|
||||||
|
for item in data:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
self.dataset = tf.data.Dataset.from_generator(
|
||||||
|
lambda: my_generator(dataset),
|
||||||
|
output_types={
|
||||||
|
'image': tf.float32,
|
||||||
|
'groundtruth_classes': tf.int64,
|
||||||
|
'groundtruth_boxes': tf.float32,
|
||||||
|
'groundtruth_is_crowd': tf.bool,
|
||||||
|
'groundtruth_area': tf.float32,
|
||||||
|
'source_id': tf.int64,
|
||||||
|
'height': tf.int64,
|
||||||
|
'width': tf.int64,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='training',
|
||||||
|
is_training=True,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='evaluation',
|
||||||
|
is_training=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_preprocessor(self, is_training):
|
||||||
|
model_spec = ms.SupportedModels.MOBILENET_V2.value()
|
||||||
|
labels_keys = [
|
||||||
|
'cls_targets',
|
||||||
|
'box_targets',
|
||||||
|
'anchor_boxes',
|
||||||
|
'cls_weights',
|
||||||
|
'box_weights',
|
||||||
|
'image_info',
|
||||||
|
]
|
||||||
|
if not is_training:
|
||||||
|
labels_keys.append('groundtruths')
|
||||||
|
preprocessor = preprocessor_lib.Preprocessor(model_spec)
|
||||||
|
for example in self.dataset:
|
||||||
|
result = preprocessor(example, is_training=is_training)
|
||||||
|
image, labels = result
|
||||||
|
self.assertAllEqual(image.shape, (256, 256, 3))
|
||||||
|
self.assertCountEqual(labels_keys, labels.keys())
|
||||||
|
np_labels = tf.nest.map_structure(lambda x: x.numpy(), labels)
|
||||||
|
# Checks shapes of `image_info` and `anchor_boxes`.
|
||||||
|
self.assertEqual(np_labels['image_info'].shape, (4, 2))
|
||||||
|
n_anchors = 0
|
||||||
|
for level in range(self.MIN_LEVEL, self.MAX_LEVEL + 1):
|
||||||
|
stride = 2**level
|
||||||
|
output_size_l = [self.OUTPUT_SIZE / stride, self.OUTPUT_SIZE / stride]
|
||||||
|
anchors_per_location = self.NUM_SCALES * len(self.ASPECT_RATIOS)
|
||||||
|
self.assertEqual(
|
||||||
|
list(np_labels['anchor_boxes'][str(level)].shape),
|
||||||
|
[output_size_l[0], output_size_l[1], 4 * anchors_per_location],
|
||||||
|
)
|
||||||
|
n_anchors += output_size_l[0] * output_size_l[1] * anchors_per_location
|
||||||
|
# Checks shapes of training objectives.
|
||||||
|
self.assertEqual(np_labels['cls_weights'].shape, (int(n_anchors),))
|
||||||
|
for level in range(self.MIN_LEVEL, self.MAX_LEVEL + 1):
|
||||||
|
stride = 2**level
|
||||||
|
output_size_l = [self.OUTPUT_SIZE / stride, self.OUTPUT_SIZE / stride]
|
||||||
|
anchors_per_location = self.NUM_SCALES * len(self.ASPECT_RATIOS)
|
||||||
|
self.assertEqual(
|
||||||
|
list(np_labels['cls_targets'][str(level)].shape),
|
||||||
|
[output_size_l[0], output_size_l[1], anchors_per_location],
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
list(np_labels['box_targets'][str(level)].shape),
|
||||||
|
[output_size_l[0], output_size_l[1], 4 * anchors_per_location],
|
||||||
|
)
|
||||||
|
# Checks shape of groundtruths for eval.
|
||||||
|
if not is_training:
|
||||||
|
self.assertEqual(np_labels['groundtruths']['source_id'].shape, ())
|
||||||
|
self.assertEqual(
|
||||||
|
np_labels['groundtruths']['classes'].shape,
|
||||||
|
(self.MAX_NUM_INSTANCES,),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
np_labels['groundtruths']['boxes'].shape,
|
||||||
|
(self.MAX_NUM_INSTANCES, 4),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
np_labels['groundtruths']['areas'].shape, (self.MAX_NUM_INSTANCES,)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
np_labels['groundtruths']['is_crowds'].shape,
|
||||||
|
(self.MAX_NUM_INSTANCES,),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
BIN
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000072.jpg
vendored
Normal file
After Width: | Height: | Size: 81 KiB |
BIN
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000078.jpg
vendored
Normal file
After Width: | Height: | Size: 72 KiB |
BIN
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000315.jpg
vendored
Normal file
After Width: | Height: | Size: 53 KiB |
BIN
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000431.jpg
vendored
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000446.jpg
vendored
Normal file
After Width: | Height: | Size: 36 KiB |
1
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/labels.json
vendored
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<annotation>
|
||||||
|
<folder>images</folder>
|
||||||
|
<filename>37ca2a3d-IMG_0520.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>MyDatabase</database>
|
||||||
|
<annotation>COCO2017</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
<flickrid>NULL</flickrid>
|
||||||
|
<annotator>1</annotator>
|
||||||
|
</source>
|
||||||
|
<owner>
|
||||||
|
<flickrid>NULL</flickrid>
|
||||||
|
<name>Label Studio</name>
|
||||||
|
</owner>
|
||||||
|
<size>
|
||||||
|
<width>800</width>
|
||||||
|
<height>600</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>0</segmented>
|
||||||
|
<object>
|
||||||
|
<name>pig_android</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>242</xmin>
|
||||||
|
<ymin>17</ymin>
|
||||||
|
<xmax>556</xmax>
|
||||||
|
<ymax>476</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,34 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<annotation>
|
||||||
|
<folder>images</folder>
|
||||||
|
<filename>3d3382d3-IMG_0514.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>MyDatabase</database>
|
||||||
|
<annotation>COCO2017</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
<flickrid>NULL</flickrid>
|
||||||
|
<annotator>1</annotator>
|
||||||
|
</source>
|
||||||
|
<owner>
|
||||||
|
<flickrid>NULL</flickrid>
|
||||||
|
<name>Label Studio</name>
|
||||||
|
</owner>
|
||||||
|
<size>
|
||||||
|
<width>800</width>
|
||||||
|
<height>600</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>0</segmented>
|
||||||
|
<object>
|
||||||
|
<name>android</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>306</xmin>
|
||||||
|
<ymin>130</ymin>
|
||||||
|
<xmax>550</xmax>
|
||||||
|
<ymax>471</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,46 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<annotation>
|
||||||
|
<folder>images</folder>
|
||||||
|
<filename>d1c65813-IMG_0546.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>MyDatabase</database>
|
||||||
|
<annotation>COCO2017</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
<flickrid>NULL</flickrid>
|
||||||
|
<annotator>1</annotator>
|
||||||
|
</source>
|
||||||
|
<owner>
|
||||||
|
<flickrid>NULL</flickrid>
|
||||||
|
<name>Label Studio</name>
|
||||||
|
</owner>
|
||||||
|
<size>
|
||||||
|
<width>800</width>
|
||||||
|
<height>600</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>0</segmented>
|
||||||
|
<object>
|
||||||
|
<name>android</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>93</xmin>
|
||||||
|
<ymin>101</ymin>
|
||||||
|
<xmax>358</xmax>
|
||||||
|
<ymax>378</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>pig_android</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>438</xmin>
|
||||||
|
<ymin>28</ymin>
|
||||||
|
<xmax>654</xmax>
|
||||||
|
<ymax>296</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,46 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<annotation>
|
||||||
|
<folder>images</folder>
|
||||||
|
<filename>d86b20e0-IMG_0509.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>MyDatabase</database>
|
||||||
|
<annotation>COCO2017</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
<flickrid>NULL</flickrid>
|
||||||
|
<annotator>1</annotator>
|
||||||
|
</source>
|
||||||
|
<owner>
|
||||||
|
<flickrid>NULL</flickrid>
|
||||||
|
<name>Label Studio</name>
|
||||||
|
</owner>
|
||||||
|
<size>
|
||||||
|
<width>800</width>
|
||||||
|
<height>600</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>0</segmented>
|
||||||
|
<object>
|
||||||
|
<name>android</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>7</xmin>
|
||||||
|
<ymin>122</ymin>
|
||||||
|
<xmax>296</xmax>
|
||||||
|
<ymax>402</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>pig_android</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>523</xmin>
|
||||||
|
<ymin>69</ymin>
|
||||||
|
<xmax>723</xmax>
|
||||||
|
<ymax>329</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
After Width: | Height: | Size: 54 KiB |
After Width: | Height: | Size: 73 KiB |
After Width: | Height: | Size: 110 KiB |
After Width: | Height: | Size: 120 KiB |