diff --git a/mediapipe/model_maker/python/core/data/BUILD b/mediapipe/model_maker/python/core/data/BUILD index 1c2fb7a44..4364b7744 100644 --- a/mediapipe/model_maker/python/core/data/BUILD +++ b/mediapipe/model_maker/python/core/data/BUILD @@ -57,3 +57,14 @@ py_test( srcs = ["classification_dataset_test.py"], deps = [":classification_dataset"], ) + +py_library( + name = "cache_files", + srcs = ["cache_files.py"], +) + +py_test( + name = "cache_files_test", + srcs = ["cache_files_test.py"], + deps = [":cache_files"], +) diff --git a/mediapipe/model_maker/python/core/data/cache_files.py b/mediapipe/model_maker/python/core/data/cache_files.py new file mode 100644 index 000000000..7324891eb --- /dev/null +++ b/mediapipe/model_maker/python/core/data/cache_files.py @@ -0,0 +1,112 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common TFRecord cache files library.""" + +import dataclasses +import os +import tempfile +from typing import Any, Mapping, Sequence + +import tensorflow as tf +import yaml + + +# Suffix of the meta data file name. +METADATA_FILE_SUFFIX = '_metadata.yaml' + + +@dataclasses.dataclass(frozen=True) +class TFRecordCacheFiles: + """TFRecordCacheFiles dataclass to store and load cached TFRecord files. + + Attributes: + cache_prefix_filename: The cache prefix filename. This is usually provided + as a hash of the original data source to avoid different data sources + resulting in the same cache file. + cache_dir: The cache directory to save TFRecord and metadata file. When + cache_dir is None, a temporary folder will be created and will not be + removed automatically after training which makes it can be used later. + num_shards: Number of shards for output tfrecord files. + """ + + cache_prefix_filename: str = 'cache_prefix' + cache_dir: str = dataclasses.field(default_factory=tempfile.mkdtemp) + num_shards: int = 1 + + def __post_init__(self): + if not self.cache_prefix_filename: + raise ValueError('cache_prefix_filename cannot be empty.') + if self.num_shards <= 0: + raise ValueError( + f'num_shards must be greater than 0, got {self.num_shards}' + ) + + @property + def cache_prefix(self) -> str: + """The cache prefix including the cache directory and the cache prefix filename.""" + return os.path.join(self.cache_dir, self.cache_prefix_filename) + + @property + def tfrecord_files(self) -> Sequence[str]: + """The TFRecord files.""" + tfrecord_files = [ + self.cache_prefix + '-%05d-of-%05d.tfrecord' % (i, self.num_shards) + for i in range(self.num_shards) + ] + return tfrecord_files + + @property + def metadata_file(self) -> str: + """The metadata file.""" + return self.cache_prefix + METADATA_FILE_SUFFIX + + def get_writers(self) -> Sequence[tf.io.TFRecordWriter]: + """Gets an array of TFRecordWriter objects. + + Note that these writers should each be closed using .close() when done. + + Returns: + Array of TFRecordWriter objects + """ + if not tf.io.gfile.exists(self.cache_dir): + tf.io.gfile.makedirs(self.cache_dir) + return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files] + + def save_metadata(self, metadata): + """Writes metadata to file. + + Args: + metadata: A dictionary of metadata content to write. Exact format is + dependent on the specific dataset, but typically includes a 'size' and + 'label_names' entry. + """ + with tf.io.gfile.GFile(self.metadata_file, 'w') as f: + yaml.dump(metadata, f) + + def load_metadata(self) -> Mapping[Any, Any]: + """Reads metadata from file. + + Returns: + Dictionary object containing metadata + """ + if not tf.io.gfile.exists(self.metadata_file): + return {} + with tf.io.gfile.GFile(self.metadata_file, 'r') as f: + metadata = yaml.load(f, Loader=yaml.FullLoader) + return metadata + + def is_cached(self) -> bool: + """Checks whether this CacheFiles is already cached.""" + all_cached_files = list(self.tfrecord_files) + [self.metadata_file] + return all(tf.io.gfile.exists(f) for f in all_cached_files) diff --git a/mediapipe/model_maker/python/core/data/cache_files_test.py b/mediapipe/model_maker/python/core/data/cache_files_test.py new file mode 100644 index 000000000..ac727b3fe --- /dev/null +++ b/mediapipe/model_maker/python/core/data/cache_files_test.py @@ -0,0 +1,77 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import cache_files + + +class CacheFilesTest(tf.test.TestCase): + + def test_tfrecord_cache_files(self): + cf = cache_files.TFRecordCacheFiles( + cache_prefix_filename='tfrecord', + cache_dir='/tmp/cache_dir', + num_shards=2, + ) + self.assertEqual(cf.cache_prefix, '/tmp/cache_dir/tfrecord') + self.assertEqual( + cf.metadata_file, + '/tmp/cache_dir/tfrecord' + cache_files.METADATA_FILE_SUFFIX, + ) + expected_tfrecord_files = [ + '/tmp/cache_dir/tfrecord-%05d-of-%05d.tfrecord' % (i, 2) + for i in range(2) + ] + self.assertEqual(cf.tfrecord_files, expected_tfrecord_files) + + # Writing TFRecord Files + self.assertFalse(cf.is_cached()) + for tfrecord_file in cf.tfrecord_files: + self.assertFalse(tf.io.gfile.exists(tfrecord_file)) + writers = cf.get_writers() + for writer in writers: + writer.close() + for tfrecord_file in cf.tfrecord_files: + self.assertTrue(tf.io.gfile.exists(tfrecord_file)) + self.assertFalse(cf.is_cached()) + + # Writing Metadata Files + original_metadata = {'size': 10, 'label_names': ['label1', 'label2']} + cf.save_metadata(original_metadata) + self.assertTrue(cf.is_cached()) + metadata = cf.load_metadata() + self.assertEqual(metadata, original_metadata) + + def test_recordio_cache_files_error(self): + with self.assertRaisesRegex( + ValueError, 'cache_prefix_filename cannot be empty' + ): + cache_files.TFRecordCacheFiles( + cache_prefix_filename='', + cache_dir='/tmp/cache_dir', + num_shards=2, + ) + with self.assertRaisesRegex( + ValueError, 'num_shards must be greater than 0, got 0' + ): + cache_files.TFRecordCacheFiles( + cache_prefix_filename='tfrecord', + cache_dir='/tmp/cache_dir', + num_shards=0, + ) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/object_detector/BUILD b/mediapipe/model_maker/python/vision/object_detector/BUILD index 75c08dbc8..3a0460544 100644 --- a/mediapipe/model_maker/python/vision/object_detector/BUILD +++ b/mediapipe/model_maker/python/vision/object_detector/BUILD @@ -54,6 +54,7 @@ py_library( srcs = ["dataset.py"], deps = [ ":dataset_util", + "//mediapipe/model_maker/python/core/data:cache_files", "//mediapipe/model_maker/python/core/data:classification_dataset", ], ) @@ -73,6 +74,7 @@ py_test( py_library( name = "dataset_util", srcs = ["dataset_util.py"], + deps = ["//mediapipe/model_maker/python/core/data:cache_files"], ) py_test( diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset.py b/mediapipe/model_maker/python/vision/object_detector/dataset.py index bec1a8446..f7751915e 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset.py @@ -16,8 +16,8 @@ from typing import Optional import tensorflow as tf -import yaml +from mediapipe.model_maker.python.core.data import cache_files from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.vision.object_detector import dataset_util from official.vision.dataloaders import tf_example_decoder @@ -76,14 +76,16 @@ class Dataset(classification_dataset.ClassificationDataset): ValueError: If the label_name for id 0 is set to something other than the 'background' class. """ - cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir) - if not dataset_util.is_cached(cache_files): + tfrecord_cache_files = dataset_util.get_cache_files_coco( + data_dir, cache_dir + ) + if not tfrecord_cache_files.is_cached(): label_map = dataset_util.get_label_map_coco(data_dir) cache_writer = dataset_util.COCOCacheFilesWriter( label_map=label_map, max_num_images=max_num_images ) - cache_writer.write_files(cache_files, data_dir) - return cls.from_cache(cache_files.cache_prefix) + cache_writer.write_files(tfrecord_cache_files, data_dir) + return cls.from_cache(tfrecord_cache_files) @classmethod def from_pascal_voc_folder( @@ -134,47 +136,48 @@ class Dataset(classification_dataset.ClassificationDataset): Raises: ValueError: if the input data directory is empty. """ - cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir) - if not dataset_util.is_cached(cache_files): + tfrecord_cache_files = dataset_util.get_cache_files_pascal_voc( + data_dir, cache_dir + ) + if not tfrecord_cache_files.is_cached(): label_map = dataset_util.get_label_map_pascal_voc(data_dir) cache_writer = dataset_util.PascalVocCacheFilesWriter( label_map=label_map, max_num_images=max_num_images ) - cache_writer.write_files(cache_files, data_dir) + cache_writer.write_files(tfrecord_cache_files, data_dir) - return cls.from_cache(cache_files.cache_prefix) + return cls.from_cache(tfrecord_cache_files) @classmethod - def from_cache(cls, cache_prefix: str) -> 'Dataset': + def from_cache( + cls, tfrecord_cache_files: cache_files.TFRecordCacheFiles + ) -> 'Dataset': """Loads the TFRecord data from cache. Args: - cache_prefix: The cache prefix including the cache directory and the cache - prefix filename, e.g: '/tmp/cache/train'. + tfrecord_cache_files: The TFRecordCacheFiles object containing the already + cached TFRecord and metadata files. Returns: ObjectDetectorDataset object. + + Raises: + ValueError if tfrecord_cache_files are not already cached. """ - # Get TFRecord Files - tfrecord_file_pattern = cache_prefix + '*.tfrecord' - matched_files = tf.io.gfile.glob(tfrecord_file_pattern) - if not matched_files: - raise ValueError('TFRecord files are empty.') + if not tfrecord_cache_files.is_cached(): + raise ValueError( + 'Cache files must be already cached to use the from_cache method.' + ) - # Load meta_data. - meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX - if not tf.io.gfile.exists(meta_data_file): - raise ValueError("Metadata file %s doesn't exist." % meta_data_file) - with tf.io.gfile.GFile(meta_data_file, 'r') as f: - meta_data = yaml.load(f, Loader=yaml.FullLoader) + metadata = tfrecord_cache_files.load_metadata() - dataset = tf.data.TFRecordDataset(matched_files) + dataset = tf.data.TFRecordDataset(tfrecord_cache_files.tfrecord_files) decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False) dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE) - label_map = meta_data['label_map'] + label_map = metadata['label_map'] label_names = [label_map[k] for k in sorted(label_map.keys())] return Dataset( - dataset=dataset, label_names=label_names, size=meta_data['size'] + dataset=dataset, label_names=label_names, size=metadata['size'] ) diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset_util.py b/mediapipe/model_maker/python/vision/object_detector/dataset_util.py index 74d082f9f..fbb821b3b 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset_util.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset_util.py @@ -15,25 +15,20 @@ import abc import collections -import dataclasses import hashlib import json import math import os import tempfile -from typing import Any, Dict, List, Mapping, Optional, Sequence +from typing import Any, Dict, List, Mapping, Optional import xml.etree.ElementTree as ET import tensorflow as tf -import yaml +from mediapipe.model_maker.python.core.data import cache_files from official.vision.data import tfrecord_lib -# Suffix of the meta data file name. -META_DATA_FILE_SUFFIX = '_meta_data.yaml' - - def _xml_get(node: ET.Element, name: str) -> ET.Element: """Gets a named child from an XML Element node. @@ -71,18 +66,9 @@ def _get_dir_basename(data_dir: str) -> str: return os.path.basename(os.path.abspath(data_dir)) -@dataclasses.dataclass(frozen=True) -class CacheFiles: - """Cache files for object detection.""" - - cache_prefix: str - tfrecord_files: Sequence[str] - meta_data_file: str - - def _get_cache_files( cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10 -) -> CacheFiles: +) -> cache_files.TFRecordCacheFiles: """Creates an object of CacheFiles class. Args: @@ -96,28 +82,16 @@ def _get_cache_files( An object of CacheFiles class. """ cache_dir = _get_cache_dir_or_create(cache_dir) - # The cache prefix including the cache directory and the cache prefix - # filename, e.g: '/tmp/cache/train'. - cache_prefix = os.path.join(cache_dir, cache_prefix_filename) - tf.compat.v1.logging.info( - 'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s' - % (cache_dir, cache_prefix_filename, cache_prefix) - ) - - # Cached files including the TFRecord files and the meta data file. - tfrecord_files = [ - cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards) - for i in range(num_shards) - ] - meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX - return CacheFiles( - cache_prefix=cache_prefix, - tfrecord_files=tuple(tfrecord_files), - meta_data_file=meta_data_file, + return cache_files.TFRecordCacheFiles( + cache_prefix_filename=cache_prefix_filename, + cache_dir=cache_dir, + num_shards=num_shards, ) -def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles: +def get_cache_files_coco( + data_dir: str, cache_dir: str +) -> cache_files.TFRecordCacheFiles: """Creates an object of CacheFiles class using a COCO formatted dataset. Args: @@ -152,7 +126,9 @@ def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles: return _get_cache_files(cache_dir, cache_prefix_filename, num_shards) -def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles: +def get_cache_files_pascal_voc( + data_dir: str, cache_dir: str +) -> cache_files.TFRecordCacheFiles: """Gets an object of CacheFiles using a PASCAL VOC formatted dataset. Args: @@ -181,14 +157,6 @@ def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles: return _get_cache_files(cache_dir, cache_prefix_filename, num_shards) -def is_cached(cache_files: CacheFiles) -> bool: - """Checks whether cache files are already cached.""" - all_cached_files = list(cache_files.tfrecord_files) + [ - cache_files.meta_data_file - ] - return all(tf.io.gfile.exists(path) for path in all_cached_files) - - class CacheFilesWriter(abc.ABC): """CacheFilesWriter class to write the cached files.""" @@ -208,19 +176,22 @@ class CacheFilesWriter(abc.ABC): self.label_map = label_map self.max_num_images = max_num_images - def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None: - """Writes TFRecord and meta_data files. + def write_files( + self, + tfrecord_cache_files: cache_files.TFRecordCacheFiles, + *args, + **kwargs, + ) -> None: + """Writes TFRecord and metadata files. Args: - cache_files: CacheFiles object including a list of TFRecord files and the - meta data yaml file to save the meta_data including data size and - label_map. + tfrecord_cache_files: TFRecordCacheFiles object including a list of + TFRecord files and the meta data yaml file to save the metadata + including data size and label_map. *args: Non-keyword of parameters used in the `_get_example` method. **kwargs: Keyword parameters used in the `_get_example` method. """ - writers = [ - tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files - ] + writers = tfrecord_cache_files.get_writers() # Writes tf.Example into TFRecord files. size = 0 @@ -235,10 +206,9 @@ class CacheFilesWriter(abc.ABC): for writer in writers: writer.close() - # Writes meta_data into meta_data_file. - meta_data = {'size': size, 'label_map': self.label_map} - with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f: - yaml.dump(meta_data, f) + # Writes metadata into metadata_file. + metadata = {'size': size, 'label_map': self.label_map} + tfrecord_cache_files.save_metadata(metadata) @abc.abstractmethod def _get_example(self, *args, **kwargs): diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py b/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py index 6daea1f47..250c5d45e 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py @@ -19,7 +19,6 @@ import shutil from unittest import mock as unittest_mock import tensorflow as tf -import yaml from mediapipe.model_maker.python.vision.core import test_utils from mediapipe.model_maker.python.vision.object_detector import dataset_util @@ -30,13 +29,10 @@ class DatasetUtilTest(tf.test.TestCase): def _assert_cache_files_equal(self, cf1, cf2): self.assertEqual(cf1.cache_prefix, cf2.cache_prefix) - self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files) - self.assertEqual(cf1.meta_data_file, cf2.meta_data_file) + self.assertEqual(cf1.num_shards, cf2.num_shards) def _assert_cache_files_not_equal(self, cf1, cf2): self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix) - self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files) - self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file) def _get_cache_files_and_assert_neq_fn(self, cache_files_fn): def get_cache_files_and_assert_neq(cf, data_dir, cache_dir): @@ -57,7 +53,7 @@ class DatasetUtilTest(tf.test.TestCase): self.assertEqual( cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord' ) - self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml') + self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml') def test_matching_get_cache_files_coco(self): cache_dir = self.create_tempdir() @@ -118,7 +114,7 @@ class DatasetUtilTest(tf.test.TestCase): self.assertEqual( cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord' ) - self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml') + self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml') def test_matching_get_cache_files_pascal_voc(self): cache_dir = self.create_tempdir() @@ -173,13 +169,13 @@ class DatasetUtilTest(tf.test.TestCase): cache_files = dataset_util.get_cache_files_coco( tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir ) - self.assertFalse(dataset_util.is_cached(cache_files)) + self.assertFalse(cache_files.is_cached()) with open(cache_files.tfrecord_files[0], 'w') as f: f.write('test') - self.assertFalse(dataset_util.is_cached(cache_files)) - with open(cache_files.meta_data_file, 'w') as f: + self.assertFalse(cache_files.is_cached()) + with open(cache_files.metadata_file, 'w') as f: f.write('test') - self.assertTrue(dataset_util.is_cached(cache_files)) + self.assertTrue(cache_files.is_cached()) def test_get_label_map_coco(self): coco_dir = tasks_test_utils.get_test_data_path('coco_data') @@ -203,13 +199,11 @@ class DatasetUtilTest(tf.test.TestCase): self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0])) self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0) - # Checks the meta_data file - self.assertTrue(os.path.isfile(cache_files.meta_data_file)) - self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0) - with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f: - meta_data_dict = yaml.load(f, Loader=yaml.FullLoader) - # Size is 3 because some examples are skipped for having poor bboxes - self.assertEqual(meta_data_dict['size'], expected_size) + # Checks the metadata file + self.assertTrue(os.path.isfile(cache_files.metadata_file)) + self.assertGreater(os.path.getsize(cache_files.metadata_file), 0) + metadata_dict = cache_files.load_metadata() + self.assertEqual(metadata_dict['size'], expected_size) def test_coco_cache_files_writer(self): tempdir = self.create_tempdir()