mediapipe/mediapipe2/examples/desktop/media_sequence/demo_dataset.py
2021-06-10 23:01:19 +00:00

319 lines
13 KiB
Python

# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""A demo data set constructed with MediaSequence and MediaPipe.
This code demonstrates the steps for constructing a data set with MediaSequence.
This code has two functions. First, it can be run as a module to download and
prepare a toy dataset. Second, it can be imported and used to provide a
tf.data.Dataset reading that data from disk via as_dataset().
Running as a module prepares the data in three stages via generate_examples().
First, the actual data files are downloaded. If the download is disrupted, the
incomplete files will need to be removed before running the script again.
Second, the annotations are parsed and reformated into metadata as described in
the MediaSequence documentation. Third, MediaPipe is run to extract subsequences
of frames for subsequent training via _run_mediapipe().
The toy data set is classifying a clip as a panning shot of galaxy or nebula
from videos releasued under the [Creative Commons Attribution 4.0 International
license](http://creativecommons.org/licenses/by/4.0/) on the ESA/Hubble site.
(The use of these ESA/Hubble materials does not imply the endorsement by
ESA/Hubble or any ESA/Hubble employee of a commercial product or service.) Each
video is split into 5 or 6 ten-second clips with a label of "galaxy" or "nebula"
and downsampled to 10 frames per second. (The last clip for each test example is
only 6 seconds.) There is one video of each class in each of the training and
testing splits.
Reading the data as a tf.data.Dataset is accomplished with the following lines:
demo = DemoDataset("demo_data_path")
dataset = demo.as_dataset("test")
# implement additional processing and batching here
images_and_labels = dataset.make_one_shot_iterator().get_next()
images = images_and_labels["images"]
labels = image_and_labels["labels"]
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import csv
import os
import random
import subprocess
import sys
import tempfile
from absl import app
from absl import flags
from absl import logging
from six.moves import range
from six.moves import urllib
import tensorflow.compat.v1 as tf
from mediapipe.util.sequence import media_sequence as ms
SPLITS = {
"train":
"""url,label index,label string,duration,credits
https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1608c.mp4,0,nebula,50,"ESA/Hubble; Music: Johan B. Monell"
https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1712b.mp4,1,galaxy,50,"ESA/Hubble, Digitized Sky Survey, Nick Risinger (skysurvey.org) Music: Johan B Monell"
""",
"test":
"""url,label index,label string,duration,credits
https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1301b.m4v,0,nebula,56,"NASA, ESA. Acknowledgement: Josh Lake"
https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1305b.m4v,1,galaxy,56,"NASA, ESA, Digitized Sky Survey 2. Acknowledgement: A. van der Hoeven"
"""
}
NUM_CLASSES = 2
NUM_SHARDS = 2
SECONDS_PER_EXAMPLE = 10
MICROSECONDS_PER_SECOND = 1000000
TF_RECORD_PATTERN = "demo_space_dataset_%s_tfrecord"
GRAPHS = ["clipped_images_from_file_at_24fps.pbtxt"]
class DemoDataset(object):
"""Generates and loads a demo data set."""
def __init__(self, path_to_data):
if not path_to_data:
raise ValueError("You must supply the path to the data directory.")
self.path_to_data = path_to_data
def as_dataset(self,
split,
shuffle=False,
repeat=False,
serialized_prefetch_size=32,
decoded_prefetch_size=32):
"""Returns the dataset as a tf.data.Dataset.
Args:
split: either "train" or "test"
shuffle: if true, shuffles both files and examples.
repeat: if true, repeats the data set forever.
serialized_prefetch_size: the buffer size for reading from disk.
decoded_prefetch_size: the buffer size after decoding.
Returns:
A tf.data.Dataset object with the following structure: {
"images": uint8 tensor, shape [time, height, width, channels]
"labels": one hot encoded label tensor, shape [2]
"id": a unique string id for each example, shape []
}
"""
def parse_fn(sequence_example):
"""Parses a clip classification example."""
context_features = {
ms.get_example_id_key():
ms.get_example_id_default_parser(),
ms.get_clip_label_index_key():
ms.get_clip_label_index_default_parser(),
ms.get_clip_label_string_key():
ms.get_clip_label_string_default_parser()
}
sequence_features = {
ms.get_image_encoded_key(): ms.get_image_encoded_default_parser(),
}
parsed_context, parsed_sequence = tf.io.parse_single_sequence_example(
sequence_example, context_features, sequence_features)
example_id = parsed_context[ms.get_example_id_key()]
classification_target = tf.one_hot(
tf.sparse_tensor_to_dense(
parsed_context[ms.get_clip_label_index_key()]), NUM_CLASSES)
images = tf.map_fn(
tf.image.decode_jpeg,
parsed_sequence[ms.get_image_encoded_key()],
back_prop=False,
dtype=tf.uint8)
return {
"id": example_id,
"labels": classification_target,
"images": images,
}
if split not in SPLITS:
raise ValueError("split '%s' is unknown." % split)
all_shards = tf.io.gfile.glob(
os.path.join(self.path_to_data, TF_RECORD_PATTERN % split + "-*-of-*"))
if shuffle:
random.shuffle(all_shards)
all_shards_dataset = tf.data.Dataset.from_tensor_slices(all_shards)
cycle_length = min(16, len(all_shards))
dataset = all_shards_dataset.apply(
tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset,
cycle_length=cycle_length,
block_length=1,
sloppy=True,
buffer_output_elements=serialized_prefetch_size))
dataset = dataset.prefetch(serialized_prefetch_size)
if shuffle:
dataset = dataset.shuffle(serialized_prefetch_size)
if repeat:
dataset = dataset.repeat()
dataset = dataset.map(parse_fn)
dataset = dataset.prefetch(decoded_prefetch_size)
return dataset
def generate_examples(self, path_to_mediapipe_binary,
path_to_graph_directory):
"""Downloads data and generates sharded TFRecords.
Downloads the data files, generates metadata, and processes the metadata
with MediaPipe to produce tf.SequenceExamples for training. The resulting
files can be read with as_dataset(). After running this function the
original data files can be deleted.
Args:
path_to_mediapipe_binary: Path to the compiled binary for the BUILD target
mediapipe/examples/desktop/demo:media_sequence_demo.
path_to_graph_directory: Path to the directory with MediaPipe graphs in
mediapipe/graphs/media_sequence/.
"""
if not path_to_mediapipe_binary:
raise ValueError("You must supply the path to the MediaPipe binary for "
"mediapipe/examples/desktop/demo:media_sequence_demo.")
if not path_to_graph_directory:
raise ValueError(
"You must supply the path to the directory with MediaPipe graphs in "
"mediapipe/graphs/media_sequence/.")
logging.info("Downloading data.")
tf.io.gfile.makedirs(self.path_to_data)
if sys.version_info >= (3, 0):
urlretrieve = urllib.request.urlretrieve
else:
urlretrieve = urllib.request.urlretrieve
for split in SPLITS:
reader = csv.DictReader(SPLITS[split].split("\n"))
all_metadata = []
for row in reader:
url = row["url"]
basename = url.split("/")[-1]
local_path = os.path.join(self.path_to_data, basename)
if not tf.io.gfile.exists(local_path):
urlretrieve(url, local_path)
for start_time in range(0, int(row["duration"]), SECONDS_PER_EXAMPLE):
metadata = tf.train.SequenceExample()
ms.set_example_id(bytes23(basename + "_" + str(start_time)),
metadata)
ms.set_clip_data_path(bytes23(local_path), metadata)
ms.set_clip_start_timestamp(start_time * MICROSECONDS_PER_SECOND,
metadata)
ms.set_clip_end_timestamp(
(start_time + SECONDS_PER_EXAMPLE) * MICROSECONDS_PER_SECOND,
metadata)
ms.set_clip_label_index((int(row["label index"]),), metadata)
ms.set_clip_label_string((bytes23(row["label string"]),),
metadata)
all_metadata.append(metadata)
random.seed(47)
random.shuffle(all_metadata)
shard_names = [self._indexed_shard(split, i) for i in range(NUM_SHARDS)]
writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names]
with _close_on_exit(writers) as writers:
for i, seq_ex in enumerate(all_metadata):
for graph in GRAPHS:
graph_path = os.path.join(path_to_graph_directory, graph)
seq_ex = self._run_mediapipe(path_to_mediapipe_binary, seq_ex,
graph_path)
writers[i % len(writers)].write(seq_ex.SerializeToString())
def _indexed_shard(self, split, index):
"""Constructs a sharded filename."""
return os.path.join(
self.path_to_data,
TF_RECORD_PATTERN % split + "-%05d-of-%05d" % (index, NUM_SHARDS))
def _run_mediapipe(self, path_to_mediapipe_binary, sequence_example, graph):
"""Runs MediaPipe over MediaSequence tf.train.SequenceExamples.
Args:
path_to_mediapipe_binary: Path to the compiled binary for the BUILD target
mediapipe/examples/desktop/demo:media_sequence_demo.
sequence_example: The SequenceExample with metadata or partial data file.
graph: The path to the graph that extracts data to add to the
SequenceExample.
Returns:
A copy of the input SequenceExample with additional data fields added
by the MediaPipe graph.
Raises:
RuntimeError: if MediaPipe returns an error or fails to run the graph.
"""
if not path_to_mediapipe_binary:
raise ValueError("--path_to_mediapipe_binary must be specified.")
input_fd, input_filename = tempfile.mkstemp()
output_fd, output_filename = tempfile.mkstemp()
cmd = [
path_to_mediapipe_binary,
"--calculator_graph_config_file=%s" % graph,
"--input_side_packets=input_sequence_example=%s" % input_filename,
"--output_side_packets=output_sequence_example=%s" % output_filename
]
with open(input_filename, "wb") as input_file:
input_file.write(sequence_example.SerializeToString())
mediapipe_output = subprocess.check_output(cmd)
if b"Failed to run the graph" in mediapipe_output:
raise RuntimeError(mediapipe_output)
with open(output_filename, "rb") as output_file:
output_example = tf.train.SequenceExample()
output_example.ParseFromString(output_file.read())
os.close(input_fd)
os.remove(input_filename)
os.close(output_fd)
os.remove(output_filename)
return output_example
def bytes23(string):
"""Creates a bytes string in either Python 2 or 3."""
if sys.version_info >= (3, 0):
return bytes(string, "utf8")
else:
return bytes(string)
@contextlib.contextmanager
def _close_on_exit(writers):
"""Call close on all writers on exit."""
try:
yield writers
finally:
for writer in writers:
writer.close()
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
DemoDataset(flags.FLAGS.path_to_demo_data).generate_examples(
flags.FLAGS.path_to_mediapipe_binary, flags.FLAGS.path_to_graph_directory)
if __name__ == "__main__":
flags.DEFINE_string("path_to_demo_data", "",
"Path to directory to write data to.")
flags.DEFINE_string("path_to_mediapipe_binary", "",
"Path to the MediaPipe run_graph_file_io_main binary.")
flags.DEFINE_string("path_to_graph_directory", "",
"Path to directory containing the graph files.")
app.run(main)