Open Source Object Detector
PiperOrigin-RevId: 519201221
|
@ -17,6 +17,7 @@ from mediapipe.model_maker.python.core.utils import quantization
|
|||
from mediapipe.model_maker.python.vision import image_classifier
|
||||
from mediapipe.model_maker.python.vision import gesture_recognizer
|
||||
from mediapipe.model_maker.python.text import text_classifier
|
||||
from mediapipe.model_maker.python.vision import object_detector
|
||||
|
||||
# Remove duplicated and non-public API
|
||||
del python
|
||||
|
|
195
mediapipe/model_maker/python/vision/object_detector/BUILD
Normal file
|
@ -0,0 +1,195 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "object_detector_import",
|
||||
srcs = ["__init__.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":object_detector",
|
||||
":object_detector_options",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "object_detector_demo",
|
||||
srcs = ["object_detector_demo.py"],
|
||||
data = [":testdata"],
|
||||
python_version = "PY3",
|
||||
tags = ["requires-net:external"],
|
||||
deps = [":object_detector_import"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = glob([
|
||||
"testdata/**",
|
||||
]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = [
|
||||
":dataset_util",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
data = [":testdata"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||
"//mediapipe/model_maker/python/vision/core:test_utils",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset_util",
|
||||
srcs = ["dataset_util.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_util_test",
|
||||
srcs = ["dataset_util_test.py"],
|
||||
data = [":testdata"],
|
||||
deps = [
|
||||
":dataset_util",
|
||||
"//mediapipe/model_maker/python/vision/core:test_utils",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "preprocessor",
|
||||
srcs = ["preprocessor.py"],
|
||||
deps = [":model_spec"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "preprocessor_test",
|
||||
srcs = ["preprocessor_test.py"],
|
||||
deps = [
|
||||
":model_spec",
|
||||
":preprocessor",
|
||||
"//mediapipe/model_maker/python/vision/core:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model",
|
||||
srcs = ["model.py"],
|
||||
deps = [
|
||||
":model_options",
|
||||
":model_spec",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_test",
|
||||
size = "large",
|
||||
srcs = ["model_test.py"],
|
||||
data = [":testdata"],
|
||||
shard_count = 4,
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":model",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":preprocessor",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_options",
|
||||
srcs = ["model_options.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_spec",
|
||||
srcs = ["model_spec.py"],
|
||||
deps = ["//mediapipe/model_maker/python/core/utils:file_util"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "object_detector",
|
||||
srcs = ["object_detector.py"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":hyperparameters",
|
||||
":model",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":object_detector_options",
|
||||
":preprocessor",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:object_detector",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "object_detector_test",
|
||||
size = "enormous",
|
||||
srcs = ["object_detector_test.py"],
|
||||
data = [":testdata"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":hyperparameters",
|
||||
":model_spec",
|
||||
":object_detector",
|
||||
":object_detector_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "object_detector_options",
|
||||
srcs = ["object_detector_options.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe Model Maker Python Public API For Object Detector."""
|
||||
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset
|
||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_options
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
||||
|
||||
ObjectDetector = object_detector.ObjectDetector
|
||||
ModelOptions = model_options.ObjectDetectorModelOptions
|
||||
ModelSpec = model_spec.ModelSpec
|
||||
SupportedModels = model_spec.SupportedModels
|
||||
HParams = hyperparameters.HParams
|
||||
QATHParams = hyperparameters.QATHParams
|
||||
Dataset = dataset.Dataset
|
||||
ObjectDetectorOptions = object_detector_options.ObjectDetectorOptions
|
179
mediapipe/model_maker/python/vision/object_detector/dataset.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Object detector dataset library."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
||||
from official.vision.dataloaders import tf_example_decoder
|
||||
|
||||
|
||||
class Dataset(classification_dataset.ClassificationDataset):
|
||||
"""Dataset library for object detector."""
|
||||
|
||||
@classmethod
|
||||
def from_coco_folder(
|
||||
cls,
|
||||
data_dir: str,
|
||||
max_num_images: Optional[int] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> 'Dataset':
|
||||
"""Loads images and labels from the given directory in COCO format.
|
||||
|
||||
- https://cocodataset.org/#home
|
||||
|
||||
Folder structure should be:
|
||||
<data_dir>/
|
||||
images/
|
||||
<file0>.jpg
|
||||
...
|
||||
labels.json
|
||||
|
||||
The `labels.json` annotations file should should have the following format:
|
||||
{
|
||||
"categories": [{"id": 0, "name": "background"}, ...],
|
||||
"images": [{"id": 0, "file_name": "<file0>.jpg"}, ...],
|
||||
"annotations": [{
|
||||
"id": 0,
|
||||
"image_id": 0,
|
||||
"category_id": 2,
|
||||
"bbox": [x-top left, y-top left, width, height],
|
||||
}, ...]
|
||||
}
|
||||
Note that category id 0 is reserved for the "background" class. It is
|
||||
optional to include, but if included it must be set to "background".
|
||||
|
||||
|
||||
Args:
|
||||
data_dir: Name of the directory containing the data files.
|
||||
max_num_images: Max number of images to process.
|
||||
cache_dir: The cache directory to save TFRecord and metadata files. The
|
||||
TFRecord files are a standardized format for training object detection
|
||||
while the metadata file is used to store information like dataset size
|
||||
and label mapping of id to label name. If the cache_dir is not set, a
|
||||
temporary folder will be created and will not be removed automatically
|
||||
after training which means it can be reused later.
|
||||
|
||||
Returns:
|
||||
Dataset containing images and labels and other related info.
|
||||
Raises:
|
||||
ValueError: If the input data directory is empty.
|
||||
ValueError: If the label_name for id 0 is set to something other than
|
||||
the 'background' class.
|
||||
"""
|
||||
cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir)
|
||||
if not dataset_util.is_cached(cache_files):
|
||||
label_map = dataset_util.get_label_map_coco(data_dir)
|
||||
cache_writer = dataset_util.COCOCacheFilesWriter(
|
||||
label_map=label_map, max_num_images=max_num_images
|
||||
)
|
||||
cache_writer.write_files(cache_files, data_dir)
|
||||
return cls.from_cache(cache_files.cache_prefix)
|
||||
|
||||
@classmethod
|
||||
def from_pascal_voc_folder(
|
||||
cls,
|
||||
data_dir: str,
|
||||
max_num_images: Optional[int] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> 'Dataset':
|
||||
"""Loads images and labels from the given directory in PASCAL VOC format.
|
||||
|
||||
- http://host.robots.ox.ac.uk/pascal/VOC.
|
||||
|
||||
Folder structure should be:
|
||||
<data_dir>/
|
||||
images/
|
||||
<file0>.jpg
|
||||
...
|
||||
Annotations/
|
||||
<file0>.xml
|
||||
...
|
||||
Each <file0>.xml annotation file should have the following format:
|
||||
<annotation>
|
||||
<filename>file0.jpg<filename>
|
||||
<object>
|
||||
<name>kangaroo</name>
|
||||
<bndbox>
|
||||
<xmin>233</xmin>
|
||||
<ymin>89</ymin>
|
||||
<xmax>386</xmax>
|
||||
<ymax>262</ymax>
|
||||
</object>
|
||||
<object>...</object>
|
||||
</annotation>
|
||||
|
||||
Args:
|
||||
data_dir: Name of the directory containing the data files.
|
||||
max_num_images: Max number of images to process.
|
||||
cache_dir: The cache directory to save TFRecord and metadata files. The
|
||||
TFRecord files are a standardized format for training object detection
|
||||
while the metadata file is used to store information like dataset size
|
||||
and label mapping of id to label name. If the cache_dir is not set, a
|
||||
temporary folder will be created and will not be removed automatically
|
||||
after training which means it can be reused later.
|
||||
|
||||
Returns:
|
||||
Dataset containing images and labels and other related info.
|
||||
Raises:
|
||||
ValueError: if the input data directory is empty.
|
||||
"""
|
||||
cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir)
|
||||
if not dataset_util.is_cached(cache_files):
|
||||
label_map = dataset_util.get_label_map_pascal_voc(data_dir)
|
||||
cache_writer = dataset_util.PascalVocCacheFilesWriter(
|
||||
label_map=label_map, max_num_images=max_num_images
|
||||
)
|
||||
cache_writer.write_files(cache_files, data_dir)
|
||||
|
||||
return cls.from_cache(cache_files.cache_prefix)
|
||||
|
||||
@classmethod
|
||||
def from_cache(cls, cache_prefix: str) -> 'Dataset':
|
||||
"""Loads the TFRecord data from cache.
|
||||
|
||||
Args:
|
||||
cache_prefix: The cache prefix including the cache directory and the cache
|
||||
prefix filename, e.g: '/tmp/cache/train'.
|
||||
|
||||
Returns:
|
||||
ObjectDetectorDataset object.
|
||||
"""
|
||||
# Get TFRecord Files
|
||||
tfrecord_file_patten = cache_prefix + '*.tfrecord'
|
||||
matched_files = tf.io.gfile.glob(tfrecord_file_patten)
|
||||
if not matched_files:
|
||||
raise ValueError('TFRecord files are empty.')
|
||||
|
||||
# Load meta_data.
|
||||
meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX
|
||||
if not tf.io.gfile.exists(meta_data_file):
|
||||
raise ValueError("Metadata file %s doesn't exist." % meta_data_file)
|
||||
with tf.io.gfile.GFile(meta_data_file, 'r') as f:
|
||||
meta_data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
dataset = tf.data.TFRecordDataset(matched_files)
|
||||
decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False)
|
||||
dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
|
||||
label_map = meta_data['label_map']
|
||||
label_names = [label_map[k] for k in sorted(label_map.keys())]
|
||||
|
||||
return Dataset(
|
||||
dataset=dataset, size=meta_data['size'], label_names=label_names
|
||||
)
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.core import image_utils
|
||||
from mediapipe.model_maker.python.vision.core import test_utils
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset
|
||||
from mediapipe.tasks.python.test import test_utils as tasks_test_utils
|
||||
|
||||
IMAGE_SIZE = 224
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase):
|
||||
|
||||
def _get_rand_bbox(self):
|
||||
x1, x2 = random.uniform(0, IMAGE_SIZE), random.uniform(0, IMAGE_SIZE)
|
||||
y1, y2 = random.uniform(0, IMAGE_SIZE), random.uniform(0, IMAGE_SIZE)
|
||||
return [min(x1, x2), min(y1, y2), abs(x1 - x2), abs(y1 - y2)]
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.coco_dataset_path = os.path.join(self.get_temp_dir(), 'coco_dataset')
|
||||
if os.path.exists(self.coco_dataset_path):
|
||||
return
|
||||
os.mkdir(self.coco_dataset_path)
|
||||
categories = [{'id': 1, 'name': 'daisy'}, {'id': 2, 'name': 'tulips'}]
|
||||
images = [
|
||||
{'id': 1, 'file_name': 'img1.jpeg'},
|
||||
{'id': 2, 'file_name': 'img2.jpeg'},
|
||||
]
|
||||
annotations = [
|
||||
{'image_id': 1, 'category_id': 1, 'bbox': self._get_rand_bbox()},
|
||||
{'image_id': 2, 'category_id': 1, 'bbox': self._get_rand_bbox()},
|
||||
{'image_id': 2, 'category_id': 2, 'bbox': self._get_rand_bbox()},
|
||||
]
|
||||
labels_dict = {
|
||||
'categories': categories,
|
||||
'images': images,
|
||||
'annotations': annotations,
|
||||
}
|
||||
labels_json = json.dumps(labels_dict)
|
||||
with open(os.path.join(self.coco_dataset_path, 'labels.json'), 'w') as f:
|
||||
f.write(labels_json)
|
||||
images_dir = os.path.join(self.coco_dataset_path, 'images')
|
||||
os.mkdir(images_dir)
|
||||
for item in images:
|
||||
test_utils.write_filled_jpeg_file(
|
||||
os.path.join(images_dir, item['file_name']),
|
||||
[random.uniform(0, 255) for _ in range(3)],
|
||||
IMAGE_SIZE,
|
||||
)
|
||||
|
||||
def test_from_coco_folder(self):
|
||||
data = dataset.Dataset.from_coco_folder(
|
||||
self.coco_dataset_path, cache_dir=self.get_temp_dir()
|
||||
)
|
||||
self.assertLen(data, 2)
|
||||
self.assertEqual(data.num_classes, 3)
|
||||
self.assertEqual(data.label_names, ['background', 'daisy', 'tulips'])
|
||||
for example in data.gen_tf_dataset():
|
||||
boxes = example['groundtruth_boxes']
|
||||
classes = example['groundtruth_classes']
|
||||
self.assertNotEmpty(boxes)
|
||||
self.assertAllLessEqual(boxes, 1)
|
||||
self.assertAllGreaterEqual(boxes, 0)
|
||||
self.assertNotEmpty(classes)
|
||||
self.assertTrue(
|
||||
(classes.numpy() == [1]).all() or (classes.numpy() == [1, 2]).all()
|
||||
)
|
||||
if (classes.numpy() == [1, 1]).all():
|
||||
raw_image_tensor = image_utils.load_image(
|
||||
os.path.join(self.coco_dataset_path, 'images', 'img1.jpeg')
|
||||
)
|
||||
else:
|
||||
raw_image_tensor = image_utils.load_image(
|
||||
os.path.join(self.coco_dataset_path, 'images', 'img2.jpeg')
|
||||
)
|
||||
self.assertTrue(
|
||||
(example['image'].numpy() == raw_image_tensor.numpy()).all()
|
||||
)
|
||||
|
||||
def test_from_pascal_voc_folder(self):
|
||||
pascal_voc_folder = tasks_test_utils.get_test_data_path('pascal_voc_data')
|
||||
data = dataset.Dataset.from_pascal_voc_folder(
|
||||
pascal_voc_folder, cache_dir=self.get_temp_dir()
|
||||
)
|
||||
self.assertLen(data, 4)
|
||||
self.assertEqual(data.num_classes, 3)
|
||||
self.assertEqual(data.label_names, ['background', 'android', 'pig_android'])
|
||||
for example in data.gen_tf_dataset():
|
||||
boxes = example['groundtruth_boxes']
|
||||
classes = example['groundtruth_classes']
|
||||
self.assertNotEmpty(boxes)
|
||||
self.assertAllLessEqual(boxes, 1)
|
||||
self.assertAllGreaterEqual(boxes, 0)
|
||||
self.assertNotEmpty(classes)
|
||||
image = example['image']
|
||||
self.assertNotEmpty(image)
|
||||
self.assertAllGreaterEqual(image, 0)
|
||||
self.assertAllLessEqual(image, 255)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,484 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for Object Detector Dataset Library."""
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from official.vision.data import tfrecord_lib
|
||||
|
||||
|
||||
# Suffix of the meta data file name.
|
||||
META_DATA_FILE_SUFFIX = '_meta_data.yaml'
|
||||
|
||||
|
||||
def _xml_get(node: ET.Element, name: str) -> ET.Element:
|
||||
"""Gets a named child from an XML Element node.
|
||||
|
||||
This method is used to retrieve an XML element that is expected to exist as a
|
||||
subelement of the `node` passed into this argument. If the subelement is not
|
||||
found, then an error is thrown.
|
||||
|
||||
Raises:
|
||||
ValueError: If the subelement is not found.
|
||||
|
||||
Args:
|
||||
node: XML Element Tree node.
|
||||
name: Name of the child node to get
|
||||
|
||||
Returns:
|
||||
A child node of the parameter node with the matching name.
|
||||
"""
|
||||
result = node.find(name)
|
||||
if result is None:
|
||||
raise ValueError(f'Unexpected xml format: {name} not found in {node}')
|
||||
return result
|
||||
|
||||
|
||||
def _get_cache_dir_or_create(cache_dir: Optional[str]) -> str:
|
||||
"""Gets the cache directory or creates it if not exists."""
|
||||
if cache_dir is None:
|
||||
cache_dir = tempfile.mkdtemp()
|
||||
if not tf.io.gfile.exists(cache_dir):
|
||||
tf.io.gfile.makedirs(cache_dir)
|
||||
return cache_dir
|
||||
|
||||
|
||||
def _get_dir_basename(data_dir: str) -> str:
|
||||
"""Gets the base name of the directory."""
|
||||
return os.path.basename(os.path.abspath(data_dir))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CacheFiles:
|
||||
"""Cache files for object detection."""
|
||||
|
||||
cache_prefix: str
|
||||
tfrecord_files: Sequence[str]
|
||||
meta_data_file: str
|
||||
|
||||
|
||||
def _get_cache_files(
|
||||
cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10
|
||||
) -> CacheFiles:
|
||||
"""Creates an object of CacheFiles class.
|
||||
|
||||
Args:
|
||||
cache_dir: The cache directory to save TFRecord and metadata file. When
|
||||
cache_dir is None, a temporary folder will be created and will not be
|
||||
removed automatically after training which makes it can be used later.
|
||||
cache_prefix_filename: The cache prefix filename.
|
||||
num_shards: Number of shards for output file.
|
||||
|
||||
Returns:
|
||||
An object of CacheFiles class.
|
||||
"""
|
||||
cache_dir = _get_cache_dir_or_create(cache_dir)
|
||||
# The cache prefix including the cache directory and the cache prefix
|
||||
# filename, e.g: '/tmp/cache/train'.
|
||||
cache_prefix = os.path.join(cache_dir, cache_prefix_filename)
|
||||
tf.compat.v1.logging.info(
|
||||
'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s'
|
||||
% (cache_dir, cache_prefix_filename, cache_prefix)
|
||||
)
|
||||
|
||||
# Cached files including the TFRecord files and the meta data file.
|
||||
tfrecord_files = [
|
||||
cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards)
|
||||
for i in range(num_shards)
|
||||
]
|
||||
meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX
|
||||
return CacheFiles(
|
||||
cache_prefix=cache_prefix,
|
||||
tfrecord_files=tuple(tfrecord_files),
|
||||
meta_data_file=meta_data_file,
|
||||
)
|
||||
|
||||
|
||||
def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
|
||||
"""Creates an object of CacheFiles class using a COCO formatted dataset.
|
||||
|
||||
Args:
|
||||
data_dir: Folder path of the coco dataset
|
||||
cache_dir: Folder path of the cache location. When cache_dir is None, a
|
||||
temporary folder will be created and will not be removed automatically
|
||||
after training which makes it can be used later.
|
||||
|
||||
Returns:
|
||||
An object of CacheFiles class.
|
||||
"""
|
||||
hasher = hashlib.md5()
|
||||
# Update with dataset folder name
|
||||
hasher.update(_get_dir_basename(data_dir).encode('utf-8'))
|
||||
# Update with image filenames
|
||||
for image_file in sorted(os.listdir(os.path.join(data_dir, 'images'))):
|
||||
hasher.update(os.path.basename(image_file).encode('utf-8'))
|
||||
# Update with labels.json file content
|
||||
label_file = os.path.join(data_dir, 'labels.json')
|
||||
with open(label_file, 'r') as f:
|
||||
label_json = json.load(f)
|
||||
hasher.update(str(label_json).encode('utf-8'))
|
||||
num_examples = len(label_json['images'])
|
||||
# Num_shards automatically set to 100 images per shard, up to 10 shards total.
|
||||
# See https://www.tensorflow.org/tutorials/load_data/tfrecord for more info
|
||||
# on sharding.
|
||||
num_shards = min(math.ceil(num_examples / 100), 10)
|
||||
# Update with num shards
|
||||
hasher.update(str(num_shards).encode('utf-8'))
|
||||
cache_prefix_filename = hasher.hexdigest()
|
||||
|
||||
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
||||
|
||||
|
||||
def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
|
||||
"""Gets an object of CacheFiles using a PASCAL VOC formatted dataset.
|
||||
|
||||
Args:
|
||||
data_dir: Folder path of the pascal voc dataset.
|
||||
cache_dir: Folder path of the cache location. When cache_dir is None, a
|
||||
temporary folder will be created and will not be removed automatically
|
||||
after training which makes it can be used later.
|
||||
|
||||
Returns:
|
||||
An object of CacheFiles class.
|
||||
"""
|
||||
hasher = hashlib.md5()
|
||||
hasher.update(_get_dir_basename(data_dir).encode('utf-8'))
|
||||
annotation_files = tf.io.gfile.glob(
|
||||
os.path.join(data_dir, 'Annotations') + r'/*.xml'
|
||||
)
|
||||
annotation_filenames = [
|
||||
os.path.basename(ann_file) for ann_file in annotation_files
|
||||
]
|
||||
hasher.update(' '.join(annotation_filenames).encode('utf-8'))
|
||||
num_examples = len(annotation_filenames)
|
||||
num_shards = min(math.ceil(num_examples / 100), 10)
|
||||
hasher.update(str(num_shards).encode('utf-8'))
|
||||
cache_prefix_filename = hasher.hexdigest()
|
||||
|
||||
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
||||
|
||||
|
||||
def is_cached(cache_files: CacheFiles) -> bool:
|
||||
"""Checks whether cache files are already cached."""
|
||||
all_cached_files = list(cache_files.tfrecord_files) + [
|
||||
cache_files.meta_data_file
|
||||
]
|
||||
return all(tf.io.gfile.exists(path) for path in all_cached_files)
|
||||
|
||||
|
||||
class CacheFilesWriter(abc.ABC):
|
||||
"""CacheFilesWriter class to write the cached files."""
|
||||
|
||||
def __init__(
|
||||
self, label_map: Dict[int, str], max_num_images: Optional[int] = None
|
||||
) -> None:
|
||||
"""Initializes CacheFilesWriter for object detector.
|
||||
|
||||
Args:
|
||||
label_map: Dict, map label integer ids to string label names such as {1:
|
||||
'person', 2: 'notperson'}. 0 is the reserved key for `background` and
|
||||
doesn't need to be included in `label_map`. Label names can't be
|
||||
duplicated.
|
||||
max_num_images: Max number of images to process. If None, process all the
|
||||
images.
|
||||
"""
|
||||
self.label_map = label_map
|
||||
self.max_num_images = max_num_images
|
||||
|
||||
def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None:
|
||||
"""Writes TFRecord and meta_data files.
|
||||
|
||||
Args:
|
||||
cache_files: CacheFiles object including a list of TFRecord files and the
|
||||
meta data yaml file to save the meta_data including data size and
|
||||
label_map.
|
||||
*args: Non-keyword of parameters used in the `_get_example` method.
|
||||
**kwargs: Keyword parameters used in the `_get_example` method.
|
||||
"""
|
||||
writers = [
|
||||
tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files
|
||||
]
|
||||
|
||||
# Writes tf.Example into TFRecord files.
|
||||
size = 0
|
||||
for idx, tf_example in enumerate(self._get_example(*args, **kwargs)):
|
||||
if self.max_num_images and idx >= self.max_num_images:
|
||||
break
|
||||
if idx % 100 == 0:
|
||||
tf.compat.v1.logging.info('On image %d' % idx)
|
||||
writers[idx % len(writers)].write(tf_example.SerializeToString())
|
||||
size = idx + 1
|
||||
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
|
||||
# Writes meta_data into meta_data_file.
|
||||
meta_data = {'size': size, 'label_map': self.label_map}
|
||||
with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f:
|
||||
yaml.dump(meta_data, f)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_example(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_label_map_coco(data_dir: str):
|
||||
"""Gets the label map from a COCO formatted dataset directory.
|
||||
|
||||
Note that id 0 is reserved for the background class. If id=0 is set, it needs
|
||||
to be set to "background". It is optional to include id=0 if it is unused, and
|
||||
it will be automatically added by this method.
|
||||
|
||||
Args:
|
||||
data_dir: Path of the dataset directory
|
||||
|
||||
Returns:
|
||||
label_map dictionary of the format {<id>:<label_name>}
|
||||
|
||||
Raises:
|
||||
ValueError: If the label_name for id 0 is set to something other than
|
||||
the "background" class.
|
||||
"""
|
||||
data_dir = os.path.abspath(data_dir)
|
||||
# Process labels.json file
|
||||
label_file = os.path.join(data_dir, 'labels.json')
|
||||
with open(label_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Categories
|
||||
label_map = {}
|
||||
for category in data['categories']:
|
||||
label_map[int(category['id'])] = category['name']
|
||||
|
||||
if 0 in label_map and label_map[0] != 'background':
|
||||
raise ValueError(
|
||||
(
|
||||
'Label index 0 is reserved for the background class, but '
|
||||
f'it was found to be {label_map[0]}'
|
||||
),
|
||||
)
|
||||
if 0 not in label_map:
|
||||
label_map[0] = 'background'
|
||||
|
||||
return label_map
|
||||
|
||||
|
||||
def get_label_map_pascal_voc(data_dir: str):
|
||||
"""Gets the label map from a PASCAL VOC formatted dataset directory.
|
||||
|
||||
The id to label_name mapping is determined by sorting all label_names and
|
||||
numbering them starting from 1. Id=0 is set as the 'background' class.
|
||||
|
||||
Args:
|
||||
data_dir: Path of the dataset directory
|
||||
|
||||
Returns:
|
||||
label_map dictionary of the format {<id>:<label_name>}
|
||||
"""
|
||||
data_dir = os.path.abspath(data_dir)
|
||||
all_label_names = set()
|
||||
annotations_dir = os.path.join(data_dir, 'Annotations')
|
||||
all_annotation_files = tf.io.gfile.glob(annotations_dir + r'/*.xml')
|
||||
for ann_file in all_annotation_files:
|
||||
tree = ET.parse(ann_file)
|
||||
root = tree.getroot()
|
||||
for child in root.iter('object'):
|
||||
label_name = _xml_get(child, 'name').text
|
||||
all_label_names.add(label_name)
|
||||
label_map = {0: 'background'}
|
||||
for ind, label_name in enumerate(sorted(all_label_names)):
|
||||
label_map[ind + 1] = label_name
|
||||
return label_map
|
||||
|
||||
|
||||
def _bbox_data_to_feature_dict(data):
|
||||
"""Converts a dictionary of bbox annotations to a feature dictionary.
|
||||
|
||||
Args:
|
||||
data: Dict with keys 'xmin', 'xmax', 'ymin', 'ymax', 'category_id'
|
||||
|
||||
Returns:
|
||||
Feature dictionary
|
||||
"""
|
||||
bbox_feature_dict = {
|
||||
'image/object/bbox/xmin': tfrecord_lib.convert_to_feature(data['xmin']),
|
||||
'image/object/bbox/xmax': tfrecord_lib.convert_to_feature(data['xmax']),
|
||||
'image/object/bbox/ymin': tfrecord_lib.convert_to_feature(data['ymin']),
|
||||
'image/object/bbox/ymax': tfrecord_lib.convert_to_feature(data['ymax']),
|
||||
'image/object/class/label': tfrecord_lib.convert_to_feature(
|
||||
data['category_id']
|
||||
),
|
||||
}
|
||||
return bbox_feature_dict
|
||||
|
||||
|
||||
def _coco_annotations_to_lists(
|
||||
bbox_annotations: List[Mapping[str, Any]],
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
):
|
||||
"""Converts COCO annotations to feature lists.
|
||||
|
||||
Args:
|
||||
bbox_annotations: List of dicts with keys ['bbox', 'category_id']
|
||||
image_height: Height of image
|
||||
image_width: Width of iamge
|
||||
|
||||
Returns:
|
||||
(data, num_annotations_skipped) tuple where data contains the keys:
|
||||
['xmin', 'xmax', 'ymin', 'ymax', 'is_crowd', 'category_id', 'area'] and
|
||||
num_annotations_skipped is the number of skipped annotations because of the
|
||||
bbox having 0 area.
|
||||
"""
|
||||
|
||||
data = collections.defaultdict(list)
|
||||
|
||||
num_annotations_skipped = 0
|
||||
|
||||
for object_annotations in bbox_annotations:
|
||||
(x, y, width, height) = tuple(object_annotations['bbox'])
|
||||
|
||||
if width <= 0 or height <= 0:
|
||||
num_annotations_skipped += 1
|
||||
continue
|
||||
if x + width > image_width or y + height > image_height:
|
||||
num_annotations_skipped += 1
|
||||
continue
|
||||
data['xmin'].append(float(x) / image_width)
|
||||
data['xmax'].append(float(x + width) / image_width)
|
||||
data['ymin'].append(float(y) / image_height)
|
||||
data['ymax'].append(float(y + height) / image_height)
|
||||
category_id = int(object_annotations['category_id'])
|
||||
data['category_id'].append(category_id)
|
||||
|
||||
return data, num_annotations_skipped
|
||||
|
||||
|
||||
class COCOCacheFilesWriter(CacheFilesWriter):
|
||||
"""CacheFilesWriter class to write the cached files for COCO data."""
|
||||
|
||||
def _get_example(self, data_dir: str) -> tf.train.Example:
|
||||
"""Iterates over all examples in the COCO formatted dataset directory.
|
||||
|
||||
Args:
|
||||
data_dir: Path of the dataset directory
|
||||
|
||||
Yields:
|
||||
tf.train.Example
|
||||
"""
|
||||
data_dir = os.path.abspath(data_dir)
|
||||
# Process labels.json file
|
||||
label_file = os.path.join(data_dir, 'labels.json')
|
||||
with open(label_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Load all Annotations
|
||||
img_to_annotations = collections.defaultdict(list)
|
||||
for annotation in data['annotations']:
|
||||
image_id = annotation['image_id']
|
||||
img_to_annotations[image_id].append(annotation)
|
||||
|
||||
# For each Image:
|
||||
for image in data['images']:
|
||||
img_id = image['id']
|
||||
file_name = image['file_name']
|
||||
full_path = os.path.join(data_dir, 'images', file_name)
|
||||
with tf.io.gfile.GFile(full_path, 'rb') as fid:
|
||||
encoded_jpg = fid.read()
|
||||
image = tf.io.decode_jpeg(encoded_jpg, channels=3)
|
||||
height, width, _ = image.shape
|
||||
feature_dict = tfrecord_lib.image_info_to_feature_dict(
|
||||
height, width, file_name, img_id, encoded_jpg, 'jpg'
|
||||
)
|
||||
data, _ = _coco_annotations_to_lists(
|
||||
img_to_annotations[img_id], height, width
|
||||
)
|
||||
if not data['xmin']:
|
||||
# Skip examples which have no annotations
|
||||
continue
|
||||
bbox_feature_dict = _bbox_data_to_feature_dict(data)
|
||||
feature_dict.update(bbox_feature_dict)
|
||||
example = tf.train.Example(
|
||||
features=tf.train.Features(feature=feature_dict)
|
||||
)
|
||||
yield example
|
||||
|
||||
|
||||
class PascalVocCacheFilesWriter(CacheFilesWriter):
|
||||
"""CacheFilesWriter class to write the cached files for PASCAL VOC data."""
|
||||
|
||||
def _get_example(self, data_dir: str) -> tf.train.Example:
|
||||
"""Iterates over all examples in the PASCAL VOC formatted dataset directory.
|
||||
|
||||
Args:
|
||||
data_dir: Path of the dataset directory
|
||||
|
||||
Yields:
|
||||
tf.train.Example
|
||||
"""
|
||||
label_name_to_id = {name: i for (i, name) in self.label_map.items()}
|
||||
annotations_dir = os.path.join(data_dir, 'Annotations')
|
||||
images_dir = os.path.join(data_dir, 'images')
|
||||
all_annotation_paths = tf.io.gfile.glob(annotations_dir + r'/*.xml')
|
||||
|
||||
data = collections.defaultdict(list)
|
||||
for ind, ann_file in enumerate(all_annotation_paths):
|
||||
tree = ET.parse(ann_file)
|
||||
root = tree.getroot()
|
||||
img_filename = _xml_get(root, 'filename').text
|
||||
img_file = os.path.join(images_dir, img_filename)
|
||||
with tf.io.gfile.GFile(img_file, 'rb') as fid:
|
||||
encoded_jpg = fid.read()
|
||||
image = tf.io.decode_jpeg(encoded_jpg, channels=3)
|
||||
height, width, _ = image.shape
|
||||
for child in root.iter('object'):
|
||||
category_name = _xml_get(child, 'name').text
|
||||
category_id = label_name_to_id[category_name]
|
||||
bndbox = _xml_get(child, 'bndbox')
|
||||
xmin = float(_xml_get(bndbox, 'xmin').text)
|
||||
xmax = float(_xml_get(bndbox, 'xmax').text)
|
||||
ymin = float(_xml_get(bndbox, 'ymin').text)
|
||||
ymax = float(_xml_get(bndbox, 'ymax').text)
|
||||
if xmax <= xmin or ymax <= ymin or xmax > width or ymax > height:
|
||||
# Skip annotations that have no area or are larger than the image
|
||||
continue
|
||||
data['xmin'].append(xmin / width)
|
||||
data['ymin'].append(ymin / height)
|
||||
data['xmax'].append(xmax / width)
|
||||
data['ymax'].append(ymax / height)
|
||||
data['category_id'].append(category_id)
|
||||
if not data['xmin']:
|
||||
# Skip examples which have no valid annotations
|
||||
continue
|
||||
feature_dict = tfrecord_lib.image_info_to_feature_dict(
|
||||
height, width, img_filename, ind, encoded_jpg, 'jpg'
|
||||
)
|
||||
bbox_feature_dict = _bbox_data_to_feature_dict(data)
|
||||
feature_dict.update(bbox_feature_dict)
|
||||
example = tf.train.Example(
|
||||
features=tf.train.Features(feature=feature_dict)
|
||||
)
|
||||
yield example
|
|
@ -0,0 +1,236 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from mediapipe.model_maker.python.vision.core import test_utils
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
||||
from mediapipe.tasks.python.test import test_utils as tasks_test_utils
|
||||
|
||||
|
||||
class DatasetUtilTest(tf.test.TestCase):
|
||||
|
||||
def _assert_cache_files_equal(self, cf1, cf2):
|
||||
self.assertEqual(cf1.cache_prefix, cf2.cache_prefix)
|
||||
self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files)
|
||||
self.assertEqual(cf1.meta_data_file, cf2.meta_data_file)
|
||||
|
||||
def _assert_cache_files_not_equal(self, cf1, cf2):
|
||||
self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix)
|
||||
self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files)
|
||||
self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file)
|
||||
|
||||
def _get_cache_files_and_assert_neq_fn(self, cache_files_fn):
|
||||
def get_cache_files_and_assert_neq(cf, data_dir, cache_dir):
|
||||
new_cf = cache_files_fn(data_dir, cache_dir)
|
||||
self._assert_cache_files_not_equal(cf, new_cf)
|
||||
return new_cf
|
||||
|
||||
return get_cache_files_and_assert_neq
|
||||
|
||||
@unittest_mock.patch.object(hashlib, 'md5', autospec=True)
|
||||
def test_get_cache_files_coco(self, mock_md5):
|
||||
mock_md5.return_value.hexdigest.return_value = 'train'
|
||||
cache_files = dataset_util.get_cache_files_coco(
|
||||
tasks_test_utils.get_test_data_path('coco_data'), cache_dir='/tmp/'
|
||||
)
|
||||
self.assertEqual(cache_files.cache_prefix, '/tmp/train')
|
||||
self.assertLen(cache_files.tfrecord_files, 1)
|
||||
self.assertEqual(
|
||||
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
||||
)
|
||||
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
|
||||
|
||||
def test_matching_get_cache_files_coco(self):
|
||||
cache_dir = self.create_tempdir()
|
||||
coco_folder = tasks_test_utils.get_test_data_path('coco_data')
|
||||
coco_folder_tmp = os.path.join(self.create_tempdir(), 'coco_data')
|
||||
shutil.copytree(coco_folder, coco_folder_tmp)
|
||||
cache_files1 = dataset_util.get_cache_files_coco(coco_folder, cache_dir)
|
||||
cache_files2 = dataset_util.get_cache_files_coco(coco_folder, cache_dir)
|
||||
self._assert_cache_files_equal(cache_files1, cache_files2)
|
||||
cache_files3 = dataset_util.get_cache_files_coco(coco_folder_tmp, cache_dir)
|
||||
self._assert_cache_files_equal(cache_files1, cache_files3)
|
||||
|
||||
def test_not_matching_get_cache_files_coco(self):
|
||||
cache_dir = self.create_tempdir()
|
||||
temp_dir = self.create_tempdir()
|
||||
coco_folder = os.path.join(temp_dir, 'coco_data')
|
||||
shutil.copytree(
|
||||
tasks_test_utils.get_test_data_path('coco_data'), coco_folder
|
||||
)
|
||||
prev_cache_file = dataset_util.get_cache_files_coco(coco_folder, cache_dir)
|
||||
os.chmod(coco_folder, 0o700)
|
||||
os.chmod(os.path.join(coco_folder, 'images'), 0o700)
|
||||
os.chmod(os.path.join(coco_folder, 'labels.json'), 0o700)
|
||||
get_cache_files_and_assert_neq = self._get_cache_files_and_assert_neq_fn(
|
||||
dataset_util.get_cache_files_coco
|
||||
)
|
||||
# Test adding image
|
||||
test_utils.write_filled_jpeg_file(
|
||||
os.path.join(coco_folder, 'images', 'test.jpg'), [0, 0, 0], 50
|
||||
)
|
||||
prev_cache_file = get_cache_files_and_assert_neq(
|
||||
prev_cache_file, coco_folder, cache_dir
|
||||
)
|
||||
# Test modifying labels.json
|
||||
with open(os.path.join(coco_folder, 'labels.json'), 'w') as f:
|
||||
json.dump({'images': [{'id': 1, 'file_name': '000000000078.jpg'}]}, f)
|
||||
prev_cache_file = get_cache_files_and_assert_neq(
|
||||
prev_cache_file, coco_folder, cache_dir
|
||||
)
|
||||
|
||||
# Test rename folder
|
||||
new_coco_folder = os.path.join(temp_dir, 'coco_data_renamed')
|
||||
shutil.move(coco_folder, new_coco_folder)
|
||||
coco_folder = new_coco_folder
|
||||
prev_cache_file = get_cache_files_and_assert_neq(
|
||||
prev_cache_file, new_coco_folder, cache_dir
|
||||
)
|
||||
|
||||
@unittest_mock.patch.object(hashlib, 'md5', autospec=True)
|
||||
def test_get_cache_files_pascal_voc(self, mock_md5):
|
||||
mock_md5.return_value.hexdigest.return_value = 'train'
|
||||
cache_files = dataset_util.get_cache_files_pascal_voc(
|
||||
tasks_test_utils.get_test_data_path('pascal_voc_data'),
|
||||
cache_dir='/tmp/',
|
||||
)
|
||||
self.assertEqual(cache_files.cache_prefix, '/tmp/train')
|
||||
self.assertLen(cache_files.tfrecord_files, 1)
|
||||
self.assertEqual(
|
||||
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
||||
)
|
||||
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
|
||||
|
||||
def test_matching_get_cache_files_pascal_voc(self):
|
||||
cache_dir = self.create_tempdir()
|
||||
pascal_folder = tasks_test_utils.get_test_data_path('pascal_voc_data')
|
||||
pascal_folder_temp = os.path.join(self.create_tempdir(), 'pascal_voc_data')
|
||||
shutil.copytree(pascal_folder, pascal_folder_temp)
|
||||
cache_files1 = dataset_util.get_cache_files_pascal_voc(
|
||||
pascal_folder, cache_dir
|
||||
)
|
||||
cache_files2 = dataset_util.get_cache_files_pascal_voc(
|
||||
pascal_folder, cache_dir
|
||||
)
|
||||
self._assert_cache_files_equal(cache_files1, cache_files2)
|
||||
cache_files3 = dataset_util.get_cache_files_pascal_voc(
|
||||
pascal_folder_temp, cache_dir
|
||||
)
|
||||
self._assert_cache_files_equal(cache_files1, cache_files3)
|
||||
|
||||
def test_not_matching_get_cache_files_pascal_voc(self):
|
||||
cache_dir = self.create_tempdir()
|
||||
temp_dir = self.create_tempdir()
|
||||
pascal_folder = os.path.join(temp_dir, 'pascal_voc_data')
|
||||
shutil.copytree(
|
||||
tasks_test_utils.get_test_data_path('pascal_voc_data'), pascal_folder
|
||||
)
|
||||
prev_cache_files = dataset_util.get_cache_files_pascal_voc(
|
||||
pascal_folder, cache_dir
|
||||
)
|
||||
os.chmod(pascal_folder, 0o700)
|
||||
os.chmod(os.path.join(pascal_folder, 'images'), 0o700)
|
||||
os.chmod(os.path.join(pascal_folder, 'Annotations'), 0o700)
|
||||
get_cache_files_and_assert_neq = self._get_cache_files_and_assert_neq_fn(
|
||||
dataset_util.get_cache_files_pascal_voc
|
||||
)
|
||||
# Test adding xml file
|
||||
with open(os.path.join(pascal_folder, 'Annotations', 'test.xml'), 'w') as f:
|
||||
f.write('test')
|
||||
prev_cache_files = get_cache_files_and_assert_neq(
|
||||
prev_cache_files, pascal_folder, cache_dir
|
||||
)
|
||||
|
||||
# Test rename folder
|
||||
new_pascal_folder = os.path.join(temp_dir, 'pascal_voc_data_renamed')
|
||||
shutil.move(pascal_folder, new_pascal_folder)
|
||||
pascal_folder = new_pascal_folder
|
||||
prev_cache_files = get_cache_files_and_assert_neq(
|
||||
prev_cache_files, new_pascal_folder, cache_dir
|
||||
)
|
||||
|
||||
def test_is_cached(self):
|
||||
tempdir = self.create_tempdir()
|
||||
cache_files = dataset_util.get_cache_files_coco(
|
||||
tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir
|
||||
)
|
||||
self.assertFalse(dataset_util.is_cached(cache_files))
|
||||
with open(cache_files.tfrecord_files[0], 'w') as f:
|
||||
f.write('test')
|
||||
self.assertFalse(dataset_util.is_cached(cache_files))
|
||||
with open(cache_files.meta_data_file, 'w') as f:
|
||||
f.write('test')
|
||||
self.assertTrue(dataset_util.is_cached(cache_files))
|
||||
|
||||
def test_get_label_map_coco(self):
|
||||
coco_dir = tasks_test_utils.get_test_data_path('coco_data')
|
||||
label_map = dataset_util.get_label_map_coco(coco_dir)
|
||||
all_keys = sorted(label_map.keys())
|
||||
self.assertEqual(all_keys[0], 0)
|
||||
self.assertEqual(all_keys[-1], 11)
|
||||
self.assertLen(all_keys, 12)
|
||||
|
||||
def test_get_label_map_pascal_voc(self):
|
||||
pascal_dir = tasks_test_utils.get_test_data_path('pascal_voc_data')
|
||||
label_map = dataset_util.get_label_map_pascal_voc(pascal_dir)
|
||||
all_keys = sorted(label_map.keys())
|
||||
self.assertEqual(label_map[0], 'background')
|
||||
self.assertEqual(all_keys[0], 0)
|
||||
self.assertEqual(all_keys[-1], 2)
|
||||
self.assertLen(all_keys, 3)
|
||||
|
||||
def _validate_cache_files(self, cache_files, expected_size):
|
||||
# Checks the TFRecord file
|
||||
self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0]))
|
||||
self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0)
|
||||
|
||||
# Checks the meta_data file
|
||||
self.assertTrue(os.path.isfile(cache_files.meta_data_file))
|
||||
self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0)
|
||||
with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f:
|
||||
meta_data_dict = yaml.load(f, Loader=yaml.FullLoader)
|
||||
# Size is 3 because some examples are skipped for having poor bboxes
|
||||
self.assertEqual(meta_data_dict['size'], expected_size)
|
||||
|
||||
def test_coco_cache_files_writer(self):
|
||||
tempdir = self.create_tempdir()
|
||||
coco_dir = tasks_test_utils.get_test_data_path('coco_data')
|
||||
label_map = dataset_util.get_label_map_coco(coco_dir)
|
||||
cache_writer = dataset_util.COCOCacheFilesWriter(label_map)
|
||||
cache_files = dataset_util.get_cache_files_coco(coco_dir, cache_dir=tempdir)
|
||||
cache_writer.write_files(cache_files, coco_dir)
|
||||
self._validate_cache_files(cache_files, 3)
|
||||
|
||||
def test_pascal_voc_cache_files_writer(self):
|
||||
tempdir = self.create_tempdir()
|
||||
pascal_dir = tasks_test_utils.get_test_data_path('pascal_voc_data')
|
||||
label_map = dataset_util.get_label_map_pascal_voc(pascal_dir)
|
||||
cache_writer = dataset_util.PascalVocCacheFilesWriter(label_map)
|
||||
cache_files = dataset_util.get_cache_files_pascal_voc(
|
||||
pascal_dir, cache_dir=tempdir
|
||||
)
|
||||
cache_writer.write_files(cache_files, pascal_dir)
|
||||
self._validate_cache_files(cache_files, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Hyperparameters for training object detection models."""
|
||||
|
||||
import dataclasses
|
||||
from typing import List
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HParams(hp.BaseHParams):
|
||||
"""The hyperparameters for training object detectors.
|
||||
|
||||
Attributes:
|
||||
learning_rate: Learning rate to use for gradient descent training.
|
||||
batch_size: Batch size for training.
|
||||
epochs: Number of training iterations over the dataset.
|
||||
do_fine_tuning: If true, the base module is trained together with the
|
||||
classification layer on top.
|
||||
learning_rate_boundaries: List of epoch boundaries where
|
||||
learning_rate_boundaries[i] is the epoch where the learning rate will
|
||||
decay to learning_rate * learning_rate_decay_multipliers[i].
|
||||
learning_rate_decay_multipliers: List of learning rate multipliers which
|
||||
calculates the learning rate at the ith boundary as learning_rate *
|
||||
learning_rate_decay_multipliers[i].
|
||||
"""
|
||||
|
||||
# Parameters from BaseHParams class.
|
||||
learning_rate: float = 0.003
|
||||
batch_size: int = 32
|
||||
epochs: int = 10
|
||||
|
||||
# Parameters for learning rate decay
|
||||
learning_rate_boundaries: List[int] = dataclasses.field(
|
||||
default_factory=lambda: [5, 8]
|
||||
)
|
||||
learning_rate_decay_multipliers: List[float] = dataclasses.field(
|
||||
default_factory=lambda: [0.1, 0.01]
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate stepwise learning rate parameters
|
||||
lr_boundary_len = len(self.learning_rate_boundaries)
|
||||
lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers)
|
||||
if lr_boundary_len != lr_decay_multipliers_len:
|
||||
raise ValueError(
|
||||
"Length of learning_rate_boundaries and ",
|
||||
"learning_rate_decay_multipliers do not match: ",
|
||||
f"{lr_boundary_len}!={lr_decay_multipliers_len}",
|
||||
)
|
||||
# Validate learning_rate_boundaries
|
||||
if sorted(self.learning_rate_boundaries) != self.learning_rate_boundaries:
|
||||
raise ValueError(
|
||||
"learning_rate_boundaries is not in ascending order: ",
|
||||
self.learning_rate_boundaries,
|
||||
)
|
||||
if (
|
||||
self.learning_rate_boundaries
|
||||
and self.learning_rate_boundaries[-1] > self.epochs
|
||||
):
|
||||
raise ValueError(
|
||||
"Values in learning_rate_boundaries cannot be greater ", "than epochs"
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class QATHParams:
|
||||
"""The hyperparameters for running quantization aware training (QAT) on object detectors.
|
||||
|
||||
For more information on QAT, see:
|
||||
https://www.tensorflow.org/model_optimization/guide/quantization/training
|
||||
|
||||
Attributes:
|
||||
learning_rate: Learning rate to use for gradient descent QAT.
|
||||
batch_size: Batch size for QAT.
|
||||
epochs: Number of training iterations over the dataset.
|
||||
decay_steps: Learning rate decay steps for Exponential Decay. See
|
||||
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay
|
||||
for more information.
|
||||
decay_rate: Learning rate decay rate for Exponential Decay. See
|
||||
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay
|
||||
for more information.
|
||||
"""
|
||||
|
||||
learning_rate: float = 0.03
|
||||
batch_size: int = 32
|
||||
epochs: int = 10
|
||||
decay_steps: int = 231
|
||||
decay_rate: float = 0.96
|
355
mediapipe/model_maker/python/vision/object_detector/model.py
Normal file
|
@ -0,0 +1,355 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Custom Model for Object Detection."""
|
||||
|
||||
import os
|
||||
from typing import Mapping, Optional, Sequence, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||
from official.core import config_definitions as cfg
|
||||
from official.projects.qat.vision.configs import common as qat_common
|
||||
from official.projects.qat.vision.modeling import factory as qat_factory
|
||||
from official.vision import configs
|
||||
from official.vision.losses import focal_loss
|
||||
from official.vision.losses import loss_utils
|
||||
from official.vision.modeling import factory
|
||||
from official.vision.modeling import retinanet_model
|
||||
from official.vision.modeling.layers import detection_generator
|
||||
from official.vision.serving import detection
|
||||
|
||||
|
||||
class ObjectDetectorModel(tf.keras.Model):
|
||||
"""An object detector model which can be trained using Model Maker's training API.
|
||||
|
||||
Attributes:
|
||||
loss_trackers: List of tf.keras.metrics.Mean objects used to track the loss
|
||||
during training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_spec: ms.ModelSpec,
|
||||
model_options: model_opt.ObjectDetectorModelOptions,
|
||||
num_classes: int,
|
||||
) -> None:
|
||||
"""Initializes an ObjectDetectorModel.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
model_options: Model options for creating the model.
|
||||
num_classes: Number of classes for object detection.
|
||||
"""
|
||||
super().__init__()
|
||||
self._model_spec = model_spec
|
||||
self._model_options = model_options
|
||||
self._num_classes = num_classes
|
||||
self._model = self._build_model()
|
||||
checkpoint_folder = self._model_spec.downloaded_files.get_path()
|
||||
checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200')
|
||||
self.load_checkpoint(checkpoint_file)
|
||||
self._model.summary()
|
||||
self.loss_trackers = [
|
||||
tf.keras.metrics.Mean(name=n)
|
||||
for n in ['total_loss', 'cls_loss', 'box_loss', 'model_loss']
|
||||
]
|
||||
|
||||
def _get_model_config(
|
||||
self,
|
||||
generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(),
|
||||
) -> configs.retinanet.RetinaNet:
|
||||
model_config = configs.retinanet.RetinaNet(
|
||||
min_level=3,
|
||||
max_level=7,
|
||||
num_classes=self._num_classes,
|
||||
input_size=self._model_spec.input_image_shape,
|
||||
anchor=configs.retinanet.Anchor(
|
||||
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
|
||||
),
|
||||
backbone=configs.backbones.Backbone(
|
||||
type='mobilenet', mobilenet=configs.backbones.MobileNet()
|
||||
),
|
||||
decoder=configs.decoders.Decoder(
|
||||
type='fpn',
|
||||
fpn=configs.decoders.FPN(
|
||||
num_filters=128, use_separable_conv=True, use_keras_layer=True
|
||||
),
|
||||
),
|
||||
head=configs.retinanet.RetinaNetHead(
|
||||
num_filters=128, use_separable_conv=True
|
||||
),
|
||||
detection_generator=generator_config,
|
||||
norm_activation=configs.common.NormActivation(activation='relu6'),
|
||||
)
|
||||
return model_config
|
||||
|
||||
def _build_model(self) -> tf.keras.Model:
|
||||
"""Builds a RetinaNet object detector model."""
|
||||
input_specs = tf.keras.layers.InputSpec(
|
||||
shape=[None] + self._model_spec.input_image_shape
|
||||
)
|
||||
l2_regularizer = tf.keras.regularizers.l2(
|
||||
self._model_options.l2_weight_decay / 2.0
|
||||
)
|
||||
model_config = self._get_model_config()
|
||||
|
||||
return factory.build_retinanet(input_specs, model_config, l2_regularizer)
|
||||
|
||||
def save_checkpoint(self, checkpoint_path: str) -> None:
|
||||
"""Saves a model checkpoint to checkpoint_path.
|
||||
|
||||
Args:
|
||||
checkpoint_path: The path to save checkpoint.
|
||||
"""
|
||||
ckpt_items = {
|
||||
'backbone': self._model.backbone,
|
||||
'decoder': self._model.decoder,
|
||||
'head': self._model.head,
|
||||
}
|
||||
tf.train.Checkpoint(**ckpt_items).write(checkpoint_path)
|
||||
|
||||
def load_checkpoint(
|
||||
self, checkpoint_path: str, include_last_layer: bool = False
|
||||
) -> None:
|
||||
"""Loads a model checkpoint from checkpoint_path.
|
||||
|
||||
Args:
|
||||
checkpoint_path: The path to load a checkpoint from.
|
||||
include_last_layer: Whether or not to load the last classification layer.
|
||||
The size of the last classification layer will differ depending on the
|
||||
number of classes. When loading from the pre-trained checkpoint, this
|
||||
parameter should be False to avoid shape mismatch on the last layer.
|
||||
Defaults to False.
|
||||
"""
|
||||
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
|
||||
self._model(dummy_input, training=True)
|
||||
if include_last_layer:
|
||||
head = self._model.head
|
||||
else:
|
||||
head_classifier = tf.train.Checkpoint(
|
||||
depthwise_kernel=self._model.head._classifier.depthwise_kernel # pylint:disable=protected-access
|
||||
)
|
||||
head_items = {
|
||||
'_classifier': head_classifier,
|
||||
'_box_norms': self._model.head._box_norms, # pylint:disable=protected-access
|
||||
'_box_regressor': self._model.head._box_regressor, # pylint:disable=protected-access
|
||||
'_cls_convs': self._model.head._cls_convs, # pylint:disable=protected-access
|
||||
'_cls_norms': self._model.head._cls_norms, # pylint:disable=protected-access
|
||||
'_box_convs': self._model.head._box_convs, # pylint:disable=protected-access
|
||||
}
|
||||
head = tf.train.Checkpoint(**head_items)
|
||||
ckpt_items = {
|
||||
'backbone': self._model.backbone,
|
||||
'decoder': self._model.decoder,
|
||||
'head': head,
|
||||
}
|
||||
ckpt = tf.train.Checkpoint(**ckpt_items)
|
||||
status = ckpt.read(checkpoint_path)
|
||||
status.expect_partial().assert_existing_objects_matched()
|
||||
|
||||
def convert_to_qat(self) -> None:
|
||||
"""Converts the model to a QAT RetinaNet model."""
|
||||
model = self._build_model()
|
||||
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
|
||||
model(dummy_input, training=True)
|
||||
model.set_weights(self._model.get_weights())
|
||||
quantization_config = qat_common.Quantization(
|
||||
quantize_detection_decoder=True, quantize_detection_head=True
|
||||
)
|
||||
model_config = self._get_model_config()
|
||||
qat_model = qat_factory.build_qat_retinanet(
|
||||
model, quantization_config, model_config
|
||||
)
|
||||
self._model = qat_model
|
||||
|
||||
def export_saved_model(self, save_path: str):
|
||||
"""Exports a saved_model for tflite conversion.
|
||||
|
||||
The export process modifies the model in the following two ways:
|
||||
1. Replaces the nms operation in the detection generator with a custom
|
||||
TFLite compatible nms operation.
|
||||
2. Wraps the model with a DetectionModule which handles pre-processing
|
||||
and post-processing when running inference.
|
||||
|
||||
Args:
|
||||
save_path: Path to export the saved model.
|
||||
"""
|
||||
generator_config = configs.retinanet.DetectionGenerator(
|
||||
nms_version='tflite',
|
||||
tflite_post_processing=configs.common.TFLitePostProcessingConfig(
|
||||
nms_score_threshold=0,
|
||||
max_detections=10,
|
||||
max_classes_per_detection=1,
|
||||
normalize_anchor_coordinates=True,
|
||||
),
|
||||
)
|
||||
tflite_post_processing_config = (
|
||||
generator_config.tflite_post_processing.as_dict()
|
||||
)
|
||||
tflite_post_processing_config['input_image_size'] = (
|
||||
self._model_spec.input_image_shape[0],
|
||||
self._model_spec.input_image_shape[1],
|
||||
)
|
||||
detection_generator_obj = detection_generator.MultilevelDetectionGenerator(
|
||||
apply_nms=generator_config.apply_nms,
|
||||
pre_nms_top_k=generator_config.pre_nms_top_k,
|
||||
pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
|
||||
nms_iou_threshold=generator_config.nms_iou_threshold,
|
||||
max_num_detections=generator_config.max_num_detections,
|
||||
nms_version=generator_config.nms_version,
|
||||
use_cpu_nms=generator_config.use_cpu_nms,
|
||||
soft_nms_sigma=generator_config.soft_nms_sigma,
|
||||
tflite_post_processing_config=tflite_post_processing_config,
|
||||
return_decoded=generator_config.return_decoded,
|
||||
use_class_agnostic_nms=generator_config.use_class_agnostic_nms,
|
||||
)
|
||||
model_config = self._get_model_config(generator_config)
|
||||
model = retinanet_model.RetinaNetModel(
|
||||
self._model.backbone,
|
||||
self._model.decoder,
|
||||
self._model.head,
|
||||
detection_generator_obj,
|
||||
min_level=model_config.min_level,
|
||||
max_level=model_config.max_level,
|
||||
num_scales=model_config.anchor.num_scales,
|
||||
aspect_ratios=model_config.anchor.aspect_ratios,
|
||||
anchor_size=model_config.anchor.anchor_size,
|
||||
)
|
||||
task_config = configs.retinanet.RetinaNetTask(model=model_config)
|
||||
params = cfg.ExperimentConfig(
|
||||
task=task_config,
|
||||
)
|
||||
export_module = detection.DetectionModule(
|
||||
params=params,
|
||||
batch_size=1,
|
||||
input_image_size=self._model_spec.input_image_shape[:2],
|
||||
input_type='tflite',
|
||||
num_channels=self._model_spec.input_image_shape[2],
|
||||
model=model,
|
||||
)
|
||||
function_keys = {'tflite': tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY}
|
||||
signatures = export_module.get_inference_signatures(function_keys)
|
||||
|
||||
tf.saved_model.save(export_module, save_path, signatures=signatures)
|
||||
|
||||
# The remaining method overrides are used to train this object detector model
|
||||
# using model.fit().
|
||||
def call(
|
||||
self,
|
||||
images: Union[tf.Tensor, Sequence[tf.Tensor]],
|
||||
image_shape: Optional[tf.Tensor] = None,
|
||||
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
|
||||
output_intermediate_features: bool = False,
|
||||
training: bool = None,
|
||||
) -> Mapping[str, tf.Tensor]:
|
||||
"""Overrides call from tf.keras.Model."""
|
||||
return self._model(
|
||||
images,
|
||||
image_shape,
|
||||
anchor_boxes,
|
||||
output_intermediate_features,
|
||||
training,
|
||||
)
|
||||
|
||||
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
|
||||
"""Overrides compute_loss from tf.keras.Model."""
|
||||
cls_loss_fn = focal_loss.FocalLoss(
|
||||
alpha=0.25, gamma=1.5, reduction=tf.keras.losses.Reduction.SUM
|
||||
)
|
||||
box_loss_fn = tf.keras.losses.Huber(
|
||||
0.1, reduction=tf.keras.losses.Reduction.SUM
|
||||
)
|
||||
labels = y
|
||||
outputs = y_pred
|
||||
# Sums all positives in a batch for normalization and avoids zero
|
||||
# num_positives_sum, which would lead to inf loss during training
|
||||
cls_sample_weight = labels['cls_weights']
|
||||
box_sample_weight = labels['box_weights']
|
||||
num_positives = tf.reduce_sum(box_sample_weight) + 1.0
|
||||
cls_sample_weight = cls_sample_weight / num_positives
|
||||
box_sample_weight = box_sample_weight / num_positives
|
||||
y_true_cls = loss_utils.multi_level_flatten(
|
||||
labels['cls_targets'], last_dim=None
|
||||
)
|
||||
y_true_cls = tf.one_hot(y_true_cls, self._num_classes)
|
||||
y_pred_cls = loss_utils.multi_level_flatten(
|
||||
outputs['cls_outputs'], last_dim=self._num_classes
|
||||
)
|
||||
y_true_box = loss_utils.multi_level_flatten(
|
||||
labels['box_targets'], last_dim=4
|
||||
)
|
||||
y_pred_box = loss_utils.multi_level_flatten(
|
||||
outputs['box_outputs'], last_dim=4
|
||||
)
|
||||
|
||||
cls_loss = cls_loss_fn(
|
||||
y_true=y_true_cls, y_pred=y_pred_cls, sample_weight=cls_sample_weight
|
||||
)
|
||||
box_loss = box_loss_fn(
|
||||
y_true=y_true_box, y_pred=y_pred_box, sample_weight=box_sample_weight
|
||||
)
|
||||
|
||||
model_loss = cls_loss + 50 * box_loss
|
||||
total_loss = model_loss
|
||||
regularization_losses = self._model.losses
|
||||
if regularization_losses:
|
||||
reg_loss = tf.reduce_sum(regularization_losses)
|
||||
total_loss = model_loss + reg_loss
|
||||
all_losses = {
|
||||
'total_loss': total_loss,
|
||||
'cls_loss': cls_loss,
|
||||
'box_loss': box_loss,
|
||||
'model_loss': model_loss,
|
||||
}
|
||||
for m in self.metrics:
|
||||
m.update_state(all_losses[m.name])
|
||||
return total_loss
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
"""Overrides metrics from tf.keras.Model."""
|
||||
return self.loss_trackers
|
||||
|
||||
def compute_metrics(self, x, y, y_pred, sample_weight=None):
|
||||
"""Overrides compute_metrics from tf.keras.Model."""
|
||||
return self.get_metrics_result()
|
||||
|
||||
def train_step(self, data):
|
||||
"""Overrides train_step from tf.keras.Model."""
|
||||
tf.keras.backend.set_learning_phase(1)
|
||||
x, y = data
|
||||
# Run forward pass.
|
||||
with tf.GradientTape() as tape:
|
||||
y_pred = self(x, training=True)
|
||||
loss = self.compute_loss(x, y, y_pred)
|
||||
self._validate_target_and_loss(y, loss)
|
||||
# Run backwards pass.
|
||||
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
||||
return self.compute_metrics(x, y, y_pred)
|
||||
|
||||
def test_step(self, data):
|
||||
"""Overrides test_step from tf.keras.Model."""
|
||||
tf.keras.backend.set_learning_phase(0)
|
||||
x, y = data
|
||||
y_pred = self(
|
||||
x,
|
||||
anchor_boxes=y['anchor_boxes'],
|
||||
image_shape=y['image_info'][:, 1, :],
|
||||
training=False,
|
||||
)
|
||||
# Updates stateful loss metrics.
|
||||
self.compute_loss(x, y, y_pred)
|
||||
return self.compute_metrics(x, y, y_pred)
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Configurable model options for object detector models."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ObjectDetectorModelOptions:
|
||||
"""Configurable options for object detector model.
|
||||
|
||||
Attributes:
|
||||
l2_weight_decay: L2 regularization penalty used in
|
||||
https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/L2.
|
||||
"""
|
||||
|
||||
l2_weight_decay: float = 3.0e-05
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Object detector model specification."""
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from typing import List
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
|
||||
|
||||
MOBILENET_V2_FILES = file_util.DownloadedFiles(
|
||||
'object_detector/mobilenetv2',
|
||||
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz',
|
||||
is_folder=True,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelSpec(object):
|
||||
"""Specification of object detector model."""
|
||||
|
||||
# Mean and Stddev image preprocessing normalization values.
|
||||
mean_norm = (0.5,)
|
||||
stddev_norm = (0.5,)
|
||||
mean_rgb = (127.5,)
|
||||
stddev_rgb = (127.5,)
|
||||
|
||||
downloaded_files: file_util.DownloadedFiles
|
||||
input_image_shape: List[int]
|
||||
|
||||
|
||||
mobilenet_v2_spec = functools.partial(
|
||||
ModelSpec,
|
||||
downloaded_files=MOBILENET_V2_FILES,
|
||||
input_image_shape=[256, 256, 3],
|
||||
)
|
||||
|
||||
|
||||
@enum.unique
|
||||
class SupportedModels(enum.Enum):
|
||||
"""Predefined object detector model specs supported by Model Maker."""
|
||||
|
||||
MOBILENET_V2 = mobilenet_v2_spec
|
||||
|
||||
@classmethod
|
||||
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||
"""Get model spec from the input enum and initializes it."""
|
||||
if spec not in cls:
|
||||
raise TypeError(f'Unsupported object detector spec: {spec}')
|
||||
return spec.value()
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset as ds
|
||||
from mediapipe.model_maker.python.vision.object_detector import model as model_lib
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.object_detector import preprocessor
|
||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||
|
||||
|
||||
def _dicts_match(dict_1, dict_2):
|
||||
for key in dict_1:
|
||||
if key not in dict_2 or np.any(dict_1[key] != dict_2[key]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _outputs_match(output1, output2):
|
||||
return _dicts_match(
|
||||
output1['cls_outputs'], output2['cls_outputs']
|
||||
) and _dicts_match(output1['box_outputs'], output2['box_outputs'])
|
||||
|
||||
|
||||
class ObjectDetectorModelTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
||||
cache_dir = self.create_tempdir()
|
||||
self.data = ds.Dataset.from_coco_folder(dataset_folder, cache_dir=cache_dir)
|
||||
self.model_spec = ms.SupportedModels.MOBILENET_V2.value()
|
||||
self.preprocessor = preprocessor.Preprocessor(self.model_spec)
|
||||
self.fake_inputs = np.random.uniform(
|
||||
low=0, high=1, size=(1, 256, 256, 3)
|
||||
).astype(np.float32)
|
||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||
# condition when downloading model since these tests may run in parallel.
|
||||
mock_gettempdir = unittest_mock.patch.object(
|
||||
tempfile,
|
||||
'gettempdir',
|
||||
return_value=self.create_tempdir(),
|
||||
autospec=True,
|
||||
)
|
||||
self.mock_gettempdir = mock_gettempdir.start()
|
||||
self.addCleanup(mock_gettempdir.stop)
|
||||
|
||||
def _create_model(self):
|
||||
model_options = model_opt.ObjectDetectorModelOptions()
|
||||
model = model_lib.ObjectDetectorModel(
|
||||
self.model_spec, model_options, self.data.num_classes
|
||||
)
|
||||
return model
|
||||
|
||||
def _train_model(self, model):
|
||||
"""Helper to run a simple training run on the model."""
|
||||
dataset = self.data.gen_tf_dataset(
|
||||
batch_size=2,
|
||||
is_training=True,
|
||||
shuffle=False,
|
||||
preprocess=self.preprocessor,
|
||||
)
|
||||
optimizer = tf.keras.optimizers.experimental.SGD(
|
||||
learning_rate=0.03, momentum=0.9
|
||||
)
|
||||
model.compile(optimizer=optimizer)
|
||||
model.fit(
|
||||
x=dataset, epochs=2, steps_per_epoch=None, validation_data=dataset
|
||||
)
|
||||
|
||||
def test_model(self):
|
||||
model = self._create_model()
|
||||
outputs_before = model(self.fake_inputs, training=True)
|
||||
self._train_model(model)
|
||||
outputs_after = model(self.fake_inputs, training=True)
|
||||
self.assertFalse(_outputs_match(outputs_before, outputs_after))
|
||||
|
||||
def test_model_convert_to_qat(self):
|
||||
model_options = model_opt.ObjectDetectorModelOptions()
|
||||
model = model_lib.ObjectDetectorModel(
|
||||
self.model_spec, model_options, self.data.num_classes
|
||||
)
|
||||
outputs_before = model(self.fake_inputs, training=True)
|
||||
model.convert_to_qat()
|
||||
outputs_after = model(self.fake_inputs, training=True)
|
||||
self.assertFalse(_outputs_match(outputs_before, outputs_after))
|
||||
outputs_before = outputs_after
|
||||
self._train_model(model)
|
||||
outputs_after = model(self.fake_inputs, training=True)
|
||||
self.assertFalse(_outputs_match(outputs_before, outputs_after))
|
||||
|
||||
def test_model_save_and_load_checkpoint(self):
|
||||
model = self._create_model()
|
||||
checkpoint_path = os.path.join(self.create_tempdir(), 'ckpt')
|
||||
model.save_checkpoint(checkpoint_path)
|
||||
data_checkpoint_file = checkpoint_path + '.data-00000-of-00001'
|
||||
index_checkpoint_file = checkpoint_path + '.index'
|
||||
self.assertTrue(os.path.exists(data_checkpoint_file))
|
||||
self.assertTrue(os.path.exists(index_checkpoint_file))
|
||||
self.assertGreater(os.path.getsize(data_checkpoint_file), 0)
|
||||
self.assertGreater(os.path.getsize(index_checkpoint_file), 0)
|
||||
outputs_before = model(self.fake_inputs, training=True)
|
||||
|
||||
# Check model output is different after training
|
||||
self._train_model(model)
|
||||
outputs_after = model(self.fake_inputs, training=True)
|
||||
self.assertFalse(_outputs_match(outputs_before, outputs_after))
|
||||
|
||||
# Check model output is the same after loading previous checkpoint
|
||||
model.load_checkpoint(checkpoint_path, include_last_layer=True)
|
||||
outputs_after = model(self.fake_inputs, training=True)
|
||||
self.assertTrue(_outputs_match(outputs_before, outputs_after))
|
||||
|
||||
def test_export_saved_model(self):
|
||||
export_dir = self.create_tempdir()
|
||||
export_path = os.path.join(export_dir, 'saved_model')
|
||||
model = self._create_model()
|
||||
model.export_saved_model(export_path)
|
||||
self.assertTrue(os.path.exists(export_path))
|
||||
self.assertGreater(os.path.getsize(export_path), 0)
|
||||
model.convert_to_qat()
|
||||
export_path = os.path.join(export_dir, 'saved_model_qat')
|
||||
model.export_saved_model(export_path)
|
||||
self.assertTrue(os.path.exists(export_path))
|
||||
self.assertGreater(os.path.getsize(export_path), 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,353 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""APIs to train object detector model."""
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset as ds
|
||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.vision.object_detector import model as model_lib
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
||||
from mediapipe.model_maker.python.vision.object_detector import preprocessor
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import object_detector as object_detector_writer
|
||||
from official.vision.evaluation import coco_evaluator
|
||||
|
||||
|
||||
class ObjectDetector(classifier.Classifier):
|
||||
"""ObjectDetector for building object detection model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_spec: ms.ModelSpec,
|
||||
label_names: List[str],
|
||||
hparams: hp.HParams,
|
||||
model_options: model_opt.ObjectDetectorModelOptions,
|
||||
) -> None:
|
||||
"""Initializes ObjectDetector class.
|
||||
|
||||
Args:
|
||||
model_spec: Specifications for the model.
|
||||
label_names: A list of label names for the classes.
|
||||
hparams: The hyperparameters for training object detector.
|
||||
model_options: Options for creating the object detector model.
|
||||
"""
|
||||
super().__init__(
|
||||
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle
|
||||
)
|
||||
self._preprocessor = preprocessor.Preprocessor(model_spec)
|
||||
self._hparams = hparams
|
||||
self._model_options = model_options
|
||||
self._optimizer = self._create_optimizer()
|
||||
self._is_qat = False
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
train_data: ds.Dataset,
|
||||
validation_data: ds.Dataset,
|
||||
options: object_detector_options.ObjectDetectorOptions,
|
||||
) -> 'ObjectDetector':
|
||||
"""Creates and trains an ObjectDetector.
|
||||
|
||||
Loads data and trains the model based on data for object detection.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
options: Configurations for creating and training object detector.
|
||||
|
||||
Returns:
|
||||
An instance of ObjectDetector.
|
||||
"""
|
||||
if options.hparams is None:
|
||||
options.hparams = hp.HParams()
|
||||
|
||||
if options.model_options is None:
|
||||
options.model_options = model_opt.ObjectDetectorModelOptions()
|
||||
|
||||
spec = ms.SupportedModels.get(options.supported_model)
|
||||
object_detector = cls(
|
||||
model_spec=spec,
|
||||
label_names=train_data.label_names,
|
||||
hparams=options.hparams,
|
||||
model_options=options.model_options,
|
||||
)
|
||||
object_detector._create_and_train_model(train_data, validation_data)
|
||||
return object_detector
|
||||
|
||||
def _create_and_train_model(
|
||||
self, train_data: ds.Dataset, validation_data: ds.Dataset
|
||||
):
|
||||
"""Creates and trains the model.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
"""
|
||||
self._create_model()
|
||||
self._train_model(
|
||||
train_data, validation_data, preprocessor=self._preprocessor
|
||||
)
|
||||
self._save_float_ckpt()
|
||||
|
||||
def _create_model(self) -> None:
|
||||
"""Creates the object detector model."""
|
||||
self._model = model_lib.ObjectDetectorModel(
|
||||
model_spec=self._model_spec,
|
||||
model_options=self._model_options,
|
||||
num_classes=self._num_classes,
|
||||
)
|
||||
|
||||
def _save_float_ckpt(self) -> None:
|
||||
"""Saves a checkpoint of the trained float model.
|
||||
|
||||
The default save path is {hparams.export_dir}/float_ckpt. Note that
|
||||
`float_cpt` represents a file prefix, not directory. The resulting files
|
||||
saved to {hparams.export_dir} will be:
|
||||
- float_ckpt.data-00000-of-00001
|
||||
- float_ckpt.index
|
||||
"""
|
||||
save_path = os.path.join(self._hparams.export_dir, 'float_ckpt')
|
||||
if not os.path.exists(self._hparams.export_dir):
|
||||
os.makedirs(self._hparams.export_dir)
|
||||
self._model.save_checkpoint(save_path)
|
||||
|
||||
def restore_float_ckpt(self) -> None:
|
||||
"""Loads a float checkpoint of the model from {hparams.export_dir}/float_ckpt.
|
||||
|
||||
The float checkpoint at {hparams.export_dir}/float_ckpt is automatically
|
||||
saved after training an ObjectDetector using the `create` method. This
|
||||
method is used to restore the trained float checkpoint state of the model in
|
||||
order to run `quantization_aware_training` multiple times. Example usage:
|
||||
|
||||
# Train a model
|
||||
model = object_detector.create(...)
|
||||
# Run QAT
|
||||
model.quantization_aware_training(...)
|
||||
model.evaluate(...)
|
||||
# Restore the float checkpoint to run QAT again
|
||||
model.restore_float_ckpt()
|
||||
# Run QAT with different parameters
|
||||
model.quantization_aware_training(...)
|
||||
model.evaluate(...)
|
||||
"""
|
||||
self._create_model()
|
||||
self._model.load_checkpoint(
|
||||
os.path.join(self._hparams.export_dir, 'float_ckpt'),
|
||||
include_last_layer=True,
|
||||
)
|
||||
self._model.compile()
|
||||
self._is_qat = False
|
||||
|
||||
# TODO: Refactor this method to utilize shared training function
|
||||
def quantization_aware_training(
|
||||
self,
|
||||
train_data: ds.Dataset,
|
||||
validation_data: ds.Dataset,
|
||||
qat_hparams: hp.QATHParams,
|
||||
) -> None:
|
||||
"""Runs quantization aware training(QAT) on the model.
|
||||
|
||||
The QAT step happens after training a regular float model from the `create`
|
||||
method. This additional step will fine-tune the model with a lower precision
|
||||
in order mimic the behavior of a quantized model. The resulting quantized
|
||||
model generally has better performance than a model which is quantized
|
||||
without running QAT. See the following link for more information:
|
||||
- https://www.tensorflow.org/model_optimization/guide/quantization/training
|
||||
|
||||
Just like training the float model using the `create` method, the QAT step
|
||||
also requires some manual tuning of hyperparameters. In order to run QAT
|
||||
more than once for purposes such as hyperparameter tuning, use the
|
||||
`restore_float_ckpt` method to restore the model state to the trained float
|
||||
checkpoint without having to rerun the `create` method.
|
||||
|
||||
Args:
|
||||
train_data: Training dataset.
|
||||
validation_data: Validaiton dataset.
|
||||
qat_hparams: Configuration for QAT.
|
||||
"""
|
||||
self._model.convert_to_qat()
|
||||
learning_rate_fn = tf.keras.optimizers.schedules.ExponentialDecay(
|
||||
qat_hparams.learning_rate * qat_hparams.batch_size / 256,
|
||||
decay_steps=qat_hparams.decay_steps,
|
||||
decay_rate=qat_hparams.decay_rate,
|
||||
staircase=True,
|
||||
)
|
||||
optimizer = tf.keras.optimizers.experimental.SGD(
|
||||
learning_rate=learning_rate_fn, momentum=0.9
|
||||
)
|
||||
if len(train_data) < qat_hparams.batch_size:
|
||||
raise ValueError(
|
||||
f"The size of the train_data {len(train_data)} can't be smaller than"
|
||||
f' batch_size {qat_hparams.batch_size}. To solve this problem, set'
|
||||
' the batch_size smaller or increase the size of the train_data.'
|
||||
)
|
||||
|
||||
train_dataset = train_data.gen_tf_dataset(
|
||||
batch_size=qat_hparams.batch_size,
|
||||
is_training=True,
|
||||
shuffle=self._shuffle,
|
||||
preprocess=self._preprocessor,
|
||||
)
|
||||
steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=None,
|
||||
batch_size=qat_hparams.batch_size,
|
||||
train_data=train_data,
|
||||
)
|
||||
train_dataset = train_dataset.take(count=steps_per_epoch)
|
||||
validation_dataset = validation_data.gen_tf_dataset(
|
||||
batch_size=qat_hparams.batch_size,
|
||||
is_training=False,
|
||||
preprocess=self._preprocessor,
|
||||
)
|
||||
self._model.compile(optimizer=optimizer)
|
||||
self._model.fit(
|
||||
x=train_dataset,
|
||||
epochs=qat_hparams.epochs,
|
||||
steps_per_epoch=None,
|
||||
validation_data=validation_dataset,
|
||||
)
|
||||
self._is_qat = True
|
||||
|
||||
def evaluate(
|
||||
self, dataset: ds.Dataset, batch_size: int = 1
|
||||
) -> Tuple[List[float], Dict[str, float]]:
|
||||
"""Overrides Classifier.evaluate to calculate COCO metrics."""
|
||||
dataset = dataset.gen_tf_dataset(
|
||||
batch_size, is_training=False, preprocess=self._preprocessor
|
||||
)
|
||||
losses = self._model.evaluate(dataset)
|
||||
coco_eval = coco_evaluator.COCOEvaluator(
|
||||
annotation_file=None,
|
||||
include_mask=False,
|
||||
per_category_metrics=True,
|
||||
max_num_eval_detections=100,
|
||||
)
|
||||
for batch in dataset:
|
||||
x, y = batch
|
||||
y_pred = self._model(
|
||||
x,
|
||||
anchor_boxes=y['anchor_boxes'],
|
||||
image_shape=y['image_info'][:, 1, :],
|
||||
training=False,
|
||||
)
|
||||
groundtruths = y['groundtruths']
|
||||
y_pred['image_info'] = groundtruths['image_info']
|
||||
y_pred['source_id'] = groundtruths['source_id']
|
||||
coco_eval.update_state(groundtruths, y_pred)
|
||||
coco_metrics = coco_eval.result()
|
||||
return losses, coco_metrics
|
||||
|
||||
def export_model(
|
||||
self,
|
||||
model_name: str = 'model.tflite',
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||
):
|
||||
"""Converts and saves the model to a TFLite file with metadata included.
|
||||
|
||||
The model export format is automatically set based on whether or not
|
||||
`quantization_aware_training`(QAT) was run. The model exports to float32 by
|
||||
default and will export to an int8 quantized model if QAT was run. To export
|
||||
a float32 model after running QAT, run `restore_float_ckpt` before this
|
||||
method. For custom post-training quantization without QAT, use the
|
||||
quantization_config parameter.
|
||||
|
||||
Note that only the TFLite file is needed for deployment. This function also
|
||||
saves a metadata.json file to the same directory as the TFLite file which
|
||||
can be used to interpret the metadata content in the TFLite file.
|
||||
|
||||
Args:
|
||||
model_name: File name to save TFLite model with metadata. The full export
|
||||
path is {self._hparams.export_dir}/{model_name}.
|
||||
quantization_config: The configuration for model quantization. Note that
|
||||
int8 quantization aware training is automatically applied when possible.
|
||||
This parameter is used to specify other post-training quantization
|
||||
options such as fp16 and int8 without QAT.
|
||||
|
||||
Raises:
|
||||
ValueError: If a custom quantization_config is specified when the model
|
||||
has quantization aware training enabled.
|
||||
"""
|
||||
if quantization_config:
|
||||
if self._is_qat:
|
||||
raise ValueError(
|
||||
'Exporting a qat model with a custom quantization_config is not '
|
||||
'supported.'
|
||||
)
|
||||
else:
|
||||
print(
|
||||
'Exporting with custom post-training-quantization: ',
|
||||
quantization_config,
|
||||
)
|
||||
else:
|
||||
if self._is_qat:
|
||||
print('Exporting a qat int8 model')
|
||||
quantization_config = quantization.QuantizationConfig(
|
||||
inference_input_type=tf.uint8, inference_output_type=tf.uint8
|
||||
)
|
||||
else:
|
||||
print('Exporting a floating point model')
|
||||
|
||||
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
save_path = os.path.join(temp_dir, 'saved_model')
|
||||
self._model.export_saved_model(save_path)
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
|
||||
if quantization_config:
|
||||
converter = quantization_config.set_converter_with_quantization(
|
||||
converter, preprocess=self._preprocessor
|
||||
)
|
||||
|
||||
converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,)
|
||||
tflite_model = converter.convert()
|
||||
|
||||
writer = object_detector_writer.MetadataWriter.create(
|
||||
tflite_model,
|
||||
self._model_spec.mean_rgb,
|
||||
self._model_spec.stddev_rgb,
|
||||
labels=metadata_writer.Labels().add(list(self._label_names)),
|
||||
)
|
||||
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||
with open(metadata_file, 'w') as f:
|
||||
f.write(metadata_json)
|
||||
|
||||
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
|
||||
"""Creates an optimizer with learning rate schedule for regular training.
|
||||
|
||||
Uses Keras PiecewiseConstantDecay schedule by default.
|
||||
|
||||
Returns:
|
||||
A tf.keras.optimizer.Optimizer for model training.
|
||||
"""
|
||||
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
|
||||
lr_values = [init_lr] + [
|
||||
init_lr * m for m in self._hparams.learning_rate_decay_multipliers
|
||||
]
|
||||
learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
|
||||
self._hparams.learning_rate_boundaries, lr_values
|
||||
)
|
||||
return tf.keras.optimizers.experimental.SGD(
|
||||
learning_rate=learning_rate_fn, momentum=0.9
|
||||
)
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Demo for making an object detector model by MediaPipe Model Maker."""
|
||||
|
||||
import os
|
||||
# Dependency imports
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
from mediapipe.model_maker.python.vision import object_detector
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/object_detector/testdata/coco_data'
|
||||
|
||||
|
||||
def define_flags() -> None:
|
||||
"""Define flags for the object detection model maker demo."""
|
||||
flags.DEFINE_string(
|
||||
'export_dir', None, 'The directory to save exported files.'
|
||||
)
|
||||
flags.DEFINE_string(
|
||||
'input_data_dir',
|
||||
None,
|
||||
"""The directory with input training data. If the training data is not
|
||||
specified, the pipeline will use the test dataset.""",
|
||||
)
|
||||
flags.DEFINE_bool('qat', True, 'Whether or not to do QAT.')
|
||||
flags.mark_flag_as_required('export_dir')
|
||||
|
||||
|
||||
def run(data_dir: str, export_dir: str, qat: bool):
|
||||
"""Runs demo."""
|
||||
data = object_detector.Dataset.from_coco_folder(data_dir)
|
||||
train_data, rest_data = data.split(0.6)
|
||||
validation_data, test_data = rest_data.split(0.5)
|
||||
|
||||
hparams = object_detector.HParams(batch_size=1, export_dir=export_dir)
|
||||
options = object_detector.ObjectDetectorOptions(
|
||||
supported_model=object_detector.SupportedModels.MOBILENET_V2,
|
||||
hparams=hparams,
|
||||
)
|
||||
model = object_detector.ObjectDetector.create(
|
||||
train_data=train_data, validation_data=validation_data, options=options
|
||||
)
|
||||
loss, coco_metrics = model.evaluate(test_data, batch_size=1)
|
||||
print(f'Evaluation loss:{loss}, coco_metrics:{coco_metrics}')
|
||||
if qat:
|
||||
qat_hparams = object_detector.QATHParams(batch_size=1)
|
||||
model.quantization_aware_training(train_data, validation_data, qat_hparams)
|
||||
qat_loss, qat_coco_metrics = model.evaluate(test_data, batch_size=1)
|
||||
print(f'QAT Evaluation loss:{qat_loss}, coco_metrics:{qat_coco_metrics}')
|
||||
|
||||
model.export_model()
|
||||
|
||||
|
||||
def main(_) -> None:
|
||||
logging.set_verbosity(logging.INFO)
|
||||
|
||||
if FLAGS.input_data_dir is None:
|
||||
data_dir = os.path.join(FLAGS.test_srcdir, TEST_DATA_DIR)
|
||||
else:
|
||||
data_dir = FLAGS.input_data_dir
|
||||
|
||||
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||
run(data_dir=data_dir, export_dir=export_dir, qat=FLAGS.qat)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
define_flags()
|
||||
app.run(main)
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Options for building object detector."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ObjectDetectorOptions:
|
||||
"""Configurable options for building object detector.
|
||||
|
||||
Attributes:
|
||||
supported_model: A model from the SupportedModels enum.
|
||||
model_options: A set of options for configuring the selected model.
|
||||
hparams: A set of hyperparameters used to train the object detector.
|
||||
"""
|
||||
|
||||
supported_model: model_spec.SupportedModels
|
||||
model_options: Optional[model_opt.ObjectDetectorModelOptions] = None
|
||||
hparams: Optional[hyperparameters.HParams] = None
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset
|
||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||
|
||||
|
||||
class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
||||
cache_dir = self.create_tempdir()
|
||||
self.data = dataset.Dataset.from_coco_folder(
|
||||
dataset_folder, cache_dir=cache_dir
|
||||
)
|
||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||
# condition when downloading model since these tests may run in parallel.
|
||||
mock_gettempdir = unittest_mock.patch.object(
|
||||
tempfile,
|
||||
'gettempdir',
|
||||
return_value=self.create_tempdir(),
|
||||
autospec=True,
|
||||
)
|
||||
self.mock_gettempdir = mock_gettempdir.start()
|
||||
self.addCleanup(mock_gettempdir.stop)
|
||||
|
||||
def test_object_detector(self):
|
||||
hparams = hyperparameters.HParams(
|
||||
epochs=10,
|
||||
batch_size=2,
|
||||
learning_rate=0.9,
|
||||
shuffle=False,
|
||||
export_dir=self.create_tempdir(),
|
||||
)
|
||||
options = object_detector_options.ObjectDetectorOptions(
|
||||
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
|
||||
)
|
||||
# Test `create``
|
||||
model = object_detector.ObjectDetector.create(
|
||||
train_data=self.data, validation_data=self.data, options=options
|
||||
)
|
||||
losses, coco_metrics = model.evaluate(self.data)
|
||||
self._assert_ap_greater(coco_metrics)
|
||||
self.assertFalse(model._is_qat)
|
||||
# Test float export_model
|
||||
model.export_model()
|
||||
output_metadata_file = os.path.join(
|
||||
options.hparams.export_dir, 'metadata.json'
|
||||
)
|
||||
output_tflite_file = os.path.join(
|
||||
options.hparams.export_dir, 'model.tflite'
|
||||
)
|
||||
print('ASDF float', os.path.getsize(output_tflite_file))
|
||||
self.assertTrue(os.path.exists(output_tflite_file))
|
||||
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||
self.assertTrue(os.path.exists(output_metadata_file))
|
||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
|
||||
# Test `quantization_aware_training`
|
||||
qat_hparams = hyperparameters.QATHParams(
|
||||
learning_rate=0.9,
|
||||
batch_size=2,
|
||||
epochs=5,
|
||||
decay_steps=6,
|
||||
decay_rate=0.96,
|
||||
)
|
||||
model.quantization_aware_training(self.data, self.data, qat_hparams)
|
||||
qat_losses, qat_coco_metrics = model.evaluate(self.data)
|
||||
self._assert_ap_greater(qat_coco_metrics)
|
||||
self.assertNotAllEqual(losses, qat_losses)
|
||||
self.assertTrue(model._is_qat)
|
||||
model.export_model('model_qat.tflite')
|
||||
output_metadata_file = os.path.join(
|
||||
options.hparams.export_dir, 'metadata.json'
|
||||
)
|
||||
output_tflite_file = os.path.join(
|
||||
options.hparams.export_dir, 'model_qat.tflite'
|
||||
)
|
||||
print('ASDF qat', os.path.getsize(output_tflite_file))
|
||||
self.assertTrue(os.path.exists(output_tflite_file))
|
||||
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||
self.assertLess(os.path.getsize(output_tflite_file), 3500000)
|
||||
self.assertTrue(os.path.exists(output_metadata_file))
|
||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
|
||||
# Load float ckpt test
|
||||
model.restore_float_ckpt()
|
||||
losses_2, _ = model.evaluate(self.data)
|
||||
self.assertAllEqual(losses, losses_2)
|
||||
self.assertNotAllEqual(qat_losses, losses_2)
|
||||
self.assertFalse(model._is_qat)
|
||||
|
||||
def _assert_ap_greater(self, coco_metrics, threshold=0.0):
|
||||
self.assertGreaterEqual(coco_metrics['AP'], threshold)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Preprocessor for object detector."""
|
||||
from typing import Any, Mapping, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||
from official.vision.dataloaders import utils
|
||||
from official.vision.ops import anchor
|
||||
from official.vision.ops import box_ops
|
||||
from official.vision.ops import preprocess_ops
|
||||
|
||||
|
||||
# TODO Combine preprocessing logic with image_preprocessor.
|
||||
class Preprocessor(object):
|
||||
"""Preprocessor for object detector."""
|
||||
|
||||
def __init__(self, model_spec: ms.ModelSpec):
|
||||
"""Initialize a Preprocessor."""
|
||||
self._mean_norm = model_spec.mean_norm
|
||||
self._stddev_norm = model_spec.stddev_norm
|
||||
self._output_size = model_spec.input_image_shape[:2]
|
||||
self._min_level = 3
|
||||
self._max_level = 7
|
||||
self._num_scales = 3
|
||||
self._aspect_ratios = [0.5, 1, 2]
|
||||
self._anchor_size = 3
|
||||
self._dtype = tf.float32
|
||||
self._match_threshold = 0.5
|
||||
self._unmatched_threshold = 0.5
|
||||
self._aug_scale_min = 0.5
|
||||
self._aug_scale_max = 2.0
|
||||
self._max_num_instances = 100
|
||||
|
||||
def __call__(
|
||||
self, data: Mapping[str, Any], is_training: bool = True
|
||||
) -> Tuple[tf.Tensor, Mapping[str, Any]]:
|
||||
"""Run the preprocessor on an example.
|
||||
|
||||
The data dict should contain the following keys always:
|
||||
- image
|
||||
- groundtruth_classes
|
||||
- groundtruth_boxes
|
||||
- groundtruth_is_crowd
|
||||
Additional keys needed when is_training is set to True:
|
||||
- groundtruth_area
|
||||
- source_id
|
||||
- height
|
||||
- width
|
||||
|
||||
Args:
|
||||
data: A dict of object detector inputs.
|
||||
is_training: Whether or not the data is used for training.
|
||||
|
||||
Returns:
|
||||
A tuple of (image, labels) where image is a Tensor and labels is a dict.
|
||||
"""
|
||||
classes = data['groundtruth_classes']
|
||||
boxes = data['groundtruth_boxes']
|
||||
|
||||
# Get original image.
|
||||
image = data['image']
|
||||
image_shape = tf.shape(input=image)[0:2]
|
||||
|
||||
# Normalize image with mean and std pixel values.
|
||||
image = preprocess_ops.normalize_image(
|
||||
image, self._mean_norm, self._stddev_norm
|
||||
)
|
||||
|
||||
# Flip image randomly during training.
|
||||
if is_training:
|
||||
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
|
||||
|
||||
# Convert boxes from normalized coordinates to pixel coordinates.
|
||||
boxes = box_ops.denormalize_boxes(boxes, image_shape)
|
||||
|
||||
# Resize and crop image.
|
||||
image, image_info = preprocess_ops.resize_and_crop_image(
|
||||
image,
|
||||
self._output_size,
|
||||
padded_size=preprocess_ops.compute_padded_size(
|
||||
self._output_size, 2**self._max_level
|
||||
),
|
||||
aug_scale_min=(self._aug_scale_min if is_training else 1.0),
|
||||
aug_scale_max=(self._aug_scale_max if is_training else 1.0),
|
||||
)
|
||||
image_height, image_width, _ = image.get_shape().as_list()
|
||||
|
||||
# Resize and crop boxes.
|
||||
image_scale = image_info[2, :]
|
||||
offset = image_info[3, :]
|
||||
boxes = preprocess_ops.resize_and_crop_boxes(
|
||||
boxes, image_scale, image_info[1, :], offset
|
||||
)
|
||||
# Filter out ground-truth boxes that are all zeros.
|
||||
indices = box_ops.get_non_empty_box_indices(boxes)
|
||||
boxes = tf.gather(boxes, indices)
|
||||
classes = tf.gather(classes, indices)
|
||||
|
||||
# Assign anchors.
|
||||
input_anchor = anchor.build_anchor_generator(
|
||||
min_level=self._min_level,
|
||||
max_level=self._max_level,
|
||||
num_scales=self._num_scales,
|
||||
aspect_ratios=self._aspect_ratios,
|
||||
anchor_size=self._anchor_size,
|
||||
)
|
||||
anchor_boxes = input_anchor(image_size=(image_height, image_width))
|
||||
anchor_labeler = anchor.AnchorLabeler(
|
||||
self._match_threshold, self._unmatched_threshold
|
||||
)
|
||||
(cls_targets, box_targets, _, cls_weights, box_weights) = (
|
||||
anchor_labeler.label_anchors(
|
||||
anchor_boxes, boxes, tf.expand_dims(classes, axis=1)
|
||||
)
|
||||
)
|
||||
|
||||
# Cast input image to desired data type.
|
||||
image = tf.cast(image, dtype=self._dtype)
|
||||
|
||||
# Pack labels for model_fn outputs.
|
||||
labels = {
|
||||
'cls_targets': cls_targets,
|
||||
'box_targets': box_targets,
|
||||
'anchor_boxes': anchor_boxes,
|
||||
'cls_weights': cls_weights,
|
||||
'box_weights': box_weights,
|
||||
'image_info': image_info,
|
||||
}
|
||||
if not is_training:
|
||||
groundtruths = {
|
||||
'source_id': data['source_id'],
|
||||
'height': data['height'],
|
||||
'width': data['width'],
|
||||
'num_detections': tf.shape(data['groundtruth_classes']),
|
||||
'image_info': image_info,
|
||||
'boxes': box_ops.denormalize_boxes(
|
||||
data['groundtruth_boxes'], image_shape
|
||||
),
|
||||
'classes': data['groundtruth_classes'],
|
||||
'areas': data['groundtruth_area'],
|
||||
'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
|
||||
}
|
||||
groundtruths['source_id'] = utils.process_source_id(
|
||||
groundtruths['source_id']
|
||||
)
|
||||
groundtruths = utils.pad_groundtruths_to_fixed_size(
|
||||
groundtruths, self._max_num_instances
|
||||
)
|
||||
labels.update({'groundtruths': groundtruths})
|
||||
return image, labels
|
|
@ -0,0 +1,158 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.core import test_utils
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.object_detector import preprocessor as preprocessor_lib
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
||||
MAX_IMAGE_SIZE = 360
|
||||
OUTPUT_SIZE = 256
|
||||
NUM_CLASSES = 10
|
||||
NUM_EXAMPLES = 3
|
||||
MIN_LEVEL = 3
|
||||
MAX_LEVEL = 7
|
||||
NUM_SCALES = 3
|
||||
ASPECT_RATIOS = [0.5, 1, 2]
|
||||
MAX_NUM_INSTANCES = 100
|
||||
|
||||
def _get_rand_example(self):
|
||||
num_annotations = random.randint(1, 3)
|
||||
bboxes, classes, is_crowds = [], [], []
|
||||
image_size = random.randint(10, self.MAX_IMAGE_SIZE + 1)
|
||||
rgb = [random.uniform(0, 255) for _ in range(3)]
|
||||
image = test_utils.fill_image(rgb, image_size)
|
||||
for _ in range(num_annotations):
|
||||
x1, x2 = random.uniform(0, image_size), random.uniform(0, image_size)
|
||||
y1, y2 = random.uniform(0, image_size), random.uniform(0, image_size)
|
||||
bbox = [min(x1, x2), min(y1, y2), abs(x1 - x2), abs(y1 - y2)]
|
||||
bboxes.append(bbox)
|
||||
classes.append(random.randint(0, self.NUM_CLASSES - 1))
|
||||
is_crowds.append(0)
|
||||
return {
|
||||
'image': tf.cast(image, dtype=tf.float32),
|
||||
'groundtruth_boxes': tf.cast(bboxes, dtype=tf.float32),
|
||||
'groundtruth_classes': tf.cast(classes, dtype=tf.int64),
|
||||
'groundtruth_is_crowd': tf.cast(is_crowds, dtype=tf.bool),
|
||||
'groundtruth_area': tf.cast(is_crowds, dtype=tf.float32),
|
||||
'source_id': tf.cast(1, dtype=tf.int64),
|
||||
'height': tf.cast(image_size, dtype=tf.int64),
|
||||
'width': tf.cast(image_size, dtype=tf.int64),
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
dataset = [self._get_rand_example() for _ in range(self.NUM_EXAMPLES)]
|
||||
|
||||
def my_generator(data):
|
||||
for item in data:
|
||||
yield item
|
||||
|
||||
self.dataset = tf.data.Dataset.from_generator(
|
||||
lambda: my_generator(dataset),
|
||||
output_types={
|
||||
'image': tf.float32,
|
||||
'groundtruth_classes': tf.int64,
|
||||
'groundtruth_boxes': tf.float32,
|
||||
'groundtruth_is_crowd': tf.bool,
|
||||
'groundtruth_area': tf.float32,
|
||||
'source_id': tf.int64,
|
||||
'height': tf.int64,
|
||||
'width': tf.int64,
|
||||
},
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='training',
|
||||
is_training=True,
|
||||
),
|
||||
dict(
|
||||
testcase_name='evaluation',
|
||||
is_training=False,
|
||||
),
|
||||
)
|
||||
def test_preprocessor(self, is_training):
|
||||
model_spec = ms.SupportedModels.MOBILENET_V2.value()
|
||||
labels_keys = [
|
||||
'cls_targets',
|
||||
'box_targets',
|
||||
'anchor_boxes',
|
||||
'cls_weights',
|
||||
'box_weights',
|
||||
'image_info',
|
||||
]
|
||||
if not is_training:
|
||||
labels_keys.append('groundtruths')
|
||||
preprocessor = preprocessor_lib.Preprocessor(model_spec)
|
||||
for example in self.dataset:
|
||||
result = preprocessor(example, is_training=is_training)
|
||||
image, labels = result
|
||||
self.assertAllEqual(image.shape, (256, 256, 3))
|
||||
self.assertCountEqual(labels_keys, labels.keys())
|
||||
np_labels = tf.nest.map_structure(lambda x: x.numpy(), labels)
|
||||
# Checks shapes of `image_info` and `anchor_boxes`.
|
||||
self.assertEqual(np_labels['image_info'].shape, (4, 2))
|
||||
n_anchors = 0
|
||||
for level in range(self.MIN_LEVEL, self.MAX_LEVEL + 1):
|
||||
stride = 2**level
|
||||
output_size_l = [self.OUTPUT_SIZE / stride, self.OUTPUT_SIZE / stride]
|
||||
anchors_per_location = self.NUM_SCALES * len(self.ASPECT_RATIOS)
|
||||
self.assertEqual(
|
||||
list(np_labels['anchor_boxes'][str(level)].shape),
|
||||
[output_size_l[0], output_size_l[1], 4 * anchors_per_location],
|
||||
)
|
||||
n_anchors += output_size_l[0] * output_size_l[1] * anchors_per_location
|
||||
# Checks shapes of training objectives.
|
||||
self.assertEqual(np_labels['cls_weights'].shape, (int(n_anchors),))
|
||||
for level in range(self.MIN_LEVEL, self.MAX_LEVEL + 1):
|
||||
stride = 2**level
|
||||
output_size_l = [self.OUTPUT_SIZE / stride, self.OUTPUT_SIZE / stride]
|
||||
anchors_per_location = self.NUM_SCALES * len(self.ASPECT_RATIOS)
|
||||
self.assertEqual(
|
||||
list(np_labels['cls_targets'][str(level)].shape),
|
||||
[output_size_l[0], output_size_l[1], anchors_per_location],
|
||||
)
|
||||
self.assertEqual(
|
||||
list(np_labels['box_targets'][str(level)].shape),
|
||||
[output_size_l[0], output_size_l[1], 4 * anchors_per_location],
|
||||
)
|
||||
# Checks shape of groundtruths for eval.
|
||||
if not is_training:
|
||||
self.assertEqual(np_labels['groundtruths']['source_id'].shape, ())
|
||||
self.assertEqual(
|
||||
np_labels['groundtruths']['classes'].shape,
|
||||
(self.MAX_NUM_INSTANCES,),
|
||||
)
|
||||
self.assertEqual(
|
||||
np_labels['groundtruths']['boxes'].shape,
|
||||
(self.MAX_NUM_INSTANCES, 4),
|
||||
)
|
||||
self.assertEqual(
|
||||
np_labels['groundtruths']['areas'].shape, (self.MAX_NUM_INSTANCES,)
|
||||
)
|
||||
self.assertEqual(
|
||||
np_labels['groundtruths']['is_crowds'].shape,
|
||||
(self.MAX_NUM_INSTANCES,),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
BIN
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000072.jpg
vendored
Normal file
After Width: | Height: | Size: 81 KiB |
BIN
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000078.jpg
vendored
Normal file
After Width: | Height: | Size: 72 KiB |
BIN
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000315.jpg
vendored
Normal file
After Width: | Height: | Size: 53 KiB |
BIN
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000431.jpg
vendored
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/images/000000000446.jpg
vendored
Normal file
After Width: | Height: | Size: 36 KiB |
1
mediapipe/model_maker/python/vision/object_detector/testdata/coco_data/labels.json
vendored
Normal file
|
@ -0,0 +1,34 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<folder>images</folder>
|
||||
<filename>37ca2a3d-IMG_0520.jpg</filename>
|
||||
<source>
|
||||
<database>MyDatabase</database>
|
||||
<annotation>COCO2017</annotation>
|
||||
<image>flickr</image>
|
||||
<flickrid>NULL</flickrid>
|
||||
<annotator>1</annotator>
|
||||
</source>
|
||||
<owner>
|
||||
<flickrid>NULL</flickrid>
|
||||
<name>Label Studio</name>
|
||||
</owner>
|
||||
<size>
|
||||
<width>800</width>
|
||||
<height>600</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<segmented>0</segmented>
|
||||
<object>
|
||||
<name>pig_android</name>
|
||||
<pose>Unspecified</pose>
|
||||
<truncated>0</truncated>
|
||||
<difficult>0</difficult>
|
||||
<bndbox>
|
||||
<xmin>242</xmin>
|
||||
<ymin>17</ymin>
|
||||
<xmax>556</xmax>
|
||||
<ymax>476</ymax>
|
||||
</bndbox>
|
||||
</object>
|
||||
</annotation>
|
|
@ -0,0 +1,34 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<folder>images</folder>
|
||||
<filename>3d3382d3-IMG_0514.jpg</filename>
|
||||
<source>
|
||||
<database>MyDatabase</database>
|
||||
<annotation>COCO2017</annotation>
|
||||
<image>flickr</image>
|
||||
<flickrid>NULL</flickrid>
|
||||
<annotator>1</annotator>
|
||||
</source>
|
||||
<owner>
|
||||
<flickrid>NULL</flickrid>
|
||||
<name>Label Studio</name>
|
||||
</owner>
|
||||
<size>
|
||||
<width>800</width>
|
||||
<height>600</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<segmented>0</segmented>
|
||||
<object>
|
||||
<name>android</name>
|
||||
<pose>Unspecified</pose>
|
||||
<truncated>0</truncated>
|
||||
<difficult>0</difficult>
|
||||
<bndbox>
|
||||
<xmin>306</xmin>
|
||||
<ymin>130</ymin>
|
||||
<xmax>550</xmax>
|
||||
<ymax>471</ymax>
|
||||
</bndbox>
|
||||
</object>
|
||||
</annotation>
|
|
@ -0,0 +1,46 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<folder>images</folder>
|
||||
<filename>d1c65813-IMG_0546.jpg</filename>
|
||||
<source>
|
||||
<database>MyDatabase</database>
|
||||
<annotation>COCO2017</annotation>
|
||||
<image>flickr</image>
|
||||
<flickrid>NULL</flickrid>
|
||||
<annotator>1</annotator>
|
||||
</source>
|
||||
<owner>
|
||||
<flickrid>NULL</flickrid>
|
||||
<name>Label Studio</name>
|
||||
</owner>
|
||||
<size>
|
||||
<width>800</width>
|
||||
<height>600</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<segmented>0</segmented>
|
||||
<object>
|
||||
<name>android</name>
|
||||
<pose>Unspecified</pose>
|
||||
<truncated>0</truncated>
|
||||
<difficult>0</difficult>
|
||||
<bndbox>
|
||||
<xmin>93</xmin>
|
||||
<ymin>101</ymin>
|
||||
<xmax>358</xmax>
|
||||
<ymax>378</ymax>
|
||||
</bndbox>
|
||||
</object>
|
||||
<object>
|
||||
<name>pig_android</name>
|
||||
<pose>Unspecified</pose>
|
||||
<truncated>0</truncated>
|
||||
<difficult>0</difficult>
|
||||
<bndbox>
|
||||
<xmin>438</xmin>
|
||||
<ymin>28</ymin>
|
||||
<xmax>654</xmax>
|
||||
<ymax>296</ymax>
|
||||
</bndbox>
|
||||
</object>
|
||||
</annotation>
|
|
@ -0,0 +1,46 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<folder>images</folder>
|
||||
<filename>d86b20e0-IMG_0509.jpg</filename>
|
||||
<source>
|
||||
<database>MyDatabase</database>
|
||||
<annotation>COCO2017</annotation>
|
||||
<image>flickr</image>
|
||||
<flickrid>NULL</flickrid>
|
||||
<annotator>1</annotator>
|
||||
</source>
|
||||
<owner>
|
||||
<flickrid>NULL</flickrid>
|
||||
<name>Label Studio</name>
|
||||
</owner>
|
||||
<size>
|
||||
<width>800</width>
|
||||
<height>600</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<segmented>0</segmented>
|
||||
<object>
|
||||
<name>android</name>
|
||||
<pose>Unspecified</pose>
|
||||
<truncated>0</truncated>
|
||||
<difficult>0</difficult>
|
||||
<bndbox>
|
||||
<xmin>7</xmin>
|
||||
<ymin>122</ymin>
|
||||
<xmax>296</xmax>
|
||||
<ymax>402</ymax>
|
||||
</bndbox>
|
||||
</object>
|
||||
<object>
|
||||
<name>pig_android</name>
|
||||
<pose>Unspecified</pose>
|
||||
<truncated>0</truncated>
|
||||
<difficult>0</difficult>
|
||||
<bndbox>
|
||||
<xmin>523</xmin>
|
||||
<ymin>69</ymin>
|
||||
<xmax>723</xmax>
|
||||
<ymax>329</ymax>
|
||||
</bndbox>
|
||||
</object>
|
||||
</annotation>
|
After Width: | Height: | Size: 54 KiB |
After Width: | Height: | Size: 73 KiB |
After Width: | Height: | Size: 110 KiB |
After Width: | Height: | Size: 120 KiB |