Modify Kinetics datasets reader for custom.

Signed-off-by: Cheoljun Lee <cheoljun.lee@samsung.com>
This commit is contained in:
Cheoljun Lee 2021-06-28 16:03:54 +09:00
parent f42c8fd442
commit 9d70a4c5f9

View File

@ -73,6 +73,7 @@ import subprocess
import sys
import tarfile
import tempfile
from multiprocessing import Process, Queue
from absl import app
from absl import flags
@ -80,7 +81,7 @@ from absl import logging
from six.moves import range
from six.moves import urllib
from six.moves import zip
import tensorflow.compat.v1 as tf
import tensorflow as tf
from mediapipe.util.sequence import media_sequence as ms
@ -127,7 +128,7 @@ class Kinetics(object):
def as_dataset(self, split, shuffle=False, repeat=False,
serialized_prefetch_size=32, decoded_prefetch_size=32,
parse_labels=True):
parse_labels=True, num_classes=NUM_CLASSES):
"""Returns Kinetics as a tf.data.Dataset.
After running this function, calling padded_batch() on the Dataset object
@ -195,7 +196,7 @@ class Kinetics(object):
"num_frames": num_frames,
}
if parse_labels:
target = tf.one_hot(parsed_context[ms.get_clip_label_index_key()], 700)
target = tf.one_hot(parsed_context[ms.get_clip_label_index_key()], num_classes)
output_dict["labels"] = target
return output_dict
@ -265,14 +266,18 @@ class Kinetics(object):
logging.info("An example of the metadata: ")
logging.info(all_metadata[0])
random.seed(47)
random.shuffle(all_metadata)
#random.shuffle(all_metadata)
shards = SPLITS[key]["shards"]
shard_names = [os.path.join(
self.path_to_data, FILEPATTERN % key + "-%05d-of-%05d" % (
i, shards)) for i in range(shards)]
writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names]
with _close_on_exit(writers) as writers:
for i, seq_ex in enumerate(all_metadata):
def generate(q, writer):
while True:
i, seq_ex = q.get()
if i is None:
break
if not only_generate_metadata:
print("Processing example %d of %d (%d%%) \r" % (
i, len(all_metadata), i * 100 / len(all_metadata)), end="")
@ -280,7 +285,20 @@ class Kinetics(object):
graph_path = os.path.join(path_to_graph_directory, graph)
seq_ex = self._run_mediapipe(
path_to_mediapipe_binary, seq_ex, graph_path)
writers[i % len(writers)].write(seq_ex.SerializeToString())
writer.write(seq_ex.SerializeToString())
writer.close()
queue = Queue()
workers = []
for i, seq_ex in enumerate(all_metadata):
queue.put((i, seq_ex))
for i in range(shards):
queue.put((None, None))
for i in range(shards):
worker = Process(target=generate, args=(queue,writers[i]))
worker.start()
workers.append(worker)
for worker in workers:
worker.join()
logging.info("Data extraction complete.")
def _generate_metadata(self, key, download_output,
@ -320,6 +338,8 @@ class Kinetics(object):
ms.set_clip_data_path(bytes23(filepath), metadata)
assert row["start"].isdigit(), "Invalid row: %s" % str(row)
assert row["end"].isdigit(), "Invalid row: %s" % str(row)
ms.set_clip_start_timestamp(int(float(row["start"]) * SECONDS_TO_MICROSECONDS), metadata)
ms.set_clip_end_timestamp(int(float(row["end"]) * SECONDS_TO_MICROSECONDS), metadata)
if "label_name" in row:
ms.set_clip_label_string([bytes23(row["label_name"])], metadata)
if label_map:
@ -394,13 +414,13 @@ class Kinetics(object):
def get_label_map_and_verify_example_counts(self, paths):
"""Verify the number of examples and labels have not changed."""
label_map = None
for name, path in paths.items():
with open(path, "r") as f:
lines = f.readlines()
# the header adds one line and one "key".
num_examples = len(lines) - 1
keys = [l.split(",")[0] for l in lines]
label_map = None
if name == "train":
classes = sorted(list(set(keys[1:])))
num_keys = len(set(keys)) - 1