263 lines
8.7 KiB
Python
263 lines
8.7 KiB
Python
"""Server for YouTube8M Model Inference Demo.
|
|
|
|
Serves up both the static files for the website and provides a service that
|
|
fetches the video id and timestamp based labels for a video analyzed in a
|
|
tfrecord files.
|
|
|
|
"""
|
|
from __future__ import print_function
|
|
import json
|
|
import os
|
|
import re
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
import http.client
|
|
import http.server
|
|
from six.moves.urllib import parse
|
|
|
|
FLAGS = flags.FLAGS
|
|
flags.DEFINE_bool("show_label_at_center", False,
|
|
"Show labels at the center of the segment.")
|
|
flags.DEFINE_integer("port", 8008, "Port that the API is served over.")
|
|
flags.DEFINE_string("tmp_dir", "/tmp/mediapipe",
|
|
"Temporary asset storage location.")
|
|
flags.DEFINE_string("root", "", "MediaPipe root directory.")
|
|
# binary, pbtxt, label_map paths are relative to 'root' path
|
|
flags.DEFINE_string(
|
|
"binary",
|
|
"bazel-bin/mediapipe/examples/desktop/youtube8m/model_inference",
|
|
"Inference binary location.")
|
|
flags.DEFINE_string(
|
|
"pbtxt",
|
|
"mediapipe/graphs/youtube8m/yt8m_dataset_model_inference.pbtxt",
|
|
"Default pbtxt graph file.")
|
|
flags.DEFINE_string("label_map", "mediapipe/graphs/youtube8m/label_map.txt",
|
|
"Default label map text file.")
|
|
|
|
|
|
class HTTPServerV6(http.server.HTTPServer):
|
|
address_family = socket.AF_INET6
|
|
|
|
|
|
class Youtube8MRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|
"""Static file server with /healthz support."""
|
|
|
|
def do_GET(self):
|
|
if self.path.startswith("/healthz"):
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "text/plain")
|
|
self.send_header("Content-length", 2)
|
|
self.end_headers()
|
|
self.wfile.write("ok")
|
|
if self.path.startswith("/video"):
|
|
parsed_params = parse.urlparse(self.path)
|
|
url_params = parse.parse_qs(parsed_params.query)
|
|
|
|
tfrecord_path = ""
|
|
segment_size = 5
|
|
|
|
print(url_params)
|
|
if "file" in url_params:
|
|
tfrecord_path = url_params["file"][0]
|
|
if "segments" in url_params:
|
|
segment_size = int(url_params["segments"][0])
|
|
|
|
self.fetch(tfrecord_path, segment_size)
|
|
|
|
else:
|
|
if self.path == "/":
|
|
self.path = "/index.html"
|
|
# Default to serve up a local file
|
|
self.path = "/static" + self.path
|
|
http.server.SimpleHTTPRequestHandler.do_GET(self)
|
|
|
|
def report_error(self, msg):
|
|
"""Simplifies sending out a string as a 500 http response."""
|
|
self.send_response(500)
|
|
self.send_header("Content-type", "text/plain")
|
|
self.end_headers()
|
|
if sys.version_info[0] < 3:
|
|
self.wfile.write(str(msg).encode("utf-8"))
|
|
else:
|
|
self.wfile.write(bytes(msg, "utf-8"))
|
|
|
|
def report_missing_files(self, files):
|
|
"""Sends out 500 response with missing files."""
|
|
accumulate = ""
|
|
for file_path in files:
|
|
if not os.path.exists(file_path):
|
|
accumulate = "%s '%s'" % (accumulate, file_path)
|
|
|
|
if accumulate:
|
|
self.report_error("Could not find:%s" % accumulate)
|
|
return True
|
|
|
|
return False
|
|
|
|
def fetch(self, path, segment_size):
|
|
"""Returns the video id and labels for a tfrecord at a provided index."""
|
|
|
|
print("Received request. File=", path, "Segment Size =", segment_size)
|
|
|
|
if (self.report_missing_files([
|
|
"%s/%s" % (FLAGS.root, FLAGS.pbtxt),
|
|
"%s/%s" % (FLAGS.root, FLAGS.binary),
|
|
"%s/%s" % (FLAGS.root, FLAGS.label_map)
|
|
])):
|
|
return
|
|
|
|
# Parse the youtube video id off the end of the link or as a standalone id.
|
|
filename_match = re.match(
|
|
"(?:.*youtube.*v=)?([a-zA-Z-0-9_]{2})([a-zA-Z-0-9_]+)", path)
|
|
tfrecord_url = filename_match.expand(r"data.yt8m.org/2/j/r/\1/\1\2.js")
|
|
|
|
print("Trying to get tfrecord via", tfrecord_url)
|
|
|
|
connection = http.client.HTTPConnection("data.yt8m.org")
|
|
connection.request("GET", tfrecord_url)
|
|
response = connection.getresponse()
|
|
|
|
response_object = json.loads(response.read())
|
|
filename = response_object["filename_raw"]
|
|
index = response_object["index"]
|
|
|
|
print("TFRecord discovered: ", filename, ", index", index)
|
|
|
|
output_file = r"%s/%s" % (FLAGS.tmp_dir, filename)
|
|
tfrecord_url = r"http://us.data.yt8m.org/2/frame/train/%s" % filename
|
|
|
|
connection = http.client.HTTPConnection("us.data.yt8m.org")
|
|
connection.request("HEAD",
|
|
filename_match.expand(r"/2/frame/train/%s" % filename))
|
|
response = connection.getresponse()
|
|
if response.getheader("Content-Type") != "application/octet-stream":
|
|
self.report_error("Filename '%s' is invalid." % path)
|
|
|
|
print(output_file, "exists on yt8m.org. Did we fetch this before?")
|
|
|
|
if not os.path.exists(output_file):
|
|
print(output_file, "doesn't exist locally, download it now.")
|
|
return_code = subprocess.call(
|
|
["curl", "--output", output_file, tfrecord_url],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE)
|
|
if return_code:
|
|
self.report_error("Could not retrieve contents from %s" % tfrecord_url)
|
|
return
|
|
else:
|
|
print(output_file, "exist locally, reuse it.")
|
|
|
|
print("Run the graph...")
|
|
process = subprocess.Popen([
|
|
"%s/%s" % (FLAGS.root, FLAGS.binary),
|
|
"--calculator_graph_config_file=%s/%s" % (FLAGS.root, FLAGS.pbtxt),
|
|
"--input_side_packets=tfrecord_path=%s" % output_file +
|
|
",record_index=%d" % index + ",desired_segment_size=%d" % segment_size,
|
|
"--output_stream=annotation_summary",
|
|
"--output_stream_file=%s/labels" % FLAGS.tmp_dir,
|
|
"--output_side_packets=yt8m_id",
|
|
"--output_side_packets_file=%s/yt8m_id" % FLAGS.tmp_dir
|
|
],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE)
|
|
stdout_str, stderr_str = process.communicate()
|
|
process.wait()
|
|
|
|
if stderr_str and "success" not in str(stderr_str).lower():
|
|
self.report_error("Error executing server binary: \n%s" % stderr_str)
|
|
return
|
|
|
|
f = open("%s/yt8m_id" % FLAGS.tmp_dir, "r")
|
|
contents = f.read()
|
|
print("yt8m_id is", contents[-5:-1])
|
|
|
|
curl_arg = "data.yt8m.org/2/j/i/%s/%s.js" % (contents[-5:-3],
|
|
contents[-5:-1])
|
|
print("Grab labels from", curl_arg)
|
|
process = subprocess.Popen(["curl", curl_arg],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE)
|
|
stdout = process.communicate()
|
|
process.wait()
|
|
|
|
stdout_str = stdout[0].decode("utf-8")
|
|
|
|
match = re.match(""".+"([^"]+)"[^"]+""", stdout_str)
|
|
final_results = {
|
|
"video_id": match.group(1),
|
|
"link": "https://www.youtube.com/watch?v=%s" % match.group(1),
|
|
"entries": []
|
|
}
|
|
f = open("%s/labels" % FLAGS.tmp_dir, "r")
|
|
lines = f.readlines()
|
|
show_at_center = FLAGS.show_label_at_center
|
|
|
|
print("%s/labels" % FLAGS.tmp_dir, "holds", len(lines), "entries")
|
|
for line in lines:
|
|
entry = {"labels": []}
|
|
final_results["entries"].append(entry)
|
|
first = True
|
|
for column in line.split(","):
|
|
if first:
|
|
subtract = segment_size / 2.0 if show_at_center else 0.0
|
|
entry["time"] = float(int(column)) / 1000000.0 - subtract
|
|
first = False
|
|
else:
|
|
label_score = re.match("(.+):([0-9.]+).*", column)
|
|
if label_score:
|
|
score = float(label_score.group(2))
|
|
entry["labels"].append({
|
|
"label": label_score.group(1),
|
|
"score": score
|
|
})
|
|
else:
|
|
print("empty score")
|
|
|
|
response_json = json.dumps(final_results, indent=2, separators=(",", ": "))
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "application/json")
|
|
self.end_headers()
|
|
if sys.version_info[0] < 3:
|
|
self.wfile.write(str(response_json).encode("utf-8"))
|
|
else:
|
|
self.wfile.write(bytes(response_json, "utf-8"))
|
|
|
|
|
|
def update_pbtxt():
|
|
"""Update graph.pbtxt to use full path to label_map.txt."""
|
|
edited_line = ""
|
|
lines = []
|
|
with open("%s/%s" % (FLAGS.root, FLAGS.pbtxt), "r") as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
if "label_map_path" in line:
|
|
kv = line.split(":")
|
|
edited_line = kv[0] + (": \"%s/%s\"\n" % (FLAGS.root, FLAGS.label_map))
|
|
with open("%s/%s" % (FLAGS.root, FLAGS.pbtxt), "w") as f:
|
|
for line in lines:
|
|
if "label_map_path" in line:
|
|
f.write(edited_line)
|
|
else:
|
|
f.write(line)
|
|
|
|
|
|
def main(unused_args):
|
|
dname = os.path.dirname(os.path.abspath(__file__))
|
|
os.chdir(dname)
|
|
if not FLAGS.root:
|
|
print("Must specify MediaPipe root directory: --root `pwd`")
|
|
return
|
|
update_pbtxt()
|
|
port = FLAGS.port
|
|
print("Listening on port %s" % port) # pylint: disable=superfluous-parens
|
|
server = HTTPServerV6(("::", int(port)), Youtube8MRequestHandler)
|
|
server.serve_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(main)
|