diff --git a/mediapipe/examples/desktop/media_sequence/kinetics_dataset.py b/mediapipe/examples/desktop/media_sequence/kinetics_dataset.py index eafe18f77..4b0bc5449 100644 --- a/mediapipe/examples/desktop/media_sequence/kinetics_dataset.py +++ b/mediapipe/examples/desktop/media_sequence/kinetics_dataset.py @@ -73,6 +73,7 @@ import subprocess import sys import tarfile import tempfile +from multiprocessing import Process, Queue from absl import app from absl import flags @@ -80,7 +81,7 @@ from absl import logging from six.moves import range from six.moves import urllib from six.moves import zip -import tensorflow.compat.v1 as tf +import tensorflow as tf from mediapipe.util.sequence import media_sequence as ms @@ -127,7 +128,7 @@ class Kinetics(object): def as_dataset(self, split, shuffle=False, repeat=False, serialized_prefetch_size=32, decoded_prefetch_size=32, - parse_labels=True): + parse_labels=True, num_classes=NUM_CLASSES): """Returns Kinetics as a tf.data.Dataset. After running this function, calling padded_batch() on the Dataset object @@ -195,7 +196,7 @@ class Kinetics(object): "num_frames": num_frames, } if parse_labels: - target = tf.one_hot(parsed_context[ms.get_clip_label_index_key()], 700) + target = tf.one_hot(parsed_context[ms.get_clip_label_index_key()], num_classes) output_dict["labels"] = target return output_dict @@ -265,14 +266,18 @@ class Kinetics(object): logging.info("An example of the metadata: ") logging.info(all_metadata[0]) random.seed(47) - random.shuffle(all_metadata) + #random.shuffle(all_metadata) shards = SPLITS[key]["shards"] shard_names = [os.path.join( self.path_to_data, FILEPATTERN % key + "-%05d-of-%05d" % ( i, shards)) for i in range(shards)] writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names] - with _close_on_exit(writers) as writers: - for i, seq_ex in enumerate(all_metadata): + + def generate(q, writer): + while True: + i, seq_ex = q.get() + if i is None: + break if not only_generate_metadata: print("Processing example %d of %d (%d%%) \r" % ( i, len(all_metadata), i * 100 / len(all_metadata)), end="") @@ -280,7 +285,20 @@ class Kinetics(object): graph_path = os.path.join(path_to_graph_directory, graph) seq_ex = self._run_mediapipe( path_to_mediapipe_binary, seq_ex, graph_path) - writers[i % len(writers)].write(seq_ex.SerializeToString()) + writer.write(seq_ex.SerializeToString()) + writer.close() + queue = Queue() + workers = [] + for i, seq_ex in enumerate(all_metadata): + queue.put((i, seq_ex)) + for i in range(shards): + queue.put((None, None)) + for i in range(shards): + worker = Process(target=generate, args=(queue,writers[i])) + worker.start() + workers.append(worker) + for worker in workers: + worker.join() logging.info("Data extraction complete.") def _generate_metadata(self, key, download_output, @@ -320,6 +338,8 @@ class Kinetics(object): ms.set_clip_data_path(bytes23(filepath), metadata) assert row["start"].isdigit(), "Invalid row: %s" % str(row) assert row["end"].isdigit(), "Invalid row: %s" % str(row) + ms.set_clip_start_timestamp(int(float(row["start"]) * SECONDS_TO_MICROSECONDS), metadata) + ms.set_clip_end_timestamp(int(float(row["end"]) * SECONDS_TO_MICROSECONDS), metadata) if "label_name" in row: ms.set_clip_label_string([bytes23(row["label_name"])], metadata) if label_map: @@ -394,13 +414,13 @@ class Kinetics(object): def get_label_map_and_verify_example_counts(self, paths): """Verify the number of examples and labels have not changed.""" + label_map = None for name, path in paths.items(): with open(path, "r") as f: lines = f.readlines() # the header adds one line and one "key". num_examples = len(lines) - 1 keys = [l.split(",")[0] for l in lines] - label_map = None if name == "train": classes = sorted(list(set(keys[1:]))) num_keys = len(set(keys)) - 1