Internal change

PiperOrigin-RevId: 547404737
This commit is contained in:
MediaPipe Team 2023-07-12 00:01:53 -07:00 committed by Copybara-Service
parent 917af2ce6b
commit 3e93cbc838
7 changed files with 270 additions and 101 deletions

View File

@ -57,3 +57,14 @@ py_test(
srcs = ["classification_dataset_test.py"],
deps = [":classification_dataset"],
)
py_library(
name = "cache_files",
srcs = ["cache_files.py"],
)
py_test(
name = "cache_files_test",
srcs = ["cache_files_test.py"],
deps = [":cache_files"],
)

View File

@ -0,0 +1,112 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common TFRecord cache files library."""
import dataclasses
import os
import tempfile
from typing import Any, Mapping, Sequence
import tensorflow as tf
import yaml
# Suffix of the meta data file name.
METADATA_FILE_SUFFIX = '_metadata.yaml'
@dataclasses.dataclass(frozen=True)
class TFRecordCacheFiles:
"""TFRecordCacheFiles dataclass to store and load cached TFRecord files.
Attributes:
cache_prefix_filename: The cache prefix filename. This is usually provided
as a hash of the original data source to avoid different data sources
resulting in the same cache file.
cache_dir: The cache directory to save TFRecord and metadata file. When
cache_dir is None, a temporary folder will be created and will not be
removed automatically after training which makes it can be used later.
num_shards: Number of shards for output tfrecord files.
"""
cache_prefix_filename: str = 'cache_prefix'
cache_dir: str = dataclasses.field(default_factory=tempfile.mkdtemp)
num_shards: int = 1
def __post_init__(self):
if not self.cache_prefix_filename:
raise ValueError('cache_prefix_filename cannot be empty.')
if self.num_shards <= 0:
raise ValueError(
f'num_shards must be greater than 0, got {self.num_shards}'
)
@property
def cache_prefix(self) -> str:
"""The cache prefix including the cache directory and the cache prefix filename."""
return os.path.join(self.cache_dir, self.cache_prefix_filename)
@property
def tfrecord_files(self) -> Sequence[str]:
"""The TFRecord files."""
tfrecord_files = [
self.cache_prefix + '-%05d-of-%05d.tfrecord' % (i, self.num_shards)
for i in range(self.num_shards)
]
return tfrecord_files
@property
def metadata_file(self) -> str:
"""The metadata file."""
return self.cache_prefix + METADATA_FILE_SUFFIX
def get_writers(self) -> Sequence[tf.io.TFRecordWriter]:
"""Gets an array of TFRecordWriter objects.
Note that these writers should each be closed using .close() when done.
Returns:
Array of TFRecordWriter objects
"""
if not tf.io.gfile.exists(self.cache_dir):
tf.io.gfile.makedirs(self.cache_dir)
return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files]
def save_metadata(self, metadata):
"""Writes metadata to file.
Args:
metadata: A dictionary of metadata content to write. Exact format is
dependent on the specific dataset, but typically includes a 'size' and
'label_names' entry.
"""
with tf.io.gfile.GFile(self.metadata_file, 'w') as f:
yaml.dump(metadata, f)
def load_metadata(self) -> Mapping[Any, Any]:
"""Reads metadata from file.
Returns:
Dictionary object containing metadata
"""
if not tf.io.gfile.exists(self.metadata_file):
return {}
with tf.io.gfile.GFile(self.metadata_file, 'r') as f:
metadata = yaml.load(f, Loader=yaml.FullLoader)
return metadata
def is_cached(self) -> bool:
"""Checks whether this CacheFiles is already cached."""
all_cached_files = list(self.tfrecord_files) + [self.metadata_file]
return all(tf.io.gfile.exists(f) for f in all_cached_files)

View File

@ -0,0 +1,77 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files
class CacheFilesTest(tf.test.TestCase):
def test_tfrecord_cache_files(self):
cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='tfrecord',
cache_dir='/tmp/cache_dir',
num_shards=2,
)
self.assertEqual(cf.cache_prefix, '/tmp/cache_dir/tfrecord')
self.assertEqual(
cf.metadata_file,
'/tmp/cache_dir/tfrecord' + cache_files.METADATA_FILE_SUFFIX,
)
expected_tfrecord_files = [
'/tmp/cache_dir/tfrecord-%05d-of-%05d.tfrecord' % (i, 2)
for i in range(2)
]
self.assertEqual(cf.tfrecord_files, expected_tfrecord_files)
# Writing TFRecord Files
self.assertFalse(cf.is_cached())
for tfrecord_file in cf.tfrecord_files:
self.assertFalse(tf.io.gfile.exists(tfrecord_file))
writers = cf.get_writers()
for writer in writers:
writer.close()
for tfrecord_file in cf.tfrecord_files:
self.assertTrue(tf.io.gfile.exists(tfrecord_file))
self.assertFalse(cf.is_cached())
# Writing Metadata Files
original_metadata = {'size': 10, 'label_names': ['label1', 'label2']}
cf.save_metadata(original_metadata)
self.assertTrue(cf.is_cached())
metadata = cf.load_metadata()
self.assertEqual(metadata, original_metadata)
def test_recordio_cache_files_error(self):
with self.assertRaisesRegex(
ValueError, 'cache_prefix_filename cannot be empty'
):
cache_files.TFRecordCacheFiles(
cache_prefix_filename='',
cache_dir='/tmp/cache_dir',
num_shards=2,
)
with self.assertRaisesRegex(
ValueError, 'num_shards must be greater than 0, got 0'
):
cache_files.TFRecordCacheFiles(
cache_prefix_filename='tfrecord',
cache_dir='/tmp/cache_dir',
num_shards=0,
)
if __name__ == '__main__':
tf.test.main()

View File

@ -54,6 +54,7 @@ py_library(
srcs = ["dataset.py"],
deps = [
":dataset_util",
"//mediapipe/model_maker/python/core/data:cache_files",
"//mediapipe/model_maker/python/core/data:classification_dataset",
],
)
@ -73,6 +74,7 @@ py_test(
py_library(
name = "dataset_util",
srcs = ["dataset_util.py"],
deps = ["//mediapipe/model_maker/python/core/data:cache_files"],
)
py_test(

View File

@ -16,8 +16,8 @@
from typing import Optional
import tensorflow as tf
import yaml
from mediapipe.model_maker.python.core.data import cache_files
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.object_detector import dataset_util
from official.vision.dataloaders import tf_example_decoder
@ -76,14 +76,16 @@ class Dataset(classification_dataset.ClassificationDataset):
ValueError: If the label_name for id 0 is set to something other than
the 'background' class.
"""
cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir)
if not dataset_util.is_cached(cache_files):
tfrecord_cache_files = dataset_util.get_cache_files_coco(
data_dir, cache_dir
)
if not tfrecord_cache_files.is_cached():
label_map = dataset_util.get_label_map_coco(data_dir)
cache_writer = dataset_util.COCOCacheFilesWriter(
label_map=label_map, max_num_images=max_num_images
)
cache_writer.write_files(cache_files, data_dir)
return cls.from_cache(cache_files.cache_prefix)
cache_writer.write_files(tfrecord_cache_files, data_dir)
return cls.from_cache(tfrecord_cache_files)
@classmethod
def from_pascal_voc_folder(
@ -134,47 +136,48 @@ class Dataset(classification_dataset.ClassificationDataset):
Raises:
ValueError: if the input data directory is empty.
"""
cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir)
if not dataset_util.is_cached(cache_files):
tfrecord_cache_files = dataset_util.get_cache_files_pascal_voc(
data_dir, cache_dir
)
if not tfrecord_cache_files.is_cached():
label_map = dataset_util.get_label_map_pascal_voc(data_dir)
cache_writer = dataset_util.PascalVocCacheFilesWriter(
label_map=label_map, max_num_images=max_num_images
)
cache_writer.write_files(cache_files, data_dir)
cache_writer.write_files(tfrecord_cache_files, data_dir)
return cls.from_cache(cache_files.cache_prefix)
return cls.from_cache(tfrecord_cache_files)
@classmethod
def from_cache(cls, cache_prefix: str) -> 'Dataset':
def from_cache(
cls, tfrecord_cache_files: cache_files.TFRecordCacheFiles
) -> 'Dataset':
"""Loads the TFRecord data from cache.
Args:
cache_prefix: The cache prefix including the cache directory and the cache
prefix filename, e.g: '/tmp/cache/train'.
tfrecord_cache_files: The TFRecordCacheFiles object containing the already
cached TFRecord and metadata files.
Returns:
ObjectDetectorDataset object.
Raises:
ValueError if tfrecord_cache_files are not already cached.
"""
# Get TFRecord Files
tfrecord_file_pattern = cache_prefix + '*.tfrecord'
matched_files = tf.io.gfile.glob(tfrecord_file_pattern)
if not matched_files:
raise ValueError('TFRecord files are empty.')
if not tfrecord_cache_files.is_cached():
raise ValueError(
'Cache files must be already cached to use the from_cache method.'
)
# Load meta_data.
meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX
if not tf.io.gfile.exists(meta_data_file):
raise ValueError("Metadata file %s doesn't exist." % meta_data_file)
with tf.io.gfile.GFile(meta_data_file, 'r') as f:
meta_data = yaml.load(f, Loader=yaml.FullLoader)
metadata = tfrecord_cache_files.load_metadata()
dataset = tf.data.TFRecordDataset(matched_files)
dataset = tf.data.TFRecordDataset(tfrecord_cache_files.tfrecord_files)
decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False)
dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE)
label_map = meta_data['label_map']
label_map = metadata['label_map']
label_names = [label_map[k] for k in sorted(label_map.keys())]
return Dataset(
dataset=dataset, label_names=label_names, size=meta_data['size']
dataset=dataset, label_names=label_names, size=metadata['size']
)

View File

@ -15,25 +15,20 @@
import abc
import collections
import dataclasses
import hashlib
import json
import math
import os
import tempfile
from typing import Any, Dict, List, Mapping, Optional, Sequence
from typing import Any, Dict, List, Mapping, Optional
import xml.etree.ElementTree as ET
import tensorflow as tf
import yaml
from mediapipe.model_maker.python.core.data import cache_files
from official.vision.data import tfrecord_lib
# Suffix of the meta data file name.
META_DATA_FILE_SUFFIX = '_meta_data.yaml'
def _xml_get(node: ET.Element, name: str) -> ET.Element:
"""Gets a named child from an XML Element node.
@ -71,18 +66,9 @@ def _get_dir_basename(data_dir: str) -> str:
return os.path.basename(os.path.abspath(data_dir))
@dataclasses.dataclass(frozen=True)
class CacheFiles:
"""Cache files for object detection."""
cache_prefix: str
tfrecord_files: Sequence[str]
meta_data_file: str
def _get_cache_files(
cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10
) -> CacheFiles:
) -> cache_files.TFRecordCacheFiles:
"""Creates an object of CacheFiles class.
Args:
@ -96,28 +82,16 @@ def _get_cache_files(
An object of CacheFiles class.
"""
cache_dir = _get_cache_dir_or_create(cache_dir)
# The cache prefix including the cache directory and the cache prefix
# filename, e.g: '/tmp/cache/train'.
cache_prefix = os.path.join(cache_dir, cache_prefix_filename)
tf.compat.v1.logging.info(
'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s'
% (cache_dir, cache_prefix_filename, cache_prefix)
)
# Cached files including the TFRecord files and the meta data file.
tfrecord_files = [
cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards)
for i in range(num_shards)
]
meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX
return CacheFiles(
cache_prefix=cache_prefix,
tfrecord_files=tuple(tfrecord_files),
meta_data_file=meta_data_file,
return cache_files.TFRecordCacheFiles(
cache_prefix_filename=cache_prefix_filename,
cache_dir=cache_dir,
num_shards=num_shards,
)
def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
def get_cache_files_coco(
data_dir: str, cache_dir: str
) -> cache_files.TFRecordCacheFiles:
"""Creates an object of CacheFiles class using a COCO formatted dataset.
Args:
@ -152,7 +126,9 @@ def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles:
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
def get_cache_files_pascal_voc(
data_dir: str, cache_dir: str
) -> cache_files.TFRecordCacheFiles:
"""Gets an object of CacheFiles using a PASCAL VOC formatted dataset.
Args:
@ -181,14 +157,6 @@ def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles:
return _get_cache_files(cache_dir, cache_prefix_filename, num_shards)
def is_cached(cache_files: CacheFiles) -> bool:
"""Checks whether cache files are already cached."""
all_cached_files = list(cache_files.tfrecord_files) + [
cache_files.meta_data_file
]
return all(tf.io.gfile.exists(path) for path in all_cached_files)
class CacheFilesWriter(abc.ABC):
"""CacheFilesWriter class to write the cached files."""
@ -208,19 +176,22 @@ class CacheFilesWriter(abc.ABC):
self.label_map = label_map
self.max_num_images = max_num_images
def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None:
"""Writes TFRecord and meta_data files.
def write_files(
self,
tfrecord_cache_files: cache_files.TFRecordCacheFiles,
*args,
**kwargs,
) -> None:
"""Writes TFRecord and metadata files.
Args:
cache_files: CacheFiles object including a list of TFRecord files and the
meta data yaml file to save the meta_data including data size and
label_map.
tfrecord_cache_files: TFRecordCacheFiles object including a list of
TFRecord files and the meta data yaml file to save the metadata
including data size and label_map.
*args: Non-keyword of parameters used in the `_get_example` method.
**kwargs: Keyword parameters used in the `_get_example` method.
"""
writers = [
tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files
]
writers = tfrecord_cache_files.get_writers()
# Writes tf.Example into TFRecord files.
size = 0
@ -235,10 +206,9 @@ class CacheFilesWriter(abc.ABC):
for writer in writers:
writer.close()
# Writes meta_data into meta_data_file.
meta_data = {'size': size, 'label_map': self.label_map}
with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f:
yaml.dump(meta_data, f)
# Writes metadata into metadata_file.
metadata = {'size': size, 'label_map': self.label_map}
tfrecord_cache_files.save_metadata(metadata)
@abc.abstractmethod
def _get_example(self, *args, **kwargs):

View File

@ -19,7 +19,6 @@ import shutil
from unittest import mock as unittest_mock
import tensorflow as tf
import yaml
from mediapipe.model_maker.python.vision.core import test_utils
from mediapipe.model_maker.python.vision.object_detector import dataset_util
@ -30,13 +29,10 @@ class DatasetUtilTest(tf.test.TestCase):
def _assert_cache_files_equal(self, cf1, cf2):
self.assertEqual(cf1.cache_prefix, cf2.cache_prefix)
self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files)
self.assertEqual(cf1.meta_data_file, cf2.meta_data_file)
self.assertEqual(cf1.num_shards, cf2.num_shards)
def _assert_cache_files_not_equal(self, cf1, cf2):
self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix)
self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files)
self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file)
def _get_cache_files_and_assert_neq_fn(self, cache_files_fn):
def get_cache_files_and_assert_neq(cf, data_dir, cache_dir):
@ -57,7 +53,7 @@ class DatasetUtilTest(tf.test.TestCase):
self.assertEqual(
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
)
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
def test_matching_get_cache_files_coco(self):
cache_dir = self.create_tempdir()
@ -118,7 +114,7 @@ class DatasetUtilTest(tf.test.TestCase):
self.assertEqual(
cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord'
)
self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml')
self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml')
def test_matching_get_cache_files_pascal_voc(self):
cache_dir = self.create_tempdir()
@ -173,13 +169,13 @@ class DatasetUtilTest(tf.test.TestCase):
cache_files = dataset_util.get_cache_files_coco(
tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir
)
self.assertFalse(dataset_util.is_cached(cache_files))
self.assertFalse(cache_files.is_cached())
with open(cache_files.tfrecord_files[0], 'w') as f:
f.write('test')
self.assertFalse(dataset_util.is_cached(cache_files))
with open(cache_files.meta_data_file, 'w') as f:
self.assertFalse(cache_files.is_cached())
with open(cache_files.metadata_file, 'w') as f:
f.write('test')
self.assertTrue(dataset_util.is_cached(cache_files))
self.assertTrue(cache_files.is_cached())
def test_get_label_map_coco(self):
coco_dir = tasks_test_utils.get_test_data_path('coco_data')
@ -203,13 +199,11 @@ class DatasetUtilTest(tf.test.TestCase):
self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0]))
self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0)
# Checks the meta_data file
self.assertTrue(os.path.isfile(cache_files.meta_data_file))
self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0)
with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f:
meta_data_dict = yaml.load(f, Loader=yaml.FullLoader)
# Size is 3 because some examples are skipped for having poor bboxes
self.assertEqual(meta_data_dict['size'], expected_size)
# Checks the metadata file
self.assertTrue(os.path.isfile(cache_files.metadata_file))
self.assertGreater(os.path.getsize(cache_files.metadata_file), 0)
metadata_dict = cache_files.load_metadata()
self.assertEqual(metadata_dict['size'], expected_size)
def test_coco_cache_files_writer(self):
tempdir = self.create_tempdir()