Modify Kinetics datasets reader for custom.

Signed-off-by: Cheoljun Lee <cheoljun.lee@samsung.com>
This commit is contained in:
Cheoljun Lee 2021-06-28 16:03:54 +09:00
parent f42c8fd442
commit 9d70a4c5f9

View File

@ -73,6 +73,7 @@ import subprocess
import sys import sys
import tarfile import tarfile
import tempfile import tempfile
from multiprocessing import Process, Queue
from absl import app from absl import app
from absl import flags from absl import flags
@ -80,7 +81,7 @@ from absl import logging
from six.moves import range from six.moves import range
from six.moves import urllib from six.moves import urllib
from six.moves import zip from six.moves import zip
import tensorflow.compat.v1 as tf import tensorflow as tf
from mediapipe.util.sequence import media_sequence as ms from mediapipe.util.sequence import media_sequence as ms
@ -127,7 +128,7 @@ class Kinetics(object):
def as_dataset(self, split, shuffle=False, repeat=False, def as_dataset(self, split, shuffle=False, repeat=False,
serialized_prefetch_size=32, decoded_prefetch_size=32, serialized_prefetch_size=32, decoded_prefetch_size=32,
parse_labels=True): parse_labels=True, num_classes=NUM_CLASSES):
"""Returns Kinetics as a tf.data.Dataset. """Returns Kinetics as a tf.data.Dataset.
After running this function, calling padded_batch() on the Dataset object After running this function, calling padded_batch() on the Dataset object
@ -195,7 +196,7 @@ class Kinetics(object):
"num_frames": num_frames, "num_frames": num_frames,
} }
if parse_labels: if parse_labels:
target = tf.one_hot(parsed_context[ms.get_clip_label_index_key()], 700) target = tf.one_hot(parsed_context[ms.get_clip_label_index_key()], num_classes)
output_dict["labels"] = target output_dict["labels"] = target
return output_dict return output_dict
@ -265,14 +266,18 @@ class Kinetics(object):
logging.info("An example of the metadata: ") logging.info("An example of the metadata: ")
logging.info(all_metadata[0]) logging.info(all_metadata[0])
random.seed(47) random.seed(47)
random.shuffle(all_metadata) #random.shuffle(all_metadata)
shards = SPLITS[key]["shards"] shards = SPLITS[key]["shards"]
shard_names = [os.path.join( shard_names = [os.path.join(
self.path_to_data, FILEPATTERN % key + "-%05d-of-%05d" % ( self.path_to_data, FILEPATTERN % key + "-%05d-of-%05d" % (
i, shards)) for i in range(shards)] i, shards)) for i in range(shards)]
writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names] writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names]
with _close_on_exit(writers) as writers:
for i, seq_ex in enumerate(all_metadata): def generate(q, writer):
while True:
i, seq_ex = q.get()
if i is None:
break
if not only_generate_metadata: if not only_generate_metadata:
print("Processing example %d of %d (%d%%) \r" % ( print("Processing example %d of %d (%d%%) \r" % (
i, len(all_metadata), i * 100 / len(all_metadata)), end="") i, len(all_metadata), i * 100 / len(all_metadata)), end="")
@ -280,7 +285,20 @@ class Kinetics(object):
graph_path = os.path.join(path_to_graph_directory, graph) graph_path = os.path.join(path_to_graph_directory, graph)
seq_ex = self._run_mediapipe( seq_ex = self._run_mediapipe(
path_to_mediapipe_binary, seq_ex, graph_path) path_to_mediapipe_binary, seq_ex, graph_path)
writers[i % len(writers)].write(seq_ex.SerializeToString()) writer.write(seq_ex.SerializeToString())
writer.close()
queue = Queue()
workers = []
for i, seq_ex in enumerate(all_metadata):
queue.put((i, seq_ex))
for i in range(shards):
queue.put((None, None))
for i in range(shards):
worker = Process(target=generate, args=(queue,writers[i]))
worker.start()
workers.append(worker)
for worker in workers:
worker.join()
logging.info("Data extraction complete.") logging.info("Data extraction complete.")
def _generate_metadata(self, key, download_output, def _generate_metadata(self, key, download_output,
@ -320,6 +338,8 @@ class Kinetics(object):
ms.set_clip_data_path(bytes23(filepath), metadata) ms.set_clip_data_path(bytes23(filepath), metadata)
assert row["start"].isdigit(), "Invalid row: %s" % str(row) assert row["start"].isdigit(), "Invalid row: %s" % str(row)
assert row["end"].isdigit(), "Invalid row: %s" % str(row) assert row["end"].isdigit(), "Invalid row: %s" % str(row)
ms.set_clip_start_timestamp(int(float(row["start"]) * SECONDS_TO_MICROSECONDS), metadata)
ms.set_clip_end_timestamp(int(float(row["end"]) * SECONDS_TO_MICROSECONDS), metadata)
if "label_name" in row: if "label_name" in row:
ms.set_clip_label_string([bytes23(row["label_name"])], metadata) ms.set_clip_label_string([bytes23(row["label_name"])], metadata)
if label_map: if label_map:
@ -394,13 +414,13 @@ class Kinetics(object):
def get_label_map_and_verify_example_counts(self, paths): def get_label_map_and_verify_example_counts(self, paths):
"""Verify the number of examples and labels have not changed.""" """Verify the number of examples and labels have not changed."""
label_map = None
for name, path in paths.items(): for name, path in paths.items():
with open(path, "r") as f: with open(path, "r") as f:
lines = f.readlines() lines = f.readlines()
# the header adds one line and one "key". # the header adds one line and one "key".
num_examples = len(lines) - 1 num_examples = len(lines) - 1
keys = [l.split(",")[0] for l in lines] keys = [l.split(",")[0] for l in lines]
label_map = None
if name == "train": if name == "train":
classes = sorted(list(set(keys[1:]))) classes = sorted(list(set(keys[1:])))
num_keys = len(set(keys)) - 1 num_keys = len(set(keys)) - 1