Modify Kinetics datasets reader for custom.
Signed-off-by: Cheoljun Lee <cheoljun.lee@samsung.com>
This commit is contained in:
parent
f42c8fd442
commit
9d70a4c5f9
|
@ -73,6 +73,7 @@ import subprocess
|
|||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
from multiprocessing import Process, Queue
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
@ -80,7 +81,7 @@ from absl import logging
|
|||
from six.moves import range
|
||||
from six.moves import urllib
|
||||
from six.moves import zip
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.util.sequence import media_sequence as ms
|
||||
|
||||
|
@ -127,7 +128,7 @@ class Kinetics(object):
|
|||
|
||||
def as_dataset(self, split, shuffle=False, repeat=False,
|
||||
serialized_prefetch_size=32, decoded_prefetch_size=32,
|
||||
parse_labels=True):
|
||||
parse_labels=True, num_classes=NUM_CLASSES):
|
||||
"""Returns Kinetics as a tf.data.Dataset.
|
||||
|
||||
After running this function, calling padded_batch() on the Dataset object
|
||||
|
@ -195,7 +196,7 @@ class Kinetics(object):
|
|||
"num_frames": num_frames,
|
||||
}
|
||||
if parse_labels:
|
||||
target = tf.one_hot(parsed_context[ms.get_clip_label_index_key()], 700)
|
||||
target = tf.one_hot(parsed_context[ms.get_clip_label_index_key()], num_classes)
|
||||
output_dict["labels"] = target
|
||||
return output_dict
|
||||
|
||||
|
@ -265,14 +266,18 @@ class Kinetics(object):
|
|||
logging.info("An example of the metadata: ")
|
||||
logging.info(all_metadata[0])
|
||||
random.seed(47)
|
||||
random.shuffle(all_metadata)
|
||||
#random.shuffle(all_metadata)
|
||||
shards = SPLITS[key]["shards"]
|
||||
shard_names = [os.path.join(
|
||||
self.path_to_data, FILEPATTERN % key + "-%05d-of-%05d" % (
|
||||
i, shards)) for i in range(shards)]
|
||||
writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names]
|
||||
with _close_on_exit(writers) as writers:
|
||||
for i, seq_ex in enumerate(all_metadata):
|
||||
|
||||
def generate(q, writer):
|
||||
while True:
|
||||
i, seq_ex = q.get()
|
||||
if i is None:
|
||||
break
|
||||
if not only_generate_metadata:
|
||||
print("Processing example %d of %d (%d%%) \r" % (
|
||||
i, len(all_metadata), i * 100 / len(all_metadata)), end="")
|
||||
|
@ -280,7 +285,20 @@ class Kinetics(object):
|
|||
graph_path = os.path.join(path_to_graph_directory, graph)
|
||||
seq_ex = self._run_mediapipe(
|
||||
path_to_mediapipe_binary, seq_ex, graph_path)
|
||||
writers[i % len(writers)].write(seq_ex.SerializeToString())
|
||||
writer.write(seq_ex.SerializeToString())
|
||||
writer.close()
|
||||
queue = Queue()
|
||||
workers = []
|
||||
for i, seq_ex in enumerate(all_metadata):
|
||||
queue.put((i, seq_ex))
|
||||
for i in range(shards):
|
||||
queue.put((None, None))
|
||||
for i in range(shards):
|
||||
worker = Process(target=generate, args=(queue,writers[i]))
|
||||
worker.start()
|
||||
workers.append(worker)
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
logging.info("Data extraction complete.")
|
||||
|
||||
def _generate_metadata(self, key, download_output,
|
||||
|
@ -320,6 +338,8 @@ class Kinetics(object):
|
|||
ms.set_clip_data_path(bytes23(filepath), metadata)
|
||||
assert row["start"].isdigit(), "Invalid row: %s" % str(row)
|
||||
assert row["end"].isdigit(), "Invalid row: %s" % str(row)
|
||||
ms.set_clip_start_timestamp(int(float(row["start"]) * SECONDS_TO_MICROSECONDS), metadata)
|
||||
ms.set_clip_end_timestamp(int(float(row["end"]) * SECONDS_TO_MICROSECONDS), metadata)
|
||||
if "label_name" in row:
|
||||
ms.set_clip_label_string([bytes23(row["label_name"])], metadata)
|
||||
if label_map:
|
||||
|
@ -394,13 +414,13 @@ class Kinetics(object):
|
|||
|
||||
def get_label_map_and_verify_example_counts(self, paths):
|
||||
"""Verify the number of examples and labels have not changed."""
|
||||
label_map = None
|
||||
for name, path in paths.items():
|
||||
with open(path, "r") as f:
|
||||
lines = f.readlines()
|
||||
# the header adds one line and one "key".
|
||||
num_examples = len(lines) - 1
|
||||
keys = [l.split(",")[0] for l in lines]
|
||||
label_map = None
|
||||
if name == "train":
|
||||
classes = sorted(list(set(keys[1:])))
|
||||
num_keys = len(set(keys)) - 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user