Internal change
PiperOrigin-RevId: 547404737
This commit is contained in:
parent
917af2ce6b
commit
3e93cbc838
|
@ -57,3 +57,14 @@ py_test(
|
||||||
srcs = ["classification_dataset_test.py"],
|
srcs = ["classification_dataset_test.py"],
|
||||||
deps = [":classification_dataset"],
|
deps = [":classification_dataset"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "cache_files",
|
||||||
|
srcs = ["cache_files.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "cache_files_test",
|
||||||
|
srcs = ["cache_files_test.py"],
|
||||||
|
deps = [":cache_files"],
|
||||||
|
)
|
||||||
|
|
112
mediapipe/model_maker/python/core/data/cache_files.py
Normal file
112
mediapipe/model_maker/python/core/data/cache_files.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Common TFRecord cache files library."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Mapping, Sequence
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
# Suffix of the meta data file name.
|
||||||
|
METADATA_FILE_SUFFIX = '_metadata.yaml'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class TFRecordCacheFiles:
|
||||||
|
"""TFRecordCacheFiles dataclass to store and load cached TFRecord files.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
cache_prefix_filename: The cache prefix filename. This is usually provided
|
||||||
|
as a hash of the original data source to avoid different data sources
|
||||||
|
resulting in the same cache file.
|
||||||
|
cache_dir: The cache directory to save TFRecord and metadata file. When
|
||||||
|
cache_dir is None, a temporary folder will be created and will not be
|
||||||
|
removed automatically after training which makes it can be used later.
|
||||||
|
num_shards: Number of shards for output tfrecord files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cache_prefix_filename: str = 'cache_prefix'
|
||||||
|
cache_dir: str = dataclasses.field(default_factory=tempfile.mkdtemp)
|
||||||
|
num_shards: int = 1
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.cache_prefix_filename:
|
||||||
|
raise ValueError('cache_prefix_filename cannot be empty.')
|
||||||
|
if self.num_shards <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f'num_shards must be greater than 0, got {self.num_shards}'
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_prefix(self) -> str:
|
||||||
|
"""The cache prefix including the cache directory and the cache prefix filename."""
|
||||||
|
return os.path.join(self.cache_dir, self.cache_prefix_filename)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tfrecord_files(self) -> Sequence[str]:
|
||||||
|
"""The TFRecord files."""
|
||||||
|
tfrecord_files = [
|
||||||
|
self.cache_prefix + '-%05d-of-%05d.tfrecord' % (i, self.num_shards)
|
||||||
|
for i in range(self.num_shards)
|
||||||
|
]
|
||||||
|
return tfrecord_files
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata_file(self) -> str:
|
||||||
|
"""The metadata file."""
|
||||||
|
return self.cache_prefix + METADATA_FILE_SUFFIX
|
||||||
|
|
||||||
|
def get_writers(self) -> Sequence[tf.io.TFRecordWriter]:
|
||||||
|
"""Gets an array of TFRecordWriter objects.
|
||||||
|
|
||||||
|
Note that these writers should each be closed using .close() when done.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array of TFRecordWriter objects
|
||||||
|
"""
|
||||||
|
if not tf.io.gfile.exists(self.cache_dir):
|
||||||
|
tf.io.gfile.makedirs(self.cache_dir)
|
||||||
|
return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files]
|
||||||
|
|
||||||
|
def save_metadata(self, metadata):
|
||||||
|
"""Writes metadata to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: A dictionary of metadata content to write. Exact format is
|
||||||
|
dependent on the specific dataset, but typically includes a 'size' and
|
||||||
|
'label_names' entry.
|
||||||
|
"""
|
||||||
|
with tf.io.gfile.GFile(self.metadata_file, 'w') as f:
|
||||||
|
yaml.dump(metadata, f)
|
||||||
|
|
||||||
|
def load_metadata(self) -> Mapping[Any, Any]:
|
||||||
|
"""Reads metadata from file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary object containing metadata
|
||||||
|
"""
|
||||||
|
if not tf.io.gfile.exists(self.metadata_file):
|
||||||
|
return {}
|
||||||
|
with tf.io.gfile.GFile(self.metadata_file, 'r') as f:
|
||||||
|
metadata = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def is_cached(self) -> bool:
|
||||||
|
"""Checks whether this CacheFiles is already cached."""
|
||||||
|
all_cached_files = list(self.tfrecord_files) + [self.metadata_file]
|
||||||
|
return all(tf.io.gfile.exists(f) for f in all_cached_files)
|
77
mediapipe/model_maker/python/core/data/cache_files_test.py
Normal file
77
mediapipe/model_maker/python/core/data/cache_files_test.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files
|
||||||
|
|
||||||
|
|
||||||
|
class CacheFilesTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_tfrecord_cache_files(self):
|
||||||
|
cf = cache_files.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename='tfrecord',
|
||||||
|
cache_dir='/tmp/cache_dir',
|
||||||
|
num_shards=2,
|
||||||
|
)
|
||||||
|
self.assertEqual(cf.cache_prefix, '/tmp/cache_dir/tfrecord')
|
||||||
|
self.assertEqual(
|
||||||
|
cf.metadata_file,
|
||||||
|
'/tmp/cache_dir/tfrecord' + cache_files.METADATA_FILE_SUFFIX,
|
||||||
|
)
|
||||||
|
expected_tfrecord_files = [
|
||||||
|
'/tmp/cache_dir/tfrecord-%05d-of-%05d.tfrecord' % (i, 2)
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
self.assertEqual(cf.tfrecord_files, expected_tfrecord_files)
|
||||||
|
|
||||||
|
# Writing TFRecord Files
|
||||||
|
self.assertFalse(cf.is_cached())
|
||||||
|
for tfrecord_file in cf.tfrecord_files:
|
||||||
|
self.assertFalse(tf.io.gfile.exists(tfrecord_file))
|
||||||
|
writers = cf.get_writers()
|
||||||
|
for writer in writers:
|
||||||
|
writer.close()
|
||||||
|
for tfrecord_file in cf.tfrecord_files:
|
||||||
|
self.assertTrue(tf.io.gfile.exists(tfrecord_file))
|
||||||
|
self.assertFalse(cf.is_cached())
|
||||||
|
|
||||||
|
# Writing Metadata Files
|
||||||
|
original_metadata = {'size': 10, 'label_names': ['label1', 'label2']}
|
||||||
|
cf.save_metadata(original_metadata)
|
||||||
|
self.assertTrue(cf.is_cached())
|
||||||
|
metadata = cf.load_metadata()
|
||||||
|
self.assertEqual(metadata, original_metadata)
|
||||||
|
|
||||||
|
def test_recordio_cache_files_error(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'cache_prefix_filename cannot be empty'
|
||||||
|
):
|
||||||
|
cache_files.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename='',
|
||||||
|
cache_dir='/tmp/cache_dir',
|
||||||
|
num_shards=2,
|
||||||
|
)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'num_shards must be greater than 0, got 0'
|
||||||
|
):
|
||||||
|
cache_files.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename='tfrecord',
|
||||||
|
cache_dir='/tmp/cache_dir',
|
||||||
|
num_shards=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -54,6 +54,7 @@ py_library(
|
||||||
srcs = ["dataset.py"],
|
srcs = ["dataset.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset_util",
|
":dataset_util",
|
||||||
|
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -73,6 +74,7 @@ py_test(
|
||||||
py_library(
|
py_library(
|
||||||
name = "dataset_util",
|
name = "dataset_util",
|
||||||
srcs = ["dataset_util.py"],
|
srcs = ["dataset_util.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/core/data:cache_files"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files
|
||||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
||||||
from official.vision.dataloaders import tf_example_decoder
|
from official.vision.dataloaders import tf_example_decoder
|
||||||
|
@ -76,14 +76,16 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
ValueError: If the label_name for id 0 is set to something other than
|
ValueError: If the label_name for id 0 is set to something other than
|
||||||
the 'background' class.
|
the 'background' class.
|
||||||
"""
|
"""
|
||||||
cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir)
|
tfrecord_cache_files = dataset_util.get_cache_files_coco(
|
||||||
if not dataset_util.is_cached(cache_files):
|
data_dir, cache_dir
|
||||||
|
)
|
||||||
|
if not tfrecord_cache_files.is_cached():
|
||||||
label_map = dataset_util.get_label_map_coco(data_dir)
|
label_map = dataset_util.get_label_map_coco(data_dir)
|
||||||
cache_writer = dataset_util.COCOCacheFilesWriter(
|
cache_writer = dataset_util.COCOCacheFilesWriter(
|
||||||
label_map=label_map, max_num_images=max_num_images
|
label_map=label_map, max_num_images=max_num_images
|
||||||
)
|
)
|
||||||
cache_writer.write_files(cache_files, data_dir)
|
cache_writer.write_files(tfrecord_cache_files, data_dir)
|
||||||
return cls.from_cache(cache_files.cache_prefix)
|
return cls.from_cache(tfrecord_cache_files)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pascal_voc_folder(
|
def from_pascal_voc_folder(
|
||||||
|
@ -134,47 +136,48 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the input data directory is empty.
|
ValueError: if the input data directory is empty.
|
||||||
"""
|
"""
|
||||||
cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir)
|
tfrecord_cache_files = dataset_util.get_cache_files_pascal_voc(
|
||||||
if not dataset_util.is_cached(cache_files):
|
data_dir, cache_dir
|
||||||
|
)
|
||||||
|
if not tfrecord_cache_files.is_cached():
|
||||||
label_map = dataset_util.get_label_map_pascal_voc(data_dir)
|
label_map = dataset_util.get_label_map_pascal_voc(data_dir)
|
||||||
cache_writer = dataset_util.PascalVocCacheFilesWriter(
|
cache_writer = dataset_util.PascalVocCacheFilesWriter(
|
||||||
label_map=label_map, max_num_images=max_num_images
|
label_map=label_map, max_num_images=max_num_images
|
||||||
)
|
)
|
||||||
cache_writer.write_files(cache_files, data_dir)
|
cache_writer.write_files(tfrecord_cache_files, data_dir)
|
||||||
|
|
||||||
return cls.from_cache(cache_files.cache_prefix)
|
return cls.from_cache(tfrecord_cache_files)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cache(cls, cache_prefix: str) -> 'Dataset':
|
def from_cache(
|
||||||
|
cls, tfrecord_cache_files: cache_files.TFRecordCacheFiles
|
||||||
|
) -> 'Dataset':
|
||||||
"""Loads the TFRecord data from cache.
|
"""Loads the TFRecord data from cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_prefix: The cache prefix including the cache directory and the cache
|
tfrecord_cache_files: The TFRecordCacheFiles object containing the already
|
||||||
prefix filename, e.g: '/tmp/cache/train'.
|
cached TFRecord and metadata files.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ObjectDetectorDataset object.
|
ObjectDetectorDataset object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if tfrecord_cache_files are not already cached.
|
||||||
"""
|
"""
|
||||||
# Get TFRecord Files
|
if not tfrecord_cache_files.is_cached():
|
||||||
tfrecord_file_pattern = cache_prefix + '*.tfrecord'
|
raise ValueError(
|
||||||
matched_files = tf.io.gfile.glob(tfrecord_file_pattern)
|
'Cache files must be already cached to use the from_cache method.'
|
||||||
if not matched_files:
|
)
|
||||||
raise ValueError('TFRecord files are empty.')
|
|
||||||
|
|
||||||
# Load meta_data.
|
metadata = tfrecord_cache_files.load_metadata()
|
||||||
meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX
|
|
||||||
if not tf.io.gfile.exists(meta_data_file):
|
|
||||||
raise ValueError("Metadata file %s doesn't exist." % meta_data_file)
|
|
||||||
with tf.io.gfile.GFile(meta_data_file, 'r') as f:
|
|
||||||
meta_data = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
|
|
||||||
dataset = tf.data.TFRecordDataset(matched_files)
|
dataset = tf.data.TFRecordDataset(tfrecord_cache_files.tfrecord_files)
|
||||||
decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False)
|
decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False)
|
||||||
dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE)
|
dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
|
|
||||||
label_map = meta_data['label_map']
|
label_map = metadata['label_map']
|
||||||
label_names = [label_map[k] for k in sorted(label_map.keys())]
|
label_names = [label_map[k] for k in sorted(label_map.keys())]
|
||||||
|
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=dataset, label_names=label_names, size=meta_data['size']
|
dataset=dataset, label_names=label_names, size=metadata['size']
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,25 +15,20 @@
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import collections
|
import collections
|
||||||
import dataclasses
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files
|
||||||
from official.vision.data import tfrecord_lib
|
from official.vision.data import tfrecord_lib
|
||||||
|
|
||||||
|
|
||||||
# Suffix of the meta data file name.
|
|
||||||
META_DATA_FILE_SUFFIX = '_meta_data.yaml'
|
|
||||||
|
|
||||||
|
|
||||||
def _xml_get(node: ET.Element, name: str) -> ET.Element:
|
def _xml_get(node: ET.Element, name: str) -> ET.Element:
|
||||||
"""Gets a named child from an XML Element node.
|
"""Gets a named child from an XML Element node.
|
||||||
|
|
||||||
|
@ -71,18 +66,9 @@ def _get_dir_basename(data_dir: str) -> str:
|
||||||
return os.path.basename(os.path.abspath(data_dir))
|
return os.path.basename(os.path.abspath(data_dir))
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
|
||||||
class CacheFiles:
|
|
||||||
"""Cache files for object detection."""
|
|
||||||
|
|
||||||
cache_prefix: str
|
|
||||||
tfrecord_files: Sequence[str]
|
|
||||||
meta_data_file: str
|
|
||||||
|
|
||||||
|
|
||||||
def _get_cache_files(
|
def _get_cache_files(
|
||||||
cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10
|
cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10
|
||||||
) -> CacheFiles:
|
) -> cache_files.TFRecordCacheFiles:
|
||||||
"""Creates an object of CacheFiles class.
|
"""Creates an object of CacheFiles class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -96,28 +82,16 @@ def _get_cache_files(
|
||||||
An object of CacheFiles class.
|
An object of CacheFiles class.
|
||||||
"""
|
"""
|
||||||
cache_dir = _get_cache_dir_or_create(cache_dir)
|
cache_dir = _get_cache_dir_or_create(cache_dir)
|
||||||
# The cache prefix including the cache directory and the cache prefix
|
return cache_files.TFRecordCacheFiles(
|
||||||
# filename, e.g: '/tmp/cache/train'.
|
cache_prefix_filename=cache_prefix_filename,
|
||||||
cache_prefix = os.path.join(cache_dir, cache_prefix_filename)
|
cache_dir=cache_dir,
|
||||||
tf.compat.v1.logging.info(
|
num_shards=num_shards,
|
||||||
'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s'
|
|
||||||
% (cache_dir, cache_prefix_filename, cache_prefix)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cached files including the TFRecord files and the meta data file.
|
|
||||||
tfrecord_files = [
|
|
||||||
cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards)
|
|
||||||
for i in range(num_shards)
|
|
||||||
]
|
|
||||||
meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX
|
|
||||||
return CacheFiles(
|
|
||||||
cache_prefix=cache_prefix,
|
|
||||||
tfrecord_files=tuple(tfrecord_files),
|
|
||||||
meta_data_file=meta_data_file,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
|
def get_cache_files_coco(
|
||||||
|
data_dir: str, cache_dir: str
|
||||||
|
) -> cache_files.TFRecordCacheFiles:
|
||||||
"""Creates an object of CacheFiles class using a COCO formatted dataset.
|
"""Creates an object of CacheFiles class using a COCO formatted dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -152,7 +126,9 @@ def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
|
||||||
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
||||||
|
|
||||||
|
|
||||||
def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
|
def get_cache_files_pascal_voc(
|
||||||
|
data_dir: str, cache_dir: str
|
||||||
|
) -> cache_files.TFRecordCacheFiles:
|
||||||
"""Gets an object of CacheFiles using a PASCAL VOC formatted dataset.
|
"""Gets an object of CacheFiles using a PASCAL VOC formatted dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -181,14 +157,6 @@ def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
|
||||||
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
|
||||||
|
|
||||||
|
|
||||||
def is_cached(cache_files: CacheFiles) -> bool:
|
|
||||||
"""Checks whether cache files are already cached."""
|
|
||||||
all_cached_files = list(cache_files.tfrecord_files) + [
|
|
||||||
cache_files.meta_data_file
|
|
||||||
]
|
|
||||||
return all(tf.io.gfile.exists(path) for path in all_cached_files)
|
|
||||||
|
|
||||||
|
|
||||||
class CacheFilesWriter(abc.ABC):
|
class CacheFilesWriter(abc.ABC):
|
||||||
"""CacheFilesWriter class to write the cached files."""
|
"""CacheFilesWriter class to write the cached files."""
|
||||||
|
|
||||||
|
@ -208,19 +176,22 @@ class CacheFilesWriter(abc.ABC):
|
||||||
self.label_map = label_map
|
self.label_map = label_map
|
||||||
self.max_num_images = max_num_images
|
self.max_num_images = max_num_images
|
||||||
|
|
||||||
def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None:
|
def write_files(
|
||||||
"""Writes TFRecord and meta_data files.
|
self,
|
||||||
|
tfrecord_cache_files: cache_files.TFRecordCacheFiles,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Writes TFRecord and metadata files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_files: CacheFiles object including a list of TFRecord files and the
|
tfrecord_cache_files: TFRecordCacheFiles object including a list of
|
||||||
meta data yaml file to save the meta_data including data size and
|
TFRecord files and the meta data yaml file to save the metadata
|
||||||
label_map.
|
including data size and label_map.
|
||||||
*args: Non-keyword of parameters used in the `_get_example` method.
|
*args: Non-keyword of parameters used in the `_get_example` method.
|
||||||
**kwargs: Keyword parameters used in the `_get_example` method.
|
**kwargs: Keyword parameters used in the `_get_example` method.
|
||||||
"""
|
"""
|
||||||
writers = [
|
writers = tfrecord_cache_files.get_writers()
|
||||||
tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files
|
|
||||||
]
|
|
||||||
|
|
||||||
# Writes tf.Example into TFRecord files.
|
# Writes tf.Example into TFRecord files.
|
||||||
size = 0
|
size = 0
|
||||||
|
@ -235,10 +206,9 @@ class CacheFilesWriter(abc.ABC):
|
||||||
for writer in writers:
|
for writer in writers:
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
# Writes meta_data into meta_data_file.
|
# Writes metadata into metadata_file.
|
||||||
meta_data = {'size': size, 'label_map': self.label_map}
|
metadata = {'size': size, 'label_map': self.label_map}
|
||||||
with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f:
|
tfrecord_cache_files.save_metadata(metadata)
|
||||||
yaml.dump(meta_data, f)
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _get_example(self, *args, **kwargs):
|
def _get_example(self, *args, **kwargs):
|
||||||
|
|
|
@ -19,7 +19,6 @@ import shutil
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import yaml
|
|
||||||
|
|
||||||
from mediapipe.model_maker.python.vision.core import test_utils
|
from mediapipe.model_maker.python.vision.core import test_utils
|
||||||
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
from mediapipe.model_maker.python.vision.object_detector import dataset_util
|
||||||
|
@ -30,13 +29,10 @@ class DatasetUtilTest(tf.test.TestCase):
|
||||||
|
|
||||||
def _assert_cache_files_equal(self, cf1, cf2):
|
def _assert_cache_files_equal(self, cf1, cf2):
|
||||||
self.assertEqual(cf1.cache_prefix, cf2.cache_prefix)
|
self.assertEqual(cf1.cache_prefix, cf2.cache_prefix)
|
||||||
self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files)
|
self.assertEqual(cf1.num_shards, cf2.num_shards)
|
||||||
self.assertEqual(cf1.meta_data_file, cf2.meta_data_file)
|
|
||||||
|
|
||||||
def _assert_cache_files_not_equal(self, cf1, cf2):
|
def _assert_cache_files_not_equal(self, cf1, cf2):
|
||||||
self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix)
|
self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix)
|
||||||
self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files)
|
|
||||||
self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file)
|
|
||||||
|
|
||||||
def _get_cache_files_and_assert_neq_fn(self, cache_files_fn):
|
def _get_cache_files_and_assert_neq_fn(self, cache_files_fn):
|
||||||
def get_cache_files_and_assert_neq(cf, data_dir, cache_dir):
|
def get_cache_files_and_assert_neq(cf, data_dir, cache_dir):
|
||||||
|
@ -57,7 +53,7 @@ class DatasetUtilTest(tf.test.TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
||||||
)
|
)
|
||||||
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
|
self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
|
||||||
|
|
||||||
def test_matching_get_cache_files_coco(self):
|
def test_matching_get_cache_files_coco(self):
|
||||||
cache_dir = self.create_tempdir()
|
cache_dir = self.create_tempdir()
|
||||||
|
@ -118,7 +114,7 @@ class DatasetUtilTest(tf.test.TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
|
||||||
)
|
)
|
||||||
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
|
self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
|
||||||
|
|
||||||
def test_matching_get_cache_files_pascal_voc(self):
|
def test_matching_get_cache_files_pascal_voc(self):
|
||||||
cache_dir = self.create_tempdir()
|
cache_dir = self.create_tempdir()
|
||||||
|
@ -173,13 +169,13 @@ class DatasetUtilTest(tf.test.TestCase):
|
||||||
cache_files = dataset_util.get_cache_files_coco(
|
cache_files = dataset_util.get_cache_files_coco(
|
||||||
tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir
|
tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir
|
||||||
)
|
)
|
||||||
self.assertFalse(dataset_util.is_cached(cache_files))
|
self.assertFalse(cache_files.is_cached())
|
||||||
with open(cache_files.tfrecord_files[0], 'w') as f:
|
with open(cache_files.tfrecord_files[0], 'w') as f:
|
||||||
f.write('test')
|
f.write('test')
|
||||||
self.assertFalse(dataset_util.is_cached(cache_files))
|
self.assertFalse(cache_files.is_cached())
|
||||||
with open(cache_files.meta_data_file, 'w') as f:
|
with open(cache_files.metadata_file, 'w') as f:
|
||||||
f.write('test')
|
f.write('test')
|
||||||
self.assertTrue(dataset_util.is_cached(cache_files))
|
self.assertTrue(cache_files.is_cached())
|
||||||
|
|
||||||
def test_get_label_map_coco(self):
|
def test_get_label_map_coco(self):
|
||||||
coco_dir = tasks_test_utils.get_test_data_path('coco_data')
|
coco_dir = tasks_test_utils.get_test_data_path('coco_data')
|
||||||
|
@ -203,13 +199,11 @@ class DatasetUtilTest(tf.test.TestCase):
|
||||||
self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0]))
|
self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0]))
|
||||||
self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0)
|
self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0)
|
||||||
|
|
||||||
# Checks the meta_data file
|
# Checks the metadata file
|
||||||
self.assertTrue(os.path.isfile(cache_files.meta_data_file))
|
self.assertTrue(os.path.isfile(cache_files.metadata_file))
|
||||||
self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0)
|
self.assertGreater(os.path.getsize(cache_files.metadata_file), 0)
|
||||||
with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f:
|
metadata_dict = cache_files.load_metadata()
|
||||||
meta_data_dict = yaml.load(f, Loader=yaml.FullLoader)
|
self.assertEqual(metadata_dict['size'], expected_size)
|
||||||
# Size is 3 because some examples are skipped for having poor bboxes
|
|
||||||
self.assertEqual(meta_data_dict['size'], expected_size)
|
|
||||||
|
|
||||||
def test_coco_cache_files_writer(self):
|
def test_coco_cache_files_writer(self):
|
||||||
tempdir = self.create_tempdir()
|
tempdir = self.create_tempdir()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user