Modify Kinetics datasets reader for custom.
Signed-off-by: Cheoljun Lee <cheoljun.lee@samsung.com>
This commit is contained in:
parent
f42c8fd442
commit
9d70a4c5f9
|
@ -73,6 +73,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from multiprocessing import Process, Queue
|
||||||
|
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
@ -80,7 +81,7 @@ from absl import logging
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
from six.moves import urllib
|
from six.moves import urllib
|
||||||
from six.moves import zip
|
from six.moves import zip
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.util.sequence import media_sequence as ms
|
from mediapipe.util.sequence import media_sequence as ms
|
||||||
|
|
||||||
|
@ -127,7 +128,7 @@ class Kinetics(object):
|
||||||
|
|
||||||
def as_dataset(self, split, shuffle=False, repeat=False,
|
def as_dataset(self, split, shuffle=False, repeat=False,
|
||||||
serialized_prefetch_size=32, decoded_prefetch_size=32,
|
serialized_prefetch_size=32, decoded_prefetch_size=32,
|
||||||
parse_labels=True):
|
parse_labels=True, num_classes=NUM_CLASSES):
|
||||||
"""Returns Kinetics as a tf.data.Dataset.
|
"""Returns Kinetics as a tf.data.Dataset.
|
||||||
|
|
||||||
After running this function, calling padded_batch() on the Dataset object
|
After running this function, calling padded_batch() on the Dataset object
|
||||||
|
@ -195,7 +196,7 @@ class Kinetics(object):
|
||||||
"num_frames": num_frames,
|
"num_frames": num_frames,
|
||||||
}
|
}
|
||||||
if parse_labels:
|
if parse_labels:
|
||||||
target = tf.one_hot(parsed_context[ms.get_clip_label_index_key()], 700)
|
target = tf.one_hot(parsed_context[ms.get_clip_label_index_key()], num_classes)
|
||||||
output_dict["labels"] = target
|
output_dict["labels"] = target
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|
||||||
|
@ -265,14 +266,18 @@ class Kinetics(object):
|
||||||
logging.info("An example of the metadata: ")
|
logging.info("An example of the metadata: ")
|
||||||
logging.info(all_metadata[0])
|
logging.info(all_metadata[0])
|
||||||
random.seed(47)
|
random.seed(47)
|
||||||
random.shuffle(all_metadata)
|
#random.shuffle(all_metadata)
|
||||||
shards = SPLITS[key]["shards"]
|
shards = SPLITS[key]["shards"]
|
||||||
shard_names = [os.path.join(
|
shard_names = [os.path.join(
|
||||||
self.path_to_data, FILEPATTERN % key + "-%05d-of-%05d" % (
|
self.path_to_data, FILEPATTERN % key + "-%05d-of-%05d" % (
|
||||||
i, shards)) for i in range(shards)]
|
i, shards)) for i in range(shards)]
|
||||||
writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names]
|
writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names]
|
||||||
with _close_on_exit(writers) as writers:
|
|
||||||
for i, seq_ex in enumerate(all_metadata):
|
def generate(q, writer):
|
||||||
|
while True:
|
||||||
|
i, seq_ex = q.get()
|
||||||
|
if i is None:
|
||||||
|
break
|
||||||
if not only_generate_metadata:
|
if not only_generate_metadata:
|
||||||
print("Processing example %d of %d (%d%%) \r" % (
|
print("Processing example %d of %d (%d%%) \r" % (
|
||||||
i, len(all_metadata), i * 100 / len(all_metadata)), end="")
|
i, len(all_metadata), i * 100 / len(all_metadata)), end="")
|
||||||
|
@ -280,7 +285,20 @@ class Kinetics(object):
|
||||||
graph_path = os.path.join(path_to_graph_directory, graph)
|
graph_path = os.path.join(path_to_graph_directory, graph)
|
||||||
seq_ex = self._run_mediapipe(
|
seq_ex = self._run_mediapipe(
|
||||||
path_to_mediapipe_binary, seq_ex, graph_path)
|
path_to_mediapipe_binary, seq_ex, graph_path)
|
||||||
writers[i % len(writers)].write(seq_ex.SerializeToString())
|
writer.write(seq_ex.SerializeToString())
|
||||||
|
writer.close()
|
||||||
|
queue = Queue()
|
||||||
|
workers = []
|
||||||
|
for i, seq_ex in enumerate(all_metadata):
|
||||||
|
queue.put((i, seq_ex))
|
||||||
|
for i in range(shards):
|
||||||
|
queue.put((None, None))
|
||||||
|
for i in range(shards):
|
||||||
|
worker = Process(target=generate, args=(queue,writers[i]))
|
||||||
|
worker.start()
|
||||||
|
workers.append(worker)
|
||||||
|
for worker in workers:
|
||||||
|
worker.join()
|
||||||
logging.info("Data extraction complete.")
|
logging.info("Data extraction complete.")
|
||||||
|
|
||||||
def _generate_metadata(self, key, download_output,
|
def _generate_metadata(self, key, download_output,
|
||||||
|
@ -320,6 +338,8 @@ class Kinetics(object):
|
||||||
ms.set_clip_data_path(bytes23(filepath), metadata)
|
ms.set_clip_data_path(bytes23(filepath), metadata)
|
||||||
assert row["start"].isdigit(), "Invalid row: %s" % str(row)
|
assert row["start"].isdigit(), "Invalid row: %s" % str(row)
|
||||||
assert row["end"].isdigit(), "Invalid row: %s" % str(row)
|
assert row["end"].isdigit(), "Invalid row: %s" % str(row)
|
||||||
|
ms.set_clip_start_timestamp(int(float(row["start"]) * SECONDS_TO_MICROSECONDS), metadata)
|
||||||
|
ms.set_clip_end_timestamp(int(float(row["end"]) * SECONDS_TO_MICROSECONDS), metadata)
|
||||||
if "label_name" in row:
|
if "label_name" in row:
|
||||||
ms.set_clip_label_string([bytes23(row["label_name"])], metadata)
|
ms.set_clip_label_string([bytes23(row["label_name"])], metadata)
|
||||||
if label_map:
|
if label_map:
|
||||||
|
@ -394,13 +414,13 @@ class Kinetics(object):
|
||||||
|
|
||||||
def get_label_map_and_verify_example_counts(self, paths):
|
def get_label_map_and_verify_example_counts(self, paths):
|
||||||
"""Verify the number of examples and labels have not changed."""
|
"""Verify the number of examples and labels have not changed."""
|
||||||
|
label_map = None
|
||||||
for name, path in paths.items():
|
for name, path in paths.items():
|
||||||
with open(path, "r") as f:
|
with open(path, "r") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
# the header adds one line and one "key".
|
# the header adds one line and one "key".
|
||||||
num_examples = len(lines) - 1
|
num_examples = len(lines) - 1
|
||||||
keys = [l.split(",")[0] for l in lines]
|
keys = [l.split(",")[0] for l in lines]
|
||||||
label_map = None
|
|
||||||
if name == "train":
|
if name == "train":
|
||||||
classes = sorted(list(set(keys[1:])))
|
classes = sorted(list(set(keys[1:])))
|
||||||
num_keys = len(set(keys)) - 1
|
num_keys = len(set(keys)) - 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user