319 lines
13 KiB
Python
319 lines
13 KiB
Python
# Copyright 2019 The MediaPipe Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
r"""A demo data set constructed with MediaSequence and MediaPipe.
|
|
|
|
This code demonstrates the steps for constructing a data set with MediaSequence.
|
|
This code has two functions. First, it can be run as a module to download and
|
|
prepare a toy dataset. Second, it can be imported and used to provide a
|
|
tf.data.Dataset reading that data from disk via as_dataset().
|
|
|
|
Running as a module prepares the data in three stages via generate_examples().
|
|
First, the actual data files are downloaded. If the download is disrupted, the
|
|
incomplete files will need to be removed before running the script again.
|
|
Second, the annotations are parsed and reformated into metadata as described in
|
|
the MediaSequence documentation. Third, MediaPipe is run to extract subsequences
|
|
of frames for subsequent training via _run_mediapipe().
|
|
|
|
The toy data set is classifying a clip as a panning shot of galaxy or nebula
|
|
from videos releasued under the [Creative Commons Attribution 4.0 International
|
|
license](http://creativecommons.org/licenses/by/4.0/) on the ESA/Hubble site.
|
|
(The use of these ESA/Hubble materials does not imply the endorsement by
|
|
ESA/Hubble or any ESA/Hubble employee of a commercial product or service.) Each
|
|
video is split into 5 or 6 ten-second clips with a label of "galaxy" or "nebula"
|
|
and downsampled to 10 frames per second. (The last clip for each test example is
|
|
only 6 seconds.) There is one video of each class in each of the training and
|
|
testing splits.
|
|
|
|
Reading the data as a tf.data.Dataset is accomplished with the following lines:
|
|
|
|
demo = DemoDataset("demo_data_path")
|
|
dataset = demo.as_dataset("test")
|
|
# implement additional processing and batching here
|
|
images_and_labels = dataset.make_one_shot_iterator().get_next()
|
|
images = images_and_labels["images"]
|
|
labels = image_and_labels["labels"]
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import contextlib
|
|
import csv
|
|
import os
|
|
import random
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
from absl import logging
|
|
from six.moves import range
|
|
from six.moves import urllib
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
from mediapipe.util.sequence import media_sequence as ms
|
|
|
|
SPLITS = {
|
|
"train":
|
|
"""url,label index,label string,duration,credits
|
|
https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1608c.mp4,0,nebula,50,"ESA/Hubble; Music: Johan B. Monell"
|
|
https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1712b.mp4,1,galaxy,50,"ESA/Hubble, Digitized Sky Survey, Nick Risinger (skysurvey.org) Music: Johan B Monell"
|
|
""",
|
|
"test":
|
|
"""url,label index,label string,duration,credits
|
|
https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1301b.m4v,0,nebula,56,"NASA, ESA. Acknowledgement: Josh Lake"
|
|
https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1305b.m4v,1,galaxy,56,"NASA, ESA, Digitized Sky Survey 2. Acknowledgement: A. van der Hoeven"
|
|
"""
|
|
}
|
|
NUM_CLASSES = 2
|
|
NUM_SHARDS = 2
|
|
SECONDS_PER_EXAMPLE = 10
|
|
MICROSECONDS_PER_SECOND = 1000000
|
|
TF_RECORD_PATTERN = "demo_space_dataset_%s_tfrecord"
|
|
GRAPHS = ["clipped_images_from_file_at_24fps.pbtxt"]
|
|
|
|
|
|
class DemoDataset(object):
|
|
"""Generates and loads a demo data set."""
|
|
|
|
def __init__(self, path_to_data):
|
|
if not path_to_data:
|
|
raise ValueError("You must supply the path to the data directory.")
|
|
self.path_to_data = path_to_data
|
|
|
|
def as_dataset(self,
|
|
split,
|
|
shuffle=False,
|
|
repeat=False,
|
|
serialized_prefetch_size=32,
|
|
decoded_prefetch_size=32):
|
|
"""Returns the dataset as a tf.data.Dataset.
|
|
|
|
Args:
|
|
split: either "train" or "test"
|
|
shuffle: if true, shuffles both files and examples.
|
|
repeat: if true, repeats the data set forever.
|
|
serialized_prefetch_size: the buffer size for reading from disk.
|
|
decoded_prefetch_size: the buffer size after decoding.
|
|
|
|
Returns:
|
|
A tf.data.Dataset object with the following structure: {
|
|
"images": uint8 tensor, shape [time, height, width, channels]
|
|
"labels": one hot encoded label tensor, shape [2]
|
|
"id": a unique string id for each example, shape []
|
|
}
|
|
"""
|
|
|
|
def parse_fn(sequence_example):
|
|
"""Parses a clip classification example."""
|
|
context_features = {
|
|
ms.get_example_id_key():
|
|
ms.get_example_id_default_parser(),
|
|
ms.get_clip_label_index_key():
|
|
ms.get_clip_label_index_default_parser(),
|
|
ms.get_clip_label_string_key():
|
|
ms.get_clip_label_string_default_parser()
|
|
}
|
|
sequence_features = {
|
|
ms.get_image_encoded_key(): ms.get_image_encoded_default_parser(),
|
|
}
|
|
parsed_context, parsed_sequence = tf.io.parse_single_sequence_example(
|
|
sequence_example, context_features, sequence_features)
|
|
example_id = parsed_context[ms.get_example_id_key()]
|
|
classification_target = tf.one_hot(
|
|
tf.sparse_tensor_to_dense(
|
|
parsed_context[ms.get_clip_label_index_key()]), NUM_CLASSES)
|
|
images = tf.map_fn(
|
|
tf.image.decode_jpeg,
|
|
parsed_sequence[ms.get_image_encoded_key()],
|
|
back_prop=False,
|
|
dtype=tf.uint8)
|
|
return {
|
|
"id": example_id,
|
|
"labels": classification_target,
|
|
"images": images,
|
|
}
|
|
|
|
if split not in SPLITS:
|
|
raise ValueError("split '%s' is unknown." % split)
|
|
all_shards = tf.io.gfile.glob(
|
|
os.path.join(self.path_to_data, TF_RECORD_PATTERN % split + "-*-of-*"))
|
|
if shuffle:
|
|
random.shuffle(all_shards)
|
|
all_shards_dataset = tf.data.Dataset.from_tensor_slices(all_shards)
|
|
cycle_length = min(16, len(all_shards))
|
|
dataset = all_shards_dataset.apply(
|
|
tf.data.experimental.parallel_interleave(
|
|
tf.data.TFRecordDataset,
|
|
cycle_length=cycle_length,
|
|
block_length=1,
|
|
sloppy=True,
|
|
buffer_output_elements=serialized_prefetch_size))
|
|
dataset = dataset.prefetch(serialized_prefetch_size)
|
|
if shuffle:
|
|
dataset = dataset.shuffle(serialized_prefetch_size)
|
|
if repeat:
|
|
dataset = dataset.repeat()
|
|
dataset = dataset.map(parse_fn)
|
|
dataset = dataset.prefetch(decoded_prefetch_size)
|
|
return dataset
|
|
|
|
def generate_examples(self, path_to_mediapipe_binary,
|
|
path_to_graph_directory):
|
|
"""Downloads data and generates sharded TFRecords.
|
|
|
|
Downloads the data files, generates metadata, and processes the metadata
|
|
with MediaPipe to produce tf.SequenceExamples for training. The resulting
|
|
files can be read with as_dataset(). After running this function the
|
|
original data files can be deleted.
|
|
|
|
Args:
|
|
path_to_mediapipe_binary: Path to the compiled binary for the BUILD target
|
|
mediapipe/examples/desktop/demo:media_sequence_demo.
|
|
path_to_graph_directory: Path to the directory with MediaPipe graphs in
|
|
mediapipe/graphs/media_sequence/.
|
|
"""
|
|
if not path_to_mediapipe_binary:
|
|
raise ValueError("You must supply the path to the MediaPipe binary for "
|
|
"mediapipe/examples/desktop/demo:media_sequence_demo.")
|
|
if not path_to_graph_directory:
|
|
raise ValueError(
|
|
"You must supply the path to the directory with MediaPipe graphs in "
|
|
"mediapipe/graphs/media_sequence/.")
|
|
logging.info("Downloading data.")
|
|
tf.io.gfile.makedirs(self.path_to_data)
|
|
if sys.version_info >= (3, 0):
|
|
urlretrieve = urllib.request.urlretrieve
|
|
else:
|
|
urlretrieve = urllib.request.urlretrieve
|
|
for split in SPLITS:
|
|
reader = csv.DictReader(SPLITS[split].split("\n"))
|
|
all_metadata = []
|
|
for row in reader:
|
|
url = row["url"]
|
|
basename = url.split("/")[-1]
|
|
local_path = os.path.join(self.path_to_data, basename)
|
|
if not tf.io.gfile.exists(local_path):
|
|
urlretrieve(url, local_path)
|
|
|
|
for start_time in range(0, int(row["duration"]), SECONDS_PER_EXAMPLE):
|
|
metadata = tf.train.SequenceExample()
|
|
ms.set_example_id(bytes23(basename + "_" + str(start_time)),
|
|
metadata)
|
|
ms.set_clip_data_path(bytes23(local_path), metadata)
|
|
ms.set_clip_start_timestamp(start_time * MICROSECONDS_PER_SECOND,
|
|
metadata)
|
|
ms.set_clip_end_timestamp(
|
|
(start_time + SECONDS_PER_EXAMPLE) * MICROSECONDS_PER_SECOND,
|
|
metadata)
|
|
ms.set_clip_label_index((int(row["label index"]),), metadata)
|
|
ms.set_clip_label_string((bytes23(row["label string"]),),
|
|
metadata)
|
|
all_metadata.append(metadata)
|
|
random.seed(47)
|
|
random.shuffle(all_metadata)
|
|
shard_names = [self._indexed_shard(split, i) for i in range(NUM_SHARDS)]
|
|
writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names]
|
|
with _close_on_exit(writers) as writers:
|
|
for i, seq_ex in enumerate(all_metadata):
|
|
for graph in GRAPHS:
|
|
graph_path = os.path.join(path_to_graph_directory, graph)
|
|
seq_ex = self._run_mediapipe(path_to_mediapipe_binary, seq_ex,
|
|
graph_path)
|
|
writers[i % len(writers)].write(seq_ex.SerializeToString())
|
|
|
|
def _indexed_shard(self, split, index):
|
|
"""Constructs a sharded filename."""
|
|
return os.path.join(
|
|
self.path_to_data,
|
|
TF_RECORD_PATTERN % split + "-%05d-of-%05d" % (index, NUM_SHARDS))
|
|
|
|
def _run_mediapipe(self, path_to_mediapipe_binary, sequence_example, graph):
|
|
"""Runs MediaPipe over MediaSequence tf.train.SequenceExamples.
|
|
|
|
Args:
|
|
path_to_mediapipe_binary: Path to the compiled binary for the BUILD target
|
|
mediapipe/examples/desktop/demo:media_sequence_demo.
|
|
sequence_example: The SequenceExample with metadata or partial data file.
|
|
graph: The path to the graph that extracts data to add to the
|
|
SequenceExample.
|
|
|
|
Returns:
|
|
A copy of the input SequenceExample with additional data fields added
|
|
by the MediaPipe graph.
|
|
Raises:
|
|
RuntimeError: if MediaPipe returns an error or fails to run the graph.
|
|
"""
|
|
if not path_to_mediapipe_binary:
|
|
raise ValueError("--path_to_mediapipe_binary must be specified.")
|
|
input_fd, input_filename = tempfile.mkstemp()
|
|
output_fd, output_filename = tempfile.mkstemp()
|
|
cmd = [
|
|
path_to_mediapipe_binary,
|
|
"--calculator_graph_config_file=%s" % graph,
|
|
"--input_side_packets=input_sequence_example=%s" % input_filename,
|
|
"--output_side_packets=output_sequence_example=%s" % output_filename
|
|
]
|
|
with open(input_filename, "wb") as input_file:
|
|
input_file.write(sequence_example.SerializeToString())
|
|
mediapipe_output = subprocess.check_output(cmd)
|
|
if b"Failed to run the graph" in mediapipe_output:
|
|
raise RuntimeError(mediapipe_output)
|
|
with open(output_filename, "rb") as output_file:
|
|
output_example = tf.train.SequenceExample()
|
|
output_example.ParseFromString(output_file.read())
|
|
os.close(input_fd)
|
|
os.remove(input_filename)
|
|
os.close(output_fd)
|
|
os.remove(output_filename)
|
|
return output_example
|
|
|
|
|
|
def bytes23(string):
|
|
"""Creates a bytes string in either Python 2 or 3."""
|
|
if sys.version_info >= (3, 0):
|
|
return bytes(string, "utf8")
|
|
else:
|
|
return bytes(string)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _close_on_exit(writers):
|
|
"""Call close on all writers on exit."""
|
|
try:
|
|
yield writers
|
|
finally:
|
|
for writer in writers:
|
|
writer.close()
|
|
|
|
|
|
def main(argv):
|
|
if len(argv) > 1:
|
|
raise app.UsageError("Too many command-line arguments.")
|
|
DemoDataset(flags.FLAGS.path_to_demo_data).generate_examples(
|
|
flags.FLAGS.path_to_mediapipe_binary, flags.FLAGS.path_to_graph_directory)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
flags.DEFINE_string("path_to_demo_data", "",
|
|
"Path to directory to write data to.")
|
|
flags.DEFINE_string("path_to_mediapipe_binary", "",
|
|
"Path to the MediaPipe run_graph_file_io_main binary.")
|
|
flags.DEFINE_string("path_to_graph_directory", "",
|
|
"Path to directory containing the graph files.")
|
|
app.run(main)
|