diff --git a/.bazelrc b/.bazelrc index 3433e846d..45de17ff9 100644 --- a/.bazelrc +++ b/.bazelrc @@ -34,3 +34,30 @@ build:android_arm --fat_apk_cpu=armeabi-v7a build:android_arm64 --config=android build:android_arm64 --cpu=arm64-v8a build:android_arm64 --fat_apk_cpu=arm64-v8a + +# iOS configs. +build:ios --apple_platform_type=ios + +build:ios_i386 --config=ios +build:ios_i386 --cpu=ios_i386 +build:ios_i386 --watchos_cpus=i386 + +build:ios_x86_64 --config=ios +build:ios_x86_64 --cpu=ios_x86_64 +build:ios_x86_64 --watchos_cpus=i386 + +build:ios_armv7 --config=ios +build:ios_armv7 --cpu=ios_armv7 +build:ios_armv7 --watchos_cpus=armv7k + +build:ios_arm64 --config=ios +build:ios_arm64 --cpu=ios_arm64 +build:ios_arm64 --watchos_cpus=armv7k + +build:ios_arm64e --config=ios +build:ios_arm64e --cpu=ios_arm64e +build:ios_arm64e --watchos_cpus=armv7k + +build:ios_fat --config=ios +build:ios_fat --ios_multi_cpus=armv7,arm64 +build:ios_fat --watchos_cpus=armv7k diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..cc06ef42e --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +mediapipe/provisioning_profile.mobileprovision diff --git a/BUILD b/BUILD index 38d7cc1d7..f225f24e3 100644 --- a/BUILD +++ b/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipeOSS Authors. +# Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/Dockerfile b/Dockerfile index ad7c6f909..972198b52 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,14 +28,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ wget \ unzip \ python \ + python-pip \ libopencv-core-dev \ libopencv-highgui-dev \ libopencv-imgproc-dev \ libopencv-video-dev \ - && \ + software-properties-common && \ + add-apt-repository -y ppa:openjdk-r/ppa && \ + apt-get update && apt-get install -y openjdk-11-jdk && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* +RUN pip install --upgrade setuptools +RUN pip install future + # Install bazel ARG BAZEL_VERSION=0.26.1 RUN mkdir /bazel && \ @@ -49,4 +55,4 @@ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ COPY . /mediapipe/ # If we want the docker image to contain the pre-built object_detection_offline_demo binary, do the following -# RUN bazel build -c opt --define 'MEDIAPIPE_DISABLE_GPU=1' mediapipe/examples/desktop/demo:object_detection_tensorflow_demo +# RUN bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/demo:object_detection_tensorflow_demo diff --git a/WORKSPACE b/WORKSPACE index b935a70d7..21f5f224a 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,12 @@ workspace(name = "mediapipe") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +skylib_version = "0.8.0" http_archive( name = "bazel_skylib", - sha256 = "bbccf674aa441c266df9894182d80de104cabd19be98be002f6d478aaa31574d", - strip_prefix = "bazel-skylib-2169ae1c374aab4a09aa90e65efe1a3aad4e279b", - urls = ["https://github.com/bazelbuild/bazel-skylib/archive/2169ae1c374aab4a09aa90e65efe1a3aad4e279b.tar.gz"], + type = "tar.gz", + url = "https://github.com/bazelbuild/bazel-skylib/releases/download/{}/bazel-skylib.{}.tar.gz".format (skylib_version, skylib_version), + sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e", ) load("@bazel_skylib//lib:versions.bzl", "versions") versions.check(minimum_bazel_version = "0.23.0") @@ -52,7 +53,7 @@ http_archive( # glog http_archive( - name = "com_google_glog", + name = "com_github_glog_glog", url = "https://github.com/google/glog/archive/v0.3.5.zip", sha256 = "267103f8a1e9578978aa1dc256001e6529ef593e5aea38193d31c2872ee025e8", strip_prefix = "glog-0.3.5", @@ -73,6 +74,12 @@ http_archive( urls = ["https://github.com/google/protobuf/archive/384989534b2246d413dbcd750744faab2607b516.zip"], ) +http_archive( + name = "com_google_audio_tools", + strip_prefix = "multichannel-audio-tools-master", + urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"], +) + # Needed by TensorFlow http_archive( name = "io_bazel_rules_closure", @@ -84,12 +91,24 @@ http_archive( ], ) -# TensorFlow r1.14-rc0 +# 2019-08-15 +_TENSORFLOW_GIT_COMMIT = "67def62936e28f97c16182dfcc467d8d1cae02b4" +_TENSORFLOW_SHA256= "ddd4e3c056e7c0ff2ef29133b30fa62781dfbf8a903e99efb91a02d292fa9562" http_archive( name = "org_tensorflow", - strip_prefix = "tensorflow-1.14.0-rc0", - sha256 = "76404a6157a45e8d7a07e4f5690275256260130145924c2a7c73f6eda2a3de10", - urls = ["https://github.com/tensorflow/tensorflow/archive/v1.14.0-rc0.zip"], + urls = [ + "https://mirror.bazel.build/github.com/tensorflow/tensorflow/archive/%s.tar.gz" % _TENSORFLOW_GIT_COMMIT, + "https://github.com/tensorflow/tensorflow/archive/%s.tar.gz" % _TENSORFLOW_GIT_COMMIT, + ], + strip_prefix = "tensorflow-%s" % _TENSORFLOW_GIT_COMMIT, + sha256 = _TENSORFLOW_SHA256, + patches = [ + "@//third_party:tensorflow_065c20bf79253257c87bd4614bb9a7fdef015cbb.diff", + "@//third_party:tensorflow_f67fcbefce906cd419e4657f0d41e21019b71abd.diff", + ], + patch_args = [ + "-p1", + ], ) load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") @@ -102,6 +121,12 @@ new_local_repository( path = "/usr", ) +new_local_repository( + name = "linux_ffmpeg", + build_file = "@//third_party:ffmpeg_linux.BUILD", + path = "/usr" +) + # Please run $ brew install opencv new_local_repository( name = "macos_opencv", @@ -109,15 +134,33 @@ new_local_repository( path = "/usr", ) +new_local_repository( + name = "macos_ffmpeg", + build_file = "@//third_party:ffmpeg_macos.BUILD", + path = "/usr", +) + http_archive( name = "android_opencv", - sha256="056b849842e4fa8751d09edbb64530cfa7a63c84ccd232d0ace330e27ba55d0b", + sha256 = "056b849842e4fa8751d09edbb64530cfa7a63c84ccd232d0ace330e27ba55d0b", build_file = "@//third_party:opencv_android.BUILD", strip_prefix = "OpenCV-android-sdk", type = "zip", url = "https://github.com/opencv/opencv/releases/download/4.1.0/opencv-4.1.0-android-sdk.zip", ) +# After OpenCV 3.2.0, the pre-compiled opencv2.framework has google protobuf symbols, which will +# trigger duplicate symbol errors in the linking stage of building a mediapipe ios app. +# To get a higher version of OpenCV for iOS, opencv2.framework needs to be built from source with +# '-DBUILD_PROTOBUF=OFF -DBUILD_opencv_dnn=OFF'. +http_archive( + name = "ios_opencv", + sha256 = "7dd536d06f59e6e1156b546bd581523d8df92ce83440002885ec5abc06558de2", + build_file = "@//third_party:opencv_ios.BUILD", + type = "zip", + url = "https://github.com/opencv/opencv/releases/download/3.2.0/opencv-3.2.0-ios-framework.zip", +) + RULES_JVM_EXTERNAL_TAG = "2.2" RULES_JVM_EXTERNAL_SHA = "f1203ce04e232ab6fdd81897cf0ff76f2c04c0741424d192f28e65ae752ce2d6" @@ -132,12 +175,15 @@ load("@rules_jvm_external//:defs.bzl", "maven_install") maven_install( artifacts = [ - "com.android.support.constraint:constraint-layout:aar:1.0.2", - "androidx.appcompat:appcompat:aar:1.0.2", - ], - repositories = [ - "https://dl.google.com/dl/android/maven2", + "androidx.annotation:annotation:aar:1.1.0", + "androidx.appcompat:appcompat:aar:1.1.0-rc01", + "androidx.constraintlayout:constraintlayout:aar:1.1.3", + "androidx.core:core:aar:1.1.0-rc03", + "androidx.legacy:legacy-support-v4:aar:1.0.0", + "androidx.recyclerview:recyclerview:aar:1.1.0-beta02", + "com.google.android.material:material:aar:1.0.0-rc01", ], + repositories = ["https://dl.google.com/dl/android/maven2"], ) maven_server( @@ -191,3 +237,50 @@ android_ndk_repository( android_sdk_repository( name = "androidsdk", ) + +# iOS basic build deps. + +load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") + +git_repository( + name = "build_bazel_rules_apple", + remote = "https://github.com/bazelbuild/rules_apple.git", + tag = "0.18.0", + patches = [ + "@//third_party:rules_apple_c0863d0596ae6b769a29fa3fb72ff036444fd249.diff", + ], + patch_args = [ + "-p1", + ], +) + +load( + "@build_bazel_rules_apple//apple:repositories.bzl", + "apple_rules_dependencies", +) + +apple_rules_dependencies() + +load( + "@build_bazel_rules_swift//swift:repositories.bzl", + "swift_rules_dependencies", +) + +swift_rules_dependencies() + +load( + "@build_bazel_apple_support//lib:repositories.bzl", + "apple_support_dependencies", +) + +apple_support_dependencies() + +# More iOS deps. + +http_archive( + name = "google_toolbox_for_mac", + url = "https://github.com/google/google-toolbox-for-mac/archive/v2.2.1.zip", + sha256 = "e3ac053813c989a88703556df4dc4466e424e30d32108433ed6beaec76ba4fdc", + strip_prefix = "google-toolbox-for-mac-2.2.1", + build_file = "@//third_party:google_toolbox_for_mac.BUILD", +) diff --git a/mediapipe/BUILD b/mediapipe/BUILD index a5fc4d1f6..d8b37ef20 100644 --- a/mediapipe/BUILD +++ b/mediapipe/BUILD @@ -65,11 +65,73 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( +# Note: this cannot just match "apple_platform_type": "macos" because that option +# defaults to "macos" even when building on Linux! +alias( name = "macos", + actual = select({ + ":macos_i386": ":macos_i386", + ":macos_x86_64": ":macos_x86_64", + "//conditions:default": ":macos_i386", # Arbitrarily chosen from above. + }), + visibility = ["//visibility:public"], +) + +# Note: this also matches on crosstool_top so that it does not produce ambiguous +# selectors when used together with "android". +config_setting( + name = "ios", + values = { + "crosstool_top": "@bazel_tools//tools/cpp:toolchain", + "apple_platform_type": "ios", + }, + visibility = ["//visibility:public"], +) + +alias( + name = "apple", + actual = select({ + ":macos": ":macos", + ":ios": ":ios", + "//conditions:default": ":ios", # Arbitrarily chosen from above. + }), + visibility = ["//visibility:public"], +) + +config_setting( + name = "macos_i386", values = { "apple_platform_type": "macos", "cpu": "darwin", }, visibility = ["//visibility:public"], ) + +config_setting( + name = "macos_x86_64", + values = { + "apple_platform_type": "macos", + "cpu": "darwin_x86_64", + }, + visibility = ["//visibility:public"], +) + +[ + config_setting( + name = arch, + values = {"cpu": arch}, + visibility = ["//visibility:public"], + ) + for arch in [ + "ios_i386", + "ios_x86_64", + "ios_armv7", + "ios_arm64", + "ios_arm64e", + ] +] + +exports_files( + ["provisioning_profile.mobileprovision"], + visibility = ["//visibility:public"], +) diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD new file mode 100644 index 000000000..0e845b20f --- /dev/null +++ b/mediapipe/calculators/audio/BUILD @@ -0,0 +1,305 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + +proto_library( + name = "mfcc_mel_calculators_proto", + srcs = ["mfcc_mel_calculators.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +mediapipe_cc_proto_library( + name = "mfcc_mel_calculators_cc_proto", + srcs = ["mfcc_mel_calculators.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":mfcc_mel_calculators_proto"], +) + +proto_library( + name = "rational_factor_resample_calculator_proto", + srcs = ["rational_factor_resample_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +mediapipe_cc_proto_library( + name = "rational_factor_resample_calculator_cc_proto", + srcs = ["rational_factor_resample_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":rational_factor_resample_calculator_proto"], +) + +proto_library( + name = "spectrogram_calculator_proto", + srcs = ["spectrogram_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "spectrogram_calculator_cc_proto", + srcs = ["spectrogram_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":spectrogram_calculator_proto"], +) + +proto_library( + name = "time_series_framer_calculator_proto", + srcs = ["time_series_framer_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +mediapipe_cc_proto_library( + name = "time_series_framer_calculator_cc_proto", + srcs = ["time_series_framer_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":time_series_framer_calculator_proto"], +) + +cc_library( + name = "audio_decoder_calculator", + srcs = ["audio_decoder_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/util:audio_decoder", + "//mediapipe/util:audio_decoder_cc_proto", + ], + alwayslink = 1, +) + +cc_library( + name = "basic_time_series_calculators", + srcs = ["basic_time_series_calculators.cc"], + hdrs = ["basic_time_series_calculators.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/util:time_series_util", + "@com_google_absl//absl/strings", + "@eigen_archive//:eigen", + ], + alwayslink = 1, +) + +cc_library( + name = "mfcc_mel_calculators", + srcs = ["mfcc_mel_calculators.cc"], + visibility = ["//visibility:public"], + deps = [ + ":mfcc_mel_calculators_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/util:time_series_util", + "@com_google_absl//absl/strings", + "@com_google_audio_tools//audio/dsp/mfcc", + "@eigen_archive//:eigen", + ], + alwayslink = 1, +) + +cc_library( + name = "rational_factor_resample_calculator", + srcs = ["rational_factor_resample_calculator.cc"], + hdrs = ["rational_factor_resample_calculator.h"], + visibility = ["//visibility:public"], + deps = [ + ":rational_factor_resample_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/util:time_series_util", + "@com_google_absl//absl/strings", + "@com_google_audio_tools//audio/dsp:resampler", + "@com_google_audio_tools//audio/dsp:resampler_rational_factor", + "@eigen_archive//:eigen", + ], + alwayslink = 1, +) + +cc_library( + name = "spectrogram_calculator", + srcs = ["spectrogram_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":spectrogram_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + "//mediapipe/util:time_series_util", + "@com_google_absl//absl/strings", + "@com_google_audio_tools//audio/dsp:window_functions", + "@com_google_audio_tools//audio/dsp/spectrogram", + "@eigen_archive//:eigen", + ], + alwayslink = 1, +) + +cc_library( + name = "time_series_framer_calculator", + srcs = ["time_series_framer_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":time_series_framer_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/util:time_series_util", + "@com_google_audio_tools//audio/dsp:window_functions", + "@eigen_archive//:eigen", + ], + alwayslink = 1, +) + +cc_test( + name = "audio_decoder_calculator_test", + srcs = ["audio_decoder_calculator_test.cc"], + data = ["//mediapipe/calculators/audio/testdata:test_audios"], + deps = [ + ":audio_decoder_calculator", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) + +cc_test( + name = "basic_time_series_calculators_test", + srcs = ["basic_time_series_calculators_test.cc"], + deps = [ + ":basic_time_series_calculators", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/util:time_series_test_util", + "@eigen_archive//:eigen", + ], +) + +cc_test( + name = "mfcc_mel_calculators_test", + srcs = ["mfcc_mel_calculators_test.cc"], + deps = [ + ":mfcc_mel_calculators", + ":mfcc_mel_calculators_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/util:time_series_test_util", + "@eigen_archive//:eigen", + ], +) + +cc_test( + name = "spectrogram_calculator_test", + srcs = ["spectrogram_calculator_test.cc"], + deps = [ + ":spectrogram_calculator", + ":spectrogram_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:benchmark", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:status", + "//mediapipe/util:time_series_test_util", + "@com_google_audio_tools//audio/dsp:number_util", + "@eigen_archive//:eigen", + ], +) + +cc_test( + name = "time_series_framer_calculator_test", + srcs = ["time_series_framer_calculator_test.cc"], + deps = [ + ":time_series_framer_calculator", + ":time_series_framer_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:status", + "//mediapipe/util:time_series_test_util", + "@com_google_audio_tools//audio/dsp:window_functions", + "@eigen_archive//:eigen", + ], +) + +cc_test( + name = "rational_factor_resample_calculator_test", + srcs = ["rational_factor_resample_calculator_test.cc"], + deps = [ + ":rational_factor_resample_calculator", + ":rational_factor_resample_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:validate_type", + "//mediapipe/util:time_series_test_util", + "@com_google_audio_tools//audio/dsp:signal_vector_util", + "@eigen_archive//:eigen", + ], +) diff --git a/mediapipe/calculators/audio/audio_decoder_calculator.cc b/mediapipe/calculators/audio/audio_decoder_calculator.cc new file mode 100644 index 000000000..24dae2b44 --- /dev/null +++ b/mediapipe/calculators/audio/audio_decoder_calculator.cc @@ -0,0 +1,106 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/audio_decoder.h" +#include "mediapipe/util/audio_decoder.pb.h" + +namespace mediapipe { + +// The AudioDecoderCalculator decodes an audio stream of the media file. It +// produces two output streams contain audio packets and the header infomation. +// +// Output Streams: +// AUDIO: Output audio frames (Matrix). +// AUDIO_HEADER: +// Optional audio header information output +// Input Side Packets: +// INPUT_FILE_PATH: The input file path. +// +// Example config: +// node { +// calculator: "AudioDecoderCalculator" +// input_side_packet: "INPUT_FILE_PATH:input_file_path" +// output_stream: "AUDIO:audio" +// output_stream: "AUDIO_HEADER:audio_header" +// node_options { +// [type.googleapis.com/mediapipe.AudioDecoderOptions]: { +// audio_stream { stream_index: 0 } +// start_time: 0 +// end_time: 1 +// } +// } +// +// TODO: support decoding multiple streams. +class AudioDecoderCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + std::unique_ptr decoder_; +}; + +::mediapipe::Status AudioDecoderCalculator::GetContract( + CalculatorContract* cc) { + cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set(); + + cc->Outputs().Tag("AUDIO").Set(); + if (cc->Outputs().HasTag("AUDIO_HEADER")) { + cc->Outputs().Tag("AUDIO_HEADER").Set(); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AudioDecoderCalculator::Open(CalculatorContext* cc) { + const std::string& input_file_path = + cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get(); + const auto& decoder_options = cc->Options(); + decoder_ = absl::make_unique(); + RETURN_IF_ERROR(decoder_->Initialize(input_file_path, decoder_options)); + std::unique_ptr header = + absl::make_unique(); + if (decoder_->FillAudioHeader(decoder_options.audio_stream(0), header.get()) + .ok()) { + // Only pass on a header if the decoder could actually produce one. + // otherwise, the header will be empty. + cc->Outputs().Tag("AUDIO_HEADER").SetHeader(Adopt(header.release())); + } + cc->Outputs().Tag("AUDIO_HEADER").Close(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AudioDecoderCalculator::Process(CalculatorContext* cc) { + Packet data; + int options_index = -1; + auto status = decoder_->GetData(&options_index, &data); + if (status.ok()) { + cc->Outputs().Tag("AUDIO").AddPacket(data); + } + return status; +} + +::mediapipe::Status AudioDecoderCalculator::Close(CalculatorContext* cc) { + return decoder_->Close(); +} + +REGISTER_CALCULATOR(AudioDecoderCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/audio_decoder_calculator_test.cc b/mediapipe/calculators/audio/audio_decoder_calculator_test.cc new file mode 100644 index 000000000..e65fe1e41 --- /dev/null +++ b/mediapipe/calculators/audio/audio_decoder_calculator_test.cc @@ -0,0 +1,153 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +TEST(AudioDecoderCalculatorTest, TestWAV) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "AudioDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_file_path" + output_stream: "AUDIO:audio" + output_stream: "AUDIO_HEADER:audio_header" + node_options { + [type.googleapis.com/mediapipe.AudioDecoderOptions]: { + audio_stream { stream_index: 0 } + } + })"); + CalculatorRunner runner(node_config); + runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + file::JoinPath("./", + "/mediapipe/calculators/audio/" + "testdata/sine_wave_1k_44100_mono_2_sec_wav.audio")); + MEDIAPIPE_ASSERT_OK(runner.Run()); + MEDIAPIPE_EXPECT_OK( + runner.Outputs() + .Tag("AUDIO_HEADER") + .header.ValidateAsType()); + const mediapipe::TimeSeriesHeader& header = + runner.Outputs() + .Tag("AUDIO_HEADER") + .header.Get(); + EXPECT_EQ(44100, header.sample_rate()); + EXPECT_EQ(1, header.num_channels()); + EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >= + std::ceil(44100.0 * 2 / 2048)); +} + +TEST(AudioDecoderCalculatorTest, Test48KWAV) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "AudioDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_file_path" + output_stream: "AUDIO:audio" + output_stream: "AUDIO_HEADER:audio_header" + node_options { + [type.googleapis.com/mediapipe.AudioDecoderOptions]: { + audio_stream { stream_index: 0 } + } + })"); + CalculatorRunner runner(node_config); + runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + file::JoinPath("./", + "/mediapipe/calculators/audio/" + "testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio")); + MEDIAPIPE_ASSERT_OK(runner.Run()); + MEDIAPIPE_EXPECT_OK( + runner.Outputs() + .Tag("AUDIO_HEADER") + .header.ValidateAsType()); + const mediapipe::TimeSeriesHeader& header = + runner.Outputs() + .Tag("AUDIO_HEADER") + .header.Get(); + EXPECT_EQ(48000, header.sample_rate()); + EXPECT_EQ(2, header.num_channels()); + EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >= + std::ceil(48000.0 * 2 / 1024)); +} + +TEST(AudioDecoderCalculatorTest, TestMP3) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "AudioDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_file_path" + output_stream: "AUDIO:audio" + output_stream: "AUDIO_HEADER:audio_header" + node_options { + [type.googleapis.com/mediapipe.AudioDecoderOptions]: { + audio_stream { stream_index: 0 } + } + })"); + CalculatorRunner runner(node_config); + runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + file::JoinPath("./", + "/mediapipe/calculators/audio/" + "testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio")); + MEDIAPIPE_ASSERT_OK(runner.Run()); + MEDIAPIPE_EXPECT_OK( + runner.Outputs() + .Tag("AUDIO_HEADER") + .header.ValidateAsType()); + const mediapipe::TimeSeriesHeader& header = + runner.Outputs() + .Tag("AUDIO_HEADER") + .header.Get(); + EXPECT_EQ(44100, header.sample_rate()); + EXPECT_EQ(2, header.num_channels()); + EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >= + std::ceil(44100.0 * 2 / 1152)); +} + +TEST(AudioDecoderCalculatorTest, TestAAC) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "AudioDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_file_path" + output_stream: "AUDIO:audio" + output_stream: "AUDIO_HEADER:audio_header" + node_options { + [type.googleapis.com/mediapipe.AudioDecoderOptions]: { + audio_stream { stream_index: 0 } + } + })"); + CalculatorRunner runner(node_config); + runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket( + file::JoinPath("./", + "/mediapipe/calculators/audio/" + "testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio")); + MEDIAPIPE_ASSERT_OK(runner.Run()); + MEDIAPIPE_EXPECT_OK( + runner.Outputs() + .Tag("AUDIO_HEADER") + .header.ValidateAsType()); + const mediapipe::TimeSeriesHeader& header = + runner.Outputs() + .Tag("AUDIO_HEADER") + .header.Get(); + EXPECT_EQ(44100, header.sample_rate()); + EXPECT_EQ(2, header.num_channels()); + EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >= + std::ceil(44100.0 * 2 / 1024)); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/basic_time_series_calculators.cc b/mediapipe/calculators/audio/basic_time_series_calculators.cc new file mode 100644 index 000000000..e05dde6a0 --- /dev/null +++ b/mediapipe/calculators/audio/basic_time_series_calculators.cc @@ -0,0 +1,403 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Basic Calculators that operate on TimeSeries streams. +#include "mediapipe/calculators/audio/basic_time_series_calculators.h" + +#include +#include + +#include "Eigen/Core" +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/time_series_util.h" + +namespace mediapipe { +namespace { +static bool SafeMultiply(int x, int y, int* result) { + static_assert(sizeof(int64) >= 2 * sizeof(int), + "Unable to detect overflow after multiplication"); + const int64 big = static_cast(x) * static_cast(y); + if (big > static_cast(INT_MIN) && big < static_cast(INT_MAX)) { + if (result != nullptr) *result = static_cast(big); + return true; + } else { + return false; + } +} +} // namespace + +::mediapipe::Status BasicTimeSeriesCalculatorBase::GetContract( + CalculatorContract* cc) { + cc->Inputs().Index(0).Set( + // Input stream with TimeSeriesHeader. + ); + cc->Outputs().Index(0).Set( + // Output stream with TimeSeriesHeader. + ); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status BasicTimeSeriesCalculatorBase::Open(CalculatorContext* cc) { + TimeSeriesHeader input_header; + RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid( + cc->Inputs().Index(0).Header(), &input_header)); + + auto output_header = new TimeSeriesHeader(input_header); + RETURN_IF_ERROR(MutateHeader(output_header)); + cc->Outputs().Index(0).SetHeader(Adopt(output_header)); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status BasicTimeSeriesCalculatorBase::Process( + CalculatorContext* cc) { + const Matrix& input = cc->Inputs().Index(0).Get(); + RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader( + input, cc->Inputs().Index(0).Header().Get())); + + std::unique_ptr output(new Matrix(ProcessMatrix(input))); + RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader( + *output, cc->Outputs().Index(0).Header().Get())); + + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status BasicTimeSeriesCalculatorBase::MutateHeader( + TimeSeriesHeader* output_header) { + return ::mediapipe::OkStatus(); +} + +// Calculator to sum an input time series across channels. This is +// useful for e.g. computing 'summary SAI' pitchogram features. +// +// Options proto: None. +class SumTimeSeriesAcrossChannelsCalculator + : public BasicTimeSeriesCalculatorBase { + protected: + ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + output_header->set_num_channels(1); + return ::mediapipe::OkStatus(); + } + + Matrix ProcessMatrix(const Matrix& input_matrix) final { + return input_matrix.colwise().sum(); + } +}; +REGISTER_CALCULATOR(SumTimeSeriesAcrossChannelsCalculator); + +// Calculator to average an input time series across channels. This is +// useful for e.g. converting stereo or multi-channel files to mono. +// +// Options proto: None. +class AverageTimeSeriesAcrossChannelsCalculator + : public BasicTimeSeriesCalculatorBase { + protected: + ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + output_header->set_num_channels(1); + return ::mediapipe::OkStatus(); + } + + Matrix ProcessMatrix(const Matrix& input_matrix) final { + return input_matrix.colwise().mean(); + } +}; +REGISTER_CALCULATOR(AverageTimeSeriesAcrossChannelsCalculator); + +// Calculator to convert a (temporal) summary SAI stream (a single-channel +// stream output by SumTimeSeriesAcrossChannelsCalculator) into pitchogram +// frames by transposing the input packets, swapping the time and channel axes. +// +// Options proto: None. +class SummarySaiToPitchogramCalculator : public BasicTimeSeriesCalculatorBase { + protected: + ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + if (output_header->num_channels() != 1) { + return tool::StatusInvalid( + absl::StrCat("Expected single-channel input, got ", + output_header->num_channels())); + } + output_header->set_num_channels(output_header->num_samples()); + output_header->set_num_samples(1); + output_header->set_sample_rate(output_header->packet_rate()); + return ::mediapipe::OkStatus(); + } + + Matrix ProcessMatrix(const Matrix& input_matrix) final { + return input_matrix.transpose(); + } +}; +REGISTER_CALCULATOR(SummarySaiToPitchogramCalculator); + +// Calculator to reverse the order of channels in TimeSeries packets. +// This is useful for e.g. interfacing with the speech pipeline which uses the +// opposite convention to the hearing filterbanks. +// +// Options proto: None. +class ReverseChannelOrderCalculator : public BasicTimeSeriesCalculatorBase { + protected: + Matrix ProcessMatrix(const Matrix& input_matrix) final { + return input_matrix.colwise().reverse(); + } +}; +REGISTER_CALCULATOR(ReverseChannelOrderCalculator); + +// Calculator to flatten all samples in a TimeSeries packet down into +// a single 'sample' vector. This is useful for e.g. stacking several +// frames of features into a single feature vector. +// +// Options proto: None. +class FlattenPacketCalculator : public BasicTimeSeriesCalculatorBase { + protected: + ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + const int num_input_channels = output_header->num_channels(); + const int num_input_samples = output_header->num_samples(); + RET_CHECK(num_input_channels >= 0) + << "FlattenPacketCalculator: num_input_channels < 0"; + RET_CHECK(num_input_samples >= 0) + << "FlattenPacketCalculator: num_input_samples < 0"; + int output_num_channels; + RET_CHECK(SafeMultiply(num_input_channels, num_input_samples, + &output_num_channels)) + << "FlattenPacketCalculator: Multiplication failed."; + output_header->set_num_channels(output_num_channels); + output_header->set_num_samples(1); + output_header->set_sample_rate(output_header->packet_rate()); + return ::mediapipe::OkStatus(); + } + + Matrix ProcessMatrix(const Matrix& input_matrix) final { + // Flatten by interleaving channels so that full samples are + // stacked on top of each other instead of interleaving samples + // from the same channel. + Matrix output(input_matrix.size(), 1); + for (int sample = 0; sample < input_matrix.cols(); ++sample) { + output.middleRows(sample * input_matrix.rows(), input_matrix.rows()) = + input_matrix.col(sample); + } + return output; + } +}; +REGISTER_CALCULATOR(FlattenPacketCalculator); + +// Calculator to subtract the within-packet mean for each channel from each +// corresponding channel. +// +// Options proto: None. +class SubtractMeanCalculator : public BasicTimeSeriesCalculatorBase { + protected: + Matrix ProcessMatrix(const Matrix& input_matrix) final { + Matrix mean = input_matrix.rowwise().mean(); + return input_matrix - mean.replicate(1, input_matrix.cols()); + } +}; +REGISTER_CALCULATOR(SubtractMeanCalculator); + +// Calculator to subtract the mean over all values (across all times and +// channels) in a Packet from the values in that Packet. +// +// Options proto: None. +class SubtractMeanAcrossChannelsCalculator + : public BasicTimeSeriesCalculatorBase { + protected: + Matrix ProcessMatrix(const Matrix& input_matrix) final { + auto mean = input_matrix.mean(); + return (input_matrix.array() - mean).matrix(); + } +}; +REGISTER_CALCULATOR(SubtractMeanAcrossChannelsCalculator); + +// Calculator to divide all values in a Packet by the average value across all +// times and channels in the packet. This is useful for normalizing +// nonnegative quantities like power, but might cause unexpected results if used +// with Packets that can contain negative numbers. +// +// If mean is exactly zero, the output will be a matrix of all ones, because +// that's what happens in other cases where all values are equal. +// +// Options proto: None. +class DivideByMeanAcrossChannelsCalculator + : public BasicTimeSeriesCalculatorBase { + protected: + Matrix ProcessMatrix(const Matrix& input_matrix) final { + auto mean = input_matrix.mean(); + + if (mean != 0) { + return input_matrix / mean; + + // When used with nonnegative matrices, the mean will only be zero if the + // entire matrix is exactly zero. If mean is exactly zero, the output will + // be a matrix of all ones, because that's what happens in other cases + // where + // all values are equal. + } else { + return Matrix::Ones(input_matrix.rows(), input_matrix.cols()); + } + } +}; +REGISTER_CALCULATOR(DivideByMeanAcrossChannelsCalculator); + +// Calculator to calculate the mean for each channel. +// +// Options proto: None. +class MeanCalculator : public BasicTimeSeriesCalculatorBase { + protected: + ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + output_header->set_num_samples(1); + output_header->set_sample_rate(output_header->packet_rate()); + return ::mediapipe::OkStatus(); + } + + Matrix ProcessMatrix(const Matrix& input_matrix) final { + return input_matrix.rowwise().mean(); + } +}; +REGISTER_CALCULATOR(MeanCalculator); + +// Calculator to calculate the uncorrected sample standard deviation in each +// channel, independently for each Packet. I.e. divide by the number of samples +// in the Packet, not ( - 1). +// +// Options proto: None. +class StandardDeviationCalculator : public BasicTimeSeriesCalculatorBase { + protected: + ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + output_header->set_num_samples(1); + output_header->set_sample_rate(output_header->packet_rate()); + return ::mediapipe::OkStatus(); + } + + Matrix ProcessMatrix(const Matrix& input_matrix) final { + Eigen::VectorXf mean = input_matrix.rowwise().mean(); + return (input_matrix.colwise() - mean).rowwise().norm() / + sqrt(input_matrix.cols()); + } +}; +REGISTER_CALCULATOR(StandardDeviationCalculator); + +// Calculator to calculate the covariance matrix. If the input matrix +// has N channels, the output matrix will be an N by N symmetric +// matrix. +// +// Options proto: None. +class CovarianceCalculator : public BasicTimeSeriesCalculatorBase { + protected: + ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + output_header->set_num_samples(output_header->num_channels()); + return ::mediapipe::OkStatus(); + } + + Matrix ProcessMatrix(const Matrix& input_matrix) final { + auto mean = input_matrix.rowwise().mean(); + auto zero_mean_input = + input_matrix - mean.replicate(1, input_matrix.cols()); + return (zero_mean_input * zero_mean_input.transpose()) / + input_matrix.cols(); + } +}; +REGISTER_CALCULATOR(CovarianceCalculator); + +// Calculator to get the per column L2 norm of an input time series. +// +// Options proto: None. +class L2NormCalculator : public BasicTimeSeriesCalculatorBase { + protected: + ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + output_header->set_num_channels(1); + return ::mediapipe::OkStatus(); + } + + Matrix ProcessMatrix(const Matrix& input_matrix) final { + return input_matrix.colwise().norm(); + } +}; +REGISTER_CALCULATOR(L2NormCalculator); + +// Calculator to convert each column of a matrix to a unit vector. +// +// Options proto: None. +class L2NormalizeColumnCalculator : public BasicTimeSeriesCalculatorBase { + protected: + Matrix ProcessMatrix(const Matrix& input_matrix) final { + return input_matrix.colwise().normalized(); + } +}; +REGISTER_CALCULATOR(L2NormalizeColumnCalculator); + +// Calculator to apply L2 normalization to the input matrix. +// +// Returns the matrix as is if the RMS is <= 1E-8. +// Options proto: None. +class L2NormalizeCalculator : public BasicTimeSeriesCalculatorBase { + protected: + Matrix ProcessMatrix(const Matrix& input_matrix) final { + constexpr double kEpsilon = 1e-8; + double rms = std::sqrt(input_matrix.array().square().mean()); + if (rms <= kEpsilon) { + return input_matrix; + } + return input_matrix / rms; + } +}; +REGISTER_CALCULATOR(L2NormalizeCalculator); + +// Calculator to apply Peak normalization to the input matrix. +// +// Returns the matrix as is if the peak is <= 1E-8. +// Options proto: None. +class PeakNormalizeCalculator : public BasicTimeSeriesCalculatorBase { + protected: + Matrix ProcessMatrix(const Matrix& input_matrix) final { + constexpr double kEpsilon = 1e-8; + double max_pcm = input_matrix.cwiseAbs().maxCoeff(); + if (max_pcm <= kEpsilon) { + return input_matrix; + } + return input_matrix / max_pcm; + } +}; +REGISTER_CALCULATOR(PeakNormalizeCalculator); + +// Calculator to compute the elementwise square of an input time series. +// +// Options proto: None. +class ElementwiseSquareCalculator : public BasicTimeSeriesCalculatorBase { + protected: + Matrix ProcessMatrix(const Matrix& input_matrix) final { + return input_matrix.array().square(); + } +}; +REGISTER_CALCULATOR(ElementwiseSquareCalculator); + +// Calculator that outputs first floor(num_samples / 2) of the samples. +// +// Options proto: None. +class FirstHalfSlicerCalculator : public BasicTimeSeriesCalculatorBase { + protected: + ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + const int num_input_samples = output_header->num_samples(); + RET_CHECK(num_input_samples >= 0) + << "FirstHalfSlicerCalculator: num_input_samples < 0"; + output_header->set_num_samples(num_input_samples / 2); + return ::mediapipe::OkStatus(); + } + + Matrix ProcessMatrix(const Matrix& input_matrix) final { + return input_matrix.block(0, 0, input_matrix.rows(), + input_matrix.cols() / 2); + } +}; +REGISTER_CALCULATOR(FirstHalfSlicerCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/basic_time_series_calculators.h b/mediapipe/calculators/audio/basic_time_series_calculators.h new file mode 100644 index 000000000..3727d66b0 --- /dev/null +++ b/mediapipe/calculators/audio/basic_time_series_calculators.h @@ -0,0 +1,48 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Abstract base class for basic MediaPipe calculators that operate on +// TimeSeries streams and don't require any Options protos. +// Subclasses must override ProcessMatrix, and optionally +// MutateHeader. + +#ifndef MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_ +#define MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_ + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" + +namespace mediapipe { + +class BasicTimeSeriesCalculatorBase : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + protected: + // Open() calls this method to mutate the output stream header. The input + // to this function will contain a copy of the input stream header, so + // subclasses that do not need to mutate the header do not need to override + // it. + virtual ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header); + + // Process() calls this method on each packet to compute the output matrix. + virtual Matrix ProcessMatrix(const Matrix& input_matrix) = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_ diff --git a/mediapipe/calculators/audio/basic_time_series_calculators_test.cc b/mediapipe/calculators/audio/basic_time_series_calculators_test.cc new file mode 100644 index 000000000..7211b83fe --- /dev/null +++ b/mediapipe/calculators/audio/basic_time_series_calculators_test.cc @@ -0,0 +1,515 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "Eigen/Core" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/util/time_series_test_util.h" + +namespace mediapipe { + +class SumTimeSeriesAcrossChannelsCalculatorTest + : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { + calculator_name_ = "SumTimeSeriesAcrossChannelsCalculator"; + } +}; + +TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, IsNoOpOnSingleChannelInputs) { + const TimeSeriesHeader header = ParseTextProtoOrDie( + "sample_rate: 8000.0 num_channels: 1 num_samples: 5"); + const Matrix input = + Matrix::Random(header.num_channels(), header.num_samples()); + + Test(header, {input}, header, {input}); +} + +TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, ConstantPacket) { + const TimeSeriesHeader header = ParseTextProtoOrDie( + "sample_rate: 8000.0 num_channels: 3 num_samples: 5"); + TimeSeriesHeader output_header(header); + output_header.set_num_channels(1); + + Test(header, + {Matrix::Constant(header.num_channels(), header.num_samples(), 1)}, + output_header, + {Matrix::Constant(1, header.num_samples(), header.num_channels())}); +} + +TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, MultiplePackets) { + const TimeSeriesHeader header = ParseTextProtoOrDie( + "sample_rate: 8000.0 num_channels: 3 num_samples: 5"); + Matrix in(header.num_channels(), header.num_samples()); + in << 10, -1, -1, 0, 0, 20, -2, 0, 1, 0, 30, -3, 1, 0, 12; + + TimeSeriesHeader output_header(header); + output_header.set_num_channels(1); + Matrix out(1, header.num_samples()); + out << 60, -6, 0, 1, 12; + + Test(header, {in, 2 * in, in + Matrix::Constant(in.rows(), in.cols(), 3.5f)}, + output_header, + {out, 2 * out, + out + Matrix::Constant(out.rows(), out.cols(), + 3.5 * header.num_channels())}); +} + +class AverageTimeSeriesAcrossChannelsCalculatorTest + : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { + calculator_name_ = "AverageTimeSeriesAcrossChannelsCalculator"; + } +}; + +TEST_F(AverageTimeSeriesAcrossChannelsCalculatorTest, + IsNoOpOnSingleChannelInputs) { + const TimeSeriesHeader header = ParseTextProtoOrDie( + "sample_rate: 8000.0 num_channels: 1 num_samples: 5"); + const Matrix input = + Matrix::Random(header.num_channels(), header.num_samples()); + + Test(header, {input}, header, {input}); +} + +TEST_F(AverageTimeSeriesAcrossChannelsCalculatorTest, ConstantPacket) { + const TimeSeriesHeader header = ParseTextProtoOrDie( + "sample_rate: 8000.0 num_channels: 3 num_samples: 5"); + TimeSeriesHeader output_header(header); + output_header.set_num_channels(1); + + Matrix input = + Matrix::Constant(header.num_channels(), header.num_samples(), 0.0); + input.row(0) = Matrix::Constant(1, header.num_samples(), 1.0); + + Test( + header, {input}, output_header, + {Matrix::Constant(1, header.num_samples(), 1.0 / header.num_channels())}); +} + +class SummarySaiToPitchogramCalculatorTest + : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { + calculator_name_ = "SummarySaiToPitchogramCalculator"; + } +}; + +TEST_F(SummarySaiToPitchogramCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 1 num_samples: 3"); + Matrix input(1, input_header.num_samples()); + input << 3, -9, 4; + + const TimeSeriesHeader output_header = ParseTextProtoOrDie( + "sample_rate: 5.0 packet_rate: 5.0 num_channels: 3 num_samples: 1"); + Matrix output(input_header.num_samples(), 1); + output << 3, -9, 4; + + Test(input_header, {input}, output_header, {output}); +} + +class ReverseChannelOrderCalculatorTest + : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "ReverseChannelOrderCalculator"; } +}; + +TEST_F(ReverseChannelOrderCalculatorTest, IsNoOpOnSingleChannelInputs) { + const TimeSeriesHeader header = ParseTextProtoOrDie( + "sample_rate: 8000.0 num_channels: 1 num_samples: 5"); + const Matrix input = + Matrix::Random(header.num_channels(), header.num_samples()); + + Test(header, {input}, header, {input}); +} + +TEST_F(ReverseChannelOrderCalculatorTest, SinglePacket) { + const TimeSeriesHeader header = ParseTextProtoOrDie( + "sample_rate: 8000.0 num_channels: 5 num_samples: 2"); + Matrix input(header.num_channels(), header.num_samples()); + input.transpose() << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5; + Matrix output(header.num_channels(), header.num_samples()); + output.transpose() << 5, 4, 3, 2, 1, -5, -4, -3, -2, -1; + + Test(header, {input}, header, {output}); +} + +class FlattenPacketCalculatorTest : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "FlattenPacketCalculator"; } +}; + +TEST_F(FlattenPacketCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input.transpose() << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5; + Matrix output(10, 1); + output << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5; + + const TimeSeriesHeader output_header = ParseTextProtoOrDie( + "sample_rate: 10.0 packet_rate: 10.0 num_channels: 10 num_samples: 1"); + Test(input_header, {input}, output_header, {output}); +} + +class SubtractMeanCalculatorTest : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "SubtractMeanCalculator"; } +}; + +TEST_F(SubtractMeanCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + Matrix output(input_header.num_channels(), input_header.num_samples()); + + // clang-format off + input.transpose() << 1, 0, 3, 0, 1, + -1, -2, -3, 4, 7; + output.transpose() << 1, 1, 3, -2, -3, + -1, -1, -3, 2, 3; + // clang-format on + + const TimeSeriesHeader output_header = input_header; + Test(input_header, {input}, output_header, {output}); +} + +class SubtractMeanAcrossChannelsCalculatorTest + : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { + calculator_name_ = "SubtractMeanAcrossChannelsCalculator"; + } +}; + +TEST_F(SubtractMeanAcrossChannelsCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2"); + TimeSeriesHeader output_header(input_header); + output_header.set_num_samples(2); + + Matrix input(input_header.num_channels(), input_header.num_samples()); + Matrix output(output_header.num_channels(), output_header.num_samples()); + + // clang-format off + input.transpose() << 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0; + output.transpose() << 1.0 - 3.5, 2.0 - 3.5, 3.0 - 3.5, + 4.0 - 3.5, 5.0 - 3.5, 6.0 - 3.5; + // clang-format on + + Test(input_header, {input}, output_header, {output}); +} + +class DivideByMeanAcrossChannelsCalculatorTest + : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { + calculator_name_ = "DivideByMeanAcrossChannelsCalculator"; + } +}; + +TEST_F(DivideByMeanAcrossChannelsCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input.transpose() << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0; + + TimeSeriesHeader output_header(input_header); + output_header.set_num_samples(2); + Matrix output(output_header.num_channels(), output_header.num_samples()); + output.transpose() << 1.0 / 3.5, 2.0 / 3.5, 3.0 / 3.5, 4.0 / 3.5, 5.0 / 3.5, + 6.0 / 3.5; + + Test(input_header, {input}, output_header, {output}); +} + +TEST_F(DivideByMeanAcrossChannelsCalculatorTest, ReturnsOneForZeroMean) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input.transpose() << -3.0, -2.0, -1.0, 1.0, 2.0, 3.0; + + TimeSeriesHeader output_header(input_header); + output_header.set_num_samples(2); + Matrix output(output_header.num_channels(), output_header.num_samples()); + output.transpose() << 1.0, 1.0, 1.0, 1.0, 1.0, 1.0; + + Test(input_header, {input}, output_header, {output}); +} + +class MeanCalculatorTest : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "MeanCalculator"; } +}; + +TEST_F(MeanCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input.transpose() << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0; + + TimeSeriesHeader output_header(input_header); + output_header.set_num_samples(1); + output_header.set_sample_rate(10.0); + Matrix output(output_header.num_channels(), output_header.num_samples()); + output << (1.0 + 4.0) / 2, (2.0 + 5.0) / 2, (3.0 + 6.0) / 2; + + Test(input_header, {input}, output_header, {output}); +} + +class StandardDeviationCalculatorTest + : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "StandardDeviationCalculator"; } +}; + +TEST_F(StandardDeviationCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input.transpose() << 0.0, 2.0, 3.0, 4.0, 5.0, 8.0; + + TimeSeriesHeader output_header(input_header); + output_header.set_sample_rate(10.0); + output_header.set_num_samples(1); + Matrix output(output_header.num_channels(), output_header.num_samples()); + output << sqrt((pow(0.0 - 2.0, 2) + pow(4.0 - 2.0, 2)) / 2), + sqrt((pow(2.0 - 3.5, 2) + pow(5.0 - 3.5, 2)) / 2), + sqrt((pow(3.0 - 5.5, 2) + pow(8.0 - 5.5, 2)) / 2); + + Test(input_header, {input}, output_header, {output}); +} + +class CovarianceCalculatorTest : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "CovarianceCalculator"; } +}; + +TEST_F(CovarianceCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + + // We'll specify in transposed form so we can write one channel at a time. + input << 1.0, 3.0, 5.0, 9.0, -1.0, -3.0; + + TimeSeriesHeader output_header(input_header); + output_header.set_num_samples(output_header.num_channels()); + Matrix output(output_header.num_channels(), output_header.num_samples()); + output << 1, 2, -1, 2, 4, -2, -1, -2, 1; + Test(input_header, {input}, output_header, {output}); +} + +class L2NormCalculatorTest : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "L2NormCalculator"; } +}; + +TEST_F(L2NormCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input << 3, 5, 8, 4, 12, -15; + + const TimeSeriesHeader output_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 1 num_samples: 3"); + Matrix output(output_header.num_channels(), output_header.num_samples()); + output << 5, 13, 17; + + Test(input_header, {input}, output_header, {output}); +} + +class L2NormalizeColumnCalculatorTest + : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "L2NormalizeColumnCalculator"; } +}; + +TEST_F(L2NormalizeColumnCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8; + + const TimeSeriesHeader output_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3"); + Matrix output(output_header.num_channels(), output_header.num_samples()); + + // The values in output are column-wise L2 normalized + // e.g. + // |a| -> |a/sqrt(a^2 + b^2)| + // |b| |b/sqrt(a^2 + b^2)| + output << 0.51449579000473022, 0.40613847970962524, 0.70710676908493042, + 0.85749292373657227, 0.91381156444549561, 0.70710676908493042; + + Test(input_header, {input}, output_header, {output}); +} + +class L2NormalizeCalculatorTest : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "L2NormalizeCalculator"; } +}; + +TEST_F(L2NormalizeCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8; + + const TimeSeriesHeader output_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3"); + Matrix output(output_header.num_channels(), output_header.num_samples()); + + // The values in output are L2 normalized + // a -> a/sqrt(a^2 + b^2 + c^2 + ...) * sqrt(matrix.cols()*matrix.rows()) + output << 0.45661166, 0.60881555, 1.21763109, 0.76101943, 1.36983498, + 1.21763109; + + Test(input_header, {input}, output_header, {output}); +} + +TEST_F(L2NormalizeCalculatorTest, UnitMatrixStaysUnchanged) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 3 num_samples: 5"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input << 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0; + + Test(input_header, {input}, input_header, {input}); +} + +class PeakNormalizeCalculatorTest : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "PeakNormalizeCalculator"; } +}; + +TEST_F(PeakNormalizeCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8; + + const TimeSeriesHeader output_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3"); + Matrix output(output_header.num_channels(), output_header.num_samples()); + output << 0.33333333, 0.44444444, 0.88888889, 0.55555556, 1.0, 0.88888889; + + Test(input_header, {input}, output_header, {output}); +} + +TEST_F(PeakNormalizeCalculatorTest, UnitMatrixStaysUnchanged) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 3 num_samples: 5"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input << 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0; + + Test(input_header, {input}, input_header, {input}); +} + +class ElementwiseSquareCalculatorTest + : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "ElementwiseSquareCalculator"; } +}; + +TEST_F(ElementwiseSquareCalculatorTest, SinglePacket) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + input << 3, 5, 8, 4, 12, -15; + + const TimeSeriesHeader output_header = ParseTextProtoOrDie( + "sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3"); + Matrix output(output_header.num_channels(), output_header.num_samples()); + output << 9, 25, 64, 16, 144, 225; + + Test(input_header, {input}, output_header, {output}); +} + +class FirstHalfSlicerCalculatorTest : public BasicTimeSeriesCalculatorTestBase { + protected: + void SetUp() override { calculator_name_ = "FirstHalfSlicerCalculator"; } +}; + +TEST_F(FirstHalfSlicerCalculatorTest, SinglePacketEvenNumSamples) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + // clang-format off + input.transpose() << 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9; + // clang-format on + + const TimeSeriesHeader output_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1"); + Matrix output(output_header.num_channels(), output_header.num_samples()); + output.transpose() << 0, 1, 2, 3, 4; + + Test(input_header, {input}, output_header, {output}); +} + +TEST_F(FirstHalfSlicerCalculatorTest, SinglePacketOddNumSamples) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 3"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + // clang-format off + input.transpose() << 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 0, 0, 0, 0, 0; + // clang-format on + + const TimeSeriesHeader output_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1"); + Matrix output(output_header.num_channels(), output_header.num_samples()); + output.transpose() << 0, 1, 2, 3, 4; + + Test(input_header, {input}, output_header, {output}); +} + +TEST_F(FirstHalfSlicerCalculatorTest, MultiplePackets) { + const TimeSeriesHeader input_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2"); + Matrix input(input_header.num_channels(), input_header.num_samples()); + // clang-format off + input.transpose() << 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9; + // clang-format on + const TimeSeriesHeader output_header = ParseTextProtoOrDie( + "sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1"); + Matrix output(output_header.num_channels(), output_header.num_samples()); + output.transpose() << 0, 1, 2, 3, 4; + + Test(input_header, + {input, 2 * input, + input + Matrix::Constant(input.rows(), input.cols(), 3.5f)}, + output_header, + {output, 2 * output, + output + Matrix::Constant(output.rows(), output.cols(), 3.5f)}); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/mfcc_mel_calculators.cc b/mediapipe/calculators/audio/mfcc_mel_calculators.cc new file mode 100644 index 000000000..3e6ebbe95 --- /dev/null +++ b/mediapipe/calculators/audio/mfcc_mel_calculators.cc @@ -0,0 +1,278 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// MediaPipe Calculator wrapper around audio/dsp/mfcc/ +// classes MelFilterbank (magnitude spectrograms warped to the Mel +// approximation of the auditory frequency scale) and Mfcc (Mel Frequency +// Cepstral Coefficients, the decorrelated transform of log-Mel-spectrum +// commonly used as acoustic features in speech and other audio tasks. +// Both calculators expect as input the SQUARED_MAGNITUDE-domain outputs +// from the MediaPipe SpectrogramCalculator object. +#include +#include + +#include "Eigen/Core" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "audio/dsp/mfcc/mel_filterbank.h" +#include "audio/dsp/mfcc/mfcc.h" +#include "mediapipe/calculators/audio/mfcc_mel_calculators.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/time_series_util.h" + +namespace mediapipe { + +namespace { + +// Portable version of TimeSeriesHeader's DebugString. +std::string PortableDebugString(const TimeSeriesHeader& header) { + std::string unsubstituted_header_debug_str = R"( + sample_rate: $0 + num_channels: $1 + num_samples: $2 + packet_rate: $3 + audio_sample_rate: $4 + )"; + return absl::Substitute(unsubstituted_header_debug_str, header.sample_rate(), + header.num_channels(), header.num_samples(), + header.packet_rate(), header.audio_sample_rate()); +} + +} // namespace + +// Abstract base class for Calculators that transform feature vectors on a +// frame-by-frame basis. +// Subclasses must override pure virtual methods ConfigureTransform and +// TransformFrame. +// Input and output MediaPipe packets are matrices with one column per frame, +// and one row per feature dimension. Each input packet results in an +// output packet with the same number of columns (but differing numbers of +// rows corresponding to the new feature space). +class FramewiseTransformCalculatorBase : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set( + // Sequence of Matrices, each column describing a particular time frame, + // each row a feature dimension, with TimeSeriesHeader. + ); + cc->Outputs().Index(0).Set( + // Sequence of Matrices, each column describing a particular time frame, + // each row a feature dimension, with TimeSeriesHeader. + ); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + int num_output_channels(void) { return num_output_channels_; } + + void set_num_output_channels(int num_output_channels) { + num_output_channels_ = num_output_channels; + } + + private: + // Takes header and options, and sets up state including calling + // set_num_output_channels() on the base object. + virtual ::mediapipe::Status ConfigureTransform( + const TimeSeriesHeader& header, const CalculatorOptions& options) = 0; + + // Takes a vector corresponding to an input frame, and + // perform the specific transformation to produce an output frame. + virtual void TransformFrame(const std::vector& input, + std::vector* output) const = 0; + + private: + int num_output_channels_; +}; + +::mediapipe::Status FramewiseTransformCalculatorBase::Open( + CalculatorContext* cc) { + TimeSeriesHeader input_header; + RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid( + cc->Inputs().Index(0).Header(), &input_header)); + + ::mediapipe::Status status = ConfigureTransform(input_header, cc->Options()); + + auto output_header = new TimeSeriesHeader(input_header); + output_header->set_num_channels(num_output_channels_); + cc->Outputs().Index(0).SetHeader(Adopt(output_header)); + + return status; +} + +::mediapipe::Status FramewiseTransformCalculatorBase::Process( + CalculatorContext* cc) { + const Matrix& input = cc->Inputs().Index(0).Get(); + const int num_frames = input.cols(); + std::unique_ptr output(new Matrix(num_output_channels_, num_frames)); + // The main work here is converting each column of the float Matrix + // into a vector of doubles, which is what our target functions from + // dsp_core consume, and doing the reverse with their output. + std::vector input_frame(input.rows()); + std::vector output_frame(num_output_channels_); + + for (int frame = 0; frame < num_frames; ++frame) { + // Copy input from Eigen::Matrix column to vector. + Eigen::Map input_frame_map(&input_frame[0], + input_frame.size(), 1); + input_frame_map = input.col(frame).cast(); + + // Perform the actual transformation. + TransformFrame(input_frame, &output_frame); + + // Copy output from vector to Eigen::Vector. + CHECK_EQ(output_frame.size(), num_output_channels_); + Eigen::Map output_frame_map(&output_frame[0], + output_frame.size(), 1); + output->col(frame) = output_frame_map.cast(); + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); +} + +// Calculator wrapper around the dsp/mfcc/mfcc.cc routine. +// Take frames of squared-magnitude spectra from the SpectrogramCalculator +// and convert them into Mel Frequency Cepstral Coefficients. +// +// Example config: +// node { +// calculator: "MfccCalculator" +// input_stream: "spectrogram_frames_stream" +// output_stream: "mfcc_frames_stream" +// options { +// [mediapipe.MfccCalculatorOptions.ext] { +// mel_spectrum_params { +// channel_count: 20 +// min_frequency_hertz: 125.0 +// max_frequency_hertz: 3800.0 +// } +// mfcc_count: 13 +// } +// } +// } +class MfccCalculator : public FramewiseTransformCalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + return FramewiseTransformCalculatorBase::GetContract(cc); + } + + private: + ::mediapipe::Status ConfigureTransform( + const TimeSeriesHeader& header, + const CalculatorOptions& options) override { + MfccCalculatorOptions mfcc_options; + time_series_util::FillOptionsExtensionOrDie(options, &mfcc_options); + mfcc_.reset(new audio_dsp::Mfcc()); + int input_length = header.num_channels(); + // Set up the parameters to the Mfcc object. + set_num_output_channels(mfcc_options.mfcc_count()); + mfcc_->set_dct_coefficient_count(num_output_channels()); + mfcc_->set_upper_frequency_limit( + mfcc_options.mel_spectrum_params().max_frequency_hertz()); + mfcc_->set_lower_frequency_limit( + mfcc_options.mel_spectrum_params().min_frequency_hertz()); + mfcc_->set_filterbank_channel_count( + mfcc_options.mel_spectrum_params().channel_count()); + // An upstream calculator (such as SpectrogramCalculator) must store + // the sample rate of its input audio waveform in the TimeSeries Header. + // audio_dsp::MelFilterBank needs to know this to + // correctly interpret the spectrogram bins. + if (!header.has_audio_sample_rate()) { + return ::mediapipe::InvalidArgumentError( + absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ", + PortableDebugString(header))); + } + // Now we can initialize the Mfcc object. + bool initialized = + mfcc_->Initialize(input_length, header.audio_sample_rate()); + + if (initialized) { + return ::mediapipe::OkStatus(); + } else { + return ::mediapipe::Status(mediapipe::StatusCode::kInternal, + "Mfcc::Initialize returned uninitialized"); + } + } + + void TransformFrame(const std::vector& input, + std::vector* output) const override { + mfcc_->Compute(input, output); + } + + private: + std::unique_ptr mfcc_; +}; +REGISTER_CALCULATOR(MfccCalculator); + +// Calculator wrapper around the dsp/mfcc/mel_filterbank.cc routine. +// Take frames of squared-magnitude spectra from the SpectrogramCalculator +// and convert them into Mel-warped (linear-magnitude) spectra. +// Note: This code computes a mel-frequency filterbank, using a simple +// algorithm that gives bad results (some mel channels that are always zero) +// if you ask for too many channels. +class MelSpectrumCalculator : public FramewiseTransformCalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + return FramewiseTransformCalculatorBase::GetContract(cc); + } + + private: + ::mediapipe::Status ConfigureTransform( + const TimeSeriesHeader& header, + const CalculatorOptions& options) override { + MelSpectrumCalculatorOptions mel_spectrum_options; + time_series_util::FillOptionsExtensionOrDie(options, &mel_spectrum_options); + mel_filterbank_.reset(new audio_dsp::MelFilterbank()); + int input_length = header.num_channels(); + set_num_output_channels(mel_spectrum_options.channel_count()); + // An upstream calculator (such as SpectrogramCalculator) must store + // the sample rate of its input audio waveform in the TimeSeries Header. + // audio_dsp::MelFilterBank needs to know this to + // correctly interpret the spectrogram bins. + if (!header.has_audio_sample_rate()) { + return ::mediapipe::InvalidArgumentError( + absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ", + PortableDebugString(header))); + } + bool initialized = mel_filterbank_->Initialize( + input_length, header.audio_sample_rate(), num_output_channels(), + mel_spectrum_options.min_frequency_hertz(), + mel_spectrum_options.max_frequency_hertz()); + + if (initialized) { + return ::mediapipe::OkStatus(); + } else { + return ::mediapipe::Status(mediapipe::StatusCode::kInternal, + "mfcc::Initialize returned uninitialized"); + } + } + + void TransformFrame(const std::vector& input, + std::vector* output) const override { + mel_filterbank_->Compute(input, output); + } + + private: + std::unique_ptr mel_filterbank_; +}; +REGISTER_CALCULATOR(MelSpectrumCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/mfcc_mel_calculators.proto b/mediapipe/calculators/audio/mfcc_mel_calculators.proto new file mode 100644 index 000000000..89af5eb41 --- /dev/null +++ b/mediapipe/calculators/audio/mfcc_mel_calculators.proto @@ -0,0 +1,50 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message MelSpectrumCalculatorOptions { + extend CalculatorOptions { + optional MelSpectrumCalculatorOptions ext = 78581812; + } + // The fields are to populate the config parameters in + // audio/dsp/mfcc/mel_filterbank.h + // but the names are chose to mirror + // audio/hearing/filterbanks/cochlea_gammatone_filterbank.proto + // and the default values match those in + // speech/greco3/frontend/filter_bank.proto . + + // Total number of frequency bands to use. + optional int32 channel_count = 1 [default = 20]; + // Lower edge of lowest triangular Mel band. + optional float min_frequency_hertz = 2 [default = 125.0]; + // Upper edge of highest triangular Mel band. + optional float max_frequency_hertz = 3 [default = 3800.0]; +} + +message MfccCalculatorOptions { + extend CalculatorOptions { + optional MfccCalculatorOptions ext = 78450441; + } + + // Specification of the underlying mel filterbank. + optional MelSpectrumCalculatorOptions mel_spectrum_params = 1; + + // How many MFCC coefficients to emit. + optional uint32 mfcc_count = 2 [default = 13]; +} diff --git a/mediapipe/calculators/audio/mfcc_mel_calculators_test.cc b/mediapipe/calculators/audio/mfcc_mel_calculators_test.cc new file mode 100644 index 000000000..38727e232 --- /dev/null +++ b/mediapipe/calculators/audio/mfcc_mel_calculators_test.cc @@ -0,0 +1,149 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "Eigen/Core" +#include "mediapipe/calculators/audio/mfcc_mel_calculators.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/util/time_series_test_util.h" + +namespace mediapipe { + +// Use a sample rate that is unlikely to be a default somewhere. +const float kAudioSampleRate = 8800.0; + +template +class FramewiseTransformCalculatorTest + : public TimeSeriesCalculatorTest { + protected: + void SetUp() override { + this->calculator_name_ = CalculatorName; + this->num_input_channels_ = 129; + // This is the frame rate coming out of the SpectrogramCalculator. + this->input_sample_rate_ = 100.0; + } + + // Returns the number of samples per packet. + int GenerateRandomNonnegInputStream(int num_packets) { + const double kSecondsPerPacket = 0.2; + const int num_samples_per_packet = + kSecondsPerPacket * this->input_sample_rate_; + for (int i = 0; i < num_packets; ++i) { + const int timestamp = + i * kSecondsPerPacket * Timestamp::kTimestampUnitsPerSecond; + // Mfcc, MelSpectrum expect squared-magnitude inputs, so make + // sure the input data has no negative values. + Matrix* sqdata = this->NewRandomMatrix(this->num_input_channels_, + num_samples_per_packet); + *sqdata = sqdata->array().square(); + this->AppendInputPacket(sqdata, timestamp); + } + return num_samples_per_packet; + } + + void CheckOutputPacketMetadata(int expected_num_channels, + int expected_num_samples_per_packet) { + int expected_timestamp = 0; + for (const auto& packet : this->output().packets) { + EXPECT_EQ(expected_timestamp, packet.Timestamp().Value()); + expected_timestamp += expected_num_samples_per_packet / + this->input_sample_rate_ * + Timestamp::kTimestampUnitsPerSecond; + + const Matrix& output_matrix = packet.template Get(); + + EXPECT_EQ(output_matrix.rows(), expected_num_channels); + EXPECT_EQ(output_matrix.cols(), expected_num_samples_per_packet); + } + } + + void SetupGraphAndHeader() { + this->InitializeGraph(); + this->FillInputHeader(); + } + + // Argument is the expected number of dimensions (channels, columns) in + // the output data from the Calculator under test, which the test should + // know. + void SetupRandomInputPackets() { + constexpr int kNumPackets = 5; + num_samples_per_packet_ = GenerateRandomNonnegInputStream(kNumPackets); + } + + ::mediapipe::Status Run() { return this->RunGraph(); } + + void CheckResults(int expected_num_channels) { + const auto& output_header = + this->output().header.template Get(); + EXPECT_EQ(this->input_sample_rate_, output_header.sample_rate()); + CheckOutputPacketMetadata(expected_num_channels, num_samples_per_packet_); + + // Sanity check that output packets have non-zero energy. + for (const auto& packet : this->output().packets) { + const Matrix& data = packet.template Get(); + EXPECT_GT(data.squaredNorm(), 0); + } + } + + // Allows SetupRandomInputPackets() to inform CheckResults() about how + // big the packets are supposed to be. + int num_samples_per_packet_; +}; + +constexpr char kMfccCalculator[] = "MfccCalculator"; +typedef FramewiseTransformCalculatorTest + MfccCalculatorTest; +TEST_F(MfccCalculatorTest, AudioSampleRateFromInputHeader) { + audio_sample_rate_ = kAudioSampleRate; + SetupGraphAndHeader(); + SetupRandomInputPackets(); + + MEDIAPIPE_EXPECT_OK(Run()); + + CheckResults(options_.mfcc_count()); +} +TEST_F(MfccCalculatorTest, NoAudioSampleRate) { + // Leave audio_sample_rate_ == kUnset, so it is not present in the + // input TimeSeriesHeader; expect failure. + SetupGraphAndHeader(); + SetupRandomInputPackets(); + + EXPECT_FALSE(Run().ok()); +} + +constexpr char kMelSpectrumCalculator[] = "MelSpectrumCalculator"; +typedef FramewiseTransformCalculatorTest + MelSpectrumCalculatorTest; +TEST_F(MelSpectrumCalculatorTest, AudioSampleRateFromInputHeader) { + audio_sample_rate_ = kAudioSampleRate; + SetupGraphAndHeader(); + SetupRandomInputPackets(); + + MEDIAPIPE_EXPECT_OK(Run()); + + CheckResults(options_.channel_count()); +} +TEST_F(MelSpectrumCalculatorTest, NoAudioSampleRate) { + // Leave audio_sample_rate_ == kUnset, so it is not present in the + // input TimeSeriesHeader; expect failure. + SetupGraphAndHeader(); + SetupRandomInputPackets(); + + EXPECT_FALSE(Run().ok()); +} +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator.cc b/mediapipe/calculators/audio/rational_factor_resample_calculator.cc new file mode 100644 index 000000000..3a966f8f8 --- /dev/null +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator.cc @@ -0,0 +1,197 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Defines RationalFactorResampleCalculator. + +#include "mediapipe/calculators/audio/rational_factor_resample_calculator.h" + +#include "audio/dsp/resampler_rational_factor.h" + +using audio_dsp::DefaultResamplingKernel; +using audio_dsp::RationalFactorResampler; +using audio_dsp::Resampler; + +namespace mediapipe { +::mediapipe::Status RationalFactorResampleCalculator::Process( + CalculatorContext* cc) { + return ProcessInternal(cc->Inputs().Index(0).Get(), false, cc); +} + +::mediapipe::Status RationalFactorResampleCalculator::Close( + CalculatorContext* cc) { + if (initial_timestamp_ == Timestamp::Unstarted()) { + return ::mediapipe::OkStatus(); + } + Matrix empty_input_frame(num_channels_, 0); + return ProcessInternal(empty_input_frame, true, cc); +} + +namespace { +void CopyChannelToVector(const Matrix& matrix, int channel, + std::vector* vec) { + vec->clear(); + vec->reserve(matrix.cols()); + for (int sample = 0; sample < matrix.cols(); ++sample) { + vec->push_back(matrix(channel, sample)); + } +} + +void CopyVectorToChannel(const std::vector& vec, Matrix* matrix, + int channel) { + if (matrix->cols() == 0) { + matrix->resize(matrix->rows(), vec.size()); + } else { + CHECK_EQ(vec.size(), matrix->cols()); + CHECK_LT(channel, matrix->rows()); + } + for (int sample = 0; sample < matrix->cols(); ++sample) { + (*matrix)(channel, sample) = vec[sample]; + } +} + +} // namespace + +::mediapipe::Status RationalFactorResampleCalculator::Open( + CalculatorContext* cc) { + RationalFactorResampleCalculatorOptions resample_options; + time_series_util::FillOptionsExtensionOrDie(cc->Options(), &resample_options); + + if (!resample_options.has_target_sample_rate()) { + return tool::StatusInvalid( + "resample_options doesn't have target_sample_rate."); + } + target_sample_rate_ = resample_options.target_sample_rate(); + + TimeSeriesHeader input_header; + RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid( + cc->Inputs().Index(0).Header(), &input_header)); + + source_sample_rate_ = input_header.sample_rate(); + num_channels_ = input_header.num_channels(); + + // Don't create resamplers for pass-thru (sample rates are equal). + if (source_sample_rate_ != target_sample_rate_) { + resampler_.resize(num_channels_); + for (auto& r : resampler_) { + r = ResamplerFromOptions(source_sample_rate_, target_sample_rate_, + resample_options); + if (!r) { + LOG(ERROR) << "Failed to initialize resampler."; + return ::mediapipe::UnknownError("Failed to initialize resampler."); + } + } + } + + TimeSeriesHeader* output_header = new TimeSeriesHeader(input_header); + output_header->set_sample_rate(target_sample_rate_); + // The resampler doesn't make guarantees about how many samples will + // be in each packet. + output_header->clear_packet_rate(); + output_header->clear_num_samples(); + + cc->Outputs().Index(0).SetHeader(Adopt(output_header)); + cumulative_output_samples_ = 0; + cumulative_input_samples_ = 0; + initial_timestamp_ = Timestamp::Unstarted(); + check_inconsistent_timestamps_ = + resample_options.check_inconsistent_timestamps(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status RationalFactorResampleCalculator::ProcessInternal( + const Matrix& input_frame, bool should_flush, CalculatorContext* cc) { + if (initial_timestamp_ == Timestamp::Unstarted()) { + initial_timestamp_ = cc->InputTimestamp(); + } + + if (check_inconsistent_timestamps_) { + time_series_util::LogWarningIfTimestampIsInconsistent( + cc->InputTimestamp(), initial_timestamp_, cumulative_input_samples_, + source_sample_rate_); + } + Timestamp output_timestamp = + initial_timestamp_ + ((cumulative_output_samples_ / target_sample_rate_) * + Timestamp::kTimestampUnitsPerSecond); + + cumulative_input_samples_ += input_frame.cols(); + std::unique_ptr output_frame(new Matrix(num_channels_, 0)); + if (resampler_.empty()) { + // Sample rates were same for input and output; pass-thru. + *output_frame = input_frame; + } else { + if (!Resample(input_frame, output_frame.get(), should_flush)) { + return ::mediapipe::UnknownError("Resample() failed."); + } + } + cumulative_output_samples_ += output_frame->cols(); + + if (output_frame->cols() > 0) { + cc->Outputs().Index(0).Add(output_frame.release(), output_timestamp); + } + return ::mediapipe::OkStatus(); +} + +bool RationalFactorResampleCalculator::Resample(const Matrix& input_frame, + Matrix* output_frame, + bool should_flush) { + std::vector input_vector; + std::vector output_vector; + for (int i = 0; i < input_frame.rows(); ++i) { + CopyChannelToVector(input_frame, i, &input_vector); + if (should_flush) { + resampler_[i]->Flush(&output_vector); + } else { + resampler_[i]->ProcessSamples(input_vector, &output_vector); + } + CopyVectorToChannel(output_vector, output_frame, i); + } + return true; +} + +// static +std::unique_ptr> +RationalFactorResampleCalculator::ResamplerFromOptions( + const double source_sample_rate, const double target_sample_rate, + const RationalFactorResampleCalculatorOptions& options) { + std::unique_ptr> resampler; + const auto& rational_factor_options = + options.resampler_rational_factor_options(); + std::unique_ptr kernel; + if (rational_factor_options.has_radius() && + rational_factor_options.has_cutoff() && + rational_factor_options.has_kaiser_beta()) { + kernel = absl::make_unique( + source_sample_rate, target_sample_rate, + rational_factor_options.radius(), rational_factor_options.cutoff(), + rational_factor_options.kaiser_beta()); + } else { + kernel = absl::make_unique(source_sample_rate, + target_sample_rate); + } + + // Set large enough so that the resampling factor between common sample + // rates (e.g. 8kHz, 16kHz, 22.05kHz, 32kHz, 44.1kHz, 48kHz) is exact, and + // that any factor is represented with error less than 0.025%. + const int kMaxDenominator = 2000; + resampler = absl::make_unique>( + *kernel, kMaxDenominator); + if (resampler != nullptr && !resampler->Valid()) { + resampler = std::unique_ptr>(); + } + return resampler; +} + +REGISTER_CALCULATOR(RationalFactorResampleCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator.h b/mediapipe/calculators/audio/rational_factor_resample_calculator.h new file mode 100644 index 000000000..745ac8f0d --- /dev/null +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator.h @@ -0,0 +1,106 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_ + +#include +#include +#include + +#include "Eigen/Core" +#include "absl/strings/str_cat.h" +#include "audio/dsp/resampler.h" +#include "mediapipe/calculators/audio/rational_factor_resample_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/util/time_series_util.h" + +namespace mediapipe { +// MediaPipe Calculator for resampling a (vector-valued) +// input time series with a uniform sample rate. The output +// stream's sampling rate is specified by target_sample_rate in the +// RationalFactorResampleCalculatorOptions. The output time series may have +// a varying number of samples per frame. +class RationalFactorResampleCalculator : public CalculatorBase { + public: + struct TestAccess; + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set( + // Single input stream with TimeSeriesHeader. + ); + cc->Outputs().Index(0).Set( + // Resampled stream with TimeSeriesHeader. + ); + return ::mediapipe::OkStatus(); + } + // Returns FAIL if the input stream header is invalid or if the + // resampler cannot be initialized. + ::mediapipe::Status Open(CalculatorContext* cc) override; + // Resamples a packet of TimeSeries data. Returns FAIL if the + // resampler state becomes inconsistent. + ::mediapipe::Status Process(CalculatorContext* cc) override; + // Flushes any remaining state. Returns FAIL if the resampler state + // becomes inconsistent. + ::mediapipe::Status Close(CalculatorContext* cc) override; + + protected: + typedef audio_dsp::Resampler ResamplerType; + + // Returns a Resampler implementation specified by the + // RationalFactorResampleCalculatorOptions proto. Returns null if the options + // specify an invalid resampler. + static std::unique_ptr ResamplerFromOptions( + const double source_sample_rate, const double target_sample_rate, + const RationalFactorResampleCalculatorOptions& options); + + // Does Timestamp bookkeeping and resampling common to Process() and + // Close(). Returns FAIL if the resampler state becomes + // inconsistent. + ::mediapipe::Status ProcessInternal(const Matrix& input_frame, + bool should_flush, CalculatorContext* cc); + + // Uses the internal resampler_ objects to actually resample each + // row of the input TimeSeries. Returns false if the resampler + // state becomes inconsistent. + bool Resample(const Matrix& input_frame, Matrix* output_frame, + bool should_flush); + + double source_sample_rate_; + double target_sample_rate_; + int64 cumulative_input_samples_; + int64 cumulative_output_samples_; + Timestamp initial_timestamp_; + bool check_inconsistent_timestamps_; + int num_channels_; + std::vector> resampler_; +}; + +// Test-only access to RationalFactorResampleCalculator methods. +struct RationalFactorResampleCalculator::TestAccess { + static std::unique_ptr ResamplerFromOptions( + const double source_sample_rate, const double target_sample_rate, + const RationalFactorResampleCalculatorOptions& options) { + return RationalFactorResampleCalculator::ResamplerFromOptions( + source_sample_rate, target_sample_rate, options); + } +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_ diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator.proto b/mediapipe/calculators/audio/rational_factor_resample_calculator.proto new file mode 100644 index 000000000..6eb36e672 --- /dev/null +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator.proto @@ -0,0 +1,46 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message RationalFactorResampleCalculatorOptions { + extend CalculatorOptions { + optional RationalFactorResampleCalculatorOptions ext = 259760074; + } + + // target_sample_rate is the sample rate, in Hertz, of the output + // stream. Required. Must be greater than 0. + optional double target_sample_rate = 1; + + // Parameters for initializing the RationalFactorResampler. See + // RationalFactorResampler for more details. + message ResamplerRationalFactorOptions { + // Kernel radius in units of input samples. + optional double radius = 1; + // Anti-aliasing cutoff frequency in Hertz. A reasonable setting is + // 0.45 * min(input_sample_rate, output_sample_rate). + optional double cutoff = 2; + // The Kaiser beta parameter for the kernel window. + optional double kaiser_beta = 3 [default = 6.0]; + } + optional ResamplerRationalFactorOptions resampler_rational_factor_options = 2; + + // Set to false to disable checks for jitter in timestamp values. Useful with + // live audio input. + optional bool check_inconsistent_timestamps = 3 [default = true]; +} diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator_test.cc b/mediapipe/calculators/audio/rational_factor_resample_calculator_test.cc new file mode 100644 index 000000000..38b947517 --- /dev/null +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator_test.cc @@ -0,0 +1,247 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/audio/rational_factor_resample_calculator.h" + +#include + +#include +#include +#include + +#include "Eigen/Core" +#include "audio/dsp/signal_vector_util.h" +#include "mediapipe/calculators/audio/rational_factor_resample_calculator.pb.h" +#include "mediapipe/framework//tool/validate_type.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/time_series_test_util.h" + +namespace mediapipe { +namespace { + +const int kInitialTimestampOffsetMilliseconds = 4; + +class RationalFactorResampleCalculatorTest + : public TimeSeriesCalculatorTest { + protected: + void SetUp() override { + calculator_name_ = "RationalFactorResampleCalculator"; + input_sample_rate_ = 4000.0; + num_input_channels_ = 3; + } + + // Expects two vectors whose lengths are almost the same and whose + // elements are equal (for indices that are present in both). + // + // This is useful because the resampler doesn't make precise + // guarantees about its output size. + void ExpectVectorMostlyFloatEq(const std::vector& expected, + const std::vector& actual) { + // Lengths should be close, but don't have to be equal. + ASSERT_NEAR(expected.size(), actual.size(), 1); + for (int i = 0; i < std::min(expected.size(), actual.size()); ++i) { + EXPECT_FLOAT_EQ(expected[i], actual[i]) << " where i=" << i << "."; + } + } + + // Returns a float value with the sample, channel, and timestamp + // separated by a few orders of magnitude, for easy parsing by + // humans. + double TestValue(int sample, int channel, int timestamp_in_microseconds) { + return timestamp_in_microseconds * 100.0 + sample + channel / 10.0; + } + + // Caller takes ownership of the returned value. + Matrix* NewTestFrame(int num_channels, int num_samples, int timestamp) { + auto matrix = new Matrix(num_channels, num_samples); + for (int c = 0; c < num_channels; ++c) { + for (int i = 0; i < num_samples; ++i) { + (*matrix)(c, i) = TestValue(i, c, timestamp); + } + } + return matrix; + } + + // Initializes and runs the test graph. + ::mediapipe::Status Run(double output_sample_rate) { + options_.set_target_sample_rate(output_sample_rate); + InitializeGraph(); + + FillInputHeader(); + concatenated_input_samples_.resize(num_input_channels_, 0); + num_input_samples_ = 0; + for (int i = 0; i < 5; ++i) { + int packet_size = (i + 1) * 10; + int timestamp = kInitialTimestampOffsetMilliseconds + + num_input_samples_ / input_sample_rate_ * + Timestamp::kTimestampUnitsPerSecond; + Matrix* data_frame = + NewTestFrame(num_input_channels_, packet_size, timestamp); + + // Keep a reference copy of the input. + // + // conservativeResize() is needed here to preserve the existing + // data. Eigen's resize() resizes without preserving data. + concatenated_input_samples_.conservativeResize( + num_input_channels_, num_input_samples_ + packet_size); + concatenated_input_samples_.rightCols(packet_size) = *data_frame; + num_input_samples_ += packet_size; + + AppendInputPacket(data_frame, timestamp); + } + + return RunGraph(); + } + + void CheckOutputLength(double output_sample_rate) { + double factor = output_sample_rate / input_sample_rate_; + + int num_output_samples = 0; + for (const Packet& packet : output().packets) { + num_output_samples += packet.Get().cols(); + } + + // The exact number of expected samples may vary based on the implementation + // of the resampler since the exact value is not an integer. + // TODO: Reduce this offset to + 1 once cl/185829520 is submitted. + const double expected_num_output_samples = num_input_samples_ * factor; + EXPECT_LE(ceil(expected_num_output_samples), num_output_samples); + EXPECT_GE(ceil(expected_num_output_samples) + 11, num_output_samples); + } + + // Checks that output timestamps are consistent with the + // output_sample_rate and output packet sizes. + void CheckOutputPacketTimestamps(double output_sample_rate) { + int num_output_samples = 0; + for (const Packet& packet : output().packets) { + const int expected_timestamp = kInitialTimestampOffsetMilliseconds + + num_output_samples / output_sample_rate * + Timestamp::kTimestampUnitsPerSecond; + EXPECT_NEAR(expected_timestamp, packet.Timestamp().Value(), 1); + num_output_samples += packet.Get().cols(); + } + } + + // Checks that output values from the calculator (which resamples + // packet-by-packet) are consistent with resampling the entire + // signal at once. + void CheckOutputValues(double output_sample_rate) { + for (int i = 0; i < num_input_channels_; ++i) { + auto verification_resampler = + RationalFactorResampleCalculator::TestAccess::ResamplerFromOptions( + input_sample_rate_, output_sample_rate, options_); + + std::vector input_data; + for (int j = 0; j < num_input_samples_; ++j) { + input_data.push_back(concatenated_input_samples_(i, j)); + } + std::vector expected_resampled_data; + std::vector temp; + verification_resampler->ProcessSamples(input_data, &temp); + audio_dsp::VectorAppend(&expected_resampled_data, temp); + verification_resampler->Flush(&temp); + audio_dsp::VectorAppend(&expected_resampled_data, temp); + std::vector actual_resampled_data; + for (const Packet& packet : output().packets) { + Matrix output_frame_row = packet.Get().row(i); + actual_resampled_data.insert( + actual_resampled_data.end(), &output_frame_row(0), + &output_frame_row(0) + output_frame_row.cols()); + } + + ExpectVectorMostlyFloatEq(expected_resampled_data, actual_resampled_data); + } + } + + void CheckOutputHeaders(double output_sample_rate) { + const TimeSeriesHeader& output_header = + output().header.Get(); + TimeSeriesHeader expected_header; + expected_header.set_sample_rate(output_sample_rate); + expected_header.set_num_channels(num_input_channels_); + EXPECT_THAT(output_header, mediapipe::EqualsProto(expected_header)); + } + + void CheckOutput(double output_sample_rate) { + CheckOutputLength(output_sample_rate); + CheckOutputPacketTimestamps(output_sample_rate); + CheckOutputValues(output_sample_rate); + CheckOutputHeaders(output_sample_rate); + } + + void CheckOutputUnchanged() { + for (int i = 0; i < num_input_channels_; ++i) { + std::vector expected_resampled_data; + for (int j = 0; j < num_input_samples_; ++j) { + expected_resampled_data.push_back(concatenated_input_samples_(i, j)); + } + std::vector actual_resampled_data; + for (const Packet& packet : output().packets) { + Matrix output_frame_row = packet.Get().row(i); + actual_resampled_data.insert( + actual_resampled_data.end(), &output_frame_row(0), + &output_frame_row(0) + output_frame_row.cols()); + } + ExpectVectorMostlyFloatEq(expected_resampled_data, actual_resampled_data); + } + } + + int num_input_samples_; + Matrix concatenated_input_samples_; +}; + +TEST_F(RationalFactorResampleCalculatorTest, Upsample) { + const double kUpsampleRate = input_sample_rate_ * 1.9; + MEDIAPIPE_ASSERT_OK(Run(kUpsampleRate)); + CheckOutput(kUpsampleRate); +} + +TEST_F(RationalFactorResampleCalculatorTest, Downsample) { + const double kDownsampleRate = input_sample_rate_ / 1.9; + MEDIAPIPE_ASSERT_OK(Run(kDownsampleRate)); + CheckOutput(kDownsampleRate); +} + +TEST_F(RationalFactorResampleCalculatorTest, UsesRationalFactorResampler) { + const double kUpsampleRate = input_sample_rate_ * 2; + MEDIAPIPE_ASSERT_OK(Run(kUpsampleRate)); + CheckOutput(kUpsampleRate); +} + +TEST_F(RationalFactorResampleCalculatorTest, PassthroughIfSampleRateUnchanged) { + const double kUpsampleRate = input_sample_rate_; + MEDIAPIPE_ASSERT_OK(Run(kUpsampleRate)); + CheckOutputUnchanged(); +} + +TEST_F(RationalFactorResampleCalculatorTest, FailsOnBadTargetRate) { + ASSERT_FALSE(Run(-999.9).ok()); // Invalid output sample rate. +} + +TEST_F(RationalFactorResampleCalculatorTest, DoesNotDieOnEmptyInput) { + options_.set_target_sample_rate(input_sample_rate_); + InitializeGraph(); + FillInputHeader(); + MEDIAPIPE_ASSERT_OK(RunGraph()); + EXPECT_TRUE(output().packets.empty()); +} + +} // anonymous namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/spectrogram_calculator.cc b/mediapipe/calculators/audio/spectrogram_calculator.cc new file mode 100644 index 000000000..15f4c917e --- /dev/null +++ b/mediapipe/calculators/audio/spectrogram_calculator.cc @@ -0,0 +1,425 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Defines SpectrogramCalculator. +#include + +#include +#include +#include +#include + +#include "Eigen/Core" +#include "absl/strings/string_view.h" +#include "audio/dsp/spectrogram/spectrogram.h" +#include "audio/dsp/window_functions.h" +#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/core_proto_inc.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/source_location.h" +#include "mediapipe/framework/port/status_builder.h" +#include "mediapipe/util/time_series_util.h" + +namespace mediapipe { + +// MediaPipe Calculator for computing the "spectrogram" (short-time Fourier +// transform squared-magnitude, by default) of a multichannel input +// time series, including optionally overlapping frames. Options are +// specified in SpectrogramCalculatorOptions proto (where names are chosen +// to mirror TimeSeriesFramerCalculator): +// +// Result is a MatrixData record (for single channel input and when the +// allow_multichannel_input flag is false), or a vector of MatrixData records, +// one for each channel (when the allow_multichannel_input flag is set). The +// rows of each spectrogram matrix correspond to the n_fft/2+1 unique complex +// values, or squared/linear/dB magnitudes, depending on the output_type option. +// Each input packet will result in zero or one output packets, each containing +// one Matrix for each channel of the input, where each Matrix has one or more +// columns of spectral values, one for each complete frame of input samples. If +// the input packet contains too few samples to trigger a new output frame, no +// output packet is generated (since zero-length packets are not legal since +// they would result in timestamps that were equal, not strictly increasing). +// +// Output packet Timestamps are set to the beginning of each frame. This is to +// allow calculators downstream from SpectrogramCalculator to have aligned +// Timestamps regardless of a packet's signal length. +// +// Both frame_duration_seconds and frame_overlap_seconds will be +// rounded to the nearest integer number of samples. Conseqently, all output +// frames will be based on the same number of input samples, and each +// analysis frame will advance from its predecessor by the same time step. +class SpectrogramCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set( + // Input stream with TimeSeriesHeader. + ); + + SpectrogramCalculatorOptions spectrogram_options; + time_series_util::FillOptionsExtensionOrDie(cc->Options(), + &spectrogram_options); + + if (!spectrogram_options.allow_multichannel_input()) { + if (spectrogram_options.output_type() == + SpectrogramCalculatorOptions::COMPLEX) { + cc->Outputs().Index(0).Set( + // Complex spectrogram frames with TimeSeriesHeader. + ); + } else { + cc->Outputs().Index(0).Set( + // Spectrogram frames with TimeSeriesHeader. + ); + } + } else { + if (spectrogram_options.output_type() == + SpectrogramCalculatorOptions::COMPLEX) { + cc->Outputs().Index(0).Set>( + // Complex spectrogram frames with MultiStreamTimeSeriesHeader. + ); + } else { + cc->Outputs().Index(0).Set>( + // Spectrogram frames with MultiStreamTimeSeriesHeader. + ); + } + } + return ::mediapipe::OkStatus(); + } + + // Returns FAIL if the input stream header is invalid. + ::mediapipe::Status Open(CalculatorContext* cc) override; + + // Outputs at most one packet consisting of a single Matrix with one or + // more columns containing the spectral values from as many input frames + // as are completed by the input samples. Always returns OK. + ::mediapipe::Status Process(CalculatorContext* cc) override; + + // Performs zero-padding and processing of any remaining samples + // if pad_final_packet is set. + // Returns OK. + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + Timestamp CurrentOutputTimestamp() { + // Current output timestamp is the *center* of the next frame to be + // emitted, hence delayed by half a window duration compared to relevant + // input timestamp. + return initial_input_timestamp_ + + round(cumulative_completed_frames_ * frame_step_samples() * + Timestamp::kTimestampUnitsPerSecond / input_sample_rate_); + } + + int frame_step_samples() const { + return frame_duration_samples_ - frame_overlap_samples_; + } + + // Take the next set of input samples, already translated into a + // vector and pass them to the spectrogram object. + // Convert the output of the spectrogram object into a Matrix (or an + // Eigen::MatrixXcf if complex-valued output is requested) and pass to + // MediaPipe output. + ::mediapipe::Status ProcessVector(const Matrix& input_stream, + CalculatorContext* cc); + + // Templated function to process either real- or complex-output spectrogram. + template + ::mediapipe::Status ProcessVectorToOutput( + const Matrix& input_stream, + const OutputMatrixType postprocess_output_fn(const OutputMatrixType&), + CalculatorContext* cc); + + double input_sample_rate_; + bool pad_final_packet_; + int frame_duration_samples_; + int frame_overlap_samples_; + // How many samples we've been passed, used for checking input time stamps. + int64 cumulative_input_samples_; + // How many frames we've emitted, used for calculating output time stamps. + int64 cumulative_completed_frames_; + Timestamp initial_input_timestamp_; + int num_input_channels_; + // How many frequency bins we emit (=N_FFT/2 + 1). + int num_output_channels_; + // Which output type? + int output_type_; + // Output type: mono or multichannel. + bool allow_multichannel_input_; + // Vector of Spectrogram objects, one for each channel. + std::vector> spectrogram_generators_; + // Fixed scale factor applied to output values (regardless of type). + double output_scale_; + + static const float kLnPowerToDb; +}; +REGISTER_CALCULATOR(SpectrogramCalculator); + +// Factor to convert ln(magnitude_squared) to deciBels = 10.0/ln(10.0). +const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518; + +::mediapipe::Status SpectrogramCalculator::Open(CalculatorContext* cc) { + SpectrogramCalculatorOptions spectrogram_options; + time_series_util::FillOptionsExtensionOrDie(cc->Options(), + &spectrogram_options); + + if (spectrogram_options.frame_duration_seconds() <= 0.0) { + ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Invalid or missing frame_duration_seconds.\n" + "frame_duration_seconds: " + << spectrogram_options.frame_overlap_seconds(); + } + if (spectrogram_options.frame_overlap_seconds() >= + spectrogram_options.frame_duration_seconds()) { + ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Invalid frame_overlap_seconds.\nframe_overlap_seconds: " + << spectrogram_options.frame_overlap_seconds() + << "\nframe_duration_seconds: " + << spectrogram_options.frame_duration_seconds(); + } + if (spectrogram_options.frame_overlap_seconds() < 0.0) { + ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Frame_overlap_seconds is < 0.0.\nframe_overlap_seconds: " + << spectrogram_options.frame_overlap_seconds(); + } + + TimeSeriesHeader input_header; + RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid( + cc->Inputs().Index(0).Header(), &input_header)); + + input_sample_rate_ = input_header.sample_rate(); + num_input_channels_ = input_header.num_channels(); + + if (!spectrogram_options.allow_multichannel_input() && + num_input_channels_ != 1) { + ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "The current setting only supports single-channel input. Please set " + "allow_multichannel_input.\n"; + } + + frame_duration_samples_ = + round(spectrogram_options.frame_duration_seconds() * input_sample_rate_); + frame_overlap_samples_ = + round(spectrogram_options.frame_overlap_seconds() * input_sample_rate_); + + pad_final_packet_ = spectrogram_options.pad_final_packet(); + output_type_ = spectrogram_options.output_type(); + allow_multichannel_input_ = spectrogram_options.allow_multichannel_input(); + + output_scale_ = spectrogram_options.output_scale(); + + std::vector window; + switch (spectrogram_options.window_type()) { + case SpectrogramCalculatorOptions::HANN: + audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_, + &window); + break; + case SpectrogramCalculatorOptions::HAMMING: + audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_, + &window); + break; + } + + // Propagate settings down to the actual Spectrogram object. + spectrogram_generators_.clear(); + for (int i = 0; i < num_input_channels_; i++) { + spectrogram_generators_.push_back( + std::unique_ptr(new audio_dsp::Spectrogram())); + spectrogram_generators_[i]->Initialize(window, frame_step_samples()); + } + + num_output_channels_ = + spectrogram_generators_[0]->output_frequency_channels(); + std::unique_ptr output_header( + new TimeSeriesHeader(input_header)); + // Store the actual sample rate of the input audio in the TimeSeriesHeader + // so that subsequent calculators can figure out the frequency scale of + // our output. + output_header->set_audio_sample_rate(input_sample_rate_); + // Setup rest of output header. + output_header->set_num_channels(num_output_channels_); + output_header->set_sample_rate(input_sample_rate_ / frame_step_samples()); + // Although we usually generate one output packet for each input + // packet, this might not be true for input packets whose size is smaller + // than the analysis window length. So we clear output_header.packet_rate + // because we can't guarantee a constant packet rate. Similarly, the number + // of output frames per packet depends on the input packet, so we also clear + // output_header.num_samples. + output_header->clear_packet_rate(); + output_header->clear_num_samples(); + if (!spectrogram_options.allow_multichannel_input()) { + cc->Outputs().Index(0).SetHeader(Adopt(output_header.release())); + } else { + std::unique_ptr multichannel_output_header( + new MultiStreamTimeSeriesHeader()); + *multichannel_output_header->mutable_time_series_header() = *output_header; + multichannel_output_header->set_num_streams(num_input_channels_); + cc->Outputs().Index(0).SetHeader( + Adopt(multichannel_output_header.release())); + } + cumulative_completed_frames_ = 0; + initial_input_timestamp_ = Timestamp::Unstarted(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SpectrogramCalculator::Process(CalculatorContext* cc) { + if (initial_input_timestamp_ == Timestamp::Unstarted()) { + initial_input_timestamp_ = cc->InputTimestamp(); + } + + const Matrix& input_stream = cc->Inputs().Index(0).Get(); + if (input_stream.rows() != num_input_channels_) { + ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Number of input channels do not correspond to the number of rows " + << "in the input matrix: " << num_input_channels_ << "channels vs " + << input_stream.rows() << " rows"; + } + + cumulative_input_samples_ += input_stream.cols(); + + return ProcessVector(input_stream, cc); +} + +template +::mediapipe::Status SpectrogramCalculator::ProcessVectorToOutput( + const Matrix& input_stream, + const OutputMatrixType postprocess_output_fn(const OutputMatrixType&), + CalculatorContext* cc) { + std::unique_ptr> spectrogram_matrices( + new std::vector()); + std::vector> output_vectors; + + // Compute a spectrogram for each channel. + int num_output_time_frames; + for (int channel = 0; channel < input_stream.rows(); ++channel) { + output_vectors.clear(); + + // Copy one row (channel) of the input matrix into the std::vector. + std::vector input_vector(input_stream.cols()); + Eigen::Map(&input_vector[0], 1, input_vector.size()) = + input_stream.row(channel); + + if (!spectrogram_generators_[channel]->ComputeSpectrogram( + input_vector, &output_vectors)) { + return ::mediapipe::Status(mediapipe::StatusCode::kInternal, + "Spectrogram returned failure"); + } + if (channel == 0) { + // Record the number of time frames we expect from each channel. + num_output_time_frames = output_vectors.size(); + } else { + RET_CHECK_EQ(output_vectors.size(), num_output_time_frames) + << "Inconsistent spectrogram time frames for channel " << channel; + } + // Skip remaining processing if there are too few input samples to trigger + // any output frames. + if (!output_vectors.empty()) { + // Translate the returned values into a matrix of output frames. + OutputMatrixType output_frames(num_output_channels_, + output_vectors.size()); + for (int frame = 0; frame < output_vectors.size(); ++frame) { + Eigen::Map frame_map( + &output_vectors[frame][0], output_vectors[frame].size(), 1); + // The underlying dsp object returns squared magnitudes; here + // we optionally translate to linear magnitude or dB. + output_frames.col(frame) = + output_scale_ * postprocess_output_fn(frame_map); + } + spectrogram_matrices->push_back(output_frames); + } + } + // If the input is very short, there may not be enough accumulated, + // unprocessed samples to cause any new frames to be generated by + // the spectrogram object. If so, we don't want to emit + // a packet at all. + if (!spectrogram_matrices->empty()) { + RET_CHECK_EQ(spectrogram_matrices->size(), input_stream.rows()) + << "Inconsistent number of spectrogram channels."; + if (allow_multichannel_input_) { + cc->Outputs().Index(0).Add(spectrogram_matrices.release(), + CurrentOutputTimestamp()); + } else { + cc->Outputs().Index(0).Add( + new OutputMatrixType(spectrogram_matrices->at(0)), + CurrentOutputTimestamp()); + } + cumulative_completed_frames_ += output_vectors.size(); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SpectrogramCalculator::ProcessVector( + const Matrix& input_stream, CalculatorContext* cc) { + switch (output_type_) { + // These blocks deliberately ignore clang-format to preserve the + // "silhouette" of the different cases. + // clang-format off + case SpectrogramCalculatorOptions::COMPLEX: { + return ProcessVectorToOutput( + input_stream, + +[](const Eigen::MatrixXcf& col) -> const Eigen::MatrixXcf { + return col; + }, cc); + } + case SpectrogramCalculatorOptions::SQUARED_MAGNITUDE: { + return ProcessVectorToOutput( + input_stream, + +[](const Matrix& col) -> const Matrix { + return col; + }, cc); + } + case SpectrogramCalculatorOptions::LINEAR_MAGNITUDE: { + return ProcessVectorToOutput( + input_stream, + +[](const Matrix& col) -> const Matrix { + return col.array().sqrt().matrix(); + }, cc); + } + case SpectrogramCalculatorOptions::DECIBELS: { + return ProcessVectorToOutput( + input_stream, + +[](const Matrix& col) -> const Matrix { + return kLnPowerToDb * col.array().log().matrix(); + }, cc); + } + // clang-format on + default: { + return ::mediapipe::Status(mediapipe::StatusCode::kInvalidArgument, + "Unrecognized spectrogram output type."); + } + } +} + +::mediapipe::Status SpectrogramCalculator::Close(CalculatorContext* cc) { + if (cumulative_input_samples_ > 0 && pad_final_packet_) { + // We can flush any remaining samples by sending frame_step_samples - 1 + // zeros to the Process method, and letting it do its thing, + // UNLESS we have fewer than one window's worth of samples, in which case + // we pad to exactly one frame_duration_samples. + // Release the memory for the Spectrogram objects. + int required_padding_samples = frame_step_samples() - 1; + if (cumulative_input_samples_ < frame_duration_samples_) { + required_padding_samples = + frame_duration_samples_ - cumulative_input_samples_; + } + return ProcessVector( + Matrix::Zero(num_input_channels_, required_padding_samples), cc); + } + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/spectrogram_calculator.proto b/mediapipe/calculators/audio/spectrogram_calculator.proto new file mode 100644 index 000000000..faef8d590 --- /dev/null +++ b/mediapipe/calculators/audio/spectrogram_calculator.proto @@ -0,0 +1,68 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message SpectrogramCalculatorOptions { + extend CalculatorOptions { + optional SpectrogramCalculatorOptions ext = 76186688; + } + + // Options mirror those of TimeSeriesFramerCalculator. + + // Analysis window duration in seconds. Required. Must be greater than 0. + // (Note: the spectrogram DFT length will be the smallest power-of-2 + // sample count that can hold this duration.) + optional double frame_duration_seconds = 1; + + // Duration of overlap between adjacent windows. + // Hence, frame_rate = 1/(frame_duration_seconds - frame_overlap_seconds). + // Required that 0 <= frame_overlap_seconds < frame_duration_seconds. + optional double frame_overlap_seconds = 2 [default = 0.0]; + + // Whether to pad the final packet with zeros. If true, guarantees that + // all input samples will output. If set to false, any partial packet + // at the end of the stream will be dropped. + optional bool pad_final_packet = 3 [default = true]; + + // Output value type can be squared-magnitude, linear-magnitude, + // deciBels (dB, = 20*log10(linear_magnitude)), or std::complex. + enum OutputType { + SQUARED_MAGNITUDE = 0; + LINEAR_MAGNITUDE = 1; + DECIBELS = 2; + COMPLEX = 3; + } + optional OutputType output_type = 4 [default = SQUARED_MAGNITUDE]; + + // If set to true then the output will be a vector of spectrograms, one for + // each channel and the stream will have a MultiStreamTimeSeriesHeader. + optional bool allow_multichannel_input = 5 [default = false]; + + // Which window to use when computing the FFT. + enum WindowType { + HANN = 0; + HAMMING = 1; + } + optional WindowType window_type = 6 [default = HANN]; + + // Support a fixed multiplicative scaling of the output. This is applied + // uniformly regardless of output type (i.e., even dBs are multiplied, not + // offset). + optional double output_scale = 7 [default = 1.0]; +} diff --git a/mediapipe/calculators/audio/spectrogram_calculator_test.cc b/mediapipe/calculators/audio/spectrogram_calculator_test.cc new file mode 100644 index 000000000..e783f04fa --- /dev/null +++ b/mediapipe/calculators/audio/spectrogram_calculator_test.cc @@ -0,0 +1,895 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include + +#include "Eigen/Core" +#include "audio/dsp/number_util.h" +#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/benchmark.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/time_series_test_util.h" + +namespace mediapipe { +namespace { + +const int kInitialTimestampOffsetMicroseconds = 4; + +class SpectrogramCalculatorTest + : public TimeSeriesCalculatorTest { + protected: + void SetUp() override { + calculator_name_ = "SpectrogramCalculator"; + input_sample_rate_ = 4000.0; + num_input_channels_ = 1; + } + + // Initializes and runs the test graph. + ::mediapipe::Status Run() { + // Now that options are set, we can set up some internal constants. + frame_duration_samples_ = + round(options_.frame_duration_seconds() * input_sample_rate_); + frame_step_samples_ = + frame_duration_samples_ - + round(options_.frame_overlap_seconds() * input_sample_rate_); + // The magnitude of the 0th FFT bin (DC) should be sum(input.*window); + // for an input identically 1.0, this is just sum(window). The average + // value of our Hann window is 0.5, hence this is the expected squared- + // magnitude output value in the DC bin for constant input of 1.0. + expected_dc_squared_magnitude_ = + pow((static_cast(frame_duration_samples_) * 0.5), 2.0); + + return RunGraph(); + } + + // Creates test multichannel input with specified packet sizes and containing + // a constant-frequency sinusoid that maintains phase between adjacent + // packets. + void SetupCosineInputPackets(const std::vector& packet_sizes_samples, + float cosine_frequency_hz) { + int total_num_input_samples = 0; + for (int packet_size_samples : packet_sizes_samples) { + double packet_start_time_seconds = + kInitialTimestampOffsetMicroseconds * 1e-6 + + total_num_input_samples / input_sample_rate_; + double packet_end_time_seconds = + packet_start_time_seconds + packet_size_samples / input_sample_rate_; + double angular_freq = 2 * M_PI * cosine_frequency_hz; + Matrix* packet_data = + new Matrix(num_input_channels_, packet_size_samples); + // Use Eigen's vectorized cos() function to fill the vector with a + // sinusoid of appropriate frequency & phase. + for (int i = 0; i < num_input_channels_; i++) { + packet_data->row(i) = + Eigen::ArrayXf::LinSpaced(packet_size_samples, + packet_start_time_seconds * angular_freq, + packet_end_time_seconds * angular_freq) + .cos() + .transpose(); + } + int64 input_timestamp = round(packet_start_time_seconds * + Timestamp::kTimestampUnitsPerSecond); + AppendInputPacket(packet_data, input_timestamp); + total_num_input_samples += packet_size_samples; + } + } + + // Setup a sequence of input packets of specified sizes, each filled + // with samples of 1.0. + void SetupConstantInputPackets(const std::vector& packet_sizes_samples) { + // A 0 Hz cosine is identically 1.0 for all samples. + SetupCosineInputPackets(packet_sizes_samples, 0.0); + } + + // Setup a sequence of input packets of specified sizes, each containing a + // single sample of 1.0 at a specified offset. + void SetupImpulseInputPackets( + const std::vector& packet_sizes_samples, + const std::vector& impulse_offsets_samples) { + int total_num_input_samples = 0; + for (int i = 0; i < packet_sizes_samples.size(); ++i) { + double packet_start_time_seconds = + kInitialTimestampOffsetMicroseconds * 1e-6 + + total_num_input_samples / input_sample_rate_; + int64 input_timestamp = round(packet_start_time_seconds * + Timestamp::kTimestampUnitsPerSecond); + std::unique_ptr impulse( + new Matrix(Matrix::Zero(1, packet_sizes_samples[i]))); + (*impulse)(0, impulse_offsets_samples[i]) = 1.0; + AppendInputPacket(impulse.release(), input_timestamp); + total_num_input_samples += packet_sizes_samples[i]; + } + } + + // Creates test multichannel input with specified packet sizes and containing + // constant input packets for the even channels and constant-frequency + // sinusoid that maintains phase between adjacent packets for the odd + // channels. + void SetupMultichannelInputPackets( + const std::vector& packet_sizes_samples, float cosine_frequency_hz) { + int total_num_input_samples = 0; + for (int packet_size_samples : packet_sizes_samples) { + double packet_start_time_seconds = + kInitialTimestampOffsetMicroseconds * 1e-6 + + total_num_input_samples / input_sample_rate_; + double packet_end_time_seconds = + packet_start_time_seconds + packet_size_samples / input_sample_rate_; + double angular_freq; + Matrix* packet_data = + new Matrix(num_input_channels_, packet_size_samples); + // Use Eigen's vectorized cos() function to fill the vector with a + // sinusoid of appropriate frequency & phase. + for (int i = 0; i < num_input_channels_; i++) { + if (i % 2 == 0) { + angular_freq = 0; + } else { + angular_freq = 2 * M_PI * cosine_frequency_hz; + } + packet_data->row(i) = + Eigen::ArrayXf::LinSpaced(packet_size_samples, + packet_start_time_seconds * angular_freq, + packet_end_time_seconds * angular_freq) + .cos() + .transpose(); + } + int64 input_timestamp = round(packet_start_time_seconds * + Timestamp::kTimestampUnitsPerSecond); + AppendInputPacket(packet_data, input_timestamp); + total_num_input_samples += packet_size_samples; + } + } + + // Return vector of the numbers of frames in each output packet. + std::vector OutputFramesPerPacket() { + std::vector frame_counts; + for (const Packet& packet : output().packets) { + const Matrix& matrix = packet.Get(); + frame_counts.push_back(matrix.cols()); + } + return frame_counts; + } + + // Checks output headers and Timestamps. + void CheckOutputHeadersAndTimestamps() { + const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_); + + TimeSeriesHeader expected_header = input().header.Get(); + expected_header.set_num_channels(fft_size / 2 + 1); + // The output header sample rate should depend on the output frame step. + expected_header.set_sample_rate(input_sample_rate_ / frame_step_samples_); + // SpectrogramCalculator stores the sample rate of the input in + // the TimeSeriesHeader. + expected_header.set_audio_sample_rate(input_sample_rate_); + // We expect the output header to have num_samples and packet_rate unset. + expected_header.clear_num_samples(); + expected_header.clear_packet_rate(); + if (!options_.allow_multichannel_input()) { + ExpectOutputHeaderEquals(expected_header); + } else { + EXPECT_THAT(output() + .header.template Get() + .time_series_header(), + mediapipe::EqualsProto(expected_header)); + EXPECT_THAT(output() + .header.template Get() + .num_streams(), + num_input_channels_); + } + + int cumulative_output_frames = 0; + // The timestamps coming out of the spectrogram correspond to the + // middle of the first frame's window, hence frame_duration_samples_/2 + // term. We use frame_duration_samples_ because that is how it is + // actually quantized inside spectrogram. + const double packet_timestamp_offset_seconds = + kInitialTimestampOffsetMicroseconds * 1e-6; + const double frame_step_seconds = frame_step_samples_ / input_sample_rate_; + + Timestamp initial_timestamp = Timestamp::Unstarted(); + + for (const Packet& packet : output().packets) { + // This is the timestamp we expect based on how the spectrogram should + // behave (advancing by one step's worth of input samples each frame). + const double expected_timestamp_seconds = + packet_timestamp_offset_seconds + + cumulative_output_frames * frame_step_seconds; + const int64 expected_timestamp_ticks = + expected_timestamp_seconds * Timestamp::kTimestampUnitsPerSecond; + EXPECT_EQ(expected_timestamp_ticks, packet.Timestamp().Value()); + // Accept the timestamp of the first packet as the baseline for checking + // the remainder. + if (initial_timestamp == Timestamp::Unstarted()) { + initial_timestamp = packet.Timestamp(); + } + // Also check that the timestamp is consistent with the sample_rate + // in the output stream's TimeSeriesHeader. + EXPECT_TRUE(time_series_util::LogWarningIfTimestampIsInconsistent( + packet.Timestamp(), initial_timestamp, cumulative_output_frames, + expected_header.sample_rate())); + if (!options_.allow_multichannel_input()) { + if (options_.output_type() == SpectrogramCalculatorOptions::COMPLEX) { + const Eigen::MatrixXcf& matrix = packet.Get(); + cumulative_output_frames += matrix.cols(); + } else { + const Matrix& matrix = packet.Get(); + cumulative_output_frames += matrix.cols(); + } + } else { + if (options_.output_type() == SpectrogramCalculatorOptions::COMPLEX) { + const Eigen::MatrixXcf& matrix = + packet.Get>().at(0); + cumulative_output_frames += matrix.cols(); + } else { + const Matrix& matrix = packet.Get>().at(0); + cumulative_output_frames += matrix.cols(); + } + } + } + } + + // Verify that the bin corresponding to the specified frequency + // is the largest one in one particular frame of a single packet. + void CheckPeakFrequencyInPacketFrame(const Packet& packet, int frame, + float frequency) { + const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_); + const int target_bin = + round((frequency / input_sample_rate_) * static_cast(fft_size)); + + const Matrix& matrix = packet.Get(); + // Stop here if the requested frame is not in this packet. + ASSERT_GT(matrix.cols(), frame); + + int actual_largest_bin; + matrix.col(frame).maxCoeff(&actual_largest_bin); + EXPECT_EQ(actual_largest_bin, target_bin); + } + + // Verify that the bin corresponding to the specified frequency + // is the largest one in one particular frame of a single spectrogram Matrix. + void CheckPeakFrequencyInMatrix(const Matrix& matrix, int frame, + float frequency) { + const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_); + const int target_bin = + round((frequency / input_sample_rate_) * static_cast(fft_size)); + + // Stop here if the requested frame is not in this packet. + ASSERT_GT(matrix.cols(), frame); + + int actual_largest_bin; + matrix.col(frame).maxCoeff(&actual_largest_bin); + EXPECT_EQ(actual_largest_bin, target_bin); + } + + int frame_duration_samples_; + int frame_step_samples_; + // Expected DC output for a window of pure 1.0, set when window length + // is set. + float expected_dc_squared_magnitude_; +}; + +TEST_F(SpectrogramCalculatorTest, IntegerFrameDurationNoOverlap) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(0.0 / input_sample_rate_); + options_.set_pad_final_packet(false); + const std::vector input_packet_sizes = {500, 200}; + const std::vector expected_output_packet_sizes = {5, 2}; + + InitializeGraph(); + FillInputHeader(); + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes); +} + +TEST_F(SpectrogramCalculatorTest, IntegerFrameDurationSomeOverlap) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_pad_final_packet(false); + const std::vector input_packet_sizes = {500, 200}; + // complete_output_frames = 1 + floor((input_samples - window_length)/step) + // = 1 + floor((500 - 100)/40) = 1 + 10 = 11 for the first packet + // = 1 + floor((700 - 100)/40) = 1 + 15 = 16 for the whole stream + // so expect 16 - 11 = 5 in the second packet. + const std::vector expected_output_packet_sizes = {11, 5}; + + InitializeGraph(); + FillInputHeader(); + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes); +} + +TEST_F(SpectrogramCalculatorTest, NonintegerFrameDurationAndOverlap) { + options_.set_frame_duration_seconds(98.5 / input_sample_rate_); + options_.set_frame_overlap_seconds(58.4 / input_sample_rate_); + options_.set_pad_final_packet(false); + const std::vector input_packet_sizes = {500, 200}; + // now frame_duration_samples will be 99 (rounded), and frame_step_samples + // will be (99-58) = 41, so the first packet of 500 samples will generate + // 1 + floor(500-99)/41 = 10 samples. + const std::vector expected_output_packet_sizes = {10, 5}; + + InitializeGraph(); + FillInputHeader(); + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes); +} + +TEST_F(SpectrogramCalculatorTest, ShortInitialPacketNoOverlap) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(0.0 / input_sample_rate_); + options_.set_pad_final_packet(false); + const std::vector input_packet_sizes = {90, 100, 110}; + // The first input packet is too small to generate any frames, + // but zero-length packets would result in a timestamp monotonicity + // violation, so they are suppressed. Thus, only the second and third + // input packets generate output packets. + const std::vector expected_output_packet_sizes = {1, 2}; + + InitializeGraph(); + FillInputHeader(); + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes); +} + +TEST_F(SpectrogramCalculatorTest, TrailingSamplesNoPad) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_pad_final_packet(false); + const std::vector input_packet_sizes = {140, 90}; + const std::vector expected_output_packet_sizes = {2, 2}; + + InitializeGraph(); + FillInputHeader(); + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes); +} + +TEST_F(SpectrogramCalculatorTest, NoTrailingSamplesWithPad) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_pad_final_packet(true); + const std::vector input_packet_sizes = {140, 80}; + const std::vector expected_output_packet_sizes = {2, 2}; + + InitializeGraph(); + FillInputHeader(); + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes); +} + +TEST_F(SpectrogramCalculatorTest, TrailingSamplesWithPad) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_pad_final_packet(true); + const std::vector input_packet_sizes = {140, 90}; + // In contrast to NoTrailingSamplesWithPad and TrailingSamplesNoPad, + // this time we get an extra frame in an extra final packet. + const std::vector expected_output_packet_sizes = {2, 2, 1}; + + InitializeGraph(); + FillInputHeader(); + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes); +} + +TEST_F(SpectrogramCalculatorTest, VeryShortInputWillPad) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_pad_final_packet(true); + const std::vector input_packet_sizes = {30}; + const std::vector expected_output_packet_sizes = {1}; + + InitializeGraph(); + FillInputHeader(); + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes); +} + +TEST_F(SpectrogramCalculatorTest, VeryShortInputZeroOutputFramesIfNoPad) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_pad_final_packet(false); + const std::vector input_packet_sizes = {90}; + const std::vector expected_output_packet_sizes = {}; + + InitializeGraph(); + FillInputHeader(); + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes); +} + +TEST_F(SpectrogramCalculatorTest, DCSignalIsPeakBin) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + const std::vector input_packet_sizes = {140}; // Gives 2 output frames. + + InitializeGraph(); + FillInputHeader(); + // Setup packets with DC input (non-zero constant value). + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + const float dc_frequency_hz = 0.0; + CheckPeakFrequencyInPacketFrame(output().packets[0], 0, dc_frequency_hz); + CheckPeakFrequencyInPacketFrame(output().packets[0], 1, dc_frequency_hz); +} + +TEST_F(SpectrogramCalculatorTest, A440ToneIsPeakBin) { + const std::vector input_packet_sizes = { + 460}; // 100 + 9*40 for 10 frames. + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + InitializeGraph(); + FillInputHeader(); + const float tone_frequency_hz = 440.0; + SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + int num_output_frames = output().packets[0].Get().cols(); + for (int frame = 0; frame < num_output_frames; ++frame) { + CheckPeakFrequencyInPacketFrame(output().packets[0], frame, + tone_frequency_hz); + } +} + +TEST_F(SpectrogramCalculatorTest, SquaredMagnitudeOutputLooksRight) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_output_type(SpectrogramCalculatorOptions::SQUARED_MAGNITUDE); + const std::vector input_packet_sizes = {140}; + + InitializeGraph(); + FillInputHeader(); + // Setup packets with DC input (non-zero constant value). + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_FLOAT_EQ(output().packets[0].Get()(0, 0), + expected_dc_squared_magnitude_); +} + +TEST_F(SpectrogramCalculatorTest, DefaultOutputIsSquaredMagnitude) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + // Let the output_type be its default + const std::vector input_packet_sizes = {140}; + + InitializeGraph(); + FillInputHeader(); + // Setup packets with DC input (non-zero constant value). + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_FLOAT_EQ(output().packets[0].Get()(0, 0), + expected_dc_squared_magnitude_); +} + +TEST_F(SpectrogramCalculatorTest, LinearMagnitudeOutputLooksRight) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_output_type(SpectrogramCalculatorOptions::LINEAR_MAGNITUDE); + const std::vector input_packet_sizes = {140}; + + InitializeGraph(); + FillInputHeader(); + // Setup packets with DC input (non-zero constant value). + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_FLOAT_EQ(output().packets[0].Get()(0, 0), + std::sqrt(expected_dc_squared_magnitude_)); +} + +TEST_F(SpectrogramCalculatorTest, DbMagnitudeOutputLooksRight) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_output_type(SpectrogramCalculatorOptions::DECIBELS); + const std::vector input_packet_sizes = {140}; + + InitializeGraph(); + FillInputHeader(); + // Setup packets with DC input (non-zero constant value). + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_FLOAT_EQ(output().packets[0].Get()(0, 0), + 10.0 * std::log10(expected_dc_squared_magnitude_)); +} + +TEST_F(SpectrogramCalculatorTest, OutputScalingLooksRight) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_output_type(SpectrogramCalculatorOptions::DECIBELS); + double output_scale = 2.5; + options_.set_output_scale(output_scale); + const std::vector input_packet_sizes = {140}; + + InitializeGraph(); + FillInputHeader(); + // Setup packets with DC input (non-zero constant value). + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_FLOAT_EQ( + output().packets[0].Get()(0, 0), + output_scale * 10.0 * std::log10(expected_dc_squared_magnitude_)); +} + +TEST_F(SpectrogramCalculatorTest, ComplexOutputLooksRight) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX); + const std::vector input_packet_sizes = {140}; + + InitializeGraph(); + FillInputHeader(); + // Setup packets with DC input (non-zero constant value). + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_FLOAT_EQ(std::norm(output().packets[0].Get()(0, 0)), + expected_dc_squared_magnitude_); +} + +TEST_F(SpectrogramCalculatorTest, ComplexOutputLooksRightForImpulses) { + const int frame_size_samples = 100; + options_.set_frame_duration_seconds(frame_size_samples / input_sample_rate_); + options_.set_frame_overlap_seconds(0.0 / input_sample_rate_); + options_.set_pad_final_packet(false); + options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX); + const std::vector input_packet_sizes = {frame_size_samples, + frame_size_samples}; + const std::vector input_packet_impulse_offsets = {49, 50}; + + InitializeGraph(); + FillInputHeader(); + + // Make two impulse packets offset one sample from each other + SetupImpulseInputPackets(input_packet_sizes, input_packet_impulse_offsets); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + const int num_buckets = + (audio_dsp::NextPowerOfTwo(frame_size_samples) / 2) + 1; + const float precision = 0.01f; + auto norm_fn = [](const std::complex& cf) { return std::norm(cf); }; + + // Both impulses should have (approximately) constant power across all + // frequency bins + EXPECT_TRUE(output() + .packets[0] + .Get() + .unaryExpr(norm_fn) + .isApproxToConstant(1.0f, precision)); + EXPECT_TRUE(output() + .packets[1] + .Get() + .unaryExpr(norm_fn) + .isApproxToConstant(1.0f, precision)); + + // Because the second Packet's impulse is delayed by exactly one sample with + // respect to the first Packet's impulse, the second impulse should have + // greater phase, and in the highest frequency bin, the real part should + // (approximately) flip sign from the first Packet to the second + EXPECT_LT(std::arg(output().packets[0].Get()(1, 0)), + std::arg(output().packets[1].Get()(1, 0))); + const float highest_bucket_real_ratio = + output().packets[0].Get()(num_buckets - 1, 0).real() / + output().packets[1].Get()(num_buckets - 1, 0).real(); + EXPECT_NEAR(highest_bucket_real_ratio, -1.0f, precision); +} + +TEST_F(SpectrogramCalculatorTest, SquaredMagnitudeOutputLooksRightForNonDC) { + const int frame_size_samples = 100; + options_.set_frame_duration_seconds(frame_size_samples / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_output_type(SpectrogramCalculatorOptions::SQUARED_MAGNITUDE); + const std::vector input_packet_sizes = {140}; + + InitializeGraph(); + FillInputHeader(); + // Make the tone have an integral number of cycles within the window + const int target_bin = 16; + const int fft_size = audio_dsp::NextPowerOfTwo(frame_size_samples); + const float tone_frequency_hz = target_bin * (input_sample_rate_ / fft_size); + SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + // For a non-DC bin, the magnitude will be split between positive and + // negative frequency bins, so it should about be half-magnitude + // = quarter-power. + // It's not quite exact because of the interference from the hann(100) + // spread from the negative-frequency half. + EXPECT_GT(output().packets[0].Get()(target_bin, 0), + 0.98 * expected_dc_squared_magnitude_ / 4.0); + EXPECT_LT(output().packets[0].Get()(target_bin, 0), + 1.02 * expected_dc_squared_magnitude_ / 4.0); +} + +TEST_F(SpectrogramCalculatorTest, ZeroOutputsForZeroInputsWithPaddingEnabled) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_pad_final_packet(true); + const std::vector input_packet_sizes = {}; + const std::vector expected_output_packet_sizes = {}; + + InitializeGraph(); + FillInputHeader(); + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes); +} + +TEST_F(SpectrogramCalculatorTest, NumChannelsIsRight) { + const std::vector input_packet_sizes = {460}; + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_pad_final_packet(false); + options_.set_allow_multichannel_input(true); + num_input_channels_ = 3; + InitializeGraph(); + FillInputHeader(); + const float tone_frequency_hz = 440.0; + SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz); + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + EXPECT_EQ(output().packets[0].Get>().size(), + num_input_channels_); +} + +TEST_F(SpectrogramCalculatorTest, NumSamplesAndPacketRateAreCleared) { + num_input_samples_ = 500; + input_packet_rate_ = 1.0; + const std::vector input_packet_sizes = {num_input_samples_}; + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(0.0); + options_.set_pad_final_packet(false); + + InitializeGraph(); + FillInputHeader(); + SetupConstantInputPackets(input_packet_sizes); + + MEDIAPIPE_ASSERT_OK(Run()); + + const TimeSeriesHeader& output_header = + output().header.Get(); + EXPECT_FALSE(output_header.has_num_samples()); + EXPECT_FALSE(output_header.has_packet_rate()); +} + +TEST_F(SpectrogramCalculatorTest, MultichannelSpectrogramSizesAreRight) { + const std::vector input_packet_sizes = {420}; // less than 10 frames + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_pad_final_packet(false); + options_.set_allow_multichannel_input(true); + num_input_channels_ = 10; + InitializeGraph(); + FillInputHeader(); + const float tone_frequency_hz = 440.0; + SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz); + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + auto spectrograms = output().packets[0].Get>(); + EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_); + int spectrogram_num_rows = spectrograms[0].rows(); + int spectrogram_num_cols = spectrograms[0].cols(); + for (int i = 1; i < num_input_channels_; i++) { + EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows()); + EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols()); + } +} + +TEST_F(SpectrogramCalculatorTest, MultichannelSpectrogramValuesAreRight) { + const std::vector input_packet_sizes = { + 460}; // 100 + 9*40 for 10 frames. + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_allow_multichannel_input(true); + num_input_channels_ = 10; + InitializeGraph(); + FillInputHeader(); + const float tone_frequency_hz = 440.0; + SetupMultichannelInputPackets(input_packet_sizes, tone_frequency_hz); + + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + auto spectrograms = output().packets[0].Get>(); + int num_output_frames = spectrograms[0].cols(); + for (int i = 0; i < num_input_channels_; i++) { + for (int frame = 0; frame < num_output_frames; ++frame) { + if (i % 2 == 0) { + CheckPeakFrequencyInMatrix(spectrograms[i], frame, 0); + } else { + CheckPeakFrequencyInMatrix(spectrograms[i], frame, tone_frequency_hz); + } + } + } +} + +TEST_F(SpectrogramCalculatorTest, MultichannelHandlesShortInitialPacket) { + // First packet is less than one frame, but second packet should trigger a + // complete frame from all channels. + const std::vector input_packet_sizes = {50, 50}; + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_pad_final_packet(false); + options_.set_allow_multichannel_input(true); + num_input_channels_ = 2; + InitializeGraph(); + FillInputHeader(); + const float tone_frequency_hz = 440.0; + SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz); + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + auto spectrograms = output().packets[0].Get>(); + EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_); + int spectrogram_num_rows = spectrograms[0].rows(); + int spectrogram_num_cols = spectrograms[0].cols(); + for (int i = 1; i < num_input_channels_; i++) { + EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows()); + EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols()); + } +} + +TEST_F(SpectrogramCalculatorTest, + MultichannelComplexHandlesShortInitialPacket) { + // First packet is less than one frame, but second packet should trigger a + // complete frame from all channels, even for complex output. + const std::vector input_packet_sizes = {50, 50}; + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(60.0 / input_sample_rate_); + options_.set_pad_final_packet(false); + options_.set_allow_multichannel_input(true); + options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX); + num_input_channels_ = 2; + InitializeGraph(); + FillInputHeader(); + const float tone_frequency_hz = 440.0; + SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz); + MEDIAPIPE_ASSERT_OK(Run()); + + CheckOutputHeadersAndTimestamps(); + auto spectrograms = output().packets[0].Get>(); + EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_); + int spectrogram_num_rows = spectrograms[0].rows(); + int spectrogram_num_cols = spectrograms[0].cols(); + for (int i = 1; i < num_input_channels_; i++) { + EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows()); + EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols()); + } +} + +void BM_ProcessDC(benchmark::State& state) { + CalculatorGraphConfig::Node node_config; + node_config.set_calculator("SpectrogramCalculator"); + node_config.add_input_stream("input_audio"); + node_config.add_output_stream("output_spectrogram"); + + SpectrogramCalculatorOptions* options = + node_config.mutable_options()->MutableExtension( + SpectrogramCalculatorOptions::ext); + options->set_frame_duration_seconds(0.010); + options->set_frame_overlap_seconds(0.0); + options->set_pad_final_packet(false); + *node_config.mutable_options()->MutableExtension( + SpectrogramCalculatorOptions::ext) = *options; + + int num_input_channels = 1; + int packet_size_samples = 1600000; + TimeSeriesHeader* header = new TimeSeriesHeader(); + header->set_sample_rate(16000.0); + header->set_num_channels(num_input_channels); + + CalculatorRunner runner(node_config); + runner.MutableInputs()->Index(0).header = Adopt(header); + + Matrix* payload = new Matrix( + Matrix::Constant(num_input_channels, packet_size_samples, 1.0)); + Timestamp timestamp = Timestamp(0); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(payload).At(timestamp)); + + for (auto _ : state) { + ASSERT_TRUE(runner.Run().ok()); + } + + const CalculatorRunner::StreamContents& output = runner.Outputs().Index(0); + const Matrix& output_matrix = output.packets[0].Get(); + LOG(INFO) << "Output matrix=" << output_matrix.rows() << "x" + << output_matrix.cols(); + LOG(INFO) << "First values=" << output_matrix(0, 0) << ", " + << output_matrix(1, 0) << ", " << output_matrix(2, 0) << ", " + << output_matrix(3, 0); +} + +BENCHMARK(BM_ProcessDC); + +} // anonymous namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/testdata/BUILD b/mediapipe/calculators/audio/testdata/BUILD new file mode 100644 index 000000000..64f6ccf63 --- /dev/null +++ b/mediapipe/calculators/audio/testdata/BUILD @@ -0,0 +1,27 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "test_audios", + srcs = [ + "sine_wave_1k_44100_mono_2_sec_wav.audio", + "sine_wave_1k_44100_stereo_2_sec_aac.audio", + "sine_wave_1k_44100_stereo_2_sec_mp3.audio", + "sine_wave_1k_48000_stereo_2_sec_wav.audio", + ], + visibility = ["//visibility:public"], +) diff --git a/mediapipe/calculators/audio/testdata/sine_wave_1k_44100_mono_2_sec_wav.audio b/mediapipe/calculators/audio/testdata/sine_wave_1k_44100_mono_2_sec_wav.audio new file mode 100644 index 000000000..bd04691d4 Binary files /dev/null and b/mediapipe/calculators/audio/testdata/sine_wave_1k_44100_mono_2_sec_wav.audio differ diff --git a/mediapipe/calculators/audio/testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio b/mediapipe/calculators/audio/testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio new file mode 100644 index 000000000..9b3b03a36 Binary files /dev/null and b/mediapipe/calculators/audio/testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio differ diff --git a/mediapipe/calculators/audio/testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio b/mediapipe/calculators/audio/testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio new file mode 100644 index 000000000..fe11e1165 Binary files /dev/null and b/mediapipe/calculators/audio/testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio differ diff --git a/mediapipe/calculators/audio/testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio b/mediapipe/calculators/audio/testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio new file mode 100644 index 000000000..4258bc4c2 Binary files /dev/null and b/mediapipe/calculators/audio/testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio differ diff --git a/mediapipe/calculators/audio/time_series_framer_calculator.cc b/mediapipe/calculators/audio/time_series_framer_calculator.cc new file mode 100644 index 000000000..60ab5d7b4 --- /dev/null +++ b/mediapipe/calculators/audio/time_series_framer_calculator.cc @@ -0,0 +1,289 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Defines TimeSeriesFramerCalculator. +#include + +#include +#include +#include + +#include "Eigen/Core" +#include "audio/dsp/window_functions.h" +#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/time_series_util.h" + +namespace mediapipe { + +// MediaPipe Calculator for framing a (vector-valued) input time series, +// i.e. for breaking an input time series into fixed-size, possibly +// overlapping, frames. The output stream's frame duration is +// specified by frame_duration_seconds in the +// TimeSeriesFramerCalculatorOptions, and the output's overlap is +// specified by frame_overlap_seconds. +// +// This calculator assumes that the input timestamps refer to the +// first sample in each Matrix. The output timestamps follow this +// same convention. +// +// All output frames will have exactly the same number of samples: the number of +// samples that approximates frame_duration_seconds most closely. +// +// Similarly, frame overlap is by default the (fixed) number of samples +// approximating frame_overlap_seconds most closely. But if +// emulate_fractional_frame_overlap is set to true, frame overlap is a variable +// number of samples instead, such that the long-term average step between +// frames is the difference between the (nominal) frame_duration_seconds and +// frame_overlap_seconds. +// +// If pad_final_packet is true, all input samples will be emitted and the final +// packet will be zero padded as necessary. If pad_final_packet is false, some +// samples may be dropped at the end of the stream. +class TimeSeriesFramerCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set( + // Input stream with TimeSeriesHeader. + ); + cc->Outputs().Index(0).Set( + // Fixed length time series Packets with TimeSeriesHeader. + ); + return ::mediapipe::OkStatus(); + } + + // Returns FAIL if the input stream header is invalid. + ::mediapipe::Status Open(CalculatorContext* cc) override; + + // Outputs as many framed packets as possible given the accumulated + // input. Always returns OK. + ::mediapipe::Status Process(CalculatorContext* cc) override; + + // Flushes any remaining samples in a zero-padded packet. Always + // returns OK. + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + // Adds input data to the internal buffer. + void EnqueueInput(CalculatorContext* cc); + // Constructs and emits framed output packets. + void FrameOutput(CalculatorContext* cc); + + Timestamp CurrentOutputTimestamp() { + return initial_input_timestamp_ + + round(cumulative_completed_samples_ / sample_rate_ * + Timestamp::kTimestampUnitsPerSecond); + } + + // The number of input samples to advance after the current output frame is + // emitted. + int next_frame_step_samples() const { + // All numbers are in input samples. + const int64 current_output_frame_start = static_cast( + round(cumulative_output_frames_ * average_frame_step_samples_)); + CHECK_EQ(current_output_frame_start, cumulative_completed_samples_); + const int64 next_output_frame_start = static_cast( + round((cumulative_output_frames_ + 1) * average_frame_step_samples_)); + return next_output_frame_start - current_output_frame_start; + } + + double sample_rate_; + bool pad_final_packet_; + int frame_duration_samples_; + // The advance, in input samples, between the start of successive output + // frames. This may be a non-integer average value if + // emulate_fractional_frame_overlap is true. + double average_frame_step_samples_; + int samples_still_to_drop_; + int64 cumulative_input_samples_; + int64 cumulative_output_frames_; + // "Completed" samples are samples that are no longer needed because + // the framer has completely stepped past them (taking into account + // any overlap). + int64 cumulative_completed_samples_; + Timestamp initial_input_timestamp_; + int num_channels_; + + // Each entry in this deque consists of a single sample, i.e. a + // single column vector. + std::deque sample_buffer_; + + bool use_window_; + Matrix window_; +}; +REGISTER_CALCULATOR(TimeSeriesFramerCalculator); + +void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) { + const Matrix& input_frame = cc->Inputs().Index(0).Get(); + + for (int i = 0; i < input_frame.cols(); ++i) { + sample_buffer_.emplace_back(input_frame.col(i)); + } + + cumulative_input_samples_ += input_frame.cols(); +} + +void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { + while (sample_buffer_.size() >= + frame_duration_samples_ + samples_still_to_drop_) { + while (samples_still_to_drop_ > 0) { + sample_buffer_.pop_front(); + --samples_still_to_drop_; + } + const int frame_step_samples = next_frame_step_samples(); + std::unique_ptr output_frame( + new Matrix(num_channels_, frame_duration_samples_)); + for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_); + ++i) { + output_frame->col(i) = sample_buffer_.front(); + sample_buffer_.pop_front(); + } + const int frame_overlap_samples = + frame_duration_samples_ - frame_step_samples; + if (frame_overlap_samples > 0) { + for (int i = 0; i < frame_overlap_samples; ++i) { + output_frame->col(i + frame_step_samples) = sample_buffer_[i]; + } + } else { + samples_still_to_drop_ = -frame_overlap_samples; + } + + if (use_window_) { + *output_frame = (output_frame->array() * window_.array()).matrix(); + } + + cc->Outputs().Index(0).Add(output_frame.release(), + CurrentOutputTimestamp()); + ++cumulative_output_frames_; + cumulative_completed_samples_ += frame_step_samples; + } +} + +::mediapipe::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) { + if (initial_input_timestamp_ == Timestamp::Unstarted()) { + initial_input_timestamp_ = cc->InputTimestamp(); + } + + EnqueueInput(cc); + FrameOutput(cc); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) { + while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) { + sample_buffer_.pop_front(); + --samples_still_to_drop_; + } + if (!sample_buffer_.empty() && pad_final_packet_) { + std::unique_ptr output_frame(new Matrix); + output_frame->setZero(num_channels_, frame_duration_samples_); + for (int i = 0; i < sample_buffer_.size(); ++i) { + output_frame->col(i) = sample_buffer_[i]; + } + + cc->Outputs().Index(0).Add(output_frame.release(), + CurrentOutputTimestamp()); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) { + TimeSeriesFramerCalculatorOptions framer_options; + time_series_util::FillOptionsExtensionOrDie(cc->Options(), &framer_options); + + RET_CHECK_GT(framer_options.frame_duration_seconds(), 0.0) + << "Invalid or missing frame_duration_seconds. " + << "framer_duration_seconds: \n" + << framer_options.frame_duration_seconds(); + RET_CHECK_LT(framer_options.frame_overlap_seconds(), + framer_options.frame_duration_seconds()) + << "Invalid frame_overlap_seconds. framer_overlap_seconds: \n" + << framer_options.frame_overlap_seconds(); + + TimeSeriesHeader input_header; + RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid( + cc->Inputs().Index(0).Header(), &input_header)); + + sample_rate_ = input_header.sample_rate(); + num_channels_ = input_header.num_channels(); + frame_duration_samples_ = time_series_util::SecondsToSamples( + framer_options.frame_duration_seconds(), sample_rate_); + RET_CHECK_GT(frame_duration_samples_, 0) + << "Frame duration of " << framer_options.frame_duration_seconds() + << "s too small to cover a single sample at " << sample_rate_ << " Hz "; + if (framer_options.emulate_fractional_frame_overlap()) { + // Frame step may be fractional. + average_frame_step_samples_ = (framer_options.frame_duration_seconds() - + framer_options.frame_overlap_seconds()) * + sample_rate_; + } else { + // Frame step is an integer (stored in a double). + average_frame_step_samples_ = + frame_duration_samples_ - + time_series_util::SecondsToSamples( + framer_options.frame_overlap_seconds(), sample_rate_); + } + RET_CHECK_GE(average_frame_step_samples_, 1) + << "Frame step too small to cover a single sample at " << sample_rate_ + << " Hz."; + pad_final_packet_ = framer_options.pad_final_packet(); + + auto output_header = new TimeSeriesHeader(input_header); + output_header->set_num_samples(frame_duration_samples_); + if (round(average_frame_step_samples_) == average_frame_step_samples_) { + // Only set output packet rate if it is fixed. + output_header->set_packet_rate(sample_rate_ / average_frame_step_samples_); + } + cc->Outputs().Index(0).SetHeader(Adopt(output_header)); + cumulative_completed_samples_ = 0; + cumulative_input_samples_ = 0; + cumulative_output_frames_ = 0; + samples_still_to_drop_ = 0; + initial_input_timestamp_ = Timestamp::Unstarted(); + + std::vector window_vector; + use_window_ = false; + switch (framer_options.window_function()) { + case TimeSeriesFramerCalculatorOptions::HAMMING: + audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_, + &window_vector); + use_window_ = true; + break; + case TimeSeriesFramerCalculatorOptions::HANN: + audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_, + &window_vector); + use_window_ = true; + break; + case TimeSeriesFramerCalculatorOptions::NONE: + break; + } + + if (use_window_) { + window_ = Matrix::Ones(num_channels_, 1) * + Eigen::Map(window_vector.data(), 1, + frame_duration_samples_) + .cast(); + } + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/audio/time_series_framer_calculator.proto b/mediapipe/calculators/audio/time_series_framer_calculator.proto new file mode 100644 index 000000000..61be38da7 --- /dev/null +++ b/mediapipe/calculators/audio/time_series_framer_calculator.proto @@ -0,0 +1,65 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TimeSeriesFramerCalculatorOptions { + extend CalculatorOptions { + optional TimeSeriesFramerCalculatorOptions ext = 50631621; + } + + // Frame duration in seconds. Required. Must be greater than 0. This is + // rounded to the nearest integer number of samples. + optional double frame_duration_seconds = 1; + + // Frame overlap in seconds. + // + // If emulate_fractional_frame_overlap is false (the default), then the frame + // overlap is rounded to the nearest integer number of samples, and the step + // from one frame to the next will be the difference between the number of + // samples in a frame and the number of samples in the overlap. + // + // If emulate_fractional_frame_overlap is true, then frame overlap will be a + // variable number of samples, such that the long-time average time step from + // one frame to the next will be the difference between the (nominal, not + // rounded) frame_duration_seconds and frame_overlap_seconds. This is useful + // where the desired time step is not an integral number of input samples. + // + // A negative frame_overlap_seconds corresponds to skipping some input samples + // between each frame of emitted samples. + // + // Required that frame_overlap_seconds < frame_duration_seconds. + optional double frame_overlap_seconds = 2 [default = 0.0]; + + // See frame_overlap_seconds for semantics. + optional bool emulate_fractional_frame_overlap = 5 [default = false]; + + // Whether to pad the final packet with zeros. If true, guarantees that all + // input samples (other than those that fall in gaps implied by negative + // frame_overlap_seconds) will be emitted. If set to false, any partial + // packet at the end of the stream will be dropped. + optional bool pad_final_packet = 3 [default = true]; + + // Optional windowing function. The default is NONE (no windowing function). + enum WindowFunction { + NONE = 0; + HAMMING = 1; + HANN = 2; + } + optional WindowFunction window_function = 4 [default = NONE]; +} diff --git a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc new file mode 100644 index 000000000..ec64b2b1a --- /dev/null +++ b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc @@ -0,0 +1,395 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include + +#include "Eigen/Core" +#include "audio/dsp/window_functions.h" +#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/time_series_test_util.h" + +namespace mediapipe { +namespace { + +const int kInitialTimestampOffsetMicroseconds = 4; + +class TimeSeriesFramerCalculatorTest + : public TimeSeriesCalculatorTest { + protected: + void SetUp() override { + calculator_name_ = "TimeSeriesFramerCalculator"; + input_sample_rate_ = 4000.0; + num_input_channels_ = 3; + } + + // Returns a float value with the channel and timestamp separated by + // an order of magnitude, for easy parsing by humans. + float TestValue(int64 timestamp_in_microseconds, int channel) { + return timestamp_in_microseconds + channel / 10.0; + } + + // Caller takes ownership of the returned value. + Matrix* NewTestFrame(int num_channels, int num_samples, + double starting_timestamp_seconds) { + auto matrix = new Matrix(num_channels, num_samples); + for (int c = 0; c < num_channels; ++c) { + for (int i = 0; i < num_samples; ++i) { + int64 timestamp = time_series_util::SecondsToSamples( + starting_timestamp_seconds + i / input_sample_rate_, + Timestamp::kTimestampUnitsPerSecond); + (*matrix)(c, i) = TestValue(timestamp, c); + } + } + return matrix; + } + + // Initializes and runs the test graph. + ::mediapipe::Status Run() { + InitializeGraph(); + + FillInputHeader(); + InitializeInput(); + + return RunGraph(); + } + + // Creates test input and saves a reference copy. + void InitializeInput() { + concatenated_input_samples_.resize(0, num_input_channels_); + num_input_samples_ = 0; + for (int i = 0; i < 10; ++i) { + // This range of packet sizes was chosen such that some input + // packets will be smaller than the output packet size and other + // input packets will be larger. + int packet_size = (i + 1) * 20; + double timestamp_seconds = kInitialTimestampOffsetMicroseconds * 1.0e-6 + + num_input_samples_ / input_sample_rate_; + + Matrix* data_frame = + NewTestFrame(num_input_channels_, packet_size, timestamp_seconds); + + // Keep a reference copy of the input. + // + // conservativeResize() is needed here to preserve the existing + // data. Eigen's resize() resizes without preserving data. + concatenated_input_samples_.conservativeResize( + num_input_channels_, num_input_samples_ + packet_size); + concatenated_input_samples_.rightCols(packet_size) = *data_frame; + num_input_samples_ += packet_size; + + AppendInputPacket(data_frame, round(timestamp_seconds * + Timestamp::kTimestampUnitsPerSecond)); + } + + const int frame_duration_samples = FrameDurationSamples(); + std::vector window_vector; + switch (options_.window_function()) { + case TimeSeriesFramerCalculatorOptions::HAMMING: + audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples, + &window_vector); + break; + case TimeSeriesFramerCalculatorOptions::HANN: + audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples, + &window_vector); + break; + case TimeSeriesFramerCalculatorOptions::NONE: + window_vector.assign(frame_duration_samples, 1.0f); + break; + } + + window_ = Matrix::Ones(num_input_channels_, 1) * + Eigen::Map(window_vector.data(), 1, + frame_duration_samples) + .cast(); + } + + int FrameDurationSamples() { + return time_series_util::SecondsToSamples(options_.frame_duration_seconds(), + input_sample_rate_); + } + + // Checks that the values in the framed output packets matches the + // appropriate values from the input. + void CheckOutputPacketValues(const Matrix& actual, int packet_num, + int frame_duration_samples, + double frame_step_samples, + int num_columns_to_check) { + ASSERT_EQ(frame_duration_samples, actual.cols()); + Matrix expected = (concatenated_input_samples_ + .block(0, round(frame_step_samples * packet_num), + num_input_channels_, num_columns_to_check) + .array() * + window_.leftCols(num_columns_to_check).array()) + .matrix(); + ExpectApproximatelyEqual(expected, actual.leftCols(num_columns_to_check)); + } + + // Checks output headers, Timestamps, and values. + void CheckOutput() { + const int frame_duration_samples = FrameDurationSamples(); + const double frame_step_samples = + options_.emulate_fractional_frame_overlap() + ? (options_.frame_duration_seconds() - + options_.frame_overlap_seconds()) * + input_sample_rate_ + : frame_duration_samples - + time_series_util::SecondsToSamples( + options_.frame_overlap_seconds(), input_sample_rate_); + + TimeSeriesHeader expected_header = input().header.Get(); + expected_header.set_num_samples(frame_duration_samples); + if (!options_.emulate_fractional_frame_overlap() || + frame_step_samples == round(frame_step_samples)) { + expected_header.set_packet_rate(input_sample_rate_ / frame_step_samples); + } + ExpectOutputHeaderEquals(expected_header); + + int num_full_packets = output().packets.size(); + if (options_.pad_final_packet()) { + num_full_packets -= 1; + } + + for (int packet_num = 0; packet_num < num_full_packets; ++packet_num) { + const Packet& packet = output().packets[packet_num]; + CheckOutputPacketValues(packet.Get(), packet_num, + frame_duration_samples, frame_step_samples, + frame_duration_samples); + } + + // What is the effective time index of the final sample emitted? + // This includes accounting for the gaps when overlap is negative. + const int num_unique_output_samples = + round((output().packets.size() - 1) * frame_step_samples) + + frame_duration_samples; + LOG(INFO) << "packets.size()=" << output().packets.size() + << " frame_duration_samples=" << frame_duration_samples + << " frame_step_samples=" << frame_step_samples + << " num_input_samples_=" << num_input_samples_ + << " num_unique_output_samples=" << num_unique_output_samples; + const int num_padding_samples = + num_unique_output_samples - num_input_samples_; + if (options_.pad_final_packet()) { + EXPECT_LT(num_padding_samples, frame_duration_samples); + // If the input ended during the dropped samples between the end of + // the last emitted frame and where the next one would begin, there + // can be fewer unique output points than input points, even with + // padding. + const int max_dropped_samples = + static_cast(ceil(frame_step_samples - frame_duration_samples)); + EXPECT_GE(num_padding_samples, std::min(0, -max_dropped_samples)); + + if (num_padding_samples > 0) { + // Check the non-padded part of the final packet. + const Matrix& final_matrix = output().packets.back().Get(); + CheckOutputPacketValues(final_matrix, num_full_packets, + frame_duration_samples, frame_step_samples, + frame_duration_samples - num_padding_samples); + // Check the padded part of the final packet. + EXPECT_EQ( + Matrix::Zero(num_input_channels_, num_padding_samples), + final_matrix.block(0, frame_duration_samples - num_padding_samples, + num_input_channels_, num_padding_samples)); + } + } else { + EXPECT_GT(num_padding_samples, -frame_duration_samples); + EXPECT_LE(num_padding_samples, 0); + } + } + + int num_input_samples_; + Matrix concatenated_input_samples_; + Matrix window_; +}; + +TEST_F(TimeSeriesFramerCalculatorTest, IntegerSampleDurationNoOverlap) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + MEDIAPIPE_ASSERT_OK(Run()); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, + IntegerSampleDurationNoOverlapHammingWindow) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_window_function(TimeSeriesFramerCalculatorOptions::HAMMING); + MEDIAPIPE_ASSERT_OK(Run()); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, + IntegerSampleDurationNoOverlapHannWindow) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_window_function(TimeSeriesFramerCalculatorOptions::HANN); + MEDIAPIPE_ASSERT_OK(Run()); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, IntegerSampleDurationAndOverlap) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(40.0 / input_sample_rate_); + MEDIAPIPE_ASSERT_OK(Run()); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, NonintegerSampleDurationAndOverlap) { + options_.set_frame_duration_seconds(98.5 / input_sample_rate_); + options_.set_frame_overlap_seconds(38.4 / input_sample_rate_); + + MEDIAPIPE_ASSERT_OK(Run()); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapExactFrames) { + // Negative overlap means to drop samples between frames. + // 100 samples per frame plus a skip of 10 samples will be 10 full frames in + // the 1100 input samples. + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(-10.0 / input_sample_rate_); + MEDIAPIPE_ASSERT_OK(Run()); + EXPECT_EQ(output().packets.size(), 10); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapExactFramesLessSkip) { + // 100 samples per frame plus a skip of 100 samples will be 6 full frames in + // the 1100 input samples. + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(-100.0 / input_sample_rate_); + MEDIAPIPE_ASSERT_OK(Run()); + EXPECT_EQ(output().packets.size(), 6); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapWithPadding) { + // 150 samples per frame plus a skip of 50 samples will require some padding + // on the sixth and last frame given 1100 sample input. + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(-100.0 / input_sample_rate_); + MEDIAPIPE_ASSERT_OK(Run()); + EXPECT_EQ(output().packets.size(), 6); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, FixedFrameOverlap) { + // Frame of 30 samples with step of 11.4 samples (rounded to 11 samples) + // results in ceil((1100 - 30) / 11) + 1 = 99 packets. + options_.set_frame_duration_seconds(30 / input_sample_rate_); + options_.set_frame_overlap_seconds((30.0 - 11.4) / input_sample_rate_); + MEDIAPIPE_ASSERT_OK(Run()); + EXPECT_EQ(output().packets.size(), 99); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, VariableFrameOverlap) { + // Frame of 30 samples with step of 11.4 samples (not rounded) + // results in ceil((1100 - 30) / 11.4) + 1 = 95 packets. + options_.set_frame_duration_seconds(30 / input_sample_rate_); + options_.set_frame_overlap_seconds((30 - 11.4) / input_sample_rate_); + options_.set_emulate_fractional_frame_overlap(true); + MEDIAPIPE_ASSERT_OK(Run()); + EXPECT_EQ(output().packets.size(), 95); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, VariableFrameSkip) { + // Frame of 30 samples with step of 41.4 samples (not rounded) + // results in ceil((1100 - 30) / 41.4) + 1 = 27 packets. + options_.set_frame_duration_seconds(30 / input_sample_rate_); + options_.set_frame_overlap_seconds((30 - 41.4) / input_sample_rate_); + options_.set_emulate_fractional_frame_overlap(true); + MEDIAPIPE_ASSERT_OK(Run()); + EXPECT_EQ(output().packets.size(), 27); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, NoFinalPacketPadding) { + options_.set_frame_duration_seconds(98.5 / input_sample_rate_); + options_.set_pad_final_packet(false); + + MEDIAPIPE_ASSERT_OK(Run()); + CheckOutput(); +} + +TEST_F(TimeSeriesFramerCalculatorTest, + FrameRateHigherThanSampleRate_FrameDurationTooLow) { + // Try to produce a frame rate 10 times the input sample rate by using a + // a frame duration that is too small and covers only 0.1 samples. + options_.set_frame_duration_seconds(1 / (10 * input_sample_rate_)); + options_.set_frame_overlap_seconds(0.0); + EXPECT_FALSE(Run().ok()); +} + +TEST_F(TimeSeriesFramerCalculatorTest, + FrameRateHigherThanSampleRate_FrameStepTooLow) { + // Try to produce a frame rate 10 times the input sample rate by using + // a frame overlap that is too high and produces frame steps (difference + // between duration and overlap) of 0.1 samples. + options_.set_frame_duration_seconds(10.0 / input_sample_rate_); + options_.set_frame_overlap_seconds(9.9 / input_sample_rate_); + EXPECT_FALSE(Run().ok()); +} + +// A simple test class to do windowing sanity checks. Tests from this +// class input a single packet of all ones, and check the average +// value of the single output packet. This is useful as a sanity check +// that the correct windows are applied. +class TimeSeriesFramerCalculatorWindowingSanityTest + : public TimeSeriesFramerCalculatorTest { + protected: + void SetUp() override { + TimeSeriesFramerCalculatorTest::SetUp(); + num_input_channels_ = 1; + } + + void RunAndTestSinglePacketAverage(float expected_average) { + options_.set_frame_duration_seconds(100.0 / input_sample_rate_); + InitializeGraph(); + FillInputHeader(); + AppendInputPacket(new Matrix(Matrix::Ones(1, FrameDurationSamples())), + kInitialTimestampOffsetMicroseconds); + MEDIAPIPE_ASSERT_OK(RunGraph()); + ASSERT_EQ(1, output().packets.size()); + ASSERT_NEAR(expected_average * FrameDurationSamples(), + output().packets[0].Get().sum(), 1e-5); + } +}; + +TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, NoWindowSanityCheck) { + RunAndTestSinglePacketAverage(1.0f); +} + +TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, + HammingWindowSanityCheck) { + options_.set_window_function(TimeSeriesFramerCalculatorOptions::HAMMING); + RunAndTestSinglePacketAverage(0.54f); +} + +TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, HannWindowSanityCheck) { + options_.set_window_function(TimeSeriesFramerCalculatorOptions::HANN); + RunAndTestSinglePacketAverage(0.5f); +} + +} // anonymous namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index bdf655805..4ecf8fe7b 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -19,6 +19,20 @@ package(default_visibility = ["//visibility:private"]) load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +proto_library( + name = "concatenate_vector_calculator_proto", + srcs = ["concatenate_vector_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "packet_cloner_calculator_proto", + srcs = ["packet_cloner_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + proto_library( name = "packet_resampler_calculator_proto", srcs = ["packet_resampler_calculator.proto"], @@ -26,6 +40,46 @@ proto_library( deps = ["//mediapipe/framework:calculator_proto"], ) +proto_library( + name = "split_vector_calculator_proto", + srcs = ["split_vector_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "quantize_float_vector_calculator_proto", + srcs = ["quantize_float_vector_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "sequence_shift_calculator_proto", + srcs = ["sequence_shift_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +proto_library( + name = "gate_calculator_proto", + srcs = ["gate_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +mediapipe_cc_proto_library( + name = "packet_cloner_calculator_cc_proto", + srcs = ["packet_cloner_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":packet_cloner_calculator_proto"], +) + mediapipe_cc_proto_library( name = "packet_resampler_calculator_cc_proto", srcs = ["packet_resampler_calculator.proto"], @@ -34,6 +88,115 @@ mediapipe_cc_proto_library( deps = [":packet_resampler_calculator_proto"], ) +mediapipe_cc_proto_library( + name = "split_vector_calculator_cc_proto", + srcs = ["split_vector_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":split_vector_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "concatenate_vector_calculator_cc_proto", + srcs = ["concatenate_vector_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":concatenate_vector_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "quantize_float_vector_calculator_cc_proto", + srcs = ["quantize_float_vector_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":quantize_float_vector_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "sequence_shift_calculator_cc_proto", + srcs = ["sequence_shift_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":sequence_shift_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "gate_calculator_cc_proto", + srcs = ["gate_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":gate_calculator_proto"], +) + +cc_library( + name = "add_header_calculator", + srcs = ["add_header_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:logging", + ], + alwayslink = 1, +) + +cc_test( + name = "add_header_calculator_test", + size = "small", + srcs = ["add_header_calculator_test.cc"], + deps = [ + ":add_header_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:validate_type", + ], +) + +cc_library( + name = "concatenate_vector_calculator", + srcs = ["concatenate_vector_calculator.cc"], + hdrs = ["concatenate_vector_calculator.h"], + visibility = ["//visibility:public"], + deps = [ + ":concatenate_vector_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@org_tensorflow//tensorflow/lite:framework", + ], + alwayslink = 1, +) + +cc_library( + name = "concatenate_detection_vector_calculator", + srcs = ["concatenate_detection_vector_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":concatenate_vector_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", + ], + alwayslink = 1, +) + +cc_test( + name = "concatenate_vector_calculator_test", + srcs = ["concatenate_vector_calculator_test.cc"], + deps = [ + ":concatenate_vector_calculator", + "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "counting_source_calculator", srcs = ["counting_source_calculator.cc"], @@ -62,6 +225,38 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "matrix_multiply_calculator", + srcs = ["matrix_multiply_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:status", + "@eigen_archive//:eigen", + ], + alwayslink = 1, +) + +cc_library( + name = "matrix_subtract_calculator", + srcs = ["matrix_subtract_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:status", + "@eigen_archive//:eigen", + ], + alwayslink = 1, +) + cc_library( name = "mux_calculator", srcs = ["mux_calculator.cc"], @@ -83,12 +278,37 @@ cc_library( "//visibility:public", ], deps = [ + "//mediapipe/calculators/core:packet_cloner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "@com_google_absl//absl/strings", ], alwayslink = 1, ) +cc_library( + name = "packet_inner_join_calculator", + srcs = ["packet_inner_join_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_test( + name = "packet_inner_join_calculator_test", + srcs = ["packet_inner_join_calculator_test.cc"], + deps = [ + ":packet_inner_join_calculator", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:validate_type", + ], +) + cc_library( name = "pass_through_calculator", srcs = ["pass_through_calculator.cc"], @@ -145,8 +365,8 @@ cc_library( ) cc_library( - name = "real_time_flow_limiter_calculator", - srcs = ["real_time_flow_limiter_calculator.cc"], + name = "flow_limiter_calculator", + srcs = ["flow_limiter_calculator.cc"], visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", @@ -191,19 +411,16 @@ cc_library( "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", + "//mediapipe/framework/deps:mathutil", + "//mediapipe/framework/deps:random", "//mediapipe/framework/formats:video_stream_header", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", "@com_google_absl//absl/strings", - "//mediapipe/framework/deps:mathutil", - "//mediapipe/framework/port:status", - "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:logging", - "//mediapipe/framework/port:integral_types", - ] + select({ - "//conditions:default": [ - "//mediapipe/framework/deps:random", - ], - }), + ], alwayslink = 1, ) @@ -245,10 +462,43 @@ cc_test( ) cc_test( - name = "real_time_flow_limiter_calculator_test", - srcs = ["real_time_flow_limiter_calculator_test.cc"], + name = "matrix_multiply_calculator_test", + srcs = ["matrix_multiply_calculator_test.cc"], + visibility = ["//visibility:private"], deps = [ - ":real_time_flow_limiter_calculator", + ":matrix_multiply_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/tool:validate_type", + "@eigen_archive//:eigen", + ], +) + +cc_test( + name = "matrix_subtract_calculator_test", + srcs = ["matrix_subtract_calculator_test.cc"], + visibility = ["//visibility:private"], + deps = [ + ":matrix_subtract_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:validate_type", + "@eigen_archive//:eigen", + ], +) + +cc_test( + name = "flow_limiter_calculator_test", + srcs = ["flow_limiter_calculator_test.cc"], + deps = [ + ":flow_limiter_calculator", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", @@ -264,3 +514,140 @@ cc_test( "@com_google_absl//absl/time", ], ) + +cc_library( + name = "split_vector_calculator", + srcs = ["split_vector_calculator.cc"], + hdrs = ["split_vector_calculator.h"], + visibility = ["//visibility:public"], + deps = [ + ":split_vector_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/util:resource_util", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +cc_test( + name = "split_vector_calculator_test", + srcs = ["split_vector_calculator_test.cc"], + deps = [ + ":split_vector_calculator", + ":split_vector_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:validate_type", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + +cc_library( + name = "quantize_float_vector_calculator", + srcs = ["quantize_float_vector_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":quantize_float_vector_calculator_cc_proto", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_test( + name = "quantize_float_vector_calculator_test", + srcs = ["quantize_float_vector_calculator_test.cc"], + deps = [ + ":quantize_float_vector_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + ], +) + +cc_library( + name = "sequence_shift_calculator", + srcs = ["sequence_shift_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":sequence_shift_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_test( + name = "sequence_shift_calculator_test", + srcs = ["sequence_shift_calculator_test.cc"], + deps = [ + ":sequence_shift_calculator", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:validate_type", + ], +) + +cc_library( + name = "gate_calculator", + srcs = ["gate_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":gate_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/util:header_util", + ], + alwayslink = 1, +) + +cc_test( + name = "gate_calculator_test", + srcs = ["gate_calculator_test.cc"], + deps = [ + ":gate_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + ], +) + +cc_library( + name = "merge_calculator", + srcs = ["merge_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_test( + name = "merge_calculator_test", + srcs = ["merge_calculator_test.cc"], + deps = [ + ":merge_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + ], +) diff --git a/mediapipe/calculators/core/add_header_calculator.cc b/mediapipe/calculators/core/add_header_calculator.cc new file mode 100644 index 000000000..341813817 --- /dev/null +++ b/mediapipe/calculators/core/add_header_calculator.cc @@ -0,0 +1,53 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +// Attach the header from one stream to another stream. +// +// The header stream (tag HEADER) must not have any packets in it. +// +// Before using this calculator, please think about changing your +// calculator to not need a header or to accept a separate stream with +// a header, that would be more future proof. +// +class AddHeaderCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Tag("HEADER").SetNone(); + cc->Inputs().Tag("DATA").SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Tag("DATA")); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + const Packet& header = cc->Inputs().Tag("HEADER").Header(); + if (!header.IsEmpty()) { + cc->Outputs().Index(0).SetHeader(header); + } + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + cc->Outputs().Index(0).AddPacket(cc->Inputs().Tag("DATA").Value()); + return ::mediapipe::OkStatus(); + } +}; + +REGISTER_CALCULATOR(AddHeaderCalculator); +} // namespace mediapipe diff --git a/mediapipe/calculators/core/add_header_calculator_test.cc b/mediapipe/calculators/core/add_header_calculator_test.cc new file mode 100644 index 000000000..206974125 --- /dev/null +++ b/mediapipe/calculators/core/add_header_calculator_test.cc @@ -0,0 +1,99 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/validate_type.h" + +namespace mediapipe { + +class AddHeaderCalculatorTest : public ::testing::Test {}; + +TEST_F(AddHeaderCalculatorTest, Works) { + CalculatorGraphConfig::Node node; + node.set_calculator("AddHeaderCalculator"); + node.add_input_stream("HEADER:header_stream"); + node.add_input_stream("DATA:data_stream"); + node.add_output_stream("merged_stream"); + + CalculatorRunner runner(node); + + // Set header and add 5 packets. + runner.MutableInputs()->Tag("HEADER").header = + Adopt(new std::string("my_header")); + for (int i = 0; i < 5; ++i) { + Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); + runner.MutableInputs()->Tag("DATA").packets.push_back(packet); + } + + // Run calculator. + MEDIAPIPE_ASSERT_OK(runner.Run()); + + ASSERT_EQ(1, runner.Outputs().NumEntries()); + + // Test output. + EXPECT_EQ(std::string("my_header"), + runner.Outputs().Index(0).header.Get()); + const std::vector& output_packets = runner.Outputs().Index(0).packets; + ASSERT_EQ(5, output_packets.size()); + for (int i = 0; i < 5; ++i) { + const int val = output_packets[i].Get(); + EXPECT_EQ(i, val); + EXPECT_EQ(Timestamp(i * 1000), output_packets[i].Timestamp()); + } +} + +TEST_F(AddHeaderCalculatorTest, HandlesEmptyHeaderStream) { + CalculatorGraphConfig::Node node; + node.set_calculator("AddHeaderCalculator"); + node.add_input_stream("HEADER:header_stream"); + node.add_input_stream("DATA:data_stream"); + node.add_output_stream("merged_stream"); + + CalculatorRunner runner(node); + + // No header and no packets. + // Run calculator. + MEDIAPIPE_ASSERT_OK(runner.Run()); + EXPECT_TRUE(runner.Outputs().Index(0).header.IsEmpty()); +} + +TEST_F(AddHeaderCalculatorTest, NoPacketsOnHeaderStream) { + CalculatorGraphConfig::Node node; + node.set_calculator("AddHeaderCalculator"); + node.add_input_stream("HEADER:header_stream"); + node.add_input_stream("DATA:data_stream"); + node.add_output_stream("merged_stream"); + + CalculatorRunner runner(node); + + // Set header and add 5 packets. + runner.MutableInputs()->Tag("HEADER").header = + Adopt(new std::string("my_header")); + runner.MutableInputs()->Tag("HEADER").packets.push_back( + Adopt(new std::string("not allowed"))); + for (int i = 0; i < 5; ++i) { + Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); + runner.MutableInputs()->Tag("DATA").packets.push_back(packet); + } + + // Run calculator. + ASSERT_FALSE(runner.Run().ok()); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_detection_vector_calculator.cc b/mediapipe/calculators/core/concatenate_detection_vector_calculator.cc new file mode 100644 index 000000000..161a323cf --- /dev/null +++ b/mediapipe/calculators/core/concatenate_detection_vector_calculator.cc @@ -0,0 +1,33 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/core/concatenate_vector_calculator.h" +#include "mediapipe/framework/formats/detection.pb.h" + +namespace mediapipe { + +// Example config: +// node { +// calculator: "ConcatenateDetectionVectorCalculator" +// input_stream: "detection_vector_1" +// input_stream: "detection_vector_2" +// output_stream: "concatenated_detection_vector" +// } +typedef ConcatenateVectorCalculator<::mediapipe::Detection> + ConcatenateDetectionVectorCalculator; +REGISTER_CALCULATOR(ConcatenateDetectionVectorCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.cc b/mediapipe/calculators/core/concatenate_vector_calculator.cc new file mode 100644 index 000000000..087a9b6fe --- /dev/null +++ b/mediapipe/calculators/core/concatenate_vector_calculator.cc @@ -0,0 +1,44 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/core/concatenate_vector_calculator.h" + +#include + +#include "tensorflow/lite/interpreter.h" + +namespace mediapipe { + +// Example config: +// node { +// calculator: "ConcatenateFloatVectorCalculator" +// input_stream: "float_vector_1" +// input_stream: "float_vector_2" +// output_stream: "concatenated_float_vector" +// } +typedef ConcatenateVectorCalculator ConcatenateFloatVectorCalculator; +REGISTER_CALCULATOR(ConcatenateFloatVectorCalculator); + +// Example config: +// node { +// calculator: "ConcatenateTfLiteTensorVectorCalculator" +// input_stream: "tflitetensor_vector_1" +// input_stream: "tflitetensor_vector_2" +// output_stream: "concatenated_tflitetensor_vector" +// } +typedef ConcatenateVectorCalculator + ConcatenateTfLiteTensorVectorCalculator; +REGISTER_CALCULATOR(ConcatenateTfLiteTensorVectorCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.h b/mediapipe/calculators/core/concatenate_vector_calculator.h new file mode 100644 index 000000000..b7ee24a9c --- /dev/null +++ b/mediapipe/calculators/core/concatenate_vector_calculator.h @@ -0,0 +1,78 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_ + +#include + +#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +// Concatenates several std::vector following stream index order. This class +// assumes that every input stream contains the vector type. To use this +// class for a particular type T, regisiter a calculator using +// ConcatenateVectorCalculator. +template +class ConcatenateVectorCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK(cc->Inputs().NumEntries() != 0); + RET_CHECK(cc->Outputs().NumEntries() == 1); + + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).Set>(); + } + + cc->Outputs().Index(0).Set>(); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + only_emit_if_all_present_ = + cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>() + .only_emit_if_all_present(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + if (only_emit_if_all_present_) { + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + if (cc->Inputs().Index(i).IsEmpty()) return ::mediapipe::OkStatus(); + } + } + auto output = absl::make_unique>(); + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + if (cc->Inputs().Index(i).IsEmpty()) continue; + const std::vector& input = cc->Inputs().Index(i).Get>(); + output->insert(output->end(), input.begin(), input.end()); + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } + + private: + bool only_emit_if_all_present_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_ diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.proto b/mediapipe/calculators/core/concatenate_vector_calculator.proto new file mode 100644 index 000000000..bddb8af95 --- /dev/null +++ b/mediapipe/calculators/core/concatenate_vector_calculator.proto @@ -0,0 +1,29 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message ConcatenateVectorCalculatorOptions { + extend CalculatorOptions { + optional ConcatenateVectorCalculatorOptions ext = 259397839; + } + + // If true, the calculator will only emit a packet at the given timestamp if + // all input streams have a non-empty packet (AND operation on streams). + optional bool only_emit_if_all_present = 1 [default = false]; +} diff --git a/mediapipe/calculators/core/concatenate_vector_calculator_test.cc b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc new file mode 100644 index 000000000..89f6976c0 --- /dev/null +++ b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc @@ -0,0 +1,238 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/core/concatenate_vector_calculator.h" + +#include +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT + +namespace mediapipe { + +typedef ConcatenateVectorCalculator TestConcatenateIntVectorCalculator; +REGISTER_CALCULATOR(TestConcatenateIntVectorCalculator); + +void AddInputVectors(const std::vector>& inputs, + int64 timestamp, CalculatorRunner* runner) { + for (int i = 0; i < inputs.size(); ++i) { + runner->MutableInputs()->Index(i).packets.push_back( + MakePacket>(inputs[i]).At(Timestamp(timestamp))); + } +} + +TEST(TestConcatenateIntVectorCalculatorTest, EmptyVectorInputs) { + CalculatorRunner runner("TestConcatenateIntVectorCalculator", + /*options_string=*/"", /*num_inputs=*/3, + /*num_outputs=*/1, /*num_side_packets=*/0); + + std::vector> inputs = {{}, {}, {}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + MEDIAPIPE_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_TRUE(outputs[0].Get>().empty()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); +} + +TEST(TestConcatenateIntVectorCalculatorTest, OneTimestamp) { + CalculatorRunner runner("TestConcatenateIntVectorCalculator", + /*options_string=*/"", /*num_inputs=*/3, + /*num_outputs=*/1, /*num_side_packets=*/0); + + std::vector> inputs = {{1, 2, 3}, {4}, {5, 6}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + MEDIAPIPE_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + std::vector expected_vector = {1, 2, 3, 4, 5, 6}; + EXPECT_EQ(expected_vector, outputs[0].Get>()); +} + +TEST(TestConcatenateIntVectorCalculatorTest, TwoInputsAtTwoTimestamps) { + CalculatorRunner runner("TestConcatenateIntVectorCalculator", + /*options_string=*/"", /*num_inputs=*/3, + /*num_outputs=*/1, /*num_side_packets=*/0); + { + std::vector> inputs = {{1, 2, 3}, {4}, {5, 6}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + } + { + std::vector> inputs = {{0, 2}, {1}, {3, 5}}; + AddInputVectors(inputs, /*timestamp=*/2, &runner); + } + MEDIAPIPE_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(2, outputs.size()); + { + EXPECT_EQ(6, outputs[0].Get>().size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + std::vector expected_vector = {1, 2, 3, 4, 5, 6}; + EXPECT_EQ(expected_vector, outputs[0].Get>()); + } + { + EXPECT_EQ(5, outputs[1].Get>().size()); + EXPECT_EQ(Timestamp(2), outputs[1].Timestamp()); + std::vector expected_vector = {0, 2, 1, 3, 5}; + EXPECT_EQ(expected_vector, outputs[1].Get>()); + } +} + +TEST(TestConcatenateIntVectorCalculatorTest, OneEmptyStreamStillOutput) { + CalculatorRunner runner("TestConcatenateIntVectorCalculator", + /*options_string=*/"", /*num_inputs=*/2, + /*num_outputs=*/1, /*num_side_packets=*/0); + + std::vector> inputs = {{1, 2, 3}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + MEDIAPIPE_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + std::vector expected_vector = {1, 2, 3}; + EXPECT_EQ(expected_vector, outputs[0].Get>()); +} + +TEST(TestConcatenateIntVectorCalculatorTest, OneEmptyStreamNoOutput) { + CalculatorRunner runner("TestConcatenateIntVectorCalculator", + /*options_string=*/ + "[mediapipe.ConcatenateVectorCalculatorOptions.ext]: " + "{only_emit_if_all_present: true}", + /*num_inputs=*/2, + /*num_outputs=*/1, /*num_side_packets=*/0); + + std::vector> inputs = {{1, 2, 3}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + MEDIAPIPE_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(0, outputs.size()); +} + +void AddInputVectors(const std::vector>& inputs, + int64 timestamp, CalculatorRunner* runner) { + for (int i = 0; i < inputs.size(); ++i) { + runner->MutableInputs()->Index(i).packets.push_back( + MakePacket>(inputs[i]).At(Timestamp(timestamp))); + } +} + +TEST(ConcatenateFloatVectorCalculatorTest, EmptyVectorInputs) { + CalculatorRunner runner("ConcatenateFloatVectorCalculator", + /*options_string=*/"", /*num_inputs=*/3, + /*num_outputs=*/1, /*num_side_packets=*/0); + + std::vector> inputs = {{}, {}, {}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + MEDIAPIPE_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_TRUE(outputs[0].Get>().empty()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); +} + +TEST(ConcatenateFloatVectorCalculatorTest, OneTimestamp) { + CalculatorRunner runner("ConcatenateFloatVectorCalculator", + /*options_string=*/"", /*num_inputs=*/3, + /*num_outputs=*/1, /*num_side_packets=*/0); + + std::vector> inputs = { + {1.0f, 2.0f, 3.0f}, {4.0f}, {5.0f, 6.0f}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + MEDIAPIPE_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + std::vector expected_vector = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + EXPECT_EQ(expected_vector, outputs[0].Get>()); +} + +TEST(ConcatenateFloatVectorCalculatorTest, TwoInputsAtTwoTimestamps) { + CalculatorRunner runner("ConcatenateFloatVectorCalculator", + /*options_string=*/"", /*num_inputs=*/3, + /*num_outputs=*/1, /*num_side_packets=*/0); + { + std::vector> inputs = { + {1.0f, 2.0f, 3.0f}, {4.0f}, {5.0f, 6.0f}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + } + { + std::vector> inputs = { + {0.0f, 2.0f}, {1.0f}, {3.0f, 5.0f}}; + AddInputVectors(inputs, /*timestamp=*/2, &runner); + } + MEDIAPIPE_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(2, outputs.size()); + { + EXPECT_EQ(6, outputs[0].Get>().size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + std::vector expected_vector = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + EXPECT_EQ(expected_vector, outputs[0].Get>()); + } + { + EXPECT_EQ(5, outputs[1].Get>().size()); + EXPECT_EQ(Timestamp(2), outputs[1].Timestamp()); + std::vector expected_vector = {0.0f, 2.0f, 1.0f, 3.0f, 5.0f}; + EXPECT_EQ(expected_vector, outputs[1].Get>()); + } +} + +TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamStillOutput) { + CalculatorRunner runner("ConcatenateFloatVectorCalculator", + /*options_string=*/"", /*num_inputs=*/2, + /*num_outputs=*/1, /*num_side_packets=*/0); + + std::vector> inputs = {{1.0f, 2.0f, 3.0f}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + MEDIAPIPE_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + std::vector expected_vector = {1.0f, 2.0f, 3.0f}; + EXPECT_EQ(expected_vector, outputs[0].Get>()); +} + +TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) { + CalculatorRunner runner("ConcatenateFloatVectorCalculator", + /*options_string=*/ + "[mediapipe.ConcatenateVectorCalculatorOptions.ext]: " + "{only_emit_if_all_present: true}", + /*num_inputs=*/2, + /*num_outputs=*/1, /*num_side_packets=*/0); + + std::vector> inputs = {{1.0f, 2.0f, 3.0f}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + MEDIAPIPE_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(0, outputs.size()); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc b/mediapipe/calculators/core/flow_limiter_calculator.cc similarity index 89% rename from mediapipe/calculators/core/real_time_flow_limiter_calculator.cc rename to mediapipe/calculators/core/flow_limiter_calculator.cc index ea857dc43..6d595e6cd 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator.cc @@ -23,34 +23,34 @@ namespace mediapipe { -// RealTimeFlowLimiterCalculator is used to limit the number of pipelined -// processing operations in a section of the graph. +// FlowLimiterCalculator is used to limit the number of pipelined processing +// operations in a section of the graph. // // Typical topology: // -// in ->-[RTFLC]-[foo]-...-[bar]-+->- out -// ^____________________| -// FINISHED +// in ->-[FLC]-[foo]-...-[bar]-+->- out +// ^_____________________| +// FINISHED // // By connecting the output of the graph section to this calculator's FINISHED -// input with a backwards edge, this allows RTFLC to keep track of how many +// input with a backwards edge, this allows FLC to keep track of how many // timestamps are currently being processed. // // The limit defaults to 1, and can be overridden with the MAX_IN_FLIGHT side // packet. // // As long as the number of timestamps being processed ("in flight") is below -// the limit, RTFLC allows input to pass through. When the limit is reached, -// RTFLC starts dropping input packets, keeping only the most recent. When the +// the limit, FLC allows input to pass through. When the limit is reached, +// FLC starts dropping input packets, keeping only the most recent. When the // processing count decreases again, as signaled by the receipt of a packet on -// FINISHED, RTFLC allows packets to flow again, releasing the most recently +// FINISHED, FLC allows packets to flow again, releasing the most recently // queued packet, if any. // // If there are multiple input streams, packet dropping is synchronized. // -// IMPORTANT: for each timestamp where RTFLC forwards a packet (or a set of +// IMPORTANT: for each timestamp where FLC forwards a packet (or a set of // packets, if using multiple data streams), a packet must eventually arrive on -// the FINISHED stream. Dropping packets in the section between RTFLC and +// the FINISHED stream. Dropping packets in the section between FLC and // FINISHED will make the in-flight count incorrect. // // TODO: Remove this comment when graph-level ISH has been removed. @@ -61,7 +61,7 @@ namespace mediapipe { // // Example config: // node { -// calculator: "RealTimeFlowLimiterCalculator" +// calculator: "FlowLimiterCalculator" // input_stream: "raw_frames" // input_stream: "FINISHED:finished" // input_stream_info: { @@ -73,7 +73,7 @@ namespace mediapipe { // } // output_stream: "gated_frames" // } -class RealTimeFlowLimiterCalculator : public CalculatorBase { +class FlowLimiterCalculator : public CalculatorBase { public: static ::mediapipe::Status GetContract(CalculatorContract* cc) { int num_data_streams = cc->Inputs().NumEntries(""); @@ -194,6 +194,6 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase { Timestamp allow_ctr_ts_; std::vector data_stream_bound_ts_; }; -REGISTER_CALCULATOR(RealTimeFlowLimiterCalculator); +REGISTER_CALCULATOR(FlowLimiterCalculator); } // namespace mediapipe diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc similarity index 93% rename from mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc rename to mediapipe/calculators/core/flow_limiter_calculator_test.cc index 8c386b8ea..23d13ef4b 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -71,7 +71,7 @@ constexpr int kNumImageFrames = 5; constexpr int kNumFinished = 3; CalculatorGraphConfig::Node GetDefaultNode() { return ParseTextProtoOrDie(R"( - calculator: "RealTimeFlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "raw_frames" input_stream: "FINISHED:finished" input_stream_info: { tag_index: "FINISHED" back_edge: true } @@ -79,9 +79,9 @@ CalculatorGraphConfig::Node GetDefaultNode() { )"); } -// Simple test to make sure that the RealTimeFlowLimiterCalculator outputs -// just one packet when MAX_IN_FLIGHT is 1. -TEST(RealTimeFlowLimiterCalculator, OneOutputTest) { +// Simple test to make sure that the FlowLimiterCalculator outputs just one +// packet when MAX_IN_FLIGHT is 1. +TEST(FlowLimiterCalculator, OneOutputTest) { // Setup the calculator runner and add only ImageFrame packets. CalculatorRunner runner(GetDefaultNode()); for (int i = 0; i < kNumImageFrames; ++i) { @@ -98,9 +98,9 @@ TEST(RealTimeFlowLimiterCalculator, OneOutputTest) { EXPECT_EQ(frame_output_packets.size(), 1); } -// Simple test to make sure that the RealTimeFlowLimiterCalculator waits for all +// Simple test to make sure that the FlowLimiterCalculator waits for all // input streams to have at least one packet available before publishing. -TEST(RealTimeFlowLimiterCalculator, BasicTest) { +TEST(FlowLimiterCalculator, BasicTest) { // Setup the calculator runner and add both ImageFrame and finish packets. CalculatorRunner runner(GetDefaultNode()); for (int i = 0; i < kNumImageFrames; ++i) { @@ -171,13 +171,11 @@ class CloseCallbackCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(CloseCallbackCalculator); -// Tests demostrating an RealTimeFlowLimiterCalculator operating in a cyclic -// graph. +// Tests demostrating an FlowLimiterCalculator operating in a cyclic graph. // TODO: clean up these tests. -class RealTimeFlowLimiterCalculatorTest : public testing::Test { +class FlowLimiterCalculatorTest : public testing::Test { public: - RealTimeFlowLimiterCalculatorTest() - : enter_semaphore_(0), exit_semaphore_(0) {} + FlowLimiterCalculatorTest() : enter_semaphore_(0), exit_semaphore_(0) {} void SetUp() override { graph_config_ = InflightGraphConfig(); @@ -215,7 +213,7 @@ class RealTimeFlowLimiterCalculatorTest : public testing::Test { input_name, MakePacket(value).At(Timestamp(value)))); } - // A calculator graph starting with an RealTimeFlowLimiterCalculator and + // A calculator graph starting with an FlowLimiterCalculator and // ending with a InFlightFinishCalculator. // Back-edge "finished" limits processing to one frame in-flight. // The two LambdaCalculators are used to keep certain packet sets in flight. @@ -224,7 +222,7 @@ class RealTimeFlowLimiterCalculatorTest : public testing::Test { input_stream: 'in_1' input_stream: 'in_2' node { - calculator: 'RealTimeFlowLimiterCalculator' + calculator: 'FlowLimiterCalculator' input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' input_stream: 'in_1' input_stream: 'in_2' @@ -270,14 +268,14 @@ class RealTimeFlowLimiterCalculatorTest : public testing::Test { int close_count_ = 0; }; -// A test demonstrating an RealTimeFlowLimiterCalculator operating in a cyclic +// A test demonstrating an FlowLimiterCalculator operating in a cyclic // graph. This test shows that: // // (1) Timestamps are passed through unaltered. // (2) All output streams including the back_edge stream are closed when // the first input stream is closed. // -TEST_F(RealTimeFlowLimiterCalculatorTest, BackEdgeCloses) { +TEST_F(FlowLimiterCalculatorTest, BackEdgeCloses) { InitializeGraph(1); MEDIAPIPE_ASSERT_OK(graph_.StartRun({})); @@ -321,7 +319,7 @@ TEST_F(RealTimeFlowLimiterCalculatorTest, BackEdgeCloses) { // A test demonstrating that all output streams are closed when all // input streams are closed after the last input packet has been processed. -TEST_F(RealTimeFlowLimiterCalculatorTest, AllStreamsClose) { +TEST_F(FlowLimiterCalculatorTest, AllStreamsClose) { InitializeGraph(1); MEDIAPIPE_ASSERT_OK(graph_.StartRun({})); @@ -341,7 +339,7 @@ TEST_F(RealTimeFlowLimiterCalculatorTest, AllStreamsClose) { EXPECT_EQ(1, close_count_); } -TEST(RealTimeFlowLimiterCalculator, TwoStreams) { +TEST(FlowLimiterCalculator, TwoStreams) { std::vector a_passed; std::vector b_passed; CalculatorGraphConfig graph_config_ = @@ -351,7 +349,7 @@ TEST(RealTimeFlowLimiterCalculator, TwoStreams) { input_stream: 'finished' node { name: 'input_dropper' - calculator: 'RealTimeFlowLimiterCalculator' + calculator: 'FlowLimiterCalculator' input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' input_stream: 'in_a' input_stream: 'in_b' @@ -440,7 +438,7 @@ TEST(RealTimeFlowLimiterCalculator, TwoStreams) { MEDIAPIPE_EXPECT_OK(graph_.WaitUntilDone()); } -TEST(RealTimeFlowLimiterCalculator, CanConsume) { +TEST(FlowLimiterCalculator, CanConsume) { std::vector in_sampled_packets_; CalculatorGraphConfig graph_config_ = ParseTextProtoOrDie(R"( @@ -448,7 +446,7 @@ TEST(RealTimeFlowLimiterCalculator, CanConsume) { input_stream: 'finished' node { name: 'input_dropper' - calculator: 'RealTimeFlowLimiterCalculator' + calculator: 'FlowLimiterCalculator' input_side_packet: 'MAX_IN_FLIGHT:max_in_flight' input_stream: 'in' input_stream: 'FINISHED:finished' diff --git a/mediapipe/calculators/core/gate_calculator.cc b/mediapipe/calculators/core/gate_calculator.cc new file mode 100644 index 000000000..aedd01b64 --- /dev/null +++ b/mediapipe/calculators/core/gate_calculator.cc @@ -0,0 +1,163 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/core/gate_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/header_util.h" + +namespace mediapipe { + +namespace { +enum GateState { + GATE_UNINITIALIZED, + GATE_ALLOW, + GATE_DISALLOW, +}; + +std::string ToString(GateState state) { + switch (state) { + case GATE_UNINITIALIZED: + return "UNINITIALIZED"; + case GATE_ALLOW: + return "ALLOW"; + case GATE_DISALLOW: + return "DISALLOW"; + } + DLOG(FATAL) << "Unknown GateState"; + return "UNKNOWN"; +} +} // namespace + +// Controls whether or not the input packets are passed further along the graph. +// Takes multiple data input streams and either an ALLOW or a DISALLOW control +// input stream. It outputs an output stream for each input stream that is not +// ALLOW or DISALLOW as well as an optional STATE_CHANGE stream which downstream +// calculators can use to respond to state-change events. +// +// If the current ALLOW packet is set to true, the input packets are passed to +// their corresponding output stream unchanged. If the ALLOW packet is set to +// false, the current input packet is NOT passed to the output stream. If using +// DISALLOW, the behavior is opposite of ALLOW. +// +// By default, an empty packet in the ALLOW or DISALLOW input stream indicates +// disallowing the corresponding packets in other input streams. The behavior +// can be inverted with a calculator option. +// +// Intended to be used with the default input stream handler, which synchronizes +// all data input streams with the ALLOW/DISALLOW control input stream. +// +// Example config: +// node { +// calculator: "GateCalculator" +// input_stream: "input_stream0" +// input_stream: "input_stream1" +// input_stream: "input_streamN" +// input_stream: "ALLOW:allow" or "DISALLOW:disallow" +// output_stream: "STATE_CHANGE:state_change" +// output_stream: "output_stream0" +// output_stream: "output_stream1" +// output_stream: "output_streamN" +// } +class GateCalculator : public CalculatorBase { + public: + GateCalculator() {} + + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + // Assume that input streams do not have a tag and that gating signal is + // tagged either ALLOW or DISALLOW. + RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW")); + const int num_data_streams = cc->Inputs().NumEntries(""); + RET_CHECK_GE(num_data_streams, 1); + RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams) + << "Number of data output streams must match with data input streams."; + + for (int i = 0; i < num_data_streams; ++i) { + cc->Inputs().Get("", i).SetAny(); + cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i)); + } + if (cc->Inputs().HasTag("ALLOW")) { + cc->Inputs().Tag("ALLOW").Set(); + } else { + cc->Inputs().Tag("DISALLOW").Set(); + } + + if (cc->Outputs().HasTag("STATE_CHANGE")) { + cc->Outputs().Tag("STATE_CHANGE").Set(); + } + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(0)); + num_data_streams_ = cc->Inputs().NumEntries(""); + last_gate_state_ = GATE_UNINITIALIZED; + RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs())); + + const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>(); + empty_packets_as_allow_ = options.empty_packets_as_allow(); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + bool allow = empty_packets_as_allow_; + if (cc->Inputs().HasTag("ALLOW") && !cc->Inputs().Tag("ALLOW").IsEmpty()) { + allow = cc->Inputs().Tag("ALLOW").Get(); + } + if (cc->Inputs().HasTag("DISALLOW") && + !cc->Inputs().Tag("DISALLOW").IsEmpty()) { + allow = !cc->Inputs().Tag("DISALLOW").Get(); + } + + const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW; + + if (cc->Outputs().HasTag("STATE_CHANGE")) { + if (last_gate_state_ != GATE_UNINITIALIZED && + last_gate_state_ != new_gate_state) { + VLOG(2) << "State transition in " << cc->NodeName() << " @ " + << cc->InputTimestamp().Value() << " from " + << ToString(last_gate_state_) << " to " + << ToString(new_gate_state); + cc->Outputs() + .Tag("STATE_CHANGE") + .AddPacket(MakePacket(allow).At(cc->InputTimestamp())); + } + } + last_gate_state_ = new_gate_state; + + if (!allow) { + return ::mediapipe::OkStatus(); + } + + // Process data streams. + for (int i = 0; i < num_data_streams_; ++i) { + if (!cc->Inputs().Get("", i).IsEmpty()) { + cc->Outputs().Get("", i).AddPacket(cc->Inputs().Get("", i).Value()); + } + } + + return ::mediapipe::OkStatus(); + } + + private: + GateState last_gate_state_ = GATE_UNINITIALIZED; + int num_data_streams_; + bool empty_packets_as_allow_; +}; +REGISTER_CALCULATOR(GateCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/gate_calculator.proto b/mediapipe/calculators/core/gate_calculator.proto new file mode 100644 index 000000000..0ef2c3e1c --- /dev/null +++ b/mediapipe/calculators/core/gate_calculator.proto @@ -0,0 +1,30 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message GateCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional GateCalculatorOptions ext = 261754847; + } + + // By default an empty packet in the ALLOW or DISALLOW input stream indicates + // disallowing the corresponding packets in the data input streams. Setting + // this option to true inverts that, allowing the data packets to go through. + optional bool empty_packets_as_allow = 1; +} diff --git a/mediapipe/calculators/core/gate_calculator_test.cc b/mediapipe/calculators/core/gate_calculator_test.cc new file mode 100644 index 000000000..e00038879 --- /dev/null +++ b/mediapipe/calculators/core/gate_calculator_test.cc @@ -0,0 +1,190 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +class GateCalculatorTest : public ::testing::Test { + protected: + void RunTimeStep(int64 timestamp, const std::string& control_tag, + bool control) { + runner_->MutableInputs()->Get("", 0).packets.push_back( + MakePacket(true).At(Timestamp(timestamp))); + runner_->MutableInputs() + ->Tag(control_tag) + .packets.push_back(MakePacket(control).At(Timestamp(timestamp))); + + MEDIAPIPE_ASSERT_OK(runner_->Run()) << "Calculator execution failed."; + } + + void SetRunner(const std::string& proto) { + runner_ = absl::make_unique( + ParseTextProtoOrDie(proto)); + } + + CalculatorRunner* runner() { return runner_.get(); } + + private: + std::unique_ptr runner_; +}; + +TEST_F(GateCalculatorTest, Allow) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "ALLOW:gating_stream" + output_stream: "test_output" + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, "ALLOW", true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, "ALLOW", false); + constexpr int64 kTimestampValue2 = 44; + RunTimeStep(kTimestampValue2, "ALLOW", true); + constexpr int64 kTimestampValue3 = 45; + RunTimeStep(kTimestampValue3, "ALLOW", false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(2, output.size()); + EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value()); + EXPECT_EQ(kTimestampValue2, output[1].Timestamp().Value()); + EXPECT_EQ(true, output[0].Get()); + EXPECT_EQ(true, output[1].Get()); +} + +TEST_F(GateCalculatorTest, Disallow) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "DISALLOW:gating_stream" + output_stream: "test_output" + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, "DISALLOW", true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, "DISALLOW", false); + constexpr int64 kTimestampValue2 = 44; + RunTimeStep(kTimestampValue2, "DISALLOW", true); + constexpr int64 kTimestampValue3 = 45; + RunTimeStep(kTimestampValue3, "DISALLOW", false); + + const std::vector& output = runner()->Outputs().Get("", 0).packets; + ASSERT_EQ(2, output.size()); + EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value()); + EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value()); + EXPECT_EQ(true, output[0].Get()); + EXPECT_EQ(true, output[1].Get()); +} + +TEST_F(GateCalculatorTest, AllowWithStateChange) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "ALLOW:gating_stream" + output_stream: "test_output" + output_stream: "STATE_CHANGE:state_changed" + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, "ALLOW", false); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, "ALLOW", true); + constexpr int64 kTimestampValue2 = 44; + RunTimeStep(kTimestampValue2, "ALLOW", true); + constexpr int64 kTimestampValue3 = 45; + RunTimeStep(kTimestampValue3, "ALLOW", false); + + const std::vector& output = + runner()->Outputs().Get("STATE_CHANGE", 0).packets; + ASSERT_EQ(2, output.size()); + EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value()); + EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value()); + EXPECT_EQ(true, output[0].Get()); // Allow. + EXPECT_EQ(false, output[1].Get()); // Disallow. +} + +TEST_F(GateCalculatorTest, DisallowWithStateChange) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "DISALLOW:gating_stream" + output_stream: "test_output" + output_stream: "STATE_CHANGE:state_changed" + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, "DISALLOW", true); + constexpr int64 kTimestampValue1 = 43; + RunTimeStep(kTimestampValue1, "DISALLOW", false); + constexpr int64 kTimestampValue2 = 44; + RunTimeStep(kTimestampValue2, "DISALLOW", false); + constexpr int64 kTimestampValue3 = 45; + RunTimeStep(kTimestampValue3, "DISALLOW", true); + + const std::vector& output = + runner()->Outputs().Get("STATE_CHANGE", 0).packets; + ASSERT_EQ(2, output.size()); + EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value()); + EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value()); + EXPECT_EQ(true, output[0].Get()); // Allow. + EXPECT_EQ(false, output[1].Get()); // Disallow. +} + +// Must not detect disallow value for first timestamp as a state change. +TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "DISALLOW:gating_stream" + output_stream: "test_output" + output_stream: "STATE_CHANGE:state_changed" + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, "DISALLOW", false); + + const std::vector& output = + runner()->Outputs().Get("STATE_CHANGE", 0).packets; + ASSERT_EQ(0, output.size()); +} + +// Must not detect allow value for first timestamp as a state change. +TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "ALLOW:gating_stream" + output_stream: "test_output" + output_stream: "STATE_CHANGE:state_changed" + )"); + + constexpr int64 kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, "ALLOW", true); + + const std::vector& output = + runner()->Outputs().Get("STATE_CHANGE", 0).packets; + ASSERT_EQ(0, output.size()); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/immediate_mux_calculator_test.cc b/mediapipe/calculators/core/immediate_mux_calculator_test.cc index 6fe318712..974113e75 100644 --- a/mediapipe/calculators/core/immediate_mux_calculator_test.cc +++ b/mediapipe/calculators/core/immediate_mux_calculator_test.cc @@ -146,7 +146,7 @@ class ImmediateMuxCalculatorTest : public ::testing::Test { ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( input_stream: "input_packets_0" node { - calculator: 'RealTimeFlowLimiterCalculator' + calculator: 'FlowLimiterCalculator' input_stream_handler { input_stream_handler: 'ImmediateInputStreamHandler' } diff --git a/mediapipe/calculators/core/matrix_multiply_calculator.cc b/mediapipe/calculators/core/matrix_multiply_calculator.cc new file mode 100644 index 000000000..8dc60b763 --- /dev/null +++ b/mediapipe/calculators/core/matrix_multiply_calculator.cc @@ -0,0 +1,66 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Eigen/Core" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +// Perform a (left) matrix multiply. Meaning (output = A * input) +// where A is the matrix which is provided as an input side packet. +// +// Example config: +// node { +// calculator: "MatrixMultiplyCalculator" +// input_stream: "samples" +// output_stream: "multiplied_samples" +// input_side_packet: "multiplication_matrix" +// } +class MatrixMultiplyCalculator : public CalculatorBase { + public: + MatrixMultiplyCalculator() {} + ~MatrixMultiplyCalculator() override {} + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; +}; +REGISTER_CALCULATOR(MatrixMultiplyCalculator); + +// static +::mediapipe::Status MatrixMultiplyCalculator::GetContract( + CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + cc->InputSidePackets().Index(0).Set(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status MatrixMultiplyCalculator::Open(CalculatorContext* cc) { + // The output is at the same timestamp as the input. + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) { + Matrix* multiplied = new Matrix(); + *multiplied = cc->InputSidePackets().Index(0).Get() * + cc->Inputs().Index(0).Get(); + cc->Outputs().Index(0).Add(multiplied, cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/matrix_multiply_calculator_test.cc b/mediapipe/calculators/core/matrix_multiply_calculator_test.cc new file mode 100644 index 000000000..7c519d444 --- /dev/null +++ b/mediapipe/calculators/core/matrix_multiply_calculator_test.cc @@ -0,0 +1,239 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "Eigen/Core" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/validate_type.h" + +namespace mediapipe { +namespace { + +// A 3x4 Matrix of random integers in [0,1000). +const char kMatrixText[] = + "rows: 3\n" + "cols: 4\n" + "packed_data: 387\n" + "packed_data: 940\n" + "packed_data: 815\n" + "packed_data: 825\n" + "packed_data: 997\n" + "packed_data: 884\n" + "packed_data: 419\n" + "packed_data: 763\n" + "packed_data: 123\n" + "packed_data: 30\n" + "packed_data: 825\n" + "packed_data: 299\n"; + +// A 4x20 Matrix of random integers in [0,10). +// Each column of this matrix is a sample. +const char kSamplesText[] = + "rows: 4\n" + "cols: 20\n" + "packed_data: 7\n" + "packed_data: 9\n" + "packed_data: 5\n" + "packed_data: 9\n" + "packed_data: 6\n" + "packed_data: 3\n" + "packed_data: 0\n" + "packed_data: 7\n" + "packed_data: 1\n" + "packed_data: 3\n" + "packed_data: 3\n" + "packed_data: 2\n" + "packed_data: 4\n" + "packed_data: 5\n" + "packed_data: 0\n" + "packed_data: 4\n" + "packed_data: 6\n" + "packed_data: 0\n" + "packed_data: 1\n" + "packed_data: 2\n" + "packed_data: 0\n" + "packed_data: 2\n" + "packed_data: 0\n" + "packed_data: 3\n" + "packed_data: 1\n" + "packed_data: 7\n" + "packed_data: 4\n" + "packed_data: 9\n" + "packed_data: 8\n" + "packed_data: 8\n" + "packed_data: 6\n" + "packed_data: 4\n" + "packed_data: 6\n" + "packed_data: 8\n" + "packed_data: 1\n" + "packed_data: 9\n" + "packed_data: 7\n" + "packed_data: 5\n" + "packed_data: 3\n" + "packed_data: 5\n" + "packed_data: 3\n" + "packed_data: 5\n" + "packed_data: 7\n" + "packed_data: 7\n" + "packed_data: 3\n" + "packed_data: 3\n" + "packed_data: 6\n" + "packed_data: 4\n" + "packed_data: 7\n" + "packed_data: 7\n" + "packed_data: 2\n" + "packed_data: 5\n" + "packed_data: 4\n" + "packed_data: 8\n" + "packed_data: 1\n" + "packed_data: 0\n" + "packed_data: 2\n" + "packed_data: 0\n" + "packed_data: 3\n" + "packed_data: 4\n" + "packed_data: 6\n" + "packed_data: 6\n" + "packed_data: 8\n" + "packed_data: 5\n" + "packed_data: 5\n" + "packed_data: 8\n" + "packed_data: 9\n" + "packed_data: 7\n" + "packed_data: 3\n" + "packed_data: 7\n" + "packed_data: 2\n" + "packed_data: 7\n" + "packed_data: 8\n" + "packed_data: 2\n" + "packed_data: 1\n" + "packed_data: 1\n" + "packed_data: 4\n" + "packed_data: 1\n" + "packed_data: 1\n" + "packed_data: 7\n"; + +// A 3x20 Matrix of expected values for the result of the matrix multiply +// computed using R. +// Each column of this matrix is an expected output. +const char kExpectedText[] = + "rows: 3\n" + "cols: 20\n" + "packed_data: 12499\n" + "packed_data: 26793\n" + "packed_data: 16967\n" + "packed_data: 5007\n" + "packed_data: 14406\n" + "packed_data: 9635\n" + "packed_data: 4179\n" + "packed_data: 7870\n" + "packed_data: 4434\n" + "packed_data: 5793\n" + "packed_data: 12045\n" + "packed_data: 8876\n" + "packed_data: 2801\n" + "packed_data: 8053\n" + "packed_data: 5611\n" + "packed_data: 1740\n" + "packed_data: 4469\n" + "packed_data: 2665\n" + "packed_data: 8108\n" + "packed_data: 18396\n" + "packed_data: 10186\n" + "packed_data: 12330\n" + "packed_data: 23374\n" + "packed_data: 15526\n" + "packed_data: 9611\n" + "packed_data: 21804\n" + "packed_data: 14776\n" + "packed_data: 8241\n" + "packed_data: 17979\n" + "packed_data: 11989\n" + "packed_data: 8429\n" + "packed_data: 18921\n" + "packed_data: 9819\n" + "packed_data: 6270\n" + "packed_data: 13689\n" + "packed_data: 7031\n" + "packed_data: 9472\n" + "packed_data: 19210\n" + "packed_data: 13634\n" + "packed_data: 8567\n" + "packed_data: 12499\n" + "packed_data: 10455\n" + "packed_data: 2151\n" + "packed_data: 7469\n" + "packed_data: 3195\n" + "packed_data: 10774\n" + "packed_data: 21851\n" + "packed_data: 12673\n" + "packed_data: 12516\n" + "packed_data: 25318\n" + "packed_data: 14347\n" + "packed_data: 7984\n" + "packed_data: 17100\n" + "packed_data: 10972\n" + "packed_data: 5195\n" + "packed_data: 11102\n" + "packed_data: 8710\n" + "packed_data: 3002\n" + "packed_data: 11295\n" + "packed_data: 6360\n"; + +// Send a number of samples through the MatrixMultiplyCalculator. +TEST(MatrixMultiplyCalculatorTest, Multiply) { + CalculatorRunner runner("MatrixMultiplyCalculator", "", 1, 1, 1); + Matrix* matrix = new Matrix(); + MatrixFromTextProto(kMatrixText, matrix); + runner.MutableSidePackets()->Index(0) = Adopt(matrix); + + Matrix samples; + MatrixFromTextProto(kSamplesText, &samples); + Matrix expected; + MatrixFromTextProto(kExpectedText, &expected); + CHECK_EQ(samples.cols(), expected.cols()); + + for (int i = 0; i < samples.cols(); ++i) { + // Take a column from samples and produce a packet with just that + // column in it as an input sample for the calculator. + Eigen::MatrixXf* sample = new Eigen::MatrixXf(samples.block(0, i, 4, 1)); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(sample).At(Timestamp(i))); + } + + MEDIAPIPE_ASSERT_OK(runner.Run()); + EXPECT_EQ(runner.MutableInputs()->Index(0).packets.size(), + runner.Outputs().Index(0).packets.size()); + + int i = 0; + for (const Packet& output : runner.Outputs().Index(0).packets) { + EXPECT_EQ(Timestamp(i), output.Timestamp()); + const Eigen::MatrixXf& result = output.Get(); + ASSERT_EQ(3, result.rows()); + EXPECT_NEAR((expected.block(0, i, 3, 1) - result).cwiseAbs().sum(), 0.0, + 1e-5); + ++i; + } + EXPECT_EQ(samples.cols(), i); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/matrix_subtract_calculator.cc b/mediapipe/calculators/core/matrix_subtract_calculator.cc new file mode 100644 index 000000000..af13a0d38 --- /dev/null +++ b/mediapipe/calculators/core/matrix_subtract_calculator.cc @@ -0,0 +1,123 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Eigen/Core" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +// Subtract input matrix from the side input matrix and vice versa. The matrices +// must have the same dimension. +// Based on the tag (MINUEND vs SUBTRAHEND), the matrices in the input stream +// and input side packet can be either subtrahend or minuend. The output matrix +// is generated by performing minuend matrix - subtrahend matrix. +// +// Example config: +// node { +// calculator: "MatrixSubtractCalculator" +// input_stream: "MINUEND:input_matrix" +// input_side_packet: "SUBTRAHEND:side_matrix" +// output_stream: "output_matrix" +// } +// +// or +// +// node { +// calculator: "MatrixSubtractCalculator" +// input_stream: "SUBTRAHEND:input_matrix" +// input_side_packet: "MINUEND:side_matrix" +// output_stream: "output_matrix" +// } +class MatrixSubtractCalculator : public CalculatorBase { + public: + MatrixSubtractCalculator() {} + ~MatrixSubtractCalculator() override {} + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + bool subtract_from_input_ = false; +}; +REGISTER_CALCULATOR(MatrixSubtractCalculator); + +// static +::mediapipe::Status MatrixSubtractCalculator::GetContract( + CalculatorContract* cc) { + if (cc->Inputs().NumEntries() != 1 || + cc->InputSidePackets().NumEntries() != 1) { + return ::mediapipe::InvalidArgumentError( + "MatrixSubtractCalculator only accepts exactly one input stream and " + "one " + "input side packet"); + } + if (cc->Inputs().HasTag("MINUEND") && + cc->InputSidePackets().HasTag("SUBTRAHEND")) { + cc->Inputs().Tag("MINUEND").Set(); + cc->InputSidePackets().Tag("SUBTRAHEND").Set(); + } else if (cc->Inputs().HasTag("SUBTRAHEND") && + cc->InputSidePackets().HasTag("MINUEND")) { + cc->Inputs().Tag("SUBTRAHEND").Set(); + cc->InputSidePackets().Tag("MINUEND").Set(); + } else { + return ::mediapipe::InvalidArgumentError( + "Must specify exactly one minuend and one subtrahend."); + } + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status MatrixSubtractCalculator::Open(CalculatorContext* cc) { + // The output is at the same timestamp as the input. + cc->SetOffset(TimestampDiff(0)); + if (cc->Inputs().HasTag("MINUEND")) { + subtract_from_input_ = true; + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) { + Matrix* subtracted = new Matrix(); + if (subtract_from_input_) { + const Matrix& input_matrix = cc->Inputs().Tag("MINUEND").Get(); + const Matrix& side_input_matrix = + cc->InputSidePackets().Tag("SUBTRAHEND").Get(); + if (input_matrix.rows() != side_input_matrix.rows() || + input_matrix.cols() != side_input_matrix.cols()) { + return ::mediapipe::InvalidArgumentError( + "Input matrix and the input side matrix must have the same " + "dimension."); + } + *subtracted = input_matrix - side_input_matrix; + } else { + const Matrix& input_matrix = cc->Inputs().Tag("SUBTRAHEND").Get(); + const Matrix& side_input_matrix = + cc->InputSidePackets().Tag("MINUEND").Get(); + if (input_matrix.rows() != side_input_matrix.rows() || + input_matrix.cols() != side_input_matrix.cols()) { + return ::mediapipe::InvalidArgumentError( + "Input matrix and the input side matrix must have the same " + "dimension."); + } + *subtracted = side_input_matrix - input_matrix; + } + cc->Outputs().Index(0).Add(subtracted, cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/matrix_subtract_calculator_test.cc b/mediapipe/calculators/core/matrix_subtract_calculator_test.cc new file mode 100644 index 000000000..edba5f116 --- /dev/null +++ b/mediapipe/calculators/core/matrix_subtract_calculator_test.cc @@ -0,0 +1,157 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "Eigen/Core" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/validate_type.h" + +namespace mediapipe { +namespace { + +// A 3x4 Matrix of random integers in [0,1000). +const char kMatrixText[] = + "rows: 3\n" + "cols: 4\n" + "packed_data: 387\n" + "packed_data: 940\n" + "packed_data: 815\n" + "packed_data: 825\n" + "packed_data: 997\n" + "packed_data: 884\n" + "packed_data: 419\n" + "packed_data: 763\n" + "packed_data: 123\n" + "packed_data: 30\n" + "packed_data: 825\n" + "packed_data: 299\n"; + +const char kMatrixText2[] = + "rows: 3\n" + "cols: 4\n" + "packed_data: 388\n" + "packed_data: 941\n" + "packed_data: 816\n" + "packed_data: 826\n" + "packed_data: 998\n" + "packed_data: 885\n" + "packed_data: 420\n" + "packed_data: 764\n" + "packed_data: 124\n" + "packed_data: 31\n" + "packed_data: 826\n" + "packed_data: 300\n"; + +TEST(MatrixSubtractCalculatorTest, WrongConfig) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "MatrixSubtractCalculator" + input_stream: "input_matrix" + input_side_packet: "SUBTRAHEND:side_matrix" + input_side_packet: "MINUEND:side_matrix2" + output_stream: "output_matrix" + )"); + CalculatorRunner runner(node_config); + auto status = runner.Run(); + EXPECT_THAT( + status.message(), + testing::HasSubstr( + "only accepts exactly one input stream and one input side packet")); +} + +TEST(MatrixSubtractCalculatorTest, WrongConfig2) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "MatrixSubtractCalculator" + input_side_packet: "SUBTRAHEND:side_matrix" + input_stream: "SUBTRAHEND:side_matrix2" + output_stream: "output_matrix" + )"); + CalculatorRunner runner(node_config); + auto status = runner.Run(); + EXPECT_THAT( + status.message(), + testing::HasSubstr("specify exactly one minuend and one subtrahend.")); +} + +TEST(MatrixSubtractCalculatorTest, SubtractFromInput) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "MatrixSubtractCalculator" + input_stream: "MINUEND:input_matrix" + input_side_packet: "SUBTRAHEND:side_matrix" + output_stream: "output_matrix" + )"); + CalculatorRunner runner(node_config); + Matrix* side_matrix = new Matrix(); + MatrixFromTextProto(kMatrixText, side_matrix); + runner.MutableSidePackets()->Tag("SUBTRAHEND") = Adopt(side_matrix); + + Matrix* input_matrix = new Matrix(); + MatrixFromTextProto(kMatrixText2, input_matrix); + runner.MutableInputs()->Tag("MINUEND").packets.push_back( + Adopt(input_matrix).At(Timestamp(0))); + + MEDIAPIPE_ASSERT_OK(runner.Run()); + EXPECT_EQ(1, runner.Outputs().Index(0).packets.size()); + + EXPECT_EQ(Timestamp(0), runner.Outputs().Index(0).packets[0].Timestamp()); + const Eigen::MatrixXf& result = + runner.Outputs().Index(0).packets[0].Get(); + ASSERT_EQ(3, result.rows()); + ASSERT_EQ(4, result.cols()); + EXPECT_NEAR(result.sum(), 12, 1e-5); +} + +TEST(MatrixSubtractCalculatorTest, SubtractFromSideMatrix) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "MatrixSubtractCalculator" + input_stream: "SUBTRAHEND:input_matrix" + input_side_packet: "MINUEND:side_matrix" + output_stream: "output_matrix" + )"); + CalculatorRunner runner(node_config); + Matrix* side_matrix = new Matrix(); + MatrixFromTextProto(kMatrixText, side_matrix); + runner.MutableSidePackets()->Tag("MINUEND") = Adopt(side_matrix); + + Matrix* input_matrix = new Matrix(); + MatrixFromTextProto(kMatrixText2, input_matrix); + runner.MutableInputs() + ->Tag("SUBTRAHEND") + .packets.push_back(Adopt(input_matrix).At(Timestamp(0))); + + MEDIAPIPE_ASSERT_OK(runner.Run()); + EXPECT_EQ(1, runner.Outputs().Index(0).packets.size()); + + EXPECT_EQ(Timestamp(0), runner.Outputs().Index(0).packets[0].Timestamp()); + const Eigen::MatrixXf& result = + runner.Outputs().Index(0).packets[0].Get(); + ASSERT_EQ(3, result.rows()); + ASSERT_EQ(4, result.cols()); + EXPECT_NEAR(result.sum(), -12, 1e-5); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/merge_calculator.cc b/mediapipe/calculators/core/merge_calculator.cc new file mode 100644 index 000000000..e85ae0c12 --- /dev/null +++ b/mediapipe/calculators/core/merge_calculator.cc @@ -0,0 +1,91 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +// This calculator takes a set of input streams and combines them into a single +// output stream. The packets from different streams do not need to contain the +// same type. If there are packets arriving at the same time from two or more +// input streams, the packet corresponding to the input stream with the smallest +// index is passed to the output and the rest are ignored. +// +// Example use-case: +// Suppose we have two (or more) different algorithms for detecting shot +// boundaries and we need to merge their packets into a single stream. The +// algorithms may emit shot boundaries at the same time and their output types +// may not be compatible. Subsequent calculators that process the merged stream +// may be interested only in the timestamps of the shot boundary packets and so +// it may not even need to inspect the values stored inside the packets. +// +// Example config: +// node { +// calculator: "MergeCalculator" +// input_stream: "shot_info1" +// input_stream: "shot_info2" +// input_stream: "shot_info3" +// output_stream: "merged_shot_infos" +// } +// +class MergeCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK_GT(cc->Inputs().NumEntries(), 0) + << "Needs at least one input stream"; + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); + if (cc->Inputs().NumEntries() == 1) { + LOG(WARNING) + << "MergeCalculator expects multiple input streams to merge but is " + "receiving only one. Make sure the calculator is configured " + "correctly or consider removing this calculator to reduce " + "unnecessary overhead."; + } + + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + cc->Inputs().Index(i).SetAny(); + } + cc->Outputs().Index(0).SetAny(); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(0)); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + // Output the packet from the first input stream with a packet ready at this + // timestamp. + for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { + if (!cc->Inputs().Index(i).IsEmpty()) { + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(i).Value()); + return ::mediapipe::OkStatus(); + } + } + + LOG(WARNING) << "Empty input packets at timestamp " + << cc->InputTimestamp().Value(); + + return ::mediapipe::OkStatus(); + } +}; + +REGISTER_CALCULATOR(MergeCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/merge_calculator_test.cc b/mediapipe/calculators/core/merge_calculator_test.cc new file mode 100644 index 000000000..53185c0d8 --- /dev/null +++ b/mediapipe/calculators/core/merge_calculator_test.cc @@ -0,0 +1,139 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +// Checks that the calculator fails if no input streams are provided. +TEST(InvariantMergeInputStreamsCalculator, NoInputStreamsMustFail) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "MergeCalculator" + output_stream: "merged_output" + )")); + // Expect calculator to fail. + ASSERT_FALSE(runner.Run().ok()); +} + +// Checks that the calculator fails with an incorrect number of output streams. +TEST(InvariantMergeInputStreamsCalculator, ExpectExactlyOneOutputStream) { + CalculatorRunner runner1(ParseTextProtoOrDie(R"( + calculator: "MergeCalculator" + input_stream: "input1" + input_stream: "input2" + )")); + // Expect calculator to fail. + EXPECT_FALSE(runner1.Run().ok()); + + CalculatorRunner runner2(ParseTextProtoOrDie(R"( + calculator: "MergeCalculator" + input_stream: "input1" + input_stream: "input2" + output_stream: "output1" + output_stream: "output2" + )")); + // Expect calculator to fail. + ASSERT_FALSE(runner2.Run().ok()); +} + +// Ensures two streams with differing types can be merged correctly. +TEST(MediaPipeDetectionToSoapboxDetectionCalculatorTest, + TestMergingTwoStreams) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "MergeCalculator" + input_stream: "input1" + input_stream: "input2" + output_stream: "combined_output" + )")); + + // input1: integers 10, 20, 30, occurring at times 10, 20, 30. + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new int(10)).At(Timestamp(10))); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new int(20)).At(Timestamp(20))); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new int(30)).At(Timestamp(30))); + // input2: floats 5.5, 35.5 at times 5, 35. + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(new float(5.5)).At(Timestamp(5))); + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(new float(35.5)).At(Timestamp(35))); + + MEDIAPIPE_ASSERT_OK(runner.Run()); + + // Expected combined_output: 5.5, 10, 20, 30, 35.5 at times 5, 10, 20, 30, 35. + const std::vector& actual_output = runner.Outputs().Index(0).packets; + ASSERT_EQ(actual_output.size(), 5); + EXPECT_EQ(actual_output[0].Timestamp(), Timestamp(5)); + EXPECT_EQ(actual_output[0].Get(), 5.5); + + EXPECT_EQ(actual_output[1].Timestamp(), Timestamp(10)); + EXPECT_EQ(actual_output[1].Get(), 10); + + EXPECT_EQ(actual_output[2].Timestamp(), Timestamp(20)); + EXPECT_EQ(actual_output[2].Get(), 20); + + EXPECT_EQ(actual_output[3].Timestamp(), Timestamp(30)); + EXPECT_EQ(actual_output[3].Get(), 30); + + EXPECT_EQ(actual_output[4].Timestamp(), Timestamp(35)); + EXPECT_EQ(actual_output[4].Get(), 35.5); +} + +// Ensures three streams with differing types can be merged correctly. +TEST(MediaPipeDetectionToSoapboxDetectionCalculatorTest, + TestMergingThreeStreams) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "MergeCalculator" + input_stream: "input1" + input_stream: "input2" + input_stream: "input3" + output_stream: "combined_output" + )")); + + // input1: integer 30 occurring at time 30. + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(new int(30)).At(Timestamp(30))); + // input2: float 20.5 occurring at time 20. + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(new float(20.5)).At(Timestamp(20))); + // input3: char 'c' occurring at time 10. + runner.MutableInputs()->Index(2).packets.push_back( + Adopt(new char('c')).At(Timestamp(10))); + + MEDIAPIPE_ASSERT_OK(runner.Run()); + + // Expected combined_output: 'c', 20.5, 30 at times 10, 20, 30. + const std::vector& actual_output = runner.Outputs().Index(0).packets; + ASSERT_EQ(actual_output.size(), 3); + EXPECT_EQ(actual_output[0].Timestamp(), Timestamp(10)); + EXPECT_EQ(actual_output[0].Get(), 'c'); + + EXPECT_EQ(actual_output[1].Timestamp(), Timestamp(20)); + EXPECT_EQ(actual_output[1].Get(), 20.5); + + EXPECT_EQ(actual_output[2].Timestamp(), Timestamp(30)); + EXPECT_EQ(actual_output[2].Get(), 30); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/mux_calculator.cc b/mediapipe/calculators/core/mux_calculator.cc index 7f96da760..1d1ae1904 100644 --- a/mediapipe/calculators/core/mux_calculator.cc +++ b/mediapipe/calculators/core/mux_calculator.cc @@ -51,7 +51,7 @@ class MuxCalculator : public CalculatorBase { data_input_base_ = cc->Inputs().GetId("INPUT", 0); num_data_inputs_ = cc->Inputs().NumEntries("INPUT"); output_ = cc->Outputs().GetId("OUTPUT", 0); - cc->SetOffset(mediapipe::TimestampDiff(0)); + cc->SetOffset(TimestampDiff(0)); return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/core/packet_cloner_calculator.cc b/mediapipe/calculators/core/packet_cloner_calculator.cc index 2750f1257..26044fc2c 100644 --- a/mediapipe/calculators/core/packet_cloner_calculator.cc +++ b/mediapipe/calculators/core/packet_cloner_calculator.cc @@ -19,6 +19,7 @@ #include #include "absl/strings/str_cat.h" +#include "mediapipe/calculators/core/packet_cloner_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" namespace mediapipe { @@ -39,6 +40,7 @@ namespace mediapipe { // } // // Related: +// packet_cloner_calculator.proto: Options for this calculator. // merge_input_streams_calculator.cc: One output stream. // packet_inner_join_calculator.cc: Don't output unless all inputs are new. class PacketClonerCalculator : public CalculatorBase { @@ -54,6 +56,13 @@ class PacketClonerCalculator : public CalculatorBase { } ::mediapipe::Status Open(CalculatorContext* cc) final { + // Load options. + const auto calculator_options = + cc->Options(); + output_only_when_all_inputs_received_ = + calculator_options.output_only_when_all_inputs_received(); + + // Parse input streams. tick_signal_index_ = cc->Inputs().NumEntries() - 1; current_.resize(tick_signal_index_); // Pass along the header for each stream if present. @@ -73,8 +82,17 @@ class PacketClonerCalculator : public CalculatorBase { } } - // Output if the tick signal is non-empty. + // Output according to the TICK signal. if (!cc->Inputs().Index(tick_signal_index_).Value().IsEmpty()) { + if (output_only_when_all_inputs_received_) { + // Return if one of the input is null. + for (int i = 0; i < tick_signal_index_; ++i) { + if (current_[i].IsEmpty()) { + return ::mediapipe::OkStatus(); + } + } + } + // Output each stream. for (int i = 0; i < tick_signal_index_; ++i) { if (!current_[i].IsEmpty()) { cc->Outputs().Index(i).AddPacket( @@ -91,6 +109,7 @@ class PacketClonerCalculator : public CalculatorBase { private: std::vector current_; int tick_signal_index_; + bool output_only_when_all_inputs_received_; }; REGISTER_CALCULATOR(PacketClonerCalculator); diff --git a/mediapipe/calculators/core/packet_cloner_calculator.proto b/mediapipe/calculators/core/packet_cloner_calculator.proto new file mode 100644 index 000000000..7abb16163 --- /dev/null +++ b/mediapipe/calculators/core/packet_cloner_calculator.proto @@ -0,0 +1,29 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message PacketClonerCalculatorOptions { + extend CalculatorOptions { + optional PacketClonerCalculatorOptions ext = 258872085; + } + + // When true, this calculator will drop received TICK packets if any input + // stream hasn't received a packet yet. + optional bool output_only_when_all_inputs_received = 1 [default = false]; +} diff --git a/mediapipe/calculators/core/packet_inner_join_calculator.cc b/mediapipe/calculators/core/packet_inner_join_calculator.cc new file mode 100644 index 000000000..2b93df3cf --- /dev/null +++ b/mediapipe/calculators/core/packet_inner_join_calculator.cc @@ -0,0 +1,78 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +// Calculator that acts like the SQL query: +// SELECT * +// FROM packets_on_stream1 AS packet1 +// INNER JOIN packets_on_stream2 AS packet2 +// ON packet1.timestamp = packet2.timestamp +// +// In other words, it only emits and forwards packets if all input streams are +// not empty. +// +// Intended for use with FixedSizeInputStreamHandler. +// +// Related: +// packet_cloner_calculator.cc: Repeats last-seen packets from empty inputs. +class PacketInnerJoinCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + int num_streams_; +}; + +REGISTER_CALCULATOR(PacketInnerJoinCalculator); + +::mediapipe::Status PacketInnerJoinCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().NumEntries() == cc->Outputs().NumEntries()) + << "The number of input and output streams must match."; + const int num_streams = cc->Inputs().NumEntries(); + for (int i = 0; i < num_streams; ++i) { + cc->Inputs().Index(i).SetAny(); + cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i)); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PacketInnerJoinCalculator::Open(CalculatorContext* cc) { + num_streams_ = cc->Inputs().NumEntries(); + cc->SetOffset(TimestampDiff(0)); + return mediapipe::OkStatus(); +} + +::mediapipe::Status PacketInnerJoinCalculator::Process(CalculatorContext* cc) { + for (int i = 0; i < num_streams_; ++i) { + if (cc->Inputs().Index(i).Value().IsEmpty()) { + return ::mediapipe::OkStatus(); + } + } + for (int i = 0; i < num_streams_; ++i) { + cc->Outputs().Index(i).AddPacket(cc->Inputs().Index(i).Value()); + } + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/packet_inner_join_calculator_test.cc b/mediapipe/calculators/core/packet_inner_join_calculator_test.cc new file mode 100644 index 000000000..8eed9a3e6 --- /dev/null +++ b/mediapipe/calculators/core/packet_inner_join_calculator_test.cc @@ -0,0 +1,101 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/validate_type.h" + +namespace mediapipe { +namespace { + +inline Packet PacketFrom(int i) { return Adopt(new int(i)).At(Timestamp(i)); } + +TEST(PacketInnerJoinCalculatorTest, AllMatching) { + // Test case. + const std::vector packets_on_stream1 = {0, 1, 2, 3}; + const std::vector packets_on_stream2 = {0, 1, 2, 3}; + // Run. + CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0); + for (int packet_load : packets_on_stream1) { + runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load)); + } + for (int packet_load : packets_on_stream2) { + runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load)); + } + MEDIAPIPE_ASSERT_OK(runner.Run()); + // Check. + const std::vector expected = {0, 1, 2, 3}; + ASSERT_EQ(expected.size(), runner.Outputs().Index(0).packets.size()); + ASSERT_EQ(expected.size(), runner.Outputs().Index(1).packets.size()); + for (int i = 0; i < expected.size(); ++i) { + const Packet packet1 = runner.Outputs().Index(0).packets[i]; + EXPECT_EQ(expected[i], packet1.Get()); + EXPECT_EQ(expected[i], packet1.Timestamp().Value()); + const Packet packet2 = runner.Outputs().Index(1).packets[i]; + EXPECT_EQ(expected[i], packet2.Get()); + EXPECT_EQ(expected[i], packet2.Timestamp().Value()); + } +} + +TEST(PacketInnerJoinCalculatorTest, NoneMatching) { + // Test case. + const std::vector packets_on_stream1 = {0, 2}; + const std::vector packets_on_stream2 = {1, 3}; + // Run. + CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0); + for (int packet_load : packets_on_stream1) { + runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load)); + } + for (int packet_load : packets_on_stream2) { + runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load)); + } + MEDIAPIPE_ASSERT_OK(runner.Run()); + // Check. + EXPECT_TRUE(runner.Outputs().Index(0).packets.empty()); + EXPECT_TRUE(runner.Outputs().Index(1).packets.empty()); +} + +TEST(PacketInnerJoinCalculatorTest, SomeMatching) { + // Test case. + const std::vector packets_on_stream1 = {0, 1, 2, 3, 4, 6}; + const std::vector packets_on_stream2 = {0, 2, 4, 5, 6}; + // Run. + CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0); + for (int packet_load : packets_on_stream1) { + runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load)); + } + for (int packet_load : packets_on_stream2) { + runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load)); + } + MEDIAPIPE_ASSERT_OK(runner.Run()); + // Check. + const std::vector expected = {0, 2, 4, 6}; + ASSERT_EQ(expected.size(), runner.Outputs().Index(0).packets.size()); + ASSERT_EQ(expected.size(), runner.Outputs().Index(1).packets.size()); + for (int i = 0; i < expected.size(); ++i) { + const Packet packet1 = runner.Outputs().Index(0).packets[i]; + EXPECT_EQ(expected[i], packet1.Get()); + EXPECT_EQ(expected[i], packet1.Timestamp().Value()); + const Packet packet2 = runner.Outputs().Index(1).packets[i]; + EXPECT_EQ(expected[i], packet2.Get()); + EXPECT_EQ(expected[i], packet2.Timestamp().Value()); + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/core/quantize_float_vector_calculator.cc b/mediapipe/calculators/core/quantize_float_vector_calculator.cc new file mode 100644 index 000000000..76e635e5b --- /dev/null +++ b/mediapipe/calculators/core/quantize_float_vector_calculator.cc @@ -0,0 +1,102 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "mediapipe/calculators/core/quantize_float_vector_calculator.pb.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/status.h" + +// Quantizes a vector of floats to a std::string so that each float becomes a +// byte in the [0, 255] range. Any value above max_quantized_value or below +// min_quantized_value will be saturated to '/xFF' or '/0'. +// +// Example config: +// node { +// calculator: "QuantizeFloatVectorCalculator" +// input_stream: "FLOAT_VECTOR:float_vector" +// output_stream: "ENCODED:encoded" +// options { +// [mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: { +// max_quantized_value: 64 +// min_quantized_value: -64 +// } +// } +// } +namespace mediapipe { + +class QuantizeFloatVectorCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Tag("FLOAT_VECTOR").Set>(); + cc->Outputs().Tag("ENCODED").Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + const auto options = + cc->Options<::mediapipe::QuantizeFloatVectorCalculatorOptions>(); + if (!options.has_max_quantized_value() || + !options.has_min_quantized_value()) { + return ::mediapipe::InvalidArgumentError( + "Both max_quantized_value and min_quantized_value must be provided " + "in QuantizeFloatVectorCalculatorOptions."); + } + max_quantized_value_ = options.max_quantized_value(); + min_quantized_value_ = options.min_quantized_value(); + if (max_quantized_value_ < min_quantized_value_ + FLT_EPSILON) { + return ::mediapipe::InvalidArgumentError( + "max_quantized_value must be greater than min_quantized_value."); + } + range_ = max_quantized_value_ - min_quantized_value_; + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + const std::vector& float_vector = + cc->Inputs().Tag("FLOAT_VECTOR").Value().Get>(); + int feature_size = float_vector.size(); + std::string encoded_features; + encoded_features.reserve(feature_size); + for (int i = 0; i < feature_size; i++) { + float old_value = float_vector[i]; + if (old_value < min_quantized_value_) { + old_value = min_quantized_value_; + } + if (old_value > max_quantized_value_) { + old_value = max_quantized_value_; + } + unsigned char encoded = static_cast( + (old_value - min_quantized_value_) * (255.0 / range_)); + encoded_features += encoded; + } + cc->Outputs().Tag("ENCODED").AddPacket( + MakePacket(encoded_features).At(cc->InputTimestamp())); + return ::mediapipe::OkStatus(); + } + + private: + float max_quantized_value_; + float min_quantized_value_; + float range_; +}; + +REGISTER_CALCULATOR(QuantizeFloatVectorCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/quantize_float_vector_calculator.proto b/mediapipe/calculators/core/quantize_float_vector_calculator.proto new file mode 100644 index 000000000..3f6cfda21 --- /dev/null +++ b/mediapipe/calculators/core/quantize_float_vector_calculator.proto @@ -0,0 +1,28 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message QuantizeFloatVectorCalculatorOptions { + extend CalculatorOptions { + optional QuantizeFloatVectorCalculatorOptions ext = 259848061; + } + + optional float max_quantized_value = 1; + optional float min_quantized_value = 2; +} diff --git a/mediapipe/calculators/core/quantize_float_vector_calculator_test.cc b/mediapipe/calculators/core/quantize_float_vector_calculator_test.cc new file mode 100644 index 000000000..c5566297e --- /dev/null +++ b/mediapipe/calculators/core/quantize_float_vector_calculator_test.cc @@ -0,0 +1,204 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT + +namespace mediapipe { + +TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "QuantizeFloatVectorCalculator" + input_stream: "FLOAT_VECTOR:float_vector" + output_stream: "ENCODED:encoded" + options { + [mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: { + min_quantized_value: 1 + } + } + )"); + CalculatorRunner runner(node_config); + std::vector empty_vector; + runner.MutableInputs() + ->Tag("FLOAT_VECTOR") + .packets.push_back( + MakePacket>(empty_vector).At(Timestamp(0))); + auto status = runner.Run(); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + testing::HasSubstr( + "Both max_quantized_value and min_quantized_value must be provided")); +} + +TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "QuantizeFloatVectorCalculator" + input_stream: "FLOAT_VECTOR:float_vector" + output_stream: "ENCODED:encoded" + options { + [mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: { + max_quantized_value: -1 + min_quantized_value: 1 + } + } + )"); + CalculatorRunner runner(node_config); + std::vector empty_vector; + runner.MutableInputs() + ->Tag("FLOAT_VECTOR") + .packets.push_back( + MakePacket>(empty_vector).At(Timestamp(0))); + auto status = runner.Run(); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + testing::HasSubstr( + "max_quantized_value must be greater than min_quantized_value")); +} + +TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "QuantizeFloatVectorCalculator" + input_stream: "FLOAT_VECTOR:float_vector" + output_stream: "ENCODED:encoded" + options { + [mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: { + max_quantized_value: 1 + min_quantized_value: 1 + } + } + )"); + CalculatorRunner runner(node_config); + std::vector empty_vector; + runner.MutableInputs() + ->Tag("FLOAT_VECTOR") + .packets.push_back( + MakePacket>(empty_vector).At(Timestamp(0))); + auto status = runner.Run(); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + testing::HasSubstr( + "max_quantized_value must be greater than min_quantized_value")); +} + +TEST(QuantizeFloatVectorCalculatorTest, TestEmptyVector) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "QuantizeFloatVectorCalculator" + input_stream: "FLOAT_VECTOR:float_vector" + output_stream: "ENCODED:encoded" + options { + [mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: { + max_quantized_value: 1 + min_quantized_value: -1 + } + } + )"); + CalculatorRunner runner(node_config); + std::vector empty_vector; + runner.MutableInputs() + ->Tag("FLOAT_VECTOR") + .packets.push_back( + MakePacket>(empty_vector).At(Timestamp(0))); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const std::vector& outputs = runner.Outputs().Tag("ENCODED").packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_TRUE(outputs[0].Get().empty()); + EXPECT_EQ(Timestamp(0), outputs[0].Timestamp()); +} + +TEST(QuantizeFloatVectorCalculatorTest, TestNonEmptyVector) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "QuantizeFloatVectorCalculator" + input_stream: "FLOAT_VECTOR:float_vector" + output_stream: "ENCODED:encoded" + options { + [mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: { + max_quantized_value: 64 + min_quantized_value: -64 + } + } + )"); + CalculatorRunner runner(node_config); + std::vector vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f}; + runner.MutableInputs() + ->Tag("FLOAT_VECTOR") + .packets.push_back( + MakePacket>(vector).At(Timestamp(0))); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const std::vector& outputs = runner.Outputs().Tag("ENCODED").packets; + EXPECT_EQ(1, outputs.size()); + const std::string& result = outputs[0].Get(); + ASSERT_FALSE(result.empty()); + EXPECT_EQ(5, result.size()); + // 127 + EXPECT_EQ('\x7F', result.c_str()[0]); + // 0 + EXPECT_EQ('\0', result.c_str()[1]); + // 255 + EXPECT_EQ('\xFF', result.c_str()[2]); + // 63 + EXPECT_EQ('\x3F', result.c_str()[3]); + // 191 + EXPECT_EQ('\xBF', result.c_str()[4]); + EXPECT_EQ(Timestamp(0), outputs[0].Timestamp()); +} + +TEST(QuantizeFloatVectorCalculatorTest, TestSaturation) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie(R"( + calculator: "QuantizeFloatVectorCalculator" + input_stream: "FLOAT_VECTOR:float_vector" + output_stream: "ENCODED:encoded" + options { + [mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: { + max_quantized_value: 64 + min_quantized_value: -64 + } + } + )"); + CalculatorRunner runner(node_config); + std::vector vector = {-65.0f, 65.0f}; + runner.MutableInputs() + ->Tag("FLOAT_VECTOR") + .packets.push_back( + MakePacket>(vector).At(Timestamp(0))); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const std::vector& outputs = runner.Outputs().Tag("ENCODED").packets; + EXPECT_EQ(1, outputs.size()); + const std::string& result = outputs[0].Get(); + ASSERT_FALSE(result.empty()); + EXPECT_EQ(2, result.size()); + // 0 + EXPECT_EQ('\0', result.c_str()[0]); + // 255 + EXPECT_EQ('\xFF', result.c_str()[1]); + EXPECT_EQ(Timestamp(0), outputs[0].Timestamp()); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/sequence_shift_calculator.cc b/mediapipe/calculators/core/sequence_shift_calculator.cc new file mode 100644 index 000000000..f2ab11025 --- /dev/null +++ b/mediapipe/calculators/core/sequence_shift_calculator.cc @@ -0,0 +1,114 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/core/sequence_shift_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" + +namespace mediapipe { + +// A Calculator that shifts the timestamps of packets along a stream. Packets on +// the input stream are output with a timestamp of the packet given by packet +// offset. That is, packet[i] is output with the timestamp of +// packet[i + packet_offset]. Packet offset can be either positive or negative. +// If packet_offset is -n, the first n packets will be dropped. If packet offset +// is n, the final n packets will be dropped. For example, with a packet_offset +// of -1, the first packet on the stream will be dropped, the second will be +// output with the timestamp of the first, the third with the timestamp of the +// second, and so on. +class SequenceShiftCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + + // Reads from options to set cache_size_ and packet_offset_. + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + // A positive offset means we want a packet to be output with the timestamp of + // a later packet. Stores packets waiting for their output timestamps and + // outputs a single packet when the cache fills. + void ProcessPositiveOffset(CalculatorContext* cc); + + // A negative offset means we want a packet to be output with the timestamp of + // an earlier packet. Stores timestamps waiting for the corresponding input + // packet and outputs a single packet when the cache fills. + void ProcessNegativeOffset(CalculatorContext* cc); + + // Storage for packets waiting to be output when packet_offset > 0. When cache + // is full, oldest packet is output with current timestamp. + std::deque packet_cache_; + + // Storage for previous timestamps used when packet_offset < 0. When cache is + // full, oldest timestamp is used for current packet. + std::deque timestamp_cache_; + + // Copied from corresponding field in options. + int packet_offset_; + // The number of packets or timestamps we need to store to output packet[i] at + // the timestamp of packet[i + packet_offset]; equal to abs(packet_offset). + int cache_size_; +}; +REGISTER_CALCULATOR(SequenceShiftCalculator); + +::mediapipe::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { + packet_offset_ = + cc->Options().packet_offset(); + cache_size_ = abs(packet_offset_); + // An offset of zero is a no-op, but someone might still request it. + if (packet_offset_ == 0) { + cc->Outputs().Index(0).SetOffset(0); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) { + if (packet_offset_ > 0) { + ProcessPositiveOffset(cc); + } else if (packet_offset_ < 0) { + ProcessNegativeOffset(cc); + } else { + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + } + return ::mediapipe::OkStatus(); +} + +void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) { + if (packet_cache_.size() >= cache_size_) { + // Ready to output oldest packet with current timestamp. + cc->Outputs().Index(0).AddPacket( + packet_cache_.front().At(cc->InputTimestamp())); + packet_cache_.pop_front(); + } + // Store current packet for later output. + packet_cache_.push_back(cc->Inputs().Index(0).Value()); +} + +void SequenceShiftCalculator::ProcessNegativeOffset(CalculatorContext* cc) { + if (timestamp_cache_.size() >= cache_size_) { + // Ready to output current packet with oldest timestamp. + cc->Outputs().Index(0).AddPacket( + cc->Inputs().Index(0).Value().At(timestamp_cache_.front())); + timestamp_cache_.pop_front(); + } + // Store current timestamp for use by a future packet. + timestamp_cache_.push_back(cc->InputTimestamp()); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/sequence_shift_calculator.proto b/mediapipe/calculators/core/sequence_shift_calculator.proto new file mode 100644 index 000000000..15b111d71 --- /dev/null +++ b/mediapipe/calculators/core/sequence_shift_calculator.proto @@ -0,0 +1,26 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message SequenceShiftCalculatorOptions { + extend CalculatorOptions { + optional SequenceShiftCalculatorOptions ext = 107633927; + } + optional int32 packet_offset = 1 [default = -1]; +} diff --git a/mediapipe/calculators/core/sequence_shift_calculator_test.cc b/mediapipe/calculators/core/sequence_shift_calculator_test.cc new file mode 100644 index 000000000..0466fe3b1 --- /dev/null +++ b/mediapipe/calculators/core/sequence_shift_calculator_test.cc @@ -0,0 +1,104 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/validate_type.h" + +namespace mediapipe { + +namespace { + +// Adds packets containing integers equal to their original timestamp. +void AddPackets(CalculatorRunner* runner) { + for (int i = 0; i < 10; ++i) { + runner->MutableInputs()->Index(0).packets.push_back( + Adopt(new int(i)).At(Timestamp(i))); + } +} + +// Zero shift is a no-op (output input[i] at timestamp[i]). Input and output +// streams should be identical. +TEST(SequenceShiftCalculatorTest, ZeroShift) { + CalculatorRunner runner( + "SequenceShiftCalculator", + "[mediapipe.SequenceShiftCalculatorOptions.ext]: { packet_offset: 0 }", 1, + 1, 0); + AddPackets(&runner); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const std::vector& input_packets = + runner.MutableInputs()->Index(0).packets; + const std::vector& output_packets = runner.Outputs().Index(0).packets; + ASSERT_EQ(10, input_packets.size()); + ASSERT_EQ(input_packets.size(), output_packets.size()); + for (int i = 0; i < output_packets.size(); ++i) { + // Make sure the contents are as expected. + EXPECT_EQ(input_packets[i].Get(), output_packets[i].Get()); + EXPECT_EQ(input_packets[i].Timestamp(), output_packets[i].Timestamp()); + } +} + +// Tests shifting by three packets, i.e., output input[i] with the timestamp of +// input[i + 3]. +TEST(SequenceShiftCalculatorTest, PositiveShift) { + CalculatorRunner runner( + "SequenceShiftCalculator", + "[mediapipe.SequenceShiftCalculatorOptions.ext]: { packet_offset: 3 }", 1, + 1, 0); + AddPackets(&runner); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const std::vector& input_packets = + runner.MutableInputs()->Index(0).packets; + const std::vector& output_packets = runner.Outputs().Index(0).packets; + ASSERT_EQ(10, input_packets.size()); + // input_packet[i] should be output with the timestamp of input_packet[i + 3]. + // The last 3 packets are dropped. + ASSERT_EQ(7, output_packets.size()); + for (int i = 0; i < output_packets.size(); ++i) { + // Make sure the contents are as expected. + EXPECT_EQ(input_packets[i].Get(), output_packets[i].Get()); + // Make sure the timestamps are shifted as expected. + EXPECT_EQ(input_packets[i + 3].Timestamp(), output_packets[i].Timestamp()); + } +} + +// Tests shifting by -2, i.e., output input[i] with timestamp[i - 2]. The first +// two packets should be dropped. +TEST(SequenceShiftCalculatorTest, NegativeShift) { + CalculatorRunner runner( + "SequenceShiftCalculator", + "[mediapipe.SequenceShiftCalculatorOptions.ext]: { packet_offset: -2 }", + 1, 1, 0); + AddPackets(&runner); + MEDIAPIPE_ASSERT_OK(runner.Run()); + const std::vector& input_packets = + runner.MutableInputs()->Index(0).packets; + const std::vector& output_packets = runner.Outputs().Index(0).packets; + ASSERT_EQ(10, input_packets.size()); + // Input packet[i] should be output with the timestamp of input packet[i - 2]. + // The first two packets are dropped. This means timestamps match between + // input and output packets, but the data in the output packets come from + // input_packets[i + 2]. + ASSERT_EQ(8, output_packets.size()); + for (int i = 0; i < output_packets.size(); ++i) { + EXPECT_EQ(input_packets[i].Timestamp(), output_packets[i].Timestamp()); + EXPECT_EQ(input_packets[i + 2].Get(), output_packets[i].Get()); + } +} + +} // namespace + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/split_vector_calculator.cc b/mediapipe/calculators/core/split_vector_calculator.cc new file mode 100644 index 000000000..2e18a570e --- /dev/null +++ b/mediapipe/calculators/core/split_vector_calculator.cc @@ -0,0 +1,40 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/core/split_vector_calculator.h" + +#include + +#include "tensorflow/lite/interpreter.h" + +namespace mediapipe { + +// Example config: +// node { +// calculator: "SplitTfLiteTensorVectorCalculator" +// input_stream: "tflitetensor_vector" +// output_stream: "tflitetensor_vector_range_0" +// output_stream: "tflitetensor_vector_range_1" +// options { +// [mediapipe.SplitVectorCalculatorOptions.ext] { +// ranges: { begin: 0 end: 1 } +// ranges: { begin: 1 end: 2 } +// element_only: false +// } +// } +// } +typedef SplitVectorCalculator SplitTfLiteTensorVectorCalculator; +REGISTER_CALCULATOR(SplitTfLiteTensorVectorCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/core/split_vector_calculator.h b/mediapipe/calculators/core/split_vector_calculator.h new file mode 100644 index 000000000..def156474 --- /dev/null +++ b/mediapipe/calculators/core/split_vector_calculator.h @@ -0,0 +1,125 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_CORE_SPLIT_VECTOR_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_CORE_SPLIT_VECTOR_CALCULATOR_H_ + +#include + +#include "mediapipe/calculators/core/split_vector_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/resource_util.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" + +namespace mediapipe { + +// Splits an input packet with std::vector into multiple std::vector +// output packets using the [begin, end) ranges specified in +// SplitVectorCalculatorOptions. If the option "element_only" is set to true, +// all ranges should be of size 1 and all outputs will be elements of type T. If +// "element_only" is false, ranges can be non-zero in size and all outputs will +// be of type std::vector. +// To use this class for a particular type T, register a calculator using +// SplitVectorCalculator. +template +class SplitVectorCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK(cc->Inputs().NumEntries() == 1); + RET_CHECK(cc->Outputs().NumEntries() != 0); + + cc->Inputs().Index(0).Set>(); + + const auto& options = + cc->Options<::mediapipe::SplitVectorCalculatorOptions>(); + + if (cc->Outputs().NumEntries() != options.ranges_size()) { + return ::mediapipe::InvalidArgumentError( + "The number of output streams should match the number of ranges " + "specified in the CalculatorOptions."); + } + + // Set the output types for each output stream. + for (int i = 0; i < cc->Outputs().NumEntries(); ++i) { + if (options.ranges(i).begin() < 0 || options.ranges(i).end() < 0 || + options.ranges(i).begin() >= options.ranges(i).end()) { + return ::mediapipe::InvalidArgumentError( + "Indices should be non-negative and begin index should be less " + "than the end index."); + } + if (options.element_only()) { + if (options.ranges(i).end() - options.ranges(i).begin() != 1) { + return ::mediapipe::InvalidArgumentError( + "Since element_only is true, all ranges should be of size 1."); + } + cc->Outputs().Index(i).Set(); + } else { + cc->Outputs().Index(i).Set>(); + } + } + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + + const auto& options = + cc->Options<::mediapipe::SplitVectorCalculatorOptions>(); + + for (const auto& range : options.ranges()) { + ranges_.push_back({range.begin(), range.end()}); + max_range_end_ = std::max(max_range_end_, range.end()); + } + + element_only_ = options.element_only(); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + const auto& input = cc->Inputs().Index(0).Get>(); + RET_CHECK_GE(input.size(), max_range_end_); + + if (element_only_) { + for (int i = 0; i < ranges_.size(); ++i) { + cc->Outputs().Index(i).AddPacket( + MakePacket(input[ranges_[i].first]).At(cc->InputTimestamp())); + } + } else { + for (int i = 0; i < ranges_.size(); ++i) { + auto output = absl::make_unique>( + input.begin() + ranges_[i].first, + input.begin() + ranges_[i].second); + cc->Outputs().Index(i).Add(output.release(), cc->InputTimestamp()); + } + } + + return ::mediapipe::OkStatus(); + } + + private: + std::vector> ranges_; + int32 max_range_end_ = -1; + bool element_only_ = false; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_CORE_SPLIT_VECTOR_CALCULATOR_H_ diff --git a/mediapipe/calculators/core/split_vector_calculator.proto b/mediapipe/calculators/core/split_vector_calculator.proto new file mode 100644 index 000000000..3ef31475b --- /dev/null +++ b/mediapipe/calculators/core/split_vector_calculator.proto @@ -0,0 +1,40 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// A Range {begin, end} specifies beginning ane ending indices to splice a +// vector. A vector v is spliced to have elements v[begin:(end-1)], i.e., with +// begin index inclusive and end index exclusive. +message Range { + optional int32 begin = 1; + optional int32 end = 2; +} + +message SplitVectorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional SplitVectorCalculatorOptions ext = 259438222; + } + + repeated Range ranges = 1; + + // Specify if single element ranges should be outputted as std::vector or + // just element of type T. By default, if a range specifies only one element, + // it is outputted as an std::vector. + optional bool element_only = 2 [default = false]; +} diff --git a/mediapipe/calculators/core/split_vector_calculator_test.cc b/mediapipe/calculators/core/split_vector_calculator_test.cc new file mode 100644 index 000000000..939783f3e --- /dev/null +++ b/mediapipe/calculators/core/split_vector_calculator_test.cc @@ -0,0 +1,321 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/core/split_vector_calculator.h" + +#include +#include +#include + +#include "mediapipe/calculators/core/split_vector_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT +#include "mediapipe/framework/tool/validate_type.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" + +namespace mediapipe { + +using ::tflite::Interpreter; + +const int width = 1; +const int height = 1; +const int channels = 1; + +class SplitTfLiteTensorVectorCalculatorTest : public ::testing::Test { + protected: + void TearDown() { + // Note: Since the pointers contained in this vector will be cleaned up by + // the interpreter, only ensure that the vector is cleaned up for the next + // test. + input_buffers_.clear(); + } + + void PrepareTfLiteTensorVector(int vector_size) { + ASSERT_NE(interpreter_, nullptr); + + // Prepare input tensors. + std::vector indices(vector_size); + for (int i = 0; i < vector_size; ++i) { + indices[i] = i; + } + interpreter_->AddTensors(vector_size); + interpreter_->SetInputs(indices); + + input_vec_ = absl::make_unique>(); + for (int i = 0; i < vector_size; ++i) { + interpreter_->SetTensorParametersReadWrite(i, kTfLiteFloat32, "", {3}, + TfLiteQuantization()); + const int tensor_index = interpreter_->inputs()[i]; + interpreter_->ResizeInputTensor(tensor_index, {width, height, channels}); + } + + interpreter_->AllocateTensors(); + + // Save the tensor buffer pointers for comparison after the graph runs. + input_buffers_ = std::vector(vector_size); + for (int i = 0; i < vector_size; ++i) { + const int tensor_index = interpreter_->inputs()[i]; + TfLiteTensor* tensor = interpreter_->tensor(tensor_index); + float* tensor_buffer = tensor->data.f; + ASSERT_NE(tensor_buffer, nullptr); + for (int j = 0; j < width * height * channels; ++j) { + tensor_buffer[j] = i; + } + input_vec_->push_back(*tensor); + input_buffers_[i] = tensor_buffer; + } + } + + void ValidateVectorOutput(std::vector& output_packets, + int expected_elements, int input_begin_index) { + ASSERT_EQ(1, output_packets.size()); + const std::vector& output_vec = + output_packets[0].Get>(); + ASSERT_EQ(expected_elements, output_vec.size()); + + for (int i = 0; i < expected_elements; ++i) { + const int expected_value = input_begin_index + i; + const TfLiteTensor* result = &output_vec[i]; + float* result_buffer = result->data.f; + ASSERT_NE(result_buffer, nullptr); + ASSERT_EQ(result_buffer, input_buffers_[input_begin_index + i]); + for (int j = 0; j < width * height * channels; ++j) { + ASSERT_EQ(expected_value, result_buffer[j]); + } + } + } + + void ValidateElementOutput(std::vector& output_packets, + int input_begin_index) { + ASSERT_EQ(1, output_packets.size()); + + const TfLiteTensor& result = output_packets[0].Get(); + float* result_buffer = result.data.f; + ASSERT_NE(result_buffer, nullptr); + ASSERT_EQ(result_buffer, input_buffers_[input_begin_index]); + + const int expected_value = input_begin_index; + for (int j = 0; j < width * height * channels; ++j) { + ASSERT_EQ(expected_value, result_buffer[j]); + } + } + + std::unique_ptr interpreter_ = absl::make_unique(); + std::unique_ptr> input_vec_ = nullptr; + std::vector input_buffers_; + std::unique_ptr runner_ = nullptr; +}; + +TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTest) { + ASSERT_NE(interpreter_, nullptr); + + PrepareTfLiteTensorVector(/*vector_size=*/5); + ASSERT_NE(input_vec_, nullptr); + + // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie( + R"( + input_stream: "tensor_in" + node { + calculator: "SplitTfLiteTensorVectorCalculator" + input_stream: "tensor_in" + output_stream: "range_0" + output_stream: "range_1" + output_stream: "range_2" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 1 end: 4 } + ranges: { begin: 4 end: 5 } + } + } + } + )"); + std::vector range_0_packets; + tool::AddVectorSink("range_0", &graph_config, &range_0_packets); + std::vector range_1_packets; + tool::AddVectorSink("range_1", &graph_config, &range_1_packets); + std::vector range_2_packets; + tool::AddVectorSink("range_2", &graph_config, &range_2_packets); + + // Run the graph. + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "tensor_in", Adopt(input_vec_.release()).At(Timestamp(0)))); + // Wait until the calculator finishes processing. + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + + ValidateVectorOutput(range_0_packets, /*expected_elements=*/1, + /*input_begin_index=*/0); + ValidateVectorOutput(range_1_packets, /*expected_elements=*/3, + /*input_begin_index=*/1); + ValidateVectorOutput(range_2_packets, /*expected_elements=*/1, + /*input_begin_index=*/4); + + // Fully close the graph at the end. + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("tensor_in")); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); +} + +TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidRangeTest) { + ASSERT_NE(interpreter_, nullptr); + + // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie( + R"( + input_stream: "tensor_in" + node { + calculator: "SplitTfLiteTensorVectorCalculator" + input_stream: "tensor_in" + output_stream: "range_0" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 0 } + } + } + } + )"); + + // Run the graph. + CalculatorGraph graph; + // The graph should fail running because of an invalid range (begin == end). + ASSERT_FALSE(graph.Initialize(graph_config).ok()); +} + +TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOutputStreamCountTest) { + ASSERT_NE(interpreter_, nullptr); + + // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie( + R"( + input_stream: "tensor_in" + node { + calculator: "SplitTfLiteTensorVectorCalculator" + input_stream: "tensor_in" + output_stream: "range_0" + output_stream: "range_1" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 1 } + } + } + } + )"); + + // Run the graph. + CalculatorGraph graph; + // The graph should fail running because the number of output streams does not + // match the number of range elements in the options. + ASSERT_FALSE(graph.Initialize(graph_config).ok()); +} + +TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) { + ASSERT_NE(interpreter_, nullptr); + + PrepareTfLiteTensorVector(/*vector_size=*/5); + ASSERT_NE(input_vec_, nullptr); + + // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie( + R"( + input_stream: "tensor_in" + node { + calculator: "SplitTfLiteTensorVectorCalculator" + input_stream: "tensor_in" + output_stream: "range_0" + output_stream: "range_1" + output_stream: "range_2" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 2 end: 3 } + ranges: { begin: 4 end: 5 } + element_only: true + } + } + } + )"); + std::vector range_0_packets; + tool::AddVectorSink("range_0", &graph_config, &range_0_packets); + std::vector range_1_packets; + tool::AddVectorSink("range_1", &graph_config, &range_1_packets); + std::vector range_2_packets; + tool::AddVectorSink("range_2", &graph_config, &range_2_packets); + + // Run the graph. + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "tensor_in", Adopt(input_vec_.release()).At(Timestamp(0)))); + // Wait until the calculator finishes processing. + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + + ValidateElementOutput(range_0_packets, + /*input_begin_index=*/0); + ValidateElementOutput(range_1_packets, + /*input_begin_index=*/2); + ValidateElementOutput(range_2_packets, + /*input_begin_index=*/4); + + // Fully close the graph at the end. + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("tensor_in")); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); +} + +TEST_F(SplitTfLiteTensorVectorCalculatorTest, + ElementOnlyDisablesVectorOutputs) { + // Prepare a graph to use the SplitTfLiteTensorVectorCalculator. + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie( + R"( + input_stream: "tensor_in" + node { + calculator: "SplitTfLiteTensorVectorCalculator" + input_stream: "tensor_in" + output_stream: "range_0" + output_stream: "range_1" + output_stream: "range_2" + options { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 1 end: 4 } + ranges: { begin: 4 end: 5 } + element_only: true + } + } + } + )"); + + // Run the graph. + CalculatorGraph graph; + ASSERT_FALSE(graph.Initialize(graph_config).ok()); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 2221c698d..070150c7f 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -19,6 +19,7 @@ package(default_visibility = ["//visibility:private"]) exports_files(["LICENSE"]) load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("@bazel_skylib//lib:selects.bzl", "selects") proto_library( name = "opencv_image_encoder_calculator_proto", @@ -46,6 +47,26 @@ proto_library( ], ) +proto_library( + name = "image_cropping_calculator_proto", + srcs = ["image_cropping_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +proto_library( + name = "bilateral_filter_calculator_proto", + srcs = ["bilateral_filter_calculator.proto"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + proto_library( name = "recolor_calculator_proto", srcs = ["recolor_calculator.proto"], @@ -93,6 +114,26 @@ mediapipe_cc_proto_library( deps = [":set_alpha_calculator_proto"], ) +mediapipe_cc_proto_library( + name = "image_cropping_calculator_cc_proto", + srcs = ["image_cropping_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":image_cropping_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "bilateral_filter_calculator_cc_proto", + srcs = ["bilateral_filter_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":bilateral_filter_calculator_proto"], +) + mediapipe_cc_proto_library( name = "recolor_calculator_cc_proto", srcs = ["recolor_calculator.proto"], @@ -185,6 +226,42 @@ cc_library( "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:shader_util", + ], + "//mediapipe:ios": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:shader_util", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + +cc_library( + name = "bilateral_filter_calculator", + srcs = ["bilateral_filter_calculator.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":bilateral_filter_calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:vector", ] + select({ "//mediapipe:android": [ "//mediapipe/gpu:gl_calculator_helper", @@ -221,6 +298,19 @@ mediapipe_cc_proto_library( cc_library( name = "image_transformation_calculator", srcs = ["image_transformation_calculator.cc"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + ], + "//conditions:default": [], + }), + linkopts = select({ + "//mediapipe:apple": [ + "-framework CoreVideo", + "-framework MetalKit", + ], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ ":image_transformation_calculator_cc_proto", @@ -232,8 +322,8 @@ cc_library( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - ] + select({ - "//mediapipe:android": [ + ] + selects.with_or({ + ("//mediapipe:android", "//mediapipe:ios"): [ "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gl_quad_renderer", @@ -247,8 +337,24 @@ cc_library( cc_library( name = "image_cropping_calculator", srcs = ["image_cropping_calculator.cc"], - visibility = ["//visibility:public"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + ], + "//conditions:default": [], + }), + linkopts = select({ + "//mediapipe:apple": [ + "-framework CoreVideo", + "-framework MetalKit", + ], + "//conditions:default": [], + }), + visibility = [ + "//visibility:public", + ], deps = [ + ":image_cropping_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", @@ -257,7 +363,15 @@ cc_library( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - ], + ] + selects.with_or({ + ("//mediapipe:android", "//mediapipe:ios"): [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gl_quad_renderer", + "//mediapipe/gpu:shader_util", + ], + "//conditions:default": [], + }), alwayslink = 1, ) @@ -307,6 +421,12 @@ cc_library( "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:shader_util", ], + "//mediapipe:ios": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:shader_util", + ], "//conditions:default": [], }), alwayslink = 1, @@ -357,6 +477,24 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "image_properties_calculator", + srcs = ["image_properties_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ] + selects.with_or({ + ("//mediapipe:android", "//mediapipe:ios"): [ + "//mediapipe/gpu:gpu_buffer", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + cc_test( name = "opencv_encoded_image_to_image_frame_calculator_test", srcs = ["opencv_encoded_image_to_image_frame_calculator_test.cc"], diff --git a/mediapipe/calculators/image/bilateral_filter_calculator.cc b/mediapipe/calculators/image/bilateral_filter_calculator.cc new file mode 100644 index 000000000..b8b6a88f5 --- /dev/null +++ b/mediapipe/calculators/image/bilateral_filter_calculator.cc @@ -0,0 +1,553 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/calculators/image/bilateral_filter_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/vector.h" + +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" +#endif // __ANDROID__ || __EMSCRIPTEN__ + +namespace mediapipe { + +namespace { +constexpr char kInputFrameTag[] = "IMAGE"; +constexpr char kInputGuideTag[] = "GUIDE"; +constexpr char kOutputFrameTag[] = "IMAGE"; + +constexpr char kInputFrameTagGpu[] = "IMAGE_GPU"; +constexpr char kInputGuideTagGpu[] = "GUIDE_GPU"; +constexpr char kOutputFrameTagGpu[] = "IMAGE_GPU"; + +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; +} // namespace + +// A calculator for applying a bilateral filter to an image, +// with an optional guide image (joint blateral). +// +// Inputs: +// One of the following two IMAGE tags: +// IMAGE: ImageFrame containing input image - Grayscale or RGB only. +// IMAGE_GPU: GpuBuffer containing input image - Grayscale, RGB or RGBA. +// +// GUIDE (optional): ImageFrame guide image used to filter IMAGE. (N/A). +// GUIDE_GPU (optional): GpuBuffer guide image used to filter IMAGE_GPU. +// +// Output: +// One of the following two tags: +// IMAGE: A filtered ImageFrame - Same as input. +// IMAGE_GPU: A filtered GpuBuffer - RGBA +// +// Options: +// sigma_space: Pixel radius: use (sigma_space*2+1)x(sigma_space*2+1) window. +// This should be set based on output image pixel space. +// sigma_color: Color variance: normalized [0-1] color difference allowed. +// +// Notes: +// * When GUIDE is present, the output image is same size as GUIDE image; +// otherwise, the output image is same size as input image. +// * On GPU the kernel window is subsampled by approximately sqrt(sigma_space) +// i.e. the step size is ~sqrt(sigma_space), +// prioritizing performance > quality. +// * TODO: Add CPU path for joint filter. +// +class BilateralFilterCalculator : public CalculatorBase { + public: + BilateralFilterCalculator() = default; + ~BilateralFilterCalculator() override = default; + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + // From Calculator. + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::Status RenderGpu(CalculatorContext* cc); + ::mediapipe::Status RenderCpu(CalculatorContext* cc); + + ::mediapipe::Status GlSetup(CalculatorContext* cc); + void GlRender(CalculatorContext* cc); + + mediapipe::BilateralFilterCalculatorOptions options_; + float sigma_color_ = -1.f; + float sigma_space_ = -1.f; + + bool use_gpu_ = false; + bool gpu_initialized_ = false; +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + mediapipe::GlCalculatorHelper gpu_helper_; + GLuint program_ = 0; + GLuint program_joint_ = 0; +#endif // __ANDROID__ || __EMSCRIPTEN__ +}; +REGISTER_CALCULATOR(BilateralFilterCalculator); + +::mediapipe::Status BilateralFilterCalculator::GetContract( + CalculatorContract* cc) { + CHECK_GE(cc->Inputs().NumEntries(), 1); + + if (cc->Inputs().HasTag(kInputFrameTag) && + cc->Inputs().HasTag(kInputFrameTagGpu)) { + return ::mediapipe::InternalError("Cannot have multiple input images."); + } + if (cc->Inputs().HasTag(kInputFrameTagGpu) != + cc->Outputs().HasTag(kOutputFrameTagGpu)) { + return ::mediapipe::InternalError("GPU output must have GPU input."); + } + + // Input image to filter. +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + if (cc->Inputs().HasTag(kInputFrameTagGpu)) { + cc->Inputs().Tag(kInputFrameTagGpu).Set(); + } +#endif // __ANDROID__ || __EMSCRIPTEN__ + if (cc->Inputs().HasTag(kInputFrameTag)) { + cc->Inputs().Tag(kInputFrameTag).Set(); + } + + // Input guide image mask (optional) + if (cc->Inputs().HasTag(kInputGuideTagGpu)) { +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + cc->Inputs().Tag(kInputGuideTagGpu).Set(); +#endif // __ANDROID__ || __EMSCRIPTEN__ + } + if (cc->Inputs().HasTag(kInputGuideTag)) { + cc->Inputs().Tag(kInputGuideTag).Set(); + } + + // Output image. +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { + cc->Outputs().Tag(kOutputFrameTagGpu).Set(); + } +#endif // __ANDROID__ || __EMSCRIPTEN__ + if (cc->Outputs().HasTag(kOutputFrameTag)) { + cc->Outputs().Tag(kOutputFrameTag).Set(); + } + +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // __ANDROID__ || __EMSCRIPTEN__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status BilateralFilterCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + options_ = cc->Options(); + + if (cc->Inputs().HasTag(kInputFrameTagGpu) && + cc->Outputs().HasTag(kOutputFrameTagGpu)) { +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + use_gpu_ = true; +#else + RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet."; +#endif // __ANDROID__ || __EMSCRIPTEN__ + } + + sigma_color_ = options_.sigma_color(); + sigma_space_ = options_.sigma_space(); + CHECK_GE(sigma_color_, 0.0); + CHECK_GE(sigma_space_, 0.0); + if (!use_gpu_) sigma_color_ *= 255.0; + + if (use_gpu_) { +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status BilateralFilterCalculator::Process(CalculatorContext* cc) { + if (use_gpu_) { +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { + if (!gpu_initialized_) { + RETURN_IF_ERROR(GlSetup(cc)); + gpu_initialized_ = true; + } + RETURN_IF_ERROR(RenderGpu(cc)); + return ::mediapipe::OkStatus(); + })); +#endif // __ANDROID__ || __EMSCRIPTEN__ + } else { + RETURN_IF_ERROR(RenderCpu(cc)); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status BilateralFilterCalculator::Close(CalculatorContext* cc) { +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + gpu_helper_.RunInGlContext([this] { + if (program_) glDeleteProgram(program_); + program_ = 0; + if (program_joint_) glDeleteProgram(program_joint_); + program_joint_ = 0; + }); +#endif // __ANDROID__ || __EMSCRIPTEN__ + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status BilateralFilterCalculator::RenderCpu( + CalculatorContext* cc) { + if (cc->Inputs().Tag(kInputFrameTag).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + + const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get(); + auto input_mat = mediapipe::formats::MatView(&input_frame); + + // Only 1 or 3 channel images supported by OpenCV. + if ((input_mat.channels() == 1 || input_mat.channels() == 3)) { + return ::mediapipe::InternalError( + "CPU filtering supports only 1 or 3 channel input images."); + } + + auto output_frame = absl::make_unique( + input_frame.Format(), input_mat.cols, input_mat.rows); + const bool has_guide_image = cc->Inputs().HasTag(kInputGuideTag) && + !cc->Inputs().Tag(kInputGuideTag).IsEmpty(); + + if (has_guide_image) { + // cv::jointBilateralFilter() is in contrib module 'ximgproc'. + return ::mediapipe::UnimplementedError( + "CPU joint filtering support is not implemented yet."); + } else { + auto output_mat = mediapipe::formats::MatView(output_frame.get()); + // Prefer setting 'd = sigma_space * 2' to match GPU definition of radius. + cv::bilateralFilter(input_mat, output_mat, /*d=*/sigma_space_ * 2.0, + sigma_color_, sigma_space_); + } + + cc->Outputs() + .Tag(kOutputFrameTag) + .Add(output_frame.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status BilateralFilterCalculator::RenderGpu( + CalculatorContext* cc) { + if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) { + return ::mediapipe::OkStatus(); + } +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + const auto& input_frame = + cc->Inputs().Tag(kInputFrameTagGpu).Get(); + auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); + + mediapipe::GlTexture output_texture; + const bool has_guide_image = cc->Inputs().HasTag(kInputGuideTagGpu) && + !cc->Inputs().Tag(kInputGuideTagGpu).IsEmpty(); + + // Setup textures and Update image in GPU shader. + if (has_guide_image) { + // joint bilateral filter + glUseProgram(program_joint_); + const auto& guide_image = + cc->Inputs().Tag(kInputGuideTagGpu).Get(); + auto guide_texture = gpu_helper_.CreateSourceTexture(guide_image); + glUniform2f(glGetUniformLocation(program_joint_, "texel_size_guide"), + 1.0 / guide_image.width(), 1.0 / guide_image.height()); + output_texture = gpu_helper_.CreateDestinationTexture( + guide_image.width(), guide_image.height(), + mediapipe::GpuBufferFormat::kBGRA32); + gpu_helper_.BindFramebuffer(output_texture); + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, input_texture.name()); + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, guide_texture.name()); + GlRender(cc); + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, 0); + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, 0); + guide_texture.Release(); + } else { + // regular bilateral filter + glUseProgram(program_); + glUniform2f(glGetUniformLocation(program_, "texel_size"), + 1.0 / input_frame.width(), 1.0 / input_frame.height()); + output_texture = gpu_helper_.CreateDestinationTexture( + input_frame.width(), input_frame.height(), + mediapipe::GpuBufferFormat::kBGRA32); + gpu_helper_.BindFramebuffer(output_texture); + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, input_texture.name()); + GlRender(cc); + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, 0); + } + glFlush(); + + // Send out image as GPU packet. + auto output_frame = output_texture.GetFrame(); + cc->Outputs() + .Tag(kOutputFrameTagGpu) + .Add(output_frame.release(), cc->InputTimestamp()); + + // Cleanup + input_texture.Release(); + output_texture.Release(); +#endif // __ANDROID__ || __EMSCRIPTEN__ + + return ::mediapipe::OkStatus(); +} + +void BilateralFilterCalculator::GlRender(CalculatorContext* cc) { +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const GLfloat texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); + +#endif // __ANDROID__ || __EMSCRIPTEN__ +} + +::mediapipe::Status BilateralFilterCalculator::GlSetup(CalculatorContext* cc) { +#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; + + // We bake our sigma values directly into the shader, so the GLSL compiler can + // optimize appropriately. + std::string sigma_options_string = + "const float sigma_space = " + std::to_string(sigma_space_) + + "; const float sigma_color = " + std::to_string(sigma_color_) + ";\n"; + + // Shader to do bilateral filtering on input image based on sigma space/color. + // Large kernel sizes are subsampled based on sqrt(sigma_space) window size, + // denoted as 'sparsity' below. + const std::string frag_src = GLES_VERSION_COMPAT + R"( + #if __VERSION__ < 130 + #define in varying + #endif // __VERSION__ < 130 + + #ifdef GL_ES + #define fragColor gl_FragColor + precision highp float; + #else + #define lowp + #define mediump + #define highp + #define texture2D texture + out vec4 fragColor; + #endif // defined(GL_ES) + + in vec2 sample_coordinate; + uniform sampler2D input_frame; +)" + sigma_options_string + R"( + uniform vec2 texel_size; + + const float kSparsityFactor = 0.66; // Higher is more sparse. + const float sparsity = max(1.0, sqrt(sigma_space) * kSparsityFactor); + const float step = sparsity; + const float radius = sigma_space; + const float offset = (step > 1.0) ? (step * 0.5) : (0.0); + + float gaussian(float x, float sigma) { + float coeff = -0.5 / (sigma * sigma * 4.0 + 1.0e-6); + return exp((x * x) * coeff); + } + + void main() { + vec2 center_uv = sample_coordinate; + vec3 center_val = texture2D(input_frame, center_uv).rgb; + vec3 new_val = vec3(0.0); + + float space_weight = 0.0; + float color_weight = 0.0; + float total_weight = 0.0; + + float sigma_texel = max(texel_size.x, texel_size.y) * sigma_space; + // Subsample kernel space. + for (float i = -radius+offset; i <= radius; i+=step) { + for (float j = -radius+offset; j <= radius; j+=step) { + vec2 shift = vec2(j, i) * texel_size; + vec2 uv = vec2(center_uv + shift); + vec3 val = texture2D(input_frame, uv).rgb; + + space_weight = gaussian(distance(center_uv, uv), sigma_texel); + color_weight = gaussian(distance(center_val, val), sigma_color); + total_weight += space_weight * color_weight; + + new_val += vec3(space_weight * color_weight) * val; + } + } + new_val /= vec3(total_weight); + + fragColor = vec4(new_val, 1.0); + } + )"; + + // Create shader program and set parameters. + mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src.c_str(), + NUM_ATTRIBUTES, (const GLchar**)&attr_name[0], + attr_location, &program_); + RET_CHECK(program_) << "Problem initializing the program."; + glUseProgram(program_); + glUniform1i(glGetUniformLocation(program_, "input_frame"), 1); + + // Shader to do joint bilateral filtering on input image based on + // sigma space/color, and a Guide image. + // Large kernel sizes are subsampled based on sqrt(sigma_space) window size, + // denoted as 'sparsity' below. + const std::string joint_frag_src = GLES_VERSION_COMPAT + R"( + #if __VERSION__ < 130 + #define in varying + #endif // __VERSION__ < 130 + + #ifdef GL_ES + #define fragColor gl_FragColor + precision highp float; + #else + #define lowp + #define mediump + #define highp + #define texture2D texture + out vec4 fragColor; + #endif // defined(GL_ES) + + in vec2 sample_coordinate; + uniform sampler2D input_frame; + uniform sampler2D guide_frame; +)" + sigma_options_string + R"( + uniform vec2 texel_size_guide; // size of guide and resulting filtered image + + const float kSparsityFactor = 0.66; // Higher is more sparse. + const float sparsity = max(1.0, sqrt(sigma_space) * kSparsityFactor); + const float step = sparsity; + const float radius = sigma_space; + const float offset = (step > 1.0) ? (step * 0.5) : (0.0); + + float gaussian(float x, float sigma) { + float coeff = -0.5 / (sigma * sigma * 4.0 + 1.0e-6); + return exp((x * x) * coeff); + } + + void main() { + vec2 center_uv = sample_coordinate; + vec3 center_val = texture2D(guide_frame, center_uv).rgb; + vec3 new_val = vec3(0.0); + + float space_weight = 0.0; + float color_weight = 0.0; + float total_weight = 0.0; + + float sigma_texel = max(texel_size_guide.x, texel_size_guide.y) * sigma_space; + // Subsample kernel space. + for (float i = -radius+offset; i <= radius; i+=step) { + for (float j = -radius+offset; j <= radius; j+=step) { + vec2 shift = vec2(j, i) * texel_size_guide; + vec2 uv = vec2(center_uv + shift); + vec3 guide_val = texture2D(guide_frame, uv).rgb; + vec3 out_val = texture2D(input_frame, uv).rgb; + + space_weight = gaussian(distance(center_uv, uv), sigma_texel); + color_weight = gaussian(distance(center_val, guide_val), sigma_color); + total_weight += space_weight * color_weight; + + new_val += vec3(space_weight * color_weight) * out_val; + } + } + new_val /= vec3(total_weight); + + fragColor = vec4(new_val, 1.0); + } + )"; + + // Create shader program and set parameters. + mediapipe::GlhCreateProgram( + mediapipe::kBasicVertexShader, joint_frag_src.c_str(), NUM_ATTRIBUTES, + (const GLchar**)&attr_name[0], attr_location, &program_joint_); + RET_CHECK(program_joint_) << "Problem initializing the program."; + glUseProgram(program_joint_); + glUniform1i(glGetUniformLocation(program_joint_, "input_frame"), 1); + glUniform1i(glGetUniformLocation(program_joint_, "guide_frame"), 2); + +#endif // __ANDROID__ || __EMSCRIPTEN__ + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/bilateral_filter_calculator.proto b/mediapipe/calculators/image/bilateral_filter_calculator.proto new file mode 100644 index 000000000..a787437dc --- /dev/null +++ b/mediapipe/calculators/image/bilateral_filter_calculator.proto @@ -0,0 +1,20 @@ +// Options for BilateralFilterCalculator +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message BilateralFilterCalculatorOptions { + extend CalculatorOptions { + optional BilateralFilterCalculatorOptions ext = 255670209; + } + + // Max variance in color allowed, based on normalized color values. + optional float sigma_color = 1; + + // Window radius. + // Results in a '(sigma_space*2+1) x (sigma_space*2+1)' size kernel. + // This should be set based on output image pixel space. + optional float sigma_space = 2; +} diff --git a/mediapipe/calculators/image/color_convert_calculator.cc b/mediapipe/calculators/image/color_convert_calculator.cc index 875323e9e..3ff1a738f 100644 --- a/mediapipe/calculators/image/color_convert_calculator.cc +++ b/mediapipe/calculators/image/color_convert_calculator.cc @@ -1,3 +1,17 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index b9477749e..c0a7894ba 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + +#include "mediapipe/calculators/image/image_cropping_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" @@ -21,47 +24,98 @@ #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/shader_util.h" +#endif // __ANDROID__ or iOS + +namespace { +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; +} // namespace + namespace mediapipe { +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + +#endif // __ANDROID__ or iOS + // Crops the input texture to the given rectangle region. The rectangle can // be at arbitrary location on the image with rotation. If there's rotation, the // output texture will have the size of the input rectangle. The rotation should // be in radian, see rect.proto for detail. -// Currently it only works for CPU. // // Input: -// IMAGE: ImageFrame representing the input image. // One of the following two tags: +// IMAGE - ImageFrame representing the input image. +// IMAGE_GPU - GpuBuffer representing the input image. +// One of the following two tags (optional if WIDTH/HEIGHT is specified): // RECT - A Rect proto specifying the width/height and location of the // cropping rectangle. // NORM_RECT - A NormalizedRect proto specifying the width/height and location -// of the cropping rectangle in normalized coordinates. +// of the cropping rectangle in normalized coordinates. +// Alternative tags to RECT (optional if RECT/NORM_RECT is specified): +// WIDTH - The desired width of the output cropped image, +// based on image center +// HEIGHT - The desired height of the output cropped image, +// based on image center // // Output: -// IMAGE - Cropped frames. +// One of the following two tags: +// IMAGE - Cropped ImageFrame +// IMAGE_GPU - Cropped GpuBuffer. +// +// Note: input_stream values take precedence over options defined in the graph. +// class ImageCroppingCalculator : public CalculatorBase { public: ImageCroppingCalculator() = default; ~ImageCroppingCalculator() override = default; static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Open(CalculatorContext* cc) override; ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; private: ::mediapipe::Status RenderCpu(CalculatorContext* cc); ::mediapipe::Status RenderGpu(CalculatorContext* cc); + ::mediapipe::Status InitGpu(CalculatorContext* cc); + void GlRender(); + void GetOutputDimensions(CalculatorContext* cc, int src_width, int src_height, + int* dst_width, int* dst_height); - // TODO: Merge with GlCroppingCalculator to have GPU support. - bool use_gpu_{}; + mediapipe::ImageCroppingCalculatorOptions options_; + + bool use_gpu_ = false; + // Output texture corners (4) after transoformation in normalized coordinates. + float transformed_points_[8]; +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + bool gpu_initialized_ = false; + mediapipe::GlCalculatorHelper gpu_helper_; + GLuint program_ = 0; +#endif // __ANDROID__ or iOS }; REGISTER_CALCULATOR(ImageCroppingCalculator); ::mediapipe::Status ImageCroppingCalculator::GetContract( CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("IMAGE")); - RET_CHECK(cc->Outputs().HasTag("IMAGE")); + RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU")); + RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU")); - cc->Inputs().Tag("IMAGE").Set(); + if (cc->Inputs().HasTag("IMAGE")) { + RET_CHECK(cc->Outputs().HasTag("IMAGE")); + cc->Inputs().Tag("IMAGE").Set(); + cc->Outputs().Tag("IMAGE").Set(); + } +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + if (cc->Inputs().HasTag("IMAGE_GPU")) { + RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU")); + cc->Inputs().Tag("IMAGE_GPU").Set(); + cc->Outputs().Tag("IMAGE_GPU").Set(); + } +#endif // __ANDROID__ or iOS if (cc->Inputs().HasTag("RECT")) { cc->Inputs().Tag("RECT").Set(); @@ -69,21 +123,71 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); if (cc->Inputs().HasTag("NORM_RECT")) { cc->Inputs().Tag("NORM_RECT").Set(); } + if (cc->Inputs().HasTag("WIDTH")) { + cc->Inputs().Tag("WIDTH").Set(); + } + if (cc->Inputs().HasTag("HEIGHT")) { + cc->Inputs().Tag("HEIGHT").Set(); + } - cc->Outputs().Tag("IMAGE").Set(); +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // __ANDROID__ or iOS + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ImageCroppingCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + if (cc->Inputs().HasTag("IMAGE_GPU")) { + use_gpu_ = true; + } + + options_ = cc->Options(); + + if (use_gpu_) { +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#else + RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; +#endif // __ANDROID__ or iOS + } return ::mediapipe::OkStatus(); } ::mediapipe::Status ImageCroppingCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { - RETURN_IF_ERROR(RenderGpu(cc)); +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { + if (!gpu_initialized_) { + RETURN_IF_ERROR(InitGpu(cc)); + gpu_initialized_ = true; + } + RETURN_IF_ERROR(RenderGpu(cc)); + return ::mediapipe::OkStatus(); + })); +#endif // __ANDROID__ or iOS } else { RETURN_IF_ERROR(RenderCpu(cc)); } return ::mediapipe::OkStatus(); } +::mediapipe::Status ImageCroppingCalculator::Close(CalculatorContext* cc) { +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + gpu_helper_.RunInGlContext([this] { + if (program_) glDeleteProgram(program_); + program_ = 0; + }); + gpu_initialized_ = false; +#endif // __ANDROID__ or iOS + + return ::mediapipe::OkStatus(); +} + ::mediapipe::Status ImageCroppingCalculator::RenderCpu(CalculatorContext* cc) { const auto& input_img = cc->Inputs().Tag("IMAGE").Get(); cv::Mat input_mat = formats::MatView(&input_img); @@ -97,43 +201,53 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); const auto& rect = cc->Inputs().Tag("RECT").Get(); if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 && rect.y_center() >= 0) { - rotation = rect.rotation(); rect_center_x = rect.x_center(); rect_center_y = rect.y_center(); target_width = rect.width(); target_height = rect.height(); + rotation = rect.rotation(); } } else if (cc->Inputs().HasTag("NORM_RECT")) { const auto& rect = cc->Inputs().Tag("NORM_RECT").Get(); if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 && rect.y_center() >= 0.0) { - rotation = rect.rotation(); rect_center_x = std::round(rect.x_center() * input_img.Width()); rect_center_y = std::round(rect.y_center() * input_img.Height()); target_width = std::round(rect.width() * input_img.Width()); target_height = std::round(rect.height() * input_img.Height()); + rotation = rect.rotation(); } - } - - cv::Mat rotated_mat; - if (std::abs(rotation) > 1e-5) { - // TODO: Use open source common math library. - const float pi = 3.1415926f; - rotation = rotation * 180.0 / pi; - - // First rotation the image. - cv::Point2f src_center(rect_center_x, rect_center_y); - cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, rotation, 1.0); - cv::warpAffine(input_mat, rotated_mat, rotation_mat, input_mat.size()); } else { - input_mat.copyTo(rotated_mat); + if (cc->Inputs().HasTag("WIDTH") && cc->Inputs().HasTag("HEIGHT")) { + target_width = cc->Inputs().Tag("WIDTH").Get(); + target_height = cc->Inputs().Tag("HEIGHT").Get(); + } else if (options_.has_width() && options_.has_height()) { + target_width = options_.width(); + target_height = options_.height(); + } + rotation = options_.rotation(); } - // Then crop the requested area. - const cv::Rect cropping_rect(rect_center_x - target_width / 2, - rect_center_y - target_height / 2, target_width, - target_height); - cv::Mat cropped_image = cv::Mat(rotated_mat, cropping_rect); + const cv::RotatedRect min_rect(cv::Point2f(rect_center_x, rect_center_y), + cv::Size2f(target_width, target_height), + rotation * 180.f / M_PI); + cv::Mat src_points; + cv::boxPoints(min_rect, src_points); + + float dst_corners[8] = {0, + min_rect.size.height - 1, + 0, + 0, + min_rect.size.width - 1, + 0, + min_rect.size.width - 1, + min_rect.size.height - 1}; + cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners); + cv::Mat projection_matrix = + cv::getPerspectiveTransform(src_points, dst_points); + cv::Mat cropped_image; + cv::warpPerspective(input_mat, cropped_image, projection_matrix, + cv::Size(min_rect.size.width, min_rect.size.height)); std::unique_ptr output_frame(new ImageFrame( input_img.Format(), cropped_image.cols, cropped_image.rows)); @@ -144,7 +258,220 @@ REGISTER_CALCULATOR(ImageCroppingCalculator); } ::mediapipe::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) { - return ::mediapipe::UnimplementedError("GPU support is not implemented yet."); + if (cc->Inputs().Tag("IMAGE_GPU").IsEmpty()) { + return ::mediapipe::OkStatus(); + } +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value(); + const auto& input_buffer = input_packet.Get(); + auto src_tex = gpu_helper_.CreateSourceTexture(input_buffer); + + int out_width, out_height; + GetOutputDimensions(cc, src_tex.width(), src_tex.height(), &out_width, + &out_height); + auto dst_tex = gpu_helper_.CreateDestinationTexture(out_width, out_height); + + // Run cropping shader on GPU. + { + gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0 + + glActiveTexture(GL_TEXTURE1); + glBindTexture(src_tex.target(), src_tex.name()); + + GlRender(); + + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, 0); + glFlush(); + } + + // Send result image in GPU packet. + auto output = dst_tex.GetFrame(); + cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp()); + + // Cleanup + src_tex.Release(); + dst_tex.Release(); +#endif // __ANDROID__ or iOS + + return ::mediapipe::OkStatus(); +} + +void ImageCroppingCalculator::GlRender() { +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + const GLfloat* texture_vertices = &transformed_points_[0]; + + // program + glUseProgram(program_); + + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); + + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); + +#endif // __ANDROID__ or iOS +} + +::mediapipe::Status ImageCroppingCalculator::InitGpu(CalculatorContext* cc) { +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + }; + const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", + }; + + // Simple pass-through shader. + const GLchar* frag_src = GLES_VERSION_COMPAT + R"( + #if __VERSION__ < 130 + #define in varying + #endif // __VERSION__ < 130 + + #ifdef GL_ES + #define fragColor gl_FragColor + precision highp float; + #else + #define lowp + #define mediump + #define highp + #define texture2D texture + out vec4 fragColor; + #endif // defined(GL_ES) + + in vec2 sample_coordinate; + uniform sampler2D input_frame; + + void main() { + vec4 pix = texture2D(input_frame, sample_coordinate); + fragColor = pix; + } + )"; + + // Program + mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src, + NUM_ATTRIBUTES, &attr_name[0], attr_location, + &program_); + RET_CHECK(program_) << "Problem initializing the program."; + + // Parameters + glUseProgram(program_); + glUniform1i(glGetUniformLocation(program_, "input_frame"), 1); +#endif // __ANDROID__ or iOS + + return ::mediapipe::OkStatus(); +} + +// For GPU only. +void ImageCroppingCalculator::GetOutputDimensions(CalculatorContext* cc, + int src_width, int src_height, + int* dst_width, + int* dst_height) { + // Get the size of the cropping box. + int crop_width = src_width; + int crop_height = src_height; + // Get the center of cropping box. Default is the at the center. + int x_center = src_width / 2; + int y_center = src_height / 2; + // Get the rotation of the cropping box. + float rotation = 0.0f; + if (cc->Inputs().HasTag("RECT")) { + const auto& rect = cc->Inputs().Tag("RECT").Get(); + // Only use the rect if it is valid. + if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 && + rect.y_center() >= 0) { + x_center = rect.x_center(); + y_center = rect.y_center(); + crop_width = rect.width(); + crop_height = rect.height(); + rotation = rect.rotation(); + } + } else if (cc->Inputs().HasTag("NORM_RECT")) { + const auto& rect = cc->Inputs().Tag("NORM_RECT").Get(); + // Only use the rect if it is valid. + if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 && + rect.y_center() >= 0.0) { + x_center = std::round(rect.x_center() * src_width); + y_center = std::round(rect.y_center() * src_height); + crop_width = std::round(rect.width() * src_width); + crop_height = std::round(rect.height() * src_height); + rotation = rect.rotation(); + } + } else { + if (cc->Inputs().HasTag("WIDTH") && cc->Inputs().HasTag("HEIGHT")) { + crop_width = cc->Inputs().Tag("WIDTH").Get(); + crop_height = cc->Inputs().Tag("HEIGHT").Get(); + } else if (options_.has_width() && options_.has_height()) { + crop_width = options_.width(); + crop_height = options_.height(); + } + rotation = options_.rotation(); + } + + const float half_width = crop_width / 2.0f; + const float half_height = crop_height / 2.0f; + const float corners[] = {-half_width, -half_height, half_width, -half_height, + -half_width, half_height, half_width, half_height}; + + for (int i = 0; i < 4; ++i) { + const float rotated_x = std::cos(rotation) * corners[i * 2] - + std::sin(rotation) * corners[i * 2 + 1]; + const float rotated_y = std::sin(rotation) * corners[i * 2] + + std::cos(rotation) * corners[i * 2 + 1]; + + transformed_points_[i * 2] = ((rotated_x + x_center) / src_width); + transformed_points_[i * 2 + 1] = ((rotated_y + y_center) / src_height); + } + + // Find the boundaries of the transformed rectangle. + float col_min = transformed_points_[0]; + float col_max = transformed_points_[0]; + float row_min = transformed_points_[1]; + float row_max = transformed_points_[1]; + for (int i = 1; i < 4; ++i) { + col_min = std::min(col_min, transformed_points_[i * 2]); + col_max = std::max(col_max, transformed_points_[i * 2]); + row_min = std::min(row_min, transformed_points_[i * 2 + 1]); + row_max = std::max(row_max, transformed_points_[i * 2 + 1]); + } + + *dst_width = std::round((col_max - col_min) * src_width); + *dst_height = std::round((row_max - row_min) * src_height); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/image_cropping_calculator.proto b/mediapipe/calculators/image/image_cropping_calculator.proto new file mode 100644 index 000000000..70e271035 --- /dev/null +++ b/mediapipe/calculators/image/image_cropping_calculator.proto @@ -0,0 +1,33 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message ImageCroppingCalculatorOptions { + extend CalculatorOptions { + optional ImageCroppingCalculatorOptions ext = 262466399; + } + + // Output texture buffer dimensions. The values defined in the options will be + // overriden by the WIDTH and HEIGHT input streams if they exist. + optional int32 width = 1; + optional int32 height = 2; + + // Rotation angle is counter-clockwise in radian. + optional float rotation = 3 [default = 0.0]; +} diff --git a/mediapipe/calculators/image/image_properties_calculator.cc b/mediapipe/calculators/image/image_properties_calculator.cc new file mode 100644 index 000000000..70c49de61 --- /dev/null +++ b/mediapipe/calculators/image/image_properties_calculator.cc @@ -0,0 +1,93 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" + +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) +#include "mediapipe/gpu/gpu_buffer.h" +#endif // __ANDROID__ or iOS + +namespace mediapipe { + +// Extracts image properties from the input image and outputs the properties. +// Currently only supports image size. +// Input: +// One of the following: +// IMAGE: An ImageFrame +// IMAGE_GPU: A GpuBuffer +// +// Output: +// SIZE: Size (as a std::pair) of the input image. +// +// Example usage: +// node { +// calculator: "ImagePropertiesCalculator" +// input_stream: "IMAGE:image" +// output_stream: "SIZE:size" +// } +class ImagePropertiesCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU")); + if (cc->Inputs().HasTag("IMAGE")) { + cc->Inputs().Tag("IMAGE").Set(); + } +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + if (cc->Inputs().HasTag("IMAGE_GPU")) { + cc->Inputs().Tag("IMAGE_GPU").Set<::mediapipe::GpuBuffer>(); + } +#endif // __ANDROID__ or iOS + + if (cc->Outputs().HasTag("SIZE")) { + cc->Outputs().Tag("SIZE").Set>(); + } + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + int width; + int height; + + if (cc->Inputs().HasTag("IMAGE") && !cc->Inputs().Tag("IMAGE").IsEmpty()) { + const auto& image = cc->Inputs().Tag("IMAGE").Get(); + width = image.Width(); + height = image.Height(); + } +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + if (cc->Inputs().HasTag("IMAGE_GPU") && + !cc->Inputs().Tag("IMAGE_GPU").IsEmpty()) { + const auto& image = + cc->Inputs().Tag("IMAGE_GPU").Get(); + width = image.width(); + height = image.height(); + } +#endif // __ANDROID__ or iOS + + cc->Outputs().Tag("SIZE").AddPacket( + MakePacket>(width, height) + .At(cc->InputTimestamp())); + + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(ImagePropertiesCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index 90e16781f..de33d4424 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -22,12 +22,12 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/scale_mode.pb.h" -#if defined(__ANDROID__) +#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_quad_renderer.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" -#endif // __ANDROID__ +#endif // __ANDROID__ || iOS #if defined(__ANDROID__) // The size of Java arrays is dynamic, which makes it difficult to @@ -42,9 +42,9 @@ typedef int DimensionsPacketType[2]; namespace mediapipe { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX -#endif // __ANDROID__ +#endif // __ANDROID__ || iOS namespace { int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) { @@ -170,11 +170,12 @@ class ImageTransformationCalculator : public CalculatorBase { mediapipe::ScaleMode_Mode scale_mode_; bool use_gpu_ = false; -#if defined(__ANDROID__) +#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX GlCalculatorHelper helper_; std::unique_ptr rgb_renderer_; + std::unique_ptr yuv_renderer_; std::unique_ptr ext_rgb_renderer_; -#endif // __ANDROID__ +#endif // __ANDROID__ || iOS }; REGISTER_CALCULATOR(ImageTransformationCalculator); @@ -189,13 +190,13 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); cc->Inputs().Tag("IMAGE").Set(); cc->Outputs().Tag("IMAGE").Set(); } -#if defined(__ANDROID__) +#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX if (cc->Inputs().HasTag("IMAGE_GPU")) { RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU")); cc->Inputs().Tag("IMAGE_GPU").Set(); cc->Outputs().Tag("IMAGE_GPU").Set(); } -#endif // __ANDROID__ +#endif // __ANDROID__ || iOS if (cc->Inputs().HasTag("ROTATION_DEGREES")) { cc->Inputs().Tag("ROTATION_DEGREES").Set(); } @@ -211,9 +212,9 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); cc->Outputs().Tag("LETTERBOX_PADDING").Set>(); } -#if defined(__ANDROID__) +#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); -#endif // __ANDROID__ +#endif // __ANDROID__ || iOS return ::mediapipe::OkStatus(); } @@ -221,7 +222,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); ::mediapipe::Status ImageTransformationCalculator::Open(CalculatorContext* cc) { // Inform the framework that we always output at the same timestamp // as we receive a packet at. - cc->SetOffset(mediapipe::TimestampDiff(0)); + cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); @@ -249,12 +250,12 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); scale_mode_ = ParseScaleMode(options_.scale_mode(), DEFAULT_SCALE_MODE); if (use_gpu_) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX // Let the helper access the GL context information. RETURN_IF_ERROR(helper_.Open(cc)); #else - RET_CHECK_FAIL() << "GPU processing for non-Android not supported yet."; -#endif // __ANDROID__ + RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; +#endif // __ANDROID__ || iOS } return ::mediapipe::OkStatus(); @@ -263,10 +264,10 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); ::mediapipe::Status ImageTransformationCalculator::Process( CalculatorContext* cc) { if (use_gpu_) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX return helper_.RunInGlContext( [this, cc]() -> ::mediapipe::Status { return RenderGpu(cc); }); -#endif // __ANDROID__ +#endif // __ANDROID__ || iOS } else { return RenderCpu(cc); } @@ -276,10 +277,11 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); ::mediapipe::Status ImageTransformationCalculator::Close( CalculatorContext* cc) { if (use_gpu_) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX QuadRenderer* rgb_renderer = rgb_renderer_.release(); + QuadRenderer* yuv_renderer = yuv_renderer_.release(); QuadRenderer* ext_rgb_renderer = ext_rgb_renderer_.release(); - helper_.RunInGlContext([rgb_renderer, ext_rgb_renderer] { + helper_.RunInGlContext([rgb_renderer, yuv_renderer, ext_rgb_renderer] { if (rgb_renderer) { rgb_renderer->GlTeardown(); delete rgb_renderer; @@ -288,10 +290,13 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); ext_rgb_renderer->GlTeardown(); delete ext_rgb_renderer; } + if (yuv_renderer) { + yuv_renderer->GlTeardown(); + delete yuv_renderer; + } }); -#endif // __ANDROID__ +#endif // __ANDROID__ || iOS } - return ::mediapipe::OkStatus(); } @@ -366,7 +371,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); ::mediapipe::Status ImageTransformationCalculator::RenderGpu( CalculatorContext* cc) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX int input_width = cc->Inputs().Tag("IMAGE_GPU").Get().width(); int input_height = cc->Inputs().Tag("IMAGE_GPU").Get().height(); @@ -387,8 +392,23 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get(); QuadRenderer* renderer = nullptr; GlTexture src1; + +#if defined(__APPLE__) && !TARGET_OS_OSX + if (input.format() == GpuBufferFormat::kBiPlanar420YpCbCr8VideoRange || + input.format() == GpuBufferFormat::kBiPlanar420YpCbCr8FullRange) { + if (!yuv_renderer_) { + yuv_renderer_ = absl::make_unique(); + RETURN_IF_ERROR( + yuv_renderer_->GlSetup(::mediapipe::kYUV2TexToRGBFragmentShader, + {"video_frame_y", "video_frame_uv"})); + } + renderer = yuv_renderer_.get(); + src1 = helper_.CreateSourceTexture(input, 0); + } else // NOLINT(readability/braces) +#endif // iOS { src1 = helper_.CreateSourceTexture(input); +#if defined(__ANDROID__) if (src1.target() == GL_TEXTURE_EXTERNAL_OES) { if (!ext_rgb_renderer_) { ext_rgb_renderer_ = absl::make_unique(); @@ -396,7 +416,9 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); ::mediapipe::kBasicTexturedFragmentShaderOES, {"video_frame"})); } renderer = ext_rgb_renderer_.get(); - } else { + } else // NOLINT(readability/braces) +#endif // __ANDROID__ + { if (!rgb_renderer_) { rgb_renderer_ = absl::make_unique(); RETURN_IF_ERROR(rgb_renderer_->GlSetup()); @@ -438,7 +460,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); auto output = dst.GetFrame(); cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp()); -#endif // __ANDROID__ +#endif // __ANDROID__ || iOS return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/image/opencv_put_text_calculator.cc b/mediapipe/calculators/image/opencv_put_text_calculator.cc index ff336ff92..07f6f0dbf 100644 --- a/mediapipe/calculators/image/opencv_put_text_calculator.cc +++ b/mediapipe/calculators/image/opencv_put_text_calculator.cc @@ -45,11 +45,11 @@ class OpenCvPutTextCalculator : public CalculatorBase { ::mediapipe::Status OpenCvPutTextCalculator::Process(CalculatorContext* cc) { const std::string& text_content = cc->Inputs().Index(0).Get(); - cv::Mat mat = cv::Mat::zeros(640, 640, CV_8UC3); + cv::Mat mat = cv::Mat::zeros(640, 640, CV_8UC4); cv::putText(mat, text_content, cv::Point(15, 70), cv::FONT_HERSHEY_PLAIN, 3, - cv::Scalar(255, 255, 0), 4); + cv::Scalar(255, 255, 0, 255), 4); std::unique_ptr output_frame = absl::make_unique( - ImageFormat::SRGB, mat.size().width, mat.size().height); + ImageFormat::SRGBA, mat.size().width, mat.size().height); mat.copyTo(formats::MatView(output_frame.get())); cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp()); return ::mediapipe::OkStatus(); diff --git a/mediapipe/calculators/image/recolor_calculator.cc b/mediapipe/calculators/image/recolor_calculator.cc index fb74dd45f..1c2a0fcda 100644 --- a/mediapipe/calculators/image/recolor_calculator.cc +++ b/mediapipe/calculators/image/recolor_calculator.cc @@ -21,12 +21,12 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/util/color.pb.h" -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS namespace { enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; @@ -95,10 +95,10 @@ class RecolorCalculator : public CalculatorBase { mediapipe::RecolorCalculatorOptions::MaskChannel mask_channel_; bool use_gpu_ = false; -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS }; REGISTER_CALCULATOR(RecolorCalculator); @@ -107,46 +107,48 @@ REGISTER_CALCULATOR(RecolorCalculator); RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Inputs().HasTag("IMAGE_GPU")) { cc->Inputs().Tag("IMAGE_GPU").Set(); } -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS if (cc->Inputs().HasTag("IMAGE")) { cc->Inputs().Tag("IMAGE").Set(); } -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Inputs().HasTag("MASK_GPU")) { cc->Inputs().Tag("MASK_GPU").Set(); } -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS if (cc->Inputs().HasTag("MASK")) { cc->Inputs().Tag("MASK").Set(); } -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Outputs().HasTag("IMAGE_GPU")) { cc->Outputs().Tag("IMAGE_GPU").Set(); } -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS if (cc->Outputs().HasTag("IMAGE")) { cc->Outputs().Tag("IMAGE").Set(); } -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } ::mediapipe::Status RecolorCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + if (cc->Inputs().HasTag("IMAGE_GPU")) { use_gpu_ = true; -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS } RETURN_IF_ERROR(LoadOptions(cc)); @@ -156,7 +158,7 @@ REGISTER_CALCULATOR(RecolorCalculator); ::mediapipe::Status RecolorCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { if (!initialized_) { @@ -166,7 +168,7 @@ REGISTER_CALCULATOR(RecolorCalculator); RETURN_IF_ERROR(RenderGpu(cc)); return ::mediapipe::OkStatus(); })); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS } else { RETURN_IF_ERROR(RenderCpu(cc)); } @@ -174,12 +176,12 @@ REGISTER_CALCULATOR(RecolorCalculator); } ::mediapipe::Status RecolorCalculator::Close(CalculatorContext* cc) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; }); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } @@ -192,7 +194,7 @@ REGISTER_CALCULATOR(RecolorCalculator); if (cc->Inputs().Tag("MASK_GPU").IsEmpty()) { return ::mediapipe::OkStatus(); } -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) // Get inputs and setup output. const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value(); const Packet& mask_packet = cc->Inputs().Tag("MASK_GPU").Value(); @@ -231,13 +233,13 @@ REGISTER_CALCULATOR(RecolorCalculator); img_tex.Release(); mask_tex.Release(); dst_tex.Release(); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } void RecolorCalculator::GlRender() { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -285,7 +287,7 @@ void RecolorCalculator::GlRender() { glBindVertexArray(0); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS } ::mediapipe::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) { @@ -303,7 +305,7 @@ void RecolorCalculator::GlRender() { } ::mediapipe::Status RecolorCalculator::InitGpu(CalculatorContext* cc) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -372,7 +374,7 @@ void RecolorCalculator::GlRender() { glUniform1i(glGetUniformLocation(program_, "mask"), 2); glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0], color_[1], color_[2]); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/image/set_alpha_calculator.cc b/mediapipe/calculators/image/set_alpha_calculator.cc index 9b8a0b416..4c3ab29ae 100644 --- a/mediapipe/calculators/image/set_alpha_calculator.cc +++ b/mediapipe/calculators/image/set_alpha_calculator.cc @@ -25,11 +25,12 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS namespace mediapipe { @@ -98,18 +99,18 @@ class SetAlphaCalculator : public CalculatorBase { ::mediapipe::Status RenderGpu(CalculatorContext* cc); ::mediapipe::Status RenderCpu(CalculatorContext* cc); - ::mediapipe::Status GlRender(CalculatorContext* cc); ::mediapipe::Status GlSetup(CalculatorContext* cc); + void GlRender(CalculatorContext* cc); mediapipe::SetAlphaCalculatorOptions options_; float alpha_value_ = -1.f; bool use_gpu_ = false; bool gpu_initialized_ = false; -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS }; REGISTER_CALCULATOR(SetAlphaCalculator); @@ -126,52 +127,54 @@ REGISTER_CALCULATOR(SetAlphaCalculator); } // Input image to add/edit alpha channel. -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().Tag(kInputFrameTagGpu).Set(); } -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS if (cc->Inputs().HasTag(kInputFrameTag)) { cc->Inputs().Tag(kInputFrameTag).Set(); } // Input alpha image mask (optional) -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Inputs().HasTag(kInputAlphaTagGpu)) { cc->Inputs().Tag(kInputAlphaTagGpu).Set(); } -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS if (cc->Inputs().HasTag(kInputAlphaTag)) { cc->Inputs().Tag(kInputAlphaTag).Set(); } // RGBA output image. -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().Tag(kOutputFrameTagGpu).Set(); } -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS if (cc->Outputs().HasTag(kOutputFrameTag)) { cc->Outputs().Tag(kOutputFrameTag).Set(); } -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } ::mediapipe::Status SetAlphaCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + options_ = cc->Options(); if (cc->Inputs().HasTag(kInputFrameTagGpu) && cc->Outputs().HasTag(kOutputFrameTagGpu)) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) use_gpu_ = true; #else - RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet."; -#endif // __ANDROID__ + RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; +#endif // __ANDROID__ or iOS } // Get global value from options (-1 if not set). @@ -184,7 +187,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator); RET_CHECK_FAIL() << "Must use either image mask or options alpha value."; if (use_gpu_) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) RETURN_IF_ERROR(gpu_helper_.Open(cc)); #endif } @@ -194,7 +197,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator); ::mediapipe::Status SetAlphaCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { if (!gpu_initialized_) { @@ -204,7 +207,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator); RETURN_IF_ERROR(RenderGpu(cc)); return ::mediapipe::OkStatus(); })); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS } else { RETURN_IF_ERROR(RenderCpu(cc)); } @@ -213,12 +216,12 @@ REGISTER_CALCULATOR(SetAlphaCalculator); } ::mediapipe::Status SetAlphaCalculator::Close(CalculatorContext* cc) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; }); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } @@ -292,7 +295,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator); if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) { return ::mediapipe::OkStatus(); } -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) // Setup source texture. const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); @@ -345,13 +348,13 @@ REGISTER_CALCULATOR(SetAlphaCalculator); // Cleanup input_texture.Release(); output_texture.Release(); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } -::mediapipe::Status SetAlphaCalculator::GlRender(CalculatorContext* cc) { -#if defined(__ANDROID__) +void SetAlphaCalculator::GlRender(CalculatorContext* cc) { +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -400,13 +403,11 @@ REGISTER_CALCULATOR(SetAlphaCalculator); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // __ANDROID__ - - return ::mediapipe::OkStatus(); +#endif // __ANDROID__ or iOS } ::mediapipe::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -417,6 +418,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator); }; // Shader to overlay a texture onto another when overlay is non-zero. + // TODO split into two shaders to handle alpha_value<0 separately const GLchar* frag_src = GLES_VERSION_COMPAT R"( #if __VERSION__ < 130 @@ -458,7 +460,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator); glUniform1i(glGetUniformLocation(program_, "alpha_mask"), 2); glUniform1f(glGetUniformLocation(program_, "alpha_value"), alpha_value_); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/image/set_alpha_calculator.proto b/mediapipe/calculators/image/set_alpha_calculator.proto index 0e2bc9732..3f07ddb31 100644 --- a/mediapipe/calculators/image/set_alpha_calculator.proto +++ b/mediapipe/calculators/image/set_alpha_calculator.proto @@ -1,4 +1,17 @@ -// Options for SetAlphaCalculator +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + syntax = "proto2"; package mediapipe; diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 66493aba6..f88481cf7 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -417,6 +417,9 @@ cc_library( "//mediapipe:android": [ "@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos", ], + "//mediapipe:ios": [ + "@org_tensorflow//tensorflow/core:ios_tensorflow_lib", + ], }), alwayslink = 1, ) @@ -435,6 +438,9 @@ cc_library( "//mediapipe:android": [ "@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos", ], + "//mediapipe:ios": [ + "@org_tensorflow//tensorflow/core:ios_tensorflow_lib", + ], }), ) @@ -459,6 +465,10 @@ cc_library( "@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos", "//mediapipe/android/file/base", ], + "//mediapipe:ios": [ + "@org_tensorflow//tensorflow/core:ios_tensorflow_lib", + "//mediapipe/android/file/base", + ], }), alwayslink = 1, ) @@ -783,6 +793,7 @@ cc_test( "@org_tensorflow//tensorflow/core/kernels:io", "@org_tensorflow//tensorflow/core/kernels:state", "@org_tensorflow//tensorflow/core/kernels:string", + "@org_tensorflow//tensorflow/core/kernels/data:tensor_dataset_op", ], ) @@ -813,6 +824,7 @@ cc_test( "@org_tensorflow//tensorflow/core/kernels:io", "@org_tensorflow//tensorflow/core/kernels:state", "@org_tensorflow//tensorflow/core/kernels:string", + "@org_tensorflow//tensorflow/core/kernels/data:tensor_dataset_op", ], ) @@ -949,6 +961,9 @@ cc_test( "//mediapipe:android": [ "@org_tensorflow//tensorflow/core:android_tensorflow_lib_with_ops_lite_proto_no_rtti_lib", ], + "//mediapipe:ios": [ + "@org_tensorflow//tensorflow/core:ios_tensorflow_test_lib", + ], }), ) diff --git a/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc index 31243e133..ca704b793 100644 --- a/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc @@ -118,7 +118,7 @@ REGISTER_CALCULATOR(MatrixToTensorCalculator); // Inform the framework that we always output at the same timestamp // as we receive a packet at. - cc->SetOffset(mediapipe::TimestampDiff(0)); + cc->SetOffset(TimestampDiff(0)); return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index 9a6c7c97d..7780d7850 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -37,6 +37,7 @@ const char kImageTag[] = "IMAGE"; const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; const char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; const char kBBoxTag[] = "BBOX"; +const char kKeypointsTag[] = "KEYPOINTS"; const char kSegmentationMaskTag[] = "CLASS_SEGMENTATION"; namespace tf = ::tensorflow; @@ -54,16 +55,11 @@ namespace mpms = ::mediapipe::mediasequence; // stores the encoded optical flow from the same calculator, "BBOX" which stores // bounding boxes from vector, and streams with the // "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector's -// associated with the name ${NAME}. Audio streams (i.e. Matrix with a -// TimeSeriesHeader) are given extra packing and unpacking support and are named -// similar to floats with the pattern "AUDIO_${NAME}". "IMAGE_${NAME}" and -// "BBOX_${NAME}" will also store prefixed versions of each stream, which allows -// for multiple image streams to be included. However, the default names are -// suppored by more tools. "ENCODED_MEDIA" stores a video encoding for the clip -// directly. The last packet on this stream is stored, and can be unpacked with -// the metadata generator. Because the media decoder always starts with -// timestamp zero, the "ENCODED_MEDIA_START_TIMESTAMP" should be recorded as -// well. Use the FirstTimestampCalculator to determine this value. +// associated with the name ${NAME}. "KEYPOINTS" stores a map of 2D keypoints +// from unordered_map>>. "IMAGE_${NAME}", +// "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store prefixed versions of +// each stream, which allows for multiple image streams to be included. However, +// the default names are suppored by more tools. // // Example config: // node { @@ -122,6 +118,21 @@ class PackMediaSequenceCalculator : public CalculatorBase { } cc->Inputs().Tag(tag).Set(); } + if (absl::StartsWith(tag, kKeypointsTag)) { + std::string key = ""; + if (tag != kKeypointsTag) { + int tag_length = sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1; + if (tag[tag_length] == '_') { + key = tag.substr(tag_length + 1); + } else { + continue; // Skip keys that don't match "(kKeypointsTag)_?" + } + } + cc->Inputs() + .Tag(tag) + .Set>>>(); + } if (absl::StartsWith(tag, kBBoxTag)) { std::string key = ""; if (tag != kBBoxTag) { @@ -169,8 +180,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { features_present_[tag] = false; } - if (cc->Options() - .GetExtension(PackMediaSequenceCalculatorOptions::ext) + if (cc->Options() .replace_data_instead_of_append()) { for (const auto& tag : cc->Inputs().GetTags()) { if (absl::StartsWith(tag, kImageTag)) { @@ -186,13 +196,19 @@ class PackMediaSequenceCalculator : public CalculatorBase { mpms::ClearImageEncoded(key, sequence_.get()); mpms::ClearImageTimestamp(key, sequence_.get()); } - } - if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) { - mpms::ClearForwardFlowEncoded(sequence_.get()); - mpms::ClearForwardFlowTimestamp(sequence_.get()); - } - - for (const auto& tag : cc->Inputs().GetTags()) { + if (absl::StartsWith(tag, kBBoxTag)) { + std::string key = ""; + if (tag != kBBoxTag) { + int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1; + if (tag[tag_length] == '_') { + key = tag.substr(tag_length + 1); + } else { + continue; // Skip keys that don't match "(kBBoxTag)_?" + } + } + mpms::ClearBBox(key, sequence_.get()); + mpms::ClearBBoxTimestamp(key, sequence_.get()); + } if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / sizeof(*kFloatFeaturePrefixTag) - @@ -200,6 +216,16 @@ class PackMediaSequenceCalculator : public CalculatorBase { mpms::ClearFeatureFloats(key, sequence_.get()); mpms::ClearFeatureTimestamp(key, sequence_.get()); } + if (absl::StartsWith(tag, kKeypointsTag)) { + std::string key = + tag.substr(sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1); + mpms::ClearBBoxPoint(key, sequence_.get()); + mpms::ClearBBoxTimestamp(key, sequence_.get()); + } + } + if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) { + mpms::ClearForwardFlowEncoded(sequence_.get()); + mpms::ClearForwardFlowTimestamp(sequence_.get()); } } @@ -228,11 +254,11 @@ class PackMediaSequenceCalculator : public CalculatorBase { } ::mediapipe::Status Close(CalculatorContext* cc) override { - auto& options = - cc->Options().GetExtension(PackMediaSequenceCalculatorOptions::ext); + auto& options = cc->Options(); if (options.reconcile_metadata()) { - RET_CHECK_OK(mpms::ReconcileMetadata(options.reconcile_bbox_annotations(), - sequence_.get())); + RET_CHECK_OK(mpms::ReconcileMetadata( + options.reconcile_bbox_annotations(), + options.reconcile_region_annotations(), sequence_.get())); } if (options.output_only_if_all_present()) { @@ -260,6 +286,9 @@ class PackMediaSequenceCalculator : public CalculatorBase { ::mediapipe::Status Process(CalculatorContext* cc) override { for (const auto& tag : cc->Inputs().GetTags()) { + if (!cc->Inputs().Tag(tag).IsEmpty()) { + features_present_[tag] = true; + } if (absl::StartsWith(tag, kImageTag) && !cc->Inputs().Tag(tag).IsEmpty()) { std::string key = ""; @@ -281,23 +310,29 @@ class PackMediaSequenceCalculator : public CalculatorBase { sequence_.get()); mpms::AddImageEncoded(key, image.encoded_image(), sequence_.get()); } - } - if (cc->Inputs().HasTag(kForwardFlowEncodedTag) && - !cc->Inputs().Tag(kForwardFlowEncodedTag).IsEmpty()) { - const OpenCvImageEncoderCalculatorResults& forward_flow = - cc->Inputs() - .Tag(kForwardFlowEncodedTag) - .Get(); - if (!forward_flow.has_encoded_image()) { - return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) - << "No encoded forward flow"; + if (absl::StartsWith(tag, kKeypointsTag) && + !cc->Inputs().Tag(tag).IsEmpty()) { + std::string key = ""; + if (tag != kImageTag) { + int tag_length = sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1; + if (tag[tag_length] == '_') { + key = tag.substr(tag_length + 1); + } else { + continue; // Skip keys that don't match "(kKeypointsTag)_?" + } + } + const auto& keypoints = + cc->Inputs() + .Tag(tag) + .Get>>>(); + for (const auto& pair : keypoints) { + mpms::AddBBoxTimestamp(mpms::merge_prefix(key, pair.first), + cc->InputTimestamp().Value(), sequence_.get()); + mpms::AddBBoxPoint(mpms::merge_prefix(key, pair.first), pair.second, + sequence_.get()); + } } - mpms::AddForwardFlowTimestamp(cc->InputTimestamp().Value(), - sequence_.get()); - mpms::AddForwardFlowEncoded(forward_flow.encoded_image(), - sequence_.get()); - } - for (const auto& tag : cc->Inputs().GetTags()) { if (absl::StartsWith(tag, kFloatFeaturePrefixTag) && !cc->Inputs().Tag(tag).IsEmpty()) { std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / @@ -309,8 +344,6 @@ class PackMediaSequenceCalculator : public CalculatorBase { cc->Inputs().Tag(tag).Get>(), sequence_.get()); } - } - for (const auto& tag : cc->Inputs().GetTags()) { if (absl::StartsWith(tag, kBBoxTag) && !cc->Inputs().Tag(tag).IsEmpty()) { std::string key = ""; if (tag != kBBoxTag) { @@ -349,15 +382,30 @@ class PackMediaSequenceCalculator : public CalculatorBase { mpms::AddBBoxTimestamp(key, cc->InputTimestamp().Value(), sequence_.get()); if (!predicted_class_strings.empty()) { - mpms::AddBBoxClassString(key, predicted_class_strings, + mpms::AddBBoxLabelString(key, predicted_class_strings, sequence_.get()); } if (!predicted_label_ids.empty()) { - mpms::AddBBoxClassIndex(key, predicted_label_ids, sequence_.get()); + mpms::AddBBoxLabelIndex(key, predicted_label_ids, sequence_.get()); } } } } + if (cc->Inputs().HasTag(kForwardFlowEncodedTag) && + !cc->Inputs().Tag(kForwardFlowEncodedTag).IsEmpty()) { + const OpenCvImageEncoderCalculatorResults& forward_flow = + cc->Inputs() + .Tag(kForwardFlowEncodedTag) + .Get(); + if (!forward_flow.has_encoded_image()) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "No encoded forward flow"; + } + mpms::AddForwardFlowTimestamp(cc->InputTimestamp().Value(), + sequence_.get()); + mpms::AddForwardFlowEncoded(forward_flow.encoded_image(), + sequence_.get()); + } if (cc->Inputs().HasTag(kSegmentationMaskTag) && !cc->Inputs().Tag(kSegmentationMaskTag).IsEmpty()) { bool already_has_mask = false; @@ -387,11 +435,6 @@ class PackMediaSequenceCalculator : public CalculatorBase { } } } - for (const auto& tag : cc->Inputs().GetTags()) { - if (!cc->Inputs().Tag(tag).IsEmpty()) { - features_present_[tag] = true; - } - } return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto index 53a6f73c2..1c5c559ee 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto @@ -32,18 +32,29 @@ message PackMediaSequenceCalculatorOptions { // (e.g. fills in the image height, width, and number of frames.) optional bool reconcile_metadata = 2 [default = true]; - // If true, updates the metadata for sequences with bounding boxes. This will - // align each bounding box annotation with the nearest frame and insert empty - // annotations as needed to satisfy the frame rate. + // If true, updates the metadata for bounding box portions of sequences. This + // will align each bounding box annotation with the nearest frame and insert + // empty annotations as needed to satisfy the frame rate. // NOTE: IF YOU DOWNSAMPLE IN TIME YOU WILL LOSE ANNOTATIONS. // If two or more annotations are closest to the same frame, then only // the closest annotation is saved. This matches the behavior of // downsampling images in time. - optional bool reconcile_bbox_annotations = 5 [default = true]; + optional bool reconcile_bbox_annotations = 5 [default = false]; + + // If true, updates the metadata for all regions annotations, regardless of + // prefix, in the sequence. This will align each region annotation with the + // nearest frame and insert empty annotations as needed to satisfy the frame + // rate. This is particularly useful for key point annotations that are + // represented as region points. This does not exclude bboxes. + // NOTE: IF YOU DOWNSAMPLE IN TIME YOU WILL LOSE ANNOTATIONS. + // If two or more annotations are closest to the same frame, then only + // the closest annotation is saved. This matches the behavior of + // downsampling images in time. + optional bool reconcile_region_annotations = 6 [default = true]; // If true, the SequenceExample is output only if all input streams are // present. - optional bool output_only_if_all_present = 3 [default = false]; + optional bool output_only_if_all_present = 3 [default = true]; // If true, will remove all data from a sequence example for a corresponding // input stream. E.g. if images are already present and an IMAGE stream is diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index d8931cfa8..033271528 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -348,15 +348,51 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) { ASSERT_NEAR(1.0, rect.ymax(), 0.001); } auto class_strings = - mpms::GetPredictedBBoxClassStringAt(output_sequence, i); + mpms::GetPredictedBBoxLabelStringAt(output_sequence, i); ASSERT_EQ("absolute bbox", class_strings[0]); ASSERT_EQ("relative bbox", class_strings[1]); - auto class_indices = mpms::GetPredictedBBoxClassIndexAt(output_sequence, i); + auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i); ASSERT_EQ(0, class_indices[0]); ASSERT_EQ(1, class_indices[1]); } } +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) { + SetUpCalculator({"KEYPOINTS_TEST:keypoints"}, {}, false, true); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + + std::unordered_map>> points = + {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}}; + runner_->MutableInputs() + ->Tag("KEYPOINTS_TEST") + .packets.push_back(PointToForeign(&points).At(Timestamp(0))); + runner_->MutableInputs() + ->Tag("KEYPOINTS_TEST") + .packets.push_back(PointToForeign(&points).At(Timestamp(1))); + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = + Adopt(input_sequence.release()); + + MEDIAPIPE_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); + ASSERT_EQ(2, mpms::GetBBoxPointSize("TEST/HEAD", output_sequence)); + ASSERT_EQ(2, mpms::GetBBoxPointSize("TEST/TAIL", output_sequence)); + ASSERT_NEAR(0.2, + mpms::GetBBoxPointAt("TEST/HEAD", output_sequence, 0)[0].second, + 0.001); + ASSERT_NEAR(0.5, + mpms::GetBBoxPointAt("TEST/TAIL", output_sequence, 1)[0].first, + 0.001); +} + TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) { SetUpCalculator({"CLASS_SEGMENTATION:detections"}, {}, false, true); auto input_sequence = ::absl::make_unique(); @@ -395,7 +431,6 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) { ASSERT_EQ(1, output_packets.size()); const tf::SequenceExample& output_sequence = output_packets[0].Get(); - LOG(INFO) << "output_sequence: \n" << output_sequence.DebugString(); ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); ASSERT_EQ(height, mpms::GetImageHeight(output_sequence)); @@ -602,6 +637,10 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { mpms::AddBBoxTimestamp(21, input_sequence.get()); mpms::AddBBoxTimestamp(22, input_sequence.get()); + mpms::AddBBoxTimestamp("PREFIX", 8, input_sequence.get()); + mpms::AddBBoxTimestamp("PREFIX", 9, input_sequence.get()); + mpms::AddBBoxTimestamp("PREFIX", 22, input_sequence.get()); + runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = Adopt(input_sequence.release()); MEDIAPIPE_ASSERT_OK(runner_->Run()); @@ -617,6 +656,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 2), 30); ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 3), 40); ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 4), 50); + + ASSERT_EQ(mpms::GetBBoxTimestampSize("PREFIX", output_sequence), 5); + ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 0), 10); + ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 1), 20); + ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 2), 30); + ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 3), 40); + ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 4), 50); } } // namespace diff --git a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc index b061fe7b3..9c7f3458c 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc @@ -183,7 +183,7 @@ REGISTER_CALCULATOR(TensorToMatrixCalculator); header_ = *input_header; cc->Outputs().Tag(kMatrix).SetHeader(Adopt(input_header.release())); } - cc->SetOffset(mediapipe::TimestampDiff(0)); + cc->SetOffset(TimestampDiff(0)); return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc index 196f91e2e..51493d7b6 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc @@ -232,8 +232,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { current_timestamp_index_ = 0; // Determine the data path and output it. - const auto& options = - cc->Options().GetExtension(UnpackMediaSequenceCalculatorOptions::ext); + const auto& options = cc->Options(); const auto& sequence = cc->InputSidePackets() .Tag(kSequenceExampleTag) .Get(); @@ -338,7 +337,6 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { ? Timestamp::PostStream() : Timestamp(map_kv.second[i]); - LOG(INFO) << "key: " << map_kv.first; if (absl::StrContains(map_kv.first, mpms::GetImageTimestampKey())) { std::vector pieces = absl::StrSplit(map_kv.first, '/'); std::string feature_key = ""; diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto index 7088ff076..e6e839645 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.proto @@ -40,7 +40,7 @@ message UnpackMediaSequenceCalculatorOptions { optional float extra_padding_from_media_decoder = 5 [default = 0.0]; // Stores the packet resampler settings for the graph. The most accurate - // proceedure for sampling a range of frames is to request a padded time range + // procedure for sampling a range of frames is to request a padded time range // from the MediaDecoderCalculator and then trim it down to the proper time // range with the PacketResamplerCalculator. optional PacketResamplerCalculatorOptions base_packet_resampler_options = 6; diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc index 0d582fb2e..d8492a9dc 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -460,6 +460,8 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) { } TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) { + // TODO: Suport proto3 proto.Any in CalculatorOptions. + // TODO: Avoid proto2 extensions in "RESAMPLER_OPTIONS". CalculatorOptions options; options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) ->set_padding_before_label(1); diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 374f0f11e..36d246eef 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -61,6 +61,13 @@ proto_library( deps = ["//mediapipe/framework:calculator_proto"], ) +proto_library( + name = "tflite_tensors_to_landmarks_calculator_proto", + srcs = ["tflite_tensors_to_landmarks_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + mediapipe_cc_proto_library( name = "ssd_anchors_calculator_cc_proto", srcs = ["ssd_anchors_calculator.proto"], @@ -109,6 +116,14 @@ mediapipe_cc_proto_library( deps = [":tflite_tensors_to_detections_calculator_proto"], ) +mediapipe_cc_proto_library( + name = "tflite_tensors_to_landmarks_calculator_cc_proto", + srcs = ["tflite_tensors_to_landmarks_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":tflite_tensors_to_landmarks_calculator_proto"], +) + cc_library( name = "ssd_anchors_calculator", srcs = ["ssd_anchors_calculator.cc"], @@ -168,6 +183,21 @@ cc_test( cc_library( name = "tflite_inference_calculator", srcs = ["tflite_inference_calculator.cc"], + copts = select({ + "//mediapipe:ios": [ + "-std=c++11", + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + linkopts = select({ + "//mediapipe:ios": [ + "-framework CoreVideo", + "-framework MetalKit", + ], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ ":tflite_inference_calculator_cc_proto", @@ -179,12 +209,17 @@ cc_library( "//mediapipe/framework/port:ret_check", ] + select({ "//mediapipe:android": [ - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gpu_buffer", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", + ], + "//mediapipe:ios": [ + "//mediapipe/gpu:MPPMetalHelper", + "//mediapipe/objc:mediapipe_framework_ios", + "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", ], "//conditions:default": [], }), @@ -194,12 +229,28 @@ cc_library( cc_library( name = "tflite_converter_calculator", srcs = ["tflite_converter_calculator.cc"], + copts = select({ + "//mediapipe:ios": [ + "-std=c++11", + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + linkopts = select({ + "//mediapipe:ios": [ + "-framework CoreVideo", + "-framework MetalKit", + ], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ ":tflite_converter_calculator_cc_proto", "//mediapipe/util:resource_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:matrix", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/tool:status_util", "//mediapipe/framework/port:status", @@ -208,13 +259,19 @@ cc_library( "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ] + select({ "//mediapipe:android": [ - "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gpu_buffer", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", ], + "//mediapipe:ios": [ + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:MPPMetalHelper", + "//mediapipe/objc:mediapipe_framework_ios", + "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", + ], "//conditions:default": [], }), alwayslink = 1, @@ -229,6 +286,9 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -279,6 +339,32 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tflite_tensors_to_landmarks_calculator", + srcs = ["tflite_tensors_to_landmarks_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":tflite_tensors_to_landmarks_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/lite:framework", + ], + alwayslink = 1, +) + +cc_library( + name = "tflite_tensors_to_floats_calculator", + srcs = ["tflite_tensors_to_floats_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/lite:framework", + ], + alwayslink = 1, +) + cc_test( name = "tflite_inference_calculator_test", srcs = ["tflite_inference_calculator_test.cc"], @@ -298,3 +384,24 @@ cc_test( "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], ) + +cc_test( + name = "tflite_converter_calculator_test", + srcs = ["tflite_converter_calculator_test.cc"], + deps = [ + ":tflite_converter_calculator", + ":tflite_converter_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:validate_type", + "@com_google_absl//absl/memory", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator.cc b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc index 2dc60b990..086636245 100644 --- a/mediapipe/calculators/tflite/ssd_anchors_calculator.cc +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc @@ -73,6 +73,8 @@ class SsdAnchorsCalculator : public CalculatorBase { } ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + const SsdAnchorsCalculatorOptions& options = cc->Options(); diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.cc b/mediapipe/calculators/tflite/tflite_converter_calculator.cc index 7ef0e246b..6a6306df2 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.cc @@ -12,9 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include + #include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/resource_util.h" #include "tensorflow/lite/error_reporter.h" @@ -22,18 +27,41 @@ #if defined(__ANDROID__) #include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // ANDROID +#endif // __ANDROID__ + +#if defined(__APPLE__) && !TARGET_OS_OSX // iOS +#import +#import +#import + +#import "mediapipe/gpu/MPPMetalHelper.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "tensorflow/lite/delegates/gpu/metal_delegate.h" +#endif // iOS + +#if defined(__ANDROID__) +typedef ::tflite::gpu::gl::GlBuffer GpuTensor; +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS +typedef id GpuTensor; +#endif namespace { constexpr int kWorkgroupSize = 8; // Block size for GPU shader. // Commonly used to compute the number of blocks to launch in a kernel. -int RoundUp(const int size, const int multiple) { - return (size + multiple - 1) / multiple; +int NumGroups(const int size, const int group_size) { // NOLINT + return (size + group_size - 1) / group_size; } + +typedef Eigen::Matrix + RowMajorMatrixXf; +typedef Eigen::Matrix + ColMajorMatrixXf; + } // namespace namespace mediapipe { @@ -43,30 +71,38 @@ using ::tflite::gpu::gl::GlBuffer; using ::tflite::gpu::gl::GlProgram; using ::tflite::gpu::gl::GlShader; struct GPUData { - int width; - int height; - int channels; - GlBuffer ssbo; + int elements = 1; + GlBuffer buffer; GlShader shader; GlProgram program; }; -#endif // ANDROID +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS +struct GPUData { + int elements = 1; + id buffer; + id pipeline_state; +}; +#endif -// Calculator for normalizing and converting an ImageFrame or GpuBuffer -// into a TfLiteTensor (float 32) or tflite::gpu::GlBuffer, respetively. +// Calculator for normalizing and converting an ImageFrame or Matrix +// into a TfLiteTensor (float 32) or a GpuBuffer to a tflite::gpu::GlBuffer. // // This calculator is designed to be used with the TfLiteInferenceCalcualtor, // as a pre-processing step for calculator inputs. // -// Input data is normalized to [-1,1] (default) or [0,1], specified by options. +// IMAGE and IMAGE_GPU inputs are normalized to [-1,1] (default) or [0,1], +// specified by options (unless outputting a quantized tensor). // // Input: +// One of the following tags: // IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data). -// IMAGE_GPU - GpuBuffer (assumed to be RGBA or RGB GL texture) +// IMAGE_GPU - GpuBuffer (assumed to be RGBA or RGB GL texture). +// MATRIX - Matrix. // // Output: -// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 -// TENSORS_GPU - vector of GlBuffer +// One of the following tags: +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8. +// TENSORS_GPU - vector of GlBuffer. // // Example use: // node { @@ -83,7 +119,7 @@ struct GPUData { // IMPORTANT Notes: // No conversion between CPU/GPU is done. // Inputs/outputs must match type: CPU->CPU or GPU->GPU. -// GPU tensors are currently only supported on Android. +// GPU tensors are currently only supported on mobile platforms. // This calculator uses FixedSizeInputStreamHandler by default. // class TfLiteConverterCalculator : public CalculatorBase { @@ -101,43 +137,62 @@ class TfLiteConverterCalculator : public CalculatorBase { ::mediapipe::Status NormalizeImage(const ImageFrame& image_frame, bool zero_center, bool flip_vertically, float* tensor_buffer); + ::mediapipe::Status CopyMatrixToTensor(const Matrix& matrix, + float* tensor_buffer); + ::mediapipe::Status ProcessCPU(CalculatorContext* cc); + ::mediapipe::Status ProcessGPU(CalculatorContext* cc); std::unique_ptr interpreter_ = nullptr; #if defined(__ANDROID__) mediapipe::GlCalculatorHelper gpu_helper_; std::unique_ptr gpu_data_out_; +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + MPPMetalHelper* gpu_helper_ = nullptr; + std::unique_ptr gpu_data_out_; #endif bool initialized_ = false; bool use_gpu_ = false; bool zero_center_ = true; // normalize range to [-1,1] | otherwise [0,1] bool flip_vertically_ = false; + bool row_major_matrix_ = false; + bool use_quantized_tensors_ = false; int max_num_channels_ = 3; }; REGISTER_CALCULATOR(TfLiteConverterCalculator); ::mediapipe::Status TfLiteConverterCalculator::GetContract( CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("IMAGE") || cc->Inputs().HasTag("IMAGE_GPU")); - RET_CHECK(cc->Outputs().HasTag("TENSORS") || + const bool has_image_tag = cc->Inputs().HasTag("IMAGE"); + const bool has_image_gpu_tag = cc->Inputs().HasTag("IMAGE_GPU"); + const bool has_matrix_tag = cc->Inputs().HasTag("MATRIX"); + // Confirm only one of the input streams is present. + RET_CHECK(has_image_tag ^ has_image_gpu_tag ^ has_matrix_tag && + !(has_image_tag && has_image_gpu_tag && has_matrix_tag)); + + // Confirm only one of the output streams is present. + RET_CHECK(cc->Outputs().HasTag("TENSORS") ^ cc->Outputs().HasTag("TENSORS_GPU")); if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set(); -#if defined(__ANDROID__) + if (cc->Inputs().HasTag("MATRIX")) cc->Inputs().Tag("MATRIX").Set(); +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Inputs().HasTag("IMAGE_GPU")) cc->Inputs().Tag("IMAGE_GPU").Set(); #endif if (cc->Outputs().HasTag("TENSORS")) cc->Outputs().Tag("TENSORS").Set>(); -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Outputs().HasTag("TENSORS_GPU")) - cc->Outputs().Tag("TENSORS_GPU").Set>(); + cc->Outputs().Tag("TENSORS_GPU").Set>(); #endif #if defined(__ANDROID__) RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); #endif // Assign this calculator's default InputStreamHandler. @@ -147,14 +202,16 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); } ::mediapipe::Status TfLiteConverterCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + RETURN_IF_ERROR(LoadOptions(cc)); if (cc->Inputs().HasTag("IMAGE_GPU") || cc->Outputs().HasTag("IMAGE_OUT_GPU")) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) use_gpu_ = true; #else - RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet."; + RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; #endif } @@ -162,8 +219,13 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); // Cannot mix CPU/GPU streams. RET_CHECK(cc->Inputs().HasTag("IMAGE_GPU") && cc->Outputs().HasTag("TENSORS_GPU")); + // Cannot use quantization. + use_quantized_tensors_ = false; #if defined(__ANDROID__) RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(gpu_helper_); #endif } else { interpreter_ = absl::make_unique(); @@ -176,77 +238,59 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); ::mediapipe::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { - // GpuBuffer to tflite::gpu::GlBuffer conversion. -#if defined(__ANDROID__) if (!initialized_) { RETURN_IF_ERROR(InitGpu(cc)); initialized_ = true; } - - const auto& input = - cc->Inputs().Tag("IMAGE_GPU").Get(); - RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, &input]() -> ::mediapipe::Status { - // Convert GL texture into TfLite GlBuffer (SSBO). - auto src = gpu_helper_.CreateSourceTexture(input); - glActiveTexture(GL_TEXTURE0 + 0); - glBindTexture(GL_TEXTURE_2D, src.name()); - auto status = gpu_data_out_->ssbo.BindToIndex(1); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - const tflite::gpu::uint3 workgroups = { - RoundUp(gpu_data_out_->width, kWorkgroupSize), - RoundUp(gpu_data_out_->height, kWorkgroupSize), 1}; - status = gpu_data_out_->program.Dispatch(workgroups); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); - glBindTexture(GL_TEXTURE_2D, 0); - src.Release(); - return ::mediapipe::OkStatus(); - })); - - auto output_tensors = absl::make_unique>(); - output_tensors->resize(1); - for (int i = 0; i < 1; ++i) { - GlBuffer& tensor = output_tensors->at(i); - using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; - auto status = CreateReadWriteShaderStorageBuffer( - gpu_data_out_->width * gpu_data_out_->height * - gpu_data_out_->channels, - &tensor); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - tflite::gpu::gl::CopyBuffer(gpu_data_out_->ssbo, tensor); - } - cc->Outputs() - .Tag("TENSORS_GPU") - .Add(output_tensors.release(), cc->InputTimestamp()); -#else - RET_CHECK_FAIL() - << "GPU input on non-Android devices is not supported yet."; -#endif + // Convert to GPU tensors type. + RETURN_IF_ERROR(ProcessGPU(cc)); } else { + // Convert to CPU tensors or Matrix type. + RETURN_IF_ERROR(ProcessCPU(cc)); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { +#if defined(__ANDROID__) + gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); }); +#endif +#if defined(__APPLE__) && !TARGET_OS_OSX // iOS + gpu_data_out_.reset(); +#endif + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteConverterCalculator::ProcessCPU( + CalculatorContext* cc) { + if (cc->Inputs().HasTag("IMAGE")) { // CPU ImageFrame to TfLiteTensor conversion. const auto& image_frame = cc->Inputs().Tag("IMAGE").Get(); const int height = image_frame.Height(); const int width = image_frame.Width(); - const int channels_preserved = - std::min(image_frame.NumberOfChannels(), max_num_channels_); - - if (!(image_frame.Format() == mediapipe::ImageFormat::SRGBA || - image_frame.Format() == mediapipe::ImageFormat::SRGB || - image_frame.Format() == mediapipe::ImageFormat::GRAY8 || - image_frame.Format() == mediapipe::ImageFormat::VEC32F1)) - RET_CHECK_FAIL() << "Unsupported CPU input format."; + const int channels = image_frame.NumberOfChannels(); + const int channels_preserved = std::min(channels, max_num_channels_); if (!initialized_) { - interpreter_->SetTensorParametersReadWrite( - 0, kTfLiteFloat32, "", {channels_preserved}, TfLiteQuantization()); + if (!(image_frame.Format() == mediapipe::ImageFormat::SRGBA || + image_frame.Format() == mediapipe::ImageFormat::SRGB || + image_frame.Format() == mediapipe::ImageFormat::GRAY8 || + image_frame.Format() == mediapipe::ImageFormat::VEC32F1)) + RET_CHECK_FAIL() << "Unsupported CPU input format."; + TfLiteQuantization quant; + if (use_quantized_tensors_) { + RET_CHECK(image_frame.Format() != mediapipe::ImageFormat::VEC32F1) + << "Only 8-bit input images are supported for quantization."; + // Optional: Set 'quant' quantization params here if needed. + interpreter_->SetTensorParametersReadWrite(0, kTfLiteUInt8, "", + {channels_preserved}, quant); + } else { + // Default TfLiteQuantization used for no quantization. + interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", + {channels_preserved}, quant); + } initialized_ = true; } @@ -256,19 +300,66 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); {height, width, channels_preserved}); interpreter_->AllocateTensors(); + // Copy image data into tensor. + if (use_quantized_tensors_) { + const int width_padding = + image_frame.WidthStep() / image_frame.ByteDepth() - width * channels; + const uint8* image_buffer = + reinterpret_cast(image_frame.PixelData()); + uint8* tensor_buffer = tensor->data.uint8; + RET_CHECK(tensor_buffer); + for (int row = 0; row < height; ++row) { + for (int col = 0; col < width; ++col) { + for (int channel = 0; channel < channels_preserved; ++channel) { + *tensor_buffer++ = image_buffer[channel]; + } + image_buffer += channels; + } + image_buffer += width_padding; + } + } else { + float* tensor_buffer = tensor->data.f; + RET_CHECK(tensor_buffer); + if (image_frame.ByteDepth() == 1) { + RETURN_IF_ERROR(NormalizeImage(image_frame, zero_center_, + flip_vertically_, tensor_buffer)); + } else if (image_frame.ByteDepth() == 4) { + RETURN_IF_ERROR(NormalizeImage(image_frame, zero_center_, + flip_vertically_, tensor_buffer)); + } else { + return ::mediapipe::InternalError( + "Only byte-based (8 bit) and float (32 bit) images supported."); + } + } + + auto output_tensors = absl::make_unique>(); + output_tensors->emplace_back(*tensor); + cc->Outputs().Tag("TENSORS").Add(output_tensors.release(), + cc->InputTimestamp()); + } else if (cc->Inputs().HasTag("MATRIX")) { + // CPU Matrix to TfLiteTensor conversion. + + const auto& matrix = cc->Inputs().Tag("MATRIX").Get(); + const int height = matrix.rows(); + const int width = matrix.cols(); + const int channels = 1; + + if (!initialized_) { + interpreter_->SetTensorParametersReadWrite( + /*tensor_index=*/0, /*type=*/kTfLiteFloat32, /*name=*/"", + /*dims=*/{channels}, /*quantization=*/TfLiteQuantization()); + initialized_ = true; + } + + const int tensor_idx = interpreter_->inputs()[0]; + TfLiteTensor* tensor = interpreter_->tensor(tensor_idx); + interpreter_->ResizeInputTensor(tensor_idx, {height, width, channels}); + interpreter_->AllocateTensors(); + float* tensor_buffer = tensor->data.f; RET_CHECK(tensor_buffer); - if (image_frame.ByteDepth() == 1) { - RETURN_IF_ERROR(NormalizeImage(image_frame, zero_center_, - flip_vertically_, tensor_buffer)); - } else if (image_frame.ByteDepth() == 4) { - RETURN_IF_ERROR(NormalizeImage(image_frame, zero_center_, - flip_vertically_, tensor_buffer)); - } else { - return ::mediapipe::InternalError( - "Only byte-based (8 bit) and float (32 bit) images supported."); - } + RETURN_IF_ERROR(CopyMatrixToTensor(matrix, tensor_buffer)); auto output_tensors = absl::make_unique>(); output_tensors->emplace_back(*tensor); @@ -279,43 +370,132 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); return ::mediapipe::OkStatus(); } -::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { +::mediapipe::Status TfLiteConverterCalculator::ProcessGPU( + CalculatorContext* cc) { #if defined(__ANDROID__) - gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); }); -#endif // __ANDROID__ + // GpuBuffer to tflite::gpu::GlBuffer conversion. + const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get(); + RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, &input]() -> ::mediapipe::Status { + // Convert GL texture into TfLite GlBuffer (SSBO). + auto src = gpu_helper_.CreateSourceTexture(input); + glActiveTexture(GL_TEXTURE0 + 0); + glBindTexture(GL_TEXTURE_2D, src.name()); + auto status = gpu_data_out_->buffer.BindToIndex(1); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + const tflite::gpu::uint3 workgroups = { + NumGroups(input.width(), kWorkgroupSize), + NumGroups(input.height(), kWorkgroupSize), 1}; + status = gpu_data_out_->program.Dispatch(workgroups); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); + glBindTexture(GL_TEXTURE_2D, 0); + src.Release(); + return ::mediapipe::OkStatus(); + })); + + // Copy into outputs. + auto output_tensors = absl::make_unique>(); + output_tensors->resize(1); + { + GlBuffer& tensor = output_tensors->at(0); + using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; + auto status = CreateReadWriteShaderStorageBuffer( + gpu_data_out_->elements, &tensor); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + tflite::gpu::gl::CopyBuffer(gpu_data_out_->buffer, tensor); + } + cc->Outputs() + .Tag("TENSORS_GPU") + .Add(output_tensors.release(), cc->InputTimestamp()); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + // GpuBuffer to id conversion. + const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get(); + { + id src_texture = [gpu_helper_ metalTextureWithGpuBuffer:input]; + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"TfLiteConverterCalculatorConvert"; + id compute_encoder = + [command_buffer computeCommandEncoder]; + [compute_encoder setComputePipelineState:gpu_data_out_->pipeline_state]; + [compute_encoder setTexture:src_texture atIndex:0]; + [compute_encoder setBuffer:gpu_data_out_->buffer offset:0 atIndex:1]; + MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1); + MTLSize threadgroups = + MTLSizeMake(NumGroups(input.width(), kWorkgroupSize), + NumGroups(input.height(), kWorkgroupSize), 1); + [compute_encoder dispatchThreadgroups:threadgroups + threadsPerThreadgroup:threads_per_group]; + [compute_encoder endEncoding]; + [command_buffer commit]; + [command_buffer waitUntilCompleted]; + } + + // Copy into outputs. + auto output_tensors = absl::make_unique>(); + { + id device = gpu_helper_.mtlDevice; + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"TfLiteConverterCalculatorCopy"; + id tensor = + [device newBufferWithLength:gpu_data_out_->elements * sizeof(float) + options:MTLResourceStorageModeShared]; + id blit_command = + [command_buffer blitCommandEncoder]; + [blit_command copyFromBuffer:gpu_data_out_->buffer + sourceOffset:0 + toBuffer:tensor + destinationOffset:0 + size:gpu_data_out_->elements * sizeof(float)]; + [blit_command endEncoding]; + [command_buffer commit]; + [command_buffer waitUntilCompleted]; + + output_tensors->push_back(tensor); + } + + cc->Outputs() + .Tag("TENSORS_GPU") + .Add(output_tensors.release(), cc->InputTimestamp()); +#else + RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; +#endif + return ::mediapipe::OkStatus(); } ::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { -#if defined(__ANDROID__) - // Get input image sizes. +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + // Configure inputs. const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get(); - mediapipe::ImageFormat::Format format = mediapipe::ImageFormatForGpuBufferFormat(input.format()); - gpu_data_out_ = absl::make_unique(); - gpu_data_out_->height = input.height(); - gpu_data_out_->width = input.width(); - gpu_data_out_->channels = max_num_channels_; // desired output channels - + gpu_data_out_->elements = input.height() * input.width() * max_num_channels_; const bool include_alpha = (max_num_channels_ == 4); - if (!(format == mediapipe::ImageFormat::SRGB || format == mediapipe::ImageFormat::SRGBA)) RET_CHECK_FAIL() << "Unsupported GPU input format."; - if (include_alpha && (format != mediapipe::ImageFormat::SRGBA)) RET_CHECK_FAIL() << "Num input channels is less than desired output."; +#endif - // Shader to convert GL Texture to Shader Storage Buffer Object (SSBO), - // with normalization to either: [0,1] or [-1,1]. +#if defined(__ANDROID__) + // Device memory. auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( - gpu_data_out_->width * gpu_data_out_->height * gpu_data_out_->channels, - &gpu_data_out_->ssbo); + gpu_data_out_->elements, &gpu_data_out_->buffer); if (!status.ok()) { return ::mediapipe::InternalError(status.error_message()); } + + // Shader to convert GL Texture to Shader Storage Buffer Object (SSBO), + // with normalization to either: [0,1] or [-1,1]. const std::string shader_source = absl::Substitute( R"( #version 310 es layout(local_size_x = $0, local_size_y = $0) in; @@ -333,8 +513,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); output_data.elements[linear_index + 2] = pixel.z; $6 // alpha channel })", - /*$0=*/kWorkgroupSize, /*$1=*/gpu_data_out_->width, - /*$2=*/gpu_data_out_->height, + /*$0=*/kWorkgroupSize, /*$1=*/input.width(), /*$2=*/input.height(), /*$3=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "", /*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y", /*$5=*/ @@ -353,7 +532,65 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); if (!status.ok()) { return ::mediapipe::InternalError(status.error_message()); } -#endif // ANDROID +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + RET_CHECK(include_alpha) + << "iOS GPU inference currently accepts only RGBA input."; + + // Device memory. + id device = gpu_helper_.mtlDevice; + gpu_data_out_->buffer = + [device newBufferWithLength:gpu_data_out_->elements * sizeof(float) + options:MTLResourceStorageModeShared]; + + // Shader to convert GL Texture to Metal Buffer, + // with normalization to either: [0,1] or [-1,1]. + const std::string shader_source = absl::Substitute( + R"( + #include + + #include + + using namespace metal; + + kernel void convertKernel( + texture2d in_tex [[ texture(0) ]], + device float* out_buf [[ buffer(1) ]], + uint2 gid [[ thread_position_in_grid ]]) { + if (gid.x >= in_tex.get_width() || gid.y >= in_tex.get_height()) return; + constexpr sampler texture_sampler(coord::pixel, address::clamp_to_edge); + const float2 coord = float2(gid.x, gid.y); + $0 pixel = $0(in_tex.sample(texture_sampler, coord).$1); + $2 // normalize [-1,1] + const int linear_index = $4 * ($3 * in_tex.get_width() + gid.x); + out_buf[linear_index + 0] = pixel.x; + out_buf[linear_index + 1] = pixel.y; + out_buf[linear_index + 2] = pixel.z; + $5 // alpha channel + } + )", + /*$0=*/include_alpha ? "float4" : "float3", + /*$1=*/include_alpha ? "rgba" : "rgb", + /*$2=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "", + /*$3=*/flip_vertically_ ? "(in_tex.get_height() - 1 - gid.y)" : "gid.y", + /*$4=*/include_alpha ? 4 : 3, + /*$5=*/include_alpha ? "out_buf[linear_index + 3] = pixel.w;" : ""); + + NSString* library_source = + [NSString stringWithUTF8String:shader_source.c_str()]; + NSError* error = nil; + id library = + [device newLibraryWithSource:library_source options:nullptr error:&error]; + RET_CHECK(library != nil) << "Couldn't create shader library " + << [[error localizedDescription] UTF8String]; + id kernel_func = nil; + kernel_func = [library newFunctionWithName:@"convertKernel"]; + RET_CHECK(kernel_func != nil) << "Couldn't create kernel function."; + gpu_data_out_->pipeline_state = + [device newComputePipelineStateWithFunction:kernel_func error:&error]; + RET_CHECK(gpu_data_out_->pipeline_state != nil) + << "Couldn't create pipeline state " + << [[error localizedDescription] UTF8String]; +#endif return ::mediapipe::OkStatus(); } @@ -370,11 +607,23 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); // Get y-flip mode. flip_vertically_ = options.flip_vertically(); + // Get row_major_matrix mode. + row_major_matrix_ = options.row_major_matrix(); + // Get desired way to handle input channels. max_num_channels_ = options.max_num_channels(); // Currently only alpha channel toggling is suppored. CHECK_GE(max_num_channels_, 3); CHECK_LE(max_num_channels_, 4); +#if defined(__APPLE__) && !TARGET_OS_OSX // iOS + if (cc->Inputs().HasTag("IMAGE_GPU")) + // Currently on iOS, tflite gpu input tensor must be 4 channels, + // so input image must be 4 channels also (checked in InitGpu). + max_num_channels_ = 4; +#endif + + // Get tensor type, float or quantized. + use_quantized_tensors_ = options.use_quantized_tensors(); return ::mediapipe::OkStatus(); } @@ -415,4 +664,19 @@ template return ::mediapipe::OkStatus(); } +::mediapipe::Status TfLiteConverterCalculator::CopyMatrixToTensor( + const Matrix& matrix, float* tensor_buffer) { + if (row_major_matrix_) { + auto matrix_map = Eigen::Map(tensor_buffer, matrix.rows(), + matrix.cols()); + matrix_map = matrix; + } else { + auto matrix_map = Eigen::Map(tensor_buffer, matrix.rows(), + matrix.cols()); + matrix_map = matrix; + } + + return ::mediapipe::OkStatus(); +} + } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.proto b/mediapipe/calculators/tflite/tflite_converter_calculator.proto index f4c931c11..3be32b347 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.proto @@ -22,9 +22,10 @@ message TfLiteConverterCalculatorOptions { optional TfLiteConverterCalculatorOptions ext = 245817797; } - // Choose normalization mode for output: + // Choose normalization mode for output (not applied for Matrix inputs). // true = [-1,1] // false = [0,1] + // Ignored if using quantization. optional bool zero_center = 1 [default = true]; // Whether the input image should be flipped vertically (along the @@ -38,4 +39,12 @@ message TfLiteConverterCalculatorOptions { // tensor. Currently this only controls whether or not to ignore alpha // channel, so it must be 3 or 4. optional int32 max_num_channels = 3 [default = 3]; + + // The calculator expects Matrix inputs to be in column-major order. Set + // row_major_matrix to true if the inputs are in row-major order. + optional bool row_major_matrix = 4 [default = false]; + + // Quantization option (CPU only). + // When true, output kTfLiteUInt8 tensor instead of kTfLiteFloat32. + optional bool use_quantized_tensors = 5 [default = false]; } diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator_test.cc b/mediapipe/calculators/tflite/tflite_converter_calculator_test.cc new file mode 100644 index 000000000..3360586a2 --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_converter_calculator_test.cc @@ -0,0 +1,199 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT +#include "mediapipe/framework/tool/validate_type.h" +#include "tensorflow/lite/interpreter.h" + +namespace mediapipe { + +namespace { + +constexpr char kTransposeOptionsString[] = + "[mediapipe.TfLiteConverterCalculatorOptions.ext]: {" + "row_major_matrix: True}"; + +} // namespace + +using RandomEngine = std::mt19937_64; +const uint32 kSeed = 1234; +const int kNumSizes = 8; +const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, + {5, 3}, {7, 13}, {16, 32}, {101, 2}}; + +class TfLiteConverterCalculatorTest : public ::testing::Test { + protected: + // Adds a packet with a matrix filled with random values in [0,1]. + void AddRandomMatrix(int num_rows, int num_columns, uint32 seed, + bool row_major_matrix = false) { + RandomEngine random(kSeed); + std::uniform_real_distribution<> uniform_dist(0, 1.0); + auto matrix = ::absl::make_unique(); + matrix->resize(num_rows, num_columns); + if (row_major_matrix) { + for (int y = 0; y < num_rows; ++y) { + for (int x = 0; x < num_columns; ++x) { + float value = uniform_dist(random); + (*matrix)(y, x) = value; + } + } + } else { + for (int x = 0; x < num_columns; ++x) { + for (int y = 0; y < num_rows; ++y) { + float value = uniform_dist(random); + (*matrix)(y, x) = value; + } + } + } + MEDIAPIPE_ASSERT_OK(graph_->AddPacketToInputStream( + "matrix", Adopt(matrix.release()).At(Timestamp(0)))); + } + + std::unique_ptr graph_; +}; + +TEST_F(TfLiteConverterCalculatorTest, RandomMatrixColMajor) { + for (int size_index = 0; size_index < kNumSizes; ++size_index) { + const int num_rows = sizes[size_index][0]; + const int num_columns = sizes[size_index][1]; + + // Run the calculator and verify that one output is generated. + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "matrix" + node { + calculator: "TfLiteConverterCalculator" + input_stream: "MATRIX:matrix" + output_stream: "TENSORS:tensor" + options { + [mediapipe.TfLiteConverterCalculatorOptions.ext] { + row_major_matrix: false + } + } + } + )"); + std::vector output_packets; + tool::AddVectorSink("tensor", &graph_config, &output_packets); + + // Run the graph. + graph_ = absl::make_unique(); + MEDIAPIPE_ASSERT_OK(graph_->Initialize(graph_config)); + MEDIAPIPE_ASSERT_OK(graph_->StartRun({})); + + // Push the tensor into the graph. + AddRandomMatrix(num_rows, num_columns, kSeed, /*row_major_matrix=*/false); + + // Wait until the calculator done processing. + MEDIAPIPE_ASSERT_OK(graph_->WaitUntilIdle()); + EXPECT_EQ(1, output_packets.size()); + + // Get and process results. + const std::vector& tensor_vec = + output_packets[0].Get>(); + EXPECT_EQ(1, tensor_vec.size()); + + const TfLiteTensor* tensor = &tensor_vec[0]; + EXPECT_EQ(kTfLiteFloat32, tensor->type); + + // Verify that the data is correct. + RandomEngine random(kSeed); + std::uniform_real_distribution<> uniform_dist(0, 1.0); + const float* tensor_buffer = tensor->data.f; + for (int i = 0; i < num_rows * num_columns; ++i) { + const float expected = uniform_dist(random); + EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i; + } + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MEDIAPIPE_ASSERT_OK(graph_->CloseInputStream("matrix")); + MEDIAPIPE_ASSERT_OK(graph_->WaitUntilDone()); + + graph_.reset(); + } +} + +TEST_F(TfLiteConverterCalculatorTest, RandomMatrixRowMajor) { + for (int size_index = 0; size_index < kNumSizes; ++size_index) { + const int num_rows = sizes[size_index][0]; + const int num_columns = sizes[size_index][1]; + + // Run the calculator and verify that one output is generated. + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "matrix" + node { + calculator: "TfLiteConverterCalculator" + input_stream: "MATRIX:matrix" + output_stream: "TENSORS:tensor" + options { + [mediapipe.TfLiteConverterCalculatorOptions.ext] { + row_major_matrix: true + } + } + } + )"); + std::vector output_packets; + tool::AddVectorSink("tensor", &graph_config, &output_packets); + + // Run the graph. + graph_ = absl::make_unique(); + MEDIAPIPE_ASSERT_OK(graph_->Initialize(graph_config)); + MEDIAPIPE_ASSERT_OK(graph_->StartRun({})); + + // Push the tensor into the graph. + AddRandomMatrix(num_rows, num_columns, kSeed, /*row_major_matrix=*/true); + + // Wait until the calculator done processing. + MEDIAPIPE_ASSERT_OK(graph_->WaitUntilIdle()); + EXPECT_EQ(1, output_packets.size()); + + // Get and process results. + const std::vector& tensor_vec = + output_packets[0].Get>(); + EXPECT_EQ(1, tensor_vec.size()); + + const TfLiteTensor* tensor = &tensor_vec[0]; + EXPECT_EQ(kTfLiteFloat32, tensor->type); + + // Verify that the data is correct. + RandomEngine random(kSeed); + std::uniform_real_distribution<> uniform_dist(0, 1.0); + const float* tensor_buffer = tensor->data.f; + for (int i = 0; i < num_rows * num_columns; ++i) { + const float expected = uniform_dist(random); + EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i; + } + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MEDIAPIPE_ASSERT_OK(graph_->CloseInputStream("matrix")); + MEDIAPIPE_ASSERT_OK(graph_->WaitUntilDone()); + + graph_.reset(); + } +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc index 4628062e7..bce1b6076 100644 --- a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc @@ -47,6 +47,8 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase { } ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + const TfLiteCustomOpResolverCalculatorOptions& options = cc->Options(); diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index 392a6c853..33a24993b 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// #include #include @@ -32,17 +31,22 @@ #include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // ANDROID +#endif // __ANDROID__ #if defined(__APPLE__) && !TARGET_OS_OSX // iOS -#if defined(__OBJC__) #import +#import #import -#endif // OBJC -#import "mediapipe/framework/ios/NSError+util_status.h" -#import "mediapipe/gpu/MediaPipeMetalHelper.h" + +#import "mediapipe/gpu/MPPMetalHelper.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h" -#endif // APPLE && !TARGET_OS_OSX +#endif // iOS + +#if defined(__ANDROID__) +typedef ::tflite::gpu::gl::GlBuffer GpuTensor; +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS +typedef id GpuTensor; +#endif // TfLiteInferenceCalculator File Layout: // * Header @@ -56,11 +60,14 @@ using ::tflite::gpu::gl::GlProgram; using ::tflite::gpu::gl::GlShader; struct GPUData { int elements = 1; - GlBuffer ssbo; - GlShader shader; - GlProgram program; + GlBuffer buffer; }; -#endif // ANDROID +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS +struct GPUData { + int elements = 1; + id buffer; +}; +#endif // Calculator Header Section @@ -78,12 +85,12 @@ struct GPUData { // GPU. // // Input: -// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 -// TENSORS_GPU - Vector of GlBuffer (assumed to be RGB image) +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8 +// TENSORS_GPU - Vector of GlBuffer or MTLBuffer // // Output: -// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 -// TENSORS_GPU - Vector of GlBuffer +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8 +// TENSORS_GPU - Vector of GlBuffer or MTLBuffer // // Input side packet: // CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver, @@ -107,7 +114,7 @@ struct GPUData { // Input tensors are assumed to be of the correct size and already normalized. // All output TfLiteTensors will be destroyed when the graph closes, // (i.e. after calling graph.WaitUntilDone()). -// GPU tensors are currently only supported on Android. +// GPU tensors are currently only supported on Android and iOS. // This calculator uses FixedSizeInputStreamHandler by default. // class TfLiteInferenceCalculator : public CalculatorBase { @@ -131,40 +138,41 @@ class TfLiteInferenceCalculator : public CalculatorBase { mediapipe::GlCalculatorHelper gpu_helper_; std::unique_ptr gpu_data_in_; std::vector> gpu_data_out_; -#endif -#if defined(__APPLE__) && !TARGET_OS_OSX // iOS - MediaPipeMetalHelper* gpu_helper_ = nullptr; +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + MPPMetalHelper* gpu_helper_ = nullptr; + std::unique_ptr gpu_data_in_; + std::vector> gpu_data_out_; #endif std::string model_path_ = ""; bool gpu_inference_ = false; bool gpu_input_ = false; bool gpu_output_ = false; -}; // TfLiteInferenceCalculator - + bool use_quantized_tensors_ = false; +}; REGISTER_CALCULATOR(TfLiteInferenceCalculator); // Calculator Core Section ::mediapipe::Status TfLiteInferenceCalculator::GetContract( CalculatorContract* cc) { - RET_CHECK(cc->Inputs().HasTag("TENSORS") || + RET_CHECK(cc->Inputs().HasTag("TENSORS") ^ cc->Inputs().HasTag("TENSORS_GPU")); - RET_CHECK(cc->Outputs().HasTag("TENSORS") || + RET_CHECK(cc->Outputs().HasTag("TENSORS") ^ cc->Outputs().HasTag("TENSORS_GPU")); if (cc->Inputs().HasTag("TENSORS")) cc->Inputs().Tag("TENSORS").Set>(); -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Inputs().HasTag("TENSORS_GPU")) - cc->Inputs().Tag("TENSORS_GPU").Set>(); + cc->Inputs().Tag("TENSORS_GPU").Set>(); #endif if (cc->Outputs().HasTag("TENSORS")) cc->Outputs().Tag("TENSORS").Set>(); -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Outputs().HasTag("TENSORS_GPU")) - cc->Outputs().Tag("TENSORS_GPU").Set>(); + cc->Outputs().Tag("TENSORS_GPU").Set>(); #endif if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { @@ -176,7 +184,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); #if defined(__ANDROID__) RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS - RETURN_IF_ERROR([MediaPipeMetalHelper updateContract:cc]); + RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); #endif // Assign this calculator's default InputStreamHandler. @@ -186,26 +194,26 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); } ::mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + RETURN_IF_ERROR(LoadOptions(cc)); if (cc->Inputs().HasTag("TENSORS_GPU")) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) gpu_input_ = true; gpu_inference_ = true; // Inference must be on GPU also. #else - RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU")) - << "GPU input for non-Android not supported yet."; + RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; #endif } if (cc->Outputs().HasTag("TENSORS_GPU")) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) gpu_output_ = true; RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU")) << "GPU output must also have GPU Input."; #else - RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU")) - << "GPU output for non-Android not supported yet."; + RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; #endif } @@ -215,7 +223,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); #if defined(__ANDROID__) RETURN_IF_ERROR(gpu_helper_.Open(cc)); #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS - gpu_helper_ = [[MediaPipeMetalHelper alloc] initWithCalculatorContext:cc]; + gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(gpu_helper_); #endif @@ -226,24 +234,38 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); } ::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) { - // Receive pre-processed tensor inputs. + // 1. Receive pre-processed tensor inputs. if (gpu_input_) { // Read GPU input into SSBO. #if defined(__ANDROID__) const auto& input_tensors = - cc->Inputs().Tag("TENSORS_GPU").Get>(); + cc->Inputs().Tag("TENSORS_GPU").Get>(); RET_CHECK_EQ(input_tensors.size(), 1); RETURN_IF_ERROR(gpu_helper_.RunInGlContext( [this, &input_tensors]() -> ::mediapipe::Status { // Explicit copy input. - tflite::gpu::gl::CopyBuffer(input_tensors[0], gpu_data_in_->ssbo); - // Run inference. - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + tflite::gpu::gl::CopyBuffer(input_tensors[0], gpu_data_in_->buffer); return ::mediapipe::OkStatus(); })); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + const auto& input_tensors = + cc->Inputs().Tag("TENSORS_GPU").Get>(); + RET_CHECK_EQ(input_tensors.size(), 1); + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"TfLiteInferenceCalculatorInput"; + id blit_command = + [command_buffer blitCommandEncoder]; + // Explicit copy input. + [blit_command copyFromBuffer:input_tensors[0] + sourceOffset:0 + toBuffer:gpu_data_in_->buffer + destinationOffset:0 + size:gpu_data_in_->elements * sizeof(float)]; + [blit_command endEncoding]; + [command_buffer commit]; + [command_buffer waitUntilCompleted]; #else - RET_CHECK_FAIL() - << "GPU input on non-Android devices is not supported yet."; + RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; #endif } else { // Read CPU input into tensors. @@ -252,35 +274,38 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); RET_CHECK_GT(input_tensors.size(), 0); for (int i = 0; i < input_tensors.size(); ++i) { const TfLiteTensor* input_tensor = &input_tensors[i]; - const float* input_tensor_buffer = input_tensor->data.f; - RET_CHECK(input_tensor_buffer); - - float* local_tensor_buffer = interpreter_->typed_input_tensor(i); - RET_CHECK(local_tensor_buffer); - - memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes); - } - - // Run inference. - if (gpu_inference_) { -#if defined(__ANDROID__) - RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status { - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); - return ::mediapipe::OkStatus(); - })); -#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); -#endif - } else { - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + RET_CHECK(input_tensor->data.raw); + if (use_quantized_tensors_) { + const uint8* input_tensor_buffer = input_tensor->data.uint8; + uint8* local_tensor_buffer = interpreter_->typed_input_tensor(i); + memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes); + } else { + const float* input_tensor_buffer = input_tensor->data.f; + float* local_tensor_buffer = interpreter_->typed_input_tensor(i); + memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes); + } } } + // 2. Run inference. + if (gpu_inference_) { +#if defined(__ANDROID__) + RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status { + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + return ::mediapipe::OkStatus(); + })); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); +#endif + } else { + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + } + + // 3. Output processed tensors. if (gpu_output_) { #if defined(__ANDROID__) // Output result tensors (GPU). - auto output_tensors = absl::make_unique>(); + auto output_tensors = absl::make_unique>(); output_tensors->resize(gpu_data_out_.size()); for (int i = 0; i < gpu_data_out_.size(); ++i) { GlBuffer& tensor = output_tensors->at(i); @@ -290,13 +315,39 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); if (!status.ok()) { return ::mediapipe::InternalError(status.error_message()); } - tflite::gpu::gl::CopyBuffer(gpu_data_out_[i]->ssbo, tensor); + tflite::gpu::gl::CopyBuffer(gpu_data_out_[i]->buffer, tensor); + } + cc->Outputs() + .Tag("TENSORS_GPU") + .Add(output_tensors.release(), cc->InputTimestamp()); +#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS + // Output result tensors (GPU). + auto output_tensors = absl::make_unique>(); + id device = gpu_helper_.mtlDevice; + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"TfLiteInferenceCalculatorOutput"; + for (int i = 0; i < gpu_data_out_.size(); ++i) { + id tensor = + [device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float) + options:MTLResourceStorageModeShared]; + id blit_command = + [command_buffer blitCommandEncoder]; + // Explicit copy input. + [blit_command copyFromBuffer:gpu_data_out_[i]->buffer + sourceOffset:0 + toBuffer:tensor + destinationOffset:0 + size:gpu_data_out_[i]->elements * sizeof(float)]; + [blit_command endEncoding]; + [command_buffer commit]; + [command_buffer waitUntilCompleted]; + output_tensors->push_back(tensor); } cc->Outputs() .Tag("TENSORS_GPU") .Add(output_tensors.release(), cc->InputTimestamp()); #else - LOG(ERROR) << "GPU output on non-Android not supported yet."; + RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; #endif } else { // Output result tensors (CPU). @@ -325,8 +376,13 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); return ::mediapipe::OkStatus(); })); #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS - DeleteGpuDelegate(delegate_); + TFLGpuDelegateDelete(delegate_); + gpu_data_in_.reset(); + for (int i = 0; i < gpu_data_out_.size(); ++i) { + gpu_data_out_[i].reset(); + } #endif + delegate_ = nullptr; } return ::mediapipe::OkStatus(); } @@ -341,8 +397,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); // Get model name. if (!options.model_path().empty()) { - ASSIGN_OR_RETURN(model_path_, - mediapipe::PathToResourceAsFile(options.model_path())); + auto model_path = options.model_path(); + + ASSIGN_OR_RETURN(model_path_, mediapipe::PathToResourceAsFile(model_path)); } else { LOG(ERROR) << "Must specify path to TFLite model."; return ::mediapipe::Status(::mediapipe::StatusCode::kNotFound, @@ -360,11 +417,6 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); model_ = tflite::FlatBufferModel::BuildFromFile(model_path_.c_str()); RET_CHECK(model_); -#if !defined(__ANDROID__) && !(defined(__APPLE__) && !TARGET_OS_OSX) - LOG(WARNING) << "GPU only supported on mobile platforms. Using CPU fallback."; - gpu_inference_ = false; -#endif - if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { const auto& op_resolver = cc->InputSidePackets() @@ -378,8 +430,13 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); RET_CHECK(interpreter_); - if (!gpu_output_) { + if (gpu_output_) { + use_quantized_tensors_ = false; + } else { RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + use_quantized_tensors_ = (interpreter_->tensor(0)->quantization.type == + kTfLiteAffineQuantization); + if (use_quantized_tensors_) gpu_inference_ = false; } return ::mediapipe::OkStatus(); @@ -388,12 +445,20 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); ::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate( CalculatorContext* cc) { #if defined(__ANDROID__) - // Get input image sizes. + // Configure and create the delegate. + TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); + options.compile_options.precision_loss_allowed = 1; + options.compile_options.preferred_gl_object_type = + TFLITE_GL_OBJECT_TYPE_FASTEST; + options.compile_options.dynamic_batch_enabled = 0; + options.compile_options.inline_parameters = 1; + if (!delegate_) delegate_ = TfLiteGpuDelegateCreate(&options); + if (gpu_input_) { + // Get input image sizes. gpu_data_in_ = absl::make_unique(); const auto& input_indices = interpreter_->inputs(); - // TODO accept > 1. - RET_CHECK_EQ(input_indices.size(), 1); + RET_CHECK_EQ(input_indices.size(), 1); // TODO accept > 1. const TfLiteTensor* tensor = interpreter_->tensor(input_indices[0]); gpu_data_in_->elements = 1; for (int d = 0; d < tensor->dims->size; ++d) { @@ -402,9 +467,19 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); // Input to model can be either RGB/RGBA only. RET_CHECK_GE(tensor->dims->data[3], 3); RET_CHECK_LE(tensor->dims->data[3], 4); + // Create and bind input buffer. + auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( + gpu_data_in_->elements, &gpu_data_in_->buffer); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( + delegate_, gpu_data_in_->buffer.id(), + interpreter_->inputs()[0]), // First tensor only + kTfLiteOk); } - // Get output image sizes. if (gpu_output_) { + // Get output image sizes. const auto& output_indices = interpreter_->outputs(); gpu_data_out_.resize(output_indices.size()); for (int i = 0; i < gpu_data_out_.size(); ++i) { @@ -416,55 +491,90 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); gpu_data_out_[i]->elements *= tensor->dims->data[d]; } } - } - // Configure and create the delegate. - TfLiteGpuDelegateOptions options; - options.metadata = nullptr; - options.compile_options.precision_loss_allowed = 1; - options.compile_options.preferred_gl_object_type = - TFLITE_GL_OBJECT_TYPE_FASTEST; - options.compile_options.dynamic_batch_enabled = 0; - if (!delegate_) delegate_ = TfLiteGpuDelegateCreate(&options); - // Shader to convert GL texture to SSBO. - if (gpu_input_) { - auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( - gpu_data_in_->elements, &gpu_data_in_->ssbo); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( - delegate_, gpu_data_in_->ssbo.id(), - interpreter_->inputs()[0]), // First tensor only - kTfLiteOk); - } - // Create output SSBO buffers. - if (gpu_output_) { + // Create and bind output buffers. interpreter_->SetAllowBufferHandleOutput(true); - const auto& output_indices = interpreter_->outputs(); for (int i = 0; i < gpu_data_out_.size(); ++i) { using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; auto status = CreateReadWriteShaderStorageBuffer( - gpu_data_out_[i]->elements, &gpu_data_out_[i]->ssbo); + gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer); if (!status.ok()) { return ::mediapipe::InternalError(status.error_message()); } RET_CHECK_EQ( TfLiteGpuDelegateBindBufferToTensor( - delegate_, gpu_data_out_[i]->ssbo.id(), output_indices[i]), + delegate_, gpu_data_out_[i]->buffer.id(), output_indices[i]), kTfLiteOk); } } + // Must call this last. RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk); - return ::mediapipe::OkStatus(); -#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS +#endif // __ANDROID__ + +#if defined(__APPLE__) && !TARGET_OS_OSX // iOS + // Configure and create the delegate. GpuDelegateOptions options; - options.allow_precision_loss = 1; - options.wait_type = GpuDelegateOptions::WaitType::kPassive; - if (!delegate_) delegate_ = NewGpuDelegate(&options); - RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk); - return ::mediapipe::OkStatus(); -#endif // ANDROID or iOS + options.allow_precision_loss = false; // Must match converter, F=float/T=half + options.wait_type = GpuDelegateOptions::WaitType::kActive; + if (!delegate_) delegate_ = TFLGpuDelegateCreate(&options); + + if (gpu_input_) { + // Get input image sizes. + gpu_data_in_ = absl::make_unique(); + const auto& input_indices = interpreter_->inputs(); + RET_CHECK_EQ(input_indices.size(), 1); + const TfLiteTensor* tensor = interpreter_->tensor(input_indices[0]); + gpu_data_in_->elements = 1; + // On iOS GPU, input must be 4 channels, regardless of what model expects. + { + gpu_data_in_->elements *= tensor->dims->data[0]; // batch + gpu_data_in_->elements *= tensor->dims->data[1]; // height + gpu_data_in_->elements *= tensor->dims->data[2]; // width + gpu_data_in_->elements *= 4; // channels + } + // Input to model can be RGBA only. + if (tensor->dims->data[3] != 4) { + LOG(WARNING) << "Please ensure input GPU tensor is 4 channels."; + } + // Create and bind input buffer. + id device = gpu_helper_.mtlDevice; + gpu_data_in_->buffer = + [device newBufferWithLength:gpu_data_in_->elements * sizeof(float) + options:MTLResourceStorageModeShared]; + // Must call this before TFLGpuDelegateBindMetalBufferToTensor. + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk); + RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( + delegate_, + input_indices[0], // First tensor only + gpu_data_in_->buffer), + true); + } + if (gpu_output_) { + // Get output image sizes. + const auto& output_indices = interpreter_->outputs(); + gpu_data_out_.resize(output_indices.size()); + for (int i = 0; i < gpu_data_out_.size(); ++i) { + const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + gpu_data_out_[i] = absl::make_unique(); + gpu_data_out_[i]->elements = 1; + // TODO handle *2 properly on some dialated models + for (int d = 0; d < tensor->dims->size; ++d) { + gpu_data_out_[i]->elements *= tensor->dims->data[d]; + } + } + // Create and bind output buffers. + interpreter_->SetAllowBufferHandleOutput(true); + id device = gpu_helper_.mtlDevice; + for (int i = 0; i < gpu_data_out_.size(); ++i) { + gpu_data_out_[i]->buffer = + [device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float) + options:MTLResourceStorageModeShared]; + RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( + delegate_, output_indices[i], gpu_data_out_[i]->buffer), + true); + } + } +#endif // iOS return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc index 06e49124f..cb9e5cf14 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc @@ -120,11 +120,19 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase { ::mediapipe::Status Close(CalculatorContext* cc) override; private: + ::mediapipe::Status ProcessCPU(CalculatorContext* cc, + std::vector* output_detections); + ::mediapipe::Status ProcessGPU(CalculatorContext* cc, + std::vector* output_detections); + ::mediapipe::Status LoadOptions(CalculatorContext* cc); ::mediapipe::Status GlSetup(CalculatorContext* cc); ::mediapipe::Status DecodeBoxes(const float* raw_boxes, const std::vector& anchors, std::vector* boxes); + ::mediapipe::Status ConvertToDetections( + const float* detection_boxes, const float* detection_scores, + const int* detection_classes, std::vector* output_detections); Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, int class_id, bool flip_vertically); @@ -136,6 +144,7 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase { ::mediapipe::TfLiteTensorsToDetectionsCalculatorOptions options_; std::vector anchors_; + bool side_packet_anchors_{}; #if defined(__ANDROID__) mediapipe::GlCalculatorHelper gpu_helper_; @@ -187,6 +196,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Open( CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + if (cc->Inputs().HasTag("TENSORS_GPU")) { gpu_input_ = true; #if defined(__ANDROID__) @@ -195,6 +206,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); } RETURN_IF_ERROR(LoadOptions(cc)); + side_packet_anchors_ = cc->InputSidePackets().HasTag("ANCHORS"); if (gpu_input_) { RETURN_IF_ERROR(GlSetup(cc)); @@ -210,72 +222,32 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); return ::mediapipe::OkStatus(); } - const bool side_packet_anchors = - cc->InputSidePackets().HasTag("ANCHORS") && - !cc->InputSidePackets().Tag("ANCHORS").IsEmpty(); auto output_detections = absl::make_unique>(); - std::vector boxes(num_boxes_ * num_coords_); - std::vector score_class_id_pairs(num_boxes_ * 2); - if (gpu_input_) { -#if defined(__ANDROID__) - const auto& input_tensors = - cc->Inputs().Tag("TENSORS_GPU").Get>(); - - // Copy inputs. - tflite::gpu::gl::CopyBuffer(input_tensors[0], *raw_boxes_buffer_.get()); - tflite::gpu::gl::CopyBuffer(input_tensors[1], *raw_scores_buffer_.get()); - if (!anchors_init_) { - if (side_packet_anchors) { - const auto& anchors = - cc->InputSidePackets().Tag("ANCHORS").Get>(); - std::vector raw_anchors(num_boxes_ * kNumCoordsPerBox); - ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data()); - raw_anchors_buffer_->Write(absl::MakeSpan(raw_anchors)); - } else { - CHECK_EQ(input_tensors.size(), 3); - tflite::gpu::gl::CopyBuffer(input_tensors[2], - *raw_anchors_buffer_.get()); - } - anchors_init_ = true; - } - - // Run shaders. - RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &input_tensors]() -> ::mediapipe::Status { - // Decode boxes. - decoded_boxes_buffer_->BindToIndex(0); - raw_boxes_buffer_->BindToIndex(1); - raw_anchors_buffer_->BindToIndex(2); - const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1}; - decode_program_->Dispatch(decode_workgroups); - - // Score boxes. - scored_boxes_buffer_->BindToIndex(0); - raw_scores_buffer_->BindToIndex(1); - const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1}; - score_program_->Dispatch(score_workgroups); - - return ::mediapipe::OkStatus(); - })); - - // Copy decoded boxes from GPU to CPU. - auto status = decoded_boxes_buffer_->Read(absl::MakeSpan(boxes)); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } - status = scored_boxes_buffer_->Read(absl::MakeSpan(score_class_id_pairs)); - if (!status.ok()) { - return ::mediapipe::InternalError(status.error_message()); - } -#else - LOG(ERROR) << "GPU input on non-Android not supported yet."; -#endif // defined(__ANDROID__) + RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get())); } else { - const auto& input_tensors = - cc->Inputs().Tag("TENSORS").Get>(); + RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get())); + } // if gpu_input_ + // Output + if (cc->Outputs().HasTag("DETECTIONS")) { + cc->Outputs() + .Tag("DETECTIONS") + .Add(output_detections.release(), cc->InputTimestamp()); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU( + CalculatorContext* cc, std::vector* output_detections) { + const auto& input_tensors = + cc->Inputs().Tag("TENSORS").Get>(); + + if (input_tensors.size() == 2) { + // Postprocessing on CPU for model without postprocessing op. E.g. output + // raw score tensor and box tensor. Anchor decoding will be handled below. const TfLiteTensor* raw_box_tensor = &input_tensors[0]; const TfLiteTensor* raw_score_tensor = &input_tensors[1]; @@ -300,7 +272,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox); const float* raw_anchors = anchor_tensor->data.f; ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_); - } else if (side_packet_anchors) { + } else if (side_packet_anchors_) { + CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); anchors_ = cc->InputSidePackets().Tag("ANCHORS").Get>(); } else { @@ -308,8 +281,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); } anchors_init_ = true; } + std::vector boxes(num_boxes_ * num_coords_); RETURN_IF_ERROR(DecodeBoxes(raw_boxes, anchors_, &boxes)); + std::vector detection_scores(num_boxes_); + std::vector detection_classes(num_boxes_); + // Filter classes by scores. for (int i = 0; i < num_boxes_; ++i) { int class_id = -1; @@ -335,44 +312,119 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); } } } - score_class_id_pairs[i * 2 + 0] = max_score; - score_class_id_pairs[i * 2 + 1] = class_id; + detection_scores[i] = max_score; + detection_classes[i] = class_id; } - } // if gpu_input_ - // Convert to Detection. + RETURN_IF_ERROR(ConvertToDetections(boxes.data(), detection_scores.data(), + detection_classes.data(), + output_detections)); + } else { + // Postprocessing on CPU with postprocessing op (e.g. anchor decoding and + // non-maximum suppression) within the model. + RET_CHECK_EQ(input_tensors.size(), 4); + + const TfLiteTensor* detection_boxes_tensor = &input_tensors[0]; + const TfLiteTensor* detection_classes_tensor = &input_tensors[1]; + const TfLiteTensor* detection_scores_tensor = &input_tensors[2]; + const TfLiteTensor* num_boxes_tensor = &input_tensors[3]; + RET_CHECK_EQ(num_boxes_tensor->dims->size, 1); + RET_CHECK_EQ(num_boxes_tensor->dims->data[0], 1); + const float* num_boxes = num_boxes_tensor->data.f; + num_boxes_ = num_boxes[0]; + RET_CHECK_EQ(detection_boxes_tensor->dims->size, 3); + RET_CHECK_EQ(detection_boxes_tensor->dims->data[0], 1); + const int max_detections = detection_boxes_tensor->dims->data[1]; + RET_CHECK_EQ(detection_boxes_tensor->dims->data[2], num_coords_); + RET_CHECK_EQ(detection_classes_tensor->dims->size, 2); + RET_CHECK_EQ(detection_classes_tensor->dims->data[0], 1); + RET_CHECK_EQ(detection_classes_tensor->dims->data[1], max_detections); + RET_CHECK_EQ(detection_scores_tensor->dims->size, 2); + RET_CHECK_EQ(detection_scores_tensor->dims->data[0], 1); + RET_CHECK_EQ(detection_scores_tensor->dims->data[1], max_detections); + + const float* detection_boxes = detection_boxes_tensor->data.f; + const float* detection_scores = detection_scores_tensor->data.f; + std::vector detection_classes(num_boxes_); + for (int i = 0; i < num_boxes_; ++i) { + detection_classes[i] = + static_cast(detection_classes_tensor->data.f[i]); + } + RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores, + detection_classes.data(), + output_detections)); + } + return ::mediapipe::OkStatus(); +} +::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( + CalculatorContext* cc, std::vector* output_detections) { +#if defined(__ANDROID__) + const auto& input_tensors = + cc->Inputs().Tag("TENSORS_GPU").Get>(); + + // Copy inputs. + tflite::gpu::gl::CopyBuffer(input_tensors[0], *raw_boxes_buffer_.get()); + tflite::gpu::gl::CopyBuffer(input_tensors[1], *raw_scores_buffer_.get()); + if (!anchors_init_) { + if (side_packet_anchors_) { + CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); + const auto& anchors = + cc->InputSidePackets().Tag("ANCHORS").Get>(); + std::vector raw_anchors(num_boxes_ * kNumCoordsPerBox); + ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data()); + raw_anchors_buffer_->Write(absl::MakeSpan(raw_anchors)); + } else { + CHECK_EQ(input_tensors.size(), 3); + tflite::gpu::gl::CopyBuffer(input_tensors[2], *raw_anchors_buffer_.get()); + } + anchors_init_ = true; + } + + // Run shaders. + RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &input_tensors]() -> ::mediapipe::Status { + // Decode boxes. + decoded_boxes_buffer_->BindToIndex(0); + raw_boxes_buffer_->BindToIndex(1); + raw_anchors_buffer_->BindToIndex(2); + const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1}; + decode_program_->Dispatch(decode_workgroups); + + // Score boxes. + scored_boxes_buffer_->BindToIndex(0); + raw_scores_buffer_->BindToIndex(1); + const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1}; + score_program_->Dispatch(score_workgroups); + + return ::mediapipe::OkStatus(); + })); + + // Copy decoded boxes from GPU to CPU. + std::vector boxes(num_boxes_ * num_coords_); + auto status = decoded_boxes_buffer_->Read(absl::MakeSpan(boxes)); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + std::vector score_class_id_pairs(num_boxes_ * 2); + status = scored_boxes_buffer_->Read(absl::MakeSpan(score_class_id_pairs)); + if (!status.ok()) { + return ::mediapipe::InternalError(status.error_message()); + } + + // TODO: b/138851969. Is it possible to output a float vector + // for score and an int vector for class so that we can avoid copying twice? + std::vector detection_scores(num_boxes_); + std::vector detection_classes(num_boxes_); for (int i = 0; i < num_boxes_; ++i) { - const float score = score_class_id_pairs[i * 2 + 0]; - const int class_id = score_class_id_pairs[i * 2 + 1]; - const int box_offset = i * num_coords_; - Detection detection = ConvertToDetection( - boxes[box_offset + 0], boxes[box_offset + 1], boxes[box_offset + 2], - boxes[box_offset + 3], score, class_id, options_.flip_vertically()); - // Add keypoints. - if (options_.num_keypoints() > 0) { - auto* location_data = detection.mutable_location_data(); - for (int kp_id = 0; kp_id < options_.num_keypoints() * - options_.num_values_per_keypoint(); - kp_id += options_.num_values_per_keypoint()) { - auto keypoint = location_data->add_relative_keypoints(); - const int keypoint_index = - box_offset + options_.keypoint_coord_offset() + kp_id; - keypoint->set_x(boxes[keypoint_index + 0]); - keypoint->set_y(options_.flip_vertically() - ? 1.f - boxes[keypoint_index + 1] - : boxes[keypoint_index + 1]); - } - } - output_detections->emplace_back(detection); + detection_scores[i] = score_class_id_pairs[i * 2]; + detection_classes[i] = static_cast(score_class_id_pairs[i * 2 + 1]); } - - // Output - if (cc->Outputs().HasTag("DETECTIONS")) { - cc->Outputs() - .Tag("DETECTIONS") - .Add(output_detections.release(), cc->InputTimestamp()); - } - + RETURN_IF_ERROR(ConvertToDetections(boxes.data(), detection_scores.data(), + detection_classes.data(), + output_detections)); +#else + LOG(ERROR) << "GPU input on non-Android not supported yet."; +#endif // defined(__ANDROID__) return ::mediapipe::OkStatus(); } @@ -481,6 +533,39 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); return ::mediapipe::OkStatus(); } +::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ConvertToDetections( + const float* detection_boxes, const float* detection_scores, + const int* detection_classes, std::vector* output_detections) { + for (int i = 0; i < num_boxes_; ++i) { + if (options_.has_min_score_thresh() && + detection_scores[i] < options_.min_score_thresh()) { + continue; + } + const int box_offset = i * num_coords_; + Detection detection = ConvertToDetection( + detection_boxes[box_offset + 0], detection_boxes[box_offset + 1], + detection_boxes[box_offset + 2], detection_boxes[box_offset + 3], + detection_scores[i], detection_classes[i], options_.flip_vertically()); + // Add keypoints. + if (options_.num_keypoints() > 0) { + auto* location_data = detection.mutable_location_data(); + for (int kp_id = 0; kp_id < options_.num_keypoints() * + options_.num_values_per_keypoint(); + kp_id += options_.num_values_per_keypoint()) { + auto keypoint = location_data->add_relative_keypoints(); + const int keypoint_index = + box_offset + options_.keypoint_coord_offset() + kp_id; + keypoint->set_x(detection_boxes[keypoint_index + 0]); + keypoint->set_y(options_.flip_vertically() + ? 1.f - detection_boxes[keypoint_index + 1] + : detection_boxes[keypoint_index + 1]); + } + } + output_detections->emplace_back(detection); + } + return ::mediapipe::OkStatus(); +} + Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection( float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, int class_id, bool flip_vertically) { @@ -628,7 +713,7 @@ void main() { if (!status.ok()) { return ::mediapipe::InternalError(status.error_message()); } - size_t raw_anchors_length = num_boxes_ * num_coords_; + size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox; raw_anchors_buffer_ = absl::make_unique(); status = CreateReadWriteShaderStorageBuffer(raw_anchors_length, raw_anchors_buffer_.get()); diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto index ca4688086..ef494c2cc 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto @@ -68,4 +68,7 @@ message TfLiteTensorsToDetectionsCalculatorOptions { // the origin is at the top-left corner, whereas the desired detection // representation has a bottom-left origin (e.g., in OpenGL). optional bool flip_vertically = 18 [default = false]; + + // Score threshold for perserving decoded detections. + optional float min_score_thresh = 19; } diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc new file mode 100644 index 000000000..72dd60a0b --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc @@ -0,0 +1,103 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "tensorflow/lite/interpreter.h" + +namespace mediapipe { + +// A calculator for converting TFLite tensors to to a float or a float vector. +// +// Input: +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first +// tensor will be used. +// Output: +// FLOAT(optional) - Converted single float number. +// FLOATS(optional) - Converted float vector. +// +// Notes: To output FLOAT stream, the input TFLite tensor must have size 1, e.g. +// only 1 float number in the tensor. +// +// Usage example: +// node { +// calculator: "TfLiteTensorsToFloatsCalculator" +// input_stream: "TENSORS:tensors" +// output_stream: "FLOATS:floats" +// } +class TfLiteTensorsToFloatsCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + + ::mediapipe::Status Process(CalculatorContext* cc) override; +}; +REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator); + +::mediapipe::Status TfLiteTensorsToFloatsCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag("TENSORS")); + RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT")); + + cc->Inputs().Tag("TENSORS").Set>(); + if (cc->Outputs().HasTag("FLOATS")) { + cc->Outputs().Tag("FLOATS").Set>(); + } + if (cc->Outputs().HasTag("FLOAT")) { + cc->Outputs().Tag("FLOAT").Set(); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToFloatsCalculator::Open( + CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToFloatsCalculator::Process( + CalculatorContext* cc) { + RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty()); + + const auto& input_tensors = + cc->Inputs().Tag("TENSORS").Get>(); + // TODO: Add option to specify which tensor to take from. + const TfLiteTensor* raw_tensor = &input_tensors[0]; + const float* raw_floats = raw_tensor->data.f; + int num_values = 1; + for (int i = 0; i < raw_tensor->dims->size; ++i) { + RET_CHECK_GT(raw_tensor->dims->data[i], 0); + num_values *= raw_tensor->dims->data[i]; + } + + if (cc->Outputs().HasTag("FLOAT")) { + // TODO: Could add an index in the option to specifiy returning one + // value of a float array. + RET_CHECK_EQ(num_values, 1); + cc->Outputs().Tag("FLOAT").AddPacket( + MakePacket(raw_floats[0]).At(cc->InputTimestamp())); + } + if (cc->Outputs().HasTag("FLOATS")) { + auto output_floats = absl::make_unique>( + raw_floats, raw_floats + num_values); + cc->Outputs().Tag("FLOATS").Add(output_floats.release(), + cc->InputTimestamp()); + } + + return ::mediapipe::OkStatus(); +} +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc new file mode 100644 index 000000000..d77e1514c --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc @@ -0,0 +1,188 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "tensorflow/lite/interpreter.h" + +namespace mediapipe { + +// A calculator for converting TFLite tensors from regression models into +// landmarks. +// +// Input: +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first +// tensor will be used. The size of the values must be +// (num_dimension x num_landmarks). +// Output: +// LANDMARKS(optional) - Result MediaPipe landmarks. +// NORM_LANDMARKS(optional) - Result MediaPipe normalized landmarks. +// +// Notes: +// To output normalized landmarks, user must provide the original input image +// size to the model using calculator option input_image_width and +// input_image_height. +// Usage example: +// node { +// calculator: "TfLiteTensorsToLandmarksCalculator" +// input_stream: "TENSORS:landmark_tensors" +// output_stream: "LANDMARKS:landmarks" +// output_stream: "NORM_LANDMARKS:landmarks" +// options: { +// [mediapipe.TfLiteTensorsToLandmarksCalculatorOptions.ext] { +// num_landmarks: 21 +// +// input_image_width: 256 +// input_image_height: 256 +// } +// } +// } +class TfLiteTensorsToLandmarksCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + ::mediapipe::Status LoadOptions(CalculatorContext* cc); + int num_landmarks_ = 0; + + ::mediapipe::TfLiteTensorsToLandmarksCalculatorOptions options_; +}; +REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); + +::mediapipe::Status TfLiteTensorsToLandmarksCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(!cc->Inputs().GetTags().empty()); + RET_CHECK(!cc->Outputs().GetTags().empty()); + + if (cc->Inputs().HasTag("TENSORS")) { + cc->Inputs().Tag("TENSORS").Set>(); + } + + if (cc->Outputs().HasTag("LANDMARKS")) { + cc->Outputs().Tag("LANDMARKS").Set>(); + } + + if (cc->Outputs().HasTag("NORM_LANDMARKS")) { + cc->Outputs().Tag("NORM_LANDMARKS").Set>(); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToLandmarksCalculator::Open( + CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + RETURN_IF_ERROR(LoadOptions(cc)); + + if (cc->Outputs().HasTag("NORM_LANDMARKS")) { + RET_CHECK(options_.has_input_image_height() && + options_.has_input_image_width()) + << "Must provide input with/height for getting normalized landmarks."; + } + if (cc->Outputs().HasTag("LANDMARKS") && options_.flip_vertically()) { + RET_CHECK(options_.has_input_image_height() && + options_.has_input_image_width()) + << "Must provide input with/height for using flip_vertically option " + "when outputing landmarks in absolute coordinates."; + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToLandmarksCalculator::Process( + CalculatorContext* cc) { + if (cc->Inputs().Tag("TENSORS").IsEmpty()) { + return ::mediapipe::OkStatus(); + } + + const auto& input_tensors = + cc->Inputs().Tag("TENSORS").Get>(); + + const TfLiteTensor* raw_tensor = &input_tensors[0]; + + int num_values = 1; + for (int i = 0; i < raw_tensor->dims->size; ++i) { + num_values *= raw_tensor->dims->data[i]; + } + const int num_dimensions = num_values / num_landmarks_; + // Landmarks must have less than 3 dimensions. Otherwise please consider + // using matrix. + CHECK_LE(num_dimensions, 3); + CHECK_GT(num_dimensions, 0); + + const float* raw_landmarks = raw_tensor->data.f; + + auto output_landmarks = absl::make_unique>(); + + for (int ld = 0; ld < num_landmarks_; ++ld) { + const int offset = ld * num_dimensions; + Landmark landmark; + landmark.set_x(raw_landmarks[offset]); + if (num_dimensions > 1) { + if (options_.flip_vertically()) { + landmark.set_y(options_.input_image_height() - + raw_landmarks[offset + 1]); + } else { + landmark.set_y(raw_landmarks[offset + 1]); + } + } + if (num_dimensions > 2) { + landmark.set_z(raw_landmarks[offset + 2]); + } + output_landmarks->push_back(landmark); + } + + // Output normalized landmarks if required. + if (cc->Outputs().HasTag("NORM_LANDMARKS")) { + auto output_norm_landmarks = + absl::make_unique>(); + for (const auto& landmark : *output_landmarks) { + NormalizedLandmark norm_landmark; + norm_landmark.set_x(static_cast(landmark.x()) / + options_.input_image_width()); + norm_landmark.set_y(static_cast(landmark.y()) / + options_.input_image_height()); + norm_landmark.set_z(landmark.z() / options_.normalize_z()); + + output_norm_landmarks->push_back(norm_landmark); + } + cc->Outputs() + .Tag("NORM_LANDMARKS") + .Add(output_norm_landmarks.release(), cc->InputTimestamp()); + } + // Output absolute landmarks. + if (cc->Outputs().HasTag("LANDMARKS")) { + cc->Outputs() + .Tag("LANDMARKS") + .Add(output_landmarks.release(), cc->InputTimestamp()); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteTensorsToLandmarksCalculator::LoadOptions( + CalculatorContext* cc) { + // Get calculator options specified in the graph. + options_ = + cc->Options<::mediapipe::TfLiteTensorsToLandmarksCalculatorOptions>(); + num_landmarks_ = options_.num_landmarks(); + + return ::mediapipe::OkStatus(); +} +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto new file mode 100644 index 000000000..5f37e6238 --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto @@ -0,0 +1,45 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The option proto for the TfLiteTensorsToLandmarksCalculator. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TfLiteTensorsToLandmarksCalculatorOptions { + extend .mediapipe.CalculatorOptions { + optional TfLiteTensorsToLandmarksCalculatorOptions ext = 257405002; + } + + // Number of landmarks from the output of the model. + required int32 num_landmarks = 1; + + // Size of the input image for the model. These options are used only when + // normalized landmarks is needed. + optional int32 input_image_width = 2; + optional int32 input_image_height = 3; + + // Whether the detection coordinates from the input tensors should be flipped + // vertically (along the y-direction). This is useful, for example, when the + // input tensors represent detections defined with a coordinate system where + // the origin is at the top-left corner, whereas the desired detection + // representation has a bottom-left origin (e.g., in OpenGL). + optional bool flip_vertically = 4 [default = false]; + + // A value that z values should be divided by. + optional float normalize_z = 5 [default = 1.0]; +} diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc index 1542871b6..7b903f157 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc @@ -20,6 +20,9 @@ #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/resource_util.h" #include "tensorflow/lite/interpreter.h" @@ -39,8 +42,11 @@ namespace { constexpr int kWorkgroupSize = 8; // Block size for GPU shader. enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; // Commonly used to compute the number of blocks to launch in a kernel. -int RoundUp(const int size, const int multiple) { - return (size + multiple - 1) / multiple; +int NumGroups(const int size, const int group_size) { // NOLINT + return (size + group_size - 1) / group_size; +} +float Clamp(float val, float min, float max) { + return std::min(std::max(val, min), max); } } // namespace @@ -58,7 +64,7 @@ using ::tflite::gpu::gl::GlShader; // Performs optional upscale to REFERENCE_IMAGE dimensions if provided, // otherwise the mask is the same size as input tensor. // -// Note: This calculator is currently GPU only, so only *_GPU tags can be used. +// Produces result as an RGBA image, with the mask in both R & A channels. // // Inputs: // One of the following TENSORS tags: @@ -71,11 +77,11 @@ using ::tflite::gpu::gl::GlShader; // REFERENCE_IMAGE_GPU (optional): A GpuBuffer input image, // used only for output dimensions. // One of the following PREV_MASK tags: -// PREV_MASK (optional): An ImageFrame input mask, Gray, RGB or RGBA. -// PREV_MASK_GPU (optional): A GpuBuffer input mask, RGBA. +// PREV_MASK (optional): An ImageFrame input mask, Gray, RGB or RGBA, [0-255]. +// PREV_MASK_GPU (optional): A GpuBuffer input mask, RGBA, [0-1]. // Output: // One of the following MASK tags: -// MASK: An ImageFrame output mask, Gray, RGB or RGBA. +// MASK: An ImageFrame output mask, RGBA. // MASK_GPU: A GpuBuffer output mask, RGBA. // // Options: @@ -141,10 +147,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); cc->Inputs().Tag("TENSORS").Set>(); } if (cc->Inputs().HasTag("PREV_MASK")) { - cc->Inputs().Tag("PREV_MASK").Set(); + cc->Inputs().Tag("PREV_MASK").Set(); } if (cc->Inputs().HasTag("REFERENCE_IMAGE")) { - cc->Inputs().Tag("REFERENCE_IMAGE").Set(); + cc->Inputs().Tag("REFERENCE_IMAGE").Set(); } // Inputs GPU. @@ -162,7 +168,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); // Outputs. if (cc->Outputs().HasTag("MASK")) { - cc->Outputs().Tag("MASK").Set(); + cc->Outputs().Tag("MASK").Set(); } #if defined(__ANDROID__) if (cc->Outputs().HasTag("MASK_GPU")) { @@ -179,6 +185,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Open( CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + if (cc->Inputs().HasTag("TENSORS_GPU")) { use_gpu_ = true; #if defined(__ANDROID__) @@ -238,7 +246,107 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::ProcessCpu( CalculatorContext* cc) { - return ::mediapipe::UnimplementedError("CPU support is not implemented yet."); + if (cc->Inputs().Tag("TENSORS").IsEmpty()) { + return ::mediapipe::OkStatus(); + } + + // Get input streams. + const auto& input_tensors = + cc->Inputs().Tag("TENSORS").Get>(); + const bool has_prev_mask = cc->Inputs().HasTag("PREV_MASK") && + !cc->Inputs().Tag("PREV_MASK").IsEmpty(); + const ImageFrame placeholder; + const auto& input_mask = has_prev_mask + ? cc->Inputs().Tag("PREV_MASK").Get() + : placeholder; + int output_width = tensor_width_, output_height = tensor_height_; + if (cc->Inputs().HasTag("REFERENCE_IMAGE")) { + const auto& input_image = + cc->Inputs().Tag("REFERENCE_IMAGE").Get(); + output_width = input_image.Width(); + output_height = input_image.Height(); + } + RET_CHECK_EQ(input_tensors.size(), 1); + + // Create initial working mask. + cv::Mat small_mask_mat(cv::Size(tensor_width_, tensor_height_), CV_8UC4); + + // Get input previous mask. + cv::Mat input_mask_mat; + if (has_prev_mask) { + cv::Mat temp_mask_mat = formats::MatView(&input_mask); + if (temp_mask_mat.channels() != 4) { + cv::Mat converted_mat; + cv::cvtColor(temp_mask_mat, converted_mat, + temp_mask_mat.channels() == 1 ? cv::COLOR_GRAY2RGBA + : cv::COLOR_RGB2RGBA); + temp_mask_mat = converted_mat.clone(); + } + cv::resize(temp_mask_mat, input_mask_mat, small_mask_mat.size()); + } + + // Copy input tensor. + const TfLiteTensor* raw_input_tensor = &input_tensors[0]; + const float* raw_input_data = raw_input_tensor->data.f; + cv::Mat tensor_mat(cv::Size(tensor_width_, tensor_height_), + CV_MAKETYPE(CV_32F, tensor_channels_)); + float* tensor_mat_ptr = tensor_mat.ptr(); + memcpy(tensor_mat_ptr, raw_input_data, raw_input_tensor->bytes); + + // Process mask tensor. + // Run softmax over tensor output and blend with previous mask. + const int output_layer_index = options_.output_layer_index(); + const float combine_with_prev_ratio = options_.combine_with_previous_ratio(); + for (int i = 0; i < tensor_height_; ++i) { + for (int j = 0; j < tensor_width_; ++j) { + // Only two channel input tensor is supported. + const cv::Vec2f input_pix = tensor_mat.at(i, j); + const float shift = std::max(input_pix[0], input_pix[1]); + const float softmax_denom = + std::exp(input_pix[0] - shift) + std::exp(input_pix[1] - shift); + float new_mask_value = + std::exp(input_pix[output_layer_index] - shift) / softmax_denom; + // Combine previous value with current using uncertainty^2 as mixing coeff + if (has_prev_mask) { + const float prev_mask_value = + input_mask_mat.at(i, j)[0] / 255.0f; + const float eps = 0.001; + float uncertainty_alpha = + 1.0 + + (new_mask_value * std::log(new_mask_value + eps) + + (1.0 - new_mask_value) * std::log(1.0 - new_mask_value + eps)) / + std::log(2.0f); + uncertainty_alpha = Clamp(uncertainty_alpha, 0.0f, 1.0f); + // Equivalent to: a = 1 - (1 - a) * (1 - a); (squaring the uncertainty) + uncertainty_alpha *= 2.0 - uncertainty_alpha; + const float mixed_mask_value = + new_mask_value * uncertainty_alpha + + prev_mask_value * (1.0f - uncertainty_alpha); + new_mask_value = mixed_mask_value * combine_with_prev_ratio + + (1.0f - combine_with_prev_ratio) * new_mask_value; + } + const uchar mask_value = static_cast(new_mask_value * 255); + // Set both R and A channels for convenience. + const cv::Vec4b out_value = {mask_value, 0, 0, mask_value}; + small_mask_mat.at(i, j) = out_value; + } + } + + if (options_.flip_vertically()) cv::flip(small_mask_mat, small_mask_mat, 0); + + // Upsample small mask into output. + cv::Mat large_mask_mat; + cv::resize(small_mask_mat, large_mask_mat, + cv::Size(output_width, output_height)); + + // Send out image as CPU packet. + std::unique_ptr output_mask = absl::make_unique( + ImageFormat::SRGBA, output_width, output_height); + cv::Mat output_mat = formats::MatView(output_mask.get()); + large_mask_mat.copyTo(output_mat); + cc->Outputs().Tag("MASK").Add(output_mask.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); } // Steps: @@ -267,10 +375,9 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); output_width = input_image.width(); output_height = input_image.height(); } - RET_CHECK_EQ(input_tensors.size(), 1); - // Create initial output mask texture. + // Create initial working mask texture. ::tflite::gpu::gl::GlTexture small_mask_texture; ::tflite::gpu::gl::CreateReadWriteRgbaImageTexture( tflite::gpu::DataType::UINT8, // GL_RGBA8 @@ -285,6 +392,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); tflite::gpu::gl::CopyBuffer(input_tensors[0], *tensor_buffer_); // Run shader, process mask tensor. + // Run softmax over tensor output and blend with previous mask. { const int output_index = 0; glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0, @@ -292,8 +400,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); tensor_buffer_->BindToIndex(2); const tflite::gpu::uint3 workgroups = { - RoundUp(tensor_width_, kWorkgroupSize), - RoundUp(tensor_height_, kWorkgroupSize), 1}; + NumGroups(tensor_width_, kWorkgroupSize), + NumGroups(tensor_height_, kWorkgroupSize), 1}; if (!has_prev_mask) { mask_program_no_prev_->Dispatch(workgroups); @@ -330,8 +438,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); // Cleanup input_mask_texture.Release(); output_texture.Release(); - #endif // __ANDROID__ + return ::mediapipe::OkStatus(); } @@ -445,7 +553,7 @@ void main() { int linear_index = gid.y * out_width + gid.x; vec2 input_value = input_data.elements[linear_index]; - // Only two channel output is supported. + // Only two channel input tensor is supported. vec2 input_px = input_value.rg; float shift = max(input_px.r, input_px.g); float softmax_denom = exp(input_px.r - shift) + exp(input_px.g - shift); @@ -474,8 +582,8 @@ void main() { (1.0f - combine_with_previous_ratio) * new_mask_value; #endif // READ_PREVIOUS - // Texture coordinates are inverted on y axis. - ivec2 output_coordinate = ivec2(gid.x, out_height - gid.y - 1); + int y_coord = int($4); + ivec2 output_coordinate = ivec2(gid.x, y_coord); // Set both R and A channels for convenience. vec4 out_value = vec4(new_mask_value, 0.0, 0.0, new_mask_value); imageStore(output_texture, output_coordinate, out_value); @@ -483,10 +591,12 @@ void main() { const std::string shader_src_no_previous = absl::Substitute( shader_src_template, kWorkgroupSize, options_.output_layer_index(), - options_.combine_with_previous_ratio(), ""); + options_.combine_with_previous_ratio(), "", + options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y"); const std::string shader_src_with_previous = absl::Substitute( shader_src_template, kWorkgroupSize, options_.output_layer_index(), - options_.combine_with_previous_ratio(), "#define READ_PREVIOUS"); + options_.combine_with_previous_ratio(), "#define READ_PREVIOUS", + options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y"); auto status = ::tflite::gpu::OkStatus(); diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.proto b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.proto index 9694d2c5f..d04aa562b 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.proto @@ -34,4 +34,7 @@ message TfLiteTensorsToSegmentationCalculatorOptions { // Model specific: Channel to use for processing tensor. optional int32 output_layer_index = 5 [default = 1]; + + // Flip result image mask along y-axis. + optional bool flip_vertically = 6; } diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 44eb4334a..8204ff384 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -241,6 +241,12 @@ cc_library( "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:shader_util", ], + "//mediapipe:ios": [ + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_simple_shaders", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:shader_util", + ], "//conditions:default": [], }), alwayslink = 1, @@ -261,6 +267,12 @@ cc_library( "//mediapipe:android": [ "//mediapipe/util/android/file/base", ], + "//mediapipe:apple": [ + "//mediapipe/util/android/file/base", + ], + "//mediapipe:macos": [ + "//mediapipe/framework/port:file_helpers", + ], "//conditions:default": [ "//mediapipe/framework/port:file_helpers", ], @@ -285,6 +297,147 @@ cc_library( alwayslink = 1, ) +mediapipe_cc_proto_library( + name = "thresholding_calculator_cc_proto", + srcs = ["thresholding_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":thresholding_calculator_proto"], +) + +cc_library( + name = "thresholding_calculator", + srcs = ["thresholding_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":thresholding_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +mediapipe_cc_proto_library( + name = "landmarks_to_detection_calculator_cc_proto", + srcs = ["landmarks_to_detection_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":landmarks_to_detection_calculator_proto"], +) + +cc_library( + name = "landmarks_to_detection_calculator", + srcs = ["landmarks_to_detection_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":landmarks_to_detection_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +mediapipe_cc_proto_library( + name = "detections_to_rects_calculator_cc_proto", + srcs = ["detections_to_rects_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":detections_to_rects_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "landmark_projection_calculator_cc_proto", + srcs = ["landmark_projection_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":landmark_projection_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "rect_transformation_calculator_cc_proto", + srcs = ["rect_transformation_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":rect_transformation_calculator_proto"], +) + +cc_library( + name = "detections_to_rects_calculator", + srcs = ["detections_to_rects_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":detections_to_rects_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "rect_transformation_calculator", + srcs = ["rect_transformation_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":rect_transformation_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_test( + name = "detections_to_rects_calculator_test", + size = "small", + srcs = ["detections_to_rects_calculator_test.cc"], + deps = [ + ":detections_to_rects_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) + +proto_library( + name = "rect_to_render_data_calculator_proto", + srcs = ["rect_to_render_data_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/util:color_proto", + "//mediapipe/util:render_data_proto", + ], +) + proto_library( name = "detections_to_render_data_calculator_proto", srcs = ["detections_to_render_data_calculator.proto"], @@ -296,6 +449,78 @@ proto_library( ], ) +proto_library( + name = "landmarks_to_render_data_calculator_proto", + srcs = ["landmarks_to_render_data_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/util:color_proto", + "//mediapipe/util:render_data_proto", + ], +) + +proto_library( + name = "thresholding_calculator_proto", + srcs = ["thresholding_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/util:color_proto", + "//mediapipe/util:render_data_proto", + ], +) + +proto_library( + name = "detections_to_rects_calculator_proto", + srcs = ["detections_to_rects_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +proto_library( + name = "landmark_projection_calculator_proto", + srcs = ["landmark_projection_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +proto_library( + name = "rect_transformation_calculator_proto", + srcs = ["rect_transformation_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +proto_library( + name = "landmarks_to_detection_calculator_proto", + srcs = ["landmarks_to_detection_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/util:color_proto", + "//mediapipe/util:render_data_proto", + ], +) + +mediapipe_cc_proto_library( + name = "rect_to_render_data_calculator_cc_proto", + srcs = ["rect_to_render_data_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":rect_to_render_data_calculator_proto"], +) + mediapipe_cc_proto_library( name = "detections_to_render_data_calculator_cc_proto", srcs = ["detections_to_render_data_calculator.proto"], @@ -327,6 +552,52 @@ cc_library( alwayslink = 1, ) +mediapipe_cc_proto_library( + name = "landmarks_to_render_data_calculator_cc_proto", + srcs = ["landmarks_to_render_data_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + ], + visibility = ["//mediapipe:__subpackages__"], + deps = [":landmarks_to_render_data_calculator_proto"], +) + +cc_library( + name = "landmarks_to_render_data_calculator", + srcs = ["landmarks_to_render_data_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":landmarks_to_render_data_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_library( + name = "rect_to_render_data_calculator", + srcs = ["rect_to_render_data_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":rect_to_render_data_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + ], + alwayslink = 1, +) + cc_test( name = "detections_to_render_data_calculator_test", size = "small", @@ -364,6 +635,35 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "landmark_letterbox_removal_calculator", + srcs = ["landmark_letterbox_removal_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "landmark_projection_calculator", + srcs = ["landmark_projection_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":landmark_projection_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + cc_test( name = "detection_letterbox_removal_calculator_test", srcs = ["detection_letterbox_removal_calculator_test.cc"], @@ -379,3 +679,18 @@ cc_test( "//mediapipe/framework/tool:validate_type", ], ) + +cc_test( + name = "landmark_letterbox_removal_calculator_test", + srcs = ["landmark_letterbox_removal_calculator_test.cc"], + deps = [ + ":landmark_letterbox_removal_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:validate_type", + ], +) diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.cc b/mediapipe/calculators/util/annotation_overlay_calculator.cc index df3e3d891..5cd4e20e2 100644 --- a/mediapipe/calculators/util/annotation_overlay_calculator.cc +++ b/mediapipe/calculators/util/annotation_overlay_calculator.cc @@ -27,11 +27,12 @@ #include "mediapipe/util/annotation_renderer.h" #include "mediapipe/util/color.pb.h" -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS namespace mediapipe { @@ -44,6 +45,14 @@ constexpr char kInputFrameTagGpu[] = "INPUT_FRAME_GPU"; constexpr char kOutputFrameTagGpu[] = "OUTPUT_FRAME_GPU"; enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; + +// Round up n to next multiple of m. +size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT + +// When using GPU, this color will become transparent when the calculator +// merges the annotation overlay with the image frame. As a result, drawing in +// this color is not supported and it should be set to something unlikely used. +constexpr int kAnnotationBackgroundColor[] = {100, 101, 102}; } // namespace // A calculator for rendering data on images. @@ -66,7 +75,8 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; // // For GPU input frames, only 4-channel images are supported. // -// Note: When using GPU, drawing with black color is not supported. +// Note: When using GPU, drawing with color kAnnotationBackgroundColor (defined +// above) is not supported. // // Example config (CPU): // node { @@ -136,11 +146,13 @@ class AnnotationOverlayCalculator : public CalculatorBase { bool use_gpu_ = false; bool gpu_initialized_ = false; -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; GLuint image_mat_tex_ = 0; // Overlay drawing image for GPU. -#endif // __ANDROID__ + int width_ = 0; + int height_ = 0; +#endif // __ANDROID__ or iOS }; REGISTER_CALCULATOR(AnnotationOverlayCalculator); @@ -161,12 +173,12 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); int num_render_streams = cc->Inputs().NumEntries(); // Input image to render onto copy of. -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().Tag(kInputFrameTagGpu).Set(); num_render_streams = cc->Inputs().NumEntries() - 1; } -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS if (cc->Inputs().HasTag(kInputFrameTag)) { cc->Inputs().Tag(kInputFrameTag).Set(); num_render_streams = cc->Inputs().NumEntries() - 1; @@ -178,32 +190,33 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); } // Rendered image. -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().Tag(kOutputFrameTagGpu).Set(); } -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS if (cc->Outputs().HasTag(kOutputFrameTag)) { cc->Outputs().Tag(kOutputFrameTag).Set(); } -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } ::mediapipe::Status AnnotationOverlayCalculator::Open(CalculatorContext* cc) { - options_ = cc->Options(); + cc->SetOffset(TimestampDiff(0)); + options_ = cc->Options(); if (cc->Inputs().HasTag(kInputFrameTagGpu) && cc->Outputs().HasTag(kOutputFrameTagGpu)) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) use_gpu_ = true; #else - RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet."; -#endif // __ANDROID__ + RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; +#endif // __ANDROID__ or iOS } if (cc->Inputs().HasTag(kInputFrameTagGpu) || @@ -233,9 +246,9 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); } if (use_gpu_) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif +#endif // __ANDROID__ or iOS } return ::mediapipe::OkStatus(); @@ -247,6 +260,16 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); std::unique_ptr image_mat; ImageFormat::Format target_format; if (use_gpu_) { +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) + if (!gpu_initialized_) { + RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { + RETURN_IF_ERROR(GlSetup(cc)); + return ::mediapipe::OkStatus(); + })); + gpu_initialized_ = true; + } +#endif // __ANDROID__ or iOS RETURN_IF_ERROR(CreateRenderTargetGpu(cc, image_mat)); } else { RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format)); @@ -265,21 +288,15 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); } if (use_gpu_) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) // Overlay rendered image in OpenGL, onto a copy of input. uchar* image_mat_ptr = image_mat->data; RETURN_IF_ERROR(gpu_helper_.RunInGlContext( [this, cc, image_mat_ptr]() -> ::mediapipe::Status { - if (!gpu_initialized_) { - RETURN_IF_ERROR(GlSetup(cc)); - gpu_initialized_ = true; - } - RETURN_IF_ERROR(RenderToGpu(cc, image_mat_ptr)); - return ::mediapipe::OkStatus(); })); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS } else { // Copy the rendered image to output. uchar* image_mat_ptr = image_mat->data; @@ -290,26 +307,25 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); } ::mediapipe::Status AnnotationOverlayCalculator::Close(CalculatorContext* cc) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; if (image_mat_tex_) glDeleteTextures(1, &image_mat_tex_); image_mat_tex_ = 0; }); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } ::mediapipe::Status AnnotationOverlayCalculator::RenderToCpu( - CalculatorContext* cc, const ImageFormat::Format& target_format, uchar* data_image) { auto output_frame = absl::make_unique( target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight()); -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight(), data_image, ImageFrame::kGlDefaultAlignmentBoundary); @@ -317,7 +333,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight(), data_image, ImageFrame::kDefaultAlignmentBoundary); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS cc->Outputs() .Tag(kOutputFrameTag) @@ -328,22 +344,21 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); ::mediapipe::Status AnnotationOverlayCalculator::RenderToGpu( CalculatorContext* cc, uchar* overlay_image) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) // Source and destination textures. const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); - const int width = input_frame.width(), height = input_frame.height(); auto output_texture = gpu_helper_.CreateDestinationTexture( - width, height, mediapipe::GpuBufferFormat::kBGRA32); + width_, height_, mediapipe::GpuBufferFormat::kBGRA32); // Upload render target to GPU. { glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); glBindTexture(GL_TEXTURE_2D, image_mat_tex_); - glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, width, height, GL_RGB, + glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, width_, height_, GL_RGB, GL_UNSIGNED_BYTE, overlay_image); glBindTexture(GL_TEXTURE_2D, 0); } @@ -375,7 +390,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); // Cleanup input_texture.Release(); output_texture.Release(); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } @@ -436,7 +451,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); ::mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetGpu( CalculatorContext* cc, std::unique_ptr& image_mat) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (image_frame_available_) { const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); @@ -446,23 +461,24 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); if (format != mediapipe::ImageFormat::SRGBA) RET_CHECK_FAIL() << "Unsupported GPU input format."; - image_mat = - absl::make_unique(input_frame.height(), input_frame.width(), - CV_8UC3, cv::Scalar(0, 0, 0, 0)); + image_mat = absl::make_unique( + height_, width_, CV_8UC3, + cv::Scalar(kAnnotationBackgroundColor[0], kAnnotationBackgroundColor[1], + kAnnotationBackgroundColor[2])); } else { image_mat = absl::make_unique( options_.canvas_height_px(), options_.canvas_width_px(), CV_8UC3, cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(), options_.canvas_color().b())); } -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } ::mediapipe::Status AnnotationOverlayCalculator::GlRender( CalculatorContext* cc) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -510,14 +526,14 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); glBindVertexArray(0); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } ::mediapipe::Status AnnotationOverlayCalculator::GlSetup( CalculatorContext* cc) { -#if defined(__ANDROID__) +#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -548,13 +564,14 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); in vec2 sample_coordinate; uniform sampler2D input_frame; uniform sampler2D overlay; + uniform vec3 transparent_color; void main() { vec3 image_pix = texture2D(input_frame, sample_coordinate).rgb; vec3 overlay_pix = texture2D(overlay, sample_coordinate).rgb; vec3 out_pix = image_pix; - float mag = dot(overlay_pix.rgb, vec3(1.0)); - if (mag > 0.0) out_pix = overlay_pix; + float dist = distance(overlay_pix.rgb, transparent_color); + if (dist > 0.001) out_pix = overlay_pix; fragColor.rgb = out_pix; fragColor.a = 1.0; } @@ -568,15 +585,23 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); glUseProgram(program_); glUniform1i(glGetUniformLocation(program_, "input_frame"), 1); glUniform1i(glGetUniformLocation(program_, "overlay"), 2); + glUniform3f(glGetUniformLocation(program_, "transparent_color"), + kAnnotationBackgroundColor[0] / 255.0, + kAnnotationBackgroundColor[1] / 255.0, + kAnnotationBackgroundColor[2] / 255.0); // Init texture for opencv rendered frame. const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); - const int width = input_frame.width(), height = input_frame.height(); + // Ensure GPU texture is divisible by 4. See b/138751944 for more info. + width_ = + RoundUp(input_frame.width(), ImageFrame::kGlDefaultAlignmentBoundary); + height_ = + RoundUp(input_frame.height(), ImageFrame::kGlDefaultAlignmentBoundary); { glGenTextures(1, &image_mat_tex_); glBindTexture(GL_TEXTURE_2D, image_mat_tex_); - glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB8, width, height, 0, GL_RGB, + glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB8, width_, height_, 0, GL_RGB, GL_UNSIGNED_BYTE, nullptr); glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); @@ -584,7 +609,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator); glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); glBindTexture(GL_TEXTURE_2D, 0); } -#endif // __ANDROID__ +#endif // __ANDROID__ or iOS return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc index 107e08148..1c33b45a9 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc @@ -67,6 +67,8 @@ REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); ::mediapipe::Status DetectionLabelIdToTextCalculator::Open( CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + const auto& options = cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>(); diff --git a/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc index 27dfc333d..cf3761010 100644 --- a/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc +++ b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc @@ -83,6 +83,12 @@ class DetectionLetterboxRemovalCalculator : public CalculatorBase { return ::mediapipe::OkStatus(); } + ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Process(CalculatorContext* cc) override { // Only process if there's input detections. if (cc->Inputs().Tag(kDetectionsTag).IsEmpty()) { diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.cc b/mediapipe/calculators/util/detections_to_rects_calculator.cc new file mode 100644 index 000000000..4dbf90054 --- /dev/null +++ b/mediapipe/calculators/util/detections_to_rects_calculator.cc @@ -0,0 +1,314 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +using mediapipe::DetectionsToRectsCalculatorOptions; + +namespace { + +constexpr char kDetectionTag[] = "DETECTION"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kRectTag[] = "RECT"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kRectsTag[] = "RECTS"; +constexpr char kNormRectsTag[] = "NORM_RECTS"; + +::mediapipe::Status DetectionToRect(const Detection& detection, Rect* rect) { + const LocationData location_data = detection.location_data(); + RET_CHECK(location_data.format() == LocationData::BOUNDING_BOX) + << "Only Detection with formats of BOUNDING_BOX can be converted to Rect"; + const LocationData::BoundingBox bounding_box = location_data.bounding_box(); + rect->set_x_center(bounding_box.xmin() + bounding_box.width() / 2); + rect->set_y_center(bounding_box.ymin() + bounding_box.height() / 2); + rect->set_width(bounding_box.width()); + rect->set_height(bounding_box.height()); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status DetectionToNormalizedRect(const Detection& detection, + NormalizedRect* rect) { + const LocationData location_data = detection.location_data(); + RET_CHECK(location_data.format() == LocationData::RELATIVE_BOUNDING_BOX) + << "Only Detection with formats of RELATIVE_BOUNDING_BOX can be " + "converted to NormalizedRect"; + const LocationData::RelativeBoundingBox bounding_box = + location_data.relative_bounding_box(); + rect->set_x_center(bounding_box.xmin() + bounding_box.width() / 2); + rect->set_y_center(bounding_box.ymin() + bounding_box.height() / 2); + rect->set_width(bounding_box.width()); + rect->set_height(bounding_box.height()); + return ::mediapipe::OkStatus(); +} + +// Wraps around an angle in radians to within -M_PI and M_PI. +inline float NormalizeRadians(float angle) { + return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI)); +} + +} // namespace + +// A calculator that converts Detection proto to Rect proto. +// +// Detection is the format for encoding one or more detections in an image. +// The input can be a single Detection or std::vector. The output can +// be either a single Rect or NormalizedRect, or std::vector or +// std::vector. If Rect is used, the LocationData format is +// expected to be BOUNDING_BOX, and if NormalizedRect is used it is expected to +// be RELATIVE_BOUNDING_BOX. +// +// When the input is std::vector and the output is a Rect or +// NormalizedRect, only the first detection is converted. When the input is a +// single Detection and the output is a std::vector or +// std::vector, the output is a vector of size 1. +// +// Inputs: +// +// One of the following: +// DETECTION: A Detection proto. +// DETECTIONS: An std::vector. +// +// IMAGE_SIZE (optional): A std::pair represention image width and +// height. This is required only when rotation needs to be computed (see +// calculator options). +// +// Output: +// One of the following: +// RECT: A Rect proto. +// NORM_RECT: A NormalizedRect proto. +// RECTS: An std::vector. +// NORM_RECTS: An std::vector. +// +// Example config: +// node { +// calculator: "DetectionsToRectsCalculator" +// input_stream: "DETECTIONS:detections" +// input_stream: "IMAGE_SIZE:image_size" +// output_stream: "NORM_RECT:rect" +// options: { +// [mediapipe.DetectionsToRectCalculatorOptions.ext] { +// rotation_vector_start_keypoint_index: 0 +// rotation_vector_end_keypoint_index: 2 +// rotation_vector_target_angle_degrees: 90 +// output_zero_rect_for_empty_detections: true +// } +// } +// } +class DetectionsToRectsCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + float ComputeRotation(const Detection& detection, + const std::pair image_size); + + DetectionsToRectsCalculatorOptions options_; + int start_keypoint_index_; + int end_keypoint_index_; + float target_angle_; // In radians. + bool rotate_; + bool output_zero_rect_for_empty_detections_; +}; +REGISTER_CALCULATOR(DetectionsToRectsCalculator); + +::mediapipe::Status DetectionsToRectsCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kDetectionTag) ^ + cc->Inputs().HasTag(kDetectionsTag)) + << "Exactly one of DETECTION or DETECTIONS input stream should be " + "provided."; + RET_CHECK_EQ((cc->Outputs().HasTag(kNormRectTag) ? 1 : 0) + + (cc->Outputs().HasTag(kRectTag) ? 1 : 0) + + (cc->Outputs().HasTag(kNormRectsTag) ? 1 : 0) + + (cc->Outputs().HasTag(kRectsTag) ? 1 : 0), + 1) + << "Exactly one of NORM_RECT, RECT, NORM_RECTS or RECTS output stream " + "should be provided."; + + if (cc->Inputs().HasTag(kDetectionTag)) { + cc->Inputs().Tag(kDetectionTag).Set(); + } + if (cc->Inputs().HasTag(kDetectionsTag)) { + cc->Inputs().Tag(kDetectionsTag).Set>(); + } + if (cc->Inputs().HasTag(kImageSizeTag)) { + cc->Inputs().Tag(kImageSizeTag).Set>(); + } + + if (cc->Outputs().HasTag(kRectTag)) { + cc->Outputs().Tag(kRectTag).Set(); + } + if (cc->Outputs().HasTag(kNormRectTag)) { + cc->Outputs().Tag(kNormRectTag).Set(); + } + if (cc->Outputs().HasTag(kRectsTag)) { + cc->Outputs().Tag(kRectsTag).Set>(); + } + if (cc->Outputs().HasTag(kNormRectsTag)) { + cc->Outputs().Tag(kNormRectsTag).Set>(); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status DetectionsToRectsCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + options_ = cc->Options(); + + if (options_.has_rotation_vector_start_keypoint_index()) { + RET_CHECK(options_.has_rotation_vector_end_keypoint_index()); + RET_CHECK(options_.has_rotation_vector_target_angle() ^ + options_.has_rotation_vector_target_angle_degrees()); + RET_CHECK(cc->Inputs().HasTag(kImageSizeTag)); + + if (options_.has_rotation_vector_target_angle()) { + target_angle_ = options_.rotation_vector_target_angle(); + } else { + target_angle_ = + M_PI * options_.rotation_vector_target_angle_degrees() / 180.f; + } + start_keypoint_index_ = options_.rotation_vector_start_keypoint_index(); + end_keypoint_index_ = options_.rotation_vector_end_keypoint_index(); + rotate_ = true; + } + + output_zero_rect_for_empty_detections_ = + options_.output_zero_rect_for_empty_detections(); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status DetectionsToRectsCalculator::Process( + CalculatorContext* cc) { + if (cc->Inputs().HasTag(kDetectionTag) && + cc->Inputs().Tag(kDetectionTag).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + if (cc->Inputs().HasTag(kDetectionsTag) && + cc->Inputs().Tag(kDetectionsTag).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + + std::vector detections; + if (cc->Inputs().HasTag(kDetectionTag)) { + detections.push_back(cc->Inputs().Tag(kDetectionTag).Get()); + } + if (cc->Inputs().HasTag(kDetectionsTag)) { + detections = cc->Inputs().Tag(kDetectionsTag).Get>(); + if (detections.empty()) { + if (output_zero_rect_for_empty_detections_) { + if (cc->Outputs().HasTag(kRectTag)) { + cc->Outputs().Tag(kRectTag).AddPacket( + MakePacket().At(cc->InputTimestamp())); + } + if (cc->Outputs().HasTag(kNormRectTag)) { + cc->Outputs() + .Tag(kNormRectTag) + .AddPacket(MakePacket().At(cc->InputTimestamp())); + } + } + return ::mediapipe::OkStatus(); + } + } + + std::pair image_size; + if (rotate_) { + RET_CHECK(!cc->Inputs().Tag(kImageSizeTag).IsEmpty()); + image_size = cc->Inputs().Tag(kImageSizeTag).Get>(); + } + + if (cc->Outputs().HasTag(kRectTag)) { + auto output_rect = absl::make_unique(); + RETURN_IF_ERROR(DetectionToRect(detections[0], output_rect.get())); + if (rotate_) { + output_rect->set_rotation(ComputeRotation(detections[0], image_size)); + } + cc->Outputs().Tag(kRectTag).Add(output_rect.release(), + cc->InputTimestamp()); + } + if (cc->Outputs().HasTag(kNormRectTag)) { + auto output_rect = absl::make_unique(); + RETURN_IF_ERROR( + DetectionToNormalizedRect(detections[0], output_rect.get())); + if (rotate_) { + output_rect->set_rotation(ComputeRotation(detections[0], image_size)); + } + cc->Outputs() + .Tag(kNormRectTag) + .Add(output_rect.release(), cc->InputTimestamp()); + } + if (cc->Outputs().HasTag(kRectsTag)) { + auto output_rects = absl::make_unique>(detections.size()); + for (int i = 0; i < detections.size(); ++i) { + RETURN_IF_ERROR(DetectionToRect(detections[i], &(output_rects->at(i)))); + if (rotate_) { + output_rects->at(i).set_rotation( + ComputeRotation(detections[i], image_size)); + } + } + cc->Outputs().Tag(kRectsTag).Add(output_rects.release(), + cc->InputTimestamp()); + } + if (cc->Outputs().HasTag(kNormRectsTag)) { + auto output_rects = + absl::make_unique>(detections.size()); + for (int i = 0; i < detections.size(); ++i) { + RETURN_IF_ERROR( + DetectionToNormalizedRect(detections[i], &(output_rects->at(i)))); + if (rotate_) { + output_rects->at(i).set_rotation( + ComputeRotation(detections[i], image_size)); + } + } + cc->Outputs() + .Tag(kNormRectsTag) + .Add(output_rects.release(), cc->InputTimestamp()); + } + + return ::mediapipe::OkStatus(); +} + +float DetectionsToRectsCalculator::ComputeRotation( + const Detection& detection, const std::pair image_size) { + const auto& location_data = detection.location_data(); + const float x0 = location_data.relative_keypoints(start_keypoint_index_).x() * + image_size.first; + const float y0 = location_data.relative_keypoints(start_keypoint_index_).y() * + image_size.second; + const float x1 = location_data.relative_keypoints(end_keypoint_index_).x() * + image_size.first; + const float y1 = location_data.relative_keypoints(end_keypoint_index_).y() * + image_size.second; + + float rotation = target_angle_ - std::atan2(-(y1 - y0), x1 - x0); + + return NormalizeRadians(rotation); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.proto b/mediapipe/calculators/util/detections_to_rects_calculator.proto new file mode 100644 index 000000000..8d1a49a1e --- /dev/null +++ b/mediapipe/calculators/util/detections_to_rects_calculator.proto @@ -0,0 +1,38 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message DetectionsToRectsCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional DetectionsToRectsCalculatorOptions ext = 262691807; + } + + // Specify the rotation angle of the output rect with a vector formed by + // connecting two keypoints in the detection, together with the target angle + // (can be in radians or in degrees) of that vector after rotation. The target + // angle is counter-clockwise starting from the positive x-axis. + optional int32 rotation_vector_start_keypoint_index = 1; + optional int32 rotation_vector_end_keypoint_index = 2; + optional float rotation_vector_target_angle = 3; // In radians. + optional float rotation_vector_target_angle_degrees = 4; // In degrees. + + // Whether to output a zero-rect (with origin and size both zero) when the + // input detection vector is empty. + optional bool output_zero_rect_for_empty_detections = 5; +} diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc new file mode 100644 index 000000000..8b2b8f166 --- /dev/null +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -0,0 +1,312 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +Detection DetectionWithLocationData(int32 xmin, int32 ymin, int32 width, + int32 height) { + Detection detection; + LocationData* location_data = detection.mutable_location_data(); + location_data->set_format(LocationData::BOUNDING_BOX); + location_data->mutable_bounding_box()->set_xmin(xmin); + location_data->mutable_bounding_box()->set_ymin(ymin); + location_data->mutable_bounding_box()->set_width(width); + location_data->mutable_bounding_box()->set_height(height); + return detection; +} + +Detection DetectionWithRelativeLocationData(double xmin, double ymin, + double width, double height) { + Detection detection; + LocationData* location_data = detection.mutable_location_data(); + location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); + location_data->mutable_relative_bounding_box()->set_xmin(xmin); + location_data->mutable_relative_bounding_box()->set_ymin(ymin); + location_data->mutable_relative_bounding_box()->set_width(width); + location_data->mutable_relative_bounding_box()->set_height(height); + return detection; +} + +TEST(DetectionsToRectsCalculatorTest, DetectionToRect) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTION:detection" + output_stream: "RECT:rect" + )")); + + auto detection = absl::make_unique( + DetectionWithLocationData(100, 200, 300, 400)); + + runner.MutableInputs() + ->Tag("DETECTION") + .packets.push_back( + Adopt(detection.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Tag("RECT").packets; + ASSERT_EQ(1, output.size()); + const auto& rect = output[0].Get(); + EXPECT_EQ(rect.width(), 300); + EXPECT_EQ(rect.height(), 400); + EXPECT_EQ(rect.x_center(), 250); + EXPECT_EQ(rect.y_center(), 400); +} + +TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTION:detection" + output_stream: "NORM_RECT:rect" + )")); + + auto detection = absl::make_unique( + DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); + + runner.MutableInputs() + ->Tag("DETECTION") + .packets.push_back( + Adopt(detection.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Tag("NORM_RECT").packets; + ASSERT_EQ(1, output.size()); + const auto& rect = output[0].Get(); + EXPECT_FLOAT_EQ(rect.width(), 0.3); + EXPECT_FLOAT_EQ(rect.height(), 0.4); + EXPECT_FLOAT_EQ(rect.x_center(), 0.25); + EXPECT_FLOAT_EQ(rect.y_center(), 0.4); +} + +TEST(DetectionsToRectsCalculatorTest, DetectionsToRect) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "RECT:rect" + )")); + + auto detections(absl::make_unique>()); + detections->push_back(DetectionWithLocationData(100, 200, 300, 400)); + detections->push_back(DetectionWithLocationData(200, 300, 400, 500)); + + runner.MutableInputs() + ->Tag("DETECTIONS") + .packets.push_back( + Adopt(detections.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Tag("RECT").packets; + ASSERT_EQ(1, output.size()); + const auto& rect = output[0].Get(); + EXPECT_EQ(rect.width(), 300); + EXPECT_EQ(rect.height(), 400); + EXPECT_EQ(rect.x_center(), 250); + EXPECT_EQ(rect.y_center(), 400); +} + +TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRect) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "NORM_RECT:rect" + )")); + + auto detections(absl::make_unique>()); + detections->push_back(DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); + detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5)); + + runner.MutableInputs() + ->Tag("DETECTIONS") + .packets.push_back( + Adopt(detections.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Tag("NORM_RECT").packets; + ASSERT_EQ(1, output.size()); + const auto& rect = output[0].Get(); + EXPECT_FLOAT_EQ(rect.width(), 0.3); + EXPECT_FLOAT_EQ(rect.height(), 0.4); + EXPECT_FLOAT_EQ(rect.x_center(), 0.25); + EXPECT_FLOAT_EQ(rect.y_center(), 0.4); +} + +TEST(DetectionsToRectsCalculatorTest, DetectionsToRects) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "RECTS:rect" + )")); + + auto detections(absl::make_unique>()); + detections->push_back(DetectionWithLocationData(100, 200, 300, 400)); + detections->push_back(DetectionWithLocationData(200, 300, 400, 500)); + + runner.MutableInputs() + ->Tag("DETECTIONS") + .packets.push_back( + Adopt(detections.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Tag("RECTS").packets; + ASSERT_EQ(1, output.size()); + const auto& rects = output[0].Get>(); + EXPECT_EQ(rects.size(), 2); + EXPECT_EQ(rects[0].width(), 300); + EXPECT_EQ(rects[0].height(), 400); + EXPECT_EQ(rects[0].x_center(), 250); + EXPECT_EQ(rects[0].y_center(), 400); + EXPECT_EQ(rects[1].width(), 400); + EXPECT_EQ(rects[1].height(), 500); + EXPECT_EQ(rects[1].x_center(), 400); + EXPECT_EQ(rects[1].y_center(), 550); +} + +TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRects) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "NORM_RECTS:rect" + )")); + + auto detections(absl::make_unique>()); + detections->push_back(DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); + detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5)); + + runner.MutableInputs() + ->Tag("DETECTIONS") + .packets.push_back( + Adopt(detections.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = + runner.Outputs().Tag("NORM_RECTS").packets; + ASSERT_EQ(1, output.size()); + const auto& rects = output[0].Get>(); + EXPECT_EQ(rects.size(), 2); + EXPECT_FLOAT_EQ(rects[0].width(), 0.3); + EXPECT_FLOAT_EQ(rects[0].height(), 0.4); + EXPECT_FLOAT_EQ(rects[0].x_center(), 0.25); + EXPECT_FLOAT_EQ(rects[0].y_center(), 0.4); + EXPECT_FLOAT_EQ(rects[1].width(), 0.4); + EXPECT_FLOAT_EQ(rects[1].height(), 0.5); + EXPECT_FLOAT_EQ(rects[1].x_center(), 0.4); + EXPECT_FLOAT_EQ(rects[1].y_center(), 0.55); +} + +TEST(DetectionsToRectsCalculatorTest, DetectionToRects) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTION:detection" + output_stream: "RECTS:rect" + )")); + + auto detection = absl::make_unique( + DetectionWithLocationData(100, 200, 300, 400)); + + runner.MutableInputs() + ->Tag("DETECTION") + .packets.push_back( + Adopt(detection.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Tag("RECTS").packets; + ASSERT_EQ(1, output.size()); + const auto& rects = output[0].Get>(); + EXPECT_EQ(rects.size(), 1); + EXPECT_EQ(rects[0].width(), 300); + EXPECT_EQ(rects[0].height(), 400); + EXPECT_EQ(rects[0].x_center(), 250); + EXPECT_EQ(rects[0].y_center(), 400); +} + +TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRects) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTION:detection" + output_stream: "NORM_RECTS:rect" + )")); + + auto detection = absl::make_unique( + DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); + + runner.MutableInputs() + ->Tag("DETECTION") + .packets.push_back( + Adopt(detection.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = + runner.Outputs().Tag("NORM_RECTS").packets; + ASSERT_EQ(1, output.size()); + const auto& rects = output[0].Get>(); + EXPECT_EQ(rects.size(), 1); + EXPECT_FLOAT_EQ(rects[0].width(), 0.3); + EXPECT_FLOAT_EQ(rects[0].height(), 0.4); + EXPECT_FLOAT_EQ(rects[0].x_center(), 0.25); + EXPECT_FLOAT_EQ(rects[0].y_center(), 0.4); +} + +TEST(DetectionsToRectsCalculatorTest, WrongInputToRect) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "RECT:rect" + )")); + + auto detections(absl::make_unique>()); + detections->push_back(DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); + + runner.MutableInputs() + ->Tag("DETECTIONS") + .packets.push_back( + Adopt(detections.release()).At(Timestamp::PostStream())); + + ASSERT_THAT( + runner.Run().message(), + testing::HasSubstr("Only Detection with formats of BOUNDING_BOX")); +} + +TEST(DetectionsToRectsCalculatorTest, WrongInputToNormalizedRect) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "NORM_RECT:rect" + )")); + + auto detections(absl::make_unique>()); + detections->push_back(DetectionWithLocationData(100, 200, 300, 400)); + + runner.MutableInputs() + ->Tag("DETECTIONS") + .packets.push_back( + Adopt(detections.release()).At(Timestamp::PostStream())); + + ASSERT_THAT(runner.Run().message(), + testing::HasSubstr( + "Only Detection with formats of RELATIVE_BOUNDING_BOX")); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator.cc b/mediapipe/calculators/util/detections_to_render_data_calculator.cc index aa4b35089..098334ec7 100644 --- a/mediapipe/calculators/util/detections_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/detections_to_render_data_calculator.cc @@ -27,8 +27,8 @@ namespace mediapipe { namespace { +constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kDetectionListTag[] = "DETECTION_LIST"; -constexpr char kDetectionVectorTag[] = "DETECTION_VECTOR"; constexpr char kRenderDataTag[] = "RENDER_DATA"; constexpr char kSceneLabelLabel[] = "LABEL"; @@ -60,8 +60,8 @@ constexpr double kLabelToBoundingBoxRatio = 0.1; // Example config: // node { // calculator: "DetectionsToRenderDataCalculator" +// input_stream: "DETECTIONS:detections" // input_stream: "DETECTION_LIST:detection_list" -// input_stream: "DETECTION_VECTOR:detection_vector" // output_stream: "RENDER_DATA:render_data" // options { // [DetectionsToRenderDataCalculatorOptions.ext] { @@ -80,6 +80,8 @@ class DetectionsToRenderDataCalculator : public CalculatorBase { static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; private: @@ -119,19 +121,26 @@ REGISTER_CALCULATOR(DetectionsToRenderDataCalculator); ::mediapipe::Status DetectionsToRenderDataCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetectionListTag) || - cc->Inputs().HasTag(kDetectionVectorTag)) + cc->Inputs().HasTag(kDetectionsTag)) << "None of the input streams are provided."; if (cc->Inputs().HasTag(kDetectionListTag)) { cc->Inputs().Tag(kDetectionListTag).Set(); } - if (cc->Inputs().HasTag(kDetectionVectorTag)) { - cc->Inputs().Tag(kDetectionVectorTag).Set>(); + if (cc->Inputs().HasTag(kDetectionsTag)) { + cc->Inputs().Tag(kDetectionsTag).Set>(); } cc->Outputs().Tag(kRenderDataTag).Set(); return ::mediapipe::OkStatus(); } +::mediapipe::Status DetectionsToRenderDataCalculator::Open( + CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + return ::mediapipe::OkStatus(); +} + ::mediapipe::Status DetectionsToRenderDataCalculator::Process( CalculatorContext* cc) { const auto& options = cc->Options(); @@ -142,11 +151,8 @@ REGISTER_CALCULATOR(DetectionsToRenderDataCalculator); .detection() .empty(); const bool has_detection_from_vector = - cc->Inputs().HasTag(kDetectionVectorTag) && - !cc->Inputs() - .Tag(kDetectionVectorTag) - .Get>() - .empty(); + cc->Inputs().HasTag(kDetectionsTag) && + !cc->Inputs().Tag(kDetectionsTag).Get>().empty(); if (!options.produce_empty_packet() && !has_detection_from_list && !has_detection_from_vector) { return ::mediapipe::OkStatus(); @@ -164,7 +170,7 @@ REGISTER_CALCULATOR(DetectionsToRenderDataCalculator); } if (has_detection_from_vector) { for (const auto& detection : - cc->Inputs().Tag(kDetectionVectorTag).Get>()) { + cc->Inputs().Tag(kDetectionsTag).Get>()) { AddDetectionToRenderData(detection, options, render_data.get()); } } diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc index f15fec3d0..23a4d7874 100644 --- a/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc @@ -121,7 +121,7 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionList) { TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionVector) { CalculatorRunner runner{ParseTextProtoOrDie(R"( calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION_VECTOR:detection_vector" + input_stream: "DETECTIONS:detections" output_stream: "RENDER_DATA:render_data" )")}; @@ -131,7 +131,7 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionVector) { CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag")); runner.MutableInputs() - ->Tag("DETECTION_VECTOR") + ->Tag("DETECTIONS") .packets.push_back( Adopt(detections.release()).At(Timestamp::PostStream())); @@ -156,7 +156,7 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) { CalculatorRunner runner{ParseTextProtoOrDie(R"( calculator: "DetectionsToRenderDataCalculator" input_stream: "DETECTION_LIST:detection_list" - input_stream: "DETECTION_VECTOR:detection_vector" + input_stream: "DETECTIONS:detections" output_stream: "RENDER_DATA:render_data" )")}; @@ -170,13 +170,13 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) { Adopt(detection_list.release()).At(Timestamp::PostStream())); LocationData location_data2 = CreateLocationData(600, 700, 800, 900); - auto detection_vector(absl::make_unique>()); - detection_vector->push_back( + auto detections(absl::make_unique>()); + detections->push_back( CreateDetection({"label2"}, {}, {0.6}, location_data2, "feature_tag2")); runner.MutableInputs() - ->Tag("DETECTION_VECTOR") + ->Tag("DETECTIONS") .packets.push_back( - Adopt(detection_vector.release()).At(Timestamp::PostStream())); + Adopt(detections.release()).At(Timestamp::PostStream())); MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& actual = @@ -197,7 +197,7 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) { CalculatorRunner runner1{ParseTextProtoOrDie(R"( calculator: "DetectionsToRenderDataCalculator" input_stream: "DETECTION_LIST:detection_list" - input_stream: "DETECTION_VECTOR:detection_vector" + input_stream: "DETECTIONS:detections" output_stream: "RENDER_DATA:render_data" options { [mediapipe.DetectionsToRenderDataCalculatorOptions.ext] { @@ -212,11 +212,11 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) { .packets.push_back( Adopt(detection_list1.release()).At(Timestamp::PostStream())); - auto detection_vector1(absl::make_unique>()); + auto detections1(absl::make_unique>()); runner1.MutableInputs() - ->Tag("DETECTION_VECTOR") + ->Tag("DETECTIONS") .packets.push_back( - Adopt(detection_vector1.release()).At(Timestamp::PostStream())); + Adopt(detections1.release()).At(Timestamp::PostStream())); MEDIAPIPE_ASSERT_OK(runner1.Run()) << "Calculator execution failed."; const std::vector& exact1 = @@ -227,7 +227,7 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) { CalculatorRunner runner2{ParseTextProtoOrDie(R"( calculator: "DetectionsToRenderDataCalculator" input_stream: "DETECTION_LIST:detection_list" - input_stream: "DETECTION_VECTOR:detection_vector" + input_stream: "DETECTIONS:detections" output_stream: "RENDER_DATA:render_data" options { [mediapipe.DetectionsToRenderDataCalculatorOptions.ext] { @@ -242,11 +242,11 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) { .packets.push_back( Adopt(detection_list2.release()).At(Timestamp::PostStream())); - auto detection_vector2(absl::make_unique>()); + auto detections2(absl::make_unique>()); runner2.MutableInputs() - ->Tag("DETECTION_VECTOR") + ->Tag("DETECTIONS") .packets.push_back( - Adopt(detection_vector2.release()).At(Timestamp::PostStream())); + Adopt(detections2.release()).At(Timestamp::PostStream())); MEDIAPIPE_ASSERT_OK(runner2.Run()) << "Calculator execution failed."; const std::vector& exact2 = diff --git a/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc b/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc new file mode 100644 index 000000000..fd22cf191 --- /dev/null +++ b/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc @@ -0,0 +1,130 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace { + +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; + +} // namespace + +// Adjusts landmark locations on a letterboxed image to the corresponding +// locations on the same image with the letterbox removed. This is useful to map +// the landmarks inferred from a letterboxed image, for example, output of +// the ImageTransformationCalculator when the scale mode is FIT, back to the +// corresponding input image before letterboxing. +// +// Input: +// LANDMARKS: An std::vector representing landmarks on an +// letterboxed image. +// +// LETTERBOX_PADDING: An std::array representing the letterbox +// padding from the 4 sides ([left, top, right, bottom]) of the letterboxed +// image, normalized to [0.f, 1.f] by the letterboxed image dimensions. +// +// Output: +// LANDMARKS: An std::vector representing landmarks with +// their locations adjusted to the letterbox-removed (non-padded) image. +// +// Usage example: +// node { +// calculator: "LandmarkLetterboxRemovalCalculator" +// input_stream: "LANDMARKS:landmarks" +// input_stream: "LETTERBOX_PADDING:letterbox_padding" +// output_stream: "LANDMARKS:adjusted_landmarks" +// } +class LandmarkLetterboxRemovalCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) && + cc->Inputs().HasTag(kLetterboxPaddingTag)) + << "Missing one or more input streams."; + + cc->Inputs().Tag(kLandmarksTag).Set>(); + cc->Inputs().Tag(kLetterboxPaddingTag).Set>(); + + cc->Outputs().Tag(kLandmarksTag).Set>(); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + // Only process if there's input landmarks. + if (cc->Inputs().Tag(kLandmarksTag).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + + const auto& input_landmarks = + cc->Inputs().Tag(kLandmarksTag).Get>(); + const auto& letterbox_padding = + cc->Inputs().Tag(kLetterboxPaddingTag).Get>(); + + const float left = letterbox_padding[0]; + const float top = letterbox_padding[1]; + const float left_and_right = letterbox_padding[0] + letterbox_padding[2]; + const float top_and_bottom = letterbox_padding[1] + letterbox_padding[3]; + + auto output_landmarks = + absl::make_unique>(); + for (const auto& landmark : input_landmarks) { + NormalizedLandmark new_landmark; + const float new_x = (landmark.x() - left) / (1.0f - left_and_right); + const float new_y = (landmark.y() - top) / (1.0f - top_and_bottom); + + new_landmark.set_x(new_x); + new_landmark.set_y(new_y); + // Keep z-coord as is. + new_landmark.set_z(landmark.z()); + + output_landmarks->emplace_back(new_landmark); + } + + cc->Outputs() + .Tag(kLandmarksTag) + .Add(output_landmarks.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(LandmarkLetterboxRemovalCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/landmark_letterbox_removal_calculator_test.cc b/mediapipe/calculators/util/landmark_letterbox_removal_calculator_test.cc new file mode 100644 index 000000000..8bea2c54d --- /dev/null +++ b/mediapipe/calculators/util/landmark_letterbox_removal_calculator_test.cc @@ -0,0 +1,111 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/validate_type.h" + +namespace mediapipe { + +NormalizedLandmark CreateLandmark(float x, float y) { + NormalizedLandmark landmark; + landmark.set_x(x); + landmark.set_y(y); + return landmark; +} + +CalculatorGraphConfig::Node GetDefaultNode() { + return ParseTextProtoOrDie(R"( + calculator: "LandmarkLetterboxRemovalCalculator" + input_stream: "LANDMARKS:landmarks" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "LANDMARKS:adjusted_landmarks" + )"); +} + +TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingLeftRight) { + CalculatorRunner runner(GetDefaultNode()); + + auto landmarks = absl::make_unique>(); + landmarks->push_back(CreateLandmark(0.5f, 0.5f)); + landmarks->push_back(CreateLandmark(0.2f, 0.2f)); + landmarks->push_back(CreateLandmark(0.7f, 0.7f)); + runner.MutableInputs() + ->Tag("LANDMARKS") + .packets.push_back( + Adopt(landmarks.release()).At(Timestamp::PostStream())); + + auto padding = absl::make_unique>( + std::array{0.2f, 0.f, 0.3f, 0.f}); + runner.MutableInputs() + ->Tag("LETTERBOX_PADDING") + .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Tag("LANDMARKS").packets; + ASSERT_EQ(1, output.size()); + const auto& output_landmarks = + output[0].Get>(); + + EXPECT_EQ(output_landmarks.size(), 3); + + EXPECT_THAT(output_landmarks[0].x(), testing::FloatNear(0.6f, 1e-5)); + EXPECT_THAT(output_landmarks[0].y(), testing::FloatNear(0.5f, 1e-5)); + EXPECT_THAT(output_landmarks[1].x(), testing::FloatNear(0.0f, 1e-5)); + EXPECT_THAT(output_landmarks[1].y(), testing::FloatNear(0.2f, 1e-5)); + EXPECT_THAT(output_landmarks[2].x(), testing::FloatNear(1.0f, 1e-5)); + EXPECT_THAT(output_landmarks[2].y(), testing::FloatNear(0.7f, 1e-5)); +} + +TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingTopBottom) { + CalculatorRunner runner(GetDefaultNode()); + + auto landmarks = absl::make_unique>(); + landmarks->push_back(CreateLandmark(0.5f, 0.5f)); + landmarks->push_back(CreateLandmark(0.2f, 0.2f)); + landmarks->push_back(CreateLandmark(0.7f, 0.7f)); + runner.MutableInputs() + ->Tag("LANDMARKS") + .packets.push_back( + Adopt(landmarks.release()).At(Timestamp::PostStream())); + + auto padding = absl::make_unique>( + std::array{0.0f, 0.2f, 0.0f, 0.3f}); + runner.MutableInputs() + ->Tag("LETTERBOX_PADDING") + .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); + + MEDIAPIPE_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Tag("LANDMARKS").packets; + ASSERT_EQ(1, output.size()); + const auto& output_landmarks = + output[0].Get>(); + + EXPECT_EQ(output_landmarks.size(), 3); + + EXPECT_THAT(output_landmarks[0].x(), testing::FloatNear(0.5f, 1e-5)); + EXPECT_THAT(output_landmarks[0].y(), testing::FloatNear(0.6f, 1e-5)); + EXPECT_THAT(output_landmarks[1].x(), testing::FloatNear(0.2f, 1e-5)); + EXPECT_THAT(output_landmarks[1].y(), testing::FloatNear(0.0f, 1e-5)); + EXPECT_THAT(output_landmarks[2].x(), testing::FloatNear(0.7f, 1e-5)); + EXPECT_THAT(output_landmarks[2].y(), testing::FloatNear(1.0f, 1e-5)); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/landmark_projection_calculator.cc b/mediapipe/calculators/util/landmark_projection_calculator.cc new file mode 100644 index 000000000..39ac61f2e --- /dev/null +++ b/mediapipe/calculators/util/landmark_projection_calculator.cc @@ -0,0 +1,129 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/calculators/util/landmark_projection_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace { + +constexpr char kLandmarksTag[] = "NORM_LANDMARKS"; +constexpr char kRectTag[] = "NORM_RECT"; + +} // namespace + +// Projects normalized landmarks in a rectangle to its original coordinates. The +// rectangle must also be in normalized coordinates. +// Input: +// NORM_LANDMARKS: An std::vector representing landmarks +// in a normalized rectangle. +// NORM_RECT: An NormalizedRect representing a normalized rectangle in image +// coordinates. +// +// Output: +// NORM_LANDMARKS: An std::vector representing landmarks +// with their locations adjusted to the image. +// +// Usage example: +// node { +// calculator: "LandmarkProjectionCalculator" +// input_stream: "NORM_LANDMARKS:landmarks" +// input_stream: "NORM_RECT:rect" +// output_stream: "NORM_LANDMARKS:projected_landmarks" +// } +class LandmarkProjectionCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) && + cc->Inputs().HasTag(kRectTag)) + << "Missing one or more input streams."; + + cc->Inputs().Tag(kLandmarksTag).Set>(); + cc->Inputs().Tag(kRectTag).Set(); + + cc->Outputs().Tag(kLandmarksTag).Set>(); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + const auto& options = + cc->Options<::mediapipe::LandmarkProjectionCalculatorOptions>(); + // Only process if there's input landmarks. + if (cc->Inputs().Tag(kLandmarksTag).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + + const auto& input_landmarks = + cc->Inputs().Tag(kLandmarksTag).Get>(); + const auto& input_rect = cc->Inputs().Tag(kRectTag).Get(); + + auto output_landmarks = + absl::make_unique>(); + for (const auto& landmark : input_landmarks) { + NormalizedLandmark new_landmark; + + const float x = landmark.x() - 0.5f; + const float y = landmark.y() - 0.5f; + const float angle = options.ignore_rotation() ? 0 : input_rect.rotation(); + float new_x = std::cos(angle) * x - std::sin(angle) * y; + float new_y = std::sin(angle) * x + std::cos(angle) * y; + + new_x = new_x * input_rect.width() + input_rect.x_center(); + new_y = new_y * input_rect.height() + input_rect.y_center(); + + new_landmark.set_x(new_x); + new_landmark.set_y(new_y); + // Keep z-coord as is. + new_landmark.set_z(landmark.z()); + + output_landmarks->emplace_back(new_landmark); + } + + cc->Outputs() + .Tag(kLandmarksTag) + .Add(output_landmarks.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(LandmarkProjectionCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/landmark_projection_calculator.proto b/mediapipe/calculators/util/landmark_projection_calculator.proto new file mode 100644 index 000000000..221adbe00 --- /dev/null +++ b/mediapipe/calculators/util/landmark_projection_calculator.proto @@ -0,0 +1,28 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message LandmarkProjectionCalculatorOptions { + extend CalculatorOptions { + optional LandmarkProjectionCalculatorOptions ext = 263371892; + } + + // Ignore the rotation field of rect proto for projection. + optional bool ignore_rotation = 1 [default = false]; +} diff --git a/mediapipe/calculators/util/landmarks_to_detection_calculator.cc b/mediapipe/calculators/util/landmarks_to_detection_calculator.cc new file mode 100644 index 000000000..ca71ac377 --- /dev/null +++ b/mediapipe/calculators/util/landmarks_to_detection_calculator.cc @@ -0,0 +1,141 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace { + +constexpr char kDetectionTag[] = "DETECTION"; +constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS"; + +Detection ConvertLandmarksToDetection( + const std::vector& landmarks) { + Detection detection; + LocationData* location_data = detection.mutable_location_data(); + + float x_min = std::numeric_limits::max(); + float x_max = std::numeric_limits::min(); + float y_min = std::numeric_limits::max(); + float y_max = std::numeric_limits::min(); + for (const auto& landmark : landmarks) { + x_min = std::min(x_min, landmark.x()); + x_max = std::max(x_max, landmark.x()); + y_min = std::min(y_min, landmark.y()); + y_max = std::max(y_max, landmark.y()); + + auto keypoint = location_data->add_relative_keypoints(); + keypoint->set_x(landmark.x()); + keypoint->set_y(landmark.y()); + } + + location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); + LocationData::RelativeBoundingBox* relative_bbox = + location_data->mutable_relative_bounding_box(); + + relative_bbox->set_xmin(x_min); + relative_bbox->set_ymin(y_min); + relative_bbox->set_width(x_max - x_min); + relative_bbox->set_height(y_max - y_min); + + return detection; +} + +} // namespace + +// Converts NormalizedLandmark to Detection proto. A relative bounding box will +// be created containing all landmarks exactly. A calculator option is provided +// to specify a subset of landmarks for creating the detection. +// +// Input: +// NOMR_LANDMARKS: A vector of NormalizedLandmark. +// +// Output: +// DETECTION: A Detection proto. +// +// Example config: +// node { +// calculator: "LandmarksToDetectionCalculator" +// input_stream: "NORM_LANDMARKS:landmarks" +// output_stream: "DETECTIONS:detections" +// } +class LandmarksToDetectionCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Open(CalculatorContext* cc) override; + + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + ::mediapipe::LandmarksToDetectionCalculatorOptions options_; +}; +REGISTER_CALCULATOR(LandmarksToDetectionCalculator); + +::mediapipe::Status LandmarksToDetectionCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kNormalizedLandmarksTag)); + RET_CHECK(cc->Outputs().HasTag(kDetectionTag)); + // TODO: Also support converting Landmark to Detection. + cc->Inputs() + .Tag(kNormalizedLandmarksTag) + .Set>(); + cc->Outputs().Tag(kDetectionTag).Set(); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status LandmarksToDetectionCalculator::Open( + CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + options_ = cc->Options<::mediapipe::LandmarksToDetectionCalculatorOptions>(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status LandmarksToDetectionCalculator::Process( + CalculatorContext* cc) { + const auto& landmarks = cc->Inputs() + .Tag(kNormalizedLandmarksTag) + .Get>(); + RET_CHECK_GT(landmarks.size(), 0) << "Input landmark vector is empty."; + + auto detection = absl::make_unique(); + if (options_.selected_landmark_indices_size()) { + std::vector subset_landmarks( + options_.selected_landmark_indices_size()); + for (int i = 0; i < subset_landmarks.size(); ++i) { + RET_CHECK_LT(options_.selected_landmark_indices(i), landmarks.size()) + << "Index of landmark subset is out of range."; + subset_landmarks[i] = landmarks[options_.selected_landmark_indices(i)]; + } + *detection = ConvertLandmarksToDetection(subset_landmarks); + } else { + *detection = ConvertLandmarksToDetection(landmarks); + } + cc->Outputs() + .Tag(kDetectionTag) + .Add(detection.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_to_detection_calculator.proto b/mediapipe/calculators/util/landmarks_to_detection_calculator.proto new file mode 100644 index 000000000..b5a563669 --- /dev/null +++ b/mediapipe/calculators/util/landmarks_to_detection_calculator.proto @@ -0,0 +1,28 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message LandmarksToDetectionCalculatorOptions { + extend CalculatorOptions { + optional LandmarksToDetectionCalculatorOptions ext = 260199669; + } + + // A subset of indices to be included when creating the detection. + repeated int32 selected_landmark_indices = 1; +} diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc new file mode 100644 index 000000000..25ffb67ef --- /dev/null +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc @@ -0,0 +1,318 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" +namespace mediapipe { + +namespace { + +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; +constexpr char kRenderDataTag[] = "RENDER_DATA"; +constexpr char kLandmarkLabel[] = "KEYPOINT"; +constexpr int kMaxLandmarkThickness = 18; + +using ::mediapipe::RenderAnnotation_Point; + +inline void SetColor(RenderAnnotation* annotation, const Color& color) { + annotation->mutable_color()->set_r(color.r()); + annotation->mutable_color()->set_g(color.g()); + annotation->mutable_color()->set_b(color.b()); +} + +// Remap x from range [lo hi] to range [0 1] then multiply by scale. +inline float Remap(float x, float lo, float hi, float scale) { + return (x - lo) / (hi - lo + 1e-6) * scale; +} + +template +inline void GetMinMaxZ(const std::vector& landmarks, float* z_min, + float* z_max) { + *z_min = std::numeric_limits::max(); + *z_max = std::numeric_limits::min(); + for (const auto& landmark : landmarks) { + *z_min = std::min(landmark.z(), *z_min); + *z_max = std::max(landmark.z(), *z_max); + } +} + +void SetColorSizeValueFromZ(float z, float z_min, float z_max, + RenderAnnotation* render_annotation) { + const int color_value = 255 - static_cast(Remap(z, z_min, z_max, 255)); + ::mediapipe::Color color; + color.set_r(color_value); + color.set_g(color_value); + color.set_b(color_value); + SetColor(render_annotation, color); + const int thickness = static_cast((1.f - Remap(z, z_min, z_max, 1)) * + kMaxLandmarkThickness); + render_annotation->set_thickness(thickness); +} + +} // namespace + +// A calculator that converts Landmark proto to RenderData proto for +// visualization. The input should be std::vector. It is also possible +// to specify the connections between landmarks. +// +// Example config: +// node { +// calculator: "LandmarksToRenderDataCalculator" +// input_stream: "NORM_LANDMARKS:landmarks" +// output_stream: "RENDER_DATA:render_data" +// options { +// [LandmarksToRenderDataCalculatorOptions.ext] { +// landmark_connections: [0, 1, 1, 2] +// landmark_color { r: 0 g: 255 b: 0 } +// connection_color { r: 0 g: 255 b: 0 } +// thickness: 4.0 +// } +// } +// } +class LandmarksToRenderDataCalculator : public CalculatorBase { + public: + LandmarksToRenderDataCalculator() {} + ~LandmarksToRenderDataCalculator() override {} + LandmarksToRenderDataCalculator(const LandmarksToRenderDataCalculator&) = + delete; + LandmarksToRenderDataCalculator& operator=( + const LandmarksToRenderDataCalculator&) = delete; + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + static void AddConnectionToRenderData( + float start_x, float start_y, float end_x, float end_y, + const LandmarksToRenderDataCalculatorOptions& options, bool normalized, + RenderData* render_data); + static void SetRenderAnnotationColorThickness( + const LandmarksToRenderDataCalculatorOptions& options, + RenderAnnotation* render_annotation); + static RenderAnnotation* AddPointRenderData( + const LandmarksToRenderDataCalculatorOptions& options, + RenderData* render_data); + static void AddConnectionToRenderData( + float start_x, float start_y, float end_x, float end_y, + const LandmarksToRenderDataCalculatorOptions& options, bool normalized, + int gray_val1, int gray_val2, RenderData* render_data); + + template + void AddConnections(const std::vector& landmarks, + bool normalized, RenderData* render_data); + template + void AddConnectionsWithDepth(const std::vector& landmarks, + bool normalized, float min_z, float max_z, + RenderData* render_data); + + LandmarksToRenderDataCalculatorOptions options_; +}; +REGISTER_CALCULATOR(LandmarksToRenderDataCalculator); + +::mediapipe::Status LandmarksToRenderDataCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) || + cc->Inputs().HasTag(kNormLandmarksTag)) + << "None of the input streams are provided."; + RET_CHECK(!(cc->Inputs().HasTag(kLandmarksTag) && + cc->Inputs().HasTag(kNormLandmarksTag))) + << "Can only one type of landmark can be taken. Either absolute or " + "normalized landmarks."; + + if (cc->Inputs().HasTag(kLandmarksTag)) { + cc->Inputs().Tag(kLandmarksTag).Set>(); + } + if (cc->Inputs().HasTag(kNormLandmarksTag)) { + cc->Inputs().Tag(kNormLandmarksTag).Set>(); + } + cc->Outputs().Tag(kRenderDataTag).Set(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status LandmarksToRenderDataCalculator::Open( + CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + options_ = cc->Options(); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status LandmarksToRenderDataCalculator::Process( + CalculatorContext* cc) { + auto render_data = absl::make_unique(); + bool visualize_depth = options_.visualize_landmark_depth(); + float z_min = 0.f; + float z_max = 0.f; + + if (cc->Inputs().HasTag(kLandmarksTag)) { + const auto& landmarks = + cc->Inputs().Tag(kLandmarksTag).Get>(); + RET_CHECK_EQ(options_.landmark_connections_size() % 2, 0) + << "Number of entries in landmark connections must be a multiple of 2"; + if (visualize_depth) { + GetMinMaxZ(landmarks, &z_min, &z_max); + } + // Only change rendering if there are actually z values other than 0. + visualize_depth &= ((z_max - z_min) > 1e-3); + for (const auto& landmark : landmarks) { + auto* landmark_data_render = + AddPointRenderData(options_, render_data.get()); + if (visualize_depth) { + SetColorSizeValueFromZ(landmark.z(), z_min, z_max, + landmark_data_render); + } + auto* landmark_data = landmark_data_render->mutable_point(); + landmark_data->set_normalized(false); + landmark_data->set_x(landmark.x()); + landmark_data->set_y(landmark.y()); + } + if (visualize_depth) { + AddConnectionsWithDepth(landmarks, /*normalized=*/false, z_min, z_max, + render_data.get()); + } else { + AddConnections(landmarks, /*normalized=*/false, render_data.get()); + } + } + + if (cc->Inputs().HasTag(kNormLandmarksTag)) { + const auto& landmarks = cc->Inputs() + .Tag(kNormLandmarksTag) + .Get>(); + RET_CHECK_EQ(options_.landmark_connections_size() % 2, 0) + << "Number of entries in landmark connections must be a multiple of 2"; + if (visualize_depth) { + GetMinMaxZ(landmarks, &z_min, &z_max); + } + // Only change rendering if there are actually z values other than 0. + visualize_depth &= ((z_max - z_min) > 1e-3); + for (const auto& landmark : landmarks) { + auto* landmark_data_render = + AddPointRenderData(options_, render_data.get()); + if (visualize_depth) { + SetColorSizeValueFromZ(landmark.z(), z_min, z_max, + landmark_data_render); + } + auto* landmark_data = landmark_data_render->mutable_point(); + landmark_data->set_normalized(true); + landmark_data->set_x(landmark.x()); + landmark_data->set_y(landmark.y()); + } + if (visualize_depth) { + AddConnectionsWithDepth(landmarks, /*normalized=*/true, z_min, z_max, + render_data.get()); + } else { + AddConnections(landmarks, /*normalized=*/true, render_data.get()); + } + } + + cc->Outputs() + .Tag(kRenderDataTag) + .Add(render_data.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +template +void LandmarksToRenderDataCalculator::AddConnectionsWithDepth( + const std::vector& landmarks, bool normalized, float min_z, + float max_z, RenderData* render_data) { + for (int i = 0; i < options_.landmark_connections_size(); i += 2) { + const auto& ld0 = landmarks[options_.landmark_connections(i)]; + const auto& ld1 = landmarks[options_.landmark_connections(i + 1)]; + const int gray_val1 = + 255 - static_cast(Remap(ld0.z(), min_z, max_z, 255)); + const int gray_val2 = + 255 - static_cast(Remap(ld1.z(), min_z, max_z, 255)); + AddConnectionToRenderData(ld0.x(), ld0.y(), ld1.x(), ld1.y(), options_, + normalized, gray_val1, gray_val2, render_data); + } +} + +void LandmarksToRenderDataCalculator::AddConnectionToRenderData( + float start_x, float start_y, float end_x, float end_y, + const LandmarksToRenderDataCalculatorOptions& options, bool normalized, + int gray_val1, int gray_val2, RenderData* render_data) { + auto* connection_annotation = render_data->add_render_annotations(); + RenderAnnotation::GradientLine* line = + connection_annotation->mutable_gradient_line(); + line->set_x_start(start_x); + line->set_y_start(start_y); + line->set_x_end(end_x); + line->set_y_end(end_y); + line->set_normalized(normalized); + line->mutable_color1()->set_r(gray_val1); + line->mutable_color1()->set_g(gray_val1); + line->mutable_color1()->set_b(gray_val1); + line->mutable_color2()->set_r(gray_val2); + line->mutable_color2()->set_g(gray_val2); + line->mutable_color2()->set_b(gray_val2); + connection_annotation->set_thickness(options.thickness()); +} + +template +void LandmarksToRenderDataCalculator::AddConnections( + const std::vector& landmarks, bool normalized, + RenderData* render_data) { + for (int i = 0; i < options_.landmark_connections_size(); i += 2) { + const auto& ld0 = landmarks[options_.landmark_connections(i)]; + const auto& ld1 = landmarks[options_.landmark_connections(i + 1)]; + AddConnectionToRenderData(ld0.x(), ld0.y(), ld1.x(), ld1.y(), options_, + normalized, render_data); + } +} + +void LandmarksToRenderDataCalculator::AddConnectionToRenderData( + float start_x, float start_y, float end_x, float end_y, + const LandmarksToRenderDataCalculatorOptions& options, bool normalized, + RenderData* render_data) { + auto* connection_annotation = render_data->add_render_annotations(); + RenderAnnotation::Line* line = connection_annotation->mutable_line(); + line->set_x_start(start_x); + line->set_y_start(start_y); + line->set_x_end(end_x); + line->set_y_end(end_y); + line->set_normalized(normalized); + SetColor(connection_annotation, options.connection_color()); + connection_annotation->set_thickness(options.thickness()); +} + +RenderAnnotation* LandmarksToRenderDataCalculator::AddPointRenderData( + const LandmarksToRenderDataCalculatorOptions& options, + RenderData* render_data) { + auto* landmark_data_annotation = render_data->add_render_annotations(); + landmark_data_annotation->set_scene_tag(kLandmarkLabel); + SetRenderAnnotationColorThickness(options, landmark_data_annotation); + return landmark_data_annotation; +} + +void LandmarksToRenderDataCalculator::SetRenderAnnotationColorThickness( + const LandmarksToRenderDataCalculatorOptions& options, + RenderAnnotation* render_annotation) { + SetColor(render_annotation, options.landmark_color()); + render_annotation->set_thickness(options.thickness()); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto b/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto new file mode 100644 index 000000000..1334fc1f1 --- /dev/null +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto @@ -0,0 +1,43 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/util/color.proto"; + +message LandmarksToRenderDataCalculatorOptions { + extend CalculatorOptions { + optional LandmarksToRenderDataCalculatorOptions ext = 258435389; + } + + // Specifies the landmarks to be connected in the drawing. For example, the + // landmark_connections value of [0, 1, 1, 2] specifies two connections: one + // that connects landmarks with index 0 and 1, and another that connects + // landmarks with index 1 and 2. + repeated int32 landmark_connections = 1; + + // Color of the landmarks. + optional Color landmark_color = 2; + // Color of the connections. + optional Color connection_color = 3; + + // Thickness of the drawing of landmarks and connections. + optional double thickness = 4 [default = 1.0]; + + // Change color and size of rendered landmarks based on its z value. + optional bool visualize_landmark_depth = 5 [default = true]; +} diff --git a/mediapipe/calculators/util/non_max_suppression_calculator.cc b/mediapipe/calculators/util/non_max_suppression_calculator.cc index c08c92656..5836a5a6a 100644 --- a/mediapipe/calculators/util/non_max_suppression_calculator.cc +++ b/mediapipe/calculators/util/non_max_suppression_calculator.cc @@ -167,6 +167,8 @@ class NonMaxSuppressionCalculator : public CalculatorBase { } ::mediapipe::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + options_ = cc->Options(); CHECK_GT(options_.num_detection_streams(), 0) << "At least one detection stream need to be specified."; diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.cc b/mediapipe/calculators/util/rect_to_render_data_calculator.cc new file mode 100644 index 000000000..12cce1fa2 --- /dev/null +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.cc @@ -0,0 +1,144 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/rect_to_render_data_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { + +namespace { + +constexpr char kNormalizedRectTag[] = "NORM_RECT"; +constexpr char kRectTag[] = "RECT"; +constexpr char kRenderDataTag[] = "RENDER_DATA"; + +void SetRect(bool normalized, double xmin, double ymin, double width, + double height, double rotation, + RenderAnnotation::Rectangle* rect) { + if (xmin + width < 0.0 || ymin + height < 0.0) return; + if (normalized) { + if (xmin > 1.0 || ymin > 1.0) return; + } + rect->set_normalized(normalized); + rect->set_left(normalized ? std::max(xmin, 0.0) : xmin); + rect->set_top(normalized ? std::max(ymin, 0.0) : ymin); + rect->set_right(normalized ? std::min(xmin + width, 1.0) : xmin + width); + rect->set_bottom(normalized ? std::min(ymin + height, 1.0) : ymin + height); + rect->set_rotation(rotation); +} + +} // namespace + +// Generates render data needed to render a rectangle in +// AnnotationOverlayCalculator. +// +// Input: +// One of the following: +// NORM_RECT: A NormalizedRect +// RECT: A Rect +// +// Output: +// RENDER_DATA: A RenderData +// +// Example config: +// node { +// calculator: "RectToRenderDataCalculator" +// input_stream: "NORM_RECT:rect" +// output_stream: "RENDER_DATA:rect_render_data" +// options: { +// [mediapipe.RectToRenderDataCalculatorOptions.ext] { +// filled: true +// color { r: 255 g: 0 b: 0 } +// thickness: 4.0 +// } +// } +// } +class RectToRenderDataCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + RectToRenderDataCalculatorOptions options_; +}; +REGISTER_CALCULATOR(RectToRenderDataCalculator); + +::mediapipe::Status RectToRenderDataCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kNormalizedRectTag) ^ + cc->Inputs().HasTag(kRectTag)); + RET_CHECK(cc->Outputs().HasTag(kRenderDataTag)); + + if (cc->Inputs().HasTag(kNormalizedRectTag)) { + cc->Inputs().Tag(kNormalizedRectTag).Set(); + } + if (cc->Inputs().HasTag(kRectTag)) { + cc->Inputs().Tag(kRectTag).Set(); + } + cc->Outputs().Tag(kRenderDataTag).Set(); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status RectToRenderDataCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + options_ = cc->Options(); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status RectToRenderDataCalculator::Process(CalculatorContext* cc) { + auto render_data = absl::make_unique(); + auto* annotation = render_data->add_render_annotations(); + annotation->mutable_color()->set_r(options_.color().r()); + annotation->mutable_color()->set_g(options_.color().g()); + annotation->mutable_color()->set_b(options_.color().b()); + annotation->set_thickness(options_.thickness()); + + auto* rectangle = + options_.filled() + ? annotation->mutable_filled_rectangle()->mutable_rectangle() + : annotation->mutable_rectangle(); + + if (cc->Inputs().HasTag(kNormalizedRectTag) && + !cc->Inputs().Tag(kNormalizedRectTag).IsEmpty()) { + const auto& rect = + cc->Inputs().Tag(kNormalizedRectTag).Get(); + SetRect(/*normalized=*/true, rect.x_center() - rect.width() / 2.f, + rect.y_center() - rect.height() / 2.f, rect.width(), rect.height(), + rect.rotation(), rectangle); + } + if (cc->Inputs().HasTag(kRectTag) && !cc->Inputs().Tag(kRectTag).IsEmpty()) { + const auto& rect = cc->Inputs().Tag(kRectTag).Get(); + SetRect(/*normalized=*/false, rect.x_center() - rect.width() / 2.f, + rect.y_center() - rect.height() / 2.f, rect.width(), rect.height(), + rect.rotation(), rectangle); + } + + cc->Outputs() + .Tag(kRenderDataTag) + .Add(render_data.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.proto b/mediapipe/calculators/util/rect_to_render_data_calculator.proto new file mode 100644 index 000000000..badc8df44 --- /dev/null +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.proto @@ -0,0 +1,35 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/util/color.proto"; + +message RectToRenderDataCalculatorOptions { + extend CalculatorOptions { + optional RectToRenderDataCalculatorOptions ext = 262270380; + } + + // Whether the rendered rectangle should be filled. + optional bool filled = 1; + + // Line color or filled color of the rectangle. + optional Color color = 2; + + // Thickness of the line (applicable when the rectangle is not filled). + optional double thickness = 3 [default = 1.0]; +} diff --git a/mediapipe/calculators/util/rect_transformation_calculator.cc b/mediapipe/calculators/util/rect_transformation_calculator.cc new file mode 100644 index 000000000..98a7da301 --- /dev/null +++ b/mediapipe/calculators/util/rect_transformation_calculator.cc @@ -0,0 +1,205 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" + +namespace mediapipe { + +namespace { + +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kRectTag[] = "RECT"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; + +// Wraps around an angle in radians to within -M_PI and M_PI. +inline float NormalizeRadians(float angle) { + return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI)); +} + +} // namespace + +// Performs geometric transformation to the input Rect or NormalizedRect, +// correpsonding to input stream RECT or NORM_RECT respectively. When the input +// is NORM_RECT, an addition input stream IMAGE_SIZE is required, which is a +// std::pair representing the image width and height. +// +// Example config: +// node { +// calculator: "RectTransformationCalculator" +// input_stream: "NORM_RECT:rect" +// input_stream: "IMAGE_SIZE:image_size" +// output_stream: "output_rect" +// options: { +// [mediapipe.RectTransformationCalculatorOptions.ext] { +// scale_x: 2.6 +// scale_y: 2.6 +// shift_y: -0.5 +// square_long: true +// } +// } +// } +class RectTransformationCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + RectTransformationCalculatorOptions options_; + + float ComputeNewRotation(float rotation); + void TransformRect(Rect* rect); + void TransformNormalizedRect(NormalizedRect* rect, int image_width, + int image_height); +}; +REGISTER_CALCULATOR(RectTransformationCalculator); + +::mediapipe::Status RectTransformationCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kNormRectTag) ^ cc->Inputs().HasTag(kRectTag)); + if (cc->Inputs().HasTag(kRectTag)) { + cc->Inputs().Tag(kRectTag).Set(); + cc->Outputs().Index(0).Set(); + } + if (cc->Inputs().HasTag(kNormRectTag)) { + RET_CHECK(cc->Inputs().HasTag(kImageSizeTag)); + cc->Inputs().Tag(kNormRectTag).Set(); + cc->Inputs().Tag(kImageSizeTag).Set>(); + cc->Outputs().Index(0).Set(); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status RectTransformationCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + options_ = cc->Options(); + RET_CHECK(!(options_.has_rotation() && options_.has_rotation_degrees())); + RET_CHECK(!(options_.has_square_long() && options_.has_square_short())); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status RectTransformationCalculator::Process( + CalculatorContext* cc) { + if (cc->Inputs().HasTag(kRectTag) && !cc->Inputs().Tag(kRectTag).IsEmpty()) { + auto rect = cc->Inputs().Tag(kRectTag).Get(); + TransformRect(&rect); + cc->Outputs().Index(0).AddPacket( + MakePacket(rect).At(cc->InputTimestamp())); + } + + if (cc->Inputs().HasTag(kNormRectTag) && + !cc->Inputs().Tag(kNormRectTag).IsEmpty()) { + auto rect = cc->Inputs().Tag(kNormRectTag).Get(); + const auto& image_size = + cc->Inputs().Tag(kImageSizeTag).Get>(); + TransformNormalizedRect(&rect, image_size.first, image_size.second); + cc->Outputs().Index(0).AddPacket( + MakePacket(rect).At(cc->InputTimestamp())); + } + + return ::mediapipe::OkStatus(); +} + +float RectTransformationCalculator::ComputeNewRotation(float rotation) { + if (options_.has_rotation()) { + rotation += options_.rotation(); + } else if (options_.has_rotation_degrees()) { + rotation += M_PI * options_.rotation_degrees() / 180.f; + } + return NormalizeRadians(rotation); +} + +void RectTransformationCalculator::TransformRect(Rect* rect) { + float width = rect->width(); + float height = rect->height(); + float rotation = rect->rotation(); + + if (options_.has_rotation() || options_.has_rotation_degrees()) { + rotation = ComputeNewRotation(rotation); + } + if (rotation == 0.f) { + rect->set_x_center(rect->x_center() + width * options_.shift_x()); + rect->set_y_center(rect->y_center() + height * options_.shift_y()); + } else { + const float x_shift = width * options_.shift_x() * std::cos(rotation) - + height * options_.shift_y() * std::sin(rotation); + const float y_shift = width * options_.shift_x() * std::sin(rotation) + + height * options_.shift_y() * std::cos(rotation); + rect->set_x_center(rect->x_center() + x_shift); + rect->set_y_center(rect->y_center() + y_shift); + } + + if (options_.square_long()) { + const float long_side = std::max(width, height); + width = long_side; + height = long_side; + } else if (options_.square_short()) { + const float short_side = std::min(width, height); + width = short_side; + height = short_side; + } + rect->set_width(width * options_.scale_x()); + rect->set_height(height * options_.scale_y()); +} + +void RectTransformationCalculator::TransformNormalizedRect(NormalizedRect* rect, + int image_width, + int image_height) { + float width = rect->width(); + float height = rect->height(); + float rotation = rect->rotation(); + + if (options_.has_rotation() || options_.has_rotation_degrees()) { + rotation = ComputeNewRotation(rotation); + } + if (rotation == 0.f) { + rect->set_x_center(rect->x_center() + width * options_.shift_x()); + rect->set_y_center(rect->y_center() + height * options_.shift_y()); + } else { + const float x_shift = + (image_width * width * options_.shift_x() * std::cos(rotation) - + image_height * height * options_.shift_y() * std::sin(rotation)) / + image_width; + const float y_shift = + (image_width * width * options_.shift_x() * std::sin(rotation) + + image_height * height * options_.shift_y() * std::cos(rotation)) / + image_height; + rect->set_x_center(rect->x_center() + x_shift); + rect->set_y_center(rect->y_center() + y_shift); + } + + if (options_.square_long()) { + const float long_side = + std::max(width * image_width, height * image_height); + width = long_side / image_width; + height = long_side / image_height; + } else if (options_.square_short()) { + const float short_side = + std::min(width * image_width, height * image_height); + width = short_side / image_width; + height = short_side / image_height; + } + rect->set_width(width * options_.scale_x()); + rect->set_height(height * options_.scale_y()); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/rect_transformation_calculator.proto b/mediapipe/calculators/util/rect_transformation_calculator.proto new file mode 100644 index 000000000..44e781d4d --- /dev/null +++ b/mediapipe/calculators/util/rect_transformation_calculator.proto @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message RectTransformationCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional RectTransformationCalculatorOptions ext = 262226312; + } + + // Scaling factor along the side of a rotated rect that was aligned with the + // X and Y axis before rotation respectively. + optional float scale_x = 1 [default = 1.0]; + optional float scale_y = 2 [default = 1.0]; + + // Additional rotation (counter-clockwise) around the rect center either in + // radians or in degrees. + optional float rotation = 3; + optional int32 rotation_degrees = 4; + + // Shift along the side of a rotated rect that was aligned with the X and Y + // axis before rotation respectively. The shift is relative to the length of + // corresponding side. For example, for a rect with size (0.4, 0.6), with + // shift_x = 0.5 and shift_y = -0.5 the rect is shifted along the two sides + // by 0.2 and -0.3 respectively. + optional float shift_x = 5; + optional float shift_y = 6; + + // Change the final transformed rect into a square that shares the same center + // and rotation with the rect, and with the side of the square equal to either + // the long or short side of the rect respectively. + optional bool square_long = 7; + optional bool square_short = 8; +} diff --git a/mediapipe/calculators/util/thresholding_calculator.cc b/mediapipe/calculators/util/thresholding_calculator.cc new file mode 100644 index 000000000..1d7b5476b --- /dev/null +++ b/mediapipe/calculators/util/thresholding_calculator.cc @@ -0,0 +1,137 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/thresholding_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" + +namespace mediapipe { + +// Applies a threshold on a stream of numeric values and outputs a flag and/or +// accept/reject stream. The threshold can be specified by one of the following: +// 1) Input stream. +// 2) Input side packet. +// 3) Calculator option. +// +// Input: +// FLOAT: A float, which will be cast to double to be compared with a +// threshold of double type. +// THRESHOLD(optional): A double specifying the threshold at current timestamp. +// +// Output: +// FLAG(optional): A boolean indicating if the input value is larger than the +// threshold. +// ACCEPT(optional): A packet will be sent if the value is larger than the +// threshold. +// REJECT(optional): A packet will be sent if the value is no larger than the +// threshold. +// +// Usage example: +// node { +// calculator: "ThresholdingCalculator" +// input_stream: "FLOAT:score" +// output_stream: "ACCEPT:accept" +// output_stream: "REJECT:reject" +// options: { +// [mediapipe.ThresholdingCalculatorOptions.ext] { +// threshold: 0.1 +// } +// } +// } +class ThresholdingCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Open(CalculatorContext* cc) override; + + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + double threshold_{}; +}; +REGISTER_CALCULATOR(ThresholdingCalculator); + +::mediapipe::Status ThresholdingCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag("FLOAT")); + cc->Inputs().Tag("FLOAT").Set(); + + if (cc->Outputs().HasTag("FLAG")) { + cc->Outputs().Tag("FLAG").Set(); + } + if (cc->Outputs().HasTag("ACCEPT")) { + cc->Outputs().Tag("ACCEPT").Set(); + } + if (cc->Outputs().HasTag("REJECT")) { + cc->Outputs().Tag("REJECT").Set(); + } + if (cc->Inputs().HasTag("THRESHOLD")) { + cc->Inputs().Tag("THRESHOLD").Set(); + } + if (cc->InputSidePackets().HasTag("THRESHOLD")) { + cc->InputSidePackets().Tag("THRESHOLD").Set(); + RET_CHECK(!cc->Inputs().HasTag("THRESHOLD")) + << "Using both the threshold input side packet and input stream is not " + "supported."; + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ThresholdingCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + const auto& options = + cc->Options<::mediapipe::ThresholdingCalculatorOptions>(); + if (options.has_threshold()) { + RET_CHECK(!cc->Inputs().HasTag("THRESHOLD")) + << "Using both the threshold option and input stream is not supported."; + RET_CHECK(!cc->InputSidePackets().HasTag("THRESHOLD")) + << "Using both the threshold option and input side packet is not " + "supported."; + threshold_ = options.threshold(); + } + + if (cc->InputSidePackets().HasTag("THRESHOLD")) { + threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get(); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ThresholdingCalculator::Process(CalculatorContext* cc) { + if (cc->Inputs().HasTag("THRESHOLD") && + !cc->Inputs().Tag("THRESHOLD").IsEmpty()) { + threshold_ = cc->Inputs().Tag("THRESHOLD").Get(); + } + + bool accept = false; + RET_CHECK(!cc->Inputs().Tag("FLOAT").IsEmpty()); + accept = + static_cast(cc->Inputs().Tag("FLOAT").Get()) > threshold_; + + if (cc->Outputs().HasTag("FLAG")) { + cc->Outputs().Tag("FLAG").AddPacket( + MakePacket(accept).At(cc->InputTimestamp())); + } + + if (accept && cc->Outputs().HasTag("ACCEPT")) { + cc->Outputs().Tag("ACCEPT").AddPacket( + MakePacket(true).At(cc->InputTimestamp())); + } + if (!accept && cc->Outputs().HasTag("REJECT")) { + cc->Outputs().Tag("REJECT").AddPacket( + MakePacket(false).At(cc->InputTimestamp())); + } + + return ::mediapipe::OkStatus(); +} +} // namespace mediapipe diff --git a/mediapipe/calculators/util/thresholding_calculator.proto b/mediapipe/calculators/util/thresholding_calculator.proto new file mode 100644 index 000000000..b8d81ad5d --- /dev/null +++ b/mediapipe/calculators/util/thresholding_calculator.proto @@ -0,0 +1,27 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message ThresholdingCalculatorOptions { + extend CalculatorOptions { + optional ThresholdingCalculatorOptions ext = 259990498; + } + + optional double threshold = 1; +} diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 7546e5443..2244ad0dc 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -113,6 +113,22 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tvl1_optical_flow_calculator", + srcs = ["tvl1_optical_flow_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats/motion:optical_flow_field", + "//mediapipe/framework/port:opencv_video", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], + alwayslink = 1, +) + cc_test( name = "opencv_video_decoder_calculator_test", srcs = ["opencv_video_decoder_calculator_test.cc"], @@ -155,3 +171,22 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", ], ) + +cc_test( + name = "tvl1_optical_flow_calculator_test", + srcs = ["tvl1_optical_flow_calculator_test.cc"], + data = ["//mediapipe/calculators/image/testdata:test_images"], + deps = [ + ":tvl1_optical_flow_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats/motion:optical_flow_field", + "//mediapipe/framework/port:file_helpers", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:parse_text_proto", + ], +) diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc index c34a30ade..6ac11d933 100644 --- a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc @@ -139,9 +139,9 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase { << " in OpenCvVideoEncoderCalculator::Process()"; } if (format == ImageFormat::SRGB) { - cv::cvtColor(tmp_frame, frame, cv::COLOR_BGR2RGB); + cv::cvtColor(tmp_frame, frame, cv::COLOR_RGB2BGR); } else if (format == ImageFormat::SRGBA) { - cv::cvtColor(tmp_frame, frame, cv::COLOR_BGRA2RGBA); + cv::cvtColor(tmp_frame, frame, cv::COLOR_RGBA2BGR); } else { return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Unsupported image format: " << format; diff --git a/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc new file mode 100644 index 000000000..2d361a8e2 --- /dev/null +++ b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc @@ -0,0 +1,191 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/base/macros.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/motion/optical_flow_field.h" +#include "mediapipe/framework/port/opencv_video_inc.h" + +namespace mediapipe { +namespace { + +// Checks that img1 and img2 have the same dimensions. +bool ImageSizesMatch(const ImageFrame& img1, const ImageFrame& img2) { + return (img1.Width() == img2.Width()) && (img1.Height() == img2.Height()); +} + +// Converts an RGB image to grayscale. +cv::Mat ConvertToGrayscale(const cv::Mat& image) { + if (image.channels() == 1) { + return image; + } + cv::Mat gray; + cv::cvtColor(image, gray, cv::COLOR_RGB2GRAY); + return gray; +} + +} // namespace + +// Calls OpenCV's DenseOpticalFlow to compute the optical flow between a pair of +// image frames. The calculator can output forward flow fields (optical flow +// from the first frame to the second frame), backward flow fields (optical flow +// from the second frame to the first frame), or both, depending on the tag of +// the specified output streams. Note that the timestamp of the output optical +// flow is always tied to the input timestamp. Be aware of the different +// meanings of the timestamp between the forward and the backward optical flows +// if the calculator outputs both. +// +// If the "max_in_flight" field is set to any value greater than 1, it will +// enable the calculator to process multiple inputs in parallel. The output +// packets will be automatically ordered by timestamp before they are passed +// along to downstream calculators. +// +// Inputs: +// FIRST_FRAME: An ImageFrame in either SRGB or GRAY8 format. +// SECOND_FRAME: An ImageFrame in either SRGB or GRAY8 format. +// Outputs: +// FORWARD_FLOW: The OpticalFlowField from the first frame to the second +// frame, output at the input timestamp. +// BACKWARD_FLOW: The OpticalFlowField from the second frame to the first +// frame, output at the input timestamp. +// Example config: +// node { +// calculator: "Tvl1OpticalFlowCalculator" +// input_stream: "FIRST_FRAME:first_frames" +// input_stream: "SECOND_FRAME:second_frames" +// output_stream: "FORWARD_FLOW:forward_flow" +// output_stream: "BACKWARD_FLOW:backward_flow" +// max_in_flight: 10 +// } +// num_threads: 10 +class Tvl1OpticalFlowCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + ::mediapipe::Status CalculateOpticalFlow(const ImageFrame& current_frame, + const ImageFrame& next_frame, + OpticalFlowField* flow); + bool forward_requested_ = false; + bool backward_requested_ = false; + // Stores the idle DenseOpticalFlow objects. + // cv::DenseOpticalFlow is not thread-safe. Invoking multiple + // DenseOpticalFlow::calc() in parallel may lead to memory corruption or + // memory leak. + std::list> tvl1_computers_ GUARDED_BY(mutex_); + absl::Mutex mutex_; +}; + +::mediapipe::Status Tvl1OpticalFlowCalculator::GetContract( + CalculatorContract* cc) { + if (!cc->Inputs().HasTag("FIRST_FRAME") || + !cc->Inputs().HasTag("SECOND_FRAME")) { + return ::mediapipe::InvalidArgumentError( + "Missing required input streams. Both FIRST_FRAME and SECOND_FRAME " + "must be specified."); + } + cc->Inputs().Tag("FIRST_FRAME").Set(); + cc->Inputs().Tag("SECOND_FRAME").Set(); + if (cc->Outputs().HasTag("FORWARD_FLOW")) { + cc->Outputs().Tag("FORWARD_FLOW").Set(); + } + if (cc->Outputs().HasTag("BACKWARD_FLOW")) { + cc->Outputs().Tag("BACKWARD_FLOW").Set(); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) { + { + absl::MutexLock lock(&mutex_); + tvl1_computers_.emplace_back(cv::createOptFlow_DualTVL1()); + } + if (cc->Outputs().HasTag("FORWARD_FLOW")) { + forward_requested_ = true; + } + if (cc->Outputs().HasTag("BACKWARD_FLOW")) { + backward_requested_ = true; + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { + const ImageFrame& first_frame = + cc->Inputs().Tag("FIRST_FRAME").Value().Get(); + const ImageFrame& second_frame = + cc->Inputs().Tag("SECOND_FRAME").Value().Get(); + if (forward_requested_) { + auto forward_optical_flow_field = absl::make_unique(); + RETURN_IF_ERROR(CalculateOpticalFlow(first_frame, second_frame, + forward_optical_flow_field.get())); + cc->Outputs() + .Tag("FORWARD_FLOW") + .Add(forward_optical_flow_field.release(), cc->InputTimestamp()); + } + if (backward_requested_) { + auto backward_optical_flow_field = absl::make_unique(); + RETURN_IF_ERROR(CalculateOpticalFlow(second_frame, first_frame, + backward_optical_flow_field.get())); + cc->Outputs() + .Tag("BACKWARD_FLOW") + .Add(backward_optical_flow_field.release(), cc->InputTimestamp()); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status Tvl1OpticalFlowCalculator::CalculateOpticalFlow( + const ImageFrame& current_frame, const ImageFrame& next_frame, + OpticalFlowField* flow) { + CHECK(flow); + if (!ImageSizesMatch(current_frame, next_frame)) { + return tool::StatusInvalid("Images are different sizes."); + } + const cv::Mat& first = ConvertToGrayscale(formats::MatView(¤t_frame)); + const cv::Mat& second = ConvertToGrayscale(formats::MatView(&next_frame)); + + // Tries getting an idle DenseOpticalFlow object from the cache. If not, + // creates a new DenseOpticalFlow. + cv::Ptr tvl1_computer; + { + absl::MutexLock lock(&mutex_); + if (!tvl1_computers_.empty()) { + std::swap(tvl1_computer, tvl1_computers_.front()); + tvl1_computers_.pop_front(); + } + } + if (tvl1_computer.empty()) { + tvl1_computer = cv::createOptFlow_DualTVL1(); + } + + flow->Allocate(first.cols, first.rows); + cv::Mat cv_flow(flow->mutable_flow_data()); + tvl1_computer->calc(first, second, cv_flow); + CHECK_EQ(flow->mutable_flow_data().data, cv_flow.data); + // Inserts the idle DenseOpticalFlow object back to the cache for reuse. + { + absl::MutexLock lock(&mutex_); + tvl1_computers_.push_back(tvl1_computer); + } + return ::mediapipe::OkStatus(); +} + +REGISTER_CALCULATOR(Tvl1OpticalFlowCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/video/tvl1_optical_flow_calculator_test.cc b/mediapipe/calculators/video/tvl1_optical_flow_calculator_test.cc new file mode 100644 index 000000000..e187d1aa9 --- /dev/null +++ b/mediapipe/calculators/video/tvl1_optical_flow_calculator_test.cc @@ -0,0 +1,127 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/motion/optical_flow_field.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { +void AddInputPackets(int num_packets, CalculatorGraph* graph) { + int width = 127; + int height = 227; + Packet packet1 = MakePacket(ImageFormat::SRGB, width, height); + Packet packet2 = MakePacket(ImageFormat::SRGB, width, height); + cv::Mat mat1 = formats::MatView(&(packet1.Get())); + cv::Mat mat2 = formats::MatView(&(packet2.Get())); + for (int r = 0; r < mat1.rows; ++r) { + for (int c = 0; c < mat1.cols; ++c) { + cv::Vec3b& color1 = mat1.at(r, c); + color1[0] = r + 3; + color1[1] = r + 3; + color1[2] = 0; + cv::Vec3b& color2 = mat2.at(r, c); + color2[0] = r; + color2[1] = r; + color2[2] = 0; + } + } + + for (int i = 0; i < num_packets; ++i) { + MEDIAPIPE_ASSERT_OK(graph->AddPacketToInputStream( + "first_frames", packet1.At(Timestamp(i)))); + MEDIAPIPE_ASSERT_OK(graph->AddPacketToInputStream( + "second_frames", packet2.At(Timestamp(i)))); + } + MEDIAPIPE_ASSERT_OK(graph->CloseAllInputStreams()); +} + +void RunTest(int num_input_packets, int max_in_flight) { + CalculatorGraphConfig config = ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "first_frames" + input_stream: "second_frames" + node { + calculator: "Tvl1OpticalFlowCalculator" + input_stream: "FIRST_FRAME:first_frames" + input_stream: "SECOND_FRAME:second_frames" + output_stream: "FORWARD_FLOW:forward_flow" + output_stream: "BACKWARD_FLOW:backward_flow" + max_in_flight: $0 + } + num_threads: $0 + )", + max_in_flight)); + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + StatusOrPoller status_or_poller1 = + graph.AddOutputStreamPoller("forward_flow"); + ASSERT_TRUE(status_or_poller1.ok()); + OutputStreamPoller poller1 = std::move(status_or_poller1.ValueOrDie()); + StatusOrPoller status_or_poller2 = + graph.AddOutputStreamPoller("backward_flow"); + ASSERT_TRUE(status_or_poller2.ok()); + OutputStreamPoller poller2 = std::move(status_or_poller2.ValueOrDie()); + + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + AddInputPackets(num_input_packets, &graph); + Packet packet; + std::vector forward_optical_flow_packets; + while (poller1.Next(&packet)) { + forward_optical_flow_packets.emplace_back(packet); + } + std::vector backward_optical_flow_packets; + while (poller2.Next(&packet)) { + backward_optical_flow_packets.emplace_back(packet); + } + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_EQ(num_input_packets, forward_optical_flow_packets.size()); + + int count = 0; + for (const Packet& packet : forward_optical_flow_packets) { + cv::Scalar average = cv::mean(packet.Get().flow_data()); + EXPECT_NEAR(average[0], 0.0, 0.5) << "Actual mean_dx = " << average[0]; + EXPECT_NEAR(average[1], 3.0, 0.5) << "Actual mean_dy = " << average[1]; + EXPECT_EQ(count++, packet.Timestamp().Value()); + } + EXPECT_EQ(num_input_packets, backward_optical_flow_packets.size()); + count = 0; + for (const Packet& packet : backward_optical_flow_packets) { + cv::Scalar average = cv::mean(packet.Get().flow_data()); + EXPECT_NEAR(average[0], 0.0, 0.5) << "Actual mean_dx = " << average[0]; + EXPECT_NEAR(average[1], -3.0, 0.5) << "Actual mean_dy = " << average[1]; + EXPECT_EQ(count++, packet.Timestamp().Value()); + } +} + +TEST(Tvl1OpticalFlowCalculatorTest, TestSequentialExecution) { + RunTest(/*num_input_packets=*/2, /*max_in_flight=*/1); +} + +TEST(Tvl1OpticalFlowCalculatorTest, TestParallelExecution) { + RunTest(/*num_input_packets=*/20, /*max_in_flight=*/10); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/docs/examples.md b/mediapipe/docs/examples.md index e8d773e26..56b0ed51e 100644 --- a/mediapipe/docs/examples.md +++ b/mediapipe/docs/examples.md @@ -9,47 +9,72 @@ for Objective-C shortly. ### Hello World! on Android [Hello World! on Android](./hello_world_android.md) should be the first mobile -example users go through in detail. It teaches the following: +Android example users go through in detail. It teaches the following: * Introduction of a simple MediaPipe graph running on mobile GPUs for - [Sobel edge detection]. + [Sobel edge detection](https://en.wikipedia.org/wiki/Sobel_operator). * Building a simple baseline Android application that displays "Hello World!". * Adding camera preview support into the baseline application using the Android [CameraX] API. * Incorporating the Sobel edge detection graph to process the live camera preview and display the processed video in real-time. -### Object Detection with GPU on Android +### Hello World! on iOS -[Object Detection on GPU on Android](./object_detection_android_gpu.md) -illustrates how to use MediaPipe with a TFLite model for object detection in a -GPU-accelerated pipeline. +[Hello World! on iOS](./hello_world_ios.md) is the iOS version of Sobel edge +detection example -### Object Detection with CPU on Android +### Object Detection with GPU -[Object Detection on CPU on Android](./object_detection_android_cpu.md) -illustrates using the same TFLite model in a CPU-based pipeline. This example -highlights how graphs can be easily adapted to run on CPU v.s. GPU. - -### Face Detection on Android - -[Face Detection on Android](./face_detection_android_gpu.md) illustrates how to -use MediaPipe with a TFLite model for face detection in a GPU-accelerated +[Object Detection with GPU](./object_detection_mobile_gpu.md) illustrates how to +use MediaPipe with a TFLite model for object detection in a GPU-accelerated pipeline. -* The selfie face detection TFLite model is based on - ["BlazeFace: Sub-millisecond Neural Face Detection on Mobile GPUs"](https://sites.google.com/view/perception-cv4arvr/blazeface). -* [Model card](https://sites.google.com/corp/view/perception-cv4arvr/blazeface#h.p_21ojPZDx3cqq). +* [Android](./object_detection_mobile_gpu.md#android) +* [iOS](./object_detection_mobile_gpu.md#ios) -### Hair Segmentation on Android +### Object Detection with CPU -[Hair Segmentation on Android](./hair_segmentation_android_gpu.md) illustrates -how to use MediaPipe with a TFLite model for hair segmentation in a -GPU-accelerated pipeline. +[Object Detection with CPU](./object_detection_mobile_cpu.md) illustrates using +the same TFLite model in a CPU-based pipeline. This example highlights how +graphs can be easily adapted to run on CPU v.s. GPU. -* The selfie hair segmentation TFLite model is based on - ["Real-time Hair segmentation and recoloring on Mobile GPUs"](https://sites.google.com/view/perception-cv4arvr/hair-segmentation). -* [Model card](https://sites.google.com/corp/view/perception-cv4arvr/hair-segmentation#h.p_NimuO7PgHxlY). +### Face Detection with GPU + +[Face Detection with GPU](./face_detection_mobile_gpu.md) illustrates how to use +MediaPipe with a TFLite model for face detection in a GPU-accelerated pipeline. +The selfie face detection TFLite model is based on +["BlazeFace: Sub-millisecond Neural Face Detection on Mobile GPUs"](https://sites.google.com/view/perception-cv4arvr/blazeface). +[Model card](https://sites.google.com/corp/view/perception-cv4arvr/blazeface#h.p_21ojPZDx3cqq). + +* [Android](./face_detection_mobile_gpu.md#android) +* [iOS](./face_detection_mobile_gpu.md#ios) + +### Hand Detection with GPU + +[Hand Detection with GPU](./hand_detection_mobile_gpu.md) illustrates how to use +MediaPipe with a TFLite model for hand detection in a GPU-accelerated pipeline. + +* [Android](./hand_detection_mobile_gpu.md#android) +* [iOS](./hand_detection_mobile_gpu.md#ios) + +### Hand Tracking with GPU + +[Hand Tracking with GPU](./hand_tracking_mobile_gpu.md) illustrates how to use +MediaPipe with a TFLite model for hand tracking in a GPU-accelerated pipeline. + +* [Android](./hand_tracking_mobile_gpu.md#android) +* [iOS](./hand_tracking_mobile_gpu.md#ios) + +### Hair Segmentation with GPU + +[Hair Segmentation on GPU](./hair_segmentation_mobile_gpu.md) illustrates how to +use MediaPipe with a TFLite model for hair segmentation in a GPU-accelerated +pipeline. The selfie hair segmentation TFLite model is based on +["Real-time Hair segmentation and recoloring on Mobile GPUs"](https://sites.google.com/view/perception-cv4arvr/hair-segmentation). +[Model card](https://sites.google.com/corp/view/perception-cv4arvr/hair-segmentation#h.p_NimuO7PgHxlY). + +* [Android](./hair_segmentation_mobile_gpu.md#android) ## Desktop diff --git a/mediapipe/docs/face_detection_android_gpu.md b/mediapipe/docs/face_detection_mobile_gpu.md similarity index 60% rename from mediapipe/docs/face_detection_android_gpu.md rename to mediapipe/docs/face_detection_mobile_gpu.md index a88a10818..4bf7d6f0f 100644 --- a/mediapipe/docs/face_detection_android_gpu.md +++ b/mediapipe/docs/face_detection_mobile_gpu.md @@ -1,17 +1,18 @@ -# Face Detection on Android +# Face Detection (GPU) -Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for -general instructions to develop an Android application that uses MediaPipe. This -doc focuses on the -[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_detection/face_detection_android_gpu.pbtxt) +This doc focuses on the +[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt) that performs face detection with TensorFlow Lite on GPU. ![face_detection_android_gpu_gif](images/mobile/face_detection_android_gpu.gif){width="300"} -## App +## Android + +Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for +general instructions to develop an Android application that uses MediaPipe. The graph is used in the -[Face Detection GPU](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu) +[Face Detection GPU](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu) example app. To build the app, run: ```bash @@ -24,17 +25,33 @@ To further install the app on android device, run: adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/facedetectiongpu.apk ``` +## iOS + +Please see [Hello World! in MediaPipe on iOS](hello_world_ios.md) for general +instructions to develop an iOS application that uses MediaPipe. The graph below +is used in the +[Face Detection GPU iOS example app](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/facedetectiongpu). + +To build the iOS app, please see the general +[MediaPipe iOS app building and setup instructions](./mediapipe_ios_setup.md). +Specifically, run: + +```bash +bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/facedetectiongpu:FaceDetectionGpuApp +``` + ## Graph -![face_detection_android_gpu_graph](images/mobile/face_detection_android_gpu.png){width="400"} +![face_detection_mobile_gpu_graph](images/mobile/face_detection_mobile_gpu.png){width="400"} To visualize the graph as shown above, copy the text specification of the graph below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). ```bash -# MediaPipe graph that performs object detection with TensorFlow Lite on GPU. +# MediaPipe graph that performs face detection with TensorFlow Lite on GPU. # Used in the example in -# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectiongpu. +# mediapipie/examples/android/src/java/com/mediapipe/apps/facedetectiongpu and +# mediapipie/examples/ios/facedetectiongpu. # Images on GPU coming into and out of the graph. input_stream: "input_video" @@ -54,7 +71,7 @@ output_stream: "output_video" # TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy # processing previous inputs. node { - calculator: "RealTimeFlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "input_video" input_stream: "FINISHED:detections" input_stream_info: { @@ -64,40 +81,29 @@ node { output_stream: "throttled_input_video" } -# Transforms the input image on GPU to a 320x320 image. To scale the image, by -# default it uses the STRETCH scale mode that maps the entire input image to the -# entire transformed image. As a result, image aspect ratio may be changed and -# objects in the image may be deformed (stretched or squeezed), but the object -# detection model used in this graph is agnostic to that deformation. +# Transforms the input image on GPU to a 128x128 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. node: { calculator: "ImageTransformationCalculator" input_stream: "IMAGE_GPU:throttled_input_video" output_stream: "IMAGE_GPU:transformed_input_video" + output_stream: "LETTERBOX_PADDING:letterbox_padding" node_options: { [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { - output_width: 320 - output_height: 320 + output_width: 128 + output_height: 128 + scale_mode: FIT } } } -# Converts the transformed input image on GPU into an image tensor stored in -# tflite::gpu::GlBuffer. The zero_center option is set to true to normalize the -# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically -# option is set to true to account for the descrepancy between the -# representation of the input image (origin at the bottom-left corner, the -# OpenGL convention) and what the model used in this graph is expecting (origin -# at the top-left corner). +# Converts the transformed input image on GPU into an image tensor stored as a +# TfLiteTensor. node { calculator: "TfLiteConverterCalculator" input_stream: "IMAGE_GPU:transformed_input_video" output_stream: "TENSORS_GPU:image_tensor" - node_options: { - [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { - zero_center: true - flip_vertically: true - } - } } # Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a @@ -106,10 +112,10 @@ node { node { calculator: "TfLiteInferenceCalculator" input_stream: "TENSORS_GPU:image_tensor" - output_stream: "TENSORS_GPU:detection_tensors" + output_stream: "TENSORS:detection_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "ssdlite_object_detection.tflite" + model_path: "face_detection_front.tflite" } } } @@ -121,25 +127,19 @@ node { output_side_packet: "anchors" node_options: { [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { - num_layers: 6 - min_scale: 0.2 - max_scale: 0.95 - input_size_height: 320 - input_size_width: 320 + num_layers: 4 + min_scale: 0.1484375 + max_scale: 0.75 + input_size_height: 128 + input_size_width: 128 anchor_offset_x: 0.5 anchor_offset_y: 0.5 + strides: 8 + strides: 16 + strides: 16 strides: 16 - strides: 32 - strides: 64 - strides: 128 - strides: 256 - strides: 512 aspect_ratios: 1.0 - aspect_ratios: 2.0 - aspect_ratios: 0.5 - aspect_ratios: 3.0 - aspect_ratios: 0.3333 - reduce_boxes_in_lowest_layer: true + fixed_anchor_size: true } } } @@ -149,22 +149,26 @@ node { # detections. Each detection describes a detected object. node { calculator: "TfLiteTensorsToDetectionsCalculator" - input_stream: "TENSORS_GPU:detection_tensors" + input_stream: "TENSORS:detection_tensors" input_side_packet: "ANCHORS:anchors" output_stream: "DETECTIONS:detections" node_options: { [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { - num_classes: 91 - num_boxes: 2034 - num_coords: 4 - ignore_classes: 0 + num_classes: 1 + num_boxes: 896 + num_coords: 16 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 6 + num_values_per_keypoint: 2 sigmoid_score: true - apply_exponential_on_box_size: true - x_scale: 10.0 - y_scale: 10.0 - h_scale: 5.0 - w_scale: 5.0 - flip_vertically: true + score_clipping_thresh: 100.0 + reverse_output_order: true + x_scale: 128.0 + y_scale: 128.0 + h_scale: 128.0 + w_scale: 128.0 + min_score_thresh: 0.75 } } } @@ -176,56 +180,58 @@ node { output_stream: "filtered_detections" node_options: { [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { - min_suppression_threshold: 0.4 - min_score_threshold: 0.6 - max_num_detections: 3 + min_suppression_threshold: 0.3 overlap_type: INTERSECTION_OVER_UNION + algorithm: WEIGHTED + return_empty_detections: true } } } -# Maps detection label IDs to the corresponding label text. The label map is -# provided in the label_map_path option. +# Maps detection label IDs to the corresponding label text ("Face"). The label +# map is provided in the label_map_path option. node { calculator: "DetectionLabelIdToTextCalculator" input_stream: "filtered_detections" - output_stream: "output_detections" + output_stream: "labeled_detections" node_options: { [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { - label_map_path: "ssdlite_object_detection_labelmap.txt" + label_map_path: "face_detection_front_labelmap.txt" } } } +# Adjusts detection locations (already normalized to [0.f, 1.f]) on the +# letterboxed image (after image transformation with the FIT scale mode) to the +# corresponding locations on the same image with the letterbox removed (the +# input image to the graph before image transformation). +node { + calculator: "DetectionLetterboxRemovalCalculator" + input_stream: "DETECTIONS:labeled_detections" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "DETECTIONS:output_detections" +} + # Converts the detections to drawing primitives for annotation overlay. node { calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION_VECTOR:output_detections" + input_stream: "DETECTIONS:output_detections" output_stream: "RENDER_DATA:render_data" node_options: { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { - thickness: 4.0 + thickness: 10.0 color { r: 255 g: 0 b: 0 } } } } -# Draws annotations and overlays them on top of the original image coming into -# the graph. Annotation drawing is performed on CPU, and the result is -# transferred to GPU and overlaid on the input image. The calculator assumes -# that image origin is always at the top-left corner and renders text -# accordingly. However, the input image has its origin at the bottom-left corner -# (OpenGL convention) and the flip_text_vertically option is set to true to -# compensate that. +# Draws annotations and overlays them on top of a GPU copy of the original +# image coming into the graph. The calculator assumes that image origin is +# always at the top-left corner and renders text accordingly. node { calculator: "AnnotationOverlayCalculator" input_stream: "INPUT_FRAME_GPU:throttled_input_video" input_stream: "render_data" output_stream: "OUTPUT_FRAME_GPU:output_video" - node_options: { - [type.googleapis.com/mediapipe.AnnotationOverlayCalculatorOptions] { - flip_text_vertically: true - } - } } ``` diff --git a/mediapipe/docs/framework_concepts.md b/mediapipe/docs/framework_concepts.md index 5facb283c..86d91e7be 100644 --- a/mediapipe/docs/framework_concepts.md +++ b/mediapipe/docs/framework_concepts.md @@ -252,21 +252,23 @@ node { To modularize a `CalculatorGraphConfig` into sub-modules and assist with re-use of perception solutions, a MediaPipe graph can be defined as a `Subgraph`. The -public interface to a subgraph consists of a set of input and output streams -similar to the public interface of a calculator. The subgraph can then be +public interface of a subgraph consists of a set of input and output streams +similar to a calculator's public interface. The subgraph can then be included in an `CalculatorGraphConfig` as if it were a calculator. When a MediaPipe graph is loaded from a `CalculatorGraphConfig`, each subgraph node is replaced by the corresponding graph of calculators. As a result, the semantics and performance of the subgraph is identical to the corresponding graph of calculators. -Below is an example of how to create a subgraph named `TwoPassThroughSubgraph` +Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`. -1. Defining the subgraph. +1. Defining the subgraph. ```proto # This subgraph is defined in two_pass_through_subgraph.pbtxt - # that is registered in the BUILD file as "TwoPassThroughSubgraph" + # and is registered as "TwoPassThroughSubgraph" + + type: "TwoPassThroughSubgraph" input_stream: "out1" output_stream: "out3" @@ -282,19 +284,20 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph` } ``` -The public interface to the graph that consist of: - * Graph input streams - * Graph output streams - * Graph input side packets - * Graph output side packets + The public interface to the subgraph consists of: -2. Register the subgraph using BUILD rule `mediapipe_simple_subgraph` - * The parameter `register_as` defines the component name for the new subgraph + * Graph input streams + * Graph output streams + * Graph input side packets + * Graph output side packets + +2. Register the subgraph using BUILD rule `mediapipe_simple_subgraph`. The + parameter `register_as` defines the component name for the new subgraph. ```proto # Small section of BUILD file for registering the "TwoPassThroughSubgraph" # subgraph for use by main graph main_pass_throughcals.pbtxt - # + mediapipe_simple_subgraph( name = "twopassthrough_subgraph", graph = "twopassthrough_subgraph.pbtxt", @@ -306,12 +309,12 @@ The public interface to the graph that consist of: ) ``` -3. Use the subgraph in the main graph +3. Use the subgraph in the main graph. ```proto # This main graph is defined in main_pass_throughcals.pbtxt # using subgraph called "TwoPassThroughSubgraph" - # + input_stream: "in" node { calculator: "PassThroughCalculator" @@ -329,108 +332,3 @@ The public interface to the graph that consist of: output_stream: "out4" } ``` - - diff --git a/mediapipe/docs/hair_segmentation_android_gpu.md b/mediapipe/docs/hair_segmentation_mobile_gpu.md similarity index 83% rename from mediapipe/docs/hair_segmentation_android_gpu.md rename to mediapipe/docs/hair_segmentation_mobile_gpu.md index 02d513a24..08a370b40 100644 --- a/mediapipe/docs/hair_segmentation_android_gpu.md +++ b/mediapipe/docs/hair_segmentation_mobile_gpu.md @@ -1,14 +1,15 @@ -# Hair Segmentation on Android +# Hair Segmentation (GPU) -Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for -general instructions to develop an Android application that uses MediaPipe. This -doc focuses on the -[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hair_segmentation/hair_segmentation_android_gpu.pbtxt) +This doc focuses on the +[below example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hair_segmentation/hair_segmentation_android_gpu.pbtxt) that performs hair segmentation with TensorFlow Lite on GPU. ![hair_segmentation_android_gpu_gif](images/mobile/hair_segmentation_android_gpu.gif){width="300"} -## App +## Android + +Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for +general instructions to develop an Android application that uses MediaPipe. The graph is used in the [Hair Segmentation GPU](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu) @@ -26,7 +27,7 @@ adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/a ## Graph -![hair_segmentation_android_gpu_graph](images/mobile/hair_segmentation_android_gpu.png){width="600"} +![hair_segmentation_mobile_gpu_graph](images/mobile/hair_segmentation_mobile_gpu.png){width="600"} To visualize the graph as shown above, copy the text specification of the graph below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). @@ -34,7 +35,7 @@ below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). ```bash # MediaPipe graph that performs hair segmentation with TensorFlow Lite on GPU. # Used in the example in -# mediapipie/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu. +# mediapipie/examples/ios/hairsegmentationgpu. # Images on GPU coming into and out of the graph. input_stream: "input_video" @@ -54,7 +55,7 @@ output_stream: "output_video" # TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy # processing previous inputs. node { - calculator: "RealTimeFlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "input_video" input_stream: "FINISHED:hair_mask" input_stream_info: { @@ -111,12 +112,9 @@ node { # Converts the transformed input image on GPU into an image tensor stored in # tflite::gpu::GlBuffer. The zero_center option is set to false to normalize the -# pixel values to [0.f, 1.f] as opposed to [-1.f, 1.f]. The flip_vertically -# option is set to true to account for the descrepancy between the -# representation of the input image (origin at the bottom-left corner, the -# OpenGL convention) and what the model used in this graph is expecting (origin -# at the top-left corner). With the max_num_channels option set to 4, all 4 RGBA -# channels are contained in the image tensor. +# pixel values to [0.f, 1.f] as opposed to [-1.f, 1.f]. +# With the max_num_channels option set to 4, all 4 RGBA channels are contained +# in the image tensor. node { calculator: "TfLiteConverterCalculator" input_stream: "IMAGE_GPU:mask_embedded_input_video" @@ -124,7 +122,6 @@ node { node_options: { [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { zero_center: false - flip_vertically: true max_num_channels: 4 } } @@ -148,7 +145,7 @@ node { node { calculator: "TfLiteInferenceCalculator" input_stream: "TENSORS_GPU:image_tensor" - output_stream: "TENSORS_GPU:segmentation_tensor" + output_stream: "TENSORS:segmentation_tensor" input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { @@ -158,15 +155,23 @@ node { } } +# The next step (tensors to segmentation) is not yet supported on iOS GPU. +# Convert the previous segmentation mask to CPU for processing. +node: { + calculator: "GpuBufferToImageFrameCalculator" + input_stream: "previous_hair_mask" + output_stream: "previous_hair_mask_cpu" +} + # Decodes the segmentation tensor generated by the TensorFlow Lite model into a -# mask of values in [0.f, 1.f], stored in the R channel of a GPU buffer. It also +# mask of values in [0.f, 1.f], stored in the R channel of a CPU buffer. It also # takes the mask generated previously as another input to improve the temporal # consistency. node { calculator: "TfLiteTensorsToSegmentationCalculator" - input_stream: "TENSORS_GPU:segmentation_tensor" - input_stream: "PREV_MASK_GPU:previous_hair_mask" - output_stream: "MASK_GPU:hair_mask" + input_stream: "TENSORS:segmentation_tensor" + input_stream: "PREV_MASK:previous_hair_mask_cpu" + output_stream: "MASK:hair_mask_cpu" node_options: { [type.googleapis.com/mediapipe.TfLiteTensorsToSegmentationCalculatorOptions] { tensor_width: 512 @@ -178,6 +183,13 @@ node { } } +# Send the current segmentation mask to GPU for the last step, blending. +node: { + calculator: "ImageFrameToGpuBufferCalculator" + input_stream: "hair_mask_cpu" + output_stream: "hair_mask" +} + # Colors the hair segmentation with the color specified in the option. node { calculator: "RecolorCalculator" diff --git a/mediapipe/docs/hand_detection_mobile_gpu.md b/mediapipe/docs/hand_detection_mobile_gpu.md new file mode 100644 index 000000000..2dcd4df70 --- /dev/null +++ b/mediapipe/docs/hand_detection_mobile_gpu.md @@ -0,0 +1,326 @@ +# Hand Detection (GPU) + +This doc focuses on the +[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt) +that performs hand detection with TensorFlow Lite on GPU. This hand detection +example is related to +[hand tracking GPU example](./hand_tracking_mobile_gpu.md). Here is the +[model card](https://mediapipe.page.link/handmc) for hand detection. + +For overall context on hand detection and hand tracking, please read +[this Google AI blog post](https://mediapipe.page.link/handgoogleaiblog). + +![hand_detection_android_gpu_gif](images/mobile/hand_detection_android_gpu.gif){width="300"} + +## Android + +Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for +general instructions to develop an Android application that uses MediaPipe. + +The graph is used in the +[Hand Detection GPU](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handdetectiongpu) +example app. To build the app, run: + +```bash +bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/handdetectiongpu +``` + +To further install the app on android device, run: + +```bash +adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handdetectiongpu/handdetectiongpu.apk +``` + +## iOS + +Please see [Hello World! in MediaPipe on iOS](hello_world_ios.md) for general +instructions to develop an iOS application that uses MediaPipe. The graph below +is used in the +[Hand Detection GPU iOS example app](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/handdetectiongpu) + +To build the iOS app, please see the general +[MediaPipe iOS app building and setup instructions](./mediapipe_ios_setup.md). +Specifically, run: + +```bash +bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/handdetectiongpu:HandDetectionGpuApp +``` + +## Graph + +The hand detection graph is +[hand_detection_mobile.pbtxt](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_mobile.pbtxt) +and it includes a [HandDetectionSubgraph](./framework_concepts.md#subgraph) with +filename +[hand_detection_gpu.pbtxt](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt) +shown as a box called `HandDetection` in purple + +For more information on how to visualize a graph that includes subgraphs, see +[subgraph documentation](./visualizer.md#visualizing-subgraphs) for Visualizer. + +![hand_detection_mobile_graph](images/mobile/hand_detection_mobile.png){width="500"} + +```bash +# MediaPipe graph that performs hand detection with TensorFlow Lite on GPU. +# Used in the example in +# mediapipie/examples/android/src/java/com/mediapipe/apps/handdetectiongpu. +# mediapipie/examples/ios/handdetectiongpu. + +# Images coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +node { + calculator: "FlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:hand_rect_from_palm_detections" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +node { + calculator: "HandDetectionSubgraph" + input_stream: "throttled_input_video" + output_stream: "DETECTIONS:palm_detections" + output_stream: "NORM_RECT:hand_rect_from_palm_detections" +} + +# Converts detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:palm_detections" + output_stream: "RENDER_DATA:detection_render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 0 g: 255 b: 0 } + } + } +} + +# Converts normalized rects to drawing primitives for annotation overlay. +node { + calculator: "RectToRenderDataCalculator" + input_stream: "NORM_RECT:hand_rect_from_palm_detections" + output_stream: "RENDER_DATA:rect_render_data" + node_options: { + [type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] { + filled: false + color { r: 255 g: 0 b: 0 } + thickness: 4.0 + } + } +} + +# Draws annotations and overlays them on top of the input image into the graph. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME_GPU:throttled_input_video" + input_stream: "detection_render_data" + input_stream: "rect_render_data" + output_stream: "OUTPUT_FRAME_GPU:output_video" +} +``` + +![hand_detection_gpu_subgraph](images/mobile/hand_detection_gpu_subgraph.png){width="500"} + +```bash +type: "HandDetectionSubgraph" + +input_stream: "input_video" +output_stream: "DETECTIONS:palm_detections" +output_stream: "NORM_RECT:hand_rect_from_palm_detections" + +# Transforms the input image on GPU to a 256x256 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE_GPU:input_video" + output_stream: "IMAGE_GPU:transformed_input_video" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 256 + output_height: 256 + scale_mode: FIT + } + } +} + +# Generates a single side packet containing a TensorFlow Lite op resolver that +# supports custom ops needed by the model used in this graph. +node { + calculator: "TfLiteCustomOpResolverCalculator" + output_side_packet: "opresolver" + node_options: { + [type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] { + use_gpu: true + } + } +} + +# Converts the transformed input image on GPU into an image tensor stored as a +# TfLiteTensor. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE_GPU:transformed_input_video" + output_stream: "TENSORS_GPU:image_tensor" +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS_GPU:image_tensor" + output_stream: "TENSORS:detection_tensors" + input_side_packet: "CUSTOM_OP_RESOLVER:opresolver" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "palm_detection.tflite" + use_gpu: true + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 5 + min_scale: 0.1171875 + max_scale: 0.75 + input_size_height: 256 + input_size_width: 256 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 8 + strides: 16 + strides: 32 + strides: 32 + strides: 32 + aspect_ratios: 1.0 + fixed_anchor_size: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 1 + num_boxes: 2944 + num_coords: 18 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 7 + num_values_per_keypoint: 2 + sigmoid_score: true + score_clipping_thresh: 100.0 + reverse_output_order: true + + x_scale: 256.0 + y_scale: 256.0 + h_scale: 256.0 + w_scale: 256.0 + min_score_thresh: 0.7 + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.3 + overlap_type: INTERSECTION_OVER_UNION + algorithm: WEIGHTED + return_empty_detections: true + } + } +} + +# Maps detection label IDs to the corresponding label text. The label map is +# provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "labeled_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "palm_detection_labelmap.txt" + } + } +} + +# Adjusts detection locations (already normalized to [0.f, 1.f]) on the +# letterboxed image (after image transformation with the FIT scale mode) to the +# corresponding locations on the same image with the letterbox removed (the +# input image to the graph before image transformation). +node { + calculator: "DetectionLetterboxRemovalCalculator" + input_stream: "DETECTIONS:labeled_detections" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "DETECTIONS:palm_detections" +} + +# Extracts image size from the input images. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_GPU:input_video" + output_stream: "SIZE:image_size" +} + +# Converts results of palm detection into a rectangle (normalized by image size) +# that encloses the palm and is rotated such that the line connecting center of +# the wrist and MCP of the middle finger is aligned with the Y-axis of the +# rectangle. +node { + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:palm_detections" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "NORM_RECT:palm_rect" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRectsCalculatorOptions] { + rotation_vector_start_keypoint_index: 0 # Center of wrist. + rotation_vector_end_keypoint_index: 2 # MCP of middle finger. + rotation_vector_target_angle_degrees: 90 + output_zero_rect_for_empty_detections: true + } + } +} + +# Expands and shifts the rectangle that contains the palm so that it's likely +# to cover the entire hand. +node { + calculator: "RectTransformationCalculator" + input_stream: "NORM_RECT:palm_rect" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "hand_rect_from_palm_detections" + node_options: { + [type.googleapis.com/mediapipe.RectTransformationCalculatorOptions] { + scale_x: 2.6 + scale_y: 2.6 + shift_y: -0.5 + square_long: true + } + } +} +``` diff --git a/mediapipe/docs/hand_tracking_mobile_gpu.md b/mediapipe/docs/hand_tracking_mobile_gpu.md new file mode 100644 index 000000000..ec6d833e4 --- /dev/null +++ b/mediapipe/docs/hand_tracking_mobile_gpu.md @@ -0,0 +1,640 @@ +# Hand Tracking (GPU) + +This doc focuses on the +[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_android_gpu.pbtxt) +that performs hand tracking with TensorFlow Lite on GPU. This hand tracking +example is related to +[hand detection GPU example](./hand_detection_mobile_gpu.md). We recommend users +to review the hand detection GPU example first. Here is the +[model card](https://mediapipe.page.link/handmc) for hand tracking. + +For overall context on hand detection and hand tracking, please read +[this Google AI blog post](https://mediapipe.page.link/handgoogleaiblog). + +![hand_tracking_android_gpu.gif](images/mobile/hand_tracking_android_gpu.gif){width="300"} + +## Android + +Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for +general instructions to develop an Android application that uses MediaPipe. + +The graph is used in the +[Hand Tracking GPU](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu) +example app. To build the app, run: + +```bash +bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu +``` + +To further install the app on android device, run: + +```bash +adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/handtrackinggpu/handtrackinggpu.apk +``` + +## iOS + +Please see [Hello World! in MediaPipe on iOS](hello_world_ios.md) for general +instructions to develop an iOS application that uses MediaPipe. The graph below +is used in the +[Hand Tracking GPU iOS example app](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/handtrackinggpu) + +To build the iOS app, please see the general +[MediaPipe iOS app building and setup instructions](./mediapipe_ios_setup.md). +Specifically, run: + +```bash +bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp +``` + +## Graph + +For more information on how to visualize a graph that includes subgraphs, see +[subgraph documentation](./visualizer.md#visualizing-subgraphs) for Visualizer. + +The hand tracking graph is +[hand_tracking_mobile.pbtxt](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt) +and it includes 3 [subgraphs](./framework_concepts.md#subgraph): + +* [HandDetectionSubgraph - hand_detection_gpu.pbtxt](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt) + +* [HandLandmarkSubgraph - hand_landmark_gpu.pbtxt](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_landmark_gpu.pbtxt) + +* [RendererSubgraph - renderer_gpu.pbtxt](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/renderer_gpu.pbtxt) + +![hand_tracking_mobile_graph](images/mobile/hand_tracking_mobile.png){width="400"} + +```bash +# MediaPipe graph that performs hand tracking with TensorFlow Lite on GPU. +# Used in the example in +# mediapipie/examples/android/src/java/com/mediapipe/apps/handtrackinggpu. + +# Images coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +node { + calculator: "FlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:hand_rect" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:throttled_input_video" + input_stream: "LOOP:hand_presence" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:prev_hand_presence" +} + +node { + calculator: "GateCalculator" + input_stream: "throttled_input_video" + input_stream: "DISALLOW:prev_hand_presence" + output_stream: "hand_detection_input_video" + + node_options: { + [type.googleapis.com/mediapipe.GateCalculatorOptions] { + empty_packets_as_allow: true + } + } +} + +node { + calculator: "HandDetectionSubgraph" + input_stream: "hand_detection_input_video" + output_stream: "DETECTIONS:palm_detections" + output_stream: "NORM_RECT:hand_rect_from_palm_detections" +} + +node { + calculator: "HandLandmarkSubgraph" + input_stream: "IMAGE:throttled_input_video" + input_stream: "NORM_RECT:hand_rect" + output_stream: "LANDMARKS:hand_landmarks" + output_stream: "NORM_RECT:hand_rect_from_landmarks" + output_stream: "PRESENCE:hand_presence" +} + +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:throttled_input_video" + input_stream: "LOOP:hand_rect_from_landmarks" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:prev_hand_rect_from_landmarks" +} + +node { + calculator: "MergeCalculator" + input_stream: "hand_rect_from_palm_detections" + input_stream: "prev_hand_rect_from_landmarks" + output_stream: "hand_rect" +} + +node { + calculator: "RendererSubgraph" + input_stream: "IMAGE:throttled_input_video" + input_stream: "LANDMARKS:hand_landmarks" + input_stream: "NORM_RECT:hand_rect" + input_stream: "DETECTIONS:palm_detections" + output_stream: "IMAGE:output_video" +} +``` + +![hand_detection_gpu_subgraph](images/mobile/hand_detection_gpu_subgraph.png){width="500"} + +```bash +type: "HandDetectionSubgraph" + +input_stream: "input_video" +output_stream: "DETECTIONS:palm_detections" +output_stream: "NORM_RECT:hand_rect_from_palm_detections" + +# Transforms the input image on GPU to a 256x256 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE_GPU:input_video" + output_stream: "IMAGE_GPU:transformed_input_video" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 256 + output_height: 256 + scale_mode: FIT + } + } +} + +# Generates a single side packet containing a TensorFlow Lite op resolver that +# supports custom ops needed by the model used in this graph. +node { + calculator: "TfLiteCustomOpResolverCalculator" + output_side_packet: "opresolver" + node_options: { + [type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] { + use_gpu: true + } + } +} + +# Converts the transformed input image on GPU into an image tensor stored as a +# TfLiteTensor. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE_GPU:transformed_input_video" + output_stream: "TENSORS_GPU:image_tensor" +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS_GPU:image_tensor" + output_stream: "TENSORS:detection_tensors" + input_side_packet: "CUSTOM_OP_RESOLVER:opresolver" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "palm_detection.tflite" + use_gpu: true + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 5 + min_scale: 0.1171875 + max_scale: 0.75 + input_size_height: 256 + input_size_width: 256 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 8 + strides: 16 + strides: 32 + strides: 32 + strides: 32 + aspect_ratios: 1.0 + fixed_anchor_size: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 1 + num_boxes: 2944 + num_coords: 18 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 7 + num_values_per_keypoint: 2 + sigmoid_score: true + score_clipping_thresh: 100.0 + reverse_output_order: true + + x_scale: 256.0 + y_scale: 256.0 + h_scale: 256.0 + w_scale: 256.0 + min_score_thresh: 0.7 + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.3 + overlap_type: INTERSECTION_OVER_UNION + algorithm: WEIGHTED + return_empty_detections: true + } + } +} + +# Maps detection label IDs to the corresponding label text. The label map is +# provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "labeled_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "palm_detection_labelmap.txt" + } + } +} + +# Adjusts detection locations (already normalized to [0.f, 1.f]) on the +# letterboxed image (after image transformation with the FIT scale mode) to the +# corresponding locations on the same image with the letterbox removed (the +# input image to the graph before image transformation). +node { + calculator: "DetectionLetterboxRemovalCalculator" + input_stream: "DETECTIONS:labeled_detections" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "DETECTIONS:palm_detections" +} + +# Extracts image size from the input images. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_GPU:input_video" + output_stream: "SIZE:image_size" +} + +# Converts results of palm detection into a rectangle (normalized by image size) +# that encloses the palm and is rotated such that the line connecting center of +# the wrist and MCP of the middle finger is aligned with the Y-axis of the +# rectangle. +node { + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:palm_detections" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "NORM_RECT:palm_rect" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRectsCalculatorOptions] { + rotation_vector_start_keypoint_index: 0 # Center of wrist. + rotation_vector_end_keypoint_index: 2 # MCP of middle finger. + rotation_vector_target_angle_degrees: 90 + output_zero_rect_for_empty_detections: true + } + } +} + +# Expands and shifts the rectangle that contains the palm so that it's likely +# to cover the entire hand. +node { + calculator: "RectTransformationCalculator" + input_stream: "NORM_RECT:palm_rect" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "hand_rect_from_palm_detections" + node_options: { + [type.googleapis.com/mediapipe.RectTransformationCalculatorOptions] { + scale_x: 2.6 + scale_y: 2.6 + shift_y: -0.5 + square_long: true + } + } +} +``` + +![hand_landmark_gpu_subgraph.pbtxt](images/mobile/hand_landmark_gpu_subgraph.png){width="400"} + +```bash +# MediaPipe hand landmark localization subgraph. + +type: "HandLandmarkSubgraph" + +input_stream: "IMAGE:input_video" +input_stream: "NORM_RECT:hand_rect" +output_stream: "LANDMARKS:hand_landmarks" +output_stream: "NORM_RECT:hand_rect_for_next_frame" +output_stream: "PRESENCE:hand_presence" + +# Crops the rectangle that contains a hand from the input image. +node { + calculator: "ImageCroppingCalculator" + input_stream: "IMAGE_GPU:input_video" + input_stream: "NORM_RECT:hand_rect" + output_stream: "IMAGE_GPU:hand_image" +} + +# Transforms the input image on GPU to a 256x256 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE_GPU:hand_image" + output_stream: "IMAGE_GPU:transformed_hand_image" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 256 + output_height: 256 + scale_mode: FIT + } + } +} + +# Converts the transformed input image on GPU into an image tensor stored as a +# TfLiteTensor. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE_GPU:transformed_hand_image" + output_stream: "TENSORS_GPU:image_tensor" +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS_GPU:image_tensor" + output_stream: "TENSORS:output_tensors" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "hand_landmark.tflite" + use_gpu: true + } + } +} + +# Splits a vector of tensors into multiple vectors. +node { + calculator: "SplitTfLiteTensorVectorCalculator" + input_stream: "output_tensors" + output_stream: "landmark_tensors" + output_stream: "hand_flag_tensor" + node_options: { + [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 1 end: 2 } + } + } +} + +# Converts the hand-flag tensor into a float that represents the confidence +# score of hand presence. +node { + calculator: "TfLiteTensorsToFloatsCalculator" + input_stream: "TENSORS:hand_flag_tensor" + output_stream: "FLOAT:hand_presence_score" +} + +# Applies a threshold to the confidence score to determine whether a hand is +# present. +node { + calculator: "ThresholdingCalculator" + input_stream: "FLOAT:hand_presence_score" + output_stream: "FLAG:hand_presence" + node_options: { + [type.googleapis.com/mediapipe.ThresholdingCalculatorOptions] { + threshold: 0.1 + } + } +} + +# Decodes the landmark tensors into a vector of lanmarks, where the landmark +# coordinates are normalized by the size of the input image to the model. +node { + calculator: "TfLiteTensorsToLandmarksCalculator" + input_stream: "TENSORS:landmark_tensors" + output_stream: "NORM_LANDMARKS:landmarks" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToLandmarksCalculatorOptions] { + num_landmarks: 21 + input_image_width: 256 + input_image_height: 256 + } + } +} + +# Adjusts landmarks (already normalized to [0.f, 1.f]) on the letterboxed hand +# image (after image transformation with the FIT scale mode) to the +# corresponding locations on the same image with the letterbox removed (hand +# image before image transformation). +node { + calculator: "LandmarkLetterboxRemovalCalculator" + input_stream: "LANDMARKS:landmarks" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "LANDMARKS:scaled_landmarks" +} + +# Projects the landmarks from the cropped hand image to the corresponding +# locations on the full image before cropping (input to the graph). +node { + calculator: "LandmarkProjectionCalculator" + input_stream: "NORM_LANDMARKS:scaled_landmarks" + input_stream: "NORM_RECT:hand_rect" + output_stream: "NORM_LANDMARKS:hand_landmarks" +} + +# Extracts image size from the input images. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_GPU:input_video" + output_stream: "SIZE:image_size" +} + +# Converts hand landmarks to a detection that tightly encloses all landmarks. +node { + calculator: "LandmarksToDetectionCalculator" + input_stream: "NORM_LANDMARKS:hand_landmarks" + output_stream: "DETECTION:hand_detection" +} + +# Converts the hand detection into a rectangle (normalized by image size) +# that encloses the hand and is rotated such that the line connecting center of +# the wrist and MCP of the middle finger is aligned with the Y-axis of the +# rectangle. +node { + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTION:hand_detection" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "NORM_RECT:hand_rect_from_landmarks" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRectsCalculatorOptions] { + rotation_vector_start_keypoint_index: 0 # Center of wrist. + rotation_vector_end_keypoint_index: 9 # MCP of middle finger. + rotation_vector_target_angle_degrees: 90 + } + } +} + +# Expands the hand rectangle so that in the next video frame it's likely to +# still contain the hand even with some motion. +node { + calculator: "RectTransformationCalculator" + input_stream: "NORM_RECT:hand_rect_from_landmarks" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "hand_rect_for_next_frame" + node_options: { + [type.googleapis.com/mediapipe.RectTransformationCalculatorOptions] { + scale_x: 1.6 + scale_y: 1.6 + square_long: true + } + } +} +``` + +![hand_renderer_gpu_subgraph.pbtxt](images/mobile/hand_renderer_gpu_subgraph.png){width="500"} + +```bash +# MediaPipe hand tracking rendering subgraph. + +type: "RendererSubgraph" + +input_stream: "IMAGE:input_image" +input_stream: "DETECTIONS:detections" +input_stream: "LANDMARKS:landmarks" +input_stream: "NORM_RECT:rect" +output_stream: "IMAGE:output_image" + +# Converts detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "RENDER_DATA:detection_render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 0 g: 255 b: 0 } + } + } +} + +# Converts landmarks to drawing primitives for annotation overlay. +node { + calculator: "LandmarksToRenderDataCalculator" + input_stream: "NORM_LANDMARKS:landmarks" + output_stream: "RENDER_DATA:landmark_render_data" + node_options: { + [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { + landmark_connections: 0 + landmark_connections: 1 + landmark_connections: 1 + landmark_connections: 2 + landmark_connections: 2 + landmark_connections: 3 + landmark_connections: 3 + landmark_connections: 4 + landmark_connections: 0 + landmark_connections: 5 + landmark_connections: 5 + landmark_connections: 6 + landmark_connections: 6 + landmark_connections: 7 + landmark_connections: 7 + landmark_connections: 8 + landmark_connections: 5 + landmark_connections: 9 + landmark_connections: 9 + landmark_connections: 10 + landmark_connections: 10 + landmark_connections: 11 + landmark_connections: 11 + landmark_connections: 12 + landmark_connections: 9 + landmark_connections: 13 + landmark_connections: 13 + landmark_connections: 14 + landmark_connections: 14 + landmark_connections: 15 + landmark_connections: 15 + landmark_connections: 16 + landmark_connections: 13 + landmark_connections: 17 + landmark_connections: 0 + landmark_connections: 17 + landmark_connections: 17 + landmark_connections: 18 + landmark_connections: 18 + landmark_connections: 19 + landmark_connections: 19 + landmark_connections: 20 + landmark_color { r: 255 g: 0 b: 0 } + connection_color { r: 0 g: 255 b: 0 } + thickness: 5.0 + } + } +} + +# Converts normalized rects to drawing primitives for annotation overlay. +node { + calculator: "RectToRenderDataCalculator" + input_stream: "NORM_RECT:rect" + output_stream: "RENDER_DATA:rect_render_data" + node_options: { + [type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] { + filled: false + color { r: 255 g: 0 b: 0 } + thickness: 4.0 + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME_GPU:input_image" + input_stream: "detection_render_data" + input_stream: "landmark_render_data" + input_stream: "rect_render_data" + output_stream: "OUTPUT_FRAME_GPU:output_image" +} +``` diff --git a/mediapipe/docs/hello_world_android.md b/mediapipe/docs/hello_world_android.md index 6b0f7c69a..ef78f92ed 100644 --- a/mediapipe/docs/hello_world_android.md +++ b/mediapipe/docs/hello_world_android.md @@ -27,18 +27,26 @@ stream on an Android device. ## Graph for edge detection -We will be using the following graph, [`edge_detection_android_gpu.pbtxt`]: +We will be using the following graph, [`edge_detection_mobile_gpu.pbtxt`]: ``` +# MediaPipe graph that performs GPU Sobel edge detection on a live video stream. +# Used in the examples +# mediapipe/examples/android/src/java/com/mediapipe/apps/edgedetectiongpu. +# mediapipe/examples/ios/edgedetectiongpu. + +# Images coming into and out of the graph. input_stream: "input_video" output_stream: "output_video" +# Converts RGB images into luminance images, still stored in RGB format. node: { calculator: "LuminanceCalculator" input_stream: "input_video" output_stream: "luma_video" } +# Applies the Sobel filter to luminance images sotred in RGB format. node: { calculator: "SobelEdgesCalculator" input_stream: "luma_video" @@ -48,7 +56,7 @@ node: { A visualization of the graph is shown below: -![edge_detection_android_gpu_graph](images/mobile/edge_detection_android_graph_gpu.png){width="200"} +![edge_detection_mobile_gpu_graph](images/mobile/edge_detection_mobile_graph_gpu.png){width="200"} This graph has a single input stream named `input_video` for all incoming frames that will be provided by your device's camera. @@ -62,7 +70,7 @@ packets in the `luma_video` stream and outputs results in `output_video` output stream. Our Android application will display the output image frames of the -`sobel_video` stream. +`output_video` stream. ## Initial minimal application setup @@ -582,7 +590,7 @@ First, add dependencies to all calculator code in the `libmediapipe_jni.so` build rule: ``` -"//mediapipe/graphs/edge_detection:android_calculators", +"//mediapipe/graphs/edge_detection:mobile_calculators", ``` MediaPipe graphs are `.pbtxt` files, but to use them in the application, we need @@ -594,7 +602,7 @@ graph: ``` genrule( name = "binary_graph", - srcs = ["//mediapipe/graphs/edge_detection:android_gpu_binary_graph"], + srcs = ["//mediapipe/graphs/edge_detection:mobile_gpu_binary_graph"], outs = ["edgedetectiongpu.binarypb"], cmd = "cp $< $@", ) @@ -712,7 +720,7 @@ If you ran into any issues, please see the full code of the tutorial [CameraX]:https://developer.android.com/training/camerax [`CameraXPreviewHelper`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java [developer options]:https://developer.android.com/studio/debug/dev-options -[`edge_detection_android_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_android_gpu.pbtxt +[`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt [`EdgeDetectionGPU` example]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/ [`EglManager`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/glutil/EglManager.java [`ExternalTextureConverter`]:https://github.com/google/mediapipe/tree/master/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java diff --git a/mediapipe/docs/hello_world_desktop.md b/mediapipe/docs/hello_world_desktop.md index b31994d2a..9f3617dfe 100644 --- a/mediapipe/docs/hello_world_desktop.md +++ b/mediapipe/docs/hello_world_desktop.md @@ -11,7 +11,7 @@ $ export GLOG_logtostderr=1 # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is not supported currently. - $ bazel run --define 'MEDIAPIPE_DISABLE_GPU=1' \ + $ bazel run --define MEDIAPIPE_DISABLE_GPU=1 \ mediapipe/examples/desktop/hello_world:hello_world # It should print 10 rows of Hello World! diff --git a/mediapipe/docs/hello_world_ios.md b/mediapipe/docs/hello_world_ios.md new file mode 100644 index 000000000..a8a8791e7 --- /dev/null +++ b/mediapipe/docs/hello_world_ios.md @@ -0,0 +1,545 @@ +# Hello World! in MediaPipe on iOS + +## Introduction + +This codelab uses MediaPipe on an iOS device. + +### What you will learn + +How to develop an Android application that uses MediaPipe and run a MediaPipe +graph on iOS. + +### What you will build + +A simple camera app for real-time Sobel edge detection applied to a live video +stream on an iOS device. + +![edge_detection_ios_gpu_gif](images/mobile/edge_detection_ios_gpu.gif){width="300"} + +## Setup + +1. Install MediaPipe on your system, see [MediaPipe installation guide] for + details. +2. Setup your iOS device for development. +3. Setup [Bazel] on your system to build and deploy the iOS app. + +## Graph for edge detection + +We will be using the following graph, [`edge_detection_mobile_gpu.pbtxt`]: + +``` +# MediaPipe graph that performs GPU Sobel edge detection on a live video stream. +# Used in the examples +# mediapipe/examples/android/src/java/com/mediapipe/apps/edgedetectiongpu. +# mediapipe/examples/ios/edgedetectiongpu. + +# Images coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Converts RGB images into luminance images, still stored in RGB format. +node: { + calculator: "LuminanceCalculator" + input_stream: "input_video" + output_stream: "luma_video" +} + +# Applies the Sobel filter to luminance images sotred in RGB format. +node: { + calculator: "SobelEdgesCalculator" + input_stream: "luma_video" + output_stream: "output_video" +} +``` + +A visualization of the graph is shown below: + +![edge_detection_mobile_gpu_graph](images/mobile/edge_detection_mobile_graph_gpu.png){width="200"} + +This graph has a single input stream named `input_video` for all incoming frames +that will be provided by your device's camera. + +The first node in the graph, `LuminanceCalculator`, takes a single packet (image +frame) and applies a change in luminance using an OpenGL shader. The resulting +image frame is sent to the `luma_video` output stream. + +The second node, `SobelEdgesCalculator` applies edge detection to incoming +packets in the `luma_video` stream and outputs results in `output_video` output +stream. + +Our iOS application will display the output image frames of the `output_video` +stream. + +## Initial minimal application setup + +We first start with a simple iOS application and demonstrate how to use `bazel` +to build it. + +First, create an XCode project via File > New > Single View App. + +Set the product name to "EdgeDetectionGpu", and use an appropriate organization +identifier, such as `com.google.mediapipe`. The organization identifier +alongwith the product name will be the `bundle_id` for the application, such as +`com.google.mediapipe.EdgeDetectionGpu`. + +Set the language to Objective-C. + +Save the project to an appropriate location. Let's call this +`$PROJECT_TEMPLATE_LOC`. So your project will be in the +`$PROJECT_TEMPLATE_LOC/EdgeDetectionGpu` directory. This directory will contain +another directory named `EdgeDetectionGpu` and an `EdgeDetectionGpu.xcodeproj` file. + +The `EdgeDetectionGpu.xcodeproj` will not be useful for this tutorial, as we will +use bazel to build the iOS application. The content of the +`$PROJECT_TEMPLATE_LOC/EdgeDetectionGpu/EdgeDetectionGpu` directory is listed below: + +1. `AppDelegate.h` and `AppDelegate.m` +2. `ViewController.h` and `ViewController.m` +3. `main.m` +4. `Info.plist` +5. `Main.storyboard` and `Launch.storyboard` +6. `Assets.xcassets` directory. + +Copy these files to a directory named `EdgeDetectionGpu` to a location that can +access the MediaPipe source code. For example, the source code of the +application that we will build in this tutorial is located in +`mediapipe/examples/ios/EdgeDetectionGpu`. We will refer to this path as the +`$APPLICATION_PATH` throughout the codelab. + +Note: MediaPipe provides Objective-C bindings for iOS. The edge detection +application in this tutorial and all iOS examples using MediaPipe use +Objective-C with C++ in `.mm` files. + +Create a `BUILD` file in the `$APPLICATION_PATH` and add the following build +rules: + +``` +MIN_IOS_VERSION = "10.0" + +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_application", +) + +ios_application( + name = "EdgeDetectionGpuApp", + bundle_id = "com.google.mediapipe.EdgeDetectionGpu", + families = [ + "iphone", + "ipad", + ], + infoplists = ["Info.plist"], + minimum_os_version = MIN_IOS_VERSION, + provisioning_profile = "//mediapipe/examples/ios:developer_provisioning_profile", + deps = [":EdgeDetectionGpuAppLibrary"], +) + +objc_library( + name = "EdgeDetectionGpuAppLibrary", + srcs = [ + "AppDelegate.m", + "ViewController.m", + "main.m", + ], + hdrs = [ + "AppDelegate.h", + "ViewController.h", + ], + data = [ + "Base.lproj/LaunchScreen.storyboard", + "Base.lproj/Main.storyboard", + ], + sdk_frameworks = [ + "UIKit", + ], + deps = [], +) +``` + +The `objc_library` rule adds dependencies for the `AppDelegate` and +`ViewController` classes, `main.m` and the application storyboards. The +templated app depends only on the `UIKit` SDK. + +The `ios_application` rule uses the `EdgeDetectionGpuAppLibrary` Objective-C +library generated to build an iOS application for installation on your iOS +device. + +Note: You need to point to your own iOS developer provisioning profile to be +able to run the application on your iOS device. + +To build the app, use the following command in a terminal: + +``` +bazel build -c opt --config=ios_arm64 <$APPLICATION_PATH>:EdgeDetectionGpuApp' +``` + +For example, to build the `EdgeDetectionGpuApp` application in +`mediapipe/examples/ios/edgedetection`, use the following command: + +``` +bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/edgedetection:EdgeDetectionGpuApp +``` + +Then, go back to XCode, open Window > Devices and Simulators, select your +device, and add the `.ipa` file generated by the command above to your device. + +Open the application on your device. Since it is empty, it should display a +blank white screen. + +## Use the camera for the live view feed + +In this tutorial, we will use the `MediaPipeCameraInputSource` class to access +and grab frames from the camera. This class uses the `AVCaptureSession` API to +get the frames from the camera. + +But before using this class, change the `Info.plist` file to support camera +usage in the app. + +In `ViewController.m`, add the following import line: + +``` +#import "mediapipe/objc/MediaPipeCameraInputSource.h" +``` + +Add the following to its implementation block to create an object +`_cameraSource`: + +``` +@implementation ViewController { + // Handles camera access via AVCaptureSession library. + MediaPipeCameraInputSource* _cameraSource; +} +``` + +Add the following code to `viewDidLoad()`: + +``` +-(void)viewDidLoad { + [super viewDidLoad]; + + _cameraSource = [[MediaPipeCameraInputSource alloc] init]; + _cameraSource.sessionPreset = AVCaptureSessionPresetHigh; + _cameraSource.cameraPosition = AVCaptureDevicePositionBack; + // The frame's native format is rotated with respect to the portrait orientation. + _cameraSource.orientation = AVCaptureVideoOrientationPortrait; +} +``` + +The code initializes `_cameraSource`, sets the capture session preset, and which +camera to use. + +We need to get frames from the `_cameraSource` into our application +`ViewController` to display them. `MediaPipeCameraInputSource` is a subclass of +`MediaPipeInputSource`, which provides a protocol for its delegates, namely the +`MediaPipeInputSourceDelegate`. So our application `ViewController` can be a +delegate of `_cameraSource`. + +To handle camera setup and process incoming frames, we should use a queue +different from the main queue. Add the following to the implementation block of +the `ViewController`: + +``` +// Process camera frames on this queue. +dispatch_queue_t _videoQueue; +``` + +In `viewDidLoad()`, add the following line after initializing the +`_cameraSource` object: + +``` +[_cameraSource setDelegate:self queue:_videoQueue]; +``` + +And add the following code to initialize the queue before setting up the +`_cameraSource` object: + +``` +dispatch_queue_attr_t qosAttribute = dispatch_queue_attr_make_with_qos_class( + DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0); +_videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute); +``` + +We will use a serial queue with the priority `QOS_CLASS_USER_INTERACTIVE` for +processing camera frames. + +Add the following line after the header imports at the top of the file, before +the interface/implementation of the `ViewController`: + +``` +static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; +``` + +Before implementing any method from `MediaPipeInputSourceDelegate` protocol, we +must first set up a way to display the camera frames. MediaPipe provides another +utility called `MediaPipeLayerRenderer` to display images on the screen. This +utility can be used to display `CVPixelBufferRef` objects, which is the type of +the images provided by `MediaPipeCameraInputSource` to its delegates. + +To display images of the screen, we need to add a new `UIView` object called +`_liveView` to the `ViewController`. + +Add the following lines to the implementation block of the `ViewController`: + +``` +// Display the camera preview frames. +IBOutlet UIView* _liveView; +// Render frames in a layer. +MediaPipeLayerRenderer* _renderer; +``` + +Go to `Main.storyboard`, add a `UIView` object from the object library to the +`View` of the `ViewController` class. Add a referencing outlet from this view to +the `_liveView` object you just added to the `ViewController` class. Resize the +view so that it is centered and covers the entire application screen. + +Go back to `ViewController.m` and add the following code to `viewDidLoad()` to +initialize the `_renderer` object: + +``` +_renderer = [[MediaPipeLayerRenderer alloc] init]; +_renderer.layer.frame = _liveView.layer.bounds; +[_liveView.layer addSublayer:_renderer.layer]; +_renderer.frameScaleMode = MediaPipeFrameScaleFillAndCrop; +``` + +To get frames from the camera, we will implement the following method: + +``` +// Must be invoked on _videoQueue. +- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer + timestamp:(CMTime)timestamp + fromSource:(MediaPipeInputSource*)source { + if (source != _cameraSource) { + NSLog(@"Unknown source: %@", source); + return; + } + // Display the captured image on the screen. + CFRetain(imageBuffer); + dispatch_async(dispatch_get_main_queue(), ^{ + [_renderer renderPixelBuffer:imageBuffer]; + CFRelease(imageBuffer); + }); +} +``` + +This is a delegate method of `MediaPipeInputSource`. We first check that we are +getting frames from the right source, i.e. the `_cameraSource`. Then we display +the frame received from the camera via `_renderer` on the main queue. + +Now, we need to start the camera as soon as the view to display the frames is +about to appear. To do this, we will implement the +`viewWillAppear:(BOOL)animated` function: + +``` +-(void)viewWillAppear:(BOOL)animated { + [super viewWillAppear:animated]; +} +``` + +Before we start running the camera, we need the user's permission to access it. +`MediaPipeCameraInputSource` provides a function +`requestCameraAccessWithCompletionHandler:(void (^_Nullable)(BOOL +granted))handler` to request camera access and do some work when the user has +responded. Add the following code to `viewWillAppear:animated`: + +``` +[_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { + if (granted) { + dispatch_queue(_videoQueue, ^{ + [_cameraSource start]; + }); + } +}]; +``` + +Before building the application, add the following dependencies to your `BUILD` +file: + +``` +sdk_frameworks = [ + "AVFoundation", + "CoreGraphics", + "CoreMedia", +], +deps = [ + "//mediapipe/objc:mediapipe_framework_ios", + "//mediapipe/objc:mediapipe_input_sources_ios", + "//mediapipe/objc:mediapipe_layer_renderer", +], +``` + +Now build and run the application on your iOS device. You should see a live +camera view feed after accepting camera permissions. + +We are now ready to use camera frames in a MediaPipe graph. + +## Using a MediaPipe graph in iOS + +### Add relevant dependencies + +We already added the dependencies of the MediaPipe framework code which contains +the iOS API to use a MediaPipe graph. To use a MediaPipe graph, we need to add a +dependency on the graph we intend to use in our application. Add the following +line to the `data` list in your `BUILD` file: + +``` +"//mediapipe/graphs/edge_detection:mobile_gpu_binary_graph", +``` + +Now add the dependency to the calculators used in this graph in the `deps` field +in the `BUILD` file: + +``` +"//mediapipe/graphs/edge_detection:mobile_calculators", +``` + +Finally, rename the file `ViewController.m` to `ViewController.mm` to support +Objective-C++. + +### Use the graph in `ViewController` + +Declare a static constant with the name of the graph, the input stream and the +output stream: + +``` +static NSString* const kGraphName = @"android_gpu"; + +static const char* kInputStream = "input_video"; +static const char* kOutputStream = "output_video"; +``` + +Add the following property to the interface of the `ViewController`: + +``` +// The MediaPipe graph currently in use. Initialized in viewDidLoad, started in viewWillAppear: and +// sent video frames on _videoQueue. +@property(nonatomic) MediaPipeGraph* mediapipeGraph; +``` + +As explained in the comment above, we will initialize this graph in +`viewDidLoad` first. To do so, we need to load the graph from the `.pbtxt` file +using the following function: + +``` ++ (MediaPipeGraph*)loadGraphFromResource:(NSString*)resource { + // Load the graph config resource. + NSError* configLoadError = nil; + NSBundle* bundle = [NSBundle bundleForClass:[self class]]; + if (!resource || resource.length == 0) { + return nil; + } + NSURL* graphURL = [bundle URLForResource:resource withExtension:@"binarypb"]; + NSData* data = [NSData dataWithContentsOfURL:graphURL options:0 error:&configLoadError]; + if (!data) { + NSLog(@"Failed to load MediaPipe graph config: %@", configLoadError); + return nil; + } + + // Parse the graph config resource into mediapipe::CalculatorGraphConfig proto object. + mediapipe::CalculatorGraphConfig config; + config.ParseFromArray(data.bytes, data.length); + + // Create MediaPipe graph with mediapipe::CalculatorGraphConfig proto object. + MediaPipeGraph* newGraph = [[MediaPipeGraph alloc] initWithGraphConfig:config]; + [newGraph addFrameOutputStream:kOutputStream outputPacketType:MediaPipePacketPixelBuffer]; + return newGraph; +} +``` + +Use this function to initialize the graph in `viewDidLoad` as follows: + +``` +self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName]; +``` + +The graph should send the results of processing camera frames back to the +`ViewController`. Add the following line after initializing the graph to set the +`ViewController` as a delegate of the `mediapipeGraph` object: + +``` +self.mediapipeGraph.delegate = self; +``` + +To avoid memory contention while processing frames from the live video feed, add +the following line: + +``` +// Set maxFramesInFlight to a small value to avoid memory contention for real-time processing. +self.mediapipeGraph.maxFramesInFlight = 2; +``` + +Now, start the graph when the user has granted the permission to use the camera +in our app: + +``` +[_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { + if (granted) { + // Start running self.mediapipeGraph. + NSError* error; + if (![self.mediapipeGraph startWithError:&error]) { + NSLog(@"Failed to start graph: %@", error); + } + + dispatch_queue(_videoQueue, ^{ + [_cameraSource start]; + }); + } +}]; +``` + +Note: It is important to start the graph before starting the camera, so that +the graph is ready to process frames as soon as the camera starts sending them. + +Earlier, when we received frames from the camera in the `processVideoFrame` +function, we displayed them in the `_liveView` using the `_renderer`. Now, we +need to send those frames to the graph and render the results instead. Modify +this function's implementation to do the following: + +``` +- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer + timestamp:(CMTime)timestamp + fromSource:(MediaPipeInputSource*)source { + if (source != _cameraSource) { + NSLog(@"Unknown source: %@", source); + return; + } + [self.mediapipeGraph sendPixelBuffer:imageBuffer + intoStream:kInputStream + packetType:MediaPipePacketPixelBuffer]; +} +``` + +We send the `imageBuffer` to `self.mediapipeGraph` as a packet of type +`MediaPipePacketPixelBuffer` into the input stream `kInputStream`, i.e. +"input_video". + +The graph will run with this input packet and output a result in +`kOutputStream`, i.e. "output_video". We can implement the following delegate +method to receive packets on this output stream and display them on the screen: + +``` +- (void)mediapipeGraph:(MediaPipeGraph*)graph + didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer + fromStream:(const std::string&)streamName { + if (streamName == kOutputStream) { + // Display the captured image on the screen. + CVPixelBufferRetain(pixelBuffer); + dispatch_async(dispatch_get_main_queue(), ^{ + [_renderer renderPixelBuffer:pixelBuffer]; + CVPixelBufferRelease(pixelBuffer); + }); + } +} +``` + +And that is all! Build and run the app on your iOS device. You should see the +results of running the edge detection graph on a live video feed. Congrats! + +![edge_detection_ios_gpu_gif](images/mobile/edge_detection_ios_gpu.gif){width="300"} + +If you ran into any issues, please see the full code of the tutorial +[here](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/edgedetectiongpu). + +[Bazel]:https://bazel.build/ +[`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt +[MediaPipe installation guide]:./install.md diff --git a/mediapipe/docs/how_to_questions.md b/mediapipe/docs/how_to_questions.md index ddc6c80c6..076e5ab69 100644 --- a/mediapipe/docs/how_to_questions.md +++ b/mediapipe/docs/how_to_questions.md @@ -19,7 +19,7 @@ section [GpuBuffer to ImageFrame converters](./gpu.md). You can see an example in: - * [`object_detection_android_cpu.pbtxt`] +* [`object_detection_mobile_cpu.pbtxt`] ### How to visualize perception results @@ -29,7 +29,7 @@ the recognized objects. The results can be displayed in a diagnostic window when running on a workstation, or in a texture frame when running on device. You can see an example use of [`AnnotationOverlayCalculator`] in: - * [`face_detection_android_gpu.pbtxt`]. +* [`face_detection_mobile_gpu.pbtxt`]. ### How to run calculators in parallel @@ -106,7 +106,7 @@ continues as long as necessary. For online processing, it is often necessary to drop input packets in order to keep pace with the arrival of input data frames. When inputs arrive too frequently, the recommended technique for dropping packets is to use the MediaPipe calculators designed specifically for this -purpose such as [`RealTimeFlowLimiterCalculator`] and [`PacketClonerCalculator`]. +purpose such as [`FlowLimiterCalculator`] and [`PacketClonerCalculator`]. For online processing, it is also necessary to promptly determine when processing can proceed. MediaPipe supports this by propagating timestamp bounds between @@ -124,20 +124,19 @@ MacOS, Android, and iOS. The core of MediaPipe framework is a C++ library conforming to the C++11 standard, so it is relatively easy to port to additional platforms. -[`object_detection_android_cpu.pbtxt`]: https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_android_cpu.pbtxt - +[`object_detection_mobile_cpu.pbtxt`]: https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt [`ImageFrame`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/formats/image_frame.h [`GpuBuffer`]: https://github.com/google/mediapipe/tree/master/mediapipe/gpu/gpu_buffer.h [`GpuBufferToImageFrameCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc [`ImageFrameToGpuBufferCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc [`AnnotationOverlayCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/util/annotation_overlay_calculator.cc -[`face_detection_android_gpu.pbtxt`]: https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_detection/face_detection_android_gpu.pbtxt +[`face_detection_mobile_gpu.pbtxt`]: https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt [`CalculatorBase::Process`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_base.h [`max_in_flight`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto [`RoundRobinDemuxCalculator`]: https://github.com/google/mediapipe/tree/master//mediapipe/calculators/core/round_robin_demux_calculator.cc [`ScaleImageCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/image/scale_image_calculator.cc [`ImmediateInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc [`CalculatorGraphConfig`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto -[`RealTimeFlowLimiterCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +[`FlowLimiterCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/flow_limiter_calculator.cc [`PacketClonerCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/packet_cloner_calculator.cc [`MakePairCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/make_pair_calculator.cc diff --git a/mediapipe/docs/images/add_ipa.png b/mediapipe/docs/images/add_ipa.png new file mode 100644 index 000000000..6fb793487 Binary files /dev/null and b/mediapipe/docs/images/add_ipa.png differ diff --git a/mediapipe/docs/images/app_ipa.png b/mediapipe/docs/images/app_ipa.png new file mode 100644 index 000000000..ebbe0ec87 Binary files /dev/null and b/mediapipe/docs/images/app_ipa.png differ diff --git a/mediapipe/docs/images/app_ipa_added.png b/mediapipe/docs/images/app_ipa_added.png new file mode 100644 index 000000000..e6b1efd1b Binary files /dev/null and b/mediapipe/docs/images/app_ipa_added.png differ diff --git a/mediapipe/docs/images/bazel_permission.png b/mediapipe/docs/images/bazel_permission.png new file mode 100644 index 000000000..e67dd72dc Binary files /dev/null and b/mediapipe/docs/images/bazel_permission.png differ diff --git a/mediapipe/docs/images/click_subgraph_handdetection.png b/mediapipe/docs/images/click_subgraph_handdetection.png new file mode 100644 index 000000000..32cf3a1da Binary files /dev/null and b/mediapipe/docs/images/click_subgraph_handdetection.png differ diff --git a/mediapipe/docs/images/device.png b/mediapipe/docs/images/device.png new file mode 100644 index 000000000..d911a24c2 Binary files /dev/null and b/mediapipe/docs/images/device.png differ diff --git a/mediapipe/docs/images/maingraph_visualizer.png b/mediapipe/docs/images/maingraph_visualizer.png new file mode 100644 index 000000000..d34865c41 Binary files /dev/null and b/mediapipe/docs/images/maingraph_visualizer.png differ diff --git a/mediapipe/docs/images/mobile/edge_detection_ios_gpu.gif b/mediapipe/docs/images/mobile/edge_detection_ios_gpu.gif new file mode 100644 index 000000000..7417fcbc3 Binary files /dev/null and b/mediapipe/docs/images/mobile/edge_detection_ios_gpu.gif differ diff --git a/mediapipe/docs/images/mobile/edge_detection_android_graph_gpu.png b/mediapipe/docs/images/mobile/edge_detection_mobile_graph_gpu.png similarity index 100% rename from mediapipe/docs/images/mobile/edge_detection_android_graph_gpu.png rename to mediapipe/docs/images/mobile/edge_detection_mobile_graph_gpu.png diff --git a/mediapipe/docs/images/mobile/face_detection_android_gpu.png b/mediapipe/docs/images/mobile/face_detection_android_gpu.png deleted file mode 100644 index e73f1b52a..000000000 Binary files a/mediapipe/docs/images/mobile/face_detection_android_gpu.png and /dev/null differ diff --git a/mediapipe/docs/images/mobile/face_detection_mobile_gpu.png b/mediapipe/docs/images/mobile/face_detection_mobile_gpu.png new file mode 100644 index 000000000..452b1a17f Binary files /dev/null and b/mediapipe/docs/images/mobile/face_detection_mobile_gpu.png differ diff --git a/mediapipe/docs/images/mobile/hair_segmentation_android_gpu.png b/mediapipe/docs/images/mobile/hair_segmentation_android_gpu.png deleted file mode 100644 index 461b0e3a9..000000000 Binary files a/mediapipe/docs/images/mobile/hair_segmentation_android_gpu.png and /dev/null differ diff --git a/mediapipe/docs/images/mobile/hair_segmentation_mobile_gpu.png b/mediapipe/docs/images/mobile/hair_segmentation_mobile_gpu.png new file mode 100644 index 000000000..465046816 Binary files /dev/null and b/mediapipe/docs/images/mobile/hair_segmentation_mobile_gpu.png differ diff --git a/mediapipe/docs/images/mobile/hand_detection_android_gpu.gif b/mediapipe/docs/images/mobile/hand_detection_android_gpu.gif new file mode 100644 index 000000000..86f6f91f8 Binary files /dev/null and b/mediapipe/docs/images/mobile/hand_detection_android_gpu.gif differ diff --git a/mediapipe/docs/images/mobile/hand_detection_gpu_subgraph.png b/mediapipe/docs/images/mobile/hand_detection_gpu_subgraph.png new file mode 100644 index 000000000..c3fbc2ee0 Binary files /dev/null and b/mediapipe/docs/images/mobile/hand_detection_gpu_subgraph.png differ diff --git a/mediapipe/docs/images/mobile/hand_detection_mobile.png b/mediapipe/docs/images/mobile/hand_detection_mobile.png new file mode 100644 index 000000000..a0a763285 Binary files /dev/null and b/mediapipe/docs/images/mobile/hand_detection_mobile.png differ diff --git a/mediapipe/docs/images/mobile/hand_landmark_gpu_subgraph.png b/mediapipe/docs/images/mobile/hand_landmark_gpu_subgraph.png new file mode 100644 index 000000000..e40a9169b Binary files /dev/null and b/mediapipe/docs/images/mobile/hand_landmark_gpu_subgraph.png differ diff --git a/mediapipe/docs/images/mobile/hand_renderer_gpu_subgraph.png b/mediapipe/docs/images/mobile/hand_renderer_gpu_subgraph.png new file mode 100644 index 000000000..7fd2f5589 Binary files /dev/null and b/mediapipe/docs/images/mobile/hand_renderer_gpu_subgraph.png differ diff --git a/mediapipe/docs/images/mobile/hand_tracking_android_gpu.gif b/mediapipe/docs/images/mobile/hand_tracking_android_gpu.gif new file mode 100644 index 000000000..675f15121 Binary files /dev/null and b/mediapipe/docs/images/mobile/hand_tracking_android_gpu.gif differ diff --git a/mediapipe/docs/images/mobile/hand_tracking_mobile.png b/mediapipe/docs/images/mobile/hand_tracking_mobile.png new file mode 100644 index 000000000..83850f507 Binary files /dev/null and b/mediapipe/docs/images/mobile/hand_tracking_mobile.png differ diff --git a/mediapipe/docs/images/mobile/object_detection_android_cpu.png b/mediapipe/docs/images/mobile/object_detection_android_cpu.png deleted file mode 100644 index 2efcdd9b1..000000000 Binary files a/mediapipe/docs/images/mobile/object_detection_android_cpu.png and /dev/null differ diff --git a/mediapipe/docs/images/mobile/object_detection_android_gpu.png b/mediapipe/docs/images/mobile/object_detection_android_gpu.png deleted file mode 100644 index 603d82dba..000000000 Binary files a/mediapipe/docs/images/mobile/object_detection_android_gpu.png and /dev/null differ diff --git a/mediapipe/docs/images/mobile/object_detection_desktop_tensorflow.png b/mediapipe/docs/images/mobile/object_detection_desktop_tensorflow.png new file mode 100644 index 000000000..50d7597f1 Binary files /dev/null and b/mediapipe/docs/images/mobile/object_detection_desktop_tensorflow.png differ diff --git a/mediapipe/docs/images/mobile/object_detection_desktop_tflite.png b/mediapipe/docs/images/mobile/object_detection_desktop_tflite.png new file mode 100644 index 000000000..b66ff2c09 Binary files /dev/null and b/mediapipe/docs/images/mobile/object_detection_desktop_tflite.png differ diff --git a/mediapipe/docs/images/mobile/object_detection_mobile_cpu.png b/mediapipe/docs/images/mobile/object_detection_mobile_cpu.png new file mode 100644 index 000000000..48d7fb88e Binary files /dev/null and b/mediapipe/docs/images/mobile/object_detection_mobile_cpu.png differ diff --git a/mediapipe/docs/images/mobile/object_detection_mobile_gpu.png b/mediapipe/docs/images/mobile/object_detection_mobile_gpu.png new file mode 100644 index 000000000..3f9ee6926 Binary files /dev/null and b/mediapipe/docs/images/mobile/object_detection_mobile_gpu.png differ diff --git a/mediapipe/docs/images/mobile/renderer_gpu.png b/mediapipe/docs/images/mobile/renderer_gpu.png new file mode 100644 index 000000000..9b062b9b1 Binary files /dev/null and b/mediapipe/docs/images/mobile/renderer_gpu.png differ diff --git a/mediapipe/docs/images/upload_2pbtxt.png b/mediapipe/docs/images/upload_2pbtxt.png new file mode 100644 index 000000000..02a079ae8 Binary files /dev/null and b/mediapipe/docs/images/upload_2pbtxt.png differ diff --git a/mediapipe/docs/images/upload_graph_button.png b/mediapipe/docs/images/upload_graph_button.png new file mode 100644 index 000000000..9cbf31a8e Binary files /dev/null and b/mediapipe/docs/images/upload_graph_button.png differ diff --git a/mediapipe/docs/install.md b/mediapipe/docs/install.md index f91beda1b..cd68f98da 100644 --- a/mediapipe/docs/install.md +++ b/mediapipe/docs/install.md @@ -2,15 +2,25 @@ Choose your operating system: +- [Prework](#prework) - [Dependences](#dependences) - [Installing on Debian and Ubuntu](#installing-on-debian-and-ubuntu) - [Installing on CentOS](#installing-on-centos) - [Installing on macOS](#installing-on-macos) - [Installing on Windows Subsystem for Linux (WSL)](#installing-on-windows-subsystem-for-linux-wsl) - [Installing using Docker](#installing-using-docker) +- [Setting up Android Studio with MediaPipe](#setting-up-android-studio-with-mediapipe) - [Setting up Android SDK and NDK](#setting-up-android-sdk-and-ndk) -### Dependences +### Prework + +* Install a package manager, e.g., Homebrew for macOS, and APT for Debian and Ubuntu + +* Install Xcode for the iOS apps (macOS only) + +* Install Android Studio for the Android apps + +### Dependencies Required libraries @@ -56,8 +66,8 @@ Required libraries Option 1. Use package manager tool to install the pre-compiled OpenCV libraries. - Note: Debian 9 and Ubuntu 16.04 provide OpenCV 2.4.9. You may want to - take option 2 or 3 to install OpenCV 3 or above. + Note: Debian 9 and Ubuntu 16.04 provide OpenCV 2.4.9. You may want to take + option 2 or 3 to install OpenCV 3 or above. ```bash $ sudo apt-get install libopencv-core-dev libopencv-highgui-dev \ @@ -71,11 +81,11 @@ Required libraries [documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html) to manually build OpenCV from source code. - Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to point - MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is - installed in "/usr/local/", you need to update the "linux_opencv" - new_local_repository rule in [`WORKSAPCE`] and "opencv" cc_library rule in - [`opencv_linux.BUILD`] like the following: + Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to + point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed + in "/usr/local/", you need to update the "linux_opencv" new_local_repository + rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`] + like the following: ```bash new_local_repository( @@ -110,7 +120,7 @@ Required libraries ```bash $ export GLOG_logtostderr=1 # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is currently not supported - $ bazel run --define 'MEDIAPIPE_DISABLE_GPU=1' \ + $ bazel run --define MEDIAPIPE_DISABLE_GPU=1 \ mediapipe/examples/desktop/hello_world:hello_world # Should print: @@ -156,11 +166,11 @@ Required libraries Option 2. Build OpenCV from source code. - Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to point - MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is - installed in "/usr/local/", you need to update the "linux_opencv" - new_local_repository rule in [`WORKSAPCE`] and "opencv" cc_library rule in - [`opencv_linux.BUILD`] like the following: + Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to + point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed + in "/usr/local/", you need to update the "linux_opencv" new_local_repository + rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`] + like the following: ```bash new_local_repository( @@ -195,7 +205,7 @@ Required libraries ```bash $ export GLOG_logtostderr=1 # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is currently not supported - $ bazel run --define 'MEDIAPIPE_DISABLE_GPU=1' \ + $ bazel run --define MEDIAPIPE_DISABLE_GPU=1 \ mediapipe/examples/desktop/hello_world:hello_world # Should print: @@ -284,7 +294,7 @@ Required libraries ```bash $ export GLOG_logtostderr=1 # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is currently not supported - $ bazel run --define 'MEDIAPIPE_DISABLE_GPU=1' \ + $ bazel run --define MEDIAPIPE_DISABLE_GPU=1 \ mediapipe/examples/desktop/hello_world:hello_world # Should print: @@ -403,7 +413,7 @@ Required libraries username@DESKTOP-TMVLBJ1:~/mediapipe$ export GLOG_logtostderr=1 # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is currently not supported - username@DESKTOP-TMVLBJ1:~/mediapipe$ bazel run --define 'MEDIAPIPE_DISABLE_GPU=1' \ + username@DESKTOP-TMVLBJ1:~/mediapipe$ bazel run --define MEDIAPIPE_DISABLE_GPU=1 \ mediapipe/examples/desktop/hello_world:hello_world # Should print: @@ -454,7 +464,7 @@ This will use a Docker image that will isolate mediapipe's installation from the ```bash $ docker run -it --name mediapipe mediapipe:latest - root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazel run --define 'MEDIAPIPE_DISABLE_GPU=1' mediapipe/examples/desktop/hello_world:hello_world + root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazel run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world # Should print: # Hello World! @@ -472,7 +482,7 @@ This will use a Docker image that will isolate mediapipe's installation from the - ### Setting up Android Studio with MediaPipe The steps below use Android Studio to build and install a MediaPipe demo app. @@ -491,13 +500,16 @@ The steps below use Android Studio to build and install a MediaPipe demo app. 2. Select `Configure` | `SDK Manager` | `SDK Platforms` - * verify that an Android SDK is installed - * note the Android SDK Location such as `/usr/local/home/Android/Sdk` + * Verify that Android SDK Platform API Level 28 or 29 is installed + * Note the Android SDK Location such as `/usr/local/home/Android/Sdk` 3. Select `Configure` | `SDK Manager` | `SDK Tools` - * verify that an Android NDK is installed - * note the Android NDK Location such as `/usr/local/home/Android/Sdk/ndk-bundle` + * Verify that Android SDK Build-Tools 28 or 29 is installed + * Verify that Android SDK Platform-Tools 28 or 29 is installed + * Verify that Android SDK Tools 26.1.1 is installed + * Verify that Android NDK 17c or above is installed + * Note the Android NDK Location such as `/usr/local/home/Android/Sdk/ndk-bundle` 4. Set environment variables `$ANDROID_HOME` and `$ANDROID_NDK_HOME` to point to the installed SDK and NDK. @@ -511,18 +523,18 @@ The steps below use Android Studio to build and install a MediaPipe demo app. 6. Select `Import Bazel Project` - * select `Workspace`: `/path/to/mediapipe` - * select `Generate from BUILD file`: `/path/to/mediapipe/BUILD` - * select `Finish` + * Select `Workspace`: `/path/to/mediapipe` + * Select `Generate from BUILD file`: `/path/to/mediapipe/BUILD` + * Select `Finish` 7. Connect an android device to the workstation. 8. Select `Run...` | `Edit Configurations...` - * enter Target Expression: + * Enter Target Expression: `//mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu` - * enter Bazel command: `mobile-install` - * enter Bazel flags: `-c opt --config=android_arm64` select `Run` + * Enter Bazel command: `mobile-install` + * Enter Bazel flags: `-c opt --config=android_arm64` select `Run` ### Setting up Android SDK and NDK @@ -534,8 +546,18 @@ export ANDROID_HOME= export ANDROID_NDK_HOME= ``` -Otherwise, please run [`setup_android_sdk_and_ndk.sh`] to download and setup -Android SDK and NDK for MediaPipe before building any Android demos. +Please verify all the necessary packages are installed + +* Android SDK Platform API Level 28 or 29 +* Android SDK Build-Tools 28 or 29 +* Android SDK Platform-Tools 28 or 29 +* Android SDK Tools 26.1.1 +* Android NDK 17c or above + +MediaPipe prefers to use the Android SDK and NDK from Android Studio. See +[the previous section](#setting-up-android-studio-with-mediapipe) for the +Android Studio setup. If you prefer to try MediaPipe without Android Studio, please run [`setup_android_sdk_and_ndk.sh`] to download and setup Android SDK and NDK for +MediaPipe before building any Android demos. [`WORKSAPCE`]: https://github.com/google/mediapipe/tree/master/WORKSPACE [`opencv_linux.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_linux.BUILD diff --git a/mediapipe/docs/media_sequence.md b/mediapipe/docs/media_sequence.md index adcd473f9..1a09832d9 100644 --- a/mediapipe/docs/media_sequence.md +++ b/mediapipe/docs/media_sequence.md @@ -13,7 +13,8 @@ tasks like video object detection, but very difficult to encode in TensorFlow.Examples. The goal of MediaSequence is to simplify working with SequenceExamples and to automate common preparation tasks. Much more information is available about the MediaSequence pipeline, including how to use it to -process new data sets, in the [documentation](https://github.com/google/mediapipe/tree/master/mediapipe/util/sequence/README.md). +process new data sets, in the documentation of +[MediaSequence](https://github.com/google/mediapipe/tree/master/mediapipe/util/sequence). ### Preparing an example data set @@ -27,20 +28,20 @@ process new data sets, in the [documentation](https://github.com/google/mediapip 1. Compile the MediaSequence demo C++ binary ```bash - bazel build -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo --define 'MEDIAPIPE_DISABLE_GPU=1' + bazel build -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo --define MEDIAPIPE_DISABLE_GPU=1 ``` MediaSequence uses C++ binaries to improve multimedia processing speed and encourage a strong separation between annotations and the image data or other features. The binary code is very general in that it reads from files into input side packets and writes output side packets to files when - completed, but it also links in all of the calculators for necessary for - the MediaPipe graphs preparing the Charades data set. + completed, but it also links in all of the calculators for necessary for the + MediaPipe graphs preparing the Charades data set. 1. Download and prepare the data set through Python To run this step, you must have Python 2.7 or 3.5+ installed with the - TensorFlow 1.19+ package installed. + TensorFlow 1.14+ package installed. ```bash python -m mediapipe.examples.desktop.media_sequence.demo_dataset \ @@ -56,10 +57,11 @@ process new data sets, in the [documentation](https://github.com/google/mediapip MediaPipe graphs during processing. Running this module - 1. Downloads videos from the internet. - 1. For each annotation in a CSV, creates a structured metadata file. - 1. Runs MediaPipe to extract images as defined by the metadata. - 1. Stores the results in numbered set of TFRecords files. + + 1. Downloads videos from the internet. + 1. For each annotation in a CSV, creates a structured metadata file. + 1. Runs MediaPipe to extract images as defined by the metadata. + 1. Stores the results in numbered set of TFRecords files. MediaSequence uses SequenceExamples as the format of both inputs and outputs. Annotations are encoded as inputs in a SequenceExample of metadata @@ -84,12 +86,16 @@ process new data sets, in the [documentation](https://github.com/google/mediapip demo_data_path = '/tmp/demo_data/' with tf.Graph().as_default(): d = DemoDataset(demo_data_path) - dataset = d.as_dataset("test") + dataset = d.as_dataset('test') # implement additional processing and batching here - output = dataset.make_one_shot_iterator().get_next() + dataset_output = dataset.make_one_shot_iterator().get_next() + images = dataset_output=['images'] + labels = dataset_output=['labels'] with tf.Session() as sess: - output_ = sess.run(output) + images_, labels_ = sess.run(images, labels) + print('The shape of images_ is %s' % str(images_.shape)) + print('The shape of labels_ is %s' % str(labels_.shape)) ``` ### Preparing a practical data set @@ -104,9 +110,9 @@ The Charades data set is large (~150 GB), and will take considerable time to download and process (4-8 hours). ```bash -bazel build -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo --define 'MEDIAPIPE_DISABLE_GPU=1' +bazel build -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo --define MEDIAPIPE_DISABLE_GPU=1 -python -m mediapipe.examples.desktop.media_sequence.demo_dataset \ +python -m mediapipe.examples.desktop.media_sequence.charades_dataset \ --alsologtostderr \ --path_to_charades_data=/tmp/demo_data/ \ --path_to_mediapipe_binary=bazel-bin/mediapipe/examples/desktop/media_sequence/media_sequence_demo \ @@ -115,7 +121,7 @@ python -m mediapipe.examples.desktop.media_sequence.demo_dataset \ ### Preparing your own data set The process for preparing your own data set is described in the [MediaSequence -documentation](https://github.com/google/mediapipe/blob/master/mediapipe/util/sequence/README.md). +documentation](https://github.com/google/mediapipe/tree/master/mediapipe/util/sequence). The Python code for Charades can easily be modified to process most annotations, but the MediaPipe processing warrants further discussion. MediaSequence uses MediaPipe graphs to extract features related to the metadata or previously @@ -145,7 +151,7 @@ node { output_side_packet: "DATA_PATH:input_video_path" output_side_packet: "RESAMPLER_OPTIONS:packet_resampler_options" options { - [mediapipe.UnpackMediaSequenceCalculatorOptions.ext]: { + [type.googleapis.com/mediapipe.UnpackMediaSequenceCalculatorOptions]: { base_packet_resampler_options { frame_rate: 24.0 base_timestamp: 0 diff --git a/mediapipe/docs/mediapipe_ios_setup.md b/mediapipe/docs/mediapipe_ios_setup.md new file mode 100644 index 000000000..ecddae1d3 --- /dev/null +++ b/mediapipe/docs/mediapipe_ios_setup.md @@ -0,0 +1,48 @@ +## Setting up MediaPipe for iOS + +1. Install [Xcode](https://developer.apple.com/xcode/). + + Follow Apple's instructions to obtain the required developemnt certificates + and provisioning profiles for your iOS device. + +2. Install [Bazel](https://bazel.build/). + + See their [instructions](https://docs.bazel.build/versions/master/install-os-x.html). + We recommend using [Homebrew](https://brew.sh/): + + ```bash + brew tap bazelbuild/tap + brew install bazelbuild/tap/bazel + ``` + +3. Clone the MediaPipe repository. + + ```bash + git clone https://github.com/google/mediapipe.git + ``` + +4. Symlink or copy your provisioning profile to `mediapipe/mediapipe/provisioning_profile.mobileprovision`. + + ```bash + cd mediapipe + ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision + ``` + +## Building an iOS app from the command line + +1. Build one of the example apps for iOS. We will be using the + [Face Detection GPU App example](./face_detection_mobile_gpu.md) + + ```bash + bazel build --config=ios_arm64 mediapipe/examples/ios/facedetectiongpu:FaceDetectionGpuApp + ``` + + You may see a permission request from `codesign` in order to sign the app. + +2. In Xcode, open the `Devices and Simulators` window (command-shift-2). + +3. Make sure your device is connected. You will see a list of installed apps. + Press the "+" button under the list, and select the `.ipa` file built by + Bazel. + +4. You can now run the app on your device. diff --git a/mediapipe/docs/object_detection_desktop.md b/mediapipe/docs/object_detection_desktop.md index 9222e903b..f69eab16e 100644 --- a/mediapipe/docs/object_detection_desktop.md +++ b/mediapipe/docs/object_detection_desktop.md @@ -23,8 +23,8 @@ To build and run the TensorFlow example on desktop, run: # Note that this command also builds TensorFlow targets from scratch, it may # take a long time (e.g., up to 30 mins) to build for the first time. $ bazel build -c opt \ - --define 'MEDIAPIPE_DISABLE_GPU=1' \ - --define 'no_aws_support=true' \ + --define MEDIAPIPE_DISABLE_GPU=1 \ + --define no_aws_support=true \ mediapipe/examples/desktop/object_detection:object_detection_tensorflow # It should print: @@ -189,7 +189,7 @@ node { To build and run the TensorFlow Lite example on desktop, run: ```bash -$ bazel build -c opt --define 'MEDIAPIPE_DISABLE_GPU=1' \ +$ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \ mediapipe/examples/desktop/object_detection:object_detection_tflite # It should print: diff --git a/mediapipe/docs/object_detection_android_cpu.md b/mediapipe/docs/object_detection_mobile_cpu.md similarity index 86% rename from mediapipe/docs/object_detection_android_cpu.md rename to mediapipe/docs/object_detection_mobile_cpu.md index 3d68f8d45..4ee0459c7 100644 --- a/mediapipe/docs/object_detection_android_cpu.md +++ b/mediapipe/docs/object_detection_mobile_cpu.md @@ -1,9 +1,9 @@ -# Object Detection on CPU on Android +# Object Detection (CPU) Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for general instructions to develop an Android application that uses MediaPipe. This doc focuses on the -[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_android_cpu.pbtxt) +[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt) that performs object detection with TensorFlow Lite on CPU. This is very similar to the @@ -16,7 +16,7 @@ CPU. ![object_detection_android_cpu_gif](images/mobile/object_detection_android_cpu.gif){width="300"} -## App +## Android The graph is used in the [Object Detection CPU](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu) @@ -32,9 +32,24 @@ To further install the app on android device, run: adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/objectdetectioncpu.apk ``` +## iOS + +Please see [Hello World! in MediaPipe on iOS](hello_world_ios.md) for general +instructions to develop an iOS application that uses MediaPipe. The graph below +is used in the +[Object Detection GPU iOS example app](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/objectdetectioncpu). + +To build the iOS app, please see the general +[MediaPipe iOS app building and setup instructions](./mediapipe_ios_setup.md). +Specifically, run: + +```bash +bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/objectdetectioncpu:ObjectDetectionCpuApp +``` + ## Graph -![object_detection_android_cpu_graph](images/mobile/object_detection_android_cpu.png){width="400"} +![object_detection_mobile_cpu_graph](images/mobile/object_detection_mobile_cpu.png){width="400"} To visualize the graph as shown above, copy the text specification of the graph below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). @@ -42,7 +57,8 @@ below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). ```bash # MediaPipe graph that performs object detection with TensorFlow Lite on CPU. # Used in the example in -# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectioncpu. +# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectioncpu and +# mediapipie/examples/ios/objectdetectioncpu. # Images on GPU coming into and out of the graph. input_stream: "input_video" @@ -72,7 +88,7 @@ node: { # TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy # processing previous inputs. node { - calculator: "RealTimeFlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "input_video_cpu" input_stream: "FINISHED:detections" input_stream_info: { @@ -99,22 +115,12 @@ node: { } } -# Converts the transformed input image on CPU into an image tensor as a -# TfLiteTensor. The zero_center option is set to true to normalize the -# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically -# option is set to true to account for the descrepancy between the -# representation of the input image (origin at the bottom-left corner) and what -# the model used in this graph is expecting (origin at the top-left corner). +# Converts the transformed input image on CPU into an image tensor stored as a +# TfLiteTensor. node { calculator: "TfLiteConverterCalculator" input_stream: "IMAGE:transformed_input_video_cpu" output_stream: "TENSORS:image_tensor" - node_options: { - [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { - zero_center: true - flip_vertically: true - } - } } # Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a @@ -181,7 +187,7 @@ node { y_scale: 10.0 h_scale: 5.0 w_scale: 5.0 - flip_vertically: true + min_score_thresh: 0.6 } } } @@ -194,9 +200,9 @@ node { node_options: { [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { min_suppression_threshold: 0.4 - min_score_threshold: 0.6 max_num_detections: 3 overlap_type: INTERSECTION_OVER_UNION + return_empty_detections: true } } } @@ -217,7 +223,7 @@ node { # Converts the detections to drawing primitives for annotation overlay. node { calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION_VECTOR:output_detections" + input_stream: "DETECTIONS:output_detections" output_stream: "RENDER_DATA:render_data" node_options: { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { @@ -229,19 +235,12 @@ node { # Draws annotations and overlays them on top of the CPU copy of the original # image coming into the graph. The calculator assumes that image origin is -# always at the top-left corner and renders text accordingly. However, the input -# image has its origin at the bottom-left corner (OpenGL convention) and the -# flip_text_vertically option is set to true to compensate that. +# always at the top-left corner and renders text accordingly. node { calculator: "AnnotationOverlayCalculator" input_stream: "INPUT_FRAME:throttled_input_video_cpu" input_stream: "render_data" output_stream: "OUTPUT_FRAME:output_video_cpu" - node_options: { - [type.googleapis.com/mediapipe.AnnotationOverlayCalculatorOptions] { - flip_text_vertically: true - } - } } # Transfers the annotated image from CPU back to GPU memory, to be sent out of diff --git a/mediapipe/docs/object_detection_android_gpu.md b/mediapipe/docs/object_detection_mobile_gpu.md similarity index 78% rename from mediapipe/docs/object_detection_android_gpu.md rename to mediapipe/docs/object_detection_mobile_gpu.md index 5759258a5..917f7aa3d 100644 --- a/mediapipe/docs/object_detection_android_gpu.md +++ b/mediapipe/docs/object_detection_mobile_gpu.md @@ -1,16 +1,17 @@ -# Object Detection on GPU on Android +# Object Detection (GPU) -Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for -general instructions to develop an Android application that uses MediaPipe. This -doc focuses on the -[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_android_gpu.pbtxt) +This doc focuses on the +[below example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt) that performs object detection with TensorFlow Lite on GPU. ![object_detection_android_gpu_gif](images/mobile/object_detection_android_gpu.gif){width="300"} -## App +## Android -The graph is used in the +Please see [Hello World! in MediaPipe on Android](hello_world_android.md) for +general instructions to develop an Android application that uses MediaPipe. + +The graph below is used in the [Object Detection GPU](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu) example app. To build the app, run: @@ -24,9 +25,24 @@ To further install the app on android device, run: adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/objectdetectiongpu.apk ``` +## iOS + +Please see [Hello World! in MediaPipe on iOS](hello_world_ios.md) for general +instructions to develop an iOS application that uses MediaPipe. The graph below +is used in the +[Object Detection GPU iOS example app](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/objectdetectiongpu) + +To build the iOS app, please see the general +[MediaPipe iOS app building and setup instructions](./mediapipe_ios_setup.md). +Specifically, run: + +```bash +bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp +``` + ## Graph -![object_detection_android_gpu_graph](images/mobile/object_detection_android_gpu.png){width="400"} +![object_detection_mobile_gpu_graph](images/mobile/object_detection_mobile_gpu.png){width="400"} To visualize the graph as shown above, copy the text specification of the graph below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). @@ -34,7 +50,8 @@ below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/). ```bash # MediaPipe graph that performs object detection with TensorFlow Lite on GPU. # Used in the example in -# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectiongpu. +# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectiongpu and +# mediapipie/examples/ios/objectdetectiongpu. # Images on GPU coming into and out of the graph. input_stream: "input_video" @@ -54,7 +71,7 @@ output_stream: "output_video" # TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy # processing previous inputs. node { - calculator: "RealTimeFlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "input_video" input_stream: "FINISHED:detections" input_stream_info: { @@ -81,23 +98,12 @@ node: { } } -# Converts the transformed input image on GPU into an image tensor stored in -# tflite::gpu::GlBuffer. The zero_center option is set to true to normalize the -# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically -# option is set to true to account for the descrepancy between the -# representation of the input image (origin at the bottom-left corner, the -# OpenGL convention) and what the model used in this graph is expecting (origin -# at the top-left corner). +# Converts the transformed input image on GPU into an image tensor stored as a +# TfLiteTensor. node { calculator: "TfLiteConverterCalculator" input_stream: "IMAGE_GPU:transformed_input_video" output_stream: "TENSORS_GPU:image_tensor" - node_options: { - [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { - zero_center: true - flip_vertically: true - } - } } # Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a @@ -106,7 +112,7 @@ node { node { calculator: "TfLiteInferenceCalculator" input_stream: "TENSORS_GPU:image_tensor" - output_stream: "TENSORS_GPU:detection_tensors" + output_stream: "TENSORS:detection_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { model_path: "ssdlite_object_detection.tflite" @@ -149,7 +155,7 @@ node { # detections. Each detection describes a detected object. node { calculator: "TfLiteTensorsToDetectionsCalculator" - input_stream: "TENSORS_GPU:detection_tensors" + input_stream: "TENSORS:detection_tensors" input_side_packet: "ANCHORS:anchors" output_stream: "DETECTIONS:detections" node_options: { @@ -164,7 +170,7 @@ node { y_scale: 10.0 h_scale: 5.0 w_scale: 5.0 - flip_vertically: true + min_score_thresh: 0.6 } } } @@ -177,9 +183,9 @@ node { node_options: { [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { min_suppression_threshold: 0.4 - min_score_threshold: 0.6 max_num_detections: 3 overlap_type: INTERSECTION_OVER_UNION + return_empty_detections: true } } } @@ -200,7 +206,7 @@ node { # Converts the detections to drawing primitives for annotation overlay. node { calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION_VECTOR:output_detections" + input_stream: "DETECTIONS:output_detections" output_stream: "RENDER_DATA:render_data" node_options: { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { @@ -210,22 +216,13 @@ node { } } -# Draws annotations and overlays them on top of the original image coming into -# the graph. Annotation drawing is performed on CPU, and the result is -# transferred to GPU and overlaid on the input image. The calculator assumes -# that image origin is always at the top-left corner and renders text -# accordingly. However, the input image has its origin at the bottom-left corner -# (OpenGL convention) and the flip_text_vertically option is set to true to -# compensate that. +# Draws annotations and overlays them on top of a GPU copy of the original +# image coming into the graph. The calculator assumes that image origin is +# always at the top-left corner and renders text accordingly. node { calculator: "AnnotationOverlayCalculator" input_stream: "INPUT_FRAME_GPU:throttled_input_video" input_stream: "render_data" output_stream: "OUTPUT_FRAME_GPU:output_video" - node_options: { - [type.googleapis.com/mediapipe.AnnotationOverlayCalculatorOptions] { - flip_text_vertically: true - } - } } ``` diff --git a/mediapipe/docs/scheduling_sync.md b/mediapipe/docs/scheduling_sync.md index 382a9a326..7f77a989c 100644 --- a/mediapipe/docs/scheduling_sync.md +++ b/mediapipe/docs/scheduling_sync.md @@ -143,14 +143,13 @@ system that relaxes configured limits when needed. The second system consists of inserting special nodes which can drop packets according to real-time constraints (typically using custom input policies) -defined by [`RealTimeFlowLimiterCalculator`]. For example, a common pattern -places a flow-control node at the input of a subgraph, with a loopback -connection from the final output to the flow-control node. The flow-control node -is thus able to keep track of how many timestamps are being processed in the -downstream graph, and drop packets if this count hits a (configurable) limit; -and since packets are dropped upstream, we avoid the wasted work that would -result from partially processing a timestamp and then dropping packets between -intermediate stages. +defined by [`FlowLimiterCalculator`]. For example, a common pattern places a +flow-control node at the input of a subgraph, with a loopback connection from +the final output to the flow-control node. The flow-control node is thus able to +keep track of how many timestamps are being processed in the downstream graph, +and drop packets if this count hits a (configurable) limit; and since packets +are dropped upstream, we avoid the wasted work that would result from partially +processing a timestamp and then dropping packets between intermediate stages. This calculator-based approach gives the graph author control of where packets can be dropped, and allows flexibility in adapting and customizing the graph’s @@ -161,4 +160,4 @@ behavior depending on resource constraints. [`SyncSetInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/sync_set_input_stream_handler.h [`ImmediateInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/immediate_input_stream_handler.h [`CalculatorGraphConfig::max_queue_size`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto -[`RealTimeFlowLimiterCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +[`FlowLimiterCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/flow_limiter_calculator.cc diff --git a/mediapipe/docs/tracer.md b/mediapipe/docs/tracer.md new file mode 100644 index 000000000..e69de29bb diff --git a/mediapipe/docs/troubleshooting.md b/mediapipe/docs/troubleshooting.md index a3ad61e72..f1069a642 100644 --- a/mediapipe/docs/troubleshooting.md +++ b/mediapipe/docs/troubleshooting.md @@ -128,7 +128,7 @@ If some of the calculators in the graph cannot keep pace with the realtime input streams, then latency will continue to increase, and it becomes necessary to drop some input packets. The recommended technique is to use the MediaPipe calculators designed specifically for this purpose such as -[`RealTimeFlowLimiterCalculator`] as described in +[`FlowLimiterCalculator`] as described in [How to process realtime input streams](how_to_questions.md#how-to-process-realtime-input-streams). [`CalculatorGraphConfig`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto @@ -141,4 +141,4 @@ calculators designed specifically for this purpose such as [`CalculatorGraph::WaitUntilDone`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_graph.h [`Timestamp::Done`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/timestamp.h [`CalculatorBase::Close`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_base.h -[`RealTimeFlowLimiterCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +[`FlowLimiterCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/flow_limiter_calculator.cc diff --git a/mediapipe/docs/visualizer.md b/mediapipe/docs/visualizer.md index e35a0dc7c..f4216b65f 100644 --- a/mediapipe/docs/visualizer.md +++ b/mediapipe/docs/visualizer.md @@ -59,3 +59,24 @@ The visualizer graph shows the connections between calculator nodes. ![Special nodes](./images/special_nodes.png){width="350"} ![Special nodes](./images/special_nodes_code.png){width="350"} + +### Visualizing subgraphs + +The MediaPipe visualizer can display multiple graphs. If a graph has a name (designated by assigning a string to the "type" field in the top level of the graph's proto file) and that name is used as a calculator name in a separate graph, it is considered a subgraph and colored appropriately where it is used. Clicking on a subgraph will navigate to the corresponding tab which holds the subgraph's definition. In this example, for hand detection GPU we have 2 pbtxt files: +[hand_detection_mobile.pbtxt](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_mobile.pbtxt) +and its associated [subgraph](./framework_concepts.md#subgraph) called +[hand_detection_gpu.pbtxt](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt) + +* In the default MediaPipe visualizer, click on upload graph button and select + the 2 pbtxt files to visualize (main graph and all its associated subgraphs) + + ![Upload graph button](./images/upload_button.png){width="250"} + + ![Choose the 2 files](./images/upload_2pbtxt.png){width="400"} + +* You will see 3 tabs. The main graph tab is `hand_detection_mobile.pbtxt` + ![hand_detection_mobile_gpu.pbtxt](./images/maingraph_visualizer.png){width="1500"} + +* Click on the subgraph block in purple `Hand Detection` and the + `hand_detection_gpu.pbtxt` tab will open + ![Hand detection subgraph](./images/clicksubgraph_handdetection.png){width="1500"} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/README.md b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/README.md index 4a5a1cd33..2126761ec 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/README.md +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/README.md @@ -12,6 +12,9 @@ This directory contains MediaPipe Android example applications for different use | Face Detection on GPU | facedetectiongpu | | Object Detection on CPU | objectdetectioncpu | | Object Detection on GPU | objectdetectiongpu | +| Hair Segmentation on GPU | hairsegmentationgpu | +| Hand Detection on GPU | handdetectiongpu | +| Hand Tracking on GPU | handtrackinggpu | For instance, to build an example app for face detection on CPU, run: diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/BUILD index 21ee273a0..94a6101cd 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/BUILD @@ -21,7 +21,7 @@ cc_binary( linkshared = 1, linkstatic = 1, deps = [ - "//mediapipe/graphs/edge_detection:android_calculators", + "//mediapipe/graphs/edge_detection:mobile_calculators", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", ], ) @@ -37,7 +37,7 @@ cc_library( # MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". genrule( name = "binary_graph", - srcs = ["//mediapipe/graphs/edge_detection:android_gpu_binary_graph"], + srcs = ["//mediapipe/graphs/edge_detection:mobile_gpu_binary_graph"], outs = ["edgedetectiongpu.binarypb"], cmd = "cp $< $@", ) @@ -57,8 +57,8 @@ android_library( "//mediapipe/java/com/google/mediapipe/components:android_components", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/glutil", - "//third_party:android_constraint_layout", "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", "//third_party:opencv", "@androidx_concurrent_futures//jar", "@com_google_guava_android//jar", @@ -67,7 +67,6 @@ android_library( android_binary( name = "edgedetectiongpu", - aapt_version = "aapt2", manifest = "AndroidManifest.xml", manifest_values = {"applicationId": "com.google.mediapipe.apps.edgedetectiongpu"}, multidex = "native", diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/MainActivity.java index 7ee302137..a0623d420 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/MainActivity.java @@ -38,6 +38,12 @@ public class MainActivity extends AppCompatActivity { private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.BACK; + // Flips the camera-preview frames vertically before sending them into FrameProcessor to be + // processed in a MediaPipe graph, and flips the processed frames back when they are displayed. + // This is needed because OpenGL represents images assuming the image origin is at the bottom-left + // corner, whereas MediaPipe in general assumes the image origin is at top-left. + private static final boolean FLIP_FRAMES_VERTICALLY = true; + static { // Load all native libraries needed by the app. System.loadLibrary("mediapipe_jni"); @@ -81,6 +87,7 @@ public class MainActivity extends AppCompatActivity { BINARY_GRAPH_NAME, INPUT_VIDEO_STREAM_NAME, OUTPUT_VIDEO_STREAM_NAME); + processor.getVideoSurfaceOutput().setFlipY(FLIP_FRAMES_VERTICALLY); PermissionHelper.checkAndRequestCameraPermissions(this); } @@ -89,6 +96,7 @@ public class MainActivity extends AppCompatActivity { protected void onResume() { super.onResume(); converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setFlipY(FLIP_FRAMES_VERTICALLY); converter.setConsumer(processor); if (PermissionHelper.cameraPermissionsGranted(this)) { startCamera(); diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/layout/activity_main.xml index 22240a2d6..c19d7e628 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/layout/activity_main.xml +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/edgedetectiongpu/res/layout/activity_main.xml @@ -1,5 +1,5 @@ - - + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/BUILD index 0c1a5be65..699dc2f6d 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/BUILD @@ -21,7 +21,7 @@ cc_binary( linkshared = 1, linkstatic = 1, deps = [ - "//mediapipe/graphs/face_detection:android_calculators", + "//mediapipe/graphs/face_detection:mobile_calculators", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", ], ) @@ -37,7 +37,7 @@ cc_library( # MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". genrule( name = "binary_graph", - srcs = ["//mediapipe/graphs/face_detection:android_cpu_binary_graph"], + srcs = ["//mediapipe/graphs/face_detection:mobile_cpu_binary_graph"], outs = ["facedetectioncpu.binarypb"], cmd = "cp $< $@", ) @@ -47,8 +47,8 @@ android_library( srcs = glob(["*.java"]), assets = [ ":binary_graph", - "//mediapipe/models:facedetector_front.tflite", - "//mediapipe/models:facedetector_front_labelmap.txt", + "//mediapipe/models:face_detection_front.tflite", + "//mediapipe/models:face_detection_front_labelmap.txt", ], assets_dir = "", manifest = "AndroidManifest.xml", @@ -59,11 +59,12 @@ android_library( "//mediapipe/java/com/google/mediapipe/components:android_components", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/glutil", - "//third_party:android_constraint_layout", "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:androidx_legacy_support_v4", + "//third_party:androidx_material", + "//third_party:androidx_recyclerview", "//third_party:opencv", - "@androidsdk//com.android.support:recyclerview-v7-25.0.0", - "@androidsdk//com.android.support:support-v4-25.0.0", "@androidx_concurrent_futures//jar", "@androidx_lifecycle//jar", "@com_google_code_findbugs//jar", @@ -73,7 +74,6 @@ android_library( android_binary( name = "facedetectioncpu", - aapt_version = "aapt2", manifest = "AndroidManifest.xml", manifest_values = {"applicationId": "com.google.mediapipe.apps.facedetectioncpu"}, multidex = "native", diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/MainActivity.java index a0dc964f3..15511dcd4 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/MainActivity.java @@ -39,6 +39,12 @@ public class MainActivity extends AppCompatActivity { private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.FRONT; + // Flips the camera-preview frames vertically before sending them into FrameProcessor to be + // processed in a MediaPipe graph, and flips the processed frames back when they are displayed. + // This is needed because OpenGL represents images assuming the image origin is at the bottom-left + // corner, whereas MediaPipe in general assumes the image origin is at top-left. + private static final boolean FLIP_FRAMES_VERTICALLY = true; + static { // Load all native libraries needed by the app. System.loadLibrary("mediapipe_jni"); @@ -82,6 +88,7 @@ public class MainActivity extends AppCompatActivity { BINARY_GRAPH_NAME, INPUT_VIDEO_STREAM_NAME, OUTPUT_VIDEO_STREAM_NAME); + processor.getVideoSurfaceOutput().setFlipY(FLIP_FRAMES_VERTICALLY); PermissionHelper.checkAndRequestCameraPermissions(this); } @@ -90,6 +97,7 @@ public class MainActivity extends AppCompatActivity { protected void onResume() { super.onResume(); converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setFlipY(FLIP_FRAMES_VERTICALLY); converter.setConsumer(processor); if (PermissionHelper.cameraPermissionsGranted(this)) { startCamera(); diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/layout/activity_main.xml index 22240a2d6..c19d7e628 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/layout/activity_main.xml +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/res/layout/activity_main.xml @@ -1,5 +1,5 @@ - - + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD index 7728b3bd7..9f1140fbe 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD @@ -21,7 +21,7 @@ cc_binary( linkshared = 1, linkstatic = 1, deps = [ - "//mediapipe/graphs/face_detection:android_calculators", + "//mediapipe/graphs/face_detection:mobile_calculators", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", ], ) @@ -37,7 +37,7 @@ cc_library( # MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". genrule( name = "binary_graph", - srcs = ["//mediapipe/graphs/face_detection:android_gpu_binary_graph"], + srcs = ["//mediapipe/graphs/face_detection:mobile_gpu_binary_graph"], outs = ["facedetectiongpu.binarypb"], cmd = "cp $< $@", ) @@ -47,8 +47,8 @@ android_library( srcs = glob(["*.java"]), assets = [ ":binary_graph", - "//mediapipe/models:facedetector_front.tflite", - "//mediapipe/models:facedetector_front_labelmap.txt", + "//mediapipe/models:face_detection_front.tflite", + "//mediapipe/models:face_detection_front_labelmap.txt", ], assets_dir = "", manifest = "AndroidManifest.xml", @@ -59,11 +59,12 @@ android_library( "//mediapipe/java/com/google/mediapipe/components:android_components", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/glutil", - "//third_party:android_constraint_layout", "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:androidx_legacy_support_v4", + "//third_party:androidx_material", + "//third_party:androidx_recyclerview", "//third_party:opencv", - "@androidsdk//com.android.support:recyclerview-v7-25.0.0", - "@androidsdk//com.android.support:support-v4-25.0.0", "@androidx_concurrent_futures//jar", "@androidx_lifecycle//jar", "@com_google_code_findbugs//jar", @@ -73,7 +74,6 @@ android_library( android_binary( name = "facedetectiongpu", - aapt_version = "aapt2", manifest = "AndroidManifest.xml", manifest_values = {"applicationId": "com.google.mediapipe.apps.facedetectiongpu"}, multidex = "native", diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/MainActivity.java index d232992fb..9ee86d662 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/MainActivity.java @@ -39,6 +39,12 @@ public class MainActivity extends AppCompatActivity { private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.FRONT; + // Flips the camera-preview frames vertically before sending them into FrameProcessor to be + // processed in a MediaPipe graph, and flips the processed frames back when they are displayed. + // This is needed because OpenGL represents images assuming the image origin is at the bottom-left + // corner, whereas MediaPipe in general assumes the image origin is at top-left. + private static final boolean FLIP_FRAMES_VERTICALLY = true; + static { // Load all native libraries needed by the app. System.loadLibrary("mediapipe_jni"); @@ -82,6 +88,7 @@ public class MainActivity extends AppCompatActivity { BINARY_GRAPH_NAME, INPUT_VIDEO_STREAM_NAME, OUTPUT_VIDEO_STREAM_NAME); + processor.getVideoSurfaceOutput().setFlipY(FLIP_FRAMES_VERTICALLY); PermissionHelper.checkAndRequestCameraPermissions(this); } @@ -90,6 +97,7 @@ public class MainActivity extends AppCompatActivity { protected void onResume() { super.onResume(); converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setFlipY(FLIP_FRAMES_VERTICALLY); converter.setConsumer(processor); if (PermissionHelper.cameraPermissionsGranted(this)) { startCamera(); diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/layout/activity_main.xml index 22240a2d6..c19d7e628 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/layout/activity_main.xml +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/res/layout/activity_main.xml @@ -1,5 +1,5 @@ - - + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/BUILD index 071ccf986..6ec734e03 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/BUILD @@ -21,7 +21,7 @@ cc_binary( linkshared = 1, linkstatic = 1, deps = [ - "//mediapipe/graphs/hair_segmentation:android_calculators", + "//mediapipe/graphs/hair_segmentation:mobile_calculators", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", ], ) @@ -37,7 +37,7 @@ cc_library( # MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". genrule( name = "binary_graph", - srcs = ["//mediapipe/graphs/hair_segmentation:android_gpu_binary_graph"], + srcs = ["//mediapipe/graphs/hair_segmentation:mobile_gpu_binary_graph"], outs = ["hairsegmentationgpu.binarypb"], cmd = "cp $< $@", ) @@ -58,11 +58,12 @@ android_library( "//mediapipe/java/com/google/mediapipe/components:android_components", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/glutil", - "//third_party:android_constraint_layout", "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:androidx_legacy_support_v4", + "//third_party:androidx_material", + "//third_party:androidx_recyclerview", "//third_party:opencv", - "@androidsdk//com.android.support:recyclerview-v7-25.0.0", - "@androidsdk//com.android.support:support-v4-25.0.0", "@androidx_concurrent_futures//jar", "@androidx_lifecycle//jar", "@com_google_code_findbugs//jar", @@ -72,7 +73,6 @@ android_library( android_binary( name = "hairsegmentationgpu", - aapt_version = "aapt2", manifest = "AndroidManifest.xml", manifest_values = {"applicationId": "com.google.mediapipe.apps.hairsegmentationgpu"}, multidex = "native", diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/MainActivity.java index c33311ffb..a7c80be4e 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/MainActivity.java @@ -39,6 +39,12 @@ public class MainActivity extends AppCompatActivity { private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.FRONT; + // Flips the camera-preview frames vertically before sending them into FrameProcessor to be + // processed in a MediaPipe graph, and flips the processed frames back when they are displayed. + // This is needed because OpenGL represents images assuming the image origin is at the bottom-left + // corner, whereas MediaPipe in general assumes the image origin is at top-left. + private static final boolean FLIP_FRAMES_VERTICALLY = true; + static { // Load all native libraries needed by the app. System.loadLibrary("mediapipe_jni"); @@ -82,6 +88,7 @@ public class MainActivity extends AppCompatActivity { BINARY_GRAPH_NAME, INPUT_VIDEO_STREAM_NAME, OUTPUT_VIDEO_STREAM_NAME); + processor.getVideoSurfaceOutput().setFlipY(FLIP_FRAMES_VERTICALLY); PermissionHelper.checkAndRequestCameraPermissions(this); } @@ -90,6 +97,7 @@ public class MainActivity extends AppCompatActivity { protected void onResume() { super.onResume(); converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setFlipY(FLIP_FRAMES_VERTICALLY); converter.setConsumer(processor); if (PermissionHelper.cameraPermissionsGranted(this)) { startCamera(); diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/layout/activity_main.xml index 22240a2d6..c19d7e628 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/layout/activity_main.xml +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/hairsegmentationgpu/res/layout/activity_main.xml @@ -1,5 +1,5 @@ - - + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/BUILD index 3cab05ff2..07757c45d 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/BUILD @@ -21,7 +21,7 @@ cc_binary( linkshared = 1, linkstatic = 1, deps = [ - "//mediapipe/graphs/object_detection:android_calculators", + "//mediapipe/graphs/object_detection:mobile_calculators", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", ], ) @@ -37,7 +37,7 @@ cc_library( # MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". genrule( name = "binary_graph", - srcs = ["//mediapipe/graphs/object_detection:android_cpu_binary_graph"], + srcs = ["//mediapipe/graphs/object_detection:mobile_cpu_binary_graph"], outs = ["objectdetectioncpu.binarypb"], cmd = "cp $< $@", ) @@ -59,11 +59,12 @@ android_library( "//mediapipe/java/com/google/mediapipe/components:android_components", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/glutil", - "//third_party:android_constraint_layout", "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:androidx_legacy_support_v4", + "//third_party:androidx_material", + "//third_party:androidx_recyclerview", "//third_party:opencv", - "@androidsdk//com.android.support:recyclerview-v7-25.0.0", - "@androidsdk//com.android.support:support-v4-25.0.0", "@androidx_concurrent_futures//jar", "@androidx_lifecycle//jar", "@com_google_code_findbugs//jar", @@ -73,7 +74,6 @@ android_library( android_binary( name = "objectdetectioncpu", - aapt_version = "aapt2", manifest = "AndroidManifest.xml", manifest_values = {"applicationId": "com.google.mediapipe.apps.objectdetectioncpu"}, multidex = "native", diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/MainActivity.java index 2cbbe7cd5..9a113c9fd 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/MainActivity.java @@ -39,6 +39,12 @@ public class MainActivity extends AppCompatActivity { private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.BACK; + // Flips the camera-preview frames vertically before sending them into FrameProcessor to be + // processed in a MediaPipe graph, and flips the processed frames back when they are displayed. + // This is needed because OpenGL represents images assuming the image origin is at the bottom-left + // corner, whereas MediaPipe in general assumes the image origin is at top-left. + private static final boolean FLIP_FRAMES_VERTICALLY = true; + static { // Load all native libraries needed by the app. System.loadLibrary("mediapipe_jni"); @@ -82,6 +88,7 @@ public class MainActivity extends AppCompatActivity { BINARY_GRAPH_NAME, INPUT_VIDEO_STREAM_NAME, OUTPUT_VIDEO_STREAM_NAME); + processor.getVideoSurfaceOutput().setFlipY(FLIP_FRAMES_VERTICALLY); PermissionHelper.checkAndRequestCameraPermissions(this); } @@ -90,6 +97,7 @@ public class MainActivity extends AppCompatActivity { protected void onResume() { super.onResume(); converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setFlipY(FLIP_FRAMES_VERTICALLY); converter.setConsumer(processor); if (PermissionHelper.cameraPermissionsGranted(this)) { startCamera(); diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/layout/activity_main.xml index 22240a2d6..c19d7e628 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/layout/activity_main.xml +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectioncpu/res/layout/activity_main.xml @@ -1,5 +1,5 @@ - - + diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/BUILD index 39a3d1523..56bf2e1f2 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/BUILD @@ -21,7 +21,7 @@ cc_binary( linkshared = 1, linkstatic = 1, deps = [ - "//mediapipe/graphs/object_detection:android_calculators", + "//mediapipe/graphs/object_detection:mobile_calculators", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", ], ) @@ -37,7 +37,7 @@ cc_library( # MainActivity.BINARY_GRAPH_NAME = "appname.binarypb". genrule( name = "binary_graph", - srcs = ["//mediapipe/graphs/object_detection:android_gpu_binary_graph"], + srcs = ["//mediapipe/graphs/object_detection:mobile_gpu_binary_graph"], outs = ["objectdetectiongpu.binarypb"], cmd = "cp $< $@", ) @@ -59,11 +59,12 @@ android_library( "//mediapipe/java/com/google/mediapipe/components:android_components", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/glutil", - "//third_party:android_constraint_layout", "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:androidx_legacy_support_v4", + "//third_party:androidx_material", + "//third_party:androidx_recyclerview", "//third_party:opencv", - "@androidsdk//com.android.support:recyclerview-v7-25.0.0", - "@androidsdk//com.android.support:support-v4-25.0.0", "@androidx_concurrent_futures//jar", "@androidx_lifecycle//jar", "@com_google_code_findbugs//jar", @@ -73,7 +74,6 @@ android_library( android_binary( name = "objectdetectiongpu", - aapt_version = "aapt2", manifest = "AndroidManifest.xml", manifest_values = {"applicationId": "com.google.mediapipe.apps.objectdetectiongpu"}, multidex = "native", diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/MainActivity.java index 9d4324fde..8983e321a 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/MainActivity.java @@ -39,6 +39,12 @@ public class MainActivity extends AppCompatActivity { private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video"; private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.BACK; + // Flips the camera-preview frames vertically before sending them into FrameProcessor to be + // processed in a MediaPipe graph, and flips the processed frames back when they are displayed. + // This is needed because OpenGL represents images assuming the image origin is at the bottom-left + // corner, whereas MediaPipe in general assumes the image origin is at top-left. + private static final boolean FLIP_FRAMES_VERTICALLY = true; + static { // Load all native libraries needed by the app. System.loadLibrary("mediapipe_jni"); @@ -82,6 +88,7 @@ public class MainActivity extends AppCompatActivity { BINARY_GRAPH_NAME, INPUT_VIDEO_STREAM_NAME, OUTPUT_VIDEO_STREAM_NAME); + processor.getVideoSurfaceOutput().setFlipY(FLIP_FRAMES_VERTICALLY); PermissionHelper.checkAndRequestCameraPermissions(this); } @@ -90,6 +97,7 @@ public class MainActivity extends AppCompatActivity { protected void onResume() { super.onResume(); converter = new ExternalTextureConverter(eglManager.getContext()); + converter.setFlipY(FLIP_FRAMES_VERTICALLY); converter.setConsumer(processor); if (PermissionHelper.cameraPermissionsGranted(this)) { startCamera(); diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/layout/activity_main.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/layout/activity_main.xml index 22240a2d6..c19d7e628 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/layout/activity_main.xml +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetectiongpu/res/layout/activity_main.xml @@ -1,5 +1,5 @@ - - + diff --git a/mediapipe/examples/desktop/README.md b/mediapipe/examples/desktop/README.md index 21cb9b2a7..869d1efaa 100644 --- a/mediapipe/examples/desktop/README.md +++ b/mediapipe/examples/desktop/README.md @@ -14,10 +14,10 @@ bazel-bin/mediapipe/examples/desktop/hello_world/hello_world --logtostderr **TFlite Object Detection** -To build the objet detection demo using a TFLite model on desktop, use: +To build the object detection demo using a TFLite model on desktop, use: ``` -bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tflite --define 'MEDIAPIPE_DISABLE_GPU=1' +bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tflite --define MEDIAPIPE_DISABLE_GPU=1 ``` and run it using: @@ -35,7 +35,7 @@ To build the object detection demo using a TensorFlow model on desktop, use: ``` bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tensorflow \ - --define 'MEDIAPIPE_DISABLE_GPU=1' + --define MEDIAPIPE_DISABLE_GPU=1 ``` and run it using: @@ -46,3 +46,54 @@ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflo --input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file --alsologtostderr ``` + +**TFlite Face Detection** + +To build the face detection demo using a TFLite model on desktop, use: + +``` +bazel build -c opt mediapipe/examples/desktop/face_detection:face_detection_tflite --define MEDIAPIPE_DISABLE_GPU=1 +``` + +and run it using: + +``` +bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_tflite \ + --calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_desktop_tflite_graph.pbtxt \ + --input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file \ + --alsologtostderr +``` + +**TFlite Hand Detection** + +To build the hand detection demo using a TFLite model on desktop, use: + +``` +bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1 +``` + +and run it using: + +``` +bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \ + --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_detection_desktop_tflite_graph.pbtxt \ + --input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file \ + --alsologtostderr +``` + +**TFlite Hand Tracking** + +To build the hand tracking demo using a TFLite model on desktop, use: + +``` +bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1 +``` + +and run it using: + +``` +bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \ + --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_with_flag_desktop_tflite_graph.pbtxt \ + --input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file \ + --alsologtostderr +``` diff --git a/mediapipe/examples/desktop/media_sequence/BUILD b/mediapipe/examples/desktop/media_sequence/BUILD index c2a39f758..30b37d82a 100644 --- a/mediapipe/examples/desktop/media_sequence/BUILD +++ b/mediapipe/examples/desktop/media_sequence/BUILD @@ -35,5 +35,6 @@ cc_binary( deps = [ ":run_graph_file_io_main", "//mediapipe/graphs/media_sequence:clipped_images_from_file_at_24fps_calculators", + "//mediapipe/graphs/media_sequence:tvl1_flow_and_rgb_from_file_calculators", ], ) diff --git a/mediapipe/examples/desktop/media_sequence/README.md b/mediapipe/examples/desktop/media_sequence/README.md index 6be9014db..828d152e4 100644 --- a/mediapipe/examples/desktop/media_sequence/README.md +++ b/mediapipe/examples/desktop/media_sequence/README.md @@ -10,13 +10,12 @@ the data from TensorFlow into a tf.data.Dataset. Both pipelines can be imported and support a simple call to as_dataset() to make the data available. ### Demo data set -To generate the demo dataset you must have Tensorflow [version >= 1.19] -installed. Then the media_sequence_demo binary must be built from the top -directory in the mediapipe repo and the command to build the data set must be -run from the same directory. +To generate the demo dataset you must have Tensorflow installed. Then the +media_sequence_demo binary must be built from the top directory in the mediapipe +repo and the command to build the data set must be run from the same directory. ``` -bazel -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo \ - --define=MEDIAPIPE_DISABLE_GPU=1 +bazel build -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo \ + --define MEDIAPIPE_DISABLE_GPU=1 python -m mediapipe.examples.desktop.media_sequence.demo_dataset \ --alsologtostderr \ @@ -33,14 +32,13 @@ models in TensorFlow. You may only use this script in ways that comply with the Allen Institute for Artificial Intelligence's [license for the Charades data set.](https://allenai.org/plato/charades/license.txt) -To generate the Charades dataset you must have Tensorflow [version >= 1.19] -installed. Then the media_sequence_demo binary must be built from the top -directory in the mediapipe repo and the command to build the data set must be -run from the same directory. +To generate the Charades dataset you must have Tensorflow installed. Then the +media_sequence_demo binary must be built from the top directory in the mediapipe +repo and the command to build the data set must be run from the same directory. ``` -bazel -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo \ - --define=MEDIAPIPE_DISABLE_GPU=1 +bazel build -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo \ + --define MEDIAPIPE_DISABLE_GPU=1 python -m mediapipe.examples.desktop.media_sequence.charades_dataset \ --alsologtostderr \ @@ -49,3 +47,33 @@ python -m mediapipe.examples.desktop.media_sequence.charades_dataset \ media_sequence/media_sequence_demo \ --path_to_graph_directory=mediapipe/graphs/media_sequence/ ``` + +### Custom videos in the Kinetics format + +To produce data in the same format at the Kinetics data, use the kinetics.py +script. + +To generate the dataset you must have Tensorflow installed. Then the +media_sequence_demo binary must be built from the top directory in the mediapipe +repo and the command to build the data set must be run from the same directory. + +``` +echo "Credit for this video belongs to: ESA/Hubble; Music: Johan B. Monell" +wget https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1608c.mp4 -O /tmp/heic1608c.mp4 +CUSTOM_CSV=/tmp/custom_kinetics.csv +VIDEO_PATH=/tmp/heic1608c.mp4 +echo -e "video,time_start,time_end,split\n${VIDEO_PATH},0,10,custom" > ${CUSTOM_CSV} + +bazel build -c opt mediapipe/examples/desktop/media_sequence:media_sequence_demo \ + --define MEDIAPIPE_DISABLE_GPU=1 + +python -m mediapipe.examples.desktop.media_sequence.kinetics_dataset \ + --alsologtostderr \ + --splits_to_process=custom \ + --path_to_custom_csv=${CUSTOM_CSV} \ + --video_path_format_string={video} \ + --path_to_kinetics_data=/tmp/ms/kinetics/ \ + --path_to_mediapipe_binary=bazel-bin/mediapipe/examples/desktop/\ +media_sequence/media_sequence_demo \ + --path_to_graph_directory=mediapipe/graphs/media_sequence/ +``` diff --git a/mediapipe/examples/desktop/media_sequence/charades_dataset.py b/mediapipe/examples/desktop/media_sequence/charades_dataset.py index cb94e07eb..c0176c6cb 100644 --- a/mediapipe/examples/desktop/media_sequence/charades_dataset.py +++ b/mediapipe/examples/desktop/media_sequence/charades_dataset.py @@ -1,18 +1,18 @@ -r"""Copyright 2019 The MediaPipe Authors. +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -Code to download and parse the Charades dataset for TensorFlow models. +r"""Code to download and parse the Charades dataset for TensorFlow models. The [Charades data set](https://allenai.org/plato/charades/) is a data set of human action recognition collected with and maintained by the Allen Institute @@ -141,7 +141,6 @@ class Charades(object): shape [num_segments]. "num_segments": the number of segments in the example, shape []. "num_timesteps": the number of timesteps in the example, shape []. - "images": the [time, height, width, channels] tensor of images. """ def parse_fn(sequence_example): """Parses a Charades example.""" diff --git a/mediapipe/examples/desktop/media_sequence/demo_dataset.py b/mediapipe/examples/desktop/media_sequence/demo_dataset.py index 627149be3..79a6bacfe 100644 --- a/mediapipe/examples/desktop/media_sequence/demo_dataset.py +++ b/mediapipe/examples/desktop/media_sequence/demo_dataset.py @@ -1,18 +1,18 @@ -r"""Copyright 2019 The MediaPipe Authors. +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -A demo data set constructed with MediaSequence and MediaPipe. +r"""A demo data set constructed with MediaSequence and MediaPipe. This code demonstrates the steps for constructing a data set with MediaSequence. This code has two functions. First, it can be run as a module to download and diff --git a/mediapipe/examples/desktop/media_sequence/kinetics_dataset.py b/mediapipe/examples/desktop/media_sequence/kinetics_dataset.py new file mode 100644 index 000000000..83500a6f4 --- /dev/null +++ b/mediapipe/examples/desktop/media_sequence/kinetics_dataset.py @@ -0,0 +1,455 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Code to download and parse the Kinetics dataset for TensorFlow models. + +The [Kinetics data set]( +https://deepmind.com/research/open-source/open-source-datasets/kinetics/) +is a data set for human action recognition maintained by DeepMind and Google. +This script downloads the annotations and prepares data from similar annotations +if local video files are available. + +This script does not provide any means of accessing YouTube videos. + +Running this code as a module generates the data set on disk. First, the +required files are downloaded (_download_data) which enables constructing the +label map. Then (in generate_examples), for each split in the data set, the +metadata is generated from the annotations for each example +(_generate_metadata), and MediaPipe is used to fill in the video frames +(_run_mediapipe). This script processes local video files defined in a custom +CSV in a comparable manner to the Kinetics data set for evaluating and +predicting values on your own data. The data set is written to disk as a set of +numbered TFRecord files. + +The custom CSV format must match the Kinetics data set format, with columns +corresponding to [[label_name], video, start, end, split] followed by lines with +those fields. (Label_name is optional.) These field names can be used to +construct the paths to the video files using the Python string formatting +specification and the video_path_format_string flag: + --video_path_format_string="/path/to/video/{video}.mp4" + +Generating the data on disk can take considerable time and disk space. +(Image compression quality is the primary determiner of disk usage. TVL1 flow +determines runtime.) + +Once the data is on disk, reading the data as a tf.data.Dataset is accomplished +with the following lines: + + kinetics = Kinetics("kinetics_data_path") + dataset = kinetics.as_dataset("custom") + # implement additional processing and batching here + images_and_labels = dataset.make_one_shot_iterator().get_next() + images = images_and_labels["images"] + labels = image_and_labels["labels"] + +This data is structured for per-clip action classification where images is +the sequence of images and labels are a one-hot encoded value. See +as_dataset() for more details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import csv +import os +import random +import subprocess +import sys +import tempfile +import urllib +import zipfile +from absl import app +from absl import flags +from absl import logging +import tensorflow as tf +from mediapipe.util.sequence import media_sequence as ms + +CITATION = r"""@article{kay2017kinetics, + title={The kinetics human action video dataset}, + author={Kay, Will and Carreira, Joao and Simonyan, Karen and Zhang, Brian and Hillier, Chloe and Vijayanarasimhan, Sudheendra and Viola, Fabio and Green, Tim and Back, Trevor and Natsev, Paul and others}, + journal={arXiv preprint arXiv:1705.06950}, + year={2017}, + url = {https://deepmind.com/research/open-source/kinetics}, +}""" +ANNOTATION_URL = "https://storage.googleapis.com/deepmind-media/research/Kinetics_700.zip" +SECONDS_TO_MICROSECONDS = 1000000 +GRAPHS = ["tvl1_flow_and_rgb_from_file.pbtxt"] +FILEPATTERN = "kinetics_700_%s_25fps_rgb_flow" +SPLITS = { + "train": { + "shards": 1000, + "examples": 545317}, + "val": {"shards": 100, + "examples": 35000}, + "test": {"shards": 100, + "examples": 70000}, + "custom": {"csv": None, # Add a CSV for your own data here. + "shards": 1, # Change this number to increase sharding. + "examples": -1}, # Negative 1 allows any number of examples. +} +NUM_CLASSES = 700 + + +class Kinetics(object): + """Generates and loads the Kinetics data set.""" + + def __init__(self, path_to_data): + if not path_to_data: + raise ValueError("You must supply the path to the data directory.") + self.path_to_data = path_to_data + + def as_dataset(self, split, shuffle=False, repeat=False, + serialized_prefetch_size=32, decoded_prefetch_size=32): + """Returns Kinetics as a tf.data.Dataset. + + After running this function, calling padded_batch() on the Dataset object + will produce batches of data, but additional preprocessing may be desired. + If using padded_batch, the indicator_matrix output distinguishes valid + from padded frames. + + Args: + split: either "train" or "test" + shuffle: if true, shuffles both files and examples. + repeat: if true, repeats the data set forever. + serialized_prefetch_size: the buffer size for reading from disk. + decoded_prefetch_size: the buffer size after decoding. + Returns: + A tf.data.Dataset object with the following structure: { + "images": float tensor, shape [time, height, width, channels] + "flow": float tensor, shape [time, height, width, 2] + "labels": float32 tensor, shape [num_classes], one hot encoded + "num_frames": int32 tensor, shape [], number of frames in the sequence + """ + def parse_fn(sequence_example): + """Parses a Kinetics example.""" + context_features = { + ms.get_example_id_key(): ms.get_example_id_default_parser(), + ms.get_clip_label_string_key(): tf.FixedLenFeature((), tf.string), + ms.get_clip_label_index_key(): tf.FixedLenFeature((), tf.int64), + } + + sequence_features = { + ms.get_image_encoded_key(): ms.get_image_encoded_default_parser(), + ms.get_forward_flow_encoded_key(): + ms.get_forward_flow_encoded_default_parser(), + } + parsed_context, parsed_sequence = tf.io.parse_single_sequence_example( + sequence_example, context_features, sequence_features) + + target = tf.one_hot(parsed_context[ms.get_clip_label_index_key()], 700) + + images = tf.image.convert_image_dtype( + tf.map_fn(tf.image.decode_jpeg, + parsed_sequence[ms.get_image_encoded_key()], + back_prop=False, + dtype=tf.uint8), tf.float32) + num_frames = tf.shape(images)[0] + + flow = tf.image.convert_image_dtype( + tf.map_fn(tf.image.decode_jpeg, + parsed_sequence[ms.get_forward_flow_encoded_key()], + back_prop=False, + dtype=tf.uint8), tf.float32) + # The flow is quantized for storage in JPEGs by the FlowToImageCalculator. + # The quantization needs to be inverted. + flow = (flow[:, :, :, :2] - 0.5) * 2 * 20. + + output_dict = { + "labels": target, + "images": images, + "flow": flow, + "num_frames": num_frames, + } + return output_dict + + if split not in SPLITS: + raise ValueError("Split %s not in %s" % split, str(SPLITS.keys())) + all_shards = tf.io.gfile.glob( + os.path.join(self.path_to_data, FILEPATTERN % split + "-*-of-*")) + random.shuffle(all_shards) + all_shards_dataset = tf.data.Dataset.from_tensor_slices(all_shards) + cycle_length = min(16, len(all_shards)) + dataset = all_shards_dataset.apply( + tf.contrib.data.parallel_interleave( + tf.data.TFRecordDataset, + cycle_length=cycle_length, + block_length=1, sloppy=True, + buffer_output_elements=serialized_prefetch_size)) + dataset = dataset.prefetch(serialized_prefetch_size) + if shuffle: + dataset = dataset.shuffle(serialized_prefetch_size) + if repeat: + dataset = dataset.repeat() + dataset = dataset.map(parse_fn) + dataset = dataset.prefetch(decoded_prefetch_size) + return dataset + + def generate_examples(self, path_to_mediapipe_binary, + path_to_graph_directory, + only_generate_metadata=False, + splits_to_process="train,val,test", + video_path_format_string=None, + download_labels_for_map=True): + """Downloads data and generates sharded TFRecords. + + Downloads the data files, generates metadata, and processes the metadata + with MediaPipe to produce tf.SequenceExamples for training. The resulting + files can be read with as_dataset(). After running this function the + original data files can be deleted. + + Args: + path_to_mediapipe_binary: Path to the compiled binary for the BUILD target + mediapipe/examples/desktop/demo:media_sequence_demo. + path_to_graph_directory: Path to the directory with MediaPipe graphs in + mediapipe/graphs/media_sequence/. + only_generate_metadata: If true, do not run mediapipe and write the + metadata to disk instead. + splits_to_process: csv string of which splits to process. Allows providing + a custom CSV with the CSV flag. The original data is still downloaded + to generate the label_map. + video_path_format_string: The format string for the path to local files. + download_labels_for_map: If true, download the annotations to create the + label map. + """ + if not path_to_mediapipe_binary: + raise ValueError( + "You must supply the path to the MediaPipe binary for " + "mediapipe/examples/desktop/demo:media_sequence_demo.") + if not path_to_graph_directory: + raise ValueError( + "You must supply the path to the directory with MediaPipe graphs in " + "mediapipe/graphs/media_sequence/.") + logging.info("Downloading data.") + download_output = self._download_data(download_labels_for_map) + for key in splits_to_process.split(","): + logging.info("Generating metadata for split: %s", key) + all_metadata = list(self._generate_metadata( + key, download_output, video_path_format_string)) + logging.info("An example of the metadata: ") + logging.info(all_metadata[0]) + random.seed(47) + random.shuffle(all_metadata) + shards = SPLITS[key]["shards"] + shard_names = [os.path.join( + self.path_to_data, FILEPATTERN % key + "-%05d-of-%05d" % ( + i, shards)) for i in range(shards)] + writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names] + with _close_on_exit(writers) as writers: + for i, seq_ex in enumerate(all_metadata): + if not only_generate_metadata: + print("Processing example %d of %d (%d%%) \r" % ( + i, len(all_metadata), i * 100 / len(all_metadata)), end="") + for graph in GRAPHS: + graph_path = os.path.join(path_to_graph_directory, graph) + seq_ex = self._run_mediapipe( + path_to_mediapipe_binary, seq_ex, graph_path) + writers[i % len(writers)].write(seq_ex.SerializeToString()) + logging.info("Data extraction complete.") + + def _generate_metadata(self, key, download_output, + video_path_format_string=None): + """For each row in the annotation CSV, generates the corresponding metadata. + + Args: + key: which split to process. + download_output: the tuple output of _download_data containing + - annotations_files: dict of keys to CSV annotation paths. + - label_map: dict mapping from label strings to numeric indices. + video_path_format_string: The format string for the path to local files. + Yields: + Each tf.SequenceExample of metadata, ready to pass to MediaPipe. + """ + annotations_files, label_map = download_output + with open(annotations_files[key], "r") as annotations: + reader = csv.reader(annotations) + for i, csv_row in enumerate(reader): + if i == 0: # the first row is the header + continue + # rename the row with a constitent set of names. + if len(csv_row) == 5: + row = dict(zip(["label_name", "video", "start", "end", "split"], + csv_row)) + else: + row = dict(zip(["video", "start", "end", "split"], + csv_row)) + metadata = tf.train.SequenceExample() + ms.set_example_id(bytes23(row["video"] + "_" + row["start"]), + metadata) + ms.set_clip_media_id(bytes23(row["video"]), metadata) + ms.set_clip_alternative_media_id(bytes23(row["split"]), metadata) + if video_path_format_string: + filepath = video_path_format_string.format(**row) + ms.set_clip_data_path(bytes23(filepath), metadata) + assert row["start"].isdigit(), "Invalid row: %s" % str(row) + assert row["end"].isdigit(), "Invalid row: %s" % str(row) + if "label_name" in row: + ms.set_clip_label_string([bytes23(row["label_name"])], metadata) + if label_map: + ms.set_clip_label_index([label_map[row["label_name"]]], metadata) + yield metadata + + def _download_data(self, download_labels_for_map): + """Downloads and extracts data if not already available.""" + if sys.version_info >= (3, 0): + urlretrieve = urllib.request.urlretrieve + else: + urlretrieve = urllib.urlretrieve + logging.info("Creating data directory.") + tf.io.gfile.makedirs(self.path_to_data) + logging.info("Downloading annotations.") + paths = {} + if download_labels_for_map: + zip_path = os.path.join(self.path_to_data, ANNOTATION_URL.split("/")[-1]) + if not tf.io.gfile.exists(zip_path): + urlretrieve(ANNOTATION_URL, zip_path) + with zipfile.ZipFile(zip_path) as annotations_zip: + annotations_zip.extractall(self.path_to_data) + for split in ["train", "test", "val"]: + zip_path = os.path.join(self.path_to_data, + "kinetics_700_%s.zip" % split) + csv_path = zip_path.replace(".zip", ".csv") + if not tf.io.gfile.exists(csv_path): + with zipfile.ZipFile(zip_path) as annotations_zip: + annotations_zip.extractall(self.path_to_data) + paths[split] = csv_path + for split, contents in SPLITS.items(): + if "csv" in contents and contents["csv"]: + paths[split] = contents["csv"] + label_map = (self.get_label_map_and_verify_example_counts(paths) if + download_labels_for_map else None) + return paths, label_map + + def _run_mediapipe(self, path_to_mediapipe_binary, sequence_example, graph): + """Runs MediaPipe over MediaSequence tf.train.SequenceExamples. + + Args: + path_to_mediapipe_binary: Path to the compiled binary for the BUILD target + mediapipe/examples/desktop/demo:media_sequence_demo. + sequence_example: The SequenceExample with metadata or partial data file. + graph: The path to the graph that extracts data to add to the + SequenceExample. + Returns: + A copy of the input SequenceExample with additional data fields added + by the MediaPipe graph. + Raises: + RuntimeError: if MediaPipe returns an error or fails to run the graph. + """ + if not path_to_mediapipe_binary: + raise ValueError("--path_to_mediapipe_binary must be specified.") + input_fd, input_filename = tempfile.mkstemp() + output_fd, output_filename = tempfile.mkstemp() + cmd = [path_to_mediapipe_binary, + "--calculator_graph_config_file=%s" % graph, + "--input_side_packets=input_sequence_example=%s" % input_filename, + "--output_side_packets=output_sequence_example=%s" % output_filename] + with open(input_filename, "wb") as input_file: + input_file.write(sequence_example.SerializeToString()) + mediapipe_output = subprocess.check_output(cmd) + if b"Failed to run the graph" in mediapipe_output: + raise RuntimeError(mediapipe_output) + with open(output_filename, "rb") as output_file: + output_example = tf.train.SequenceExample() + output_example.ParseFromString(output_file.read()) + os.close(input_fd) + os.remove(input_filename) + os.close(output_fd) + os.remove(output_filename) + return output_example + + def get_label_map_and_verify_example_counts(self, paths): + """Verify the number of examples and labels have not changed.""" + for name, path in paths.items(): + with open(path, "r") as f: + lines = f.readlines() + # the header adds one line and one "key". + num_examples = len(lines) - 1 + keys = [l.split(",")[0] for l in lines] + label_map = None + if name == "train": + classes = sorted(list(set(keys[1:]))) + num_keys = len(set(keys)) - 1 + assert NUM_CLASSES == num_keys, ( + "Found %d labels for split: %s, should be %d" % ( + num_keys, name, NUM_CLASSES)) + label_map = dict(zip(classes, range(len(classes)))) + if SPLITS[name]["examples"] > 0: + assert SPLITS[name]["examples"] == num_examples, ( + "Found %d examples for split: %s, should be %d" % ( + num_examples, name, SPLITS[name]["examples"])) + return label_map + + +def bytes23(string): + """Creates a bytes string in either Python 2 or 3.""" + if sys.version_info >= (3, 0): + return bytes(string, "utf8") + else: + return bytes(string) + + +@contextlib.contextmanager +def _close_on_exit(writers): + """Call close on all writers on exit.""" + try: + yield writers + finally: + for writer in writers: + writer.close() + + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + if flags.FLAGS.path_to_custom_csv: + SPLITS["custom"]["csv"] = flags.FLAGS.path_to_custom_csv + Kinetics(flags.FLAGS.path_to_kinetics_data).generate_examples( + flags.FLAGS.path_to_mediapipe_binary, + flags.FLAGS.path_to_graph_directory, + flags.FLAGS.only_generate_metadata, + flags.FLAGS.splits_to_process, + flags.FLAGS.video_path_format_string, + flags.FLAGS.download_labels_for_map) + +if __name__ == "__main__": + flags.DEFINE_string("path_to_kinetics_data", + "", + "Path to directory to write data to.") + flags.DEFINE_string("path_to_mediapipe_binary", + "", + "Path to the MediaPipe run_graph_file_io_main binary.") + flags.DEFINE_string("path_to_graph_directory", + "", + "Path to directory containing the graph files.") + flags.DEFINE_boolean("only_generate_metadata", + False, + "If true, only generate the metadata files.") + flags.DEFINE_boolean("download_labels_for_map", + True, + "If true, download the annotations to construct the " + "label map.") + flags.DEFINE_string("splits_to_process", + "custom", + "Process these splits. Useful for custom data splits.") + flags.DEFINE_string("video_path_format_string", + None, + "The format string for the path to local video files. " + "Uses the Python string.format() syntax with possible " + "arguments of {video}, {start}, {end}, {label_name}, and " + "{split}, corresponding to columns of the data csvs.") + flags.DEFINE_string("path_to_custom_csv", + None, + "If present, processes this CSV as a custom split.") + app.run(main) diff --git a/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc index 31a6df5bf..bc2d911bf 100644 --- a/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc +++ b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc @@ -37,7 +37,7 @@ DEFINE_string(output_side_packets, "", "side packets and paths to write to disk for the " "CalculatorGraph."); -::mediapipe::Status RunMediaPipeGraph() { +::mediapipe::Status RunMPPGraph() { std::string calculator_graph_config_contents; RETURN_IF_ERROR(mediapipe::file::GetContents( FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); @@ -84,7 +84,7 @@ DEFINE_string(output_side_packets, "", int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - ::mediapipe::Status run_status = RunMediaPipeGraph(); + ::mediapipe::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); } else { diff --git a/mediapipe/examples/desktop/object_detection/BUILD b/mediapipe/examples/desktop/object_detection/BUILD index eda9e7023..4ce4fd900 100644 --- a/mediapipe/examples/desktop/object_detection/BUILD +++ b/mediapipe/examples/desktop/object_detection/BUILD @@ -33,6 +33,7 @@ cc_library( "@org_tensorflow//tensorflow/core/kernels:fused_batch_norm_op", "@org_tensorflow//tensorflow/core/kernels:gather_op", "@org_tensorflow//tensorflow/core/kernels:identity_op", + "@org_tensorflow//tensorflow/core/kernels:logging_ops", "@org_tensorflow//tensorflow/core/kernels:matmul_op", "@org_tensorflow//tensorflow/core/kernels:non_max_suppression_op", "@org_tensorflow//tensorflow/core/kernels:pack_op", @@ -49,7 +50,9 @@ cc_library( "@org_tensorflow//tensorflow/core/kernels:topk_op", "@org_tensorflow//tensorflow/core/kernels:transpose_op", "@org_tensorflow//tensorflow/core/kernels:unpack_op", + "@org_tensorflow//tensorflow/core/kernels/data:tensor_dataset_op", ], + alwayslink = 1, ) cc_binary( @@ -58,6 +61,7 @@ cc_binary( ":object_detection_tensorflow_deps", "//mediapipe/examples/desktop:simple_run_graph_main", "//mediapipe/graphs/object_detection:desktop_tensorflow_calculators", + "@org_tensorflow//tensorflow/core:direct_session", ], ) diff --git a/mediapipe/examples/desktop/simple_run_graph_main.cc b/mediapipe/examples/desktop/simple_run_graph_main.cc index 214ad1fdf..7fc93142c 100644 --- a/mediapipe/examples/desktop/simple_run_graph_main.cc +++ b/mediapipe/examples/desktop/simple_run_graph_main.cc @@ -31,7 +31,7 @@ DEFINE_string(input_side_packets, "", "for the CalculatorGraph. All values will be treated as the " "string type even if they represent doubles, floats, etc."); -::mediapipe::Status RunMediaPipeGraph() { +::mediapipe::Status RunMPPGraph() { std::string calculator_graph_config_contents; RETURN_IF_ERROR(mediapipe::file::GetContents( FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); @@ -60,7 +60,7 @@ DEFINE_string(input_side_packets, "", int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - ::mediapipe::Status run_status = RunMediaPipeGraph(); + ::mediapipe::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); } else { diff --git a/mediapipe/examples/ios/BUILD b/mediapipe/examples/ios/BUILD new file mode 100644 index 000000000..3cf7f234b --- /dev/null +++ b/mediapipe/examples/ios/BUILD @@ -0,0 +1,22 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +alias( + name = "provisioning_profile", + actual = "//mediapipe:provisioning_profile.mobileprovision", +) diff --git a/mediapipe/examples/ios/README.md b/mediapipe/examples/ios/README.md new file mode 100644 index 000000000..0a3f9b4bf --- /dev/null +++ b/mediapipe/examples/ios/README.md @@ -0,0 +1,18 @@ +This directory contains example MediaPipe applications on iOS. + +| Use Case | Directory | +|---------------------------------------|:-----------------------------------:| +| Edge Detection on GPU | edgedetection | +| Face Detection on CPU | facedetectioncpu | +| Face Detection on GPU | facedetectiongpu | +| Object Detection on CPU | objectdetectioncpu | +| Object Detection on GPU | objectdetectiongpu | +| Hand Detection on GPU | handdetectiongpu | +| Hand Tracking on GPU | handtrackinggpu | + +For instance, to build an example app for face detection on CPU, run: + +```bash +bazel build -c opt --config=ios_arm64 --xcode_version=$XCODE_VERSION --cxxopt='-std=c++14' mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp +``` +(Note: with your own $XCODE_VERSION) diff --git a/mediapipe/examples/ios/edgedetectiongpu/AppDelegate.h b/mediapipe/examples/ios/edgedetectiongpu/AppDelegate.h new file mode 100644 index 000000000..6b0377ef2 --- /dev/null +++ b/mediapipe/examples/ios/edgedetectiongpu/AppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/mediapipe/examples/ios/edgedetectiongpu/AppDelegate.m b/mediapipe/examples/ios/edgedetectiongpu/AppDelegate.m new file mode 100644 index 000000000..9e1b7ff0e --- /dev/null +++ b/mediapipe/examples/ios/edgedetectiongpu/AppDelegate.m @@ -0,0 +1,59 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + // Override point for customization after application launch. + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { + // Sent when the application is about to move from active to inactive state. This can occur for + // certain types of temporary interruptions (such as an incoming phone call or SMS message) or + // when the user quits the application and it begins the transition to the background state. Use + // this method to pause ongoing tasks, disable timers, and invalidate graphics rendering + // callbacks. Games should use this method to pause the game. +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { + // Use this method to release shared resources, save user data, invalidate timers, and store + // enough application state information to restore your application to its current state in case + // it is terminated later. If your application supports background execution, this method is + // called instead of applicationWillTerminate: when the user quits. +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { + // Called as part of the transition from the background to the active state; here you can undo + // many of the changes made on entering the background. +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { + // Restart any tasks that were paused (or not yet started) while the application was inactive. If + // the application was previously in the background, optionally refresh the user interface. +} + +- (void)applicationWillTerminate:(UIApplication *)application { + // Called when the application is about to terminate. Save data if appropriate. See also + // applicationDidEnterBackground:. +} + +@end diff --git a/mediapipe/examples/ios/edgedetectiongpu/Assets.xcassets/AppIcon.appiconset/Contents.json b/mediapipe/examples/ios/edgedetectiongpu/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 000000000..a1895a242 --- /dev/null +++ b/mediapipe/examples/ios/edgedetectiongpu/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,99 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + }, + { + "idiom" : "ios-marketing", + "size" : "1024x1024", + "scale" : "1x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} + diff --git a/mediapipe/examples/ios/edgedetectiongpu/Assets.xcassets/Contents.json b/mediapipe/examples/ios/edgedetectiongpu/Assets.xcassets/Contents.json new file mode 100644 index 000000000..7afcdfaf8 --- /dev/null +++ b/mediapipe/examples/ios/edgedetectiongpu/Assets.xcassets/Contents.json @@ -0,0 +1,7 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} + diff --git a/mediapipe/examples/ios/edgedetectiongpu/BUILD b/mediapipe/examples/ios/edgedetectiongpu/BUILD new file mode 100644 index 000000000..aa5f721c1 --- /dev/null +++ b/mediapipe/examples/ios/edgedetectiongpu/BUILD @@ -0,0 +1,65 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +MIN_IOS_VERSION = "10.0" + +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_application", +) + +ios_application( + name = "EdgeDetectionGpuApp", + bundle_id = "com.google.mediapipe.EdgeDetectionGpu", + families = [ + "iphone", + "ipad", + ], + infoplists = ["Info.plist"], + minimum_os_version = MIN_IOS_VERSION, + provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + deps = [":EdgeDetectionGpuAppLibrary"], +) + +objc_library( + name = "EdgeDetectionGpuAppLibrary", + srcs = [ + "AppDelegate.m", + "ViewController.mm", + "main.m", + ], + hdrs = [ + "AppDelegate.h", + "ViewController.h", + ], + data = [ + "Base.lproj/LaunchScreen.storyboard", + "Base.lproj/Main.storyboard", + "//mediapipe/graphs/edge_detection:mobile_gpu_binary_graph", + ], + sdk_frameworks = [ + "AVFoundation", + "CoreGraphics", + "CoreMedia", + "UIKit", + ], + deps = [ + "//mediapipe/graphs/edge_detection:mobile_calculators", + "//mediapipe/objc:mediapipe_framework_ios", + "//mediapipe/objc:mediapipe_input_sources_ios", + "//mediapipe/objc:mediapipe_layer_renderer", + ], +) diff --git a/mediapipe/examples/ios/edgedetectiongpu/Base.lproj/LaunchScreen.storyboard b/mediapipe/examples/ios/edgedetectiongpu/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 000000000..bfa361294 --- /dev/null +++ b/mediapipe/examples/ios/edgedetectiongpu/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/edgedetectiongpu/Base.lproj/Main.storyboard b/mediapipe/examples/ios/edgedetectiongpu/Base.lproj/Main.storyboard new file mode 100644 index 000000000..e3bd912a4 --- /dev/null +++ b/mediapipe/examples/ios/edgedetectiongpu/Base.lproj/Main.storyboard @@ -0,0 +1,51 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/edgedetectiongpu/Info.plist b/mediapipe/examples/ios/edgedetectiongpu/Info.plist new file mode 100644 index 000000000..c7f7ec816 --- /dev/null +++ b/mediapipe/examples/ios/edgedetectiongpu/Info.plist @@ -0,0 +1,43 @@ + + + + + NSCameraUsageDescription + This app uses the camera to demonstrate live video processing. + CFBundleDevelopmentRegion + en + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + + + diff --git a/mediapipe/examples/ios/edgedetectiongpu/ViewController.h b/mediapipe/examples/ios/edgedetectiongpu/ViewController.h new file mode 100644 index 000000000..e0a5a6367 --- /dev/null +++ b/mediapipe/examples/ios/edgedetectiongpu/ViewController.h @@ -0,0 +1,19 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface ViewController : UIViewController + +@end diff --git a/mediapipe/examples/ios/edgedetectiongpu/ViewController.mm b/mediapipe/examples/ios/edgedetectiongpu/ViewController.mm new file mode 100644 index 000000000..98ee7fd99 --- /dev/null +++ b/mediapipe/examples/ios/edgedetectiongpu/ViewController.mm @@ -0,0 +1,176 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "ViewController.h" + +#import "mediapipe/objc/MPPGraph.h" +#import "mediapipe/objc/MPPCameraInputSource.h" +#import "mediapipe/objc/MPPLayerRenderer.h" + +static NSString* const kGraphName = @"mobile_gpu"; + +static const char* kInputStream = "input_video"; +static const char* kOutputStream = "output_video"; +static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; + +@interface ViewController () + +// The MediaPipe graph currently in use. Initialized in viewDidLoad, started in viewWillAppear: and +// sent video frames on _videoQueue. +@property(nonatomic) MPPGraph* mediapipeGraph; + +@end + +@implementation ViewController { + /// Handles camera access via AVCaptureSession library. + MPPCameraInputSource* _cameraSource; + + /// Inform the user when camera is unavailable. + IBOutlet UILabel* _noCameraLabel; + /// Display the camera preview frames. + IBOutlet UIView* _liveView; + /// Render frames in a layer. + MPPLayerRenderer* _renderer; + + /// Process camera frames on this queue. + dispatch_queue_t _videoQueue; +} + +#pragma mark - Cleanup methods + +- (void)dealloc { + self.mediapipeGraph.delegate = nil; + [self.mediapipeGraph cancel]; + // Ignore errors since we're cleaning up. + [self.mediapipeGraph closeAllInputStreamsWithError:nil]; + [self.mediapipeGraph waitUntilDoneWithError:nil]; +} + +#pragma mark - MediaPipe graph methods + ++ (MPPGraph*)loadGraphFromResource:(NSString*)resource { + // Load the graph config resource. + NSError* configLoadError = nil; + NSBundle* bundle = [NSBundle bundleForClass:[self class]]; + if (!resource || resource.length == 0) { + return nil; + } + NSURL* graphURL = [bundle URLForResource:resource withExtension:@"binarypb"]; + NSData* data = [NSData dataWithContentsOfURL:graphURL options:0 error:&configLoadError]; + if (!data) { + NSLog(@"Failed to load MediaPipe graph config: %@", configLoadError); + return nil; + } + + // Parse the graph config resource into mediapipe::CalculatorGraphConfig proto object. + mediapipe::CalculatorGraphConfig config; + config.ParseFromArray(data.bytes, data.length); + + // Create MediaPipe graph with mediapipe::CalculatorGraphConfig proto object. + MPPGraph* newGraph = [[MPPGraph alloc] initWithGraphConfig:config]; + [newGraph addFrameOutputStream:kOutputStream outputPacketType:MediaPipePacketPixelBuffer]; + return newGraph; +} + +#pragma mark - UIViewController methods + +- (void)viewDidLoad { + [super viewDidLoad]; + + _renderer = [[MPPLayerRenderer alloc] init]; + _renderer.layer.frame = _liveView.layer.bounds; + [_liveView.layer addSublayer:_renderer.layer]; + _renderer.frameScaleMode = MediaPipeFrameScaleFillAndCrop; + + dispatch_queue_attr_t qosAttribute = dispatch_queue_attr_make_with_qos_class( + DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0); + _videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute); + + _cameraSource = [[MPPCameraInputSource alloc] init]; + [_cameraSource setDelegate:self queue:_videoQueue]; + _cameraSource.sessionPreset = AVCaptureSessionPresetHigh; + _cameraSource.cameraPosition = AVCaptureDevicePositionBack; + // The frame's native format is rotated with respect to the portrait orientation. + _cameraSource.orientation = AVCaptureVideoOrientationPortrait; + + self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName]; + self.mediapipeGraph.delegate = self; + // Set maxFramesInFlight to a small value to avoid memory contention for real-time processing. + self.mediapipeGraph.maxFramesInFlight = 2; +} + +// In this application, there is only one ViewController which has no navigation to other view +// controllers, and there is only one View with live display showing the result of running the +// MediaPipe graph on the live video feed. If more view controllers are needed later, the graph +// setup/teardown and camera start/stop logic should be updated appropriately in response to the +// appearance/disappearance of this ViewController, as viewWillAppear: can be invoked multiple times +// depending on the application navigation flow in that case. +- (void)viewWillAppear:(BOOL)animated { + [super viewWillAppear:animated]; + + [_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { + if (granted) { + [self startGraphAndCamera]; + dispatch_async(dispatch_get_main_queue(), ^{ + [_noCameraLabel setHidden:YES]; + }); + } + }]; +} + +- (void)startGraphAndCamera { + // Start running self.mediapipeGraph. + NSError* error; + if (![self.mediapipeGraph startWithError:&error]) { + NSLog(@"Failed to start graph: %@", error); + } + + // Start fetching frames from the camera. + dispatch_async(_videoQueue, ^{ + [_cameraSource start]; + }); +} + +#pragma mark - MPPGraphDelegate methods + +// Receives CVPixelBufferRef from the MediaPipe graph. Invoked on a MediaPipe worker thread. +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer + fromStream:(const std::string&)streamName { + if (streamName == kOutputStream) { + // Display the captured image on the screen. + CVPixelBufferRetain(pixelBuffer); + dispatch_async(dispatch_get_main_queue(), ^{ + [_renderer renderPixelBuffer:pixelBuffer]; + CVPixelBufferRelease(pixelBuffer); + }); + } +} + +#pragma mark - MPPInputSourceDelegate methods + +// Must be invoked on _videoQueue. +- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer + timestamp:(CMTime)timestamp + fromSource:(MPPInputSource*)source { + if (source != _cameraSource) { + NSLog(@"Unknown source: %@", source); + return; + } + [self.mediapipeGraph sendPixelBuffer:imageBuffer + intoStream:kInputStream + packetType:MediaPipePacketPixelBuffer]; +} + +@end diff --git a/mediapipe/examples/ios/edgedetectiongpu/main.m b/mediapipe/examples/ios/edgedetectiongpu/main.m new file mode 100644 index 000000000..7ffe5ea5d --- /dev/null +++ b/mediapipe/examples/ios/edgedetectiongpu/main.m @@ -0,0 +1,22 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "AppDelegate.h" + +int main(int argc, char * argv[]) { + @autoreleasepool { + return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); + } +} diff --git a/mediapipe/examples/ios/facedetectioncpu/AppDelegate.h b/mediapipe/examples/ios/facedetectioncpu/AppDelegate.h new file mode 100644 index 000000000..6b0377ef2 --- /dev/null +++ b/mediapipe/examples/ios/facedetectioncpu/AppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/mediapipe/examples/ios/facedetectioncpu/AppDelegate.m b/mediapipe/examples/ios/facedetectioncpu/AppDelegate.m new file mode 100644 index 000000000..9e1b7ff0e --- /dev/null +++ b/mediapipe/examples/ios/facedetectioncpu/AppDelegate.m @@ -0,0 +1,59 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + // Override point for customization after application launch. + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { + // Sent when the application is about to move from active to inactive state. This can occur for + // certain types of temporary interruptions (such as an incoming phone call or SMS message) or + // when the user quits the application and it begins the transition to the background state. Use + // this method to pause ongoing tasks, disable timers, and invalidate graphics rendering + // callbacks. Games should use this method to pause the game. +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { + // Use this method to release shared resources, save user data, invalidate timers, and store + // enough application state information to restore your application to its current state in case + // it is terminated later. If your application supports background execution, this method is + // called instead of applicationWillTerminate: when the user quits. +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { + // Called as part of the transition from the background to the active state; here you can undo + // many of the changes made on entering the background. +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { + // Restart any tasks that were paused (or not yet started) while the application was inactive. If + // the application was previously in the background, optionally refresh the user interface. +} + +- (void)applicationWillTerminate:(UIApplication *)application { + // Called when the application is about to terminate. Save data if appropriate. See also + // applicationDidEnterBackground:. +} + +@end diff --git a/mediapipe/examples/ios/facedetectioncpu/Assets.xcassets/AppIcon.appiconset/Contents.json b/mediapipe/examples/ios/facedetectioncpu/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 000000000..a1895a242 --- /dev/null +++ b/mediapipe/examples/ios/facedetectioncpu/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,99 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + }, + { + "idiom" : "ios-marketing", + "size" : "1024x1024", + "scale" : "1x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} + diff --git a/mediapipe/examples/ios/facedetectioncpu/Assets.xcassets/Contents.json b/mediapipe/examples/ios/facedetectioncpu/Assets.xcassets/Contents.json new file mode 100644 index 000000000..7afcdfaf8 --- /dev/null +++ b/mediapipe/examples/ios/facedetectioncpu/Assets.xcassets/Contents.json @@ -0,0 +1,7 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} + diff --git a/mediapipe/examples/ios/facedetectioncpu/BUILD b/mediapipe/examples/ios/facedetectioncpu/BUILD new file mode 100644 index 000000000..cd97b42d8 --- /dev/null +++ b/mediapipe/examples/ios/facedetectioncpu/BUILD @@ -0,0 +1,75 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +MIN_IOS_VERSION = "10.0" + +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_application", +) + +ios_application( + name = "FaceDetectionCpuApp", + bundle_id = "com.google.mediapipe.FaceDetectionCpu", + families = [ + "iphone", + "ipad", + ], + infoplists = ["Info.plist"], + minimum_os_version = MIN_IOS_VERSION, + provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + deps = [ + ":FaceDetectionCpuAppLibrary", + "@ios_opencv//:OpencvFramework", + ], +) + +objc_library( + name = "FaceDetectionCpuAppLibrary", + srcs = [ + "AppDelegate.m", + "ViewController.mm", + "main.m", + ], + hdrs = [ + "AppDelegate.h", + "ViewController.h", + ], + data = [ + "Base.lproj/LaunchScreen.storyboard", + "Base.lproj/Main.storyboard", + "//mediapipe/graphs/face_detection:mobile_cpu_binary_graph", + "//mediapipe/models:face_detection_front.tflite", + "//mediapipe/models:face_detection_front_labelmap.txt", + ], + sdk_frameworks = [ + "AVFoundation", + "CoreGraphics", + "CoreMedia", + "UIKit", + ], + deps = [ + "//mediapipe/objc:mediapipe_framework_ios", + "//mediapipe/objc:mediapipe_input_sources_ios", + "//mediapipe/objc:mediapipe_layer_renderer", + ] + select({ + "//mediapipe:ios_i386": [], + "//mediapipe:ios_x86_64": [], + "//conditions:default": [ + "//mediapipe/graphs/face_detection:mobile_calculators", + ], + }), +) diff --git a/mediapipe/examples/ios/facedetectioncpu/Base.lproj/LaunchScreen.storyboard b/mediapipe/examples/ios/facedetectioncpu/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 000000000..bfa361294 --- /dev/null +++ b/mediapipe/examples/ios/facedetectioncpu/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/facedetectioncpu/Base.lproj/Main.storyboard b/mediapipe/examples/ios/facedetectioncpu/Base.lproj/Main.storyboard new file mode 100644 index 000000000..1b47bc773 --- /dev/null +++ b/mediapipe/examples/ios/facedetectioncpu/Base.lproj/Main.storyboard @@ -0,0 +1,51 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/facedetectioncpu/Info.plist b/mediapipe/examples/ios/facedetectioncpu/Info.plist new file mode 100644 index 000000000..30db14c62 --- /dev/null +++ b/mediapipe/examples/ios/facedetectioncpu/Info.plist @@ -0,0 +1,42 @@ + + + + + NSCameraUsageDescription + This app uses the camera to demonstrate live video processing. + CFBundleDevelopmentRegion + en + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/mediapipe/examples/ios/facedetectioncpu/ViewController.h b/mediapipe/examples/ios/facedetectioncpu/ViewController.h new file mode 100644 index 000000000..e0a5a6367 --- /dev/null +++ b/mediapipe/examples/ios/facedetectioncpu/ViewController.h @@ -0,0 +1,19 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface ViewController : UIViewController + +@end diff --git a/mediapipe/examples/ios/facedetectioncpu/ViewController.mm b/mediapipe/examples/ios/facedetectioncpu/ViewController.mm new file mode 100644 index 000000000..99bd03e16 --- /dev/null +++ b/mediapipe/examples/ios/facedetectioncpu/ViewController.mm @@ -0,0 +1,178 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "ViewController.h" + +#import "mediapipe/objc/MPPGraph.h" +#import "mediapipe/objc/MPPCameraInputSource.h" +#import "mediapipe/objc/MPPLayerRenderer.h" + +static NSString* const kGraphName = @"mobile_cpu"; + +static const char* kInputStream = "input_video"; +static const char* kOutputStream = "output_video"; +static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; + +@interface ViewController () + +// The MediaPipe graph currently in use. Initialized in viewDidLoad, started in viewWillAppear: and +// sent video frames on _videoQueue. +@property(nonatomic) MPPGraph* mediapipeGraph; + +@end + +@implementation ViewController { + /// Handles camera access via AVCaptureSession library. + MPPCameraInputSource* _cameraSource; + + /// Inform the user when camera is unavailable. + IBOutlet UILabel* _noCameraLabel; + /// Display the camera preview frames. + IBOutlet UIView* _liveView; + /// Render frames in a layer. + MPPLayerRenderer* _renderer; + + /// Process camera frames on this queue. + dispatch_queue_t _videoQueue; +} + +#pragma mark - Cleanup methods + +- (void)dealloc { + self.mediapipeGraph.delegate = nil; + [self.mediapipeGraph cancel]; + // Ignore errors since we're cleaning up. + [self.mediapipeGraph closeAllInputStreamsWithError:nil]; + [self.mediapipeGraph waitUntilDoneWithError:nil]; +} + +#pragma mark - MediaPipe graph methods + ++ (MPPGraph*)loadGraphFromResource:(NSString*)resource { + // Load the graph config resource. + NSError* configLoadError = nil; + NSBundle* bundle = [NSBundle bundleForClass:[self class]]; + if (!resource || resource.length == 0) { + return nil; + } + NSURL* graphURL = [bundle URLForResource:resource withExtension:@"binarypb"]; + NSData* data = [NSData dataWithContentsOfURL:graphURL options:0 error:&configLoadError]; + if (!data) { + NSLog(@"Failed to load MediaPipe graph config: %@", configLoadError); + return nil; + } + + // Parse the graph config resource into mediapipe::CalculatorGraphConfig proto object. + mediapipe::CalculatorGraphConfig config; + config.ParseFromArray(data.bytes, data.length); + + // Create MediaPipe graph with mediapipe::CalculatorGraphConfig proto object. + MPPGraph* newGraph = [[MPPGraph alloc] initWithGraphConfig:config]; + [newGraph addFrameOutputStream:kOutputStream outputPacketType:MediaPipePacketPixelBuffer]; + return newGraph; +} + +#pragma mark - UIViewController methods + +- (void)viewDidLoad { + [super viewDidLoad]; + + _renderer = [[MPPLayerRenderer alloc] init]; + _renderer.layer.frame = _liveView.layer.bounds; + [_liveView.layer addSublayer:_renderer.layer]; + _renderer.frameScaleMode = MediaPipeFrameScaleFillAndCrop; + // When using the front camera, mirror the input for a more natural look. + _renderer.mirrored = YES; + + dispatch_queue_attr_t qosAttribute = dispatch_queue_attr_make_with_qos_class( + DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0); + _videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute); + + _cameraSource = [[MPPCameraInputSource alloc] init]; + [_cameraSource setDelegate:self queue:_videoQueue]; + _cameraSource.sessionPreset = AVCaptureSessionPresetHigh; + _cameraSource.cameraPosition = AVCaptureDevicePositionFront; + // The frame's native format is rotated with respect to the portrait orientation. + _cameraSource.orientation = AVCaptureVideoOrientationPortrait; + + self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName]; + self.mediapipeGraph.delegate = self; + // Set maxFramesInFlight to a small value to avoid memory contention for real-time processing. + self.mediapipeGraph.maxFramesInFlight = 2; +} + +// In this application, there is only one ViewController which has no navigation to other view +// controllers, and there is only one View with live display showing the result of running the +// MediaPipe graph on the live video feed. If more view controllers are needed later, the graph +// setup/teardown and camera start/stop logic should be updated appropriately in response to the +// appearance/disappearance of this ViewController, as viewWillAppear: can be invoked multiple times +// depending on the application navigation flow in that case. +- (void)viewWillAppear:(BOOL)animated { + [super viewWillAppear:animated]; + + [_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { + if (granted) { + [self startGraphAndCamera]; + dispatch_async(dispatch_get_main_queue(), ^{ + [_noCameraLabel setHidden:YES]; + }); + } + }]; +} + +- (void)startGraphAndCamera { + // Start running self.mediapipeGraph. + NSError* error; + if (![self.mediapipeGraph startWithError:&error]) { + NSLog(@"Failed to start graph: %@", error); + } + + // Start fetching frames from the camera. + dispatch_async(_videoQueue, ^{ + [_cameraSource start]; + }); +} + +#pragma mark - MPPGraphDelegate methods + +// Receives CVPixelBufferRef from the MediaPipe graph. Invoked on a MediaPipe worker thread. +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer + fromStream:(const std::string&)streamName { + if (streamName == kOutputStream) { + // Display the captured image on the screen. + CVPixelBufferRetain(pixelBuffer); + dispatch_async(dispatch_get_main_queue(), ^{ + [_renderer renderPixelBuffer:pixelBuffer]; + CVPixelBufferRelease(pixelBuffer); + }); + } +} + +#pragma mark - MPPInputSourceDelegate methods + +// Must be invoked on _videoQueue. +- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer + timestamp:(CMTime)timestamp + fromSource:(MPPInputSource*)source { + if (source != _cameraSource) { + NSLog(@"Unknown source: %@", source); + return; + } + [self.mediapipeGraph sendPixelBuffer:imageBuffer + intoStream:kInputStream + packetType:MediaPipePacketPixelBuffer]; +} + +@end diff --git a/mediapipe/examples/ios/facedetectioncpu/main.m b/mediapipe/examples/ios/facedetectioncpu/main.m new file mode 100644 index 000000000..7ffe5ea5d --- /dev/null +++ b/mediapipe/examples/ios/facedetectioncpu/main.m @@ -0,0 +1,22 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "AppDelegate.h" + +int main(int argc, char * argv[]) { + @autoreleasepool { + return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); + } +} diff --git a/mediapipe/examples/ios/facedetectiongpu/AppDelegate.h b/mediapipe/examples/ios/facedetectiongpu/AppDelegate.h new file mode 100644 index 000000000..6b0377ef2 --- /dev/null +++ b/mediapipe/examples/ios/facedetectiongpu/AppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/mediapipe/examples/ios/facedetectiongpu/AppDelegate.m b/mediapipe/examples/ios/facedetectiongpu/AppDelegate.m new file mode 100644 index 000000000..9e1b7ff0e --- /dev/null +++ b/mediapipe/examples/ios/facedetectiongpu/AppDelegate.m @@ -0,0 +1,59 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + // Override point for customization after application launch. + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { + // Sent when the application is about to move from active to inactive state. This can occur for + // certain types of temporary interruptions (such as an incoming phone call or SMS message) or + // when the user quits the application and it begins the transition to the background state. Use + // this method to pause ongoing tasks, disable timers, and invalidate graphics rendering + // callbacks. Games should use this method to pause the game. +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { + // Use this method to release shared resources, save user data, invalidate timers, and store + // enough application state information to restore your application to its current state in case + // it is terminated later. If your application supports background execution, this method is + // called instead of applicationWillTerminate: when the user quits. +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { + // Called as part of the transition from the background to the active state; here you can undo + // many of the changes made on entering the background. +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { + // Restart any tasks that were paused (or not yet started) while the application was inactive. If + // the application was previously in the background, optionally refresh the user interface. +} + +- (void)applicationWillTerminate:(UIApplication *)application { + // Called when the application is about to terminate. Save data if appropriate. See also + // applicationDidEnterBackground:. +} + +@end diff --git a/mediapipe/examples/ios/facedetectiongpu/Assets.xcassets/AppIcon.appiconset/Contents.json b/mediapipe/examples/ios/facedetectiongpu/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 000000000..a1895a242 --- /dev/null +++ b/mediapipe/examples/ios/facedetectiongpu/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,99 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + }, + { + "idiom" : "ios-marketing", + "size" : "1024x1024", + "scale" : "1x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} + diff --git a/mediapipe/examples/ios/facedetectiongpu/Assets.xcassets/Contents.json b/mediapipe/examples/ios/facedetectiongpu/Assets.xcassets/Contents.json new file mode 100644 index 000000000..7afcdfaf8 --- /dev/null +++ b/mediapipe/examples/ios/facedetectiongpu/Assets.xcassets/Contents.json @@ -0,0 +1,7 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} + diff --git a/mediapipe/examples/ios/facedetectiongpu/BUILD b/mediapipe/examples/ios/facedetectiongpu/BUILD new file mode 100644 index 000000000..2e46f86b8 --- /dev/null +++ b/mediapipe/examples/ios/facedetectiongpu/BUILD @@ -0,0 +1,75 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +MIN_IOS_VERSION = "10.0" + +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_application", +) + +ios_application( + name = "FaceDetectionGpuApp", + bundle_id = "com.google.mediapipe.FaceDetectionGpu", + families = [ + "iphone", + "ipad", + ], + infoplists = ["Info.plist"], + minimum_os_version = MIN_IOS_VERSION, + provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + deps = [ + ":FaceDetectionGpuAppLibrary", + "@ios_opencv//:OpencvFramework", + ], +) + +objc_library( + name = "FaceDetectionGpuAppLibrary", + srcs = [ + "AppDelegate.m", + "ViewController.mm", + "main.m", + ], + hdrs = [ + "AppDelegate.h", + "ViewController.h", + ], + data = [ + "Base.lproj/LaunchScreen.storyboard", + "Base.lproj/Main.storyboard", + "//mediapipe/graphs/face_detection:mobile_gpu_binary_graph", + "//mediapipe/models:face_detection_front.tflite", + "//mediapipe/models:face_detection_front_labelmap.txt", + ], + sdk_frameworks = [ + "AVFoundation", + "CoreGraphics", + "CoreMedia", + "UIKit", + ], + deps = [ + "//mediapipe/objc:mediapipe_framework_ios", + "//mediapipe/objc:mediapipe_input_sources_ios", + "//mediapipe/objc:mediapipe_layer_renderer", + ] + select({ + "//mediapipe:ios_i386": [], + "//mediapipe:ios_x86_64": [], + "//conditions:default": [ + "//mediapipe/graphs/face_detection:mobile_calculators", + ], + }), +) diff --git a/mediapipe/examples/ios/facedetectiongpu/Base.lproj/LaunchScreen.storyboard b/mediapipe/examples/ios/facedetectiongpu/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 000000000..bfa361294 --- /dev/null +++ b/mediapipe/examples/ios/facedetectiongpu/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/facedetectiongpu/Base.lproj/Main.storyboard b/mediapipe/examples/ios/facedetectiongpu/Base.lproj/Main.storyboard new file mode 100644 index 000000000..383fc4aa4 --- /dev/null +++ b/mediapipe/examples/ios/facedetectiongpu/Base.lproj/Main.storyboard @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/facedetectiongpu/Info.plist b/mediapipe/examples/ios/facedetectiongpu/Info.plist new file mode 100644 index 000000000..30db14c62 --- /dev/null +++ b/mediapipe/examples/ios/facedetectiongpu/Info.plist @@ -0,0 +1,42 @@ + + + + + NSCameraUsageDescription + This app uses the camera to demonstrate live video processing. + CFBundleDevelopmentRegion + en + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/mediapipe/examples/ios/facedetectiongpu/ViewController.h b/mediapipe/examples/ios/facedetectiongpu/ViewController.h new file mode 100644 index 000000000..e0a5a6367 --- /dev/null +++ b/mediapipe/examples/ios/facedetectiongpu/ViewController.h @@ -0,0 +1,19 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface ViewController : UIViewController + +@end diff --git a/mediapipe/examples/ios/facedetectiongpu/ViewController.mm b/mediapipe/examples/ios/facedetectiongpu/ViewController.mm new file mode 100644 index 000000000..36293110b --- /dev/null +++ b/mediapipe/examples/ios/facedetectiongpu/ViewController.mm @@ -0,0 +1,178 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "ViewController.h" + +#import "mediapipe/objc/MPPGraph.h" +#import "mediapipe/objc/MPPCameraInputSource.h" +#import "mediapipe/objc/MPPLayerRenderer.h" + +static NSString* const kGraphName = @"mobile_gpu"; + +static const char* kInputStream = "input_video"; +static const char* kOutputStream = "output_video"; +static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; + +@interface ViewController () + +// The MediaPipe graph currently in use. Initialized in viewDidLoad, started in viewWillAppear: and +// sent video frames on _videoQueue. +@property(nonatomic) MPPGraph* mediapipeGraph; + +@end + +@implementation ViewController { + /// Handles camera access via AVCaptureSession library. + MPPCameraInputSource* _cameraSource; + + /// Inform the user when camera is unavailable. + IBOutlet UILabel* _noCameraLabel; + /// Display the camera preview frames. + IBOutlet UIView* _liveView; + /// Render frames in a layer. + MPPLayerRenderer* _renderer; + + /// Process camera frames on this queue. + dispatch_queue_t _videoQueue; +} + +#pragma mark - Cleanup methods + +- (void)dealloc { + self.mediapipeGraph.delegate = nil; + [self.mediapipeGraph cancel]; + // Ignore errors since we're cleaning up. + [self.mediapipeGraph closeAllInputStreamsWithError:nil]; + [self.mediapipeGraph waitUntilDoneWithError:nil]; +} + +#pragma mark - MediaPipe graph methods + ++ (MPPGraph*)loadGraphFromResource:(NSString*)resource { + // Load the graph config resource. + NSError* configLoadError = nil; + NSBundle* bundle = [NSBundle bundleForClass:[self class]]; + if (!resource || resource.length == 0) { + return nil; + } + NSURL* graphURL = [bundle URLForResource:resource withExtension:@"binarypb"]; + NSData* data = [NSData dataWithContentsOfURL:graphURL options:0 error:&configLoadError]; + if (!data) { + NSLog(@"Failed to load MediaPipe graph config: %@", configLoadError); + return nil; + } + + // Parse the graph config resource into mediapipe::CalculatorGraphConfig proto object. + mediapipe::CalculatorGraphConfig config; + config.ParseFromArray(data.bytes, data.length); + + // Create MediaPipe graph with mediapipe::CalculatorGraphConfig proto object. + MPPGraph* newGraph = [[MPPGraph alloc] initWithGraphConfig:config]; + [newGraph addFrameOutputStream:kOutputStream outputPacketType:MediaPipePacketPixelBuffer]; + return newGraph; +} + +#pragma mark - UIViewController methods + +- (void)viewDidLoad { + [super viewDidLoad]; + + _renderer = [[MPPLayerRenderer alloc] init]; + _renderer.layer.frame = _liveView.layer.bounds; + [_liveView.layer addSublayer:_renderer.layer]; + _renderer.frameScaleMode = MediaPipeFrameScaleFillAndCrop; + // When using the front camera, mirror the input for a more natural look. + _renderer.mirrored = YES; + + dispatch_queue_attr_t qosAttribute = dispatch_queue_attr_make_with_qos_class( + DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0); + _videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute); + + _cameraSource = [[MPPCameraInputSource alloc] init]; + [_cameraSource setDelegate:self queue:_videoQueue]; + _cameraSource.sessionPreset = AVCaptureSessionPresetHigh; + _cameraSource.cameraPosition = AVCaptureDevicePositionFront; + // The frame's native format is rotated with respect to the portrait orientation. + _cameraSource.orientation = AVCaptureVideoOrientationPortrait; + + self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName]; + self.mediapipeGraph.delegate = self; + // Set maxFramesInFlight to a small value to avoid memory contention for real-time processing. + self.mediapipeGraph.maxFramesInFlight = 2; +} + +// In this application, there is only one ViewController which has no navigation to other view +// controllers, and there is only one View with live display showing the result of running the +// MediaPipe graph on the live video feed. If more view controllers are needed later, the graph +// setup/teardown and camera start/stop logic should be updated appropriately in response to the +// appearance/disappearance of this ViewController, as viewWillAppear: can be invoked multiple times +// depending on the application navigation flow in that case. +- (void)viewWillAppear:(BOOL)animated { + [super viewWillAppear:animated]; + + [_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { + if (granted) { + [self startGraphAndCamera]; + dispatch_async(dispatch_get_main_queue(), ^{ + [_noCameraLabel setHidden:YES]; + }); + } + }]; +} + +- (void)startGraphAndCamera { + // Start running self.mediapipeGraph. + NSError* error; + if (![self.mediapipeGraph startWithError:&error]) { + NSLog(@"Failed to start graph: %@", error); + } + + // Start fetching frames from the camera. + dispatch_async(_videoQueue, ^{ + [_cameraSource start]; + }); +} + +#pragma mark - MPPGraphDelegate methods + +// Receives CVPixelBufferRef from the MediaPipe graph. Invoked on a MediaPipe worker thread. +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer + fromStream:(const std::string&)streamName { + if (streamName == kOutputStream) { + // Display the captured image on the screen. + CVPixelBufferRetain(pixelBuffer); + dispatch_async(dispatch_get_main_queue(), ^{ + [_renderer renderPixelBuffer:pixelBuffer]; + CVPixelBufferRelease(pixelBuffer); + }); + } +} + +#pragma mark - MPPInputSourceDelegate methods + +// Must be invoked on _videoQueue. +- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer + timestamp:(CMTime)timestamp + fromSource:(MPPInputSource*)source { + if (source != _cameraSource) { + NSLog(@"Unknown source: %@", source); + return; + } + [self.mediapipeGraph sendPixelBuffer:imageBuffer + intoStream:kInputStream + packetType:MediaPipePacketPixelBuffer]; +} + +@end diff --git a/mediapipe/examples/ios/facedetectiongpu/main.m b/mediapipe/examples/ios/facedetectiongpu/main.m new file mode 100644 index 000000000..7ffe5ea5d --- /dev/null +++ b/mediapipe/examples/ios/facedetectiongpu/main.m @@ -0,0 +1,22 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "AppDelegate.h" + +int main(int argc, char * argv[]) { + @autoreleasepool { + return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); + } +} diff --git a/mediapipe/examples/ios/objectdetectioncpu/AppDelegate.h b/mediapipe/examples/ios/objectdetectioncpu/AppDelegate.h new file mode 100644 index 000000000..6b0377ef2 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectioncpu/AppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/mediapipe/examples/ios/objectdetectioncpu/AppDelegate.m b/mediapipe/examples/ios/objectdetectioncpu/AppDelegate.m new file mode 100644 index 000000000..9e1b7ff0e --- /dev/null +++ b/mediapipe/examples/ios/objectdetectioncpu/AppDelegate.m @@ -0,0 +1,59 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + // Override point for customization after application launch. + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { + // Sent when the application is about to move from active to inactive state. This can occur for + // certain types of temporary interruptions (such as an incoming phone call or SMS message) or + // when the user quits the application and it begins the transition to the background state. Use + // this method to pause ongoing tasks, disable timers, and invalidate graphics rendering + // callbacks. Games should use this method to pause the game. +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { + // Use this method to release shared resources, save user data, invalidate timers, and store + // enough application state information to restore your application to its current state in case + // it is terminated later. If your application supports background execution, this method is + // called instead of applicationWillTerminate: when the user quits. +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { + // Called as part of the transition from the background to the active state; here you can undo + // many of the changes made on entering the background. +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { + // Restart any tasks that were paused (or not yet started) while the application was inactive. If + // the application was previously in the background, optionally refresh the user interface. +} + +- (void)applicationWillTerminate:(UIApplication *)application { + // Called when the application is about to terminate. Save data if appropriate. See also + // applicationDidEnterBackground:. +} + +@end diff --git a/mediapipe/examples/ios/objectdetectioncpu/Assets.xcassets/AppIcon.appiconset/Contents.json b/mediapipe/examples/ios/objectdetectioncpu/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 000000000..a1895a242 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectioncpu/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,99 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + }, + { + "idiom" : "ios-marketing", + "size" : "1024x1024", + "scale" : "1x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} + diff --git a/mediapipe/examples/ios/objectdetectioncpu/Assets.xcassets/Contents.json b/mediapipe/examples/ios/objectdetectioncpu/Assets.xcassets/Contents.json new file mode 100644 index 000000000..7afcdfaf8 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectioncpu/Assets.xcassets/Contents.json @@ -0,0 +1,7 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} + diff --git a/mediapipe/examples/ios/objectdetectioncpu/BUILD b/mediapipe/examples/ios/objectdetectioncpu/BUILD new file mode 100644 index 000000000..37d316c99 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectioncpu/BUILD @@ -0,0 +1,75 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +MIN_IOS_VERSION = "10.0" + +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_application", +) + +ios_application( + name = "ObjectDetectionCpuApp", + bundle_id = "com.google.mediapipe.ObjectDetectionCpu", + families = [ + "iphone", + "ipad", + ], + infoplists = ["Info.plist"], + minimum_os_version = MIN_IOS_VERSION, + provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + deps = [ + ":ObjectDetectionCpuAppLibrary", + "@ios_opencv//:OpencvFramework", + ], +) + +objc_library( + name = "ObjectDetectionCpuAppLibrary", + srcs = [ + "AppDelegate.m", + "ViewController.mm", + "main.m", + ], + hdrs = [ + "AppDelegate.h", + "ViewController.h", + ], + data = [ + "Base.lproj/LaunchScreen.storyboard", + "Base.lproj/Main.storyboard", + "//mediapipe/graphs/object_detection:mobile_cpu_binary_graph", + "//mediapipe/models:ssdlite_object_detection.tflite", + "//mediapipe/models:ssdlite_object_detection_labelmap.txt", + ], + sdk_frameworks = [ + "AVFoundation", + "CoreGraphics", + "CoreMedia", + "UIKit", + ], + deps = [ + "//mediapipe/objc:mediapipe_framework_ios", + "//mediapipe/objc:mediapipe_input_sources_ios", + "//mediapipe/objc:mediapipe_layer_renderer", + ] + select({ + "//mediapipe:ios_i386": [], + "//mediapipe:ios_x86_64": [], + "//conditions:default": [ + "//mediapipe/graphs/object_detection:mobile_calculators", + ], + }), +) diff --git a/mediapipe/examples/ios/objectdetectioncpu/Base.lproj/LaunchScreen.storyboard b/mediapipe/examples/ios/objectdetectioncpu/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 000000000..bfa361294 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectioncpu/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/objectdetectioncpu/Base.lproj/Main.storyboard b/mediapipe/examples/ios/objectdetectioncpu/Base.lproj/Main.storyboard new file mode 100644 index 000000000..9e74b0b72 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectioncpu/Base.lproj/Main.storyboard @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/objectdetectioncpu/Info.plist b/mediapipe/examples/ios/objectdetectioncpu/Info.plist new file mode 100644 index 000000000..30db14c62 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectioncpu/Info.plist @@ -0,0 +1,42 @@ + + + + + NSCameraUsageDescription + This app uses the camera to demonstrate live video processing. + CFBundleDevelopmentRegion + en + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/mediapipe/examples/ios/objectdetectioncpu/ViewController.h b/mediapipe/examples/ios/objectdetectioncpu/ViewController.h new file mode 100644 index 000000000..e0a5a6367 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectioncpu/ViewController.h @@ -0,0 +1,19 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface ViewController : UIViewController + +@end diff --git a/mediapipe/examples/ios/objectdetectioncpu/ViewController.mm b/mediapipe/examples/ios/objectdetectioncpu/ViewController.mm new file mode 100644 index 000000000..47edc7edc --- /dev/null +++ b/mediapipe/examples/ios/objectdetectioncpu/ViewController.mm @@ -0,0 +1,176 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "ViewController.h" + +#import "mediapipe/objc/MPPGraph.h" +#import "mediapipe/objc/MPPCameraInputSource.h" +#import "mediapipe/objc/MPPLayerRenderer.h" + +static NSString* const kGraphName = @"mobile_cpu"; + +static const char* kInputStream = "input_video"; +static const char* kOutputStream = "output_video"; +static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; + +@interface ViewController () + +// The MediaPipe graph currently in use. Initialized in viewDidLoad, started in viewWillAppear: and +// sent video frames on _videoQueue. +@property(nonatomic) MPPGraph* mediapipeGraph; + +@end + +@implementation ViewController { + /// Handles camera access via AVCaptureSession library. + MPPCameraInputSource* _cameraSource; + + /// Inform the user when camera is unavailable. + IBOutlet UILabel* _noCameraLabel; + /// Display the camera preview frames. + IBOutlet UIView* _liveView; + /// Render frames in a layer. + MPPLayerRenderer* _renderer; + + /// Process camera frames on this queue. + dispatch_queue_t _videoQueue; +} + +#pragma mark - Cleanup methods + +- (void)dealloc { + self.mediapipeGraph.delegate = nil; + [self.mediapipeGraph cancel]; + // Ignore errors since we're cleaning up. + [self.mediapipeGraph closeAllInputStreamsWithError:nil]; + [self.mediapipeGraph waitUntilDoneWithError:nil]; +} + +#pragma mark - MediaPipe graph methods + ++ (MPPGraph*)loadGraphFromResource:(NSString*)resource { + // Load the graph config resource. + NSError* configLoadError = nil; + NSBundle* bundle = [NSBundle bundleForClass:[self class]]; + if (!resource || resource.length == 0) { + return nil; + } + NSURL* graphURL = [bundle URLForResource:resource withExtension:@"binarypb"]; + NSData* data = [NSData dataWithContentsOfURL:graphURL options:0 error:&configLoadError]; + if (!data) { + NSLog(@"Failed to load MediaPipe graph config: %@", configLoadError); + return nil; + } + + // Parse the graph config resource into mediapipe::CalculatorGraphConfig proto object. + mediapipe::CalculatorGraphConfig config; + config.ParseFromArray(data.bytes, data.length); + + // Create MediaPipe graph with mediapipe::CalculatorGraphConfig proto object. + MPPGraph* newGraph = [[MPPGraph alloc] initWithGraphConfig:config]; + [newGraph addFrameOutputStream:kOutputStream outputPacketType:MediaPipePacketPixelBuffer]; + return newGraph; +} + +#pragma mark - UIViewController methods + +- (void)viewDidLoad { + [super viewDidLoad]; + + _renderer = [[MPPLayerRenderer alloc] init]; + _renderer.layer.frame = _liveView.layer.bounds; + [_liveView.layer addSublayer:_renderer.layer]; + _renderer.frameScaleMode = MediaPipeFrameScaleFillAndCrop; + + dispatch_queue_attr_t qosAttribute = dispatch_queue_attr_make_with_qos_class( + DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0); + _videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute); + + _cameraSource = [[MPPCameraInputSource alloc] init]; + [_cameraSource setDelegate:self queue:_videoQueue]; + _cameraSource.sessionPreset = AVCaptureSessionPresetHigh; + _cameraSource.cameraPosition = AVCaptureDevicePositionBack; + // The frame's native format is rotated with respect to the portrait orientation. + _cameraSource.orientation = AVCaptureVideoOrientationPortrait; + + self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName]; + self.mediapipeGraph.delegate = self; + // Set maxFramesInFlight to a small value to avoid memory contention for real-time processing. + self.mediapipeGraph.maxFramesInFlight = 2; +} + +// In this application, there is only one ViewController which has no navigation to other view +// controllers, and there is only one View with live display showing the result of running the +// MediaPipe graph on the live video feed. If more view controllers are needed later, the graph +// setup/teardown and camera start/stop logic should be updated appropriately in response to the +// appearance/disappearance of this ViewController, as viewWillAppear: can be invoked multiple times +// depending on the application navigation flow in that case. +- (void)viewWillAppear:(BOOL)animated { + [super viewWillAppear:animated]; + + [_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { + if (granted) { + [self startGraphAndCamera]; + dispatch_async(dispatch_get_main_queue(), ^{ + _noCameraLabel.hidden = YES; + }); + } + }]; +} + +- (void)startGraphAndCamera { + // Start running self.mediapipeGraph. + NSError* error; + if (![self.mediapipeGraph startWithError:&error]) { + NSLog(@"Failed to start graph: %@", error); + } + + // Start fetching frames from the camera. + dispatch_async(_videoQueue, ^{ + [_cameraSource start]; + }); +} + +#pragma mark - MPPGraphDelegate methods + +// Receives CVPixelBufferRef from the MediaPipe graph. Invoked on a MediaPipe worker thread. +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer + fromStream:(const std::string&)streamName { + if (streamName == kOutputStream) { + // Display the captured image on the screen. + CVPixelBufferRetain(pixelBuffer); + dispatch_async(dispatch_get_main_queue(), ^{ + [_renderer renderPixelBuffer:pixelBuffer]; + CVPixelBufferRelease(pixelBuffer); + }); + } +} + +#pragma mark - MPPInputSourceDelegate methods + +// Must be invoked on _videoQueue. +- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer + timestamp:(CMTime)timestamp + fromSource:(MPPInputSource*)source { + if (source != _cameraSource) { + NSLog(@"Unknown source: %@", source); + return; + } + [self.mediapipeGraph sendPixelBuffer:imageBuffer + intoStream:kInputStream + packetType:MediaPipePacketPixelBuffer]; +} + +@end diff --git a/mediapipe/examples/ios/objectdetectioncpu/main.m b/mediapipe/examples/ios/objectdetectioncpu/main.m new file mode 100644 index 000000000..7ffe5ea5d --- /dev/null +++ b/mediapipe/examples/ios/objectdetectioncpu/main.m @@ -0,0 +1,22 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "AppDelegate.h" + +int main(int argc, char * argv[]) { + @autoreleasepool { + return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); + } +} diff --git a/mediapipe/examples/ios/objectdetectiongpu/AppDelegate.h b/mediapipe/examples/ios/objectdetectiongpu/AppDelegate.h new file mode 100644 index 000000000..6b0377ef2 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiongpu/AppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/mediapipe/examples/ios/objectdetectiongpu/AppDelegate.m b/mediapipe/examples/ios/objectdetectiongpu/AppDelegate.m new file mode 100644 index 000000000..9e1b7ff0e --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiongpu/AppDelegate.m @@ -0,0 +1,59 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + // Override point for customization after application launch. + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { + // Sent when the application is about to move from active to inactive state. This can occur for + // certain types of temporary interruptions (such as an incoming phone call or SMS message) or + // when the user quits the application and it begins the transition to the background state. Use + // this method to pause ongoing tasks, disable timers, and invalidate graphics rendering + // callbacks. Games should use this method to pause the game. +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { + // Use this method to release shared resources, save user data, invalidate timers, and store + // enough application state information to restore your application to its current state in case + // it is terminated later. If your application supports background execution, this method is + // called instead of applicationWillTerminate: when the user quits. +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { + // Called as part of the transition from the background to the active state; here you can undo + // many of the changes made on entering the background. +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { + // Restart any tasks that were paused (or not yet started) while the application was inactive. If + // the application was previously in the background, optionally refresh the user interface. +} + +- (void)applicationWillTerminate:(UIApplication *)application { + // Called when the application is about to terminate. Save data if appropriate. See also + // applicationDidEnterBackground:. +} + +@end diff --git a/mediapipe/examples/ios/objectdetectiongpu/Assets.xcassets/AppIcon.appiconset/Contents.json b/mediapipe/examples/ios/objectdetectiongpu/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 000000000..a1895a242 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiongpu/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,99 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + }, + { + "idiom" : "ios-marketing", + "size" : "1024x1024", + "scale" : "1x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} + diff --git a/mediapipe/examples/ios/objectdetectiongpu/Assets.xcassets/Contents.json b/mediapipe/examples/ios/objectdetectiongpu/Assets.xcassets/Contents.json new file mode 100644 index 000000000..7afcdfaf8 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiongpu/Assets.xcassets/Contents.json @@ -0,0 +1,7 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} + diff --git a/mediapipe/examples/ios/objectdetectiongpu/BUILD b/mediapipe/examples/ios/objectdetectiongpu/BUILD new file mode 100644 index 000000000..307bc4a12 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiongpu/BUILD @@ -0,0 +1,75 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +MIN_IOS_VERSION = "10.0" + +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_application", +) + +ios_application( + name = "ObjectDetectionGpuApp", + bundle_id = "com.google.mediapipe.ObjectDetectionGpu", + families = [ + "iphone", + "ipad", + ], + infoplists = ["Info.plist"], + minimum_os_version = MIN_IOS_VERSION, + provisioning_profile = "//mediapipe/examples/ios:provisioning_profile", + deps = [ + ":ObjectDetectionGpuAppLibrary", + "@ios_opencv//:OpencvFramework", + ], +) + +objc_library( + name = "ObjectDetectionGpuAppLibrary", + srcs = [ + "AppDelegate.m", + "ViewController.mm", + "main.m", + ], + hdrs = [ + "AppDelegate.h", + "ViewController.h", + ], + data = [ + "Base.lproj/LaunchScreen.storyboard", + "Base.lproj/Main.storyboard", + "//mediapipe/graphs/object_detection:mobile_gpu_binary_graph", + "//mediapipe/models:ssdlite_object_detection.tflite", + "//mediapipe/models:ssdlite_object_detection_labelmap.txt", + ], + sdk_frameworks = [ + "AVFoundation", + "CoreGraphics", + "CoreMedia", + "UIKit", + ], + deps = [ + "//mediapipe/objc:mediapipe_framework_ios", + "//mediapipe/objc:mediapipe_input_sources_ios", + "//mediapipe/objc:mediapipe_layer_renderer", + ] + select({ + "//mediapipe:ios_i386": [], + "//mediapipe:ios_x86_64": [], + "//conditions:default": [ + "//mediapipe/graphs/object_detection:mobile_calculators", + ], + }), +) diff --git a/mediapipe/examples/ios/objectdetectiongpu/Base.lproj/LaunchScreen.storyboard b/mediapipe/examples/ios/objectdetectiongpu/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 000000000..bfa361294 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiongpu/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/objectdetectiongpu/Base.lproj/Main.storyboard b/mediapipe/examples/ios/objectdetectiongpu/Base.lproj/Main.storyboard new file mode 100644 index 000000000..76dcb7823 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiongpu/Base.lproj/Main.storyboard @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/ios/objectdetectiongpu/Info.plist b/mediapipe/examples/ios/objectdetectiongpu/Info.plist new file mode 100644 index 000000000..30db14c62 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiongpu/Info.plist @@ -0,0 +1,42 @@ + + + + + NSCameraUsageDescription + This app uses the camera to demonstrate live video processing. + CFBundleDevelopmentRegion + en + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/mediapipe/examples/ios/objectdetectiongpu/ViewController.h b/mediapipe/examples/ios/objectdetectiongpu/ViewController.h new file mode 100644 index 000000000..e0a5a6367 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiongpu/ViewController.h @@ -0,0 +1,19 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface ViewController : UIViewController + +@end diff --git a/mediapipe/examples/ios/objectdetectiongpu/ViewController.mm b/mediapipe/examples/ios/objectdetectiongpu/ViewController.mm new file mode 100644 index 000000000..8b68d2cc4 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiongpu/ViewController.mm @@ -0,0 +1,176 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "ViewController.h" + +#import "mediapipe/objc/MPPGraph.h" +#import "mediapipe/objc/MPPCameraInputSource.h" +#import "mediapipe/objc/MPPLayerRenderer.h" + +static NSString* const kGraphName = @"mobile_gpu"; + +static const char* kInputStream = "input_video"; +static const char* kOutputStream = "output_video"; +static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; + +@interface ViewController () + +// The MediaPipe graph currently in use. Initialized in viewDidLoad, started in viewWillAppear: and +// sent video frames on _videoQueue. +@property(nonatomic) MPPGraph* mediapipeGraph; + +@end + +@implementation ViewController { + /// Handles camera access via AVCaptureSession library. + MPPCameraInputSource* _cameraSource; + + /// Inform the user when camera is unavailable. + IBOutlet UILabel* _noCameraLabel; + /// Display the camera preview frames. + IBOutlet UIView* _liveView; + /// Render frames in a layer. + MPPLayerRenderer* _renderer; + + /// Process camera frames on this queue. + dispatch_queue_t _videoQueue; +} + +#pragma mark - Cleanup methods + +- (void)dealloc { + self.mediapipeGraph.delegate = nil; + [self.mediapipeGraph cancel]; + // Ignore errors since we're cleaning up. + [self.mediapipeGraph closeAllInputStreamsWithError:nil]; + [self.mediapipeGraph waitUntilDoneWithError:nil]; +} + +#pragma mark - MediaPipe graph methods + ++ (MPPGraph*)loadGraphFromResource:(NSString*)resource { + // Load the graph config resource. + NSError* configLoadError = nil; + NSBundle* bundle = [NSBundle bundleForClass:[self class]]; + if (!resource || resource.length == 0) { + return nil; + } + NSURL* graphURL = [bundle URLForResource:resource withExtension:@"binarypb"]; + NSData* data = [NSData dataWithContentsOfURL:graphURL options:0 error:&configLoadError]; + if (!data) { + NSLog(@"Failed to load MediaPipe graph config: %@", configLoadError); + return nil; + } + + // Parse the graph config resource into mediapipe::CalculatorGraphConfig proto object. + mediapipe::CalculatorGraphConfig config; + config.ParseFromArray(data.bytes, data.length); + + // Create MediaPipe graph with mediapipe::CalculatorGraphConfig proto object. + MPPGraph* newGraph = [[MPPGraph alloc] initWithGraphConfig:config]; + [newGraph addFrameOutputStream:kOutputStream outputPacketType:MediaPipePacketPixelBuffer]; + return newGraph; +} + +#pragma mark - UIViewController methods + +- (void)viewDidLoad { + [super viewDidLoad]; + + _renderer = [[MPPLayerRenderer alloc] init]; + _renderer.layer.frame = _liveView.layer.bounds; + [_liveView.layer addSublayer:_renderer.layer]; + _renderer.frameScaleMode = MediaPipeFrameScaleFillAndCrop; + + dispatch_queue_attr_t qosAttribute = dispatch_queue_attr_make_with_qos_class( + DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0); + _videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute); + + _cameraSource = [[MPPCameraInputSource alloc] init]; + [_cameraSource setDelegate:self queue:_videoQueue]; + _cameraSource.sessionPreset = AVCaptureSessionPresetHigh; + _cameraSource.cameraPosition = AVCaptureDevicePositionBack; + // The frame's native format is rotated with respect to the portrait orientation. + _cameraSource.orientation = AVCaptureVideoOrientationPortrait; + + self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName]; + self.mediapipeGraph.delegate = self; + // Set maxFramesInFlight to a small value to avoid memory contention for real-time processing. + self.mediapipeGraph.maxFramesInFlight = 2; +} + +// In this application, there is only one ViewController which has no navigation to other view +// controllers, and there is only one View with live display showing the result of running the +// MediaPipe graph on the live video feed. If more view controllers are needed later, the graph +// setup/teardown and camera start/stop logic should be updated appropriately in response to the +// appearance/disappearance of this ViewController, as viewWillAppear: can be invoked multiple times +// depending on the application navigation flow in that case. +- (void)viewWillAppear:(BOOL)animated { + [super viewWillAppear:animated]; + + [_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { + if (granted) { + [self startGraphAndCamera]; + dispatch_async(dispatch_get_main_queue(), ^{ + _noCameraLabel.hidden = YES; + }); + } + }]; +} + +- (void)startGraphAndCamera { + // Start running self.mediapipeGraph. + NSError* error; + if (![self.mediapipeGraph startWithError:&error]) { + NSLog(@"Failed to start graph: %@", error); + } + + // Start fetching frames from the camera. + dispatch_async(_videoQueue, ^{ + [_cameraSource start]; + }); +} + +#pragma mark - MPPGraphDelegate methods + +// Receives CVPixelBufferRef from the MediaPipe graph. Invoked on a MediaPipe worker thread. +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer + fromStream:(const std::string&)streamName { + if (streamName == kOutputStream) { + // Display the captured image on the screen. + CVPixelBufferRetain(pixelBuffer); + dispatch_async(dispatch_get_main_queue(), ^{ + [_renderer renderPixelBuffer:pixelBuffer]; + CVPixelBufferRelease(pixelBuffer); + }); + } +} + +#pragma mark - MPPInputSourceDelegate methods + +// Must be invoked on _videoQueue. +- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer + timestamp:(CMTime)timestamp + fromSource:(MPPInputSource*)source { + if (source != _cameraSource) { + NSLog(@"Unknown source: %@", source); + return; + } + [self.mediapipeGraph sendPixelBuffer:imageBuffer + intoStream:kInputStream + packetType:MediaPipePacketPixelBuffer]; +} + +@end diff --git a/mediapipe/examples/ios/objectdetectiongpu/main.m b/mediapipe/examples/ios/objectdetectiongpu/main.m new file mode 100644 index 000000000..7ffe5ea5d --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiongpu/main.m @@ -0,0 +1,22 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "AppDelegate.h" + +int main(int argc, char * argv[]) { + @autoreleasepool { + return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); + } +} diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 184687e61..2208925d6 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -45,7 +45,7 @@ proto_library( "//mediapipe/framework:packet_generator_proto", "//mediapipe/framework:status_handler_proto", "//mediapipe/framework:stream_handler_proto", - "@protobuf_archive//:any_proto", + "@com_google_protobuf//:any_proto", ], ) @@ -137,7 +137,7 @@ mediapipe_cc_proto_library( ":packet_generator_cc_proto", ":status_handler_cc_proto", ":stream_handler_cc_proto", - "@protobuf_archive//:cc_wkt_protos", + "@com_google_protobuf//:cc_wkt_protos", ], visibility = ["//mediapipe:__subpackages__"], deps = [":calculator_proto"], @@ -1083,6 +1083,13 @@ cc_library( "scheduler_queue.h", "scheduler_shared.h", ], + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-std=c++11", + "-ObjC++", + ], + }), visibility = [":mediapipe_internal"], deps = [ ":calculator_context", @@ -1482,6 +1489,26 @@ cc_test( ], ) +cc_test( + name = "calculator_graph_bounds_test", + size = "small", + srcs = [ + "calculator_graph_bounds_test.cc", + ], + visibility = ["//visibility:public"], + deps = [ + ":calculator_context", + ":calculator_framework", + ":timestamp", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + ], +) + cc_test( name = "collection_test", size = "small", @@ -1636,3 +1663,18 @@ cc_test( "//mediapipe/framework/tool:template_parser", ], ) + +cc_test( + name = "subgraph_test", + srcs = ["subgraph_test.cc"], + deps = [ + ":calculator_framework", + ":subgraph", + ":test_calculators", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:sink", + "//mediapipe/framework/tool/testdata:dub_quad_test_subgraph", + ], +) diff --git a/mediapipe/framework/calculator.proto b/mediapipe/framework/calculator.proto index 01fe859f9..5afa7433c 100644 --- a/mediapipe/framework/calculator.proto +++ b/mediapipe/framework/calculator.proto @@ -170,6 +170,7 @@ message ProfilerConfig { bool use_packet_timestamp_for_added_packet = 6; // The maximum number of trace events buffered in memory. + // The default value buffers up to 20000 events. int64 trace_log_capacity = 7; // Trace event types that are not logged. @@ -185,8 +186,7 @@ message ProfilerConfig { int32 trace_log_count = 10; // The interval in microseconds between trace log output. - // The value -1 specifies output only when the graph is closed. - // The default value specifies trace log output once every 1 sec. + // The default value specifies trace log output once every 0.5 sec. int64 trace_log_interval_usec = 11; // The interval in microseconds between TimeNow and the highest times @@ -194,8 +194,8 @@ message ProfilerConfig { // to be appended to the TraceBuffer. int64 trace_log_margin_usec = 12; - // True specifies an event for each calculator invocation. - // False specifies a separate event for each start and finish time. + // False specifies an event for each calculator invocation. + // True specifies a separate event for each start and finish time. bool trace_log_duration_events = 13; // The number of trace log intervals per file. The total log duration is: @@ -206,6 +206,9 @@ message ProfilerConfig { // An option to turn ON/OFF writing trace files to disk. Saving trace files to // disk is enabled by default. bool trace_log_disabled = 15; + + // If true, tracer timing events are recorded and reported. + bool trace_enabled = 16; } // Describes the topology and function of a MediaPipe Graph. The graph of diff --git a/mediapipe/framework/calculator_contract.h b/mediapipe/framework/calculator_contract.h index 692cf601f..2402c2525 100644 --- a/mediapipe/framework/calculator_contract.h +++ b/mediapipe/framework/calculator_contract.h @@ -50,10 +50,14 @@ class CalculatorContract { ::mediapipe::Status Initialize(const CalculatorGraphConfig::Node& node); ::mediapipe::Status Initialize(const PacketGeneratorConfig& node); ::mediapipe::Status Initialize(const StatusHandlerConfig& node); + void SetNodeName(const std::string& node_name) { node_name_ = node_name; } // Returns the options given to this node. const CalculatorOptions& Options() const { return node_config_->options(); } + // Returns the name given to this node. + const std::string& GetNodeName() { return node_name_; } + // Returns the options given to this calculator. Template argument T must // be the type of the protobuf extension message or the protobuf::Any // message containing the options. @@ -141,6 +145,7 @@ class CalculatorContract { std::unique_ptr output_side_packets_; std::string input_stream_handler_; MediaPipeOptions input_stream_handler_options_; + std::string node_name_; std::map service_requests_; }; diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index f8a1cb8a2..01959e61b 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -56,7 +56,7 @@ #ifndef MEDIAPIPE_DISABLE_GPU namespace mediapipe { class GpuResources; -class GpuSharedData; +struct GpuSharedData; } // namespace mediapipe #endif // !defined(MEDIAPIPE_DISABLE_GPU) @@ -77,8 +77,8 @@ typedef ::mediapipe::StatusOr StatusOrPoller; // #include "mediapipe/framework/calculator_framework.h" // // mediapipe::CalculatorGraphConfig config; -// RETURN_IF_ERROR(mediapipe::tool::ParseGraphFromString(THE_CONFIG, -// &config)); mediapipe::CalculatorGraph graph; +// RETURN_IF_ERROR(mediapipe::tool::ParseGraphFromString(kGraphStr, &config)); +// mediapipe::CalculatorGraph graph; // RETURN_IF_ERROR(graph.Initialize(config)); // // std::map extra_side_packets; @@ -135,7 +135,7 @@ class CalculatorGraph { // |input_templates|. Every subgraph must have its graph type specified in // CalclatorGraphConfig.type. A subgraph can be instantiated directly by // specifying its type in |graph_type|. A template graph can be instantiated - // directly by specifying its template arguments in |arguments|. + // directly by specifying its template arguments in |options|. ::mediapipe::Status Initialize( const std::vector& configs, const std::vector& templates, diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc new file mode 100644 index 000000000..09441c027 --- /dev/null +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -0,0 +1,105 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { +namespace { + +class CustomBoundCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Open(CalculatorContext* cc) final { + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + cc->Outputs().Index(0).SetNextTimestampBound(cc->InputTimestamp() + 1); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(CustomBoundCalculator); + +// Shows that ImmediateInputStreamHandler allows bounds propagation. +TEST(CalculatorGraphBounds, ImmediateHandlerBounds) { + // CustomBoundCalculator produces only timestamp bounds. + // The first PassThroughCalculator propagates bounds using SetOffset(0). + // The second PassthroughCalculator delivers an output packet whenever the + // first PassThroughCalculator delivers a timestamp bound. + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: 'input' + node { + calculator: 'CustomBoundCalculator' + input_stream: 'input' + output_stream: 'bounds' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'bounds' + output_stream: 'bounds_2' + input_stream_handler { + input_stream_handler: "ImmediateInputStreamHandler" + } + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'bounds_2' + input_stream: 'input' + output_stream: 'bounds_output' + output_stream: 'output' + } + )"); + CalculatorGraph graph; + std::vector output_packets; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) { + output_packets.push_back(p); + return ::mediapipe::OkStatus(); + })); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + + // Add four packets into the graph. + for (int i = 0; i < 4; ++i) { + Packet p = MakePacket(33).At(Timestamp(i)); + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream("input", p)); + } + + // Four packets arrive at the output only if timestamp bounds are propagated. + MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_EQ(output_packets.size(), 4); + + // Eventually four packets arrive. + MEDIAPIPE_ASSERT_OK(graph.CloseAllPacketSources()); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_EQ(output_packets.size(), 4); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/deps/expand_template.bzl b/mediapipe/framework/deps/expand_template.bzl new file mode 100644 index 000000000..a6e90381d --- /dev/null +++ b/mediapipe/framework/deps/expand_template.bzl @@ -0,0 +1,45 @@ +"""Rule for simple expansion of template files. This performs a simple +search over the template file for the keys in substitutions, +and replaces them with the corresponding values. + +Typical usage: + load("//tools/build_rules:expand_template.bzl", "expand_template") + expand_template( + name = "ExpandMyTemplate", + template = "my.template", + out = "my.txt", + substitutions = { + "$VAR1": "foo", + "$VAR2": "bar", + } + ) + +Args: + name: The name of the rule. + template: The template file to expand + out: The destination of the expanded file + substitutions: A dictionary mapping strings to their substitutions + is_executable: A boolean indicating whether the output file should be executable +""" + +def expand_template_impl(ctx): + ctx.actions.expand_template( + template = ctx.file.template, + output = ctx.outputs.out, + substitutions = { + k: ctx.expand_location(v, ctx.attr.data) + for k, v in ctx.attr.substitutions.items() + }, + is_executable = ctx.attr.is_executable, + ) + +expand_template = rule( + implementation = expand_template_impl, + attrs = { + "template": attr.label(mandatory = True, allow_single_file = True), + "substitutions": attr.string_dict(mandatory = True), + "out": attr.output(mandatory = True), + "is_executable": attr.bool(default = False, mandatory = False), + "data": attr.label_list(allow_files = True), + }, +) diff --git a/mediapipe/framework/encode_binary_proto.bzl b/mediapipe/framework/encode_binary_proto.bzl index aeb5a983e..74d59ced3 100644 --- a/mediapipe/framework/encode_binary_proto.bzl +++ b/mediapipe/framework/encode_binary_proto.bzl @@ -23,7 +23,7 @@ Args: output: The desired name of the output file. Optional. """ -PROTOC = "@protobuf_archive//:protoc" +PROTOC = "@com_google_protobuf//:protoc" def _canonicalize_proto_path_oss(all_protos, genfile_path): """For the protos from external repository, canonicalize the proto path and the file name. @@ -42,12 +42,32 @@ def _canonicalize_proto_path_oss(all_protos, genfile_path): proto_file_names.append(s.path) return ([" --proto_path=" + path for path in proto_paths], proto_file_names) +def _get_proto_provider(dep): + """Get the provider for protocol buffers from a dependnecy. + + Necessary because Bazel does not provide the .proto. provider but ProtoInfo + cannot be created from Starlark at the moment. + + Returns: + The provider containing information about protocol buffers. + """ + if ProtoInfo in dep: + return dep[ProtoInfo] + elif hasattr(dep, "proto"): + return dep.proto + else: + fail("cannot happen, rule definition requires .proto or ProtoInfo") + def _encode_binary_proto_impl(ctx): """Implementation of the encode_binary_proto rule.""" all_protos = depset() for dep in ctx.attr.deps: - if hasattr(dep, "proto"): - all_protos = depset([], transitive = [all_protos, dep.proto.transitive_sources]) + provider = _get_proto_provider(dep) + all_protos = depset( + direct = [], + transitive = [all_protos, provider.transitive_sources], + ) + textpb = ctx.file.input binarypb = ctx.outputs.output or ctx.actions.declare_file( textpb.basename.rsplit(".", 1)[0] + ".binarypb", @@ -84,7 +104,7 @@ encode_binary_proto = rule( cfg = "host", ), "deps": attr.label_list( - providers = ["proto"], + providers = [[ProtoInfo], ["proto"]], ), "input": attr.label( mandatory = True, @@ -100,9 +120,9 @@ encode_binary_proto = rule( def _generate_proto_descriptor_set_impl(ctx): """Implementation of the generate_proto_descriptor_set rule.""" all_protos = depset(transitive = [ - dep.proto.transitive_sources + _get_proto_provider(dep).transitive_sources for dep in ctx.attr.deps - if hasattr(dep, "proto") + if ProtoInfo in dep or hasattr(dep, "proto") ]) descriptor = ctx.outputs.output @@ -115,7 +135,6 @@ def _generate_proto_descriptor_set_impl(ctx): executable = ctx.executable._proto_compiler, arguments = [ "--descriptor_set_out=%s" % descriptor.path, - "--absolute_paths", "--proto_path=" + ctx.genfiles_dir.path, "--proto_path=.", ] + @@ -132,7 +151,7 @@ generate_proto_descriptor_set = rule( cfg = "host", ), "deps": attr.label_list( - providers = ["proto"], + providers = [[ProtoInfo], ["proto"]], ), }, outputs = {"output": "%{name}.proto.bin"}, diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index dde7f95c9..38e4e73af 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -186,7 +186,7 @@ cc_library( "//mediapipe/framework/port:statusor", "//mediapipe/framework/formats/annotation:rasterization_cc_proto", ] + select({ - "//conditions:default": ["@protobuf_archive//:protobuf"], + "//conditions:default": ["@com_google_protobuf//:protobuf"], }) + select({ "//conditions:default": [ "//mediapipe/framework/port:opencv_imgproc", @@ -195,6 +195,7 @@ cc_library( "//conditions:default": [ ], "//mediapipe:android": [], + "//mediapipe:apple": [], }), alwayslink = 1, ) @@ -241,6 +242,21 @@ proto_library( mediapipe_cc_proto_library( name = "rect_cc_proto", srcs = ["rect.proto"], - visibility = ["//mediapipe:__subpackages__"], + visibility = [ + "//mediapipe:__subpackages__", + ], deps = [":rect_proto"], ) + +proto_library( + name = "landmark_proto", + srcs = ["landmark.proto"], + visibility = ["//mediapipe:__subpackages__"], +) + +mediapipe_cc_proto_library( + name = "landmark_cc_proto", + srcs = ["landmark.proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":landmark_proto"], +) diff --git a/mediapipe/framework/formats/image_frame.h b/mediapipe/framework/formats/image_frame.h index 93e5769d9..ff877cd16 100644 --- a/mediapipe/framework/formats/image_frame.h +++ b/mediapipe/framework/formats/image_frame.h @@ -12,29 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// -// Get a WImageView of the ImageFrame -// WImageView_b wimage_view = -// ::mediapipe::formats::WImageView(&const_frame); -// The mutable version. -// WImageView_b wimage_view = -// ::mediapipe::formats::MutableWImageView(&frame); -// -// Get an IplImage view of the ImageFrame (this is efficient): -// ::mediapipe::formats::IplView(&frame); -// // Get a cv::Mat view of the ImageFrame (this is efficient): // ::mediapipe::formats::MatView(&frame); // -// Make a constant colored ImageFrame: -// const uint8 kColor[] = {kRed, kGreen, kBlue, kAlpha}; -// ImageFrame frame(ImageFormat::SRGB, kWidth, kHeight); -// ::mediapipe::formats::MutableWImageView(&frame).Set(kColor); -// -// Copying image data from a WImage: -// ::mediapipe::formats::MutableWImageView(&gray8_image_frame) -// .CopyFrom(grayscale_wimage); -// // Copying data from raw data (stored contiguously): // frame.CopyPixelData(format, width, height, raw_data_ptr, // ImageFrame::kDefaultAlignmentBoundary); @@ -51,29 +31,6 @@ // cv::Mat destination = ::mediapipe::formats::MatView(&small_image); // cv::resize(::mediapipe::formats::MatView(&large_image), destination, // destination.size(), 0, 0, cv::INTER_LINEAR); -// -// Copy an ImageFrame into a RawImage: -// RawImage image; -// frame.CopyToResizeableImage(&image); -// -// Encoding a PNG image: -// WImageIO::EncodePNG(frame.Image(), &image_string); -// -// Encoding a JPEG image: -// WImageIO::EncodeJPEG(frame.Image(), 75 /* quality */, -// &image_string); -// -// Decoding a (RGB) JPEG/PNG/WebP image: -// auto wimage = gtl::MakeUnique(); -// auto* wimage_ptr = wimage.get(); -// WImageIO::DecodeImage(image_string, wimage.get()); -// auto frame = gtl::MakeUnique( -// /*format=*/ImageFormat::SRGB, /*width=*/wimage->Width(), -// /*height=*/wimage->Height(), -// /*width_step=*/wimage->WidthStep(), -// /*pixel_data=*/wimage->ImageData(), -// /*deleter=*/[wimage_ptr](uint8*) { delete wimage_ptr; }); -// wimage.release(); // wimage is owned by frame now. #ifndef MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_FRAME_H_ #define MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_FRAME_H_ diff --git a/mediapipe/framework/formats/landmark.proto b/mediapipe/framework/formats/landmark.proto new file mode 100644 index 000000000..cdc2ee151 --- /dev/null +++ b/mediapipe/framework/formats/landmark.proto @@ -0,0 +1,34 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +// A landmark that can have 1 to 3 dimensions. Use x for 1D points, (x, y) for +// 2D points and (x, y, z) for 3D points. For more dimensions, consider using +// matrix_data.proto. +message Landmark { + optional float x = 1; + optional float y = 2; + optional float z = 3; +} + +// A normalized version of above Landmark proto. All coordiates should be within +// [0, 1]. +message NormalizedLandmark { + optional float x = 1; + optional float y = 2; + optional float z = 3; +} diff --git a/mediapipe/framework/input_stream_manager.cc b/mediapipe/framework/input_stream_manager.cc index b2dcc4c86..edbe6a689 100644 --- a/mediapipe/framework/input_stream_manager.cc +++ b/mediapipe/framework/input_stream_manager.cc @@ -151,15 +151,15 @@ template // If the caller is MovePackets(), packet's underlying holder should be // transferred into queue_. Otherwise, queue_ keeps a copy of the packet. + ++num_packets_added_; + VLOG(2) << "Input stream:" << name_ + << " has added packet at time: " << packet.Timestamp(); if (std::is_const< typename std::remove_reference::type>::value) { queue_.emplace_back(packet); } else { queue_.emplace_back(std::move(packet)); } - ++num_packets_added_; - VLOG(2) << "Input stream:" << name_ - << " has added packet at time: " << packet.Timestamp(); } queue_became_full = (!was_queue_full && max_queue_size_ != -1 && queue_.size() >= max_queue_size_); diff --git a/mediapipe/framework/legacy_calculator_support.cc b/mediapipe/framework/legacy_calculator_support.cc index 3f2503bb4..b96d1e41f 100644 --- a/mediapipe/framework/legacy_calculator_support.cc +++ b/mediapipe/framework/legacy_calculator_support.cc @@ -16,8 +16,6 @@ namespace mediapipe { -// We only define this variable for two specializations of the template -// because it is only meant to be used for these two types. #if EMSCRIPTEN_WORKAROUND_FOR_B121216479 template <> CalculatorContext* diff --git a/mediapipe/framework/legacy_calculator_support.h b/mediapipe/framework/legacy_calculator_support.h index cdc5b2363..a78a21b91 100644 --- a/mediapipe/framework/legacy_calculator_support.h +++ b/mediapipe/framework/legacy_calculator_support.h @@ -99,6 +99,23 @@ class LegacyCalculatorSupport { }; }; +// We only declare this variable for two specializations of the template because +// it is only meant to be used for these two types. +#if EMSCRIPTEN_WORKAROUND_FOR_B121216479 +template <> +CalculatorContext* LegacyCalculatorSupport::Scoped::current_; +template <> +CalculatorContract* + LegacyCalculatorSupport::Scoped::current_; +#else +template <> +thread_local CalculatorContext* + LegacyCalculatorSupport::Scoped::current_; +template <> +thread_local CalculatorContract* + LegacyCalculatorSupport::Scoped::current_; +#endif // EMSCRIPTEN_WORKAROUND_FOR_B121216479 + } // namespace mediapipe #endif // MEDIAPIPE_FRAMEWORK_LEGACY_CALCULATOR_SUPPORT_H_ diff --git a/mediapipe/framework/packet.h b/mediapipe/framework/packet.h index f1b6895c8..048e2d2d8 100644 --- a/mediapipe/framework/packet.h +++ b/mediapipe/framework/packet.h @@ -46,12 +46,6 @@ class Packet; namespace packet_internal { class HolderBase; -// Defined in packet_serialization.cc -// TODO Remove once friend statements are unneeded. -::mediapipe::StatusOr SerializePacket(const Packet& packet); -::mediapipe::StatusOr SerializePacketContents( - const Packet& packet); - Packet Create(HolderBase* holder); Packet Create(HolderBase* holder, Timestamp timestamp); const HolderBase* GetHolder(const Packet& packet); @@ -70,9 +64,6 @@ const HolderBase* GetHolder(const Packet& packet); // PointToForeign allows a Packet to be constructed which does not // own it's data. // -// See packet_serialization.h for helper functions to serialize and -// deserialize packets. -// // This class is thread compatible. class Packet { public: @@ -200,13 +191,6 @@ class Packet { std::string DebugTypeName() const; private: - // TODO Change serialize_fn to take a Packet instead of a - // HolderBase, removing the need to friend these classes. - friend ::mediapipe::StatusOr SerializePacket( - const Packet& packet); - friend ::mediapipe::StatusOr SerializePacketContents( - const Packet& packet); - friend Packet packet_internal::Create(packet_internal::HolderBase* holder); friend Packet packet_internal::Create(packet_internal::HolderBase* holder, class Timestamp timestamp); diff --git a/mediapipe/framework/port.h b/mediapipe/framework/port.h index 749f6287a..c45a4546d 100644 --- a/mediapipe/framework/port.h +++ b/mediapipe/framework/port.h @@ -22,8 +22,9 @@ // For consistency, we now set MEDIAPIPE_MOBILE there too. However, for the sake // of projects that may want to build MediaPipe using alternative build systems, // we also try to set platform-specific defines in this header if missing. -#if !defined(MEDIAPIPE_MOBILE) && \ - (defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__)) +#if !defined(MEDIAPIPE_MOBILE) && \ + (defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) || \ + defined(__EMSCRIPTEN__)) #define MEDIAPIPE_MOBILE #endif diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index bc4b610a5..188c22e5e 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -60,7 +60,7 @@ cc_library( ":core_proto", "//mediapipe/framework:port", ] + select({ - "//conditions:default": ["@protobuf_archive//:protobuf"], + "//conditions:default": ["@com_google_protobuf//:protobuf"], }), ) @@ -75,7 +75,7 @@ cc_library( ":core_proto", "//mediapipe/framework:port", ] + select({ - "//conditions:default": ["@protobuf_archive//:protobuf"], + "//conditions:default": ["@com_google_protobuf//:protobuf"], }), ) @@ -96,7 +96,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", - "@com_google_glog//:glog", + "@com_github_glog_glog//:glog", ], ) @@ -110,7 +110,7 @@ cc_library( deps = [ "//mediapipe/framework:port", ] + select({ - "//conditions:default": ["@protobuf_archive//:protobuf"], + "//conditions:default": ["@com_google_protobuf//:protobuf"], }), ) @@ -175,7 +175,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", - "@com_google_glog//:glog", + "@com_github_glog_glog//:glog", ], ) @@ -258,7 +258,7 @@ cc_library( ":logging", "//mediapipe/framework:port", ] + select({ - "//conditions:default": ["@protobuf_archive//:protobuf"], + "//conditions:default": ["@com_google_protobuf//:protobuf"], }), ) @@ -351,6 +351,7 @@ cc_library( deps = select({ "//conditions:default": [":threadpool_impl_default_to_google"], "//mediapipe:android": [":threadpool_impl_default_to_mediapipe"], + "//mediapipe:apple": [":threadpool_impl_default_to_mediapipe"], }), ) diff --git a/mediapipe/framework/port/advanced_proto_lite_inc.h b/mediapipe/framework/port/advanced_proto_lite_inc.h index 90ae62b92..a6627cbb9 100644 --- a/mediapipe/framework/port/advanced_proto_lite_inc.h +++ b/mediapipe/framework/port/advanced_proto_lite_inc.h @@ -20,10 +20,10 @@ #include "google/protobuf/io/zero_copy_stream_impl_lite.h" #include "google/protobuf/wire_format_lite.h" -#include "google/protobuf/wire_format_lite_inl.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/proto_ns.h" + namespace mediapipe { using proto_int64 = google::protobuf::int64; using proto_uint64 = google::protobuf::uint64; diff --git a/mediapipe/framework/port/build_config.bzl b/mediapipe/framework/port/build_config.bzl index 7f665842a..4699f10c4 100644 --- a/mediapipe/framework/port/build_config.bzl +++ b/mediapipe/framework/port/build_config.bzl @@ -3,8 +3,7 @@ """.bzl file for mediapipe open source build configs.""" -load("@protobuf_archive//:protobuf.bzl", "cc_proto_library") -load("@protobuf_archive//:protobuf.bzl", "py_proto_library") +load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library") def mediapipe_py_proto_library( name, @@ -26,9 +25,9 @@ def mediapipe_py_proto_library( name = name, srcs = srcs, visibility = visibility, - default_runtime = "@protobuf_archive//:protobuf_python", - protoc = "@protobuf_archive//:protoc", - deps = py_proto_deps + ["@protobuf_archive//:protobuf_python"], + default_runtime = "@com_google_protobuf//:protobuf_python", + protoc = "@com_google_protobuf//:protoc", + deps = py_proto_deps + ["@com_google_protobuf//:protobuf_python"], ) def mediapipe_cc_proto_library(name, srcs, visibility, deps = [], cc_deps = [], testonly = 0): @@ -48,7 +47,8 @@ def mediapipe_cc_proto_library(name, srcs, visibility, deps = [], cc_deps = [], visibility = visibility, deps = cc_deps, testonly = testonly, - cc_libs = ["@protobuf_archive//:protobuf"], - protoc = "@protobuf_archive//:protoc", - default_runtime = "@protobuf_archive//:protobuf", + cc_libs = ["@com_google_protobuf//:protobuf"], + protoc = "@com_google_protobuf//:protoc", + default_runtime = "@com_google_protobuf//:protobuf", + alwayslink = 1, ) diff --git a/mediapipe/framework/port/opencv_video_inc.h b/mediapipe/framework/port/opencv_video_inc.h index bf6f956da..80b2c6251 100644 --- a/mediapipe/framework/port/opencv_video_inc.h +++ b/mediapipe/framework/port/opencv_video_inc.h @@ -21,7 +21,7 @@ #ifdef CV_VERSION_EPOCH // for OpenCV 2.x #include - +#include // Copied from "opencv2/videoio.hpp" in OpenCV 4.0.1 namespace cv { enum VideoCaptureProperties { @@ -80,7 +80,19 @@ inline int fourcc(char c1, char c2, char c3, char c4) { } // namespace mediapipe #else +#include #include + +#if CV_VERSION_MAJOR == 4 +#include + +namespace cv { +inline Ptr createOptFlow_DualTVL1() { + return optflow::createOptFlow_DualTVL1(); +} +} // namespace cv +#endif + namespace mediapipe { inline int fourcc(char c1, char c2, char c3, char c4) { return cv::VideoWriter::fourcc(c1, c2, c3, c4); diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 92aee3c68..e7407b4da 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -18,7 +18,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//mediapipe/framework:__subpackages__"]) # This is used to enable the profiler on platforms where it is not on by default. -# To enable, pass --define=MEDIAPIPE_PROFILING=1 to bazel. +# To enable, pass --define MEDIAPIPE_PROFILING=1 to bazel. config_setting( name = "graph_profiler_enabled", values = { @@ -223,8 +223,8 @@ cc_test( ":graph_profiler", ":graph_tracer", ":test_context_builder", + "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:immediate_mux_calculator", - "//mediapipe/calculators/core:real_time_flow_limiter_calculator", "//mediapipe/calculators/core:round_robin_demux_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -268,9 +268,19 @@ cc_library( srcs = select({ "//conditions:default": ["profiler_resource_util.cc"], "//mediapipe:android": ["profiler_resource_util_android.cc"], + "//mediapipe:apple": ["profiler_resource_util_apple.cc"], + "//mediapipe:macos": ["profiler_resource_util.cc"], }), hdrs = ["profiler_resource_util.h"], # We use Objective-C++ on iOS. + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-std=c++11", + "-ObjC++", + ], + "//mediapipe:macos": [], + }), visibility = [ "//mediapipe/framework:mediapipe_internal", ], @@ -285,5 +295,6 @@ cc_library( "//mediapipe:android": [ "//mediapipe/java/com/google/mediapipe/framework/jni:jni_util", ], + "//mediapipe:apple": [], }), ) diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index 20fe156d9..3d3529d39 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -66,7 +66,7 @@ bool IsProfilerEnabled(const ProfilerConfig& profiler_config) { // Returns true if trace events are recorded. bool IsTracerEnabled(const ProfilerConfig& profiler_config) { - return profiler_config.trace_log_capacity() > 0; + return profiler_config.trace_enabled(); } // Returns true if trace events are written to a log file. @@ -586,7 +586,7 @@ void AssignNodeNames(GraphProfile* profile) { absl::Microseconds(profiler_config_.trace_log_margin_usec()); GraphProfile profile; GraphTrace* trace = profile.add_graph_trace(); - if (profiler_config_.trace_log_duration_events()) { + if (!profiler_config_.trace_log_duration_events()) { tracer()->GetTrace(previous_log_end_time_, end_time, trace); } else { tracer()->GetLog(previous_log_end_time_, end_time, trace); diff --git a/mediapipe/framework/profiler/graph_profiler_ios_test.mm b/mediapipe/framework/profiler/graph_profiler_ios_test.mm new file mode 100644 index 000000000..e64417580 --- /dev/null +++ b/mediapipe/framework/profiler/graph_profiler_ios_test.mm @@ -0,0 +1,74 @@ + +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include "absl/memory/memory.h" +#include "mediapipe/framework/profiler/graph_profiler.h" +#include "mediapipe/framework/profiler/profiler_resource_util.h" +#include "mediapipe/objc/MPPGraph.h" +#include "mediapipe/objc/MPPGraphTestBase.h" + +static NSString* const kTraceFilename = @"mediapipe_trace_0.binarypb"; + +static const char* kOutputStream = "counter"; + +@interface GraphProfilerTest : MPPGraphTestBase +@end + +@implementation GraphProfilerTest + +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPacket:(const mediapipe::Packet&)packet + fromStream:(const std::string&)streamName { + XCTAssertTrue(streamName == kOutputStream); + NSLog(@"Received counter packet."); +} + +- (void)testDefaultTraceLogPathValueIsSet { + mediapipe::CalculatorGraphConfig graphConfig; + mediapipe::CalculatorGraphConfig::Node* node = graphConfig.add_node(); + node->set_calculator("SimpleCalculator"); + node->add_output_stream(kOutputStream); + + mediapipe::ProfilerConfig* profilerConfig = graphConfig.mutable_profiler_config(); + profilerConfig->set_trace_enabled(true); + profilerConfig->set_enable_profiler(true); + profilerConfig->set_trace_log_disabled(false); + + MPPGraph* graph = [[MPPGraph alloc] initWithGraphConfig:graphConfig]; + [graph addFrameOutputStream:kOutputStream outputPacketType:MediaPipePacketRaw]; + graph.delegate = self; + + NSError* error; + BOOL success = [graph startWithError:&error]; + XCTAssertTrue(success, @"%@", error.localizedDescription); + + // Shut down the graph. + success = [graph waitUntilDoneWithError:&error]; + XCTAssertTrue(success, @"%@", error.localizedDescription); + + mediapipe::StatusOr getTraceLogDir = mediapipe::GetDefaultTraceLogDirectory(); + XCTAssertTrue(getTraceLogDir.ok(), "GetDefaultTraceLogDirectory failed."); + + NSString* directoryPath = [NSString stringWithCString:(*getTraceLogDir).c_str() + encoding:[NSString defaultCStringEncoding]]; + NSString* traceLogPath = [directoryPath stringByAppendingPathComponent:kTraceFilename]; + BOOL traceLogFileExists = [[NSFileManager defaultManager] fileExistsAtPath:traceLogPath]; + XCTAssertTrue(traceLogFileExists, @"Trace log file not found at path: %@", traceLogPath); +} + +@end diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index 90b05f242..d6107a74c 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -342,7 +342,7 @@ TEST_F(GraphProfilerTestPeer, Initialize) { output_stream: "source_stream2" } node { - calculator: "RealTimeFlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "FINISHED:my_other_stream" input_stream: "source_stream2" input_stream_info: { @@ -378,7 +378,7 @@ TEST_F(GraphProfilerTestPeer, Initialize) { CheckHasProfilesWithInputStreamName("A_Normal_Calc", {"input_stream", "source_stream1"}); CheckHasProfilesWithInputStreamName("Another_Source_Calc", {}); - CheckHasProfilesWithInputStreamName("RealTimeFlowLimiterCalculator", + CheckHasProfilesWithInputStreamName("FlowLimiterCalculator", {"source_stream2", "my_other_stream"}); CheckHasProfilesWithInputStreamName("Another_Normal_Calc", {"my_stream", "gated_source_stream2"}); diff --git a/mediapipe/framework/profiler/graph_tracer.cc b/mediapipe/framework/profiler/graph_tracer.cc index 286844376..0f599879f 100644 --- a/mediapipe/framework/profiler/graph_tracer.cc +++ b/mediapipe/framework/profiler/graph_tracer.cc @@ -27,7 +27,7 @@ namespace mediapipe { namespace { -const absl::Duration kDefaultTraceLogInterval = absl::Milliseconds(100); +const absl::Duration kDefaultTraceLogInterval = absl::Milliseconds(500); // Returns a unique identifier for the current thread. inline int GetCurrentThreadId() { @@ -45,7 +45,9 @@ absl::Duration GraphTracer::GetTraceLogInterval() { } int64 GraphTracer::GetTraceLogCapacity() { - return profiler_config_.trace_log_capacity(); + return profiler_config_.trace_log_capacity() + ? profiler_config_.trace_log_capacity() + : 20000; } GraphTracer::GraphTracer(const ProfilerConfig& profiler_config) diff --git a/mediapipe/framework/profiler/graph_tracer_test.cc b/mediapipe/framework/profiler/graph_tracer_test.cc index 743a7f667..5031505af 100644 --- a/mediapipe/framework/profiler/graph_tracer_test.cc +++ b/mediapipe/framework/profiler/graph_tracer_test.cc @@ -70,6 +70,7 @@ class GraphTracerTest : public ::testing::Test { void SetUpGraphTracer(size_t size) { ProfilerConfig profiler_config; profiler_config.set_trace_log_capacity(size); + profiler_config.set_trace_enabled(true); tracer_ = absl::make_unique(profiler_config); } @@ -334,7 +335,7 @@ class GraphTracerE2ETest : public ::testing::Test { profiler_config { histogram_interval_size_usec: 1000 num_histogram_intervals: 100 - trace_log_capacity: 1000000 + trace_enabled: true } )", &graph_config_)); @@ -348,7 +349,7 @@ class GraphTracerE2ETest : public ::testing::Test { output_stream: "input_packets_0" } node { - calculator: 'RealTimeFlowLimiterCalculator' + calculator: 'FlowLimiterCalculator' input_stream_handler { input_stream_handler: 'ImmediateInputStreamHandler' } @@ -392,7 +393,7 @@ class GraphTracerE2ETest : public ::testing::Test { profiler_config { histogram_interval_size_usec: 1000 num_histogram_intervals: 100 - trace_log_capacity: 1000000 + trace_enabled: true } )", &graph_config_)); @@ -928,7 +929,7 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFile) { RunDemuxInFlightGraph(); GraphProfile profile; ReadGraphProfile(absl::StrCat(log_path, 0, ".binarypb"), &profile); - EXPECT_EQ(117, profile.graph_trace(0).calculator_trace().size()); + EXPECT_EQ(89, profile.graph_trace(0).calculator_trace().size()); } TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { @@ -951,7 +952,7 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { event_counts.push_back(count); graph_profiles.push_back(profile); } - std::vector expected = {45, 50, 22, 0, 0, 0, 0}; + std::vector expected = {37, 42, 19, 0, 0, 0, 0}; EXPECT_EQ(event_counts, expected); GraphProfile& profile_2 = graph_profiles[2]; profile_2.clear_calculator_profiles(); @@ -966,7 +967,7 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { base_time: 1544086800000000 base_timestamp: 0 calculator_name: "LambdaCalculator_1" - calculator_name: "RealTimeFlowLimiterCalculator" + calculator_name: "FlowLimiterCalculator" calculator_name: "RoundRobinDemuxCalculator" calculator_name: "LambdaCalculator_1" calculator_name: "LambdaCalculator" @@ -1080,20 +1081,14 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { input_timestamp: 50000 event_type: PROCESS start_time: 65004 - input_trace { packet_timestamp: 50000 stream_id: 5 } - } - calculator_trace { - node_id: 5 - input_timestamp: 50000 - event_type: PROCESS finish_time: 65004 + input_trace { + start_time: 65004 + finish_time: 65004 + packet_timestamp: 50000 + stream_id: 5 + } output_trace { packet_timestamp: 50000 stream_id: 6 } - } - calculator_trace { - node_id: 5 - input_timestamp: 50000 - event_type: PROCESS - finish_time: 65004 output_trace { packet_timestamp: 50000 stream_id: 7 } } calculator_trace { @@ -1121,7 +1116,12 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { input_timestamp: 50000 event_type: PROCESS start_time: 65004 - input_trace { packet_timestamp: 50000 stream_id: 7 } + input_trace { + start_time: 65004 + finish_time: 65004 + packet_timestamp: 50000 + stream_id: 7 + } } calculator_trace { node_id: 1 @@ -1176,13 +1176,13 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { input_timestamp: 40000 event_type: PROCESS start_time: 70004 - input_trace { packet_timestamp: 40000 stream_id: 8 } - } - calculator_trace { - node_id: 5 - input_timestamp: 40000 - event_type: PROCESS finish_time: 70004 + input_trace { + start_time: 70004 + finish_time: 70004 + packet_timestamp: 40000 + stream_id: 8 + } output_trace { packet_timestamp: 50001 stream_id: 7 } } calculator_trace { @@ -1205,7 +1205,12 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { input_timestamp: 50001 event_type: PROCESS start_time: 70004 - input_trace { packet_timestamp: 50001 stream_id: 7 } + input_trace { + start_time: 70004 + finish_time: 70004 + packet_timestamp: 50001 + stream_id: 7 + } } calculator_trace { node_id: 1 @@ -1234,8 +1239,8 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { input_side_packet: "callback_2" } node { - name: "RealTimeFlowLimiterCalculator" - calculator: "RealTimeFlowLimiterCalculator" + name: "FlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "input_packets_0" input_stream: "FINISHED:finish_indicator" output_stream: "input_0_sampled" @@ -1281,10 +1286,10 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { profiler_config { histogram_interval_size_usec: 1000 num_histogram_intervals: 100 - trace_log_capacity: 1000000 trace_log_count: 100 trace_log_interval_usec: 2500 trace_log_interval_count: 10 + trace_enabled: true } } )"))); diff --git a/mediapipe/framework/profiler/profiler_resource_util.h b/mediapipe/framework/profiler/profiler_resource_util.h index c1cb025e4..4b41dc509 100644 --- a/mediapipe/framework/profiler/profiler_resource_util.h +++ b/mediapipe/framework/profiler/profiler_resource_util.h @@ -17,6 +17,7 @@ #include +#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" namespace mediapipe { diff --git a/mediapipe/framework/profiler/profiler_resource_util_apple.cc b/mediapipe/framework/profiler/profiler_resource_util_apple.cc index 2ae366c95..10363b058 100644 --- a/mediapipe/framework/profiler/profiler_resource_util_apple.cc +++ b/mediapipe/framework/profiler/profiler_resource_util_apple.cc @@ -15,6 +15,7 @@ #import +#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/profiler/profiler_resource_util.h" namespace mediapipe { @@ -24,10 +25,24 @@ StatusOr GetDefaultTraceLogDirectory() { NSURL* documents_directory_url = [[[NSFileManager defaultManager] URLsForDirectory:NSDocumentDirectory inDomains:NSUserDomainMask] lastObject]; - NSString* ns_documents_directory = [documents_directory_url absoluteString]; - std::string documents_directory = [ns_documents_directory UTF8String]; - return documents_directory; + // Note: "createDirectoryAtURL:..." method doesn't successfully create + // the directory, hence this code uses "createDirectoryAtPath:..". + NSString* ns_documents_directory = [documents_directory_url absoluteString]; + NSError* error; + BOOL success = [[NSFileManager defaultManager] + createDirectoryAtPath:ns_documents_directory + withIntermediateDirectories:YES + attributes:nil + error:&error]; + if (!success) { + // TODO: Use NSError+util_status to get status from NSError. + return ::mediapipe::InternalError( + [[error localizedDescription] UTF8String]); + } + + std::string trace_log_directory = [ns_documents_directory UTF8String]; + return trace_log_directory; } } // namespace mediapipe diff --git a/mediapipe/framework/profiler/testing/BUILD b/mediapipe/framework/profiler/testing/BUILD new file mode 100644 index 000000000..9064473e3 --- /dev/null +++ b/mediapipe/framework/profiler/testing/BUILD @@ -0,0 +1,30 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = ["//mediapipe/framework:__subpackages__"], +) + +cc_library( + name = "simple_calculator", + srcs = ["simple_calculator.cc"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) diff --git a/mediapipe/framework/profiler/testing/simple_calculator.cc b/mediapipe/framework/profiler/testing/simple_calculator.cc new file mode 100644 index 000000000..8931f2379 --- /dev/null +++ b/mediapipe/framework/profiler/testing/simple_calculator.cc @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { + +class SimpleCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + if (cc->InputSidePackets().HasTag("MAX_COUNT")) { + cc->InputSidePackets().Tag("MAX_COUNT").Set(); + } + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) final { + LOG(WARNING) << "Simple Calculator Process called, count_: " << count_; + int max_count = 1; + if (cc->InputSidePackets().HasTag("MAX_COUNT")) { + max_count = cc->InputSidePackets().Tag("MAX_COUNT").Get(); + } + if (count_ >= max_count) { + return tool::StatusStop(); + } + cc->Outputs().Index(0).Add(new int(count_), Timestamp(count_)); + ++count_; + return ::mediapipe::OkStatus(); + } + + private: + int count_ = 0; +}; +REGISTER_CALCULATOR(SimpleCalculator); + +} // namespace mediapipe diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc index 372b9462b..8dd5c53d2 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc @@ -61,9 +61,9 @@ ImmediateInputStreamHandler::ImmediateInputStreamHandler( timestamp_bounds_(std::move(tag_map)) {} NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness( - Timestamp* input_timestamp) { - Timestamp min_stream_timestamp = Timestamp::Done(); - *input_timestamp = Timestamp::Done(); + Timestamp* min_stream_timestamp) { + *min_stream_timestamp = Timestamp::Done(); + Timestamp input_timestamp = Timestamp::Done(); bool stream_became_done = false; for (CollectionItemId i = input_stream_managers_.BeginId(); @@ -72,9 +72,9 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness( bool empty; Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty); if (!empty) { - *input_timestamp = std::min(*input_timestamp, stream_timestamp); + input_timestamp = std::min(input_timestamp, stream_timestamp); } - min_stream_timestamp = std::min(min_stream_timestamp, stream_timestamp); + *min_stream_timestamp = std::min(*min_stream_timestamp, stream_timestamp); if (stream_timestamp != timestamp_bounds_.Get(i)) { if (stream_timestamp == Timestamp::Done()) { stream_became_done = true; @@ -83,16 +83,17 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness( } } - if (min_stream_timestamp == Timestamp::Done()) { + if (*min_stream_timestamp == Timestamp::Done()) { return NodeReadiness::kReadyForClose; } - if (*input_timestamp < Timestamp::Done()) { + if (input_timestamp < Timestamp::Done()) { + // On kReadyForProcess, the input_timestamp is returned. + *min_stream_timestamp = input_timestamp; return NodeReadiness::kReadyForProcess; } if (stream_became_done) { - *input_timestamp = min_stream_timestamp; return NodeReadiness::kReadyForProcess; } diff --git a/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler_test.cc index db060edc8..e2c679d86 100644 --- a/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler_test.cc @@ -1,3 +1,17 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include #include "mediapipe/framework/calculator_framework.h" diff --git a/mediapipe/framework/subgraph_test.cc b/mediapipe/framework/subgraph_test.cc new file mode 100644 index 000000000..1b994c3eb --- /dev/null +++ b/mediapipe/framework/subgraph_test.cc @@ -0,0 +1,79 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/subgraph.h" + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +// Because of portability issues, we include this directly. +#include "mediapipe/framework/port/status_matchers.h" // NOLINT(build/deprecated) + +namespace mediapipe { +namespace { + +class SubgraphTest : public ::testing::Test { + protected: + void TestGraphEnclosing(const std::string& subgraph_type_name) { + EXPECT_TRUE(SubgraphRegistry::IsRegistered(subgraph_type_name)); + + CalculatorGraphConfig config; + config.add_input_stream("in"); + CalculatorGraphConfig::Node* node = config.add_node(); + node->set_calculator(subgraph_type_name); + node->add_input_stream("INTS:in"); + node->add_output_stream("DUBS:dubs_tmp"); + node->add_output_stream("QUADS:quads"); + node = config.add_node(); + node->set_calculator("PassThroughCalculator"); + node->add_input_stream("dubs_tmp"); + node->add_output_stream("dubs"); + + std::vector dubs; + tool::AddVectorSink("dubs", &config, &dubs); + + std::vector quads; + tool::AddVectorSink("quads", &config, &quads); + + CalculatorGraph graph; + MEDIAPIPE_ASSERT_OK(graph.Initialize(config)); + MEDIAPIPE_ASSERT_OK(graph.StartRun({})); + + constexpr int kCount = 5; + for (int i = 0; i < kCount; ++i) { + MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream( + "in", MakePacket(i).At(Timestamp(i)))); + } + + MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("in")); + MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone()); + + EXPECT_EQ(dubs.size(), kCount); + EXPECT_EQ(quads.size(), kCount); + for (int i = 0; i < kCount; ++i) { + EXPECT_EQ(i * 2, dubs[i].Get()); + EXPECT_EQ(i * 4, quads[i].Get()); + } + } +}; + +// Tests registration of subgraph named "DubQuadTestSubgraph" using target +// "dub_quad_test_subgraph" from macro "mediapipe_simple_subgraph". +TEST_F(SubgraphTest, LinkedSubgraph) { + TestGraphEnclosing("DubQuadTestSubgraph"); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/testdata/BUILD b/mediapipe/framework/testdata/BUILD index 8f6ff59a8..75ee0802d 100644 --- a/mediapipe/framework/testdata/BUILD +++ b/mediapipe/framework/testdata/BUILD @@ -52,7 +52,7 @@ mediapipe_cc_proto_library( proto_library( name = "zoo_mutator_proto", srcs = ["zoo_mutator.proto"], - deps = ["@protobuf_archive//:any_proto"], + deps = ["@com_google_protobuf//:any_proto"], ) proto_library( diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 5a61200e5..bb7052fca 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -17,18 +17,57 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:private"]) -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load( + "//mediapipe/framework/port:build_config.bzl", + "mediapipe_cc_proto_library", +) +load( + "//mediapipe/framework/tool:mediapipe_graph.bzl", + "data_as_c_string", + "mediapipe_binary_graph", +) + +exports_files([ + "simple_subgraph_template.cc", +]) + +cc_library( + name = "text_to_binary_graph", + srcs = ["text_to_binary_graph.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/port:advanced_proto", + "//mediapipe/framework/port:commandlineflags", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], +) proto_library( name = "calculator_graph_template_proto", srcs = ["calculator_graph_template.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/framework/deps:proto_descriptor_proto", ], ) +java_proto_library( + name = "calculator_graph_template_java_proto", + visibility = ["//visibility:public"], + deps = [":calculator_graph_template_proto"], +) + +java_lite_proto_library( + name = "calculator_graph_template_java_proto_lite", + strict_deps = 0, + visibility = ["//visibility:public"], + deps = [":calculator_graph_template_proto"], +) + proto_library( name = "source_proto", srcs = ["source.proto"], @@ -43,7 +82,10 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", ], - visibility = ["//mediapipe/framework:__subpackages__"], + visibility = [ + "//mediapipe/framework:__subpackages__", + "//mediapipe/java/com/google/mediapipe/framework:__subpackages__", + ], deps = [":calculator_graph_template_proto"], ) @@ -55,6 +97,15 @@ mediapipe_cc_proto_library( deps = [":source_proto"], ) +cc_binary( + name = "encode_as_c_string", + srcs = ["encode_as_c_string.cc"], + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "fill_packet_set", srcs = ["fill_packet_set.cc"], @@ -464,8 +515,8 @@ cc_test( deps = [ ":simulation_clock", ":simulation_clock_executor", + "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:immediate_mux_calculator", - "//mediapipe/calculators/core:real_time_flow_limiter_calculator", "//mediapipe/calculators/core:round_robin_demux_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:executor", @@ -483,3 +534,71 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +mediapipe_binary_graph( + name = "test_binarypb", + graph = "//mediapipe/framework/tool/testdata:test_graph", + output_name = "test.binarypb", + visibility = ["//visibility:private"], +) + +data_as_c_string( + name = "test_binarypb_inc", + testonly = 1, + srcs = [":test_binarypb"], + outs = ["test_binarypb.inc"], +) + +proto_library( + name = "node_chain_subgraph_proto", + srcs = ["node_chain_subgraph.proto"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + +mediapipe_cc_proto_library( + name = "node_chain_subgraph_cc_proto", + srcs = ["node_chain_subgraph.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":node_chain_subgraph_proto"], +) + +cc_test( + name = "data_as_c_string_test", + srcs = [ + "data_as_c_string_test.cc", + ":test_binarypb_inc", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_test( + name = "subgraph_expansion_test", + size = "small", + srcs = ["subgraph_expansion_test.cc"], + deps = [ + ":node_chain_subgraph_cc_proto", + ":subgraph_expansion", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework:packet_generator_cc_proto", + "//mediapipe/framework:packet_set", + "//mediapipe/framework:packet_type", + "//mediapipe/framework:status_handler", + "//mediapipe/framework:subgraph", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/tool/testdata:dub_quad_test_subgraph", + "//mediapipe/framework/tool/testdata:nested_test_subgraph", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/framework/tool/calculator_graph_template.proto b/mediapipe/framework/tool/calculator_graph_template.proto index 7c951f082..c2fc6aa86 100644 --- a/mediapipe/framework/tool/calculator_graph_template.proto +++ b/mediapipe/framework/tool/calculator_graph_template.proto @@ -5,6 +5,9 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/deps/proto_descriptor.proto"; +option java_package = "com.google.mediapipe.proto"; +option java_outer_classname = "GraphTemplateProto"; + // A template rule or a template rule argument expression. message TemplateExpression { // A template parameter name or a literal value. diff --git a/mediapipe/framework/tool/data_as_c_string_test.cc b/mediapipe/framework/tool/data_as_c_string_test.cc new file mode 100644 index 000000000..3ed561635 --- /dev/null +++ b/mediapipe/framework/tool/data_as_c_string_test.cc @@ -0,0 +1,33 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe { +namespace { + +static const char my_graph[] = +#include "mediapipe/framework/tool/test_binarypb.inc" + ; // NOLINT(whitespace/semicolon) + +TEST(DataAsCString, CanDecodeCalculatorGraphConfig) { + CalculatorGraphConfig config; + bool success = config.ParseFromArray(my_graph, sizeof(my_graph) - 1); + EXPECT_TRUE(success); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/tool/encode_as_c_string.cc b/mediapipe/framework/tool/encode_as_c_string.cc new file mode 100644 index 000000000..a202deb09 --- /dev/null +++ b/mediapipe/framework/tool/encode_as_c_string.cc @@ -0,0 +1,62 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This program takes one input file and encodes its contents as a C++ +// std::string, which can be included in a C++ source file. It is similar to +// filewrapper (and borrows some of its code), but simpler. + +#include +#include +#include + +#include "absl/strings/escaping.h" + +int main(int argc, char** argv) { + if (argc != 2) { + std::cerr << "usage: encode_as_c_string input_file\n"; + return 1; + } + const std::string input_name = argv[1]; + std::ifstream input_file(input_name, + std::ios_base::in | std::ios_base::binary); + if (!input_file.is_open()) { + std::cerr << "cannot open '" << input_name << "'\n"; + return 2; + } + constexpr int kBufSize = 4096; + std::unique_ptr buf(new char[kBufSize]); + std::cout << "\""; + int line_len = 1; + while (1) { + input_file.read(buf.get(), kBufSize); + int count = input_file.gcount(); + if (count == 0) break; + for (int i = 0; i < count; ++i) { + std::string out = absl::CEscape(absl::string_view(&buf[i], 1)); + if (line_len + out.size() > 79) { + std::cout << "\"\n\""; + line_len = 1; + } + std::cout << out; + line_len += out.size(); + } + } + input_file.close(); + if (!input_file.eof()) { + std::cerr << "error reading '" << input_name << "'\n"; + return 2; + } + std::cout << "\"\n"; + return 0; +} diff --git a/mediapipe/framework/tool/mediapipe_graph.bzl b/mediapipe/framework/tool/mediapipe_graph.bzl index 80346c757..d6e7c56a5 100644 --- a/mediapipe/framework/tool/mediapipe_graph.bzl +++ b/mediapipe/framework/tool/mediapipe_graph.bzl @@ -16,26 +16,11 @@ Example: """ load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto", "generate_proto_descriptor_set") -load("//mediapipe/framework:transitive_protos.bzl", "transitive_proto_cc_libs", "transitive_proto_descriptor_sets", "transitive_protos") +load("//mediapipe/framework:transitive_protos.bzl", "transitive_protos") +load("//mediapipe/framework/deps:expand_template.bzl", "expand_template") -def mediapipe_binary_graph(name, graph = None, output_name = None, deps = [], testonly = None, **kwargs): - """Converts a graph from text format to binary format. - - Args: - name: the name of the encode_binary_proto rule generated by this macro. - graph: the BUILD label of a text-format MediaPipe graph. - output_name: the name of the file to which the binary serialization is - written. - deps: the BUILD labels of dependencies that provide any additional message - types used by the graph. The basic messages defined in calculator.proto - are always available, but any custom types (e.g. specific calculator - options) should be provided here. It is sufficient to provide targets - that depend on the required protos indirectly: this macro examines the - entire dependency tree, and does not build any dependencies except for - the protos it finds. - testonly: pass 1 if the graph is to be used only for tests. - **kwargs: any other arguments valid for encode_binary_proto. - """ +def mediapipe_binary_graph(name, graph = None, output_name = None, deps = [], testonly = False, **kwargs): + """Converts a graph from text format to binary format.""" if not graph: fail("No input graph file specified.") @@ -44,48 +29,124 @@ def mediapipe_binary_graph(name, graph = None, output_name = None, deps = [], te fail("Must specify the output_name.") transitive_protos( - name = name + "_gather_protos", + name = name + "_gather_cc_protos", deps = deps, testonly = testonly, ) - # This collects descriptor sets for tools that need them. - transitive_proto_descriptor_sets( - name = name + "_gather_proto_descriptor_sets", - deps = deps, - testonly = testonly, - ) - - # This collects the generated .a libraries for tools that need them. - transitive_proto_cc_libs( - name = name + "_gather_proto_libs", - deps = deps, - testonly = testonly, - ) - - # This generates a single descriptor set with a single invocation of the proto compiler. - # May be faster than using the descriptor sets from proto_library. - # We always pass at least the calculator proto since the proto compiler would fail - # if it were passed no protos. - generate_proto_descriptor_set( - name = name + "_proto_descriptor_set", + # Compile a simple proto parser binary using the deps. + native.cc_binary( + name = name + "_text_to_binary_graph", + visibility = ["//visibility:private"], deps = [ - name + "_gather_protos", - "//mediapipe/framework:calculator_proto", + "//mediapipe/framework/tool:text_to_binary_graph", + name + "_gather_cc_protos", ], + tags = ["manual"], testonly = testonly, ) - return encode_binary_proto( + # Invoke the proto parser binary. + native.genrule( name = name, - deps = [ - name + "_gather_protos", - "//mediapipe/framework:calculator_proto", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_proto", + srcs = [graph], + outs = [output_name], + cmd = ( + "$(location " + name + "_text_to_binary_graph" + ") " + + ("--proto_source=$(location %s) " % graph) + + ("--proto_output=\"$@\" ") + ), + tools = [name + "_text_to_binary_graph"], + testonly = testonly, + ) + +def data_as_c_string( + name, + srcs, + outs = None, + testonly = None): + """Encodes the data from a file as a C string literal. + + This produces a text file containing the quoted C string literal. It can be + included directly in a C++ source file. + + Args: + name: The name of the rule. + srcs: A list containing a single item, the file to encode. + outs: A list containing a single item, the name of the output text file. + Defaults to the rule name. + testonly: pass 1 if the graph is to be used only for tests. + """ + if len(srcs) != 1: + fail("srcs must be a single-element list") + if outs == None: + outs = [name] + native.genrule( + name = name, + srcs = srcs, + outs = outs, + cmd = "$(location //mediapipe/framework/tool:encode_as_c_string) \"$<\" > \"$@\"", + tools = ["//mediapipe/framework/tool:encode_as_c_string"], + testonly = testonly, + ) + +def mediapipe_simple_subgraph( + name, + register_as, + graph, + deps = [], + visibility = None, + testonly = None, + **kwargs): + """Defines a registered subgraph for inclusion in other graphs. + + Args: + name: name of the subgraph target to define. + register_as: name used to invoke this graph in supergraphs. Should be in + CamelCase. + graph: the BUILD label of a text-format MediaPipe graph. + deps: any calculators or subgraphs used by this graph. + visibility: The list of packages the subgraph should be visible to. + testonly: pass 1 if the graph is to be used only for tests. + **kwargs: Remaining keyword args, forwarded to cc_library. + """ + graph_base_name = graph.replace(":", "/").split("/")[-1].rsplit(".", 1)[0] + mediapipe_binary_graph( + name = name + "_graph", + graph = graph, + output_name = graph_base_name + ".binarypb", + deps = deps, + testonly = testonly, + ) + data_as_c_string( + name = name + "_inc", + srcs = [graph_base_name + ".binarypb"], + outs = [graph_base_name + ".inc"], + ) + + # cc_library for a linked mediapipe graph. + expand_template( + name = name + "_linked_cc", + template = "//mediapipe/framework/tool:simple_subgraph_template.cc", + out = name + "_linked.cc", + substitutions = { + "{{SUBGRAPH_CLASS_NAME}}": register_as, + "{{SUBGRAPH_INC_FILE_PATH}}": native.package_name() + "/" + graph_base_name + ".inc", + }, + testonly = testonly, + ) + native.cc_library( + name = name, + srcs = [ + name + "_linked.cc", + graph_base_name + ".inc", ], - message_type = "mediapipe.CalculatorGraphConfig", - input = graph, - output = output_name, + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:subgraph", + ] + deps, + alwayslink = 1, + visibility = visibility, testonly = testonly, **kwargs ) diff --git a/mediapipe/framework/tool/node_chain_subgraph.proto b/mediapipe/framework/tool/node_chain_subgraph.proto new file mode 100644 index 000000000..c6e8980a2 --- /dev/null +++ b/mediapipe/framework/tool/node_chain_subgraph.proto @@ -0,0 +1,19 @@ +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +// Options for NodeChainSubgraph. +message NodeChainSubgraphOptions { + extend CalculatorOptions { + optional NodeChainSubgraphOptions ext = 167210579; + } + + // The type of the node. The node must have exactly one input stream and + // exactly one output stream. + optional string node_type = 1; + + // How many copies of the node should be chained in series. + optional int32 chain_length = 2; +} diff --git a/mediapipe/framework/tool/simple_subgraph_template.cc b/mediapipe/framework/tool/simple_subgraph_template.cc new file mode 100644 index 000000000..1978a2955 --- /dev/null +++ b/mediapipe/framework/tool/simple_subgraph_template.cc @@ -0,0 +1,43 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This template is used by the mediapipe_simple_subgraph macro in +// //mediapipe/framework/tool/mediapipe_graph.bzl + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/subgraph.h" + +namespace mediapipe { + +static const char binary_graph[] = +#include "{{SUBGRAPH_INC_FILE_PATH}}" + ; // NOLINT(whitespace/semicolon) + +class {{SUBGRAPH_CLASS_NAME}} : public Subgraph { + public: + ::mediapipe::StatusOr GetConfig( + const SubgraphOptions& /*options*/) { + CalculatorGraphConfig config; + // Note: this is a binary protobuf serialization, and may include NUL + // bytes. The trailing NUL added to the std::string literal should be excluded. + if (config.ParseFromArray(binary_graph, sizeof(binary_graph) - 1)) { + return config; + } else { + return ::mediapipe::InternalError("Could not parse subgraph."); + } + } +}; +REGISTER_MEDIAPIPE_GRAPH({{SUBGRAPH_CLASS_NAME}}); + +} // namespace mediapipe diff --git a/mediapipe/framework/tool/simulation_clock_test.cc b/mediapipe/framework/tool/simulation_clock_test.cc index 458401dc5..3aa291387 100644 --- a/mediapipe/framework/tool/simulation_clock_test.cc +++ b/mediapipe/framework/tool/simulation_clock_test.cc @@ -46,7 +46,7 @@ class SimulationClockTest : public ::testing::Test { graph_config_ = ParseTextProtoOrDie(R"( input_stream: "input_packets_0" node { - calculator: 'RealTimeFlowLimiterCalculator' + calculator: 'FlowLimiterCalculator' input_stream_handler { input_stream_handler: 'ImmediateInputStreamHandler' } diff --git a/mediapipe/framework/tool/subgraph_expansion_test.cc b/mediapipe/framework/tool/subgraph_expansion_test.cc new file mode 100644 index 000000000..a469e702a --- /dev/null +++ b/mediapipe/framework/tool/subgraph_expansion_test.cc @@ -0,0 +1,529 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "mediapipe/framework/tool/subgraph_expansion.h" + +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/proto_ns.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/status_handler.h" +#include "mediapipe/framework/subgraph.h" +#include "mediapipe/framework/tool/node_chain_subgraph.pb.h" + +namespace mediapipe { + +namespace { + +class SimpleTestCalculator : public CalculatorBase { + public: + ::mediapipe::Status Process(CalculatorContext* cc) override { + return ::mediapipe::OkStatus(); + } + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + for (PacketType& type : cc->Inputs()) { + type.Set(); + } + for (PacketType& type : cc->Outputs()) { + type.Set(); + } + for (PacketType& type : cc->InputSidePackets()) { + type.Set(); + } + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(SimpleTestCalculator); +typedef SimpleTestCalculator SomeSourceCalculator; +typedef SimpleTestCalculator SomeSinkCalculator; +typedef SimpleTestCalculator SomeRegularCalculator; +typedef SimpleTestCalculator SomeAggregator; +REGISTER_CALCULATOR(SomeSourceCalculator); +REGISTER_CALCULATOR(SomeSinkCalculator); +REGISTER_CALCULATOR(SomeRegularCalculator); +REGISTER_CALCULATOR(SomeAggregator); + +class TestSubgraph : public Subgraph { + public: + ::mediapipe::StatusOr GetConfig( + const SubgraphOptions& /*options*/) override { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "DATA:input_1" + node { + name: "regular_node" + calculator: "SomeRegularCalculator" + input_stream: "input_1" + output_stream: "stream_a" + input_side_packet: "side" + } + node { + name: "simple_sink" + calculator: "SomeSinkCalculator" + input_stream: "stream_a" + } + packet_generator { + packet_generator: "SomePacketGenerator" + output_side_packet: "side" + } + )"); + return config; + } +}; +REGISTER_MEDIAPIPE_GRAPH(TestSubgraph); + +class PacketFactoryTestSubgraph : public Subgraph { + public: + ::mediapipe::StatusOr GetConfig( + const SubgraphOptions& /*options*/) override { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "DATA:input_1" + node { + name: "regular_node" + calculator: "SomeRegularCalculator" + input_stream: "input_1" + output_stream: "stream_a" + input_side_packet: "side" + } + node { + name: "simple_sink" + calculator: "SomeSinkCalculator" + input_stream: "stream_a" + } + packet_factory { + packet_factory: "SomePacketFactory" + output_side_packet: "side" + } + )"); + return config; + } +}; +REGISTER_MEDIAPIPE_GRAPH(PacketFactoryTestSubgraph); + +// This subgraph chains copies of the specified node in series. The node type +// and the number of copies of the node are specified in subgraph options. +class NodeChainSubgraph : public Subgraph { + public: + ::mediapipe::StatusOr GetConfig( + const SubgraphOptions& options) override { + const mediapipe::NodeChainSubgraphOptions& opts = + options.GetExtension(mediapipe::NodeChainSubgraphOptions::ext); + const ProtoString& node_type = opts.node_type(); + int chain_length = opts.chain_length(); + RET_CHECK(!node_type.empty()); + RET_CHECK_GT(chain_length, 0); + CalculatorGraphConfig config; + config.add_input_stream("INPUT:stream_0"); + config.add_output_stream(absl::StrCat("OUTPUT:stream_", chain_length)); + for (int i = 0; i < chain_length; ++i) { + CalculatorGraphConfig::Node* node = config.add_node(); + node->set_calculator(node_type); + node->add_input_stream(absl::StrCat("stream_", i)); + node->add_output_stream(absl::StrCat("stream_", i + 1)); + } + return config; + } +}; +REGISTER_MEDIAPIPE_GRAPH(NodeChainSubgraph); + +// A subgraph used in the ExecutorFieldOfNodeInSubgraphPreserved test. The +// subgraph contains a node with the executor field "custom_thread_pool". +class NodeWithExecutorSubgraph : public Subgraph { + public: + ::mediapipe::StatusOr GetConfig( + const SubgraphOptions& options) override { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "INPUT:foo" + output_stream: "OUTPUT:bar" + node { + calculator: "PassThroughCalculator" + input_stream: "foo" + output_stream: "bar" + executor: "custom_thread_pool" + } + )"); + return config; + } +}; +REGISTER_MEDIAPIPE_GRAPH(NodeWithExecutorSubgraph); + +// A subgraph used in the ExecutorFieldOfNodeInSubgraphPreserved test. The +// subgraph contains a NodeWithExecutorSubgraph. +class EnclosingSubgraph : public Subgraph { + public: + ::mediapipe::StatusOr GetConfig( + const SubgraphOptions& options) override { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "IN:in" + output_stream: "OUT:out" + node { + calculator: "NodeWithExecutorSubgraph" + input_stream: "INPUT:in" + output_stream: "OUTPUT:out" + } + )"); + return config; + } +}; +REGISTER_MEDIAPIPE_GRAPH(EnclosingSubgraph); + +TEST(SubgraphExpansionTest, TransformStreamNames) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: "SomeSinkCalculator" + input_stream: "input_1" + input_stream: "VIDEO:input_2" + input_stream: "AUDIO:0:input_3" + input_stream: "AUDIO:1:input_4" + } + )"); + CalculatorGraphConfig expected_config = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: "SomeSinkCalculator" + input_stream: "input_1_foo" + input_stream: "VIDEO:input_2_foo" + input_stream: "AUDIO:0:input_3_foo" + input_stream: "AUDIO:1:input_4_foo" + } + )"); + auto add_foo = [](absl::string_view s) { return absl::StrCat(s, "_foo"); }; + MEDIAPIPE_EXPECT_OK(tool::TransformStreamNames( + (*config.mutable_node())[0].mutable_input_stream(), add_foo)); + EXPECT_THAT(config, mediapipe::EqualsProto(expected_config)); +} + +TEST(SubgraphExpansionTest, TransformNames) { + CalculatorGraphConfig config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input_1" + node { + calculator: "SomeRegularCalculator" + name: "bob" + input_stream: "input_1" + input_stream: "VIDEO:input_2" + input_stream: "AUDIO:0:input_3" + input_stream: "AUDIO:1:input_4" + output_stream: "output_1" + } + node { + calculator: "SomeRegularCalculator" + input_stream: "output_1" + output_stream: "output_2" + } + )"); + CalculatorGraphConfig expected_config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "__sg0_input_1" + node { + calculator: "SomeRegularCalculator" + name: "__sg0_bob" + input_stream: "__sg0_input_1" + input_stream: "VIDEO:__sg0_input_2" + input_stream: "AUDIO:0:__sg0_input_3" + input_stream: "AUDIO:1:__sg0_input_4" + output_stream: "__sg0_output_1" + } + node { + calculator: "SomeRegularCalculator" + input_stream: "__sg0_output_1" + output_stream: "__sg0_output_2" + } + )"); + auto add_prefix = [](absl::string_view s) { + return absl::StrCat("__sg0_", s); + }; + MEDIAPIPE_EXPECT_OK(tool::TransformNames(&config, add_prefix)); + EXPECT_THAT(config, mediapipe::EqualsProto(expected_config)); +} + +TEST(SubgraphExpansionTest, FindCorrespondingStreams) { + CalculatorGraphConfig config1 = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input_1" + input_stream: "VIDEO:input_2" + input_stream: "AUDIO:0:input_3" + input_stream: "AUDIO:1:input_4" + )"); + CalculatorGraphConfig config2 = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: "SomeSubgraph" + input_stream: "foo" + input_stream: "VIDEO:bar" + input_stream: "AUDIO:0:baz" + input_stream: "AUDIO:1:qux" + } + )"); + std::map stream_map; + MEDIAPIPE_EXPECT_OK(tool::FindCorrespondingStreams( + &stream_map, config1.input_stream(), config2.node()[0].input_stream())); + EXPECT_THAT(stream_map, + testing::UnorderedElementsAre(testing::Pair("input_1", "foo"), + testing::Pair("input_2", "bar"), + testing::Pair("input_3", "baz"), + testing::Pair("input_4", "qux"))); +} + +TEST(SubgraphExpansionTest, FindCorrespondingStreamsNonexistentTag) { + // The VIDEO tag does not exist in the subgraph. + CalculatorGraphConfig config1 = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input_1" + input_stream: "AUDIO:0:input_3" + input_stream: "AUDIO:1:input_4" + )"); + CalculatorGraphConfig config2 = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: "SomeSubgraph" + input_stream: "foo" + input_stream: "VIDEO:bar" + input_stream: "AUDIO:0:baz" + input_stream: "AUDIO:1:qux" + } + )"); + std::map stream_map; + auto status = tool::FindCorrespondingStreams( + &stream_map, config1.input_stream(), config2.node()[0].input_stream()); + EXPECT_THAT(status.message(), + + testing::AllOf( + // Problematic tag. + testing::HasSubstr("VIDEO"), + // Error. + testing::HasSubstr("does not exist"))); +} + +TEST(SubgraphExpansionTest, FindCorrespondingStreamsTooFewIndexes) { + // The AUDIO tag has too few indexes in the subgraph (1 vs. 2). + CalculatorGraphConfig config1 = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input_1" + input_stream: "VIDEO:input_2" + input_stream: "AUDIO:0:input_3" + )"); + CalculatorGraphConfig config2 = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: "SomeSubgraph" + input_stream: "foo" + input_stream: "VIDEO:bar" + input_stream: "AUDIO:0:baz" + input_stream: "AUDIO:1:qux" + } + )"); + std::map stream_map; + auto status = tool::FindCorrespondingStreams( + &stream_map, config1.input_stream(), config2.node()[0].input_stream()); + + EXPECT_THAT(status.message(), + testing::AllOf( + // Problematic tag. + testing::HasSubstr("AUDIO"), + // Error. + testing::HasSubstr(" 2 "), testing::HasSubstr(" 1 "))); +} + +TEST(SubgraphExpansionTest, ConnectSubgraphStreams) { + CalculatorGraphConfig subgraph = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "A:input_1" + input_stream: "B:input_2" + output_stream: "O:output_2" + input_side_packet: "SI:side_input" + output_side_packet: "SO:side_output" + node { + calculator: "SomeRegularCalculator" + input_stream: "input_1" + input_stream: "VIDEO:input_2" + input_side_packet: "side_input" + output_stream: "output_1" + } + node { + calculator: "SomeRegularCalculator" + input_stream: "input_1" + input_stream: "output_1" + output_stream: "output_2" + } + packet_generator { + packet_generator: "SomeGenerator" + input_side_packet: "side_input" + output_side_packet: "side_output" + } + )"); + CalculatorGraphConfig supergraph = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + calculator: "SomeSubgraph" + input_stream: "A:foo" + input_stream: "B:bar" + output_stream: "O:foobar" + input_side_packet: "SI:flip" + output_side_packet: "SO:flop" + } + )"); + // Note: graph input streams, output streams, and side packets on the + // subgraph are not changed because they are going to be discarded anyway. + CalculatorGraphConfig expected_subgraph = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "A:input_1" + input_stream: "B:input_2" + output_stream: "O:output_2" + input_side_packet: "SI:side_input" + output_side_packet: "SO:side_output" + node { + calculator: "SomeRegularCalculator" + input_stream: "foo" + input_stream: "VIDEO:bar" + input_side_packet: "flip" + output_stream: "output_1" + } + node { + calculator: "SomeRegularCalculator" + input_stream: "foo" + input_stream: "output_1" + output_stream: "foobar" + } + packet_generator { + packet_generator: "SomeGenerator" + input_side_packet: "flip" + output_side_packet: "flop" + } + )"); + MEDIAPIPE_EXPECT_OK( + tool::ConnectSubgraphStreams(supergraph.node()[0], &subgraph)); + EXPECT_THAT(subgraph, mediapipe::EqualsProto(expected_subgraph)); +} + +TEST(SubgraphExpansionTest, ExpandSubgraphs) { + CalculatorGraphConfig supergraph = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + name: "simple_source" + calculator: "SomeSourceCalculator" + output_stream: "foo" + } + node { calculator: "TestSubgraph" input_stream: "DATA:foo" } + )"); + CalculatorGraphConfig expected_graph = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + name: "simple_source" + calculator: "SomeSourceCalculator" + output_stream: "foo" + } + node { + name: "__sg0_regular_node" + calculator: "SomeRegularCalculator" + input_stream: "foo" + output_stream: "__sg0_stream_a" + input_side_packet: "__sg0_side" + } + node { + name: "__sg0_simple_sink" + calculator: "SomeSinkCalculator" + input_stream: "__sg0_stream_a" + } + packet_generator { + packet_generator: "SomePacketGenerator" + output_side_packet: "__sg0_side" + } + )"); + MEDIAPIPE_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); + EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); +} + +TEST(SubgraphExpansionTest, ValidateSubgraphFields) { + CalculatorGraphConfig supergraph = + ::mediapipe::ParseTextProtoOrDie(R"( + node { + name: "simple_source" + calculator: "SomeSourceCalculator" + output_stream: "foo" + } + node { + name: "foo_subgraph" + calculator: "TestSubgraph" + input_stream: "DATA:foo" + buffer_size_hint: -1 # This field is only applicable to calculators. + } + )"); + ::mediapipe::Status s1 = tool::ValidateSubgraphFields(supergraph.node(1)); + EXPECT_EQ(s1.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(s1.message(), testing::HasSubstr("foo_subgraph")); + + ::mediapipe::Status s2 = tool::ExpandSubgraphs(&supergraph); + EXPECT_EQ(s2.code(), ::mediapipe::StatusCode::kInvalidArgument); + EXPECT_THAT(s2.message(), testing::HasSubstr("foo_subgraph")); +} + +// A test that captures the use case of CL 191001940. The "executor" field of +// a node inside a subgraph should be preserved, not mapped or mangled. This +// test will help us detect breakage of this use case when we implement +// subgraph executor support in the future. +TEST(SubgraphExpansionTest, ExecutorFieldOfNodeInSubgraphPreserved) { + CalculatorGraphConfig supergraph = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input" + executor { + name: "custom_thread_pool" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 4 } + } + } + node { + calculator: "EnclosingSubgraph" + input_stream: "IN:input" + output_stream: "OUT:output" + } + )"); + CalculatorGraphConfig expected_graph = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input" + executor { + name: "custom_thread_pool" + type: "ThreadPoolExecutor" + options { + [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 4 } + } + } + node { + calculator: "PassThroughCalculator" + input_stream: "input" + output_stream: "output" + executor: "custom_thread_pool" + } + )"); + MEDIAPIPE_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); + EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/tool/testdata/BUILD b/mediapipe/framework/tool/testdata/BUILD new file mode 100644 index 000000000..fd2c8ae5c --- /dev/null +++ b/mediapipe/framework/tool/testdata/BUILD @@ -0,0 +1,57 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//mediapipe:__subpackages__"]) + +load( + "//mediapipe/framework/tool:mediapipe_graph.bzl", + "mediapipe_simple_subgraph", +) + +filegroup( + name = "test_graph", + srcs = ["test.pbtxt"], +) + +exports_files([ + "test.pbtxt", + "dub_quad_test_subgraph.pbtxt", + "nested_test_subgraph.pbtxt", +]) + +mediapipe_simple_subgraph( + name = "dub_quad_test_subgraph", + testonly = 1, + graph = "dub_quad_test_subgraph.pbtxt", + register_as = "DubQuadTestSubgraph", + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:test_calculators", + ], +) + +mediapipe_simple_subgraph( + name = "nested_test_subgraph", + testonly = 1, + graph = "nested_test_subgraph.pbtxt", + register_as = "NestedTestSubgraph", + visibility = ["//visibility:public"], + deps = [ + ":dub_quad_test_subgraph", + "//mediapipe/framework:test_calculators", + ], +) diff --git a/mediapipe/framework/tool/testdata/dub_quad_test_subgraph.pbtxt b/mediapipe/framework/tool/testdata/dub_quad_test_subgraph.pbtxt new file mode 100644 index 000000000..d3a575411 --- /dev/null +++ b/mediapipe/framework/tool/testdata/dub_quad_test_subgraph.pbtxt @@ -0,0 +1,13 @@ +input_stream: "INTS:ints" +output_stream: "DUBS:doubled" +output_stream: "QUADS:quadrupled" +node { + calculator: "DoubleIntCalculator" + input_stream: "ints" + output_stream: "doubled" +} +node { + calculator: "DoubleIntCalculator" + input_stream: "doubled" + output_stream: "quadrupled" +} diff --git a/mediapipe/framework/tool/testdata/nested_test_subgraph.pbtxt b/mediapipe/framework/tool/testdata/nested_test_subgraph.pbtxt new file mode 100644 index 000000000..00e1bed35 --- /dev/null +++ b/mediapipe/framework/tool/testdata/nested_test_subgraph.pbtxt @@ -0,0 +1,20 @@ +input_stream: "INTS:ints" +output_stream: "DUBS:doubled" +output_stream: "QUADS:quadrupled" +output_stream: "OCTS:octupled" +node { + calculator: "DubQuadTestSubgraph" + input_stream: "INTS:ints" + output_stream: "DUBS:doubled" + output_stream: "QUADS:quadrupled" +} +node { + calculator: "DoubleIntCalculator" + input_stream: "quadrupled" + output_stream: "octupled" + # The following is to ensure we handle NULs correctly. + input_stream_info { + tag_index: ":0" # 'quadrupled' + back_edge: false # The false boolean value is encoded as a zero byte. + } +} diff --git a/mediapipe/framework/tool/testdata/test.pbtxt b/mediapipe/framework/tool/testdata/test.pbtxt new file mode 100644 index 000000000..4d3c96e46 --- /dev/null +++ b/mediapipe/framework/tool/testdata/test.pbtxt @@ -0,0 +1,5 @@ +node { + calculator: "PassThroughCalculator" + input_stream: "in" + output_stream: "out" +} diff --git a/mediapipe/framework/tool/text_to_binary_graph.cc b/mediapipe/framework/tool/text_to_binary_graph.cc new file mode 100644 index 000000000..b2c96d790 --- /dev/null +++ b/mediapipe/framework/tool/text_to_binary_graph.cc @@ -0,0 +1,111 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// A command line utility to parse a text proto and output a binary proto. + +#include + +#include +#include + +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/port/advanced_proto_inc.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/commandlineflags.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +DEFINE_string(proto_source, "", + "The template source file containing CalculatorGraphConfig " + "protobuf text with inline template params."); +DEFINE_string( + proto_output, "", + "An output template file in binary CalculatorGraphTemplate form."); + +#define EXIT_IF_ERROR(status) \ + if (!status.ok()) { \ + LOG(ERROR) << status; \ + return EXIT_FAILURE; \ + } + +namespace mediapipe { + +mediapipe::Status ReadProto(proto_ns::io::ZeroCopyInputStream* in, + bool read_text, const std::string& source, + proto_ns::Message* result) { + if (read_text) { + RET_CHECK(proto_ns::TextFormat::Parse(in, result)) + << "could not parse text proto: " << source; + } else { + RET_CHECK(result->ParseFromZeroCopyStream(in)) + << "could not parse binary proto: " << source; + } + return mediapipe::OkStatus(); +} + +mediapipe::Status WriteProto(const proto_ns::Message& message, bool write_text, + const std::string& dest, + proto_ns::io::ZeroCopyOutputStream* out) { + if (write_text) { + RET_CHECK(proto_ns::TextFormat::Print(message, out)) + << "could not write text proto to: " << dest; + } else { + RET_CHECK(message.SerializeToZeroCopyStream(out)) + << "could not write binary proto to: " << dest; + } + return mediapipe::OkStatus(); +} + +// Read a proto from a text or a binary file. +mediapipe::Status ReadFile(const std::string& proto_source, bool read_text, + proto_ns::Message* result) { + std::ifstream ifs(proto_source); + proto_ns::io::IstreamInputStream in(&ifs); + RETURN_IF_ERROR(ReadProto(&in, read_text, proto_source, result)); + return mediapipe::OkStatus(); +} + +// Write a proto to a text or a binary file. +mediapipe::Status WriteFile(const std::string& proto_output, bool write_text, + const proto_ns::Message& message) { + std::ofstream ofs(proto_output, std::ofstream::out | std::ofstream::trunc); + proto_ns::io::OstreamOutputStream out(&ofs); + RETURN_IF_ERROR(WriteProto(message, write_text, proto_output, &out)); + return mediapipe::OkStatus(); +} + +} // namespace mediapipe + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + // Validate command line options. + mediapipe::Status status; + if (FLAGS_proto_source.empty()) { + status.Update( + ::mediapipe::InvalidArgumentError("--proto_source must be specified")); + } + if (FLAGS_proto_output.empty()) { + status.Update( + ::mediapipe::InvalidArgumentError("--proto_output must be specified")); + } + if (!status.ok()) { + return EXIT_FAILURE; + } + mediapipe::CalculatorGraphConfig config; + EXIT_IF_ERROR(mediapipe::ReadFile(FLAGS_proto_source, true, &config)); + EXIT_IF_ERROR(mediapipe::WriteFile(FLAGS_proto_output, false, config)); + return EXIT_SUCCESS; +} diff --git a/mediapipe/framework/transitive_protos.bzl b/mediapipe/framework/transitive_protos.bzl index fe06b4df0..76c04c776 100644 --- a/mediapipe/framework/transitive_protos.bzl +++ b/mediapipe/framework/transitive_protos.bzl @@ -1,117 +1,37 @@ -"""This rule gathers all .proto files used by all of its dependencies. +"""Extract a cc_library compatible dependency with only the top level proto rules.""" -The entire dependency tree is searched. The search crosses through cc_library -rules and portable_proto_library rules to collect the transitive set of all -.proto dependencies. This is provided to other rules in the form of a "proto" -provider, using the transitive_sources field. +ProtoLibsInfo = provider(fields = ["targets", "out"]) -This rule uses aspects. For general information on the concept, see: -- go/bazel-aspects-ides-tools -- go/bazel-aspects +def _get_proto_rules(deps, proto_rules = None): + useful_deps = [dep for dep in deps if hasattr(dep, "proto_rules")] + if proto_rules == None: + proto_rules = [] + for dep in useful_deps: + proto_rules = proto_rules + dep.proto_rules + return proto_rules -The basic rule is transitive_protos. Example: +def _proto_rules_aspect_impl(target, ctx): + # Make sure the rule has a srcs attribute. + proto_rules = [] + found_cc_proto = False + if hasattr(ctx.rule.attr, "srcs") and len(ctx.rule.attr.srcs) == 1: + for f in ctx.rule.attr.srcs[0].files.to_list(): + if f.basename.endswith(".pb.cc"): + proto_rules = [target[CcInfo]] + found = True + break -proto_library( - name = "a_proto_library", - srcs = ["a.proto], -) - -proto_library( - name = "b_proto_library", - srcs = ["b.proto], -) - -cc_library( - name = "a_cc_library", - deps = ["b_proto_library], -) - -transitive_protos( - name = "all_my_protos", - deps = [ - "a_proto_library", - "a_cc_library", - ], -) - -all_my_protos will gather all proto files used in its dependency tree; in this -case, ["a.proto", "b.proto"]. These are provided as the default outputs of this -rule, so you can place the rule in any context that requires a list of files, -and also as a "proto" provider, for use by any rules that would normally depend -on proto_library. - -The dependency tree is explored using an aspect, transitive_protos_aspect. This -aspect propagates across two attributes, "deps" and "hdrs". The latter is used -for compatibility with portable_proto_library; see comments below and in that -file for more details. - -At each visited node in the tree, the aspect collects protos: -- direct_sources from the proto provider in the current node. This is filled in - by proto_library nodes, and also by piggyback_header nodes (see - portable_proto_build_defs.bzl). -- protos from the transitive_protos provider in dependency nodes, found from - both the "deps" and the "hdrs" aspect. -Then it puts all the protos in the protos field of the transitive_protos -provider which it generates. This is how each node sends its gathered protos up -the tree. -""" - -def _gather_transitive_protos_deps(deps, my_protos = [], my_descriptors = [], my_proto_libs = []): - useful_deps = [dep for dep in deps if hasattr(dep, "transitive_protos")] - protos = depset( - my_protos, - transitive = [dep.transitive_protos.protos for dep in useful_deps], - ) - proto_libs = depset( - my_proto_libs, - transitive = [dep.transitive_protos.proto_libs for dep in useful_deps], - ) - descriptors = depset( - my_descriptors, - transitive = [dep.transitive_protos.descriptors for dep in useful_deps], - ) + if not found_cc_proto: + deps = ctx.rule.attr.deps[:] if hasattr(ctx.rule.attr, "deps") else [] + proto_rules = _get_proto_rules(deps, proto_rules) return struct( - transitive_protos = struct( - protos = protos, - descriptors = descriptors, - proto_libs = proto_libs, - ), + proto_rules = proto_rules, ) -def _transitive_protos_aspect_impl(target, ctx): - """Implementation of the transitive_protos_aspect aspect. - - Args: - target: The current target. - ctx: The current rule context. - Returns: - A transitive_protos provider. - """ - protos = target.proto.direct_sources if hasattr(target, "proto") else [] - deps = ctx.rule.attr.deps[:] if hasattr(ctx.rule.attr, "deps") else [] - descriptors = [target.proto.direct_descriptor_set] if hasattr(target, "proto") and hasattr(target.proto, "direct_descriptor_set") else [] - - proto_libs = [] - if ctx.rule.kind == "proto_library": - proto_libs = [f for f in target.files.to_list() if f.extension == "a"] - - # Searching through the hdrs attribute is necessary because of - # portable_proto_library. In portable mode, that macro - # generates a cc_library that does not depend on any proto_libraries, so - # the .proto files do not appear in its dependency tree. - # portable_proto_library cannot add arbitrary providers or attributes to - # a cc_library rule, so instead it piggybacks the provider on a rule that - # generates a header, which occurs in the hdrs attribute of the cc_library. - if hasattr(ctx.rule.attr, "hdrs"): - deps += ctx.rule.attr.hdrs - result = _gather_transitive_protos_deps(deps, protos, descriptors, proto_libs) - return result - -transitive_protos_aspect = aspect( - implementation = _transitive_protos_aspect_impl, - attr_aspects = ["deps", "hdrs"], - attrs = {}, +proto_rules_aspect = aspect( + implementation = _proto_rules_aspect_impl, + attr_aspects = ["deps"], ) def _transitive_protos_impl(ctx): @@ -123,72 +43,19 @@ def _transitive_protos_impl(ctx): A proto provider (with transitive_sources and transitive_descriptor_sets filled in), and marks all transitive sources as default output. """ - gathered = _gather_transitive_protos_deps(ctx.attr.deps) - protos = gathered.transitive_protos.protos - descriptors = gathered.transitive_protos.descriptors - return struct( - proto = struct( - transitive_sources = protos, - transitive_descriptor_sets = descriptors, - ), - files = depset(protos), - ) + cc_infos = [] + for dep in ctx.attr.deps: + for dep_proto_rule in dep.proto_rules: + cc_infos.append(dep_proto_rule) + return [cc_common.merge_cc_infos(cc_infos = cc_infos)] transitive_protos = rule( implementation = _transitive_protos_impl, - attrs = { - "deps": attr.label_list( - aspects = [transitive_protos_aspect], - ), - }, -) - -def _transitive_proto_cc_libs_impl(ctx): - """Implementation of transitive_proto_cc_libs rule. - - NOTE: this only works on Bazel, not exobazel. - - Args: - ctx: The rule context. - - Returns: - All transitive proto C++ .a files as default output. - """ - gathered = _gather_transitive_protos_deps(ctx.attr.deps) - proto_libs = gathered.transitive_protos.proto_libs - return struct( - files = proto_libs, - ) - -transitive_proto_cc_libs = rule( - implementation = _transitive_proto_cc_libs_impl, - attrs = { - "deps": attr.label_list( - aspects = [transitive_protos_aspect], - ), - }, -) - -def _transitive_proto_descriptor_sets_impl(ctx): - """Implementation of transitive_proto_descriptor_sets rule. - - Args: - ctx: The rule context. - - Returns: - All transitive proto descriptor files as default output. - """ - gathered = _gather_transitive_protos_deps(ctx.attr.deps) - descriptors = gathered.transitive_protos.descriptors - return struct( - files = descriptors, - ) - -transitive_proto_descriptor_sets = rule( - implementation = _transitive_proto_descriptor_sets_impl, - attrs = { - "deps": attr.label_list( - aspects = [transitive_protos_aspect], - ), - }, + attrs = + { + "deps": attr.label_list( + aspects = [proto_rules_aspect], + ), + }, + provides = [CcInfo], ) diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 24677b149..414ed87a2 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -218,6 +218,8 @@ std::string NodeTypeInfo::NodeTypeToString(NodeType node_type) { node_.type = NodeType::CALCULATOR; node_.index = node_index; RETURN_IF_ERROR(contract_.Initialize(node)); + contract_.SetNodeName( + CanonicalNodeName(validated_graph.Config(), node_index)); // Ensure input_stream_info field is well formed. if (!node.input_stream_info().empty()) { diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index eb9aea6ee..eea8f0ed5 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -16,6 +16,8 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) +load("//mediapipe/gpu:metal.bzl", "metal_library") +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") # Disabling GPU support is sometimes useful on desktop Linux because SwiftShader can @@ -43,28 +45,61 @@ cc_library( deps = [":gpu_service"], ) +GL_BASE_LINK_OPTS = select({ + "//conditions:default": [], + "//mediapipe:android": [ + "-lGLESv2", + "-lEGL", + # Note: on Android, libGLESv3.so is normally a symlink to + # libGLESv2.so, so we don't need to link to it. In fact, we + # do not _want_ to link to it, or we would be unable to load + # on API level < 18, where the symlink is missing entirely. + # Note: if we ever find a strange version of Android where the + # GLESv3 library is not a symlink, we will have to load it at + # runtime. Weak GLESv3 symbols will still be resolved if we + # load it early enough. + ], + "//mediapipe:apple": [ + "-framework OpenGLES", + "-framework CoreVideo", + ], + "//mediapipe:macos": [ + "-framework OpenGL", + "-framework CoreVideo", + ], +}) + +# This is @unused internally. +GL_BASE_LINK_OPTS_OSS = GL_BASE_LINK_OPTS + select({ + "//conditions:default": [ + # Use GLES/EGL on linux. + # Requires support from graphics card driver (nvidia,mesa,etc..) + # and libraries to be installed. + # Ex: libegl1-mesa-dev libgles2-mesa-dev, or libegl1-nvidia libgles2-nvidia, etc... + "-lGLESv2", + "-lEGL", + ], + "//mediapipe:android": [], + "//mediapipe:apple": [], + "//mediapipe:macos": [], + ":disable_gpu": [], +}) + cc_library( name = "gl_base", - features = ["-layering_check"], - linkopts = select({ - "//conditions:default": [], - "//mediapipe:android": [ - "-lGLESv2", - "-lEGL", - # Note: on Android, libGLESv3.so is normally a symlink to - # libGLESv2.so, so we don't need to link to it. In fact, we - # do not _want_ to link to it, or we would be unable to load - # on API level < 18, where the symlink is missing entirely. - # Note: if we ever find a strange version of Android where the - # GLESv3 library is not a symlink, we will have to load it at - # runtime. Weak GLESv3 symbols will still be resolved if we - # load it early enough. + defines = select({ + "//mediapipe:apple": [ + "GLES_SILENCE_DEPRECATION=1", ], + "//conditions:default": [], }), + features = ["-layering_check"], + linkopts = GL_BASE_LINK_OPTS_OSS, textual_hdrs = ["gl_base.h"], visibility = ["//visibility:public"], deps = [":gl_base_hdr"] + select({ "//mediapipe:android": [], + "//mediapipe:apple": [], "//conditions:default": [ ], }), @@ -75,9 +110,21 @@ cc_library( hdrs = ["gl_base.h"], features = ["-layering_check"], # Note: need the frameworks on Apple platforms to get the headers. + linkopts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-framework OpenGLES", + "-framework CoreVideo", + ], + "//mediapipe:macos": [ + "-framework OpenGL", + "-framework CoreVideo", + ], + }), visibility = ["//visibility:public"], deps = select({ "//mediapipe:android": [], + "//mediapipe:apple": [], "//conditions:default": [ ], }), @@ -101,15 +148,27 @@ cc_library( "//conditions:default": [ "gl_context_egl.cc", ], + "//mediapipe:apple": [ + "gl_context_eagl.cc", + ], + "//mediapipe:macos": [ + "gl_context_nsgl.cc", + ], }), hdrs = ["gl_context.h"], + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-x objective-c++", + "-std=c++11", + "-fobjc-arc", + ], + }), visibility = ["//visibility:public"], deps = [ ":gl_base", ":gl_thread_collector", "//mediapipe/framework:executor", - "//mediapipe/framework:mediapipe_profiling", - "//mediapipe/framework:timestamp", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -119,7 +178,14 @@ cc_library( "@com_google_absl//absl/debugging:leak_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", - ], + "//mediapipe/framework:mediapipe_profiling", + "//mediapipe/framework:timestamp", + ] + select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "//mediapipe/objc:CFHolder", + ], + }), ) cc_library( @@ -150,6 +216,13 @@ cc_library( "//conditions:default": [ ":gl_texture_buffer", ], + "//mediapipe:apple": [ + "//mediapipe/objc:CFHolder", + ], + "//mediapipe:macos": [ + "//mediapipe/objc:CFHolder", + ":gl_texture_buffer", + ], }), ) @@ -167,6 +240,58 @@ cc_library( ], ) +objc_library( + name = "pixel_buffer_pool_util", + srcs = ["pixel_buffer_pool_util.mm"], + hdrs = ["pixel_buffer_pool_util.h"], + copts = [ + "-std=c++11", + "-Wno-shorten-64-to-32", + ], + sdk_frameworks = [ + "CoreVideo", + ], + visibility = ["//visibility:public"], +) + +objc_library( + name = "MPPGraphGPUData", + srcs = [ + "MPPGraphGPUData.mm", + "gpu_shared_data_internal.cc", + ], + hdrs = ["MPPGraphGPUData.h"], + copts = [ + "-x objective-c++", + "-std=c++11", + "-Wno-shorten-64-to-32", + ], + sdk_frameworks = [ + "CoreVideo", + "Metal", + ] + select({ + "//conditions:default": [ + "OpenGLES", + ], + "//mediapipe:macos": [ + "OpenGL", + "AppKit", + ], + }), + visibility = ["//visibility:public"], + deps = [ + ":gl_base", + ":gl_context", + ":gpu_buffer_multi_pool", + ":gpu_shared_data_header", + ":graph_support", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework/port:ret_check", + "//mediapipe/gpu:gl_context_options_cc_proto", + "@google_toolbox_for_mac//:GTM_Defines", + ], +) + proto_library( name = "gl_context_options_proto", srcs = ["gl_context_options.proto"], @@ -188,7 +313,12 @@ cc_library( name = "gpu_shared_data_header", textual_hdrs = [ "gpu_shared_data_internal.h", - ], + ] + select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "MPPGraphGPUData.h", + ], + }), visibility = ["//visibility:private"], deps = [ ":gl_base", @@ -196,23 +326,46 @@ cc_library( ], ) -cc_library( +alias( name = "gpu_shared_data_internal", + actual = select({ + "//conditions:default": ":gpu_shared_data_internal_actual", + ":disable_gpu": ":gpu_shared_data_internal_stub", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "gpu_shared_data_internal_stub", + hdrs = [ + "gpu_shared_data_internal.h", + ], + defines = ["MEDIAPIPE_DISABLE_GPU"], + visibility = ["//visibility:private"], + deps = [ + ":graph_support", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_node", + "//mediapipe/framework:executor", + "//mediapipe/framework/deps:no_destructor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/gpu:gl_context_options_cc_proto", + ], +) + +cc_library( + name = "gpu_shared_data_internal_actual", srcs = select({ "//conditions:default": [ "gpu_shared_data_internal.cc", ], - # iOS uses an Objective-C++ version of this, built in MediaPipeGraphGPUData. - ":disable_gpu": [], + # iOS uses an Objective-C++ version of this, built in MPPGraphGPUData. + "//mediapipe:apple": [], }), hdrs = [ "gpu_shared_data_internal.h", ], - defines = select({ - "//conditions:default": [], - ":disable_gpu": ["MEDIAPIPE_DISABLE_GPU"], - }), - visibility = ["//visibility:public"], + visibility = ["//visibility:private"], deps = [ "//mediapipe/gpu:gl_context_options_cc_proto", ":graph_support", @@ -221,17 +374,15 @@ cc_library( "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/deps:no_destructor", + ":gl_base", + ":gl_context", + ":gpu_buffer_multi_pool", + ":gpu_shared_data_header", ] + select({ - "//conditions:default": [ - ":gl_base", - ":gl_context", - ":gpu_buffer_multi_pool", - ":gpu_shared_data_header", - ], - ":disable_gpu": [], - }) + select({ "//conditions:default": [], - ":disable_gpu": [], + "//mediapipe:apple": [ + ":MPPGraphGPUData", + ], }), ) @@ -241,11 +392,24 @@ cc_library( "//conditions:default": [ "gl_texture_buffer_pool.cc", ], + "//mediapipe:apple": [], + "//mediapipe:macos": [ + "gl_texture_buffer_pool.cc", + ], }), hdrs = ["gpu_buffer_multi_pool.h"] + select({ "//conditions:default": [ "gl_texture_buffer_pool.h", ], + "//mediapipe:apple": [ + # The inclusions check does not see that this is provided by + # pixel_buffer_pool_util, so we include it here too. This is + # b/28066691. + "pixel_buffer_pool_util.h", + ], + "//mediapipe:macos": [ + "gl_texture_buffer_pool.h", + ], }), visibility = ["//visibility:public"], deps = [ @@ -261,6 +425,14 @@ cc_library( "//conditions:default": [ ":gl_texture_buffer", ], + "//mediapipe:apple": [ + ":pixel_buffer_pool_util", + "//mediapipe/objc:CFHolder", + ], + "//mediapipe:macos": [ + ":pixel_buffer_pool_util", + ":gl_texture_buffer", + ], }), ) @@ -293,13 +465,37 @@ HELPER_COMMON_HDRS = [ "gl_calculator_helper_impl.h", ] +HELPER_IOS_SRCS = [ + "gl_calculator_helper_impl_ios.mm", + "gl_calculator_helper_impl_common.cc", +] + +HELPER_IOS_FRAMEWORKS = [ + "AVFoundation", + "CoreVideo", + "CoreGraphics", + "CoreMedia", + "GLKit", + "QuartzCore", +] + select({ + "//conditions:default": [ + "OpenGLES", + ], + "//mediapipe:macos": [ + "OpenGL", + "AppKit", + ], +}) + cc_library( name = "gl_calculator_helper", srcs = select({ "//conditions:default": HELPER_COMMON_SRCS + HELPER_ANDROID_SRCS, + "//mediapipe:apple": [], }), hdrs = HELPER_COMMON_HDRS + select({ "//conditions:default": HELPER_ANDROID_HDRS, + "//mediapipe:apple": [], }), visibility = ["//visibility:public"], deps = [ @@ -311,6 +507,7 @@ cc_library( ":gpu_service", ":graph_support", ":shader_util", + "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", @@ -332,9 +529,58 @@ cc_library( ] + select({ "//conditions:default": [ ], + "//mediapipe:apple": [ + ":gl_calculator_helper_ios", + "//mediapipe/objc:util", + ], }), ) +objc_library( + name = "gl_calculator_helper_ios", + srcs = HELPER_COMMON_SRCS + HELPER_IOS_SRCS, + hdrs = HELPER_COMMON_HDRS, + copts = [ + "-std=c++11", + "-Wno-shorten-64-to-32", + ], + sdk_frameworks = HELPER_IOS_FRAMEWORKS, + visibility = ["//visibility:public"], + deps = [ + ":gl_base", + ":gl_context", + ":gpu_buffer", + ":gpu_buffer_multi_pool", + ":gpu_service", + ":gpu_shared_data_internal", + ":shader_util", + "//mediapipe/framework:calculator_framework", + "//mediapipe/objc:mediapipe_framework_ios", + "//mediapipe/objc:util", + ], +) + +objc_library( + name = "MPPMetalHelper", + srcs = ["MPPMetalHelper.mm"], + hdrs = ["MPPMetalHelper.h"], + copts = [ + "-std=c++11", + "-Wno-shorten-64-to-32", + ], + sdk_frameworks = [ + "CoreVideo", + "Metal", + ], + visibility = ["//visibility:public"], + deps = [ + ":gpu_shared_data_internal", + ":graph_support", + "//mediapipe/objc:mediapipe_framework_ios", + "@google_toolbox_for_mac//:GTM_Defines", + ], +) + proto_library( name = "scale_mode_proto", srcs = ["scale_mode.proto"], @@ -400,7 +646,10 @@ cc_library( "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - ], + ] + select({ + "//conditions:default": [], + "//mediapipe:apple": ["//mediapipe/objc:util"], + }), alwayslink = 1, ) @@ -452,6 +701,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:options_util", "//mediapipe/gpu:gl_scaler_calculator_cc_proto", ], alwayslink = 1, @@ -493,3 +743,179 @@ mediapipe_cc_proto_library( visibility = ["//visibility:public"], deps = [":gl_surface_sink_calculator_proto"], ) + +### Metal calculators + +metal_library( + name = "simple_shaders_mtl", + srcs = ["simple_shaders.metal"], + hdrs = ["metal_shader_base.h"], +) + +# Only needed for cc_library depending on simple_shaders_mtl. +objc_library( + name = "simple_shaders_for_cc", + hdrs = ["metal_shader_base.h"], + deps = [":simple_shaders_mtl"], +) + +proto_library( + name = "copy_calculator_proto", + srcs = ["copy_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +objc_library( + name = "metal_copy_calculator", + srcs = ["MetalCopyCalculator.mm"], + copts = ["-std=c++11"], + sdk_frameworks = [ + "CoreVideo", + "Metal", + ], + visibility = ["//visibility:public"], + deps = [ + ":MPPMetalHelper", + ":simple_shaders_mtl", + "//mediapipe/gpu:copy_calculator_cc_proto", + "//mediapipe/objc:mediapipe_framework_ios", + ], + alwayslink = 1, +) + +objc_library( + name = "metal_rgb_weight_calculator", + srcs = ["MetalRgbWeightCalculator.mm"], + copts = ["-std=c++11"], + sdk_frameworks = [ + "CoreVideo", + "Metal", + ], + visibility = ["//visibility:public"], + deps = [ + ":MPPMetalHelper", + ":simple_shaders_mtl", + "//mediapipe/objc:mediapipe_framework_ios", + ], + alwayslink = 1, +) + +objc_library( + name = "metal_sobel_calculator", + srcs = ["MetalSobelCalculator.mm"], + copts = ["-std=c++11"], + sdk_frameworks = [ + "CoreVideo", + "Metal", + ], + visibility = ["//visibility:public"], + deps = [ + ":MPPMetalHelper", + ":simple_shaders_mtl", + "//mediapipe/objc:mediapipe_framework_ios", + ], + alwayslink = 1, +) + +objc_library( + name = "metal_sobel_compute_calculator", + srcs = ["MetalSobelComputeCalculator.mm"], + copts = ["-std=c++11"], + sdk_frameworks = [ + "CoreVideo", + "Metal", + ], + visibility = ["//visibility:public"], + deps = [ + ":MPPMetalHelper", + ":simple_shaders_mtl", + "//mediapipe/objc:mediapipe_framework_ios", + ], + alwayslink = 1, +) + +objc_library( + name = "mps_sobel_calculator", + srcs = ["MPSSobelCalculator.mm"], + copts = ["-std=c++11"], + sdk_frameworks = [ + "CoreVideo", + "Metal", + "MetalPerformanceShaders", + ], + visibility = ["//visibility:public"], + deps = [ + ":MPPMetalHelper", + "//mediapipe/objc:mediapipe_framework_ios", + ], + alwayslink = 1, +) + +### Tests + +cc_library( + name = "gpu_test_base", + testonly = 1, + hdrs = ["gpu_test_base.h"], + deps = [ + ":gl_calculator_helper", + ":gpu_shared_data_internal", + "//testing/base/public:gunit_for_library_testonly", + ], +) + +MIN_IOS_VERSION = "9.0" # For thread_local. + +test_suite( + name = "ios", + tags = ["ios"], +) + +test_suite( + name = "metal", + tags = ["metal"], +) + +objc_library( + name = "gl_ios_test_lib", + testonly = 1, + srcs = [ + "MPPGraphGPUDataTests.mm", + "gl_ios_test.mm", + ], + copts = [ + "-std=c++11", + "-Wno-shorten-64-to-32", + ], + data = [ + "//mediapipe/objc:testdata/googlelogo_color_272x92dp.png", + ], + deps = [ + ":MPPGraphGPUData", + ":gl_scaler_calculator", + ":gpu_buffer_to_image_frame_calculator", + ":gpu_shared_data_internal", + ":image_frame_to_gpu_buffer_calculator", + "//mediapipe/objc:MPPGraphTestBase", + "//mediapipe/objc:mediapipe_framework_ios", + "//mediapipe/framework/tool:source", + "//mediapipe/framework/port:threadpool", + "@com_google_absl//absl/memory", + ] + select({ + "//mediapipe:ios_i386": [], + "//mediapipe:ios_x86_64": [], + "//conditions:default": [ + ":metal_rgb_weight_calculator", + ], + }), +) + +ios_unit_test( + name = "gl_ios_test", + minimum_os_version = MIN_IOS_VERSION, + tags = [ + "ios", + ], + deps = [":gl_ios_test_lib"], +) diff --git a/mediapipe/gpu/MPPGraphGPUData.h b/mediapipe/gpu/MPPGraphGPUData.h new file mode 100644 index 000000000..474502619 --- /dev/null +++ b/mediapipe/gpu/MPPGraphGPUData.h @@ -0,0 +1,71 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_ +#define MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_ + +#import +#import +#import + +#import "mediapipe/gpu/gl_base.h" +#import "mediapipe/gpu/gl_context.h" + +namespace mediapipe { +class GlContext; +class GpuBufferMultiPool; +} // namespace mediapipe + +@interface MPPGraphGPUData : NSObject { + // Shared buffer pool for GPU calculators. + mediapipe::GpuBufferMultiPool* _gpuBufferPool; + mediapipe::GlContext* _glContext; +} + +- (instancetype)init NS_UNAVAILABLE; + +/// Initialize. The provided multipool pointer must remain valid throughout +/// this object's lifetime. +- (instancetype)initWithContext:(mediapipe::GlContext*)context + multiPool:(mediapipe::GpuBufferMultiPool*)pool NS_DESIGNATED_INITIALIZER; + +/// Shared texture pool for GPU calculators. +/// For internal use by GlCalculatorHelper. +@property(readonly) mediapipe::GpuBufferMultiPool* gpuBufferPool; + +/// Shared OpenGL context. +#if TARGET_OS_OSX +@property(readonly) NSOpenGLContext* glContext; +@property(readonly) NSOpenGLPixelFormat* glPixelFormat; +#else +@property(readonly) EAGLContext* glContext; +#endif // TARGET_OS_OSX + +/// Shared texture cache. +#if TARGET_OS_OSX +@property(readonly) CVOpenGLTextureCacheRef textureCache; +#else +@property(readonly) CVOpenGLESTextureCacheRef textureCache; +#endif // TARGET_OS_OSX + +/// Shared Metal resources. +@property(readonly) id mtlDevice; +@property(readonly) id mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@property(readonly) CVMetalTextureCacheRef mtlTextureCache; +#endif + +@end + +#endif // MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_ diff --git a/mediapipe/gpu/MPPGraphGPUData.mm b/mediapipe/gpu/MPPGraphGPUData.mm new file mode 100644 index 000000000..4bae8cdc7 --- /dev/null +++ b/mediapipe/gpu/MPPGraphGPUData.mm @@ -0,0 +1,123 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/gpu/MPPGraphGPUData.h" + +#import "GTMDefines.h" + +#include "mediapipe/gpu/gl_context.h" +#include "mediapipe/gpu/gpu_buffer_multi_pool.h" + +#if TARGET_OS_OSX +#import +#else +#import +#endif // TARGET_OS_OSX + +@implementation MPPGraphGPUData + +@synthesize textureCache = _textureCache; +@synthesize mtlDevice = _mtlDevice; +@synthesize mtlCommandQueue = _mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@synthesize mtlTextureCache = _mtlTextureCache; +#endif + +#if TARGET_OS_OSX +typedef CVOpenGLTextureCacheRef CVTextureCacheType; +#else +typedef CVOpenGLESTextureCacheRef CVTextureCacheType; +#endif // TARGET_OS_OSX + +- (instancetype)initWithContext:(mediapipe::GlContext*)context + multiPool:(mediapipe::GpuBufferMultiPool*)pool { + self = [super init]; + if (self) { + _gpuBufferPool = pool; + _glContext = context; + } + return self; +} + +- (void)dealloc { + if (_textureCache) { + _textureCache = NULL; + } +#if COREVIDEO_SUPPORTS_METAL + if (_mtlTextureCache) { + CFRelease(_mtlTextureCache); + _mtlTextureCache = NULL; + } +#endif +} + +#if TARGET_OS_OSX +- (NSOpenGLContext *)glContext { + return _glContext->nsgl_context(); +} + +- (NSOpenGLPixelFormat *) glPixelFormat { + return _glContext->nsgl_pixel_format(); +} +#else +- (EAGLContext *)glContext { + return _glContext->eagl_context(); +} +#endif // TARGET_OS_OSX + +- (CVTextureCacheType)textureCache { + @synchronized(self) { + if (!_textureCache) { + _textureCache = _glContext->cv_texture_cache(); + } + } + return _textureCache; +} + +- (mediapipe::GpuBufferMultiPool *)gpuBufferPool { + return _gpuBufferPool; +} + +- (id)mtlDevice { + @synchronized(self) { + if (!_mtlDevice) { + _mtlDevice = MTLCreateSystemDefaultDevice(); + } + } + return _mtlDevice; +} + +- (id)mtlCommandQueue { + @synchronized(self) { + if (!_mtlCommandQueue) { + _mtlCommandQueue = [self.mtlDevice newCommandQueue]; + } + } + return _mtlCommandQueue; +} + +#if COREVIDEO_SUPPORTS_METAL +- (CVMetalTextureCacheRef)mtlTextureCache { + @synchronized(self) { + if (!_mtlTextureCache) { + CVReturn err = CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); + NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err); + // TODO: register and flush metal caches too. + } + } + return _mtlTextureCache; +} +#endif + +@end diff --git a/mediapipe/gpu/MPPGraphGPUDataTests.mm b/mediapipe/gpu/MPPGraphGPUDataTests.mm new file mode 100644 index 000000000..e8b50845b --- /dev/null +++ b/mediapipe/gpu/MPPGraphGPUDataTests.mm @@ -0,0 +1,86 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/port/threadpool.h" + +#import "mediapipe/gpu/MPPGraphGPUData.h" +#import "mediapipe/gpu/gpu_shared_data_internal.h" + +@interface MPPGraphGPUDataTests : XCTestCase { +} +@end + +@implementation MPPGraphGPUDataTests + +// This test verifies that the internal Objective-C object is correctly +// released when the C++ wrapper is released. +- (void)testCorrectlyReleased { + __weak id gpuData = nil; + std::weak_ptr gpuRes; + @autoreleasepool { + mediapipe::GpuSharedData gpu_shared; + gpuRes = gpu_shared.gpu_resources; + gpuData = gpu_shared.gpu_resources->ios_gpu_data(); + XCTAssertNotEqual(gpuRes.lock(), nullptr); + XCTAssertNotNil(gpuData); + } + XCTAssertEqual(gpuRes.lock(), nullptr); + XCTAssertNil(gpuData); +} + +// This test verifies that the lazy initialization of the glContext instance +// variable is thread-safe. All threads should read the same value. +- (void)testGlContextThreadSafeLazyInitialization { + mediapipe::GpuSharedData gpu_shared; + constexpr int kNumThreads = 10; + EAGLContext* ogl_context[kNumThreads]; + auto pool = absl::make_unique(kNumThreads); + pool->StartWorkers(); + for (int i = 0; i < kNumThreads; ++i) { + pool->Schedule([&gpu_shared, &ogl_context, i] { + ogl_context[i] = gpu_shared.gpu_resources->ios_gpu_data().glContext; + }); + } + pool.reset(); + for (int i = 0; i < kNumThreads - 1; ++i) { + XCTAssertEqual(ogl_context[i], ogl_context[i + 1]); + } +} + +// This test verifies that the lazy initialization of the textureCache instance +// variable is thread-safe. All threads should read the same value. +- (void)testTextureCacheThreadSafeLazyInitialization { + mediapipe::GpuSharedData gpu_shared; + constexpr int kNumThreads = 10; + CFHolder texture_cache[kNumThreads]; + auto pool = absl::make_unique(kNumThreads); + pool->StartWorkers(); + for (int i = 0; i < kNumThreads; ++i) { + pool->Schedule([&gpu_shared, &texture_cache, i] { + texture_cache[i].reset(gpu_shared.gpu_resources->ios_gpu_data().textureCache); + }); + } + pool.reset(); + for (int i = 0; i < kNumThreads - 1; ++i) { + XCTAssertEqual(*texture_cache[i], *texture_cache[i + 1]); + } +} + +@end diff --git a/mediapipe/gpu/MPPMetalHelper.h b/mediapipe/gpu/MPPMetalHelper.h new file mode 100644 index 000000000..293e9acdc --- /dev/null +++ b/mediapipe/gpu/MPPMetalHelper.h @@ -0,0 +1,105 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_GPU_MEDIAPIPE_METAL_HELPER_H_ +#define MEDIAPIPE_GPU_MEDIAPIPE_METAL_HELPER_H_ + +#import +#import +#import + +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/gpu/MPPGraphGPUData.h" +#include "mediapipe/gpu/gpu_shared_data_internal.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPMetalHelper : NSObject { + MPPGraphGPUData* _gpuShared; +} + +- (instancetype)init NS_UNAVAILABLE; + +/// Initialize. This initializer is recommended for calculators. +- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext*)cc; + +/// Initialize. +- (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources + NS_DESIGNATED_INITIALIZER; + +/// Configures a calculator's contract for accessing GPU resources. +/// Calculators should use this in GetContract. ++ (::mediapipe::Status)updateContract:(mediapipe::CalculatorContract*)cc; + +/// Deprecated initializer. +- (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets; + +/// Deprecated initializer. +- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData*)gpuShared; + +/// Configures a calculator's side packets for accessing GPU resources. +/// Calculators should use this in FillExpectations. ++ (::mediapipe::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets; + +/// Get a metal command buffer. +/// Calculators should use this method instead of getting a buffer from the +/// MTLCommandQueue directly, not just for convenience, but also because the +/// framework may want to add some custom hooks to the commandBuffers used by +/// calculators. +- (id)commandBuffer; + +/// Creates a CVMetalTextureRef linked to the provided GpuBuffer. +/// Ownership follows the copy rule, so the caller is responsible for +/// releasing the CVMetalTextureRef. +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; + +/// Creates a CVMetalTextureRef linked to the provided GpuBuffer given a specific plane. +/// Ownership follows the copy rule, so the caller is responsible for +/// releasing the CVMetalTextureRef. +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer + plane:(size_t)plane; + +/// Returns a MTLTexture linked to the provided GpuBuffer. +/// A calculator can freely use it as a rendering source, but it should not +/// use it as a rendering target if the GpuBuffer was provided as an input. +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; + +/// Returns a MTLTexture linked to the provided GpuBuffer given a specific plane. +/// A calculator can freely use it as a rendering source, but it should not +/// use it as a rendering target if the GpuBuffer was provided as an input. +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer + plane:(size_t)plane; + +/// Obtains a new GpuBuffer to be used as an output destination. +- (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height; + +/// Obtains a new GpuBuffer to be used as an output destination. +- (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width + height:(int)height + format:(mediapipe::GpuBufferFormat)format; + +/// Convenience method to load a Metal library stored as a bundle resource. +- (id)newLibraryWithResourceName:(NSString*)name error:(NSError* _Nullable*)error; + +/// Shared Metal resources. +@property(readonly) id mtlDevice; +@property(readonly) id mtlCommandQueue; +@property(readonly) CVMetalTextureCacheRef mtlTextureCache; + +@end + +NS_ASSUME_NONNULL_END + +#endif // MEDIAPIPE_GPU_MEDIAPIPE_METAL_HELPER_H_ diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm new file mode 100644 index 000000000..cc4fbd6e7 --- /dev/null +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -0,0 +1,210 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/gpu/MPPMetalHelper.h" + +#import "mediapipe/gpu/graph_support.h" +#import "GTMDefines.h" + +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +// Using a C++ class so it can be declared as a friend of LegacyCalculatorSupport. +class MetalHelperLegacySupport { + public: + static CalculatorContract* GetCalculatorContract() { + return LegacyCalculatorSupport::Scoped::current(); + } + + static CalculatorContext* GetCalculatorContext() { + return LegacyCalculatorSupport::Scoped::current(); + } +}; + +} // namespace mediapipe + +@implementation MPPMetalHelper + +- (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources { + self = [super init]; + if (self) { + _gpuShared = gpuResources->ios_gpu_data(); + } + return self; +} + +- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData*)gpuShared { + return [self initWithGpuResources:gpuShared->gpu_resources.get()]; +} + +- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext*)cc { + if (!cc) return nil; + return [self initWithGpuResources:&cc->Service(mediapipe::kGpuService).GetObject()]; +} + ++ (::mediapipe::Status)updateContract:(mediapipe::CalculatorContract*)cc { + cc->UseService(mediapipe::kGpuService); + // Allow the legacy side packet to be provided, too, for backwards + // compatibility with existing graphs. It will just be ignored. + auto& input_side_packets = cc->InputSidePackets(); + auto id = input_side_packets.GetId(mediapipe::kGpuSharedTagName, 0); + if (id.IsValid()) { + input_side_packets.Get(id).Set(); + } + return ::mediapipe::OkStatus(); +} + +// Legacy support. +- (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets { + auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContext(); + if (cc) { + CHECK_EQ(&inputSidePackets, &cc->InputSidePackets()); + return [self initWithCalculatorContext:cc]; + } + + // TODO: remove when we can. + LOG(WARNING) + << "CalculatorContext not available. If this calculator uses " + "CalculatorBase, call initWithCalculatorContext instead."; + mediapipe::GpuSharedData* gpu_shared = + inputSidePackets.Tag(mediapipe::kGpuSharedTagName).Get(); + + return [self initWithGpuResources:gpu_shared->gpu_resources.get()]; +} + +// Legacy support. ++ (::mediapipe::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets { + auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContract(); + if (cc) { + CHECK_EQ(inputSidePackets, &cc->InputSidePackets()); + return [self updateContract:cc]; + } + + // TODO: remove when we can. + LOG(WARNING) + << "CalculatorContract not available. If you're calling this " + "from a GetContract method, call updateContract instead."; + auto id = inputSidePackets->GetId(mediapipe::kGpuSharedTagName, 0); + RET_CHECK(id.IsValid()) + << "A " << mediapipe::kGpuSharedTagName + << " input side packet is required here."; + inputSidePackets->Get(id).Set(); + return ::mediapipe::OkStatus(); +} + +- (id)mtlDevice { + return _gpuShared.mtlDevice; +} + +- (id)mtlCommandQueue { + return _gpuShared.mtlCommandQueue; +} + +- (CVMetalTextureCacheRef)mtlTextureCache { + return _gpuShared.mtlTextureCache; +} + +- (id)commandBuffer { + return [_gpuShared.mtlCommandQueue commandBuffer]; +} + +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer + plane:(size_t)plane { + + CVPixelBufferRef pixel_buffer = gpuBuffer.GetCVPixelBufferRef(); + OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); + + MTLPixelFormat metalPixelFormat = MTLPixelFormatInvalid; + int width = gpuBuffer.width(); + int height = gpuBuffer.height(); + + switch (pixel_format) { + case kCVPixelFormatType_32BGRA: + NSCAssert(plane == 0, @"Invalid plane number"); + metalPixelFormat = MTLPixelFormatBGRA8Unorm; + break; + case kCVPixelFormatType_64RGBAHalf: + NSCAssert(plane == 0, @"Invalid plane number"); + metalPixelFormat = MTLPixelFormatRGBA16Float; + break; + case kCVPixelFormatType_OneComponent8: + NSCAssert(plane == 0, @"Invalid plane number"); + metalPixelFormat = MTLPixelFormatR8Uint; + break; + case kCVPixelFormatType_420YpCbCr8BiPlanarVideoRange: + case kCVPixelFormatType_420YpCbCr8BiPlanarFullRange: + if (plane == 0) { + metalPixelFormat = MTLPixelFormatR8Unorm; + } else if (plane == 1) { + metalPixelFormat = MTLPixelFormatRG8Unorm; + } else { + NSCAssert(NO, @"Invalid plane number"); + } + width = CVPixelBufferGetWidthOfPlane(pixel_buffer, plane); + height = CVPixelBufferGetHeightOfPlane(pixel_buffer, plane); + break; + case kCVPixelFormatType_TwoComponent16Half: + metalPixelFormat = MTLPixelFormatRG16Float; + NSCAssert(plane == 0, @"Invalid plane number"); + break; + case kCVPixelFormatType_OneComponent32Float: + metalPixelFormat = MTLPixelFormatR32Float; + NSCAssert(plane == 0, @"Invalid plane number"); + break; + default: + NSCAssert(NO, @"Invalid pixel buffer format"); + break; + } + + CVMetalTextureRef texture; + CVReturn err = CVMetalTextureCacheCreateTextureFromImage( + NULL, _gpuShared.mtlTextureCache, gpuBuffer.GetCVPixelBufferRef(), NULL, + metalPixelFormat, width, height, plane, &texture); + CHECK_EQ(err, kCVReturnSuccess); + return texture; +} + +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer { + return [self copyCVMetalTextureWithGpuBuffer:gpuBuffer plane:0]; +} + +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer { + return [self metalTextureWithGpuBuffer:gpuBuffer plane:0]; +} + +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer + plane:(size_t)plane { + CFHolder cvTexture; + cvTexture.adopt([self copyCVMetalTextureWithGpuBuffer:gpuBuffer plane:plane]); + return CVMetalTextureGetTexture(*cvTexture); +} + +- (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height { + return _gpuShared.gpuBufferPool->GetBuffer(width, height); +} + +- (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width + height:(int)height + format:(mediapipe::GpuBufferFormat)format { + return _gpuShared.gpuBufferPool->GetBuffer(width, height, format); +} + +- (id)newLibraryWithResourceName:(NSString*)name error:(NSError * _Nullable *)error { + return [_gpuShared.mtlDevice newLibraryWithFile:[[NSBundle bundleForClass:[self class]] + pathForResource:name ofType:@"metallib"] + error:error]; +} + +@end diff --git a/mediapipe/gpu/egl_surface_holder.h b/mediapipe/gpu/egl_surface_holder.h index 0402a482d..a690f832a 100644 --- a/mediapipe/gpu/egl_surface_holder.h +++ b/mediapipe/gpu/egl_surface_holder.h @@ -32,6 +32,9 @@ struct EglSurfaceHolder { EGLSurface surface GUARDED_BY(mutex) = EGL_NO_SURFACE; // True if MediaPipe created the surface and is responsible for destroying it. bool owned GUARDED_BY(mutex) = false; + // Vertical flip of the surface, useful for conversion between coordinate + // systems with top-left v.s. bottom-left origins. + bool flip_y = false; }; } // namespace mediapipe diff --git a/mediapipe/gpu/gl_base.h b/mediapipe/gpu/gl_base.h index 1e705fce3..3fd823388 100644 --- a/mediapipe/gpu/gl_base.h +++ b/mediapipe/gpu/gl_base.h @@ -70,7 +70,7 @@ #include // When using the Linux EGL headers, we may end up pulling a -// "#define Status int" from Xlib.h, which interferes with util::Status. +// "#define Status int" from Xlib.h, which interferes with mediapipe::Status. #undef Status // More crud from X diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index 9876294b3..171fe3be0 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -32,7 +32,7 @@ #ifdef __APPLE__ #include -#include "mediapipe/framework/ios/CFHolder.h" +#include "mediapipe/objc/CFHolder.h" #endif // __APPLE__ namespace mediapipe { @@ -40,7 +40,7 @@ namespace mediapipe { class GlCalculatorHelperImpl; class GlTexture; class GpuResources; -class GpuSharedData; +struct GpuSharedData; #ifdef __APPLE__ #if TARGET_OS_OSX diff --git a/mediapipe/gpu/gl_calculator_helper_impl_ios.mm b/mediapipe/gpu/gl_calculator_helper_impl_ios.mm new file mode 100644 index 000000000..d62a1f90d --- /dev/null +++ b/mediapipe/gpu/gl_calculator_helper_impl_ios.mm @@ -0,0 +1,192 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/gpu/gl_calculator_helper_impl.h" + +#if TARGET_OS_OSX +#import +#else +#import +#endif // TARGET_OS_OSX +#import + +#include "absl/memory/memory.h" +#include "mediapipe/gpu/gpu_buffer_multi_pool.h" +#include "mediapipe/gpu/pixel_buffer_pool_util.h" +#include "mediapipe/objc/util.h" + +namespace mediapipe { + +GlVersion GlCalculatorHelperImpl::GetGlVersion() { +#if TARGET_OS_OSX + return GlVersion::kGL; +#else + if (gl_context_->eagl_context().API == kEAGLRenderingAPIOpenGLES3) return GlVersion::kGLES3; + else return GlVersion::kGLES2; +#endif // TARGET_OS_OSX +} + +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +GlTexture GlCalculatorHelperImpl::CreateSourceTexture( + const mediapipe::ImageFrame& image_frame) { + GlTexture texture; + + texture.helper_impl_ = this; + texture.width_ = image_frame.Width(); + texture.height_ = image_frame.Height(); + auto format = GpuBufferFormatForImageFormat(image_frame.Format()); + + GlTextureInfo info = GlTextureInfoForGpuBufferFormat(format, 0, GetGlVersion()); + + glGenTextures(1, &texture.name_); + glBindTexture(GL_TEXTURE_2D, texture.name_); + glTexImage2D(GL_TEXTURE_2D, 0, info.gl_internal_format, texture.width_, + texture.height_, 0, info.gl_format, info.gl_type, + image_frame.PixelData()); + SetStandardTextureParams(GL_TEXTURE_2D); + return texture; +} + +GlTexture GlCalculatorHelperImpl::CreateSourceTexture( + const GpuBuffer& gpu_buffer) { + return MapGpuBuffer(gpu_buffer, 0); +} + +GlTexture GlCalculatorHelperImpl::CreateSourceTexture( + const GpuBuffer& gpu_buffer, int plane) { + return MapGpuBuffer(gpu_buffer, plane); +} + +GlTexture GlCalculatorHelperImpl::MapGpuBuffer( + const GpuBuffer& gpu_buffer, int plane) { + CVReturn err; + GlTexture texture; + texture.helper_impl_ = this; + texture.gpu_buffer_ = gpu_buffer; + texture.plane_ = plane; + + const GlTextureInfo info = + GlTextureInfoForGpuBufferFormat(gpu_buffer.format(), plane, GetGlVersion()); + // When scale is not 1, we still give the nominal size of the image. + texture.width_ = gpu_buffer.width(); + texture.height_ = gpu_buffer.height(); + +#if TARGET_OS_OSX + CVOpenGLTextureRef cv_texture_temp; + err = CVOpenGLTextureCacheCreateTextureFromImage( + kCFAllocatorDefault, gl_context_->cv_texture_cache(), gpu_buffer.GetCVPixelBufferRef(), NULL, + &cv_texture_temp); + NSCAssert(cv_texture_temp && !err, + @"Error at CVOpenGLTextureCacheCreateTextureFromImage %d", err); + texture.cv_texture_.adopt(cv_texture_temp); + texture.target_ = CVOpenGLTextureGetTarget(*texture.cv_texture_); + texture.name_ = CVOpenGLTextureGetName(*texture.cv_texture_); +#else + CVOpenGLESTextureRef cv_texture_temp; + err = CVOpenGLESTextureCacheCreateTextureFromImage( + kCFAllocatorDefault, gl_context_->cv_texture_cache(), gpu_buffer.GetCVPixelBufferRef(), NULL, + GL_TEXTURE_2D, info.gl_internal_format, texture.width_ / info.downscale, + texture.height_ / info.downscale, info.gl_format, info.gl_type, plane, + &cv_texture_temp); + NSCAssert(cv_texture_temp && !err, + @"Error at CVOpenGLESTextureCacheCreateTextureFromImage %d", err); + texture.cv_texture_.adopt(cv_texture_temp); + texture.target_ = CVOpenGLESTextureGetTarget(*texture.cv_texture_); + texture.name_ = CVOpenGLESTextureGetName(*texture.cv_texture_); +#endif // TARGET_OS_OSX + + glBindTexture(texture.target(), texture.name()); + SetStandardTextureParams(texture.target()); + + return texture; +} +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + +template<> +std::unique_ptr GlTexture::GetFrame() const { +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + if (gpu_buffer_.GetCVPixelBufferRef()) { + return CreateImageFrameForCVPixelBuffer(gpu_buffer_.GetCVPixelBufferRef()); + } + + ImageFormat::Format image_format = + ImageFormatForGpuBufferFormat(gpu_buffer_.format()); + // TODO: handle gl version here. + GlTextureInfo info = GlTextureInfoForGpuBufferFormat( + gpu_buffer_.format(), plane_); + + auto output = absl::make_unique( + image_format, width_, height_); + + glReadPixels(0, 0, width_, height_, info.gl_format, info.gl_type, + output->MutablePixelData()); + return output; +#else + CHECK(gpu_buffer_.format() == GpuBufferFormat::kBGRA32); + auto output = + absl::make_unique(ImageFormat::SRGBA, width_, height_, + ImageFrame::kGlDefaultAlignmentBoundary); + + CHECK(helper_impl_); + helper_impl_->ReadTexture(*this, output->MutablePixelData(), output->PixelDataSize()); + + return output; +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +} + +template<> +std::unique_ptr GlTexture::GetFrame() const { + NSCAssert(gpu_buffer_, @"gpu_buffer_ must be valid"); +#if TARGET_IPHONE_SIMULATOR + CVPixelBufferRef pixel_buffer = gpu_buffer_.GetCVPixelBufferRef(); + CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); + NSCAssert(err == kCVReturnSuccess, @"CVPixelBufferLockBaseAddress failed: %d", err); + OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); + if (pixel_format == kCVPixelFormatType_32BGRA) { + // TODO: restore previous framebuffer? Move this to helper so we can + // use BindFramebuffer? + glViewport(0, 0, width_, height_); + glFramebufferTexture2D( + GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, target_, name_, 0); + glReadPixels(0, 0, width_, height_, GL_BGRA, GL_UNSIGNED_BYTE, + CVPixelBufferGetBaseAddress(pixel_buffer)); + } else { + uint32_t format_big = CFSwapInt32HostToBig(pixel_format); + NSLog(@"unsupported pixel format: %.4s", (char*)&format_big); + } + err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); + NSCAssert(err == kCVReturnSuccess, @"CVPixelBufferUnlockBaseAddress failed: %d", err); +#endif + return absl::make_unique(gpu_buffer_); +} + +void GlTexture::Release() { +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + if (*cv_texture_) { + cv_texture_.reset(NULL); + } else if (name_) { + // This is only needed because of the glGenTextures in + // CreateSourceTexture(ImageFrame)... change. + glDeleteTextures(1, &name_); + } +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + helper_impl_ = nullptr; + gpu_buffer_ = nullptr; + plane_ = 0; + name_ = 0; + width_ = 0; + height_ = 0; +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 1ff07cf1b..cf48c9f9d 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -33,7 +33,7 @@ #ifdef __APPLE__ #include -#include "mediapipe/framework/ios/CFHolder.h" +#include "mediapipe/objc/CFHolder.h" #if TARGET_OS_OSX diff --git a/mediapipe/gpu/gl_context_eagl.cc b/mediapipe/gpu/gl_context_eagl.cc index 51c49128d..34b1d0a2f 100644 --- a/mediapipe/gpu/gl_context_eagl.cc +++ b/mediapipe/gpu/gl_context_eagl.cc @@ -75,7 +75,14 @@ GlContext::StatusOrGlContext GlContext::Create(EAGLSharegroup* sharegroup, return ::mediapipe::OkStatus(); } -void GlContext::DestroyContext() {} +void GlContext::DestroyContext() { + if (*texture_cache_) { + // The texture cache must be flushed on tear down, otherwise we potentially + // leak pixel buffers whose textures have pending GL operations after the + // CVOpenGLESTextureRef is released in GlTexture::Release. + CVOpenGLESTextureCacheFlush(*texture_cache_, 0); + } +} GlContext::ContextBinding GlContext::ThisContextBinding() { GlContext::ContextBinding result; diff --git a/mediapipe/gpu/gl_context_egl.cc b/mediapipe/gpu/gl_context_egl.cc index 947c1c36c..a8f12d7ec 100644 --- a/mediapipe/gpu/gl_context_egl.cc +++ b/mediapipe/gpu/gl_context_egl.cc @@ -34,8 +34,18 @@ static pthread_key_t egl_release_thread_key; static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT; static void EglThreadExitCallback(void* key_value) { +#if defined(__ANDROID__) eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT); +#else + // Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display + // parameter for eglMakeCurrent. This behavior is not portable to all EGL + // implementations, and should be considered as an undocumented vendor + // extension. + // https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml + eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE, + EGL_NO_SURFACE, EGL_NO_CONTEXT); +#endif eglReleaseThread(); } diff --git a/mediapipe/gpu/gl_context_nsgl.cc b/mediapipe/gpu/gl_context_nsgl.cc index 4602f5714..8a43415ae 100644 --- a/mediapipe/gpu/gl_context_nsgl.cc +++ b/mediapipe/gpu/gl_context_nsgl.cc @@ -102,7 +102,14 @@ GlContext::StatusOrGlContext GlContext::Create(NSOpenGLContext* share_context, return ::mediapipe::OkStatus(); } -void GlContext::DestroyContext() {} +void GlContext::DestroyContext() { + if (*texture_cache_) { + // The texture cache must be flushed on tear down, otherwise we potentially + // leak pixel buffers whose textures have pending GL operations after the + // CVOpenGLTextureRef is released in GlTexture::Release. + CVOpenGLTextureCacheFlush(*texture_cache_, 0); + } +} GlContext::ContextBinding GlContext::ThisContextBinding() { GlContext::ContextBinding result; diff --git a/mediapipe/gpu/gl_ios_test.mm b/mediapipe/gpu/gl_ios_test.mm new file mode 100644 index 000000000..f1bb132ec --- /dev/null +++ b/mediapipe/gpu/gl_ios_test.mm @@ -0,0 +1,252 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#import "mediapipe/framework/tool/source.pb.h" +#import "mediapipe/gpu/gpu_shared_data_internal.h" +#import "mediapipe/objc/MPPGraph.h" +#import "mediapipe/objc/MPPGraphTestBase.h" +#import "mediapipe/objc/util.h" + +#include "absl/memory/memory.h" +#import "mediapipe/framework/calculator_framework.h" +#import "mediapipe/gpu/gl_calculator_helper.h" + +@interface GLIOSTests : MPPGraphTestBase{ + UIImage* _sourceImage; + MPPGraph* _graph; +} +@end + +@implementation GLIOSTests + +- (void)setUp { + [super setUp]; + + _sourceImage = [self testImageNamed:@"googlelogo_color_272x92dp" extension:@"png"]; +} + +- (void)tearDown { + [super tearDown]; +} + +- (CVPixelBufferRef)redPixelBuffer:(CVPixelBufferRef)input { + return [self transformPixelBuffer:input + outputPixelFormat:kCVPixelFormatType_32BGRA + transformation:^(CVPixelBufferRef input, + CVPixelBufferRef output) { + vImage_Buffer vInput = vImageForCVPixelBuffer(input); + vImage_Buffer vRed = vImageForCVPixelBuffer(output); + + static const int16_t matrix[16] = { + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 256, 0, + 0, 0, 0, 256, + }; + vImage_Error vError = vImageMatrixMultiply_ARGB8888( + &vInput, &vRed, matrix, 256, NULL, NULL, 0); + XCTAssertEqual(vError, kvImageNoError); + }]; +} + +- (CVPixelBufferRef)luminancePixelBuffer:(CVPixelBufferRef)input { + return [self transformPixelBuffer:input + outputPixelFormat:kCVPixelFormatType_32BGRA + transformation:^(CVPixelBufferRef input, + CVPixelBufferRef output) { + vImage_Buffer vInput = vImageForCVPixelBuffer(input); + vImage_Buffer vLuminance = vImageForCVPixelBuffer(output); + + // sRGB weights: R 0.2125, G 0.7154, B 0.0721 + static const int16_t matrix[16] = { + 721, 721, 721, 0, + 7154, 7154, 7154, 0, + 2125, 2125, 2125, 0, + 0, 0, 0, 10000, + }; + vImage_Error vError = vImageMatrixMultiply_ARGB8888( + &vInput, &vLuminance, matrix, 10000, NULL, NULL, 0); + XCTAssertEqual(vError, kvImageNoError); + }]; +} + +- (CVPixelBufferRef)grayPixelBuffer:(CVPixelBufferRef)input { + return [self transformPixelBuffer:input + outputPixelFormat:kCVPixelFormatType_OneComponent8 + transformation:^(CVPixelBufferRef input, + CVPixelBufferRef output) { + vImage_Buffer vInput = vImageForCVPixelBuffer(input); + vImage_Buffer vGray = vImageForCVPixelBuffer(output); + vImage_Error vError = vImageBGRAToGray(&vInput, &vGray); + XCTAssertEqual(vError, kvImageNoError); + }]; +} + +- (void)testGlConverters { + CFHolder originalPixelBuffer; + ::mediapipe::Status status = + CreateCVPixelBufferFromCGImage([_sourceImage CGImage], &originalPixelBuffer); + XCTAssert(status.ok()); + + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_frames"); + auto node = config.add_node(); + node->set_calculator("GpuBufferToImageFrameCalculator"); + node->add_input_stream("input_frames"); + node->add_output_stream("image_frames"); + auto node2 = config.add_node(); + node2->set_calculator("ImageFrameToGpuBufferCalculator"); + node2->add_input_stream("image_frames"); + node2->add_output_stream("output_frames"); + + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"output_frames" + outputPacketType:MediaPipePacketPixelBuffer]; + [self testGraph:_graph input:*originalPixelBuffer expectedOutput:*originalPixelBuffer]; +} + +- (void)testGlConvertersNoOpInserted { + CFHolder originalPixelBuffer; + ::mediapipe::Status status = + CreateCVPixelBufferFromCGImage([_sourceImage CGImage], &originalPixelBuffer); + XCTAssert(status.ok()); + + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_frames"); + auto node = config.add_node(); + node->set_calculator("GpuBufferToImageFrameCalculator"); + node->add_input_stream("input_frames"); + node->add_output_stream("image_frames"); + // This node should be a no-op, since its inputs are already ImageFrames. + auto no_op_node = config.add_node(); + no_op_node->set_calculator("GpuBufferToImageFrameCalculator"); + no_op_node->add_input_stream("image_frames"); + no_op_node->add_output_stream("still_image_frames"); + auto node2 = config.add_node(); + node2->set_calculator("ImageFrameToGpuBufferCalculator"); + node2->add_input_stream("still_image_frames"); + node2->add_output_stream("output_frames"); + + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"output_frames" + outputPacketType:MediaPipePacketPixelBuffer]; + [self testGraph:_graph input:*originalPixelBuffer expectedOutput:*originalPixelBuffer]; +} + +- (void)testGlConvertersWithOptionalSidePackets { + CFHolder originalPixelBuffer; + ::mediapipe::Status status = + CreateCVPixelBufferFromCGImage([_sourceImage CGImage], &originalPixelBuffer); + XCTAssert(status.ok()); + + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_frames"); + auto node = config.add_node(); + node->set_calculator("GpuBufferToImageFrameCalculator"); + node->add_input_stream("input_frames"); + node->add_output_stream("image_frames"); + auto node2 = config.add_node(); + node2->set_calculator("ImageFrameToGpuBufferCalculator"); + node2->add_input_stream("image_frames"); + node2->add_output_stream("output_frames"); + + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"output_frames" + outputPacketType:MediaPipePacketPixelBuffer]; + [self testGraph:_graph input:*originalPixelBuffer expectedOutput:*originalPixelBuffer]; +} + +- (void)testDestinationSizes { + mediapipe::GpuSharedData gpuData; + mediapipe::GlCalculatorHelper helper; + helper.InitializeForTest(&gpuData); + + std::vector> sizes{ + {200, 300}, + {200, 299}, + {196, 300}, + {194, 300}, + {193, 300}, + }; + for (const auto& width_height : sizes) { + mediapipe::GlTexture texture = + helper.CreateDestinationTexture(width_height.first, width_height.second); + XCTAssertNotEqual(texture.name(), 0); + } +} + +- (void)testSimpleConversionFromFormat:(OSType)cvPixelFormat { + CFHolder originalPixelBuffer; + ::mediapipe::Status status = + CreateCVPixelBufferFromCGImage([_sourceImage CGImage], &originalPixelBuffer); + XCTAssert(status.ok()); + CVPixelBufferRef convertedPixelBuffer = + [self convertPixelBuffer:*originalPixelBuffer + toPixelFormat:cvPixelFormat]; + CVPixelBufferRef bgraPixelBuffer = + [self convertPixelBuffer:convertedPixelBuffer + toPixelFormat:kCVPixelFormatType_32BGRA]; + + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_frames"); + auto node = config.add_node(); + node->set_calculator("GlScalerCalculator"); + node->add_input_stream("input_frames"); + node->add_output_stream("output_frames"); + + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"output_frames" + outputPacketType:MediaPipePacketPixelBuffer]; + [self testGraph:_graph input:convertedPixelBuffer expectedOutput:bgraPixelBuffer]; + CFRelease(convertedPixelBuffer); + CFRelease(bgraPixelBuffer); +} + +- (void)testOneComponent8 { + [self testSimpleConversionFromFormat:kCVPixelFormatType_OneComponent8]; +} + +- (void)testMetalRgbWeight { +#if TARGET_IPHONE_SIMULATOR + NSLog(@"Metal tests skipped on Simulator."); +#else + CFHolder originalPixelBuffer; + ::mediapipe::Status status = + CreateCVPixelBufferFromCGImage([_sourceImage CGImage], &originalPixelBuffer); + XCTAssert(status.ok()); + CVPixelBufferRef redPixelBuffer = [self redPixelBuffer:*originalPixelBuffer]; + + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_frames"); + auto node = config.add_node(); + node->set_calculator("MetalRgbWeightCalculator"); + node->add_input_stream("input_frames"); + node->add_output_stream("output_frames"); + node->add_input_side_packet("WEIGHTS:rgb_weights"); + + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"output_frames" + outputPacketType:MediaPipePacketPixelBuffer]; + [_graph setSidePacket:(mediapipe::MakePacket(1.0, 0.0, 0.0)) + named:"rgb_weights"]; + + [self testGraph:_graph input:*originalPixelBuffer expectedOutput:redPixelBuffer]; + CFRelease(redPixelBuffer); +#endif // TARGET_IPHONE_SIMULATOR +} + +@end diff --git a/mediapipe/gpu/gl_scaler_calculator.cc b/mediapipe/gpu/gl_scaler_calculator.cc index 18674e71b..08c532150 100644 --- a/mediapipe/gpu/gl_scaler_calculator.cc +++ b/mediapipe/gpu/gl_scaler_calculator.cc @@ -15,6 +15,7 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/options_util.h" #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_quad_renderer.h" #include "mediapipe/gpu/gl_scaler_calculator.pb.h" @@ -52,6 +53,8 @@ namespace mediapipe { // both having padding of 1 pixels. So the value of output stream is 1 / 5 = // 0.2. // Additional input side packets: +// OPTIONS: the GlScalerCalculatorOptions to use. Will replace or merge with +// existing calculator options, depending on field merge_fields. // OUTPUT_DIMENSIONS: the output width and height in pixels. // ROTATION: the counterclockwise rotation angle in degrees. // These can also be specified as options. @@ -101,6 +104,9 @@ REGISTER_CALCULATOR(GlScalerCalculator); } RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); + if (cc->InputSidePackets().HasTag("OPTIONS")) { + cc->InputSidePackets().Tag("OPTIONS").Set(); + } if (HasTagOrIndex(&cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)) { TagOrIndex(&cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1) .Set(); @@ -127,7 +133,9 @@ REGISTER_CALCULATOR(GlScalerCalculator); RETURN_IF_ERROR(helper_.Open(cc)); int rotation_ccw = 0; - const auto& options = cc->Options(); + const auto& options = + tool::RetrieveOptions(cc->Options(), + cc->InputSidePackets(), "OPTIONS"); if (options.has_output_width()) { dst_width_ = options.output_width(); } diff --git a/mediapipe/gpu/gl_surface_sink_calculator.cc b/mediapipe/gpu/gl_surface_sink_calculator.cc index 0ac0f02b9..e80bd29f9 100644 --- a/mediapipe/gpu/gl_surface_sink_calculator.cc +++ b/mediapipe/gpu/gl_surface_sink_calculator.cc @@ -128,7 +128,7 @@ REGISTER_CALCULATOR(GlSurfaceSinkCalculator); renderer_->GlRender(src.width(), src.height(), dst_width, dst_height, scale_mode_, FrameRotation::kNone, /*flip_horizontal=*/false, /*flip_vertical=*/false, - /*flip_texture=*/false)); + /*flip_texture=*/surface_holder_->flip_y)); glBindTexture(src.target(), 0); diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 8dc9efe63..b56e3b79b 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -23,7 +23,7 @@ #if defined(__APPLE__) #include -#include "mediapipe/framework/ios/CFHolder.h" +#include "mediapipe/objc/CFHolder.h" #if !TARGET_OS_OSX #define MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER 1 #endif // TARGET_OS_OSX diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 8c0446a3d..a8555819c 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -22,7 +22,7 @@ #include "mediapipe/gpu/gpu_shared_data_internal.h" #ifdef __APPLE__ -#include "mediapipe/framework/ios/CFHolder.h" +#include "mediapipe/objc/CFHolder.h" #endif // __APPLE__ namespace mediapipe { @@ -129,7 +129,7 @@ GpuBufferMultiPool::~GpuBufferMultiPool() { #ifdef __APPLE__ void GpuBufferMultiPool::RegisterTextureCache(CVTextureCacheType cache) { - MutexLock lock(&mutex_); + absl::MutexLock lock(&mutex_); CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == texture_caches_.end()) @@ -138,7 +138,7 @@ void GpuBufferMultiPool::RegisterTextureCache(CVTextureCacheType cache) { } void GpuBufferMultiPool::UnregisterTextureCache(CVTextureCacheType cache) { - MutexLock lock(&mutex_); + absl::MutexLock lock(&mutex_); auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); CHECK(it != texture_caches_.end()) diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 019cf7a15..43362cdec 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -39,7 +39,7 @@ namespace mediapipe { -class GpuSharedData; +struct GpuSharedData; struct BufferSpec { BufferSpec(int w, int h, GpuBufferFormat f) diff --git a/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc b/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc index 0f0b57f64..9732de5bd 100644 --- a/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc +++ b/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc @@ -19,7 +19,7 @@ #define HAVE_GPU_BUFFER #ifdef __APPLE__ -#include "mediapipe/framework/ios/util.h" +#include "mediapipe/objc/util.h" #endif #include "mediapipe/gpu/gl_calculator_helper.h" diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index e7c5bea48..7863b70d2 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -20,7 +20,7 @@ #include "mediapipe/gpu/graph_support.h" #if __APPLE__ -#import "mediapipe/gpu/MediaPipeGraphGPUData.h" +#import "mediapipe/gpu/MPPGraphGPUData.h" #endif // __APPLE__ namespace mediapipe { @@ -85,9 +85,8 @@ GpuResources::GpuResources(std::shared_ptr gl_context) { std::make_shared(gl_context.get()); #if __APPLE__ gpu_buffer_pool().RegisterTextureCache(gl_context->cv_texture_cache()); - ios_gpu_data_ = - [[MediaPipeGraphGPUData alloc] initWithContext:gl_context.get() - multiPool:&gpu_buffer_pool_]; + ios_gpu_data_ = [[MPPGraphGPUData alloc] initWithContext:gl_context.get() + multiPool:&gpu_buffer_pool_]; #endif // __APPLE__ } @@ -170,7 +169,7 @@ const std::shared_ptr& GpuResources::gl_context( GpuSharedData::GpuSharedData() : GpuSharedData(kPlatformGlContextNone) {} #if __APPLE__ -MediaPipeGraphGPUData* GpuResources::ios_gpu_data() { return ios_gpu_data_; } +MPPGraphGPUData* GpuResources::ios_gpu_data() { return ios_gpu_data_; } #endif // __APPLE__ } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index 2b659676a..65f9d8891 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -31,9 +31,9 @@ #ifdef __APPLE__ #ifdef __OBJC__ -@class MediaPipeGraphGPUData; +@class MPPGraphGPUData; #else -struct MediaPipeGraphGPUData; +struct MPPGraphGPUData; #endif // __OBJC__ #endif // defined(__APPLE__) @@ -66,7 +66,7 @@ class GpuResources { GpuBufferMultiPool& gpu_buffer_pool() { return gpu_buffer_pool_; } #ifdef __APPLE__ - MediaPipeGraphGPUData* ios_gpu_data(); + MPPGraphGPUData* ios_gpu_data(); #endif // defined(__APPLE__) void PrepareGpuNode(CalculatorNode* node); @@ -93,7 +93,7 @@ class GpuResources { #ifdef __APPLE__ // Note that this is an Objective-C object. - MediaPipeGraphGPUData* ios_gpu_data_; + MPPGraphGPUData* ios_gpu_data_; #endif // defined(__APPLE__) std::map> named_executors_; diff --git a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc index 652f81e98..c3e4c13b7 100644 --- a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +++ b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc @@ -18,7 +18,7 @@ #include "mediapipe/gpu/gl_calculator_helper.h" #ifdef __APPLE__ -#include "mediapipe/framework/ios/util.h" +#include "mediapipe/objc/util.h" #endif namespace mediapipe { diff --git a/mediapipe/gpu/metal.bzl b/mediapipe/gpu/metal.bzl new file mode 100644 index 000000000..2b45c6614 --- /dev/null +++ b/mediapipe/gpu/metal.bzl @@ -0,0 +1,181 @@ +"""Experimental Skylark rules for Apple's Metal. + +This creates a .metallib file containing compiled Metal shaders. +Note that the default behavior in Xcode is to put all metal shaders into a +single "default.metallib", which can be loaded using the method +newDefaultLibrary in MTLDevice. Meanwhile, the metal_library rule creates a +named .metallib, which can be loaded using newLibraryWithFile:error:. + +Example: + + metal_library( + name = "my_shaders", + srcs = ["my_shaders.metal"], + hdrs = ["my_shaders.h"], + ) + +This produces a "my_shaders.metallib". + +The metal_library target can be added to the deps attribute of an objc_library. +The dependent objc_library can then access the headers declared by the +metal_library, if any. + +The metal_library target can also be added to the resources attribute as a +simple data file, but in that case any declared headers are not visible to +dependent objc_library rules. +""" + +load("@build_bazel_apple_support//lib:apple_support.bzl", "apple_support") +load("@bazel_skylib//lib:dicts.bzl", "dicts") + +# This load statement is overriding the visibility of the internal implementation of rules_apple. +# This rule will be migrated to rules_apple in the future, hence the override. Please do not use +# this import anywhere else. +load( + "@build_bazel_rules_apple//apple/internal:resources.bzl", + "resources", +) + +def _metal_compiler_args(ctx, src, obj, minimum_os_version, copts, diagnostics, deps_dump): + """Returns arguments for metal compiler.""" + apple_fragment = ctx.fragments.apple + + platform = apple_fragment.single_arch_platform + + if not minimum_os_version: + minimum_os_version = ctx.attr._xcode_config[apple_common.XcodeVersionConfig].minimum_os_for_platform_type( + platform.platform_type, + ) + + args = copts + [ + "-arch", + "air64", # TODO: choose based on target device/cpu/platform? + "-emit-llvm", + "-c", + "-gline-tables-only", + "-isysroot", + apple_support.path_placeholders.sdkroot(), + "-ffast-math", + "-serialize-diagnostics", + diagnostics.path, + "-o", + obj.path, + "-mios-version-min=%s" % minimum_os_version, + "", + src.path, + "-MMD", + "-MT", + "dependencies", + "-MF", + deps_dump.path, + ] + return args + +def _metal_compiler_inputs(srcs, hdrs, deps = []): + """Determines the list of inputs required for a compile action.""" + objc_providers = [x.objc for x in deps if hasattr(x, "objc")] + + objc_files = depset() + for objc in objc_providers: + objc_files += objc.header + + return srcs + hdrs + objc_files.to_list() + +def _metal_library_impl(ctx): + """Implementation for metal_library Skylark rule.""" + + # A unique path for rule's outputs. + objs_outputs_path = "{}.objs/".format(ctx.label.name) + + output_objs = [] + for src in ctx.files.srcs: + basename = src.basename + obj = ctx.actions.declare_file(objs_outputs_path + basename + ".air") + output_objs.append(obj) + diagnostics = ctx.actions.declare_file(objs_outputs_path + basename + ".dia") + deps_dump = ctx.actions.declare_file(objs_outputs_path + basename + ".dat") + + args = (["metal"] + + _metal_compiler_args(ctx, src, obj, ctx.attr.minimum_os_version, ctx.attr.copts, diagnostics, deps_dump)) + + apple_support.run( + ctx, + xcode_path_resolve_level = apple_support.xcode_path_resolve_level.args, + inputs = _metal_compiler_inputs(ctx.files.srcs, ctx.files.hdrs, ctx.attr.deps), + outputs = [obj, diagnostics, deps_dump], + mnemonic = "MetalCompile", + executable = "/usr/bin/xcrun", + arguments = args, + use_default_shell_env = False, + progress_message = ("Compiling Metal shader %s" % + (basename)), + ) + + output_lib = ctx.actions.declare_file(ctx.label.name + ".metallib") + args = [ + "metallib", + "-split-module", + "-o", + output_lib.path, + ] + [x.path for x in output_objs] + + apple_support.run( + ctx, + xcode_path_resolve_level = apple_support.xcode_path_resolve_level.args, + inputs = output_objs, + outputs = (output_lib,), + mnemonic = "MetalLink", + executable = "/usr/bin/xcrun", + arguments = args, + progress_message = ( + "Linking Metal library %s" % ctx.label.name + ), + ) + + # This ridiculous circumlocution is needed because new_objc_provider rejects + # an empty depset, with the error: + # "Value for key header must be a set of File, instead found set of unknown." + # It also rejects an explicit "None". + additional_params = {} + if ctx.files.hdrs: + additional_params["header"] = depset([f for f in ctx.files.hdrs]) + objc_provider = apple_common.new_objc_provider( + providers = [x.objc for x in ctx.attr.deps if hasattr(x, "objc")], + **additional_params + ) + + return [ + DefaultInfo( + files = depset([output_lib]), + ), + objc_provider, + # Return the provider for the new bundling logic of rules_apple. + resources.bucketize_typed([output_lib], "unprocessed"), + ] + +METAL_LIBRARY_ATTRS = dicts.add(apple_support.action_required_attrs(), { + "srcs": attr.label_list(allow_files = [".metal"], allow_empty = False), + "hdrs": attr.label_list(allow_files = [".h"]), + "deps": attr.label_list(providers = [["objc"]]), + "copts": attr.string_list(), + "minimum_os_version": attr.string(), +}) + +metal_library = rule( + implementation = _metal_library_impl, + attrs = METAL_LIBRARY_ATTRS, + fragments = ["apple", "objc", "swift"], + output_to_genfiles = True, +) +""" +Builds a Metal library. + +Args: + srcs: Metal shader sources. + hdrs: Header files used by the shader sources. + deps: objc_library targets whose headers should be visible to the shaders. + +The header files declared in this rule are also visible to any objc_library +rules that have it as a dependency, so that constants and typedefs can be +shared between Metal and Objective-C code. +""" diff --git a/mediapipe/gpu/pixel_buffer_pool_util.h b/mediapipe/gpu/pixel_buffer_pool_util.h new file mode 100644 index 000000000..f5ae07df1 --- /dev/null +++ b/mediapipe/gpu/pixel_buffer_pool_util.h @@ -0,0 +1,73 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_GPU_PIXEL_BUFFER_POOL_UTIL_H_ +#define MEDIAPIPE_GPU_PIXEL_BUFFER_POOL_UTIL_H_ + +#include +#include + +#include + +#ifndef __APPLE__ +#error gpu_pixel_buffer_pool_util is only for use on Apple platforms. +#endif // !defined(__APPLE__) + +namespace mediapipe { + +#if TARGET_OS_OSX +typedef CVOpenGLTextureCacheRef CVTextureCacheType; +#else +typedef CVOpenGLESTextureCacheRef CVTextureCacheType; +#endif // TARGET_OS_OSX + +// Create a CVPixelBufferPool. +CVPixelBufferPoolRef CreateCVPixelBufferPool(int width, int height, + OSType pixel_format, + int keep_count, + CFTimeInterval maxAge); + +// Preallocate the given number of pixel buffers. +OSStatus PreallocateCVPixelBufferPoolBuffers(CVPixelBufferPoolRef pool, + int count, + CFDictionaryRef auxAttributes); + +// Create a CVPixelBuffer using a pool. +// If the pool is full, will flush the provided texture cache before trying +// again. +CVReturn CreateCVPixelBufferWithPool(CVPixelBufferPoolRef pool, + CFDictionaryRef auxAttributes, + CVTextureCacheType textureCache, + CVPixelBufferRef* outBuffer); + +// Create a CVPixelBuffer using a pool. +// If the pool is full, will call the provided function before trying again. +CVReturn CreateCVPixelBufferWithPool(CVPixelBufferPoolRef pool, + CFDictionaryRef auxAttributes, + std::function flush, + CVPixelBufferRef* outBuffer); + +// Create an auxiliary attribute dictionary, which can be used with +// CVPixelBufferPool, specifying the given allocation threshold. +CFDictionaryRef CreateCVPixelBufferPoolAuxiliaryAttributesForThreshold( + int allocationThreshold); + +// Create a CVPixelBuffer without using a pool. +CVReturn CreateCVPixelBufferWithoutPool(int width, int height, + OSType pixelFormat, + CVPixelBufferRef* outBuffer); + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_PIXEL_BUFFER_POOL_UTIL_H_ diff --git a/mediapipe/gpu/pixel_buffer_pool_util.mm b/mediapipe/gpu/pixel_buffer_pool_util.mm new file mode 100644 index 000000000..5d0906003 --- /dev/null +++ b/mediapipe/gpu/pixel_buffer_pool_util.mm @@ -0,0 +1,163 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/gpu/pixel_buffer_pool_util.h" + +#import + +#if !defined(ENABLE_MEDIAPIPE_GPU_BUFFER_THRESHOLD_CHECK) && !defined(NDEBUG) +#define ENABLE_MEDIAPIPE_GPU_BUFFER_THRESHOLD_CHECK 1 +#endif // defined(ENABLE_MEDIAPIPE_GPU_BUFFER_THRESHOLD_CHECK) + +namespace mediapipe { + +CVPixelBufferPoolRef CreateCVPixelBufferPool( + int width, int height, OSType pixelFormat, int keepCount, + CFTimeInterval maxAge) { + CVPixelBufferPoolRef pool = NULL; + + NSDictionary *sourcePixelBufferOptions = @{ + (id)kCVPixelBufferPixelFormatTypeKey : @(pixelFormat), + (id)kCVPixelBufferWidthKey : @(width), + (id)kCVPixelBufferHeightKey : @(height), +#if TARGET_OS_OSX + (id)kCVPixelFormatOpenGLCompatibility : @(YES), +#else + (id)kCVPixelFormatOpenGLESCompatibility : @(YES), +#endif // TARGET_OS_OSX + (id)kCVPixelBufferIOSurfacePropertiesKey : @{ /*empty dictionary*/ } + }; + + NSMutableDictionary *pixelBufferPoolOptions = [[NSMutableDictionary alloc] init]; + pixelBufferPoolOptions[(id)kCVPixelBufferPoolMinimumBufferCountKey] = @(keepCount); + if (maxAge > 0) { + pixelBufferPoolOptions[(id)kCVPixelBufferPoolMaximumBufferAgeKey] = @(maxAge); + } + + CVPixelBufferPoolCreate( + kCFAllocatorDefault, (__bridge CFDictionaryRef)pixelBufferPoolOptions, + (__bridge CFDictionaryRef)sourcePixelBufferOptions, &pool); + + return pool; +} + +OSStatus PreallocateCVPixelBufferPoolBuffers( + CVPixelBufferPoolRef pool, int count, CFDictionaryRef auxAttributes) { + CVReturn err = kCVReturnSuccess; + NSMutableArray *pixelBuffers = [[NSMutableArray alloc] init]; + for (int i = 0; i < count && err == kCVReturnSuccess; i++) { + CVPixelBufferRef pixelBuffer = NULL; + err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( + kCFAllocatorDefault, pool, auxAttributes, &pixelBuffer); + if (err != kCVReturnSuccess) { + break; + } + + [pixelBuffers addObject:(__bridge id)pixelBuffer]; + CFRelease(pixelBuffer); + } + return err; +} + +CFDictionaryRef CreateCVPixelBufferPoolAuxiliaryAttributesForThreshold(int allocationThreshold) { + if (allocationThreshold > 0) { + return (CFDictionaryRef)CFBridgingRetain( + @{(id)kCVPixelBufferPoolAllocationThresholdKey: @(allocationThreshold)}); + } else { + return nil; + } +} + +CVReturn CreateCVPixelBufferWithPool( + CVPixelBufferPoolRef pool, CFDictionaryRef auxAttributes, + CVTextureCacheType textureCache, CVPixelBufferRef* outBuffer) { + return CreateCVPixelBufferWithPool(pool, auxAttributes, [textureCache](){ +#if TARGET_OS_OSX + CVOpenGLTextureCacheFlush(textureCache, 0); +#else + CVOpenGLESTextureCacheFlush(textureCache, 0); +#endif // TARGET_OS_OSX + }, outBuffer); +} + +CVReturn CreateCVPixelBufferWithPool( + CVPixelBufferPoolRef pool, CFDictionaryRef auxAttributes, + std::function flush, CVPixelBufferRef* outBuffer) { + CVReturn err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( + kCFAllocatorDefault, pool, auxAttributes, outBuffer); + if (err == kCVReturnWouldExceedAllocationThreshold) { + if (flush) { + // Call the flush function to potentially release the retained buffers + // and try again to create a pixel buffer. + flush(); + err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( + kCFAllocatorDefault, pool, auxAttributes, outBuffer); + } + if (err == kCVReturnWouldExceedAllocationThreshold) { + // TODO: allow the application to set the threshold. For now, disable it by + // default, since the threshold we are using is arbitrary and some graphs routinely cross it. +#ifdef ENABLE_MEDIAPIPE_GPU_BUFFER_THRESHOLD_CHECK + NSLog(@"Using more buffers than expected! This is a debug-only warning, " + "you can ignore it if your app works fine otherwise."); +#ifdef DEBUG + NSLog(@"Pool status: %@", ((__bridge NSObject *)pool).description); +#endif // DEBUG +#endif // defined(ENABLE_MEDIAPIPE_GPU_BUFFER_THRESHOLD_CHECK) + // Try again and ignore threshold. + // TODO drop a frame instead? + err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( + kCFAllocatorDefault, pool, NULL, outBuffer); + } + } + return err; +} + +#if TARGET_IPHONE_SIMULATOR +static void FreeRefConReleaseCallback(void* refCon, const void* baseAddress) { + free(refCon); +} +#endif + +CVReturn CreateCVPixelBufferWithoutPool( + int width, int height, OSType pixelFormat, CVPixelBufferRef* outBuffer) { + NSDictionary *attributes = @{ +#if TARGET_OS_OSX + (id)kCVPixelFormatOpenGLCompatibility : @(YES), +#else + (id)kCVPixelFormatOpenGLESCompatibility : @(YES), +#endif // TARGET_OS_OSX + (id)kCVPixelBufferIOSurfacePropertiesKey : @{ /*empty dictionary*/ } + }; +#if TARGET_IPHONE_SIMULATOR + // On the simulator, syncing the texture with the pixelbuffer does not work, + // and we have to use glReadPixels. Since GL_UNPACK_ROW_LENGTH is not + // available in OpenGL ES 2, we should create the buffer so the pixels are + // contiguous. + // + // TODO: verify if we can use kIOSurfaceBytesPerRow to force + // CoreVideo to give us contiguous data. + size_t bytes_per_row = width * 4; + void* data = malloc(bytes_per_row * height); + return CVPixelBufferCreateWithBytes( + kCFAllocatorDefault, width, height, pixelFormat, data, bytes_per_row, + FreeRefConReleaseCallback, data, (__bridge CFDictionaryRef)attributes, + outBuffer); +#else + return CVPixelBufferCreate( + kCFAllocatorDefault, width, height, pixelFormat, + (__bridge CFDictionaryRef)attributes, outBuffer); +#endif +} + +} // namespace mediapipe diff --git a/mediapipe/graphs/edge_detection/BUILD b/mediapipe/graphs/edge_detection/BUILD index e25ad8efb..2f47a3dde 100644 --- a/mediapipe/graphs/edge_detection/BUILD +++ b/mediapipe/graphs/edge_detection/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipeOSS Authors. +# Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) cc_library( - name = "android_calculators", + name = "mobile_calculators", deps = [ "//mediapipe/calculators/image:luminance_calculator", "//mediapipe/calculators/image:sobel_edges_calculator", @@ -31,7 +30,7 @@ load( ) mediapipe_binary_graph( - name = "android_gpu_binary_graph", - graph = "edge_detection_android_gpu.pbtxt", - output_name = "android_gpu.binarypb", + name = "mobile_gpu_binary_graph", + graph = "edge_detection_mobile_gpu.pbtxt", + output_name = "mobile_gpu.binarypb", ) diff --git a/mediapipe/graphs/edge_detection/edge_detection_android_gpu.pbtxt b/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt similarity index 76% rename from mediapipe/graphs/edge_detection/edge_detection_android_gpu.pbtxt rename to mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt index 2195426d4..9b99debb9 100644 --- a/mediapipe/graphs/edge_detection/edge_detection_android_gpu.pbtxt +++ b/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt @@ -1,6 +1,7 @@ -# MediaPipe graph that performs Sobel edge detection on a live video stream on -# GPU. Used in the example in -# mediapipe/examples/android/src/java/com/mediapipe/apps/edgedetectiongpu. +# MediaPipe graph that performs GPU Sobel edge detection on a live video stream. +# Used in the examples in +# mediapipe/examples/android/src/java/com/mediapipe/apps/edgedetectiongpu and +# mediapipe/examples/ios/edgedetectiongpu. # Images coming into and out of the graph. input_stream: "input_video" diff --git a/mediapipe/graphs/face_detection/BUILD b/mediapipe/graphs/face_detection/BUILD index 25e3d16aa..0281f3437 100644 --- a/mediapipe/graphs/face_detection/BUILD +++ b/mediapipe/graphs/face_detection/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipeOSS Authors. +# Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,16 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) cc_library( - name = "android_calculators", + name = "mobile_calculators", deps = [ - "//mediapipe/calculators/core:real_time_flow_limiter_calculator", + "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tflite:ssd_anchors_calculator", "//mediapipe/calculators/tflite:tflite_converter_calculator", @@ -42,37 +41,15 @@ load( ) mediapipe_binary_graph( - name = "android_cpu_binary_graph", - graph = "face_detection_android_cpu.pbtxt", - output_name = "android_cpu.binarypb", - deps = [ - "//mediapipe/calculators/image:image_transformation_calculator_proto", - "//mediapipe/calculators/image:scale_image_calculator_proto", - "//mediapipe/calculators/tflite:ssd_anchors_calculator_proto", - "//mediapipe/calculators/tflite:tflite_converter_calculator_proto", - "//mediapipe/calculators/tflite:tflite_inference_calculator_proto", - "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator_proto", - "//mediapipe/calculators/util:annotation_overlay_calculator_proto", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator_proto", - "//mediapipe/calculators/util:detections_to_render_data_calculator_proto", - "//mediapipe/calculators/util:non_max_suppression_calculator_proto", - ], + name = "mobile_cpu_binary_graph", + graph = "face_detection_mobile_cpu.pbtxt", + output_name = "mobile_cpu.binarypb", + deps = [":mobile_calculators"], ) mediapipe_binary_graph( - name = "android_gpu_binary_graph", - graph = "face_detection_android_gpu.pbtxt", - output_name = "android_gpu.binarypb", - deps = [ - "//mediapipe/calculators/image:image_transformation_calculator_proto", - "//mediapipe/calculators/image:scale_image_calculator_proto", - "//mediapipe/calculators/tflite:ssd_anchors_calculator_proto", - "//mediapipe/calculators/tflite:tflite_converter_calculator_proto", - "//mediapipe/calculators/tflite:tflite_inference_calculator_proto", - "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator_proto", - "//mediapipe/calculators/util:annotation_overlay_calculator_proto", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator_proto", - "//mediapipe/calculators/util:detections_to_render_data_calculator_proto", - "//mediapipe/calculators/util:non_max_suppression_calculator_proto", - ], + name = "mobile_gpu_binary_graph", + graph = "face_detection_mobile_gpu.pbtxt", + output_name = "mobile_gpu.binarypb", + deps = [":mobile_calculators"], ) diff --git a/mediapipe/graphs/face_detection/face_detection_android_cpu.pbtxt b/mediapipe/graphs/face_detection/face_detection_mobile_cpu.pbtxt similarity index 82% rename from mediapipe/graphs/face_detection/face_detection_android_cpu.pbtxt rename to mediapipe/graphs/face_detection/face_detection_mobile_cpu.pbtxt index 47c1d44c7..d52ef6c5c 100644 --- a/mediapipe/graphs/face_detection/face_detection_android_cpu.pbtxt +++ b/mediapipe/graphs/face_detection/face_detection_mobile_cpu.pbtxt @@ -1,6 +1,7 @@ # MediaPipe graph that performs face detection with TensorFlow Lite on CPU. -# Used in the example in -# mediapipie/examples/android/src/java/com/mediapipe/apps/facedetectioncpu. +# Used in the examples in +# mediapipie/examples/android/src/java/com/mediapipe/apps/facedetectioncpu and +# mediapipie/examples/ios/facedetectioncpu. # Images on GPU coming into and out of the graph. input_stream: "input_video" @@ -20,7 +21,7 @@ output_stream: "output_video" # TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy # processing previous inputs. node { - calculator: "RealTimeFlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "input_video" input_stream: "FINISHED:detections" input_stream_info: { @@ -57,22 +58,12 @@ node: { } } -# Converts the transformed input image on CPU into an image tensor as a -# TfLiteTensor. The zero_center option is set to true to normalize the -# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically -# option is set to true to account for the descrepancy between the -# representation of the input image (origin at the bottom-left corner) and what -# the model used in this graph is expecting (origin at the top-left corner). +# Converts the transformed input image on CPU into an image tensor stored as a +# TfLiteTensor. node { calculator: "TfLiteConverterCalculator" input_stream: "IMAGE:transformed_input_video_cpu" output_stream: "TENSORS:image_tensor" - node_options: { - [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { - zero_center: true - flip_vertically: true - } - } } # Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a @@ -84,7 +75,7 @@ node { output_stream: "TENSORS:detection_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "facedetector_front.tflite" + model_path: "face_detection_front.tflite" } } } @@ -137,7 +128,7 @@ node { y_scale: 128.0 h_scale: 128.0 w_scale: 128.0 - flip_vertically: true + min_score_thresh: 0.75 } } } @@ -150,9 +141,9 @@ node { node_options: { [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { min_suppression_threshold: 0.3 - min_score_threshold: 0.75 overlap_type: INTERSECTION_OVER_UNION algorithm: WEIGHTED + return_empty_detections: true } } } @@ -165,7 +156,7 @@ node { output_stream: "labeled_detections" node_options: { [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { - label_map_path: "facedetector_front_labelmap.txt" + label_map_path: "face_detection_front_labelmap.txt" } } } @@ -184,7 +175,7 @@ node { # Converts the detections to drawing primitives for annotation overlay. node { calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION_VECTOR:output_detections" + input_stream: "DETECTIONS:output_detections" output_stream: "RENDER_DATA:render_data" node_options: { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { @@ -194,21 +185,12 @@ node { } } -# Draws annotations and overlays them on top of the CPU copy of the original -# image coming into the graph. The calculator assumes that image origin is -# always at the top-left corner and renders text accordingly. However, the input -# image has its origin at the bottom-left corner (OpenGL convention) and the -# flip_text_vertically option is set to true to compensate that. +# Draws annotations and overlays them on top of the input images. node { calculator: "AnnotationOverlayCalculator" input_stream: "INPUT_FRAME:input_video_cpu" input_stream: "render_data" output_stream: "OUTPUT_FRAME:output_video_cpu" - node_options: { - [type.googleapis.com/mediapipe.AnnotationOverlayCalculatorOptions] { - flip_text_vertically: true - } - } } # Transfers the annotated image from CPU back to GPU memory, to be sent out of diff --git a/mediapipe/graphs/face_detection/face_detection_android_gpu.pbtxt b/mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt similarity index 78% rename from mediapipe/graphs/face_detection/face_detection_android_gpu.pbtxt rename to mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt index 067e10cb3..e12787d5b 100644 --- a/mediapipe/graphs/face_detection/face_detection_android_gpu.pbtxt +++ b/mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt @@ -1,6 +1,7 @@ # MediaPipe graph that performs face detection with TensorFlow Lite on GPU. -# Used in the example in -# mediapipie/examples/android/src/java/com/mediapipe/apps/facedetectiongpu. +# Used in the examples in +# mediapipie/examples/android/src/java/com/mediapipe/apps/facedetectiongpu and +# mediapipie/examples/ios/facedetectiongpu. # Images on GPU coming into and out of the graph. input_stream: "input_video" @@ -20,7 +21,7 @@ output_stream: "output_video" # TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy # processing previous inputs. node { - calculator: "RealTimeFlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "input_video" input_stream: "FINISHED:detections" input_stream_info: { @@ -47,23 +48,12 @@ node: { } } -# Converts the transformed input image on GPU into an image tensor stored in -# tflite::gpu::GlBuffer. The zero_center option is set to true to normalize the -# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically -# option is set to true to account for the descrepancy between the -# representation of the input image (origin at the bottom-left corner, the -# OpenGL convention) and what the model used in this graph is expecting (origin -# at the top-left corner). +# Converts the transformed input image on GPU into an image tensor stored as a +# TfLiteTensor. node { calculator: "TfLiteConverterCalculator" input_stream: "IMAGE_GPU:transformed_input_video" output_stream: "TENSORS_GPU:image_tensor" - node_options: { - [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { - zero_center: true - flip_vertically: true - } - } } # Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a @@ -72,10 +62,10 @@ node { node { calculator: "TfLiteInferenceCalculator" input_stream: "TENSORS_GPU:image_tensor" - output_stream: "TENSORS_GPU:detection_tensors" + output_stream: "TENSORS:detection_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "facedetector_front.tflite" + model_path: "face_detection_front.tflite" } } } @@ -109,7 +99,7 @@ node { # detections. Each detection describes a detected object. node { calculator: "TfLiteTensorsToDetectionsCalculator" - input_stream: "TENSORS_GPU:detection_tensors" + input_stream: "TENSORS:detection_tensors" input_side_packet: "ANCHORS:anchors" output_stream: "DETECTIONS:detections" node_options: { @@ -128,7 +118,7 @@ node { y_scale: 128.0 h_scale: 128.0 w_scale: 128.0 - flip_vertically: true + min_score_thresh: 0.75 } } } @@ -141,9 +131,9 @@ node { node_options: { [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { min_suppression_threshold: 0.3 - min_score_threshold: 0.75 overlap_type: INTERSECTION_OVER_UNION algorithm: WEIGHTED + return_empty_detections: true } } } @@ -156,7 +146,7 @@ node { output_stream: "labeled_detections" node_options: { [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { - label_map_path: "facedetector_front_labelmap.txt" + label_map_path: "face_detection_front_labelmap.txt" } } } @@ -175,31 +165,20 @@ node { # Converts the detections to drawing primitives for annotation overlay. node { calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION_VECTOR:output_detections" + input_stream: "DETECTIONS:output_detections" output_stream: "RENDER_DATA:render_data" node_options: { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { - thickness: 8.0 + thickness: 10.0 color { r: 255 g: 0 b: 0 } } } } -# Draws annotations and overlays them on top of the original image coming into -# the graph. Annotation drawing is performed on CPU, and the result is -# transferred to GPU and overlaid on the input image. The calculator assumes -# that image origin is always at the top-left corner and renders text -# accordingly. However, the input image has its origin at the bottom-left corner -# (OpenGL convention) and the flip_text_vertically option is set to true to -# compensate that. +# Draws annotations and overlays them on top of the input images. node { calculator: "AnnotationOverlayCalculator" input_stream: "INPUT_FRAME_GPU:throttled_input_video" input_stream: "render_data" output_stream: "OUTPUT_FRAME_GPU:output_video" - node_options: { - [type.googleapis.com/mediapipe.AnnotationOverlayCalculatorOptions] { - flip_text_vertically: true - } - } } diff --git a/mediapipe/graphs/hair_segmentation/BUILD b/mediapipe/graphs/hair_segmentation/BUILD index d2abbaa9d..eec0732e3 100644 --- a/mediapipe/graphs/hair_segmentation/BUILD +++ b/mediapipe/graphs/hair_segmentation/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipeOSS Authors. +# Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,17 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) cc_library( - name = "android_calculators", + name = "mobile_calculators", deps = [ + "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:previous_loopback_calculator", - "//mediapipe/calculators/core:real_time_flow_limiter_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/image:set_alpha_calculator", @@ -29,6 +28,8 @@ cc_library( "//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator", "//mediapipe/calculators/tflite:tflite_inference_calculator", "//mediapipe/calculators/tflite:tflite_tensors_to_segmentation_calculator", + "//mediapipe/gpu:gpu_buffer_to_image_frame_calculator", + "//mediapipe/gpu:image_frame_to_gpu_buffer_calculator", ], ) @@ -38,16 +39,8 @@ load( ) mediapipe_binary_graph( - name = "android_gpu_binary_graph", - graph = "hair_segmentation_android_gpu.pbtxt", - output_name = "android_gpu.binarypb", - deps = [ - "//mediapipe/calculators/image:image_transformation_calculator_proto", - "//mediapipe/calculators/image:recolor_calculator_proto", - "//mediapipe/calculators/image:set_alpha_calculator_proto", - "//mediapipe/calculators/tflite:tflite_converter_calculator_proto", - "//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator_proto", - "//mediapipe/calculators/tflite:tflite_inference_calculator_proto", - "//mediapipe/calculators/tflite:tflite_tensors_to_segmentation_calculator_proto", - ], + name = "mobile_gpu_binary_graph", + graph = "hair_segmentation_mobile_gpu.pbtxt", + output_name = "mobile_gpu.binarypb", + deps = [":mobile_calculators"], ) diff --git a/mediapipe/graphs/hair_segmentation/hair_segmentation_android_gpu.pbtxt b/mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt similarity index 82% rename from mediapipe/graphs/hair_segmentation/hair_segmentation_android_gpu.pbtxt rename to mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt index 5b06e694c..ed5d0ada4 100644 --- a/mediapipe/graphs/hair_segmentation/hair_segmentation_android_gpu.pbtxt +++ b/mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt @@ -20,7 +20,7 @@ output_stream: "output_video" # TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy # processing previous inputs. node { - calculator: "RealTimeFlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "input_video" input_stream: "FINISHED:hair_mask" input_stream_info: { @@ -47,14 +47,11 @@ node: { } } -# Waits for a mask from the previous round of hair segmentation to be fed back -# as an input, and caches it. Upon the arrival of an input image, it checks if -# there is a mask cached, and sends out the mask with the timestamp replaced by -# that of the input image. This is needed so that the "current image" and the -# "previous mask" share the same timestamp, and as a result can be synchronized -# and combined in the subsequent calculator. Note that upon the arrival of the -# very first input frame, an empty packet is sent out to jump start the feedback -# loop. +# Caches a mask fed back from the previous round of hair segmentation, and upon +# the arrival of the next input image sends out the cached mask with the +# timestamp replaced by that of the input image, essentially generating a packet +# that carries the previous mask. Note that upon the arrival of the very first +# input image, an empty packet is sent out to jump start the feedback loop. node { calculator: "PreviousLoopbackCalculator" input_stream: "MAIN:throttled_input_video" @@ -77,12 +74,9 @@ node { # Converts the transformed input image on GPU into an image tensor stored in # tflite::gpu::GlBuffer. The zero_center option is set to false to normalize the -# pixel values to [0.f, 1.f] as opposed to [-1.f, 1.f]. The flip_vertically -# option is set to true to account for the descrepancy between the -# representation of the input image (origin at the bottom-left corner, the -# OpenGL convention) and what the model used in this graph is expecting (origin -# at the top-left corner). With the max_num_channels option set to 4, all 4 RGBA -# channels are contained in the image tensor. +# pixel values to [0.f, 1.f] as opposed to [-1.f, 1.f]. With the +# max_num_channels option set to 4, all 4 RGBA channels are contained in the +# image tensor. node { calculator: "TfLiteConverterCalculator" input_stream: "IMAGE_GPU:mask_embedded_input_video" @@ -90,7 +84,6 @@ node { node_options: { [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { zero_center: false - flip_vertically: true max_num_channels: 4 } } diff --git a/mediapipe/graphs/hand_tracking/BUILD b/mediapipe/graphs/hand_tracking/BUILD new file mode 100644 index 000000000..73c5a6ce3 --- /dev/null +++ b/mediapipe/graphs/hand_tracking/BUILD @@ -0,0 +1,113 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load( + "//mediapipe/framework/tool:mediapipe_graph.bzl", + "mediapipe_binary_graph", + "mediapipe_simple_subgraph", +) + +mediapipe_simple_subgraph( + name = "hand_detection_gpu", + graph = "hand_detection_gpu.pbtxt", + register_as = "HandDetectionSubgraph", + deps = [ + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:tflite_converter_calculator", + "//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator", + "//mediapipe/calculators/tflite:tflite_inference_calculator", + "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_letterbox_removal_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "hand_landmark_gpu", + graph = "hand_landmark_gpu.pbtxt", + register_as = "HandLandmarkSubgraph", + deps = [ + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/image:image_cropping_calculator", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tflite:tflite_converter_calculator", + "//mediapipe/calculators/tflite:tflite_inference_calculator", + "//mediapipe/calculators/tflite:tflite_tensors_to_floats_calculator", + "//mediapipe/calculators/tflite:tflite_tensors_to_landmarks_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:landmark_letterbox_removal_calculator", + "//mediapipe/calculators/util:landmark_projection_calculator", + "//mediapipe/calculators/util:landmarks_to_detection_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator", + "//mediapipe/calculators/util:thresholding_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "renderer_gpu", + graph = "renderer_gpu.pbtxt", + register_as = "RendererSubgraph", + deps = [ + "//mediapipe/calculators/util:annotation_overlay_calculator", + "//mediapipe/calculators/util:detections_to_render_data_calculator", + "//mediapipe/calculators/util:landmarks_to_render_data_calculator", + "//mediapipe/calculators/util:rect_to_render_data_calculator", + ], +) + +cc_library( + name = "mobile_calculators", + deps = [ + ":hand_detection_gpu", + ":hand_landmark_gpu", + ":renderer_gpu", + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:merge_calculator", + "//mediapipe/calculators/core:previous_loopback_calculator", + ], +) + +mediapipe_binary_graph( + name = "hand_tracking_mobile_gpu_binary_graph", + graph = "hand_tracking_mobile.pbtxt", + output_name = "hand_tracking_mobile_gpu.binarypb", + deps = [":mobile_calculators"], +) + +cc_library( + name = "detection_mobile_calculators", + deps = [ + ":hand_detection_gpu", + ":renderer_gpu", + "//mediapipe/calculators/core:flow_limiter_calculator", + ], +) + +mediapipe_binary_graph( + name = "hand_detection_mobile_gpu_binary_graph", + graph = "hand_detection_mobile.pbtxt", + output_name = "hand_detection_mobile_gpu.binarypb", + deps = [":detection_mobile_calculators"], +) diff --git a/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt b/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt new file mode 100644 index 000000000..848bacb9f --- /dev/null +++ b/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt @@ -0,0 +1,197 @@ +# MediaPipe hand detection subgraph. + +type: "HandDetectionSubgraph" + +input_stream: "input_video" +output_stream: "DETECTIONS:palm_detections" +output_stream: "NORM_RECT:hand_rect_from_palm_detections" + +# Transforms the input image on GPU to a 256x256 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE_GPU:input_video" + output_stream: "IMAGE_GPU:transformed_input_video" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 256 + output_height: 256 + scale_mode: FIT + } + } +} + +# Generates a single side packet containing a TensorFlow Lite op resolver that +# supports custom ops needed by the model used in this graph. +node { + calculator: "TfLiteCustomOpResolverCalculator" + output_side_packet: "opresolver" + node_options: { + [type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] { + use_gpu: true + } + } +} + +# Converts the transformed input image on GPU into an image tensor stored as a +# TfLiteTensor. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE_GPU:transformed_input_video" + output_stream: "TENSORS_GPU:image_tensor" +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS_GPU:image_tensor" + output_stream: "TENSORS:detection_tensors" + input_side_packet: "CUSTOM_OP_RESOLVER:opresolver" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "palm_detection.tflite" + use_gpu: true + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 5 + min_scale: 0.1171875 + max_scale: 0.75 + input_size_height: 256 + input_size_width: 256 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 8 + strides: 16 + strides: 32 + strides: 32 + strides: 32 + aspect_ratios: 1.0 + fixed_anchor_size: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TfLiteTensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] { + num_classes: 1 + num_boxes: 2944 + num_coords: 18 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 7 + num_values_per_keypoint: 2 + sigmoid_score: true + score_clipping_thresh: 100.0 + reverse_output_order: true + + x_scale: 256.0 + y_scale: 256.0 + h_scale: 256.0 + w_scale: 256.0 + min_score_thresh: 0.7 + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.3 + overlap_type: INTERSECTION_OVER_UNION + algorithm: WEIGHTED + return_empty_detections: true + } + } +} + +# Maps detection label IDs to the corresponding label text ("Palm"). The label +# map is provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "labeled_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "palm_detection_labelmap.txt" + } + } +} + +# Adjusts detection locations (already normalized to [0.f, 1.f]) on the +# letterboxed image (after image transformation with the FIT scale mode) to the +# corresponding locations on the same image with the letterbox removed (the +# input image to the graph before image transformation). +node { + calculator: "DetectionLetterboxRemovalCalculator" + input_stream: "DETECTIONS:labeled_detections" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "DETECTIONS:palm_detections" +} + +# Extracts image size from the input images. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_GPU:input_video" + output_stream: "SIZE:image_size" +} + +# Converts results of palm detection into a rectangle (normalized by image size) +# that encloses the palm and is rotated such that the line connecting center of +# the wrist and MCP of the middle finger is aligned with the Y-axis of the +# rectangle. +node { + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:palm_detections" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "NORM_RECT:palm_rect" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRectsCalculatorOptions] { + rotation_vector_start_keypoint_index: 0 # Center of wrist. + rotation_vector_end_keypoint_index: 2 # MCP of middle finger. + rotation_vector_target_angle_degrees: 90 + output_zero_rect_for_empty_detections: true + } + } +} + +# Expands and shifts the rectangle that contains the palm so that it's likely +# to cover the entire hand. +node { + calculator: "RectTransformationCalculator" + input_stream: "NORM_RECT:palm_rect" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "hand_rect_from_palm_detections" + node_options: { + [type.googleapis.com/mediapipe.RectTransformationCalculatorOptions] { + scale_x: 2.6 + scale_y: 2.6 + shift_y: -0.5 + square_long: true + } + } +} diff --git a/mediapipe/graphs/hand_tracking/hand_detection_mobile.pbtxt b/mediapipe/graphs/hand_tracking/hand_detection_mobile.pbtxt new file mode 100644 index 000000000..f5431c89d --- /dev/null +++ b/mediapipe/graphs/hand_tracking/hand_detection_mobile.pbtxt @@ -0,0 +1,73 @@ +# MediaPipe graph that performs hand detection with TensorFlow Lite on GPU. +# Used in the examples in +# mediapipie/examples/android/src/java/com/mediapipe/apps/handdetectiongpu and +# mediapipie/examples/ios/handdetectiongpu. + +# Images coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for HandDetectionSubgraph +# downstream in the graph to finish its tasks before it passes through another +# image. All images that come in while waiting are dropped, limiting the number +# of in-flight images in HandDetectionSubgraph to 1. This prevents the nodes in +# HandDetectionSubgraph from queuing up incoming images and data excessively, +# which leads to increased latency and memory usage, unwanted in real-time +# mobile applications. It also eliminates unnecessarily computation, e.g., the +# output produced by a node in the subgraph may get dropped downstream if the +# subsequent nodes are still busy processing previous inputs. +node { + calculator: "FlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:hand_rect_from_palm_detections" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +# Subgraph that detections hands (see hand_detection_gpu.pbtxt). +node { + calculator: "HandDetectionSubgraph" + input_stream: "throttled_input_video" + output_stream: "DETECTIONS:palm_detections" + output_stream: "NORM_RECT:hand_rect_from_palm_detections" +} + +# Converts detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:palm_detections" + output_stream: "RENDER_DATA:detection_render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 0 g: 255 b: 0 } + } + } +} + +# Converts normalized rects to drawing primitives for annotation overlay. +node { + calculator: "RectToRenderDataCalculator" + input_stream: "NORM_RECT:hand_rect_from_palm_detections" + output_stream: "RENDER_DATA:rect_render_data" + node_options: { + [type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] { + filled: false + color { r: 255 g: 0 b: 0 } + thickness: 4.0 + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME_GPU:throttled_input_video" + input_stream: "detection_render_data" + input_stream: "rect_render_data" + output_stream: "OUTPUT_FRAME_GPU:output_video" +} diff --git a/mediapipe/graphs/hand_tracking/hand_landmark_gpu.pbtxt b/mediapipe/graphs/hand_tracking/hand_landmark_gpu.pbtxt new file mode 100644 index 000000000..467abd4c5 --- /dev/null +++ b/mediapipe/graphs/hand_tracking/hand_landmark_gpu.pbtxt @@ -0,0 +1,175 @@ +# MediaPipe hand landmark localization subgraph. + +type: "HandLandmarkSubgraph" + +input_stream: "IMAGE:input_video" +input_stream: "NORM_RECT:hand_rect" +output_stream: "LANDMARKS:hand_landmarks" +output_stream: "NORM_RECT:hand_rect_for_next_frame" +output_stream: "PRESENCE:hand_presence" + +# Crops the rectangle that contains a hand from the input image. +node { + calculator: "ImageCroppingCalculator" + input_stream: "IMAGE_GPU:input_video" + input_stream: "NORM_RECT:hand_rect" + output_stream: "IMAGE_GPU:hand_image" +} + +# Transforms the input image on GPU to a 256x256 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE_GPU:hand_image" + output_stream: "IMAGE_GPU:transformed_hand_image" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 256 + output_height: 256 + scale_mode: FIT + } + } +} + +# Converts the transformed input image on GPU into an image tensor stored as a +# TfLiteTensor. +node { + calculator: "TfLiteConverterCalculator" + input_stream: "IMAGE_GPU:transformed_hand_image" + output_stream: "TENSORS_GPU:image_tensor" +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS_GPU:image_tensor" + output_stream: "TENSORS:output_tensors" + node_options: { + [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { + model_path: "hand_landmark.tflite" + use_gpu: true + } + } +} + +# Splits a vector of tensors into multiple vectors. +node { + calculator: "SplitTfLiteTensorVectorCalculator" + input_stream: "output_tensors" + output_stream: "landmark_tensors" + output_stream: "hand_flag_tensor" + node_options: { + [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 1 end: 2 } + } + } +} + +# Converts the hand-flag tensor into a float that represents the confidence +# score of hand presence. +node { + calculator: "TfLiteTensorsToFloatsCalculator" + input_stream: "TENSORS:hand_flag_tensor" + output_stream: "FLOAT:hand_presence_score" +} + +# Applies a threshold to the confidence score to determine whether a hand is +# present. +node { + calculator: "ThresholdingCalculator" + input_stream: "FLOAT:hand_presence_score" + output_stream: "FLAG:hand_presence" + node_options: { + [type.googleapis.com/mediapipe.ThresholdingCalculatorOptions] { + threshold: 0.1 + } + } +} + +# Decodes the landmark tensors into a vector of lanmarks, where the landmark +# coordinates are normalized by the size of the input image to the model. +node { + calculator: "TfLiteTensorsToLandmarksCalculator" + input_stream: "TENSORS:landmark_tensors" + output_stream: "NORM_LANDMARKS:landmarks" + node_options: { + [type.googleapis.com/mediapipe.TfLiteTensorsToLandmarksCalculatorOptions] { + num_landmarks: 21 + input_image_width: 256 + input_image_height: 256 + } + } +} + +# Adjusts landmarks (already normalized to [0.f, 1.f]) on the letterboxed hand +# image (after image transformation with the FIT scale mode) to the +# corresponding locations on the same image with the letterbox removed (hand +# image before image transformation). +node { + calculator: "LandmarkLetterboxRemovalCalculator" + input_stream: "LANDMARKS:landmarks" + input_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "LANDMARKS:scaled_landmarks" +} + +# Projects the landmarks from the cropped hand image to the corresponding +# locations on the full image before cropping (input to the graph). +node { + calculator: "LandmarkProjectionCalculator" + input_stream: "NORM_LANDMARKS:scaled_landmarks" + input_stream: "NORM_RECT:hand_rect" + output_stream: "NORM_LANDMARKS:hand_landmarks" +} + +# Extracts image size from the input images. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_GPU:input_video" + output_stream: "SIZE:image_size" +} + +# Converts hand landmarks to a detection that tightly encloses all landmarks. +node { + calculator: "LandmarksToDetectionCalculator" + input_stream: "NORM_LANDMARKS:hand_landmarks" + output_stream: "DETECTION:hand_detection" +} + +# Converts the hand detection into a rectangle (normalized by image size) +# that encloses the hand and is rotated such that the line connecting center of +# the wrist and MCP of the middle finger is aligned with the Y-axis of the +# rectangle. +node { + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTION:hand_detection" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "NORM_RECT:hand_rect_from_landmarks" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRectsCalculatorOptions] { + rotation_vector_start_keypoint_index: 0 # Center of wrist. + rotation_vector_end_keypoint_index: 9 # MCP of middle finger. + rotation_vector_target_angle_degrees: 90 + } + } +} + +# Expands the hand rectangle so that in the next video frame it's likely to +# still contain the hand even with some motion. +node { + calculator: "RectTransformationCalculator" + input_stream: "NORM_RECT:hand_rect_from_landmarks" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "hand_rect_for_next_frame" + node_options: { + [type.googleapis.com/mediapipe.RectTransformationCalculatorOptions] { + scale_x: 1.6 + scale_y: 1.6 + square_long: true + } + } +} diff --git a/mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt b/mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt new file mode 100644 index 000000000..fdc40d163 --- /dev/null +++ b/mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt @@ -0,0 +1,123 @@ +# MediaPipe graph that performs hand tracking with TensorFlow Lite on GPU. +# Used in the examples in +# mediapipie/examples/android/src/java/com/mediapipe/apps/handtrackinggpu and +# mediapipie/examples/ios/handtrackinggpu. + +# Images coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for downstream nodes +# (calculators and subgraphs) in the graph to finish their tasks before it +# passes through another image. All images that come in while waiting are +# dropped, limiting the number of in-flight images in most part of the graph to +# 1. This prevents the downstream nodes from queuing up incoming images and data +# excessively, which leads to increased latency and memory usage, unwanted in +# real-time mobile applications. It also eliminates unnecessarily computation, +# e.g., the output produced by a node may get dropped downstream if the +# subsequent nodes are still busy processing previous inputs. +node { + calculator: "FlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:hand_rect" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +# Caches a hand-presence decision fed back from HandLandmarkSubgraph, and upon +# the arrival of the next input image sends out the cached decision with the +# timestamp replaced by that of the input image, essentially generating a packet +# that carries the previous hand-presence decision. Note that upon the arrival +# of the very first input image, an empty packet is sent out to jump start the +# feedback loop. +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:throttled_input_video" + input_stream: "LOOP:hand_presence" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:prev_hand_presence" +} + +# Drops the incoming image if HandLandmarkSubgraph was able to identify hand +# presence in the previous image. Otherwise, passes the incoming image through +# to trigger a new round of hand detection in HandDetectionSubgraph. +node { + calculator: "GateCalculator" + input_stream: "throttled_input_video" + input_stream: "DISALLOW:prev_hand_presence" + output_stream: "hand_detection_input_video" + + node_options: { + [type.googleapis.com/mediapipe.GateCalculatorOptions] { + empty_packets_as_allow: true + } + } +} + +# Subgraph that detections hands (see hand_detection_gpu.pbtxt). +node { + calculator: "HandDetectionSubgraph" + input_stream: "hand_detection_input_video" + output_stream: "DETECTIONS:palm_detections" + output_stream: "NORM_RECT:hand_rect_from_palm_detections" +} + +# Subgraph that localizes hand landmarks (see hand_landmark_gpu.pbtxt). +node { + calculator: "HandLandmarkSubgraph" + input_stream: "IMAGE:throttled_input_video" + input_stream: "NORM_RECT:hand_rect" + output_stream: "LANDMARKS:hand_landmarks" + output_stream: "NORM_RECT:hand_rect_from_landmarks" + output_stream: "PRESENCE:hand_presence" +} + +# Caches a hand rectangle fed back from HandLandmarkSubgraph, and upon the +# arrival of the next input image sends out the cached rectangle with the +# timestamp replaced by that of the input image, essentially generating a packet +# that carries the previous hand rectangle. Note that upon the arrival of the +# very first input image, an empty packet is sent out to jump start the +# feedback loop. +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:throttled_input_video" + input_stream: "LOOP:hand_rect_from_landmarks" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:prev_hand_rect_from_landmarks" +} + +# Merges a stream of hand rectangles generated by HandDetectionSubgraph and that +# generated by HandLandmarkSubgraph into a single output stream by selecting +# between one of the two streams. The formal is selected if the incoming packet +# is not empty, i.e., hand detection is performed on the current image by +# HandDetectionSubgraph (because HandLandmarkSubgraph could not identify hand +# presence in the previous image). Otherwise, the latter is selected, which is +# never empty because HandLandmarkSubgraphs processes all images (that went +# through FlowLimiterCaculator). +node { + calculator: "MergeCalculator" + input_stream: "hand_rect_from_palm_detections" + input_stream: "prev_hand_rect_from_landmarks" + output_stream: "hand_rect" +} + +# Subgraph that renders annotations and overlays them on top of the input +# images (see renderer_gpu.pbtxt). +node { + calculator: "RendererSubgraph" + input_stream: "IMAGE:throttled_input_video" + input_stream: "LANDMARKS:hand_landmarks" + input_stream: "NORM_RECT:hand_rect" + input_stream: "DETECTIONS:palm_detections" + output_stream: "IMAGE:output_video" +} diff --git a/mediapipe/graphs/hand_tracking/renderer_gpu.pbtxt b/mediapipe/graphs/hand_tracking/renderer_gpu.pbtxt new file mode 100644 index 000000000..6635958a1 --- /dev/null +++ b/mediapipe/graphs/hand_tracking/renderer_gpu.pbtxt @@ -0,0 +1,102 @@ +# MediaPipe hand tracking rendering subgraph. + +type: "RendererSubgraph" + +input_stream: "IMAGE:input_image" +input_stream: "DETECTIONS:detections" +input_stream: "LANDMARKS:landmarks" +input_stream: "NORM_RECT:rect" +output_stream: "IMAGE:output_image" + +# Converts detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "RENDER_DATA:detection_render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 0 g: 255 b: 0 } + } + } +} + +# Converts landmarks to drawing primitives for annotation overlay. +node { + calculator: "LandmarksToRenderDataCalculator" + input_stream: "NORM_LANDMARKS:landmarks" + output_stream: "RENDER_DATA:landmark_render_data" + node_options: { + [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { + landmark_connections: 0 + landmark_connections: 1 + landmark_connections: 1 + landmark_connections: 2 + landmark_connections: 2 + landmark_connections: 3 + landmark_connections: 3 + landmark_connections: 4 + landmark_connections: 0 + landmark_connections: 5 + landmark_connections: 5 + landmark_connections: 6 + landmark_connections: 6 + landmark_connections: 7 + landmark_connections: 7 + landmark_connections: 8 + landmark_connections: 5 + landmark_connections: 9 + landmark_connections: 9 + landmark_connections: 10 + landmark_connections: 10 + landmark_connections: 11 + landmark_connections: 11 + landmark_connections: 12 + landmark_connections: 9 + landmark_connections: 13 + landmark_connections: 13 + landmark_connections: 14 + landmark_connections: 14 + landmark_connections: 15 + landmark_connections: 15 + landmark_connections: 16 + landmark_connections: 13 + landmark_connections: 17 + landmark_connections: 0 + landmark_connections: 17 + landmark_connections: 17 + landmark_connections: 18 + landmark_connections: 18 + landmark_connections: 19 + landmark_connections: 19 + landmark_connections: 20 + landmark_color { r: 255 g: 0 b: 0 } + connection_color { r: 0 g: 255 b: 0 } + thickness: 4.0 + } + } +} + +# Converts normalized rects to drawing primitives for annotation overlay. +node { + calculator: "RectToRenderDataCalculator" + input_stream: "NORM_RECT:rect" + output_stream: "RENDER_DATA:rect_render_data" + node_options: { + [type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] { + filled: false + color { r: 255 g: 0 b: 0 } + thickness: 4.0 + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "INPUT_FRAME_GPU:input_image" + input_stream: "detection_render_data" + input_stream: "landmark_render_data" + input_stream: "rect_render_data" + output_stream: "OUTPUT_FRAME_GPU:output_image" +} diff --git a/mediapipe/graphs/media_sequence/BUILD b/mediapipe/graphs/media_sequence/BUILD index 3f3ddddf5..42af89b51 100644 --- a/mediapipe/graphs/media_sequence/BUILD +++ b/mediapipe/graphs/media_sequence/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipeOSS Authors. +# Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# licenses(["notice"]) # Apache 2.0 @@ -29,3 +28,20 @@ cc_library( "//mediapipe/calculators/video:opencv_video_decoder_calculator", ], ) + +cc_library( + name = "tvl1_flow_and_rgb_from_file_calculators", + deps = [ + "//mediapipe/calculators/core:packet_inner_join_calculator", + "//mediapipe/calculators/core:packet_resampler_calculator", + "//mediapipe/calculators/core:sequence_shift_calculator", + "//mediapipe/calculators/image:opencv_image_encoder_calculator", + "//mediapipe/calculators/image:scale_image_calculator", + "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator", + "//mediapipe/calculators/tensorflow:string_to_sequence_example_calculator", + "//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator", + "//mediapipe/calculators/video:flow_to_image_calculator", + "//mediapipe/calculators/video:opencv_video_decoder_calculator", + "//mediapipe/calculators/video:tvl1_optical_flow_calculator", + ], +) diff --git a/mediapipe/graphs/media_sequence/clipped_images_from_file_at_24fps.pbtxt b/mediapipe/graphs/media_sequence/clipped_images_from_file_at_24fps.pbtxt index 787d99524..e3c6a5121 100644 --- a/mediapipe/graphs/media_sequence/clipped_images_from_file_at_24fps.pbtxt +++ b/mediapipe/graphs/media_sequence/clipped_images_from_file_at_24fps.pbtxt @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipeOSS Authors. +# Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,9 +25,9 @@ node { input_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example" output_side_packet: "DATA_PATH:input_video_path" output_side_packet: "RESAMPLER_OPTIONS:packet_resampler_options" - options { - [mediapipe.UnpackMediaSequenceCalculatorOptions.ext]: { - base_packet_resampler_options { + node_options: { + [type.googleapis.com/mediapipe.UnpackMediaSequenceCalculatorOptions]: { + base_packet_resampler_options: { frame_rate: 24.0 base_timestamp: 0 } @@ -55,7 +55,7 @@ node { calculator: "OpenCvImageEncoderCalculator" input_stream: "sampled_frames" output_stream: "encoded_frames" - node_options { + node_options: { [type.googleapis.com/mediapipe.OpenCvImageEncoderCalculatorOptions]: { quality: 80 } diff --git a/mediapipe/graphs/media_sequence/tvl1_flow_and_rgb_from_file.pbtxt b/mediapipe/graphs/media_sequence/tvl1_flow_and_rgb_from_file.pbtxt new file mode 100644 index 000000000..032fc3659 --- /dev/null +++ b/mediapipe/graphs/media_sequence/tvl1_flow_and_rgb_from_file.pbtxt @@ -0,0 +1,153 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Convert the string input into a decoded SequenceExample. +node { + calculator: "StringToSequenceExampleCalculator" + input_side_packet: "STRING:input_sequence_example" + output_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example" +} + +# Unpack the data path and clip timing from the SequenceExample. +node { + calculator: "UnpackMediaSequenceCalculator" + input_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example" + output_side_packet: "DATA_PATH:input_video_path" + output_side_packet: "RESAMPLER_OPTIONS:packet_resampler_options" + node_options: { + [type.googleapis.com/mediapipe.UnpackMediaSequenceCalculatorOptions]: { + base_packet_resampler_options: { + frame_rate: 25.0 + base_timestamp: 0 + } + } + } +} + +# Decode the entire video. +node { + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_video_path" + output_stream: "VIDEO:decoded_frames" +} + +# Extract the subset of frames we want to keep. +node { + calculator: "PacketResamplerCalculator" + input_stream: "decoded_frames" + output_stream: "sampled_frames" + input_side_packet: "OPTIONS:packet_resampler_options" +} + +# Fit the images into the target size. +node: { + calculator: "ScaleImageCalculator" + input_stream: "sampled_frames" + output_stream: "scaled_frames" + node_options: { + [type.googleapis.com/mediapipe.ScaleImageCalculatorOptions]: { + target_height: 256 + preserve_aspect_ratio: true + } + } +} + +# Shift the the timestamps of packets along a stream. +# With a packet_offset of -1, the first packet will be dropped, the second will +# be output with the timestamp of the first, the third with the timestamp of +# the second, and so on. +node: { + calculator: "SequenceShiftCalculator" + input_stream: "scaled_frames" + output_stream: "shifted_scaled_frames" + node_options: { + [type.googleapis.com/mediapipe.SequenceShiftCalculatorOptions]: { + packet_offset: -1 + } + } +} + +# Join the original input stream and the one that is shifted by one packet. +node: { + calculator: "PacketInnerJoinCalculator" + input_stream: "scaled_frames" + input_stream: "shifted_scaled_frames" + output_stream: "first_frames" + output_stream: "second_frames" +} + +# Compute the forward optical flow. +node { + calculator: "Tvl1OpticalFlowCalculator" + input_stream: "FIRST_FRAME:first_frames" + input_stream: "SECOND_FRAME:second_frames" + output_stream: "FORWARD_FLOW:forward_flow" + max_in_flight: 32 +} + +# Convert an optical flow to be an image frame with 2 channels (v_x and v_y), +# each channel is quantized to 0-255. +node: { + calculator: "FlowToImageCalculator" + input_stream: "forward_flow" + output_stream: "flow_frames" + node_options: { + [type.googleapis.com/mediapipe.FlowToImageCalculatorOptions]: { + min_value: -20.0 + max_value: 20.0 + } + } +} + +# Encode the optical flow images to store in the SequenceExample. +node { + calculator: "OpenCvImageEncoderCalculator" + input_stream: "flow_frames" + output_stream: "encoded_flow_frames" + node_options: { + [type.googleapis.com/mediapipe.OpenCvImageEncoderCalculatorOptions]: { + quality: 100 + } + } +} + +# Encode the rgb images to store in the SequenceExample. +node { + calculator: "OpenCvImageEncoderCalculator" + input_stream: "scaled_frames" + output_stream: "encoded_frames" + node_options: { + [type.googleapis.com/mediapipe.OpenCvImageEncoderCalculatorOptions]: { + quality: 100 + } + } +} + +# Store the images in the SequenceExample. +node { + calculator: "PackMediaSequenceCalculator" + input_stream: "IMAGE:encoded_frames" + input_stream: "FORWARD_FLOW_ENCODED:encoded_flow_frames" + input_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example" + output_side_packet: "SEQUENCE_EXAMPLE:sequence_example_to_serialize" +} + +# Serialize the SequenceExample to a string for storage. +node { + calculator: "StringToSequenceExampleCalculator" + input_side_packet: "SEQUENCE_EXAMPLE:sequence_example_to_serialize" + output_side_packet: "STRING:output_sequence_example" +} + +num_threads: 32 diff --git a/mediapipe/graphs/object_detection/BUILD b/mediapipe/graphs/object_detection/BUILD index 8db651b2e..cd1d1b6be 100644 --- a/mediapipe/graphs/object_detection/BUILD +++ b/mediapipe/graphs/object_detection/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipeOSS Authors. +# Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,16 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) cc_library( - name = "android_calculators", + name = "mobile_calculators", deps = [ - "//mediapipe/calculators/core:real_time_flow_limiter_calculator", + "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tflite:ssd_anchors_calculator", "//mediapipe/calculators/tflite:tflite_converter_calculator", @@ -77,35 +76,15 @@ load( ) mediapipe_binary_graph( - name = "android_cpu_binary_graph", - graph = "object_detection_android_cpu.pbtxt", - output_name = "android_cpu.binarypb", - deps = [ - "//mediapipe/calculators/image:image_transformation_calculator_proto", - "//mediapipe/calculators/tflite:ssd_anchors_calculator_proto", - "//mediapipe/calculators/tflite:tflite_converter_calculator_proto", - "//mediapipe/calculators/tflite:tflite_inference_calculator_proto", - "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator_proto", - "//mediapipe/calculators/util:annotation_overlay_calculator_proto", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator_proto", - "//mediapipe/calculators/util:detections_to_render_data_calculator_proto", - "//mediapipe/calculators/util:non_max_suppression_calculator_proto", - ], + name = "mobile_cpu_binary_graph", + graph = "object_detection_mobile_cpu.pbtxt", + output_name = "mobile_cpu.binarypb", + deps = [":mobile_calculators"], ) mediapipe_binary_graph( - name = "android_gpu_binary_graph", - graph = "object_detection_android_gpu.pbtxt", - output_name = "android_gpu.binarypb", - deps = [ - "//mediapipe/calculators/image:image_transformation_calculator_proto", - "//mediapipe/calculators/tflite:ssd_anchors_calculator_proto", - "//mediapipe/calculators/tflite:tflite_converter_calculator_proto", - "//mediapipe/calculators/tflite:tflite_inference_calculator_proto", - "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator_proto", - "//mediapipe/calculators/util:annotation_overlay_calculator_proto", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator_proto", - "//mediapipe/calculators/util:detections_to_render_data_calculator_proto", - "//mediapipe/calculators/util:non_max_suppression_calculator_proto", - ], + name = "mobile_gpu_binary_graph", + graph = "object_detection_mobile_gpu.pbtxt", + output_name = "mobile_gpu.binarypb", + deps = [":mobile_calculators"], ) diff --git a/mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt b/mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt index fa384f805..b55290c10 100644 --- a/mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt +++ b/mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt @@ -96,7 +96,7 @@ node { # Converts the detections to drawing primitives for annotation overlay. node { calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION_VECTOR:output_detections" + input_stream: "DETECTIONS:output_detections" output_stream: "RENDER_DATA:render_data" node_options: { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { @@ -106,8 +106,7 @@ node { } } -# Draws annotations and overlays them on top of the original image coming into -# the graph. +# Draws annotations and overlays them on top of the input images. node { calculator: "AnnotationOverlayCalculator" input_stream: "INPUT_FRAME:input_video" diff --git a/mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt b/mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt index bd0f2b581..575d933a8 100644 --- a/mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt +++ b/mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt @@ -146,7 +146,7 @@ node { # Converts the detections to drawing primitives for annotation overlay. node { calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION_VECTOR:output_detections" + input_stream: "DETECTIONS:output_detections" output_stream: "RENDER_DATA:render_data" node_options: { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { @@ -156,8 +156,7 @@ node { } } -# Draws annotations and overlays them on top of the original image coming into -# the graph. +# Draws annotations and overlays them on top of the input images. node { calculator: "AnnotationOverlayCalculator" input_stream: "INPUT_FRAME:input_video" diff --git a/mediapipe/graphs/object_detection/object_detection_android_cpu.pbtxt b/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt similarity index 83% rename from mediapipe/graphs/object_detection/object_detection_android_cpu.pbtxt rename to mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt index 92207d13e..4eb527a3c 100644 --- a/mediapipe/graphs/object_detection/object_detection_android_cpu.pbtxt +++ b/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt @@ -1,6 +1,7 @@ # MediaPipe graph that performs object detection with TensorFlow Lite on CPU. -# Used in the example in -# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectioncpu. +# Used in the examples in +# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectioncpu and +# mediapipie/examples/ios/objectdetectioncpu. # Images on GPU coming into and out of the graph. input_stream: "input_video" @@ -30,7 +31,7 @@ node: { # TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy # processing previous inputs. node { - calculator: "RealTimeFlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "input_video_cpu" input_stream: "FINISHED:detections" input_stream_info: { @@ -57,22 +58,12 @@ node: { } } -# Converts the transformed input image on CPU into an image tensor as a -# TfLiteTensor. The zero_center option is set to true to normalize the -# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically -# option is set to true to account for the descrepancy between the -# representation of the input image (origin at the bottom-left corner) and what -# the model used in this graph is expecting (origin at the top-left corner). +# Converts the transformed input image on CPU into an image tensor stored as a +# TfLiteTensor. node { calculator: "TfLiteConverterCalculator" input_stream: "IMAGE:transformed_input_video_cpu" output_stream: "TENSORS:image_tensor" - node_options: { - [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { - zero_center: true - flip_vertically: true - } - } } # Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a @@ -139,7 +130,7 @@ node { y_scale: 10.0 h_scale: 5.0 w_scale: 5.0 - flip_vertically: true + min_score_thresh: 0.6 } } } @@ -152,9 +143,9 @@ node { node_options: { [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { min_suppression_threshold: 0.4 - min_score_threshold: 0.6 max_num_detections: 3 overlap_type: INTERSECTION_OVER_UNION + return_empty_detections: true } } } @@ -175,7 +166,7 @@ node { # Converts the detections to drawing primitives for annotation overlay. node { calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION_VECTOR:output_detections" + input_stream: "DETECTIONS:output_detections" output_stream: "RENDER_DATA:render_data" node_options: { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { @@ -185,21 +176,12 @@ node { } } -# Draws annotations and overlays them on top of the CPU copy of the original -# image coming into the graph. The calculator assumes that image origin is -# always at the top-left corner and renders text accordingly. However, the input -# image has its origin at the bottom-left corner (OpenGL convention) and the -# flip_text_vertically option is set to true to compensate that. +# Draws annotations and overlays them on top of the input images. node { calculator: "AnnotationOverlayCalculator" input_stream: "INPUT_FRAME:throttled_input_video_cpu" input_stream: "render_data" output_stream: "OUTPUT_FRAME:output_video_cpu" - node_options: { - [type.googleapis.com/mediapipe.AnnotationOverlayCalculatorOptions] { - flip_text_vertically: true - } - } } # Transfers the annotated image from CPU back to GPU memory, to be sent out of diff --git a/mediapipe/graphs/object_detection/object_detection_android_gpu.pbtxt b/mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt similarity index 78% rename from mediapipe/graphs/object_detection/object_detection_android_gpu.pbtxt rename to mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt index 45b865a47..44bf61057 100644 --- a/mediapipe/graphs/object_detection/object_detection_android_gpu.pbtxt +++ b/mediapipe/graphs/object_detection/object_detection_mobile_gpu.pbtxt @@ -1,6 +1,7 @@ # MediaPipe graph that performs object detection with TensorFlow Lite on GPU. -# Used in the example in -# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectiongpu. +# Used in the examples in +# mediapipie/examples/android/src/java/com/mediapipe/apps/objectdetectiongpu and +# mediapipie/examples/ios/objectdetectiongpu. # Images on GPU coming into and out of the graph. input_stream: "input_video" @@ -20,7 +21,7 @@ output_stream: "output_video" # TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy # processing previous inputs. node { - calculator: "RealTimeFlowLimiterCalculator" + calculator: "FlowLimiterCalculator" input_stream: "input_video" input_stream: "FINISHED:detections" input_stream_info: { @@ -47,23 +48,12 @@ node: { } } -# Converts the transformed input image on GPU into an image tensor stored in -# tflite::gpu::GlBuffer. The zero_center option is set to true to normalize the -# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically -# option is set to true to account for the descrepancy between the -# representation of the input image (origin at the bottom-left corner, the -# OpenGL convention) and what the model used in this graph is expecting (origin -# at the top-left corner). +# Converts the transformed input image on GPU into an image tensor stored as a +# TfLiteTensor. node { calculator: "TfLiteConverterCalculator" input_stream: "IMAGE_GPU:transformed_input_video" output_stream: "TENSORS_GPU:image_tensor" - node_options: { - [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { - zero_center: true - flip_vertically: true - } - } } # Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a @@ -72,7 +62,7 @@ node { node { calculator: "TfLiteInferenceCalculator" input_stream: "TENSORS_GPU:image_tensor" - output_stream: "TENSORS_GPU:detection_tensors" + output_stream: "TENSORS:detection_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { model_path: "ssdlite_object_detection.tflite" @@ -115,7 +105,7 @@ node { # detections. Each detection describes a detected object. node { calculator: "TfLiteTensorsToDetectionsCalculator" - input_stream: "TENSORS_GPU:detection_tensors" + input_stream: "TENSORS:detection_tensors" input_side_packet: "ANCHORS:anchors" output_stream: "DETECTIONS:detections" node_options: { @@ -130,7 +120,7 @@ node { y_scale: 10.0 h_scale: 5.0 w_scale: 5.0 - flip_vertically: true + min_score_thresh: 0.6 } } } @@ -143,9 +133,9 @@ node { node_options: { [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { min_suppression_threshold: 0.4 - min_score_threshold: 0.6 max_num_detections: 3 overlap_type: INTERSECTION_OVER_UNION + return_empty_detections: true } } } @@ -166,7 +156,7 @@ node { # Converts the detections to drawing primitives for annotation overlay. node { calculator: "DetectionsToRenderDataCalculator" - input_stream: "DETECTION_VECTOR:output_detections" + input_stream: "DETECTIONS:output_detections" output_stream: "RENDER_DATA:render_data" node_options: { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { @@ -176,21 +166,10 @@ node { } } -# Draws annotations and overlays them on top of the original image coming into -# the graph. Annotation drawing is performed on CPU, and the result is -# transferred to GPU and overlaid on the input image. The calculator assumes -# that image origin is always at the top-left corner and renders text -# accordingly. However, the input image has its origin at the bottom-left corner -# (OpenGL convention) and the flip_text_vertically option is set to true to -# compensate that. +# Draws annotations and overlays them on top of the input images. node { calculator: "AnnotationOverlayCalculator" input_stream: "INPUT_FRAME_GPU:throttled_input_video" input_stream: "render_data" output_stream: "OUTPUT_FRAME_GPU:output_video" - node_options: { - [type.googleapis.com/mediapipe.AnnotationOverlayCalculatorOptions] { - flip_text_vertically: true - } - } } diff --git a/mediapipe/java/com/google/mediapipe/components/BUILD b/mediapipe/java/com/google/mediapipe/components/BUILD index d3d884ccf..d3a19cd5b 100644 --- a/mediapipe/java/com/google/mediapipe/components/BUILD +++ b/mediapipe/java/com/google/mediapipe/components/BUILD @@ -28,8 +28,9 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/glutil", "//third_party:androidx_appcompat", - "@androidsdk//com.android.support:recyclerview-v7-25.0.0", - "@androidsdk//com.android.support:support-v4-25.0.0", + "//third_party:androidx_core", + "//third_party:androidx_legacy_support_v4", + "//third_party:androidx_recyclerview", "@com_google_code_findbugs//jar", "@com_google_guava_android//jar", ], @@ -46,12 +47,23 @@ android_library( visibility = ["//visibility:public"], deps = [ "//third_party:androidx_appcompat", + "//third_party:androidx_legacy_support_v4", "//third_party:camera2", "//third_party:camerax_core", - "@androidsdk//com.android.support:support-v4-25.0.0", "@androidx_concurrent_futures//jar", "@androidx_lifecycle//jar", "@com_google_code_findbugs//jar", "@com_google_guava_android//jar", ], ) + +android_library( + name = "android_microphone_helper", + srcs = [ + "MicrophoneHelper.java", + ], + visibility = ["//visibility:public"], + deps = [ + "//third_party/java/jsr305_annotations", + ], +) diff --git a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java index cc74e724d..3dfc9cb10 100644 --- a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java +++ b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java @@ -43,12 +43,10 @@ public class ExternalTextureConverter implements TextureFrameProducer { /** * Creates the ExternalTextureConverter to create a working copy of each camera frame. * - * @param numBuffers The number of camera frames that can enter processing simultaneously. + * @param numBuffers the number of camera frames that can enter processing simultaneously. */ public ExternalTextureConverter(EGLContext parentContext, int numBuffers) { thread = new RenderThread(parentContext, numBuffers); - // Give the thread a consistent name so it can be whitelisted for use in TikTok apps - // (go/tiktok-tattletale). thread.setName(THREAD_NAME); thread.start(); try { @@ -66,6 +64,16 @@ public class ExternalTextureConverter implements TextureFrameProducer { } } + /** + * Sets vertical flipping of the texture, useful for conversion between coordinate systems with + * top-left v.s. bottom-left origins. This should be called before {@link + * #setSurfaceTexture(SurfaceTexture, int, int)} or {@link + * #setSurfaceTextureAndAttachToGLContext(SurfaceTexture, int, int)}. + */ + public void setFlipY(boolean flip) { + thread.setFlipY(flip); + } + public ExternalTextureConverter(EGLContext parentContext) { this(parentContext, DEFAULT_NUM_BUFFERS); } @@ -154,6 +162,10 @@ public class ExternalTextureConverter implements TextureFrameProducer { consumers = new ArrayList<>(); } + public void setFlipY(boolean flip) { + renderer.setFlipY(flip); + } + public void setSurfaceTexture(SurfaceTexture texture, int width, int height) { if (surfaceTexture != null) { surfaceTexture.setOnFrameAvailableListener(null); diff --git a/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java b/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java index 37097f202..976200988 100644 --- a/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java +++ b/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java @@ -17,8 +17,8 @@ package com.google.mediapipe.components; import android.Manifest; import android.app.Activity; import android.content.pm.PackageManager; -import android.support.v4.app.ActivityCompat; -import android.support.v4.content.ContextCompat; +import androidx.core.app.ActivityCompat; +import androidx.core.content.ContextCompat; import android.util.Log; /** Manages camera permission request and handling. */ @@ -31,7 +31,7 @@ public class PermissionHelper { private static final int REQUEST_CODE = 0; - private static boolean permissionsGranted(Activity context, String[] permissions) { + public static boolean permissionsGranted(Activity context, String[] permissions) { for (String permission : permissions) { int permissionStatus = ContextCompat.checkSelfPermission(context, permission); if (permissionStatus != PackageManager.PERMISSION_GRANTED) { @@ -41,10 +41,9 @@ public class PermissionHelper { return true; } - private static void checkAndRequestPermissions( - Activity context, String[] permissions, int permissionCode) { + public static void checkAndRequestPermissions(Activity context, String[] permissions) { if (!permissionsGranted(context, permissions)) { - ActivityCompat.requestPermissions(context, permissions, permissionCode); + ActivityCompat.requestPermissions(context, permissions, REQUEST_CODE); } } @@ -58,7 +57,7 @@ public class PermissionHelper { */ public static void checkAndRequestCameraPermissions(Activity context) { Log.d(TAG, "checkAndRequestCameraPermissions"); - checkAndRequestPermissions(context, new String[] {CAMERA_PERMISSION}, REQUEST_CODE); + checkAndRequestPermissions(context, new String[] {CAMERA_PERMISSION}); } /** Called by context to check if audio permissions have been granted. */ @@ -69,22 +68,7 @@ public class PermissionHelper { /** Called by context to check if audio permissions have been granted and if not, request them. */ public static void checkAndRequestAudioPermissions(Activity context) { Log.d(TAG, "checkAndRequestAudioPermissions"); - checkAndRequestPermissions(context, new String[] {AUDIO_PERMISSION}, REQUEST_CODE); - } - - /** Called by context to check if audio and camera permissions have been granted. */ - public static boolean audioCameraPermissionsGranted(Activity context) { - return permissionsGranted(context, new String[] {AUDIO_PERMISSION, CAMERA_PERMISSION}); - } - - /** - * Called by context to check if audio and camera permissions have been granted and if not, - * request them. - */ - public static void checkAndRequestAudioCameraPermissions(Activity context) { - Log.d(TAG, "checkAndRequestAudioCameraPermissions"); - checkAndRequestPermissions( - context, new String[] {AUDIO_PERMISSION, CAMERA_PERMISSION}, REQUEST_CODE); + checkAndRequestPermissions(context, new String[] {AUDIO_PERMISSION}); } /** Called by context when permissions request has been completed. */ diff --git a/mediapipe/java/com/google/mediapipe/framework/AssetCache.java b/mediapipe/java/com/google/mediapipe/framework/AssetCache.java index 15ed9c3eb..1702f4d4b 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AssetCache.java +++ b/mediapipe/java/com/google/mediapipe/framework/AssetCache.java @@ -17,7 +17,7 @@ package com.google.mediapipe.framework; import android.content.Context; import android.content.pm.PackageManager.NameNotFoundException; import android.content.res.AssetManager; -import android.support.annotation.VisibleForTesting; +import androidx.annotation.VisibleForTesting; import android.text.TextUtils; import com.google.common.base.Preconditions; import com.google.common.flogger.FluentLogger; diff --git a/mediapipe/java/com/google/mediapipe/framework/BUILD b/mediapipe/java/com/google/mediapipe/framework/BUILD index a49924207..e6ad76ed9 100644 --- a/mediapipe/java/com/google/mediapipe/framework/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/BUILD @@ -48,7 +48,8 @@ android_library( ], deps = [ ":android_core", - "@androidsdk//com.android.support:support-v4-25.0.0", + "//third_party:androidx_annotation", + "//third_party:androidx_legacy_support_v4", "@com_google_code_findbugs//jar", "@com_google_common_flogger//jar", "@com_google_common_flogger_system_backend//jar", @@ -74,6 +75,7 @@ android_library( deps = [ "//mediapipe/framework:calculator_java_proto_lite", "//mediapipe/framework:calculator_profile_java_proto_lite", + "//mediapipe/framework/tool:calculator_graph_template_java_proto_lite", "@com_google_code_findbugs//jar", "@com_google_common_flogger//jar", "@com_google_common_flogger_system_backend//jar", diff --git a/mediapipe/java/com/google/mediapipe/framework/Graph.java b/mediapipe/java/com/google/mediapipe/framework/Graph.java index 3edac48b7..0feabf386 100644 --- a/mediapipe/java/com/google/mediapipe/framework/Graph.java +++ b/mediapipe/java/com/google/mediapipe/framework/Graph.java @@ -16,7 +16,9 @@ package com.google.mediapipe.framework; import com.google.common.base.Preconditions; import com.google.common.flogger.FluentLogger; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; +import com.google.mediapipe.proto.GraphTemplateProto.CalculatorGraphTemplate; import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; import java.util.HashMap; @@ -101,13 +103,28 @@ public class Graph { nativeLoadBinaryGraphBytes(nativeGraphHandle, data); } - /** Loads a binary mediapipe graph from a CalculatorGraphConfig. */ + /** Specifies a CalculatorGraphConfig for a mediapipe graph or subgraph. */ public synchronized void loadBinaryGraph(CalculatorGraphConfig config) { loadBinaryGraph(config.toByteArray()); } + /** Specifies a CalculatorGraphTemplate for a mediapipe graph or subgraph. */ + public synchronized void loadBinaryGraphTemplate(CalculatorGraphTemplate template) { + nativeLoadBinaryGraphTemplate(nativeGraphHandle, template.toByteArray()); + } + + /** Specifies the CalculatorGraphConfig::type of the top level graph. */ + public synchronized void setGraphType(String graphType) { + nativeSetGraphType(nativeGraphHandle, graphType); + } + + /** Specifies options such as template arguments for the graph. */ + public synchronized void setGraphOptions(CalculatorOptions options) { + nativeSetGraphOptions(nativeGraphHandle, options.toByteArray()); + } + /** - * Returns the CalculatorGraphConfig if a graph is loaded. + * Returns the canonicalized CalculatorGraphConfig with subgraphs and graph templates expanded. */ public synchronized CalculatorGraphConfig getCalculatorGraphConfig() { Preconditions.checkState( @@ -594,6 +611,12 @@ public class Graph { private native void nativeLoadBinaryGraphBytes(long context, byte[] data); + private native void nativeLoadBinaryGraphTemplate(long context, byte[] data); + + private native void nativeSetGraphType(long context, String graphType); + + private native void nativeSetGraphOptions(long context, byte[] data); + private native byte[] nativeGetCalculatorGraphConfig(long context); private native void nativeRunGraphUntilClose(long context, String[] streamNames, long[] packets); diff --git a/mediapipe/java/com/google/mediapipe/framework/SurfaceOutput.java b/mediapipe/java/com/google/mediapipe/framework/SurfaceOutput.java index e2edb4c82..454ff2c2c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/SurfaceOutput.java +++ b/mediapipe/java/com/google/mediapipe/framework/SurfaceOutput.java @@ -30,6 +30,15 @@ public class SurfaceOutput { surfaceHolderPacket = holderPacket; } + /** + * Sets vertical flipping of the output surface, useful for conversion between coordinate systems + * with top-left v.s. bottom-left origins. This should be called before {@link + * #setSurface(Object)} or {@link #setEglSurface(long)}. + */ + public void setFlipY(boolean flip) { + nativeSetFlipY(surfaceHolderPacket.getNativeHandle(), flip); + } + /** * Connects an Android {@link Surface} to an output. * @@ -61,6 +70,8 @@ public class SurfaceOutput { mediapipeGraph.getNativeHandle(), surfaceHolderPacket.getNativeHandle(), nativeEglSurface); } + private native void nativeSetFlipY(long nativePacket, boolean flip); + private native void nativeSetSurface( long nativeContext, long nativePacket, Object surface); private native void nativeSetEglSurface( diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index b101e2ad6..a02cd2e33 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -88,6 +88,7 @@ cc_library( ":jni_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index d78ac095d..e26123c1c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -114,9 +114,7 @@ class CallbackHandler { } // namespace internal Graph::Graph() - : graph_loaded_(false), - executor_stack_size_increased_(false), - global_java_packet_cls_(nullptr) {} + : executor_stack_size_increased_(false), global_java_packet_cls_(nullptr) {} Graph::~Graph() { if (running_graph_) { @@ -175,13 +173,14 @@ void Graph::EnsureMinimumExecutorStackSizeForJava() {} ::mediapipe::Status Graph::AddCallbackHandler(std::string output_stream_name, jobject java_callback) { - if (!graph_loaded_) { + if (!graph_config()) { return ::mediapipe::InternalError("Graph is not loaded!"); } std::unique_ptr handler( new internal::CallbackHandler(this, java_callback)); std::string side_packet_name; - tool::AddCallbackCalculator(output_stream_name, &graph_, &side_packet_name, + tool::AddCallbackCalculator(output_stream_name, graph_config(), + &side_packet_name, /* use_std_function = */ true); EnsureMinimumExecutorStackSizeForJava(); side_packets_callbacks_.emplace( @@ -193,14 +192,14 @@ void Graph::EnsureMinimumExecutorStackSizeForJava() {} ::mediapipe::Status Graph::AddCallbackWithHeaderHandler( std::string output_stream_name, jobject java_callback) { - if (!graph_loaded_) { + if (!graph_config()) { return ::mediapipe::InternalError("Graph is not loaded!"); } std::unique_ptr handler( new internal::CallbackHandler(this, java_callback)); std::string side_packet_name; tool::AddCallbackWithHeaderCalculator(output_stream_name, output_stream_name, - &graph_, &side_packet_name, + graph_config(), &side_packet_name, /* use_std_function = */ true); EnsureMinimumExecutorStackSizeForJava(); side_packets_callbacks_.emplace( @@ -212,7 +211,7 @@ void Graph::EnsureMinimumExecutorStackSizeForJava() {} } int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { - if (!graph_loaded_) { + if (!graph_config()) { LOG(ERROR) << "Graph is not loaded!"; return 0; } @@ -220,9 +219,9 @@ int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { #ifdef MEDIAPIPE_DISABLE_GPU LOG(FATAL) << "GPU support has been disabled in this build!"; #else - CalculatorGraphConfig::Node* sink_node = graph_.add_node(); + CalculatorGraphConfig::Node* sink_node = graph_config()->add_node(); sink_node->set_name(::mediapipe::tool::GetUnusedNodeName( - graph_, absl::StrCat("egl_surface_sink_", output_stream_name))); + *graph_config(), absl::StrCat("egl_surface_sink_", output_stream_name))); sink_node->set_calculator("GlSurfaceSinkCalculator"); sink_node->add_input_stream(output_stream_name); sink_node->add_input_side_packet( @@ -230,7 +229,7 @@ int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { const std::string input_side_packet_name = ::mediapipe::tool::GetUnusedSidePacketName( - graph_, absl::StrCat(output_stream_name, "_surface")); + *graph_config(), absl::StrCat(output_stream_name, "_surface")); sink_node->add_input_side_packet( absl::StrCat("SURFACE:", input_side_packet_name)); @@ -249,24 +248,47 @@ int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { if (!status.ok()) { return status; } - if (!graph_.ParseFromString(graph_config_string)) { - return ::mediapipe::InvalidArgumentError( - absl::StrCat("Failed to parse the graph: ", path_to_graph)); - } - graph_loaded_ = true; - return ::mediapipe::OkStatus(); + return LoadBinaryGraph(graph_config_string.c_str(), + graph_config_string.length()); } ::mediapipe::Status Graph::LoadBinaryGraph(const char* data, int size) { - if (!graph_.ParseFromArray(data, size)) { + CalculatorGraphConfig graph_config; + if (!graph_config.ParseFromArray(data, size)) { return ::mediapipe::InvalidArgumentError("Failed to parse the graph"); } - graph_loaded_ = true; + graph_configs_.push_back(graph_config); return ::mediapipe::OkStatus(); } -const CalculatorGraphConfig& Graph::GetCalculatorGraphConfig() { - return graph_; +::mediapipe::Status Graph::LoadBinaryGraphTemplate(const char* data, int size) { + CalculatorGraphTemplate graph_template; + if (!graph_template.ParseFromArray(data, size)) { + return ::mediapipe::InvalidArgumentError("Failed to parse the graph"); + } + graph_templates_.push_back(graph_template); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status Graph::SetGraphType(std::string graph_type) { + graph_type_ = graph_type; + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status Graph::SetGraphOptions(const char* data, int size) { + if (!graph_options_.ParseFromArray(data, size)) { + return ::mediapipe::InvalidArgumentError("Failed to parse the graph"); + } + return ::mediapipe::OkStatus(); +} + +CalculatorGraphConfig Graph::GetCalculatorGraphConfig() { + CalculatorGraph temp_graph; + ::mediapipe::Status status = InitializeGraph(&temp_graph); + if (!status.ok()) { + LOG(ERROR) << "GetCalculatorGraphConfig failed:\n" << status.message(); + } + return temp_graph.Config(); } void Graph::CallbackToJava(JNIEnv* env, jobject java_callback_obj, @@ -332,10 +354,15 @@ void Graph::SetPacketJavaClass(JNIEnv* env) { SetPacketJavaClass(env); // Running as a synchronized mode, the same Java thread is available through // out the run. - CalculatorGraph calculator_graph(graph_); + CalculatorGraph calculator_graph; + ::mediapipe::Status status = InitializeGraph(&calculator_graph); + if (!status.ok()) { + LOG(ERROR) << status.message(); + running_graph_.reset(nullptr); + return status; + } // TODO: gpu & services set up! - ::mediapipe::Status status = - calculator_graph.Run(CreateCombinedSidePackets()); + status = calculator_graph.Run(CreateCombinedSidePackets()); LOG(INFO) << "Graph run finished."; return status; @@ -354,8 +381,8 @@ void Graph::SetPacketJavaClass(JNIEnv* env) { // Set the mode for adding packets to graph input streams. running_graph_->SetGraphInputStreamAddMode(graph_input_stream_add_mode_); if (VLOG_IS_ON(2)) { - LOG(INFO) << "input side packet streams:"; - for (auto& name : graph_.input_stream()) { + LOG(INFO) << "input packet streams:"; + for (auto& name : graph_config()->input_stream()) { LOG(INFO) << name; } } @@ -379,7 +406,7 @@ void Graph::SetPacketJavaClass(JNIEnv* env) { } } - status = running_graph_->Initialize(graph_); + status = InitializeGraph(running_graph_.get()); if (!status.ok()) { LOG(ERROR) << status.message(); running_graph_.reset(nullptr); @@ -529,5 +556,45 @@ ProfilingContext* Graph::GetProfilingContext() { return nullptr; } +CalculatorGraphConfig* Graph::graph_config() { + // Return the last specified graph config with the required graph_type. + for (auto it = graph_configs_.rbegin(); it != graph_configs_.rend(); ++it) { + if (it->type() == graph_type()) { + return &*it; + } + } + for (auto it = graph_templates_.rbegin(); it != graph_templates_.rend(); + ++it) { + if (it->mutable_config()->type() == graph_type()) { + return it->mutable_config(); + } + } + return nullptr; +} + +std::string Graph::graph_type() { + // If a graph-type is specified, that type is used. Otherwise the + // graph-type of the last specified graph config is used. + if (graph_type_ != "") { + return graph_type_; + } + if (!graph_configs_.empty()) { + return graph_configs_.back().type(); + } + if (!graph_templates_.empty()) { + return graph_templates_.back().config().type(); + } + return ""; +} + +::mediapipe::Status Graph::InitializeGraph(CalculatorGraph* graph) { + if (graph_configs_.size() == 1 && graph_templates_.empty()) { + return graph->Initialize(*graph_config()); + } else { + return graph->Initialize(graph_configs_, graph_templates_, {}, graph_type(), + &graph_options_); + } +} + } // namespace android } // namespace mediapipe diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h index 183179c24..39bd91446 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h @@ -64,8 +64,15 @@ class Graph { ::mediapipe::Status LoadBinaryGraph(std::string path_to_graph); // Loads a binary graph from a buffer. ::mediapipe::Status LoadBinaryGraph(const char* data, int size); - // Gets the calculator graph config. - const CalculatorGraphConfig& GetCalculatorGraphConfig(); + // Loads a binary graph template from a buffer. + ::mediapipe::Status LoadBinaryGraphTemplate(const char* data, int size); + // Specifies the CalculatorGraphConfig::type of the top level graph. + ::mediapipe::Status SetGraphType(std::string graph_type); + // Specifies options such as template arguments for the graph. + ::mediapipe::Status SetGraphOptions(const char* data, int size); + + // Returns the expanded calculator graph config. + CalculatorGraphConfig GetCalculatorGraphConfig(); // Runs the graph until it closes. // Mainly is used for writing tests. @@ -170,9 +177,24 @@ class Graph { void EnsureMinimumExecutorStackSizeForJava(); void SetPacketJavaClass(JNIEnv* env); std::map CreateCombinedSidePackets(); + // Returns the top-level CalculatorGraphConfig, or nullptr if the top-level + // CalculatorGraphConfig is not yet defined. + CalculatorGraphConfig* graph_config(); + // Returns the top-level CalculatorGraphConfig::type, or "" if the top-level + // CalculatorGraphConfig::type is not yet defined. + std::string graph_type(); + // Initializes CalculatorGraph |graph| using the loaded graph-configs. + ::mediapipe::Status InitializeGraph(CalculatorGraph* graph); + + // CalculatorGraphConfigs for the calculator graph and subgraphs. + std::vector graph_configs_; + // CalculatorGraphTemplates for the calculator graph and subgraphs. + std::vector graph_templates_; + // Options such as template arguments for the top-level calculator graph. + CalculatorOptions graph_options_; + // The CalculatorGraphConfig::type of the top-level calculator graph. + std::string graph_type_ = ""; - CalculatorGraphConfig graph_; - bool graph_loaded_; // Used by EnsureMinimumExecutorStackSizeForJava() to ensure that the // default executor's stack size is increased only once. bool executor_stack_size_increased_; diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc index 49a93b8eb..d968ff5d0 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc @@ -133,6 +133,45 @@ JNIEXPORT void JNICALL GRAPH_METHOD(nativeLoadBinaryGraphBytes)( ThrowIfError(env, status); } +JNIEXPORT void JNICALL GRAPH_METHOD(nativeLoadBinaryGraphTemplate)( + JNIEnv* env, jobject thiz, jlong context, jbyteArray data) { + mediapipe::android::Graph* mediapipe_graph = + reinterpret_cast(context); + jbyte* data_ptr = env->GetByteArrayElements(data, nullptr); + int size = env->GetArrayLength(data); + mediapipe::Status status = mediapipe_graph->LoadBinaryGraphTemplate( + reinterpret_cast(data_ptr), size); + env->ReleaseByteArrayElements(data, data_ptr, JNI_ABORT); + ThrowIfError(env, status); +} + +JNIEXPORT void JNICALL GRAPH_METHOD(nativeSetGraphType)(JNIEnv* env, + jobject thiz, + jlong context, + jstring graph_type) { + mediapipe::android::Graph* mediapipe_graph = + reinterpret_cast(context); + const char* graph_type_ref = env->GetStringUTFChars(graph_type, nullptr); + // Make a copy of the std::string and release the jni reference. + std::string graph_type_string(graph_type_ref); + env->ReleaseStringUTFChars(graph_type, graph_type_ref); + ThrowIfError(env, mediapipe_graph->SetGraphType(graph_type_string)); +} + +JNIEXPORT void JNICALL GRAPH_METHOD(nativeSetGraphOptions)(JNIEnv* env, + jobject thiz, + jlong context, + jbyteArray data) { + mediapipe::android::Graph* mediapipe_graph = + reinterpret_cast(context); + jbyte* data_ptr = env->GetByteArrayElements(data, nullptr); + int size = env->GetArrayLength(data); + mediapipe::Status status = + mediapipe_graph->SetGraphOptions(reinterpret_cast(data_ptr), size); + env->ReleaseByteArrayElements(data, data_ptr, JNI_ABORT); + ThrowIfError(env, status); +} + JNIEXPORT jbyteArray JNICALL GRAPH_METHOD(nativeGetCalculatorGraphConfig)( JNIEnv* env, jobject thiz, jlong context) { mediapipe::android::Graph* mediapipe_graph = diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h index 508213081..e08e36f4d 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.h @@ -42,6 +42,19 @@ JNIEXPORT void JNICALL GRAPH_METHOD(nativeLoadBinaryGraph)(JNIEnv* env, JNIEXPORT void JNICALL GRAPH_METHOD(nativeLoadBinaryGraphBytes)( JNIEnv* env, jobject thiz, jlong context, jbyteArray data); +JNIEXPORT void JNICALL GRAPH_METHOD(nativeLoadBinaryGraphTemplate)( + JNIEnv* env, jobject thiz, jlong context, jbyteArray data); + +JNIEXPORT void JNICALL GRAPH_METHOD(nativeSetGraphType)(JNIEnv* env, + jobject thiz, + jlong context, + jstring graph_type); + +JNIEXPORT void JNICALL GRAPH_METHOD(nativeSetGraphOptions)(JNIEnv* env, + jobject thiz, + jlong context, + jbyteArray data); + JNIEXPORT jbyteArray JNICALL GRAPH_METHOD(nativeGetCalculatorGraphConfig)( JNIEnv* env, jobject thiz, jlong context); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc index 827f62d6f..29646a474 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc @@ -41,6 +41,12 @@ mediapipe::GlContext* GetGlContext(jlong context) { } } // namespace +JNIEXPORT void JNICALL MEDIAPIPE_SURFACE_OUTPUT_METHOD(nativeSetFlipY)( + JNIEnv* env, jobject thiz, jlong packet, jboolean flip) { + mediapipe::EglSurfaceHolder* surface_holder = GetSurfaceHolder(packet); + surface_holder->flip_y = flip; +} + JNIEXPORT void JNICALL MEDIAPIPE_SURFACE_OUTPUT_METHOD(nativeSetSurface)( JNIEnv* env, jobject thiz, jlong context, jlong packet, jobject surface) { #ifdef __ANDROID__ diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.h index bcc56f573..d3c59f921 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.h @@ -24,6 +24,9 @@ extern "C" { #define MEDIAPIPE_SURFACE_OUTPUT_METHOD(METHOD_NAME) \ Java_com_google_mediapipe_framework_SurfaceOutput_##METHOD_NAME +JNIEXPORT void JNICALL MEDIAPIPE_SURFACE_OUTPUT_METHOD(nativeSetFlipY)( + JNIEnv* env, jobject thiz, jlong packet, jboolean flip); + #ifdef __ANDROID__ JNIEXPORT void JNICALL MEDIAPIPE_SURFACE_OUTPUT_METHOD(nativeSetSurface)( JNIEnv* env, jobject thiz, jlong context, jlong packet, jobject surface); diff --git a/mediapipe/java/com/google/mediapipe/framework/proguard.pgcfg b/mediapipe/java/com/google/mediapipe/framework/proguard.pgcfg index 808371afa..699d36eee 100644 --- a/mediapipe/java/com/google/mediapipe/framework/proguard.pgcfg +++ b/mediapipe/java/com/google/mediapipe/framework/proguard.pgcfg @@ -11,6 +11,8 @@ # This method is invoked by native code. -keep public class com.google.mediapipe.framework.Packet { public static *** create(***); + public long getNativeHandle(); + public void release(); } # This method is invoked by native code. diff --git a/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java b/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java index 17d830244..e03bf409d 100644 --- a/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java +++ b/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java @@ -37,6 +37,14 @@ public class ExternalTextureRenderer { 1.0f, 1.0f // top right ); + private static final FloatBuffer FLIPPED_TEXTURE_VERTICES = + ShaderUtil.floatBuffer( + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f // bottom right + ); + private static final String TAG = "ExternalTextureRend"; // Max length of a tag is 23. private static final int ATTRIB_POSITION = 1; private static final int ATTRIB_TEXTURE_COORDINATE = 2; @@ -45,6 +53,7 @@ public class ExternalTextureRenderer { private int frameUniform; private int textureTransformUniform; private float[] textureTransformMatrix = new float[16]; + private boolean flipY; /** Call this to setup the shader program before rendering. */ public void setup() { @@ -62,11 +71,20 @@ public class ExternalTextureRenderer { } /** - * Renders the surfaceTexture to the framebuffer. + * Flips rendering output vertically, useful for conversion between coordinate systems with + * top-left v.s. bottom-left origins. Effective in subsequent {@link #render(SurfaceTexture)} + * calls. + */ + public void setFlipY(boolean flip) { + flipY = flip; + } + + /** + * Renders the surfaceTexture to the framebuffer with optional vertical flip. * *

Before calling this, {@link #setup} must have been called. * - * NOTE: Calls {@link SurfaceTexture#updateTexImage()} on passed surface texture. + *

NOTE: Calls {@link SurfaceTexture#updateTexImage()} on passed surface texture. */ public void render(SurfaceTexture surfaceTexture) { GLES20.glClear(GLES20.GL_COLOR_BUFFER_BIT); @@ -97,7 +115,12 @@ public class ExternalTextureRenderer { GLES20.glEnableVertexAttribArray(ATTRIB_TEXTURE_COORDINATE); GLES20.glVertexAttribPointer( - ATTRIB_TEXTURE_COORDINATE, 2, GLES20.GL_FLOAT, false, 0, TEXTURE_VERTICES); + ATTRIB_TEXTURE_COORDINATE, + 2, + GLES20.GL_FLOAT, + false, + 0, + flipY ? FLIPPED_TEXTURE_VERTICES : TEXTURE_VERTICES); ShaderUtil.checkGlError("program setup"); GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, 0, 4); diff --git a/mediapipe/models/README.md b/mediapipe/models/README.md index f307c5178..80aca046a 100644 --- a/mediapipe/models/README.md +++ b/mediapipe/models/README.md @@ -10,7 +10,7 @@ For details on the models, see [here](object_detection_saved_model/README.md). * [Object Detection on CPU on Android](../docs/object_detection_android_cpu.md) ### BlazeFace face detection model - * [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/facedetector_front.tflite) + * [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite) * Paper: ["BlazeFace: Sub-millisecond Neural Face Detection on Mobile GPUs"](https://sites.google.com/corp/view/perception-cv4arvr/blazeface) * Model card: [BlazeFace model card](https://sites.google.com/corp/view/perception-cv4arvr/blazeface#h.p_21ojPZDx3cqq) diff --git a/mediapipe/models/facedetector_front.tflite b/mediapipe/models/face_detection_front.tflite similarity index 100% rename from mediapipe/models/facedetector_front.tflite rename to mediapipe/models/face_detection_front.tflite diff --git a/mediapipe/models/facedetector_front_labelmap.txt b/mediapipe/models/face_detection_front_labelmap.txt similarity index 100% rename from mediapipe/models/facedetector_front_labelmap.txt rename to mediapipe/models/face_detection_front_labelmap.txt diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD new file mode 100644 index 000000000..70258687d --- /dev/null +++ b/mediapipe/objc/BUILD @@ -0,0 +1,252 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "CFHolder", + # Header is excluded on non-ios so you can still build :all. + hdrs = select({ + "//mediapipe:apple": ["CFHolder.h"], + "//conditions:default": [], + }), + visibility = ["//mediapipe/framework:mediapipe_internal"], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + visibility = ["//mediapipe/framework:mediapipe_internal"], + deps = [ + ":CFHolder", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:status", + ], +) + +objc_library( + name = "Weakify", + hdrs = ["Weakify.h"], + visibility = ["//mediapipe/framework:mediapipe_internal"], +) + +MEDIAPIPE_IOS_SRCS = [ + "MPPGraph.mm", + "MPPTimestampConverter.mm", + "NSError+util_status.mm", +] + +MEDIAPIPE_IOS_HDRS = [ + "MPPGraph.h", + "MPPTimestampConverter.h", + "NSError+util_status.h", +] + +MEDIAPIPE_IOS_CC_DEPS = [ + ":CFHolder", + ":util", + "//mediapipe/framework/port:map_util", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:source_location", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/framework/port:threadpool", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:mediapipe_profiling", + "//mediapipe/gpu:MPPGraphGPUData", + "//mediapipe/gpu:pixel_buffer_pool_util", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gl_base", + "//mediapipe/gpu:gpu_shared_data_internal", + "//mediapipe/gpu:graph_support", + # Other deps + "//mediapipe/util:cpu_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", +] + +objc_library( + name = "mediapipe_framework_ios", + srcs = MEDIAPIPE_IOS_SRCS, + hdrs = MEDIAPIPE_IOS_HDRS, + copts = [ + "-Wno-shorten-64-to-32", + ], + sdk_frameworks = [ + # Needed for OpenCV. + "Accelerate", + ], + visibility = ["//mediapipe/framework:mediapipe_internal"], + deps = MEDIAPIPE_IOS_CC_DEPS + [ + # These are objc_library deps. + "@google_toolbox_for_mac//:GTM_Defines", + ], +) + +objc_library( + name = "mediapipe_input_sources_ios", + srcs = [ + "MPPCameraInputSource.m", + "MPPDisplayLinkWeakTarget.m", + "MPPInputSource.m", + "MPPPlayerInputSource.m", + ], + hdrs = [ + "MPPCameraInputSource.h", + "MPPDisplayLinkWeakTarget.h", + "MPPInputSource.h", + "MPPPlayerInputSource.h", + ], + visibility = ["//mediapipe/framework:mediapipe_internal"], +) + +objc_library( + name = "mediapipe_gl_view_renderer", + srcs = [ + "MPPGLViewRenderer.mm", + ], + hdrs = [ + "MPPGLViewRenderer.h", + ], + copts = [ + "-Wno-shorten-64-to-32", + ], + visibility = ["//mediapipe/framework:mediapipe_internal"], + deps = [ + ":mediapipe_framework_ios", + "//mediapipe/gpu:gl_calculator_helper_ios", + "//mediapipe/gpu:gl_quad_renderer", + "//mediapipe/gpu:gl_simple_shaders", + ], +) + +objc_library( + name = "mediapipe_layer_renderer", + srcs = [ + "MPPLayerRenderer.m", + ], + hdrs = [ + "MPPLayerRenderer.h", + ], + copts = [ + "-Wno-shorten-64-to-32", + ], + visibility = ["//mediapipe/framework:mediapipe_internal"], + deps = [ + ":mediapipe_framework_ios", + ":mediapipe_gl_view_renderer", + "//mediapipe/gpu:gl_calculator_helper_ios", + ], +) + +objc_library( + name = "CGImageRefUtils", + srcs = [ + "CGImageRefUtils.mm", + ], + hdrs = [ + "CGImageRefUtils.h", + ], + copts = [ + "-Wno-shorten-64-to-32", + ], + sdk_frameworks = [ + "CoreVideo", + ], + visibility = ["//mediapipe/framework:mediapipe_internal"], + deps = [ + ":mediapipe_framework_ios", + "@com_google_absl//absl/strings", + ], +) + +objc_library( + name = "MPPGraphTestBase", + testonly = 1, + srcs = [ + "MPPGraphTestBase.mm", + ], + hdrs = [ + "MPPGraphTestBase.h", + ], + copts = [ + "-Wno-shorten-64-to-32", + ], + sdk_frameworks = [ + "Accelerate", + "AVFoundation", + "CoreVideo", + "CoreGraphics", + "CoreMedia", + "GLKit", + "OpenGLES", + "QuartzCore", + "UIKit", + ], + visibility = ["//mediapipe/framework:mediapipe_internal"], + deps = [ + ":CGImageRefUtils", + ":Weakify", + ":mediapipe_framework_ios", + "//mediapipe/framework:calculator_framework", + ], +) + +objc_library( + name = "mediapipe_framework_ios_testLib", + testonly = 1, + srcs = [ + "CFHolderTests.mm", + "MPPGraphTests.mm", + ], + copts = [ + "-Wno-shorten-64-to-32", + ], + data = [ + "testdata/googlelogo_color_272x92dp.png", + ], + sdk_frameworks = [ + "Accelerate", + "AVFoundation", + "CoreVideo", + "CoreGraphics", + "CoreMedia", + "GLKit", + "QuartzCore", + "UIKit", + ], + visibility = ["//mediapipe/framework:mediapipe_internal"], + deps = [ + ":CGImageRefUtils", + ":MPPGraphTestBase", + ":Weakify", + ":mediapipe_framework_ios", + "//mediapipe/calculators/core:pass_through_calculator", + ], +) + +load("//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_binary_graph") + +[ + mediapipe_binary_graph( + name = graph.split("/")[-1].rsplit(".", 1)[0] + "_graph", + graph = graph, + output_name = "%s.binarypb" % graph.split("/")[-1].rsplit(".", 1)[0], + visibility = ["//mediapipe/framework:mediapipe_internal"], + ) + for graph in glob(["testdata/*.pbtxt"]) +] + +exports_files( + [ + "testdata/googlelogo_color_272x92dp.png", + "testdata/googlelogo_color_272x92dp_luminance.png", + ], + visibility = [ + "//mediapipe/feature_extraction/video/video_effects:__pkg__", + "//mediapipe/gpu:__pkg__", + ], +) diff --git a/mediapipe/objc/CFHolder.h b/mediapipe/objc/CFHolder.h new file mode 100644 index 000000000..b16e044e3 --- /dev/null +++ b/mediapipe/objc/CFHolder.h @@ -0,0 +1,109 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_OBJC_CFHOLDER_H_ +#define MEDIAPIPE_OBJC_CFHOLDER_H_ + +#import + +/// Manages ownership of a CoreFoundation type (any type that can be passed +/// to CFRetain/CFRelease). +template +class CFHolder { + public: + /// Default constructor gives a NULL ref. + CFHolder() : _object(NULL) {} + + /// Constructor with the basic ref type. Retains it. + explicit CFHolder(T object) : _object(RetainIfNotNull(object)) {} + + /// Copy constructor. + CFHolder(const CFHolder& other) : _object(RetainIfNotNull(*other)) {} + + /// Move constructor. + CFHolder(CFHolder&& other) : _object(*other) { other._object = NULL; } + + /// Destructor releases the held object. + ~CFHolder() { ReleaseIfNotNull(_object); } + + /// Dereference to access the held object. + T operator*() const { return _object; } + + /// Assigning from another CFHolder adds a reference. + CFHolder& operator=(const CFHolder& other) { return reset(*other); } + + /// Move assignment does not add a reference. + CFHolder& operator=(CFHolder&& other) { + // C++11 allows its library implementation to assume that rvalue reference + // arguments are not aliased. See 17.6.4.9 in the standard document. + ReleaseIfNotNull(_object); + _object = other._object; + other._object = NULL; + return *this; + } + + /// Equality and inequality operators. + bool operator==(const CFHolder& other) const { + return _object == other._object; + } + bool operator!=(const CFHolder& other) const { return !operator==(other); } + bool operator==(T object) const { return _object == object; } + bool operator!=(T object) const { return !operator==(object); } + + /// Sets the managed object. + CFHolder& reset(T object) { + T old = _object; + _object = RetainIfNotNull(object); + ReleaseIfNotNull(old); + return *this; + } + + /// Takes ownership of the object. Does not retain. + CFHolder& adopt(T object) { + ReleaseIfNotNull(_object); + _object = object; + return *this; + } + + private: + static inline T RetainIfNotNull(T object) { + if (object) CFRetain(object); + return object; + } + + static inline void ReleaseIfNotNull(T object) { + if (object) CFRelease(object); + } + T _object; +}; + +/// Using these functions allows template argument deduction (i.e. you do not +/// need to specify the type of object the holder holds, it is inferred from +/// the argument. +template +CFHolder* NewCFHolder(T object) { + return new CFHolder(object); +} + +template +CFHolder MakeCFHolder(T object) { + return CFHolder(object); +} + +template +CFHolder MakeCFHolderAdopting(T object) { + return CFHolder().adopt(object); +} + +#endif // MEDIAPIPE_OBJC_CFHOLDER_H_ diff --git a/mediapipe/objc/CFHolderTests.mm b/mediapipe/objc/CFHolderTests.mm new file mode 100644 index 000000000..a952c3eaf --- /dev/null +++ b/mediapipe/objc/CFHolderTests.mm @@ -0,0 +1,175 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#import "mediapipe/objc/CFHolder.h" + +#include + +@interface CFHolderTests : XCTestCase { + UInt8 _bytes[4]; + CFDataRef _data; +} +@end + +@implementation CFHolderTests + +- (void)setUp { + _data = CFDataCreate(NULL, _bytes, sizeof(_bytes)); +} + +- (void)tearDown { + CFRelease(_data); +} + +- (void)testCreateAndDestroy { + XCTAssertEqual(CFGetRetainCount(_data), 1); + { + CFHolder holder(_data); + XCTAssertEqual(CFGetRetainCount(_data), 2); + } + XCTAssertEqual(CFGetRetainCount(_data), 1); +} + +- (void)testDereference { + CFHolder holder(_data); + XCTAssertEqual(*holder, _data); +} + +- (void)testCopy { + XCTAssertEqual(CFGetRetainCount(_data), 1); + { + CFHolder holder(_data); + XCTAssertEqual(CFGetRetainCount(_data), 2); + { + CFHolder holder2(holder); + XCTAssertEqual(CFGetRetainCount(_data), 3); + { + CFHolder holder3 = holder; + XCTAssertEqual(CFGetRetainCount(_data), 4); + } + XCTAssertEqual(CFGetRetainCount(_data), 3); + } + XCTAssertEqual(CFGetRetainCount(_data), 2); + } + XCTAssertEqual(CFGetRetainCount(_data), 1); +} + +- (void)testOverwriteWithNull { + XCTAssertEqual(CFGetRetainCount(_data), 1); + { + // Copy assignment. + CFHolder holder(_data); + XCTAssertEqual(CFGetRetainCount(_data), 2); + CFHolder holder2; + holder = holder2; + XCTAssertEqual(CFGetRetainCount(_data), 1); + } + { + // Move assignment. + CFHolder holder(_data); + XCTAssertEqual(CFGetRetainCount(_data), 2); + holder = CFHolder(); + XCTAssertEqual(CFGetRetainCount(_data), 1); + } +} + +- (void)testOverwriteWithOther { + CFDataRef data2 = CFDataCreate(NULL, _bytes, sizeof(_bytes)); + XCTAssertEqual(CFGetRetainCount(_data), 1); + XCTAssertEqual(CFGetRetainCount(data2), 1); + { + // Copy assignment. + CFHolder holder(_data); + XCTAssertEqual(CFGetRetainCount(_data), 2); + CFHolder holder2(data2); + XCTAssertEqual(CFGetRetainCount(data2), 2); + holder = holder2; + XCTAssertEqual(CFGetRetainCount(_data), 1); + XCTAssertEqual(CFGetRetainCount(data2), 3); + } + { + // Move assignment. + CFHolder holder(_data); + XCTAssertEqual(CFGetRetainCount(_data), 2); + holder = CFHolder(data2); + XCTAssertEqual(CFGetRetainCount(_data), 1); + XCTAssertEqual(CFGetRetainCount(data2), 2); + } + CFRelease(data2); +} + +- (void)testCompare { + CFDataRef data2 = CFDataCreate(NULL, _bytes, sizeof(_bytes)); + CFHolder holder(_data); + CFHolder holdersame(_data); + CFHolder holderother(data2); + CFHolder empty; + // Compare with other holder. + XCTAssertEqual(holder, holder); + XCTAssertEqual(holder, holdersame); + XCTAssertNotEqual(holder, holderother); + XCTAssertNotEqual(holder, empty); + // Compare with held type. + XCTAssertEqual(holder, _data); + XCTAssertNotEqual(holder, data2); + XCTAssertNotEqual(holder, nil); + XCTAssertEqual(empty, nil); + XCTAssertNotEqual(empty, _data); + + CFRelease(data2); +} + +- (void)testReset { + XCTAssertEqual(CFGetRetainCount(_data), 1); + { + CFHolder holder; + holder.reset(_data); + XCTAssertEqual(CFGetRetainCount(_data), 2); + CFDataRef data2 = CFDataCreate(NULL, _bytes, sizeof(_bytes)); + holder.reset(data2); + XCTAssertEqual(CFGetRetainCount(_data), 1); + XCTAssertEqual(CFGetRetainCount(data2), 2); + CFRelease(data2); + } + XCTAssertEqual(CFGetRetainCount(_data), 1); +} + +- (void)testAdopt { + CFRetain(_data); + XCTAssertEqual(CFGetRetainCount(_data), 2); + { + CFHolder holder; + holder.adopt(_data); + XCTAssertEqual(CFGetRetainCount(_data), 2); + } + XCTAssertEqual(CFGetRetainCount(_data), 1); +} + +- (void)testMove { + XCTAssertEqual(CFGetRetainCount(_data), 1); + { + CFHolder holder(_data); + XCTAssertEqual(CFGetRetainCount(_data), 2); + CFHolder holder2(std::move(holder)); + XCTAssertEqual(CFGetRetainCount(_data), 2); + CFHolder holder3(std::move(holder2)); + XCTAssertEqual(CFGetRetainCount(_data), 2); + } + XCTAssertEqual(CFGetRetainCount(_data), 1); +} + +@end diff --git a/mediapipe/objc/CGImageRefUtils.h b/mediapipe/objc/CGImageRefUtils.h new file mode 100644 index 000000000..7b76deb77 --- /dev/null +++ b/mediapipe/objc/CGImageRefUtils.h @@ -0,0 +1,43 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_OBJC_CGIMAGEREFUTILS_H_ +#define MEDIAPIPE_OBJC_CGIMAGEREFUTILS_H_ + +#import + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +extern NSString *const kCGImageRefUtilsErrorDomain; + +// TODO: Get rid of this library or make it a wrapper around util.h +// versions so that it can be used in pure Objective-C code. + +/// Creates a CGImage with a copy of the contents of the CVPixelBuffer. Returns nil on error, if +/// the |error| argument is not nil, *error is set to an NSError describing the failure. Caller +/// is responsible for releasing the CGImage by calling CGImageRelease(). +CGImageRef CreateCGImageFromCVPixelBuffer(CVPixelBufferRef imageBuffer, NSError **error); + +/// Creates a CVPixelBuffer with a copy of the contents of the CGImage. Returns nil on error, if +/// the |error| argument is not nil, *error is set to an NSError describing the failure. Caller +/// is responsible for releasing the CVPixelBuffer by calling CVPixelBufferRelease. +CVPixelBufferRef CreateCVPixelBufferFromCGImage(CGImageRef image, NSError **error); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // MEDIAPIPE_OBJC_CGIMAGEREFUTILS_H_ diff --git a/mediapipe/objc/CGImageRefUtils.mm b/mediapipe/objc/CGImageRefUtils.mm new file mode 100644 index 000000000..4d7f47325 --- /dev/null +++ b/mediapipe/objc/CGImageRefUtils.mm @@ -0,0 +1,46 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "CGImageRefUtils.h" + +#import +#import "mediapipe/objc/CFHolder.h" +#import "mediapipe/objc/NSError+util_status.h" +#import "mediapipe/objc/util.h" + +#include "mediapipe/framework/port/status.h" + +CGImageRef CreateCGImageFromCVPixelBuffer(CVPixelBufferRef imageBuffer, NSError **error) { + CFHolder cg_image_holder; + ::mediapipe::Status status = CreateCGImageFromCVPixelBuffer(imageBuffer, &cg_image_holder); + if (!status.ok()) { + *error = [NSError gus_errorWithStatus:status]; + return nil; + } + CGImageRef cg_image = *cg_image_holder; + CGImageRetain(cg_image); + return cg_image; +} + +CVPixelBufferRef CreateCVPixelBufferFromCGImage(CGImageRef image, NSError **error) { + CFHolder pixel_buffer_holder; + ::mediapipe::Status status = CreateCVPixelBufferFromCGImage(image, &pixel_buffer_holder); + if (!status.ok()) { + *error = [NSError gus_errorWithStatus:status]; + return nil; + } + CVPixelBufferRef pixel_buffer = *pixel_buffer_holder; + CVPixelBufferRetain(pixel_buffer); + return pixel_buffer; +} diff --git a/mediapipe/objc/MPPCameraInputSource.h b/mediapipe/objc/MPPCameraInputSource.h new file mode 100644 index 000000000..9ce2514ea --- /dev/null +++ b/mediapipe/objc/MPPCameraInputSource.h @@ -0,0 +1,52 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/objc/MPPInputSource.h" + +/// A source that obtains video frames from the camera. +@interface MPPCameraInputSource : MPPInputSource + +/// Whether we are allowed to use the camera. +@property(nonatomic, getter=isAuthorized, readonly) BOOL authorized; + +/// Session preset to use for capturing. +@property(nonatomic) NSString *sessionPreset; + +/// Which camera on an iOS device to use, assuming iOS device with more than one camera. +@property(nonatomic) AVCaptureDevicePosition cameraPosition; + +// Whether to use depth data or not +@property(nonatomic) BOOL useDepth; + +/// Whether to rotate video buffers with device rotation. +@property(nonatomic) BOOL autoRotateBuffers; + +/// The capture session. +@property(nonatomic, readonly) AVCaptureSession *session; + +/// The capture video preview layer. +@property(nonatomic, readonly) AVCaptureVideoPreviewLayer *videoPreviewLayer; + +/// The orientation of camera frame buffers. +@property(nonatomic) AVCaptureVideoOrientation orientation; + +/// Prompts the user to grant camera access and provides the result as a BOOL to a completion +/// handler. Should be called after [MPPCameraInputSource init] and before +/// [MPPCameraInputSource start]. If the user has previously granted or denied permission, this +/// method simply returns the saved response to the permission request. +- (void)requestCameraAccessWithCompletionHandler:(void (^_Nullable)(BOOL granted))handler; + +@end diff --git a/mediapipe/objc/MPPCameraInputSource.m b/mediapipe/objc/MPPCameraInputSource.m new file mode 100644 index 000000000..3fdaa00be --- /dev/null +++ b/mediapipe/objc/MPPCameraInputSource.m @@ -0,0 +1,298 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "MPPCameraInputSource.h" + +#import + +@interface MPPCameraInputSource () +@end + +@implementation MPPCameraInputSource { + AVCaptureSession* _session; + AVCaptureDeviceInput* _videoDeviceInput; + AVCaptureVideoDataOutput* _videoDataOutput; + AVCaptureDepthDataOutput* _depthDataOutput; + AVCaptureDevice *_currentDevice; + + OSType _pixelFormatType; + BOOL _autoRotateBuffers; + BOOL _setupDone; + BOOL _useDepth; + BOOL _useCustomOrientation; +} + +- (instancetype)init { + self = [super init]; + if (self) { + _cameraPosition = AVCaptureDevicePositionBack; + _session = [[AVCaptureSession alloc] init]; + _pixelFormatType = kCVPixelFormatType_32BGRA; + + AVAuthorizationStatus status = + [AVCaptureDevice authorizationStatusForMediaType:AVMediaTypeVideo]; + _authorized = status == AVAuthorizationStatusAuthorized; + } + return self; +} + +- (void)setDelegate:(id)delegate + queue:(dispatch_queue_t)queue { + [super setDelegate:delegate queue:queue]; + // Note that _depthDataOutput and _videoDataOutput may not have been created yet. In that case, + // this message to nil is ignored, and the delegate will be set later by setupCamera. + [_videoDataOutput setSampleBufferDelegate:self queue:queue]; + [_depthDataOutput setDelegate:self callbackQueue:queue]; +} + +- (void)start { + if (!_setupDone) [self setupCamera]; + if (_autoRotateBuffers) { + [self enableAutoRotateBufferObserver:YES]; + } + [_session startRunning]; +} + +- (void)stop { + if (_autoRotateBuffers) { + [self enableAutoRotateBufferObserver:NO]; + } + [_session stopRunning]; +} + +- (BOOL)isRunning { + return _session.isRunning; +} + +- (void)setCameraPosition:(AVCaptureDevicePosition)cameraPosition { + BOOL wasRunning = [self isRunning]; + if (wasRunning) { + [self stop]; + } + _cameraPosition = cameraPosition; + _setupDone = NO; + if (wasRunning) { + [self start]; + } +} + +- (void)setUseDepth:(BOOL)useDepth { + if (useDepth == _useDepth) { + return; + } + + BOOL wasRunning = [self isRunning]; + if (wasRunning) { + [self stop]; + } + _useDepth = useDepth; + _setupDone = NO; + if (wasRunning) { + [self start]; + } +} + +- (void)setOrientation:(AVCaptureVideoOrientation)orientation { + if (orientation == _orientation) { + return; + } + + BOOL wasRunning = [self isRunning]; + if (wasRunning) { + [self stop]; + } + + _orientation = orientation; + _useCustomOrientation = YES; + _setupDone = NO; + if (wasRunning) { + [self start]; + } +} + +- (void)setAutoRotateBuffers:(BOOL)autoRotateBuffers { + if (autoRotateBuffers == _autoRotateBuffers) { + return; // State has not changed. + } + _autoRotateBuffers = autoRotateBuffers; + if ([self isRunning]) { + // Enable or disable observer this settings changes while this input source is running. + [self enableAutoRotateBufferObserver:_autoRotateBuffers]; + } +} + +- (void)enableAutoRotateBufferObserver:(BOOL)enable { + if (enable) { + [[NSNotificationCenter defaultCenter] addObserver:self + selector:@selector(deviceOrientationChanged) + name:UIDeviceOrientationDidChangeNotification + object:nil]; + // Trigger a device orientation change instead of waiting for the first change. + [self deviceOrientationChanged]; + } else { + [[NSNotificationCenter defaultCenter] removeObserver:self + name:UIDeviceOrientationDidChangeNotification + object:nil]; + } +} + +- (OSType)pixelFormatType { + return _pixelFormatType; +} + +- (void)setPixelFormatType:(OSType)pixelFormatType { + _pixelFormatType = pixelFormatType; + if ([self isRunning]) { + _videoDataOutput.videoSettings = @{ + (id)kCVPixelBufferPixelFormatTypeKey : @(_pixelFormatType) + }; + } +} + +#pragma mark - Camera-specific methods + +- (NSString*)sessionPreset { + return _session.sessionPreset; +} + +- (void)setSessionPreset:(NSString*)sessionPreset { + _session.sessionPreset = sessionPreset; +} + +- (void)setupCamera { + NSError* error = nil; + + if (_videoDeviceInput) { + [_session removeInput:_videoDeviceInput]; + } + + AVCaptureDeviceDiscoverySession* deviceDiscoverySession = [AVCaptureDeviceDiscoverySession + discoverySessionWithDeviceTypes:@[ + _cameraPosition == AVCaptureDevicePositionFront && _useDepth ? + AVCaptureDeviceTypeBuiltInTrueDepthCamera : + AVCaptureDeviceTypeBuiltInWideAngleCamera] + mediaType:AVMediaTypeVideo + position:_cameraPosition]; + AVCaptureDevice* videoDevice = + [deviceDiscoverySession devices] + ? [deviceDiscoverySession devices].firstObject + : [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo]; + _videoDeviceInput = [AVCaptureDeviceInput deviceInputWithDevice:videoDevice error:&error]; + if (error) { + NSLog(@"%@", error); + return; + } + [_session addInput:_videoDeviceInput]; + + if (!_videoDataOutput) { + _videoDataOutput = [[AVCaptureVideoDataOutput alloc] init]; + [_session addOutput:_videoDataOutput]; + + // Set this when we have a handler. + if (self.delegateQueue) + [_videoDataOutput setSampleBufferDelegate:self queue:self.delegateQueue]; + _videoDataOutput.alwaysDiscardsLateVideoFrames = YES; + + // Only a few pixel formats are available for capture output: + // kCVPixelFormatType_420YpCbCr8BiPlanarVideoRange, + // kCVPixelFormatType_420YpCbCr8BiPlanarFullRange, + // kCVPixelFormatType_32BGRA. + _videoDataOutput.videoSettings = @{ + (id)kCVPixelBufferPixelFormatTypeKey : @(_pixelFormatType) + }; + } + + // Remove Old Depth Depth + if (_depthDataOutput) { + [_session removeOutput:_depthDataOutput]; + } + + if (_useDepth) { + // Add Depth Output + _depthDataOutput = [[AVCaptureDepthDataOutput alloc] init]; + _depthDataOutput.alwaysDiscardsLateDepthData = YES; + if ([_session canAddOutput:_depthDataOutput]) { + [_session addOutput:_depthDataOutput]; + + AVCaptureConnection* connection = + [_depthDataOutput connectionWithMediaType:AVMediaTypeDepthData]; + + if (connection != nil) + connection.enabled = true; + + // Set this when we have a handler. + if (self.delegateQueue) { + [_depthDataOutput setDelegate:self callbackQueue:self.delegateQueue]; + } + } + else + _depthDataOutput = nil; + } + + if (_useCustomOrientation) { + AVCaptureConnection* connection = [_videoDataOutput connectionWithMediaType:AVMediaTypeVideo]; + connection.videoOrientation = _orientation; + } + + _setupDone = YES; +} + +- (void)requestCameraAccessWithCompletionHandler:(void (^)(BOOL))handler { + [AVCaptureDevice requestAccessForMediaType:AVMediaTypeVideo + completionHandler:^(BOOL granted) { + _authorized = granted; + if (handler) { + handler(granted); + } + }]; +} + +#pragma mark - AVCaptureVideoDataOutputSampleBufferDelegate methods + +// Receives frames from the camera. Invoked on self.frameHandlerQueue. +- (void)captureOutput:(AVCaptureOutput*)captureOutput +didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer + fromConnection:(AVCaptureConnection*)connection { + CVPixelBufferRef imageBuffer = CMSampleBufferGetImageBuffer(sampleBuffer); + CMTime timestamp = CMSampleBufferGetPresentationTimeStamp(sampleBuffer); + if ([self.delegate respondsToSelector:@selector(processVideoFrame:timestamp:fromSource:)]) { + [self.delegate processVideoFrame:imageBuffer timestamp:timestamp fromSource:self]; + } else if ([self.delegate respondsToSelector:@selector(processVideoFrame:fromSource:)]) { + [self.delegate processVideoFrame:imageBuffer fromSource:self]; + } +} + +#pragma mark - AVCaptureDepthDataOutputDelegate methods + +// Receives depth frames from the camera. Invoked on self.frameHandlerQueue. +- (void)depthDataOutput:(AVCaptureDepthDataOutput *)output + didOutputDepthData:(AVDepthData *)depthData + timestamp:(CMTime)timestamp + connection:(AVCaptureConnection *)connection { + if (depthData.depthDataType != kCVPixelFormatType_DepthFloat32) { + depthData = [depthData depthDataByConvertingToDepthDataType:kCVPixelFormatType_DepthFloat32]; + } + CVPixelBufferRef depthBuffer = depthData.depthDataMap; + [self.delegate processDepthData:depthData timestamp:timestamp fromSource:self]; +} + +#pragma mark - NSNotificationCenter event handlers + +- (void)deviceOrientationChanged { + AVCaptureConnection* connection = [_videoDataOutput connectionWithMediaType:AVMediaTypeVideo]; + connection.videoOrientation = (AVCaptureVideoOrientation)[UIDevice currentDevice].orientation; +} + +@end diff --git a/mediapipe/objc/MPPDisplayLinkWeakTarget.h b/mediapipe/objc/MPPDisplayLinkWeakTarget.h new file mode 100644 index 000000000..eb8fb077d --- /dev/null +++ b/mediapipe/objc/MPPDisplayLinkWeakTarget.h @@ -0,0 +1,26 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +/// A generic target/callback holder. Useful for indirectly using DisplayLink and allowing the +/// complete deletion of displaylink reference holders. +@interface MPPDisplayLinkWeakTarget : NSObject + +- (instancetype)initWithTarget:(id)target selector:(SEL)sel; + +- (void)displayLinkCallback:(CADisplayLink *)sender; + +@end diff --git a/mediapipe/objc/MPPDisplayLinkWeakTarget.m b/mediapipe/objc/MPPDisplayLinkWeakTarget.m new file mode 100644 index 000000000..c5922a473 --- /dev/null +++ b/mediapipe/objc/MPPDisplayLinkWeakTarget.m @@ -0,0 +1,40 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/objc/MPPDisplayLinkWeakTarget.h" + +@implementation MPPDisplayLinkWeakTarget { + __weak id _target; + SEL _selector; +} + +#pragma mark - Init + +- (instancetype)initWithTarget:(id)target selector:(SEL)sel { + self = [super init]; + if (self) { + _target = target; + _selector = sel; + } + return self; +} + +#pragma mark - Public + +- (void)displayLinkCallback:(CADisplayLink *)sender { + void (*display)(id, SEL, CADisplayLink *) = (void *)[_target methodForSelector:_selector]; + display(_target, _selector, sender); +} + +@end diff --git a/mediapipe/objc/MPPGLViewRenderer.h b/mediapipe/objc/MPPGLViewRenderer.h new file mode 100644 index 000000000..3cb0bbc42 --- /dev/null +++ b/mediapipe/objc/MPPGLViewRenderer.h @@ -0,0 +1,68 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +/// Modes of clockwise rotation for input frames. +typedef enum { + MediaPipeFrameRotationNone, + MediaPipeFrameRotation90, + MediaPipeFrameRotation180, + MediaPipeFrameRotation270 +} MediaPipeFrameRotationMode; + +typedef enum { + // Scale the frame up to fit the drawing area, preserving aspect ratio; may letterbox. + MediaPipeFrameScaleFit, + // Scale the frame up to fill the drawing area, preserving aspect ratio; may crop. + MediaPipeFrameScaleFillAndCrop, +} MediaPipeFrameScaleMode; + +/// Renders frames in a GLKView. +@interface MPPGLViewRenderer : NSObject + +/// Rendering context for display. +@property(nonatomic) EAGLContext *glContext; + +/// The frame to be rendered next. This pixel buffer must be unlocked, and +/// should not be modified after handing it to the renderer. +@property(atomic, retain) __attribute__((NSObject)) CVPixelBufferRef nextPixelBufferToRender; + +/// When YES, the last drawn pixel buffer is retained by this object after it is drawn in the GLView +/// for which it is a delegate. Otherwise it is released after it has been rendered. +/// Set this property to YES when your GLView can be redrawn with the same pixel buffer, such as +/// during an animation. +@property(nonatomic, assign) BOOL retainsLastPixelBuffer; + +/// Sets which way to rotate input frames before rendering them. +/// Default value is MediaPipeFrameRotationNone. +/// Note that changing the transform property of a GLKView once rendering has +/// started causes problems inside GLKView. Instead, we perform the rotation +/// in our rendering code. +@property(nonatomic) MediaPipeFrameRotationMode frameRotationMode; + +/// Sets how to scale the frame within the view. +/// Default value is MediaPipeFrameScaleScaleToFit. +@property(nonatomic) MediaPipeFrameScaleMode frameScaleMode; + +/// If YES, swap left and right. Useful for the front camera. +@property(nonatomic) BOOL mirrored; + +/// Draws a pixel buffer to its context with the specified view size. +- (void)drawPixelBuffer:(CVPixelBufferRef)pixelBuffer + width:(GLfloat)viewWidth + height:(GLfloat)viewHeight; + +@end diff --git a/mediapipe/objc/MPPGLViewRenderer.mm b/mediapipe/objc/MPPGLViewRenderer.mm new file mode 100644 index 000000000..458be934e --- /dev/null +++ b/mediapipe/objc/MPPGLViewRenderer.mm @@ -0,0 +1,189 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "MPPGLViewRenderer.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/gpu/gl_quad_renderer.h" +#include "mediapipe/gpu/gl_simple_shaders.h" +#include "mediapipe/gpu/shader_util.h" + +#import "mediapipe/objc/NSError+util_status.h" +#import "GTMDefines.h" + +@implementation MPPGLViewRenderer { + /// Used to create textures for the pixel buffers to be rendered. + /// The use of this class allows the GPU to access the buffer's memory directly, + /// without a memory copy. + CVOpenGLESTextureCacheRef _textureCache; + + /// Internal renderer. + std::unique_ptr renderer_; + + /// Used to synchronize access to _nextPixelBufferToRender. + OSSpinLock _bufferLock; + volatile CVPixelBufferRef _nextPixelBufferToRender; +} + +- (instancetype)init { + self = [super init]; + if (self) { + _glContext = [[EAGLContext alloc] initWithAPI:kEAGLRenderingAPIOpenGLES2]; + _bufferLock = OS_SPINLOCK_INIT; + _frameRotationMode = MediaPipeFrameRotationNone; + _frameScaleMode = MediaPipeFrameScaleFit; + } + return self; +} + +- (void)dealloc { + if (_textureCache) CFRelease(_textureCache); + CVPixelBufferRelease(_nextPixelBufferToRender); + // Fixes crash during dealloc that only happens in iOS 9. More info at b/67095363. + if ([EAGLContext currentContext] == _glContext) { + [EAGLContext setCurrentContext:nil]; + } +} + +- (CVPixelBufferRef)nextPixelBufferToRender { + OSSpinLockLock(&_bufferLock); + CVPixelBufferRef buffer = _nextPixelBufferToRender; + OSSpinLockUnlock(&_bufferLock); + return buffer; +} + +- (void)setNextPixelBufferToRender:(CVPixelBufferRef)buffer { + OSSpinLockLock(&_bufferLock); + if (_nextPixelBufferToRender != buffer) { + CVPixelBufferRelease(_nextPixelBufferToRender); + _nextPixelBufferToRender = buffer; + CVPixelBufferRetain(_nextPixelBufferToRender); + } + OSSpinLockUnlock(&_bufferLock); +} + +- (void)setupGL { + CVReturn err; + err = CVOpenGLESTextureCacheCreate(kCFAllocatorDefault, NULL, _glContext, NULL, &_textureCache); + _GTMDevAssert(err == kCVReturnSuccess, + @"CVOpenGLESTextureCacheCreate failed: %d", err); + + renderer_ = absl::make_unique(); + auto status = renderer_->GlSetup(); + _GTMDevAssert(status.ok(), + @"renderer setup failed: %@", [NSError gus_errorWithStatus:status]); +} + +mediapipe::FrameScaleMode InternalScaleMode(MediaPipeFrameScaleMode mode) { + switch (mode) { + case MediaPipeFrameScaleFit: + return mediapipe::FrameScaleMode::kFit; + case MediaPipeFrameScaleFillAndCrop: + return mediapipe::FrameScaleMode::kFillAndCrop; + } +} + +mediapipe::FrameRotation InternalRotationMode(MediaPipeFrameRotationMode rot) { + switch (rot) { + case MediaPipeFrameRotationNone: + return mediapipe::FrameRotation::kNone; + case MediaPipeFrameRotation90: + return mediapipe::FrameRotation::k90; + case MediaPipeFrameRotation180: + return mediapipe::FrameRotation::k180; + case MediaPipeFrameRotation270: + return mediapipe::FrameRotation::k270; + } +} + +- (void)drawPixelBuffer:(CVPixelBufferRef)pixelBuffer + width:(GLfloat)viewWidth + height:(GLfloat)viewHeight { + if (!_textureCache) [self setupGL]; + + size_t frameWidth = CVPixelBufferGetWidth(pixelBuffer); + size_t frameHeight = CVPixelBufferGetHeight(pixelBuffer); + CVOpenGLESTextureRef texture = NULL; + OSType pixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); + CVReturn error; + if (pixelFormat == kCVPixelFormatType_OneComponent8) { + error = CVOpenGLESTextureCacheCreateTextureFromImage( + kCFAllocatorDefault, _textureCache, pixelBuffer, NULL, + GL_TEXTURE_2D, GL_LUMINANCE, (GLsizei)frameWidth, (GLsizei)frameHeight, + GL_LUMINANCE, GL_UNSIGNED_BYTE, 0, &texture); + } else if (pixelFormat == kCVPixelFormatType_OneComponent32Float) { + error = CVOpenGLESTextureCacheCreateTextureFromImage( + kCFAllocatorDefault, _textureCache, pixelBuffer, NULL, + GL_TEXTURE_2D, GL_LUMINANCE, (GLsizei)frameWidth, (GLsizei)frameHeight, + GL_LUMINANCE, GL_FLOAT, 0, &texture); + } else { + error = CVOpenGLESTextureCacheCreateTextureFromImage( + kCFAllocatorDefault, _textureCache, pixelBuffer, NULL, + GL_TEXTURE_2D, GL_RGBA, (GLsizei)frameWidth, (GLsizei)frameHeight, + GL_BGRA, GL_UNSIGNED_BYTE, 0, &texture); + } + _GTMDevAssert(error == kCVReturnSuccess, + @"CVOpenGLESTextureCacheCreateTextureFromImage failed: %d", error); + + glClear(GL_COLOR_BUFFER_BIT); + + glActiveTexture(GL_TEXTURE1); + glBindTexture(CVOpenGLESTextureGetTarget(texture), CVOpenGLESTextureGetName(texture)); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + + // Note: we perform a vertical flip by swapping the top and bottom coordinates. + // CVPixelBuffers have a top left origin and OpenGL has a bottom left origin. + auto status = renderer_->GlRender( + frameWidth, frameHeight, viewWidth, viewHeight, + InternalScaleMode(_frameScaleMode), + InternalRotationMode(_frameRotationMode), + _mirrored, /*flip_vertical=*/false, /*flip_texture=*/true); + _GTMDevAssert(status.ok(), + @"render failed: %@", [NSError gus_errorWithStatus:status]); + + glBindTexture(CVOpenGLESTextureGetTarget(texture), 0); + glBindTexture(GL_TEXTURE_2D, 0 ); + CFRelease(texture); + CVOpenGLESTextureCacheFlush(_textureCache, 0); +} + +#pragma mark - GLKViewDelegate + +- (void)glkView:(GLKView *)view drawInRect:(CGRect)rect { + CVPixelBufferRef pixelBuffer = NULL; + + OSSpinLockLock(&_bufferLock); + pixelBuffer = _nextPixelBufferToRender; + if (_retainsLastPixelBuffer) { + CVPixelBufferRetain(pixelBuffer); + } else { + _nextPixelBufferToRender = NULL; + } + OSSpinLockUnlock(&_bufferLock); + + if (!pixelBuffer) return; + + [self drawPixelBuffer:pixelBuffer width:view.drawableWidth height:view.drawableHeight]; + CVPixelBufferRelease(pixelBuffer); +} + +@end diff --git a/mediapipe/objc/MPPGpuSimpleTest.mm b/mediapipe/objc/MPPGpuSimpleTest.mm new file mode 100644 index 000000000..8557e9c5f --- /dev/null +++ b/mediapipe/objc/MPPGpuSimpleTest.mm @@ -0,0 +1,81 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/objc/MPPGraph.h" +#import "mediapipe/objc/MPPGraphTestBase.h" + +@interface MPPGpuSimpleTest : MPPGraphTestBase +@end + +@implementation MPPGpuSimpleTest{ + CFHolder _inputPixelBuffer; + CFHolder _referencePixelBuffer; + CFHolder _outputPixelBuffer; +} +- (void)setUp { + [super setUp]; + UIImage* image = [self testImageNamed:@"sergey" extension:@"png"]; + XCTAssertTrue(CreateCVPixelBufferFromCGImage(image.CGImage, &_inputPixelBuffer).ok()); + image = [self testImageNamed:@"sobel_reference" extension:@"png"]; + XCTAssertTrue(CreateCVPixelBufferFromCGImage(image.CGImage, &_referencePixelBuffer).ok()); +} + +// This delegate method receives output. +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer + fromStream:(const std::string&)streamName + timestamp:(const mediapipe::Timestamp&)timestamp { + NSLog(@"CALLBACK INVOKED"); + _outputPixelBuffer.reset(pixelBuffer); +} + +- (void)testSimpleGpuGraph { + // Graph setup. + NSData* configData = [self testDataNamed:@"test_sobel.binarypb" extension:nil]; + mediapipe::CalculatorGraphConfig config; + XCTAssertTrue(config.ParseFromArray([configData bytes], [configData length])); + MPPGraph* mediapipeGraph = [[MPPGraph alloc] initWithGraphConfig:config]; + // We receive output by setting ourselves as the delegate. + mediapipeGraph.delegate = self; + [mediapipeGraph addFrameOutputStream:"output_video" outputPacketType:MediaPipePacketPixelBuffer]; + + // Start running the graph. + NSError *error; + BOOL success = [mediapipeGraph startWithError:&error]; + XCTAssertTrue(success, @"%@", error.localizedDescription); + + // Send a frame. + XCTAssertTrue([mediapipeGraph sendPixelBuffer:*_inputPixelBuffer + intoStream:"input_video" + packetType:MediaPipePacketPixelBuffer + timestamp:mediapipe::Timestamp(0)]); + + // Shut down the graph. + success = [mediapipeGraph closeAllInputStreamsWithError:&error]; + XCTAssertTrue(success, @"%@", error.localizedDescription); + success = [mediapipeGraph waitUntilDoneWithError:&error]; + XCTAssertTrue(success, @"%@", error.localizedDescription); + + // Check output. + XCTAssertTrue(_outputPixelBuffer != nullptr); + [self savePixelBufferToSponge:*_outputPixelBuffer + withSubpath:@"sobel.png"]; + XCTAssertTrue([self pixelBuffer:*_outputPixelBuffer + isCloseTo:*_referencePixelBuffer + maxLocalDifference:5 + maxAverageDifference:FLT_MAX]); +} +@end diff --git a/mediapipe/objc/MPPGraph.h b/mediapipe/objc/MPPGraph.h new file mode 100644 index 000000000..e1e6b72c6 --- /dev/null +++ b/mediapipe/objc/MPPGraph.h @@ -0,0 +1,214 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#ifndef __cplusplus +#error This header can only be included by an Objective-C++ file. +#endif + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/objc/util.h" + +@class MPPGraph; + +namespace mediapipe { +struct GpuSharedData; +} // namespace mediapipe + +/// A delegate that can receive frames from a MediaPipe graph. +@protocol MPPGraphDelegate + +/// Provides the delegate with a new video frame. +@optional +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer + fromStream:(const std::string&)streamName; + +/// Provides the delegate with a new video frame and time stamp. +@optional +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer + fromStream:(const std::string&)streamName + timestamp:(const mediapipe::Timestamp&)timestamp; + +/// Provides the delegate with a raw packet. +@optional +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPacket:(const mediapipe::Packet&)packet + fromStream:(const std::string&)streamName; + +@end + +/// Chooses the packet type used by MPPGraph to send and receive packets +/// from the graph. +typedef NS_ENUM(int, MediaPipePacketType) { + /// Any packet type. + /// Calls mediapipeGraph:didOutputPacket:fromStream: + MediaPipePacketRaw, + + /// CFHolder. + /// Calls mediapipeGraph:didOutputPixelBuffer:fromStream: + /// Use this packet type to pass GPU frames to calculators. + MediaPipePacketPixelBuffer, + + /// ImageFrame. + /// Calls mediapipeGraph:didOutputPixelBuffer:fromStream: + MediaPipePacketImageFrame, + + /// RGBA ImageFrame, but do not swap the channels if the input pixel buffer + /// is BGRA. This is useful when the graph needs RGBA ImageFrames, but the + /// calculators do not care about the order of the channels, so BGRA data can + /// be used as-is. + /// Calls mediapipeGraph:didOutputPixelBuffer:fromStream: + MediaPipePacketImageFrameBGRANoSwap, +}; + +/// This class is an Objective-C wrapper around a MediaPipe graph object, and +/// helps interface it with iOS technologies such as AVFoundation. +@interface MPPGraph : NSObject + +/// The delegate, which receives output frames. +@property(weak) id delegate; + +/// If the graph is already processing more than this number of frames, drop any +/// new incoming frames. Used to avoid swamping slower devices when processing +/// cannot keep up with the speed of video input. +/// This works as long as frames are sent or received using these methods: +/// - sendPixelBuffer:intoStream:packetType:[timestamp:] +/// - addFrameOutputStream:outputPacketType: +/// Set to 0 (the default) for no limit. +@property(nonatomic) int maxFramesInFlight; + +/// Determines whether adding a packet to an input stream whose queue is full +/// should fail or wait. +@property mediapipe::CalculatorGraph::GraphInputStreamAddMode packetAddMode; + +- (instancetype)init NS_UNAVAILABLE; + +/// Copies the config and initializes the graph. +/// @param config The configuration describing the graph. +- (instancetype)initWithGraphConfig:(const mediapipe::CalculatorGraphConfig&)config + NS_DESIGNATED_INITIALIZER; + +- (mediapipe::ProfilingContext*)getProfiler; + +/// Sets a stream header. If the header was already set, it is overwritten. +/// @param packet The header. +/// @param streamName The name of the stream. +- (void)setHeaderPacket:(const mediapipe::Packet&)packet forStream:(const std::string&)streamName; + +/// Sets a side packet. If it was already set, it is overwritten. +/// Must be called before the graph is started. +/// @param packet The packet to be associated with the input side packet. +/// @param name The name of the input side packet. +- (void)setSidePacket:(const mediapipe::Packet&)packet named:(const std::string&)name; + +/// Adds input side packets from a map. Any inputs that were already set are +/// left unchanged. +/// Must be called before the graph is started. +/// @param extraInputSidePackets The input side packets to be added. +- (void)addSidePackets:(const std::map&)extraSidePackets; + +// TODO: rename to addDelegateOutputStream:packetType: +/// Add an output stream in the graph from which the delegate wants to receive +/// output. The delegate method called depends on the provided packetType. +/// @param outputStreamName The name of the output stream from which +/// the delegate will receive frames. +/// @param packetType The type of packet provided by the output streams. +- (void)addFrameOutputStream:(const std::string&)outputStreamName + outputPacketType:(MediaPipePacketType)packetType; + +/// Starts running the graph. +/// @return YES if successful. +- (BOOL)startWithError:(NSError**)error; + +/// Sends a generic packet into a graph input stream. +/// The graph must have been started before calling this. +/// Returns YES if the packet was successfully sent. +- (BOOL)sendPacket:(const mediapipe::Packet&)packet + intoStream:(const std::string&)streamName + error:(NSError**)error; + +- (BOOL)movePacket:(mediapipe::Packet&&)packet + intoStream:(const std::string&)streamName + error:(NSError**)error; + +/// Sets the maximum queue size for a stream. Experimental feature, currently +/// only supported for graph input streams. Should be called before starting the +/// graph. +- (BOOL)setMaxQueueSize:(int)maxQueueSize + forStream:(const std::string&)streamName + error:(NSError**)error; + +/// Creates a MediaPipe packet wrapping the given pixelBuffer; +- (mediapipe::Packet)packetWithPixelBuffer:(CVPixelBufferRef)pixelBuffer + packetType:(MediaPipePacketType)packetType; + +/// Sends a pixel buffer into a graph input stream, using the specified packet +/// type. The graph must have been started before calling this. Drops frames and +/// returns NO if maxFramesInFlight is exceeded. If allowOverwrite is set to YES, +/// allows MediaPipe to overwrite the packet contents on successful sending for +/// possibly increased efficiency. Returns YES if the packet was successfully sent. +- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer + intoStream:(const std::string&)inputName + packetType:(MediaPipePacketType)packetType + timestamp:(const mediapipe::Timestamp&)timestamp + allowOverwrite:(BOOL)allowOverwrite; + +/// Sends a pixel buffer into a graph input stream, using the specified packet +/// type. The graph must have been started before calling this. Drops frames and +/// returns NO if maxFramesInFlight is exceeded. Returns YES if the packet was +/// successfully sent. +- (BOOL)sendPixelBuffer:(CVPixelBufferRef)pixelBuffer + intoStream:(const std::string&)inputName + packetType:(MediaPipePacketType)packetType + timestamp:(const mediapipe::Timestamp&)timestamp; + +/// Sends a pixel buffer into a graph input stream, using the specified packet +/// type. The graph must have been started before calling this. The timestamp is +/// automatically incremented from the last timestamp used by this method. Drops +/// frames and returns NO if maxFramesInFlight is exceeded. Returns YES if the +/// packet was successfully sent. +- (BOOL)sendPixelBuffer:(CVPixelBufferRef)pixelBuffer + intoStream:(const std::string&)inputName + packetType:(MediaPipePacketType)packetType; + +/// Cancels a graph run. You must still call waitUntilDoneWithError: after this. +- (void)cancel; + +/// Check if the graph contains this input stream +- (BOOL)hasInputStream:(const std::string&)inputName; + +/// Closes an input stream. +/// You must close all graph input streams before stopping the graph. +/// @return YES if successful. +- (BOOL)closeInputStream:(const std::string&)inputName error:(NSError**)error; + +/// Closes all graph input streams. +/// @return YES if successful. +- (BOOL)closeAllInputStreamsWithError:(NSError**)error; + +/// Stops running the graph. +/// Call this before releasing this object. All input streams must have been +/// closed. This call does not time out, so you should not call it from the main +/// thread. +/// @return YES if successful. +- (BOOL)waitUntilDoneWithError:(NSError**)error; + +/// Waits for the graph to become idle. +- (BOOL)waitUntilIdleWithError:(NSError**)error; + +@end diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm new file mode 100644 index 000000000..2d305fdbd --- /dev/null +++ b/mediapipe/objc/MPPGraph.mm @@ -0,0 +1,373 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/objc/MPPGraph.h" + +#import +#import + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/MPPGraphGPUData.h" +#include "mediapipe/gpu/gl_base.h" +#include "mediapipe/gpu/gpu_shared_data_internal.h" + +#import "mediapipe/objc/NSError+util_status.h" +#import "GTMDefines.h" + +@implementation MPPGraph { + // Graph is wrapped in a unique_ptr because it was generating 39+KB of unnecessary ObjC runtime + // information. See https://medium.com/@dmaclach/objective-c-encoding-and-you-866624cc02de + // for details. + std::unique_ptr _graph; + /// Input side packets that will be added to the graph when it is started. + std::map _inputSidePackets; + /// Packet headers that will be added to the graph when it is started. + std::map _streamHeaders; + + /// Number of frames currently being processed by the graph. + std::atomic _framesInFlight; + /// Used as a sequential timestamp for MediaPipe. + mediapipe::Timestamp _frameTimestamp; + int64 _frameNumber; + + // Graph config modified to expose requested output streams. + mediapipe::CalculatorGraphConfig _config; + + // Tracks whether the graph has been started and is currently running. + BOOL _started; +} + +- (instancetype)initWithGraphConfig:(const mediapipe::CalculatorGraphConfig&)config { + self = [super init]; + if (self) { + // Turn on Cocoa multithreading, since MediaPipe uses threads. + // Not needed on iOS, but we may want to have OS X clients in the future. + [[[NSThread alloc] init] start]; + _graph = absl::make_unique(); + _config = config; + } + return self; +} + +- (mediapipe::ProfilingContext*)getProfiler { + return _graph->profiler(); +} + +- (mediapipe::CalculatorGraph::GraphInputStreamAddMode)packetAddMode { + return _graph->GetGraphInputStreamAddMode(); +} + +- (void)setPacketAddMode:(mediapipe::CalculatorGraph::GraphInputStreamAddMode)mode { + _graph->SetGraphInputStreamAddMode(mode); +} + +- (void)addFrameOutputStream:(const std::string&)outputStreamName + outputPacketType:(MediaPipePacketType)packetType { + std::string callbackInputName; + mediapipe::tool::AddCallbackCalculator(outputStreamName, &_config, &callbackInputName, + /*use_std_function=*/true); + // No matter what ownership qualifiers are put on the pointer, NewPermanentCallback will + // still end up with a strong pointer to MPPGraph*. That is why we use void* instead. + void* wrapperVoid = (__bridge void*)self; + _inputSidePackets[callbackInputName] = + mediapipe::MakePacket>( + [wrapperVoid, outputStreamName, packetType](const mediapipe::Packet& packet) { + CallFrameDelegate(wrapperVoid, outputStreamName, packetType, packet); + }); +} + +- (NSString *)description { + return [NSString stringWithFormat:@"<%@: %p; framesInFlight = %d>", [self class], self, + _framesInFlight.load(std::memory_order_relaxed)]; +} + +/// This is the function that gets called by the CallbackCalculator that +/// receives the graph's output. +void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, + MediaPipePacketType packetType, const mediapipe::Packet& packet) { + MPPGraph* wrapper = (__bridge MPPGraph*)wrapperVoid; + @autoreleasepool { + if (packetType == MediaPipePacketRaw) { + [wrapper.delegate mediapipeGraph:wrapper + didOutputPacket:packet + fromStream:streamName]; + } else if (packetType == MediaPipePacketImageFrame) { + const auto& frame = packet.Get(); + mediapipe::ImageFormat::Format format = frame.Format(); + + if (format == mediapipe::ImageFormat::SRGBA || + format == mediapipe::ImageFormat::GRAY8) { + CVPixelBufferRef pixelBuffer; + // To ensure compatibility with CVOpenGLESTextureCache, this attribute should be present. + NSDictionary* attributes = @{ + (id)kCVPixelBufferIOSurfacePropertiesKey : @{}, + }; + // If kCVPixelFormatType_32RGBA does not work, it returns kCVReturnInvalidPixelFormat. + CVReturn error = CVPixelBufferCreate( + NULL, frame.Width(), frame.Height(), kCVPixelFormatType_32BGRA, + (__bridge CFDictionaryRef)attributes, &pixelBuffer); + _GTMDevAssert(error == kCVReturnSuccess, @"CVPixelBufferCreate failed: %d", error); + error = CVPixelBufferLockBaseAddress(pixelBuffer, 0); + _GTMDevAssert(error == kCVReturnSuccess, @"CVPixelBufferLockBaseAddress failed: %d", error); + + vImage_Buffer vDestination = vImageForCVPixelBuffer(pixelBuffer); + // Note: we have to throw away const here, but we should not overwrite + // the packet data. + vImage_Buffer vSource = vImageForImageFrame(frame); + if (format == mediapipe::ImageFormat::SRGBA) { + // Swap R and B channels. + const uint8_t permuteMap[4] = {2, 1, 0, 3}; + vImage_Error vError = vImagePermuteChannels_ARGB8888( + &vSource, &vDestination, permuteMap, kvImageNoFlags); + _GTMDevAssert(vError == kvImageNoError, @"vImagePermuteChannels failed: %zd", vError); + } else { + // Convert grayscale back to BGRA + vImage_Error vError = vImageGrayToBGRA(&vSource, &vDestination); + _GTMDevAssert(vError == kvImageNoError, @"vImageGrayToBGRA failed: %zd", vError); + } + + error = CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); + _GTMDevAssert(error == kCVReturnSuccess, + @"CVPixelBufferUnlockBaseAddress failed: %d", error); + + if ([wrapper.delegate respondsToSelector:@selector + (mediapipeGraph:didOutputPixelBuffer:fromStream:timestamp:)]) { + [wrapper.delegate mediapipeGraph:wrapper + didOutputPixelBuffer:pixelBuffer + fromStream:streamName + timestamp:packet.Timestamp()]; + } else if ([wrapper.delegate respondsToSelector:@selector + (mediapipeGraph:didOutputPixelBuffer:fromStream:)]) { + [wrapper.delegate mediapipeGraph:wrapper + didOutputPixelBuffer:pixelBuffer + fromStream:streamName]; + } + CVPixelBufferRelease(pixelBuffer); + } else { + _GTMDevLog(@"unsupported ImageFormat: %d", format); + } +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } else if (packetType == MediaPipePacketPixelBuffer) { + CVPixelBufferRef pixelBuffer = packet.Get().GetCVPixelBufferRef(); + if ([wrapper.delegate + respondsToSelector:@selector + (mediapipeGraph:didOutputPixelBuffer:fromStream:timestamp:)]) { + [wrapper.delegate mediapipeGraph:wrapper + didOutputPixelBuffer:pixelBuffer + fromStream:streamName + timestamp:packet.Timestamp()]; + } else if ([wrapper.delegate + respondsToSelector:@selector + (mediapipeGraph:didOutputPixelBuffer:fromStream:)]) { + [wrapper.delegate mediapipeGraph:wrapper + didOutputPixelBuffer:pixelBuffer + fromStream:streamName]; + } +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } else { + _GTMDevLog(@"unsupported packet type"); + } + + wrapper->_framesInFlight--; + } +} + +- (void)setHeaderPacket:(const mediapipe::Packet&)packet forStream:(const std::string&)streamName { + _GTMDevAssert(!_started, @"%@ must be called before the graph is started", + NSStringFromSelector(_cmd)); + _streamHeaders[streamName] = packet; +} + +- (void)setSidePacket:(const mediapipe::Packet&)packet named:(const std::string&)name { + _GTMDevAssert(!_started, @"%@ must be called before the graph is started", + NSStringFromSelector(_cmd)); + _inputSidePackets[name] = packet; +} + +- (void)addSidePackets:(const std::map&)extraSidePackets { + _GTMDevAssert(!_started, @"%@ must be called before the graph is started", + NSStringFromSelector(_cmd)); + _inputSidePackets.insert(extraSidePackets.begin(), extraSidePackets.end()); +} + +- (BOOL)startWithError:(NSError**)error { + ::mediapipe::Status status = _graph->Initialize(_config); + if (status.ok()) { + status = _graph->StartRun(_inputSidePackets, _streamHeaders); + if (status.ok()) { + _started = YES; + return YES; + } + } + if (error) { + *error = [NSError gus_errorWithStatus:status]; + } + return NO; +} + +- (void)cancel { + _graph->Cancel(); +} + +- (BOOL)hasInputStream:(const std::string&)inputName { + return _graph->HasInputStream(inputName); +} + +- (BOOL)closeInputStream:(const std::string&)inputName error:(NSError**)error { + ::mediapipe::Status status = _graph->CloseInputStream(inputName); + if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; + return status.ok(); +} + +- (BOOL)closeAllInputStreamsWithError:(NSError**)error { + ::mediapipe::Status status = _graph->CloseAllInputStreams(); + if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; + return status.ok(); +} + +- (BOOL)waitUntilDoneWithError:(NSError**)error { + // Since this method blocks with no timeout, it should not be called in the main thread in + // an app. However, it's fine to allow that in a test. + // TODO: is this too heavy-handed? Maybe a warning would be fine. + _GTMDevAssert(![NSThread isMainThread] || (NSClassFromString(@"XCTest")), + @"waitUntilDoneWithError: should not be called on the main thread"); + ::mediapipe::Status status = _graph->WaitUntilDone(); + _started = NO; + if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; + return status.ok(); +} + +- (BOOL)waitUntilIdleWithError:(NSError**)error { + ::mediapipe::Status status = _graph->WaitUntilIdle(); + if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; + return status.ok(); +} + +- (BOOL)movePacket:(mediapipe::Packet&&)packet + intoStream:(const std::string&)streamName + error:(NSError**)error { + ::mediapipe::Status status = _graph->AddPacketToInputStream(streamName, std::move(packet)); + if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; + return status.ok(); +} + +- (BOOL)sendPacket:(const mediapipe::Packet&)packet + intoStream:(const std::string&)streamName + error:(NSError**)error { + ::mediapipe::Status status = _graph->AddPacketToInputStream(streamName, packet); + if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; + return status.ok(); +} + +- (BOOL)setMaxQueueSize:(int)maxQueueSize + forStream:(const std::string&)streamName + error:(NSError**)error { + ::mediapipe::Status status = _graph->SetInputStreamMaxQueueSize(streamName, maxQueueSize); + if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; + return status.ok(); +} + +- (mediapipe::Packet)packetWithPixelBuffer:(CVPixelBufferRef)imageBuffer + packetType:(MediaPipePacketType)packetType { + mediapipe::Packet packet; + if (packetType == MediaPipePacketImageFrame || packetType == MediaPipePacketImageFrameBGRANoSwap) { + auto frame = CreateImageFrameForCVPixelBuffer( + imageBuffer, /* canOverwrite = */ false, + /* bgrAsRgb = */ packetType == MediaPipePacketImageFrameBGRANoSwap); + packet = mediapipe::Adopt(frame.release()); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } else if (packetType == MediaPipePacketPixelBuffer) { + packet = mediapipe::MakePacket(imageBuffer); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } else { + _GTMDevLog(@"unsupported packet type: %d", packetType); + } + return packet; +} + +- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer + intoStream:(const std::string&)inputName + packetType:(MediaPipePacketType)packetType + timestamp:(const mediapipe::Timestamp&)timestamp + allowOverwrite:(BOOL)allowOverwrite { + if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO; + mediapipe::Packet packet = [self packetWithPixelBuffer:imageBuffer packetType:packetType]; + NSError* error; + BOOL success; + if (allowOverwrite) { + packet = std::move(packet).At(timestamp); + success = [self movePacket:std::move(packet) + intoStream:inputName + error:&error]; + } else { + success = [self sendPacket:packet.At(timestamp) + intoStream:inputName + error:&error]; + } + if (success) _framesInFlight++; + else _GTMDevLog(@"failed to send packet: %@", error); + return success; +} + +- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer + intoStream:(const std::string&)inputName + packetType:(MediaPipePacketType)packetType + timestamp:(const mediapipe::Timestamp&)timestamp { + return [self sendPixelBuffer:imageBuffer + intoStream:inputName + packetType:packetType + timestamp:timestamp + allowOverwrite:NO]; +} + +- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer + intoStream:(const std::string&)inputName + packetType:(MediaPipePacketType)packetType { + _GTMDevAssert(_frameTimestamp < mediapipe::Timestamp::Done(), + @"Trying to send frame after stream is done."); + if (_frameTimestamp < mediapipe::Timestamp::Min()) { + _frameTimestamp = mediapipe::Timestamp::Min(); + } else { + _frameTimestamp++; + } + return [self sendPixelBuffer:imageBuffer + intoStream:inputName + packetType:packetType + timestamp:_frameTimestamp]; +} + +- (void)debugPrintGlInfo { + std::shared_ptr gpu_resources = _graph->GetGpuResources(); + if (!gpu_resources) { + NSLog(@"GPU not set up."); + return; + } + + NSString* extensionString; + (void)gpu_resources->gl_context()->Run([&extensionString]{ + extensionString = [NSString stringWithUTF8String:(char*)glGetString(GL_EXTENSIONS)]; + return ::mediapipe::OkStatus(); + }); + + NSArray* extensions = [extensionString componentsSeparatedByCharactersInSet: + [NSCharacterSet whitespaceCharacterSet]]; + for (NSString* oneExtension in extensions) + NSLog(@"%@", oneExtension); +} + +@end diff --git a/mediapipe/objc/MPPGraphTestBase.h b/mediapipe/objc/MPPGraphTestBase.h new file mode 100644 index 000000000..ecec1a1b2 --- /dev/null +++ b/mediapipe/objc/MPPGraphTestBase.h @@ -0,0 +1,141 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#import "mediapipe/objc/MPPGraph.h" +#import "mediapipe/objc/NSError+util_status.h" +#import "mediapipe/objc/util.h" + +/// This XCTestCase subclass provides common convenience methods for testing +/// with MPPGraph. +@interface MPPGraphTestBase : XCTestCase { + /// This block is used to respond to mediapipeGraph:didOutputPixelBuffer:fromStream:. + /// runGraph:withPixelBuffer:packetType: uses this internally, but you can reuse it + /// if you need to run a graph directly and want a MPPGraphTestBase object to + /// act as the delegate. + void (^_pixelBufferOutputBlock)(MPPGraph* graph, CVPixelBufferRef imageBuffer, + const std::string& streamName); + + /// This block is used to respond to mediapipeGraph:didOutputPacket:fromStream:. + /// You can use it if you need to run a graph directly and want a MPPGraphTestBase + /// object to act as the delegate. + void (^_packetOutputBlock)(MPPGraph* graph, const mediapipe::Packet& packet, + const std::string& streamName); +} + +/// Runs a single frame through a simple graph. The graph is expected to have an +/// input stream named "input_frames" and an output stream named +/// "output_frames". This function runs the graph, sends inputBuffer into +/// input_frames (at timestamp=1), receives an output buffer from output_frames, +/// completes the run, and returns the output frame. +- (CVPixelBufferRef)runGraph:(MPPGraph*)graph + withPixelBuffer:(CVPixelBufferRef)inputBuffer + packetType:(MediaPipePacketType)inputPacketType; + +/// Runs a simple graph, providing a single frame to zero or more inputs. Input images are wrapped +/// in packets each with timestamp mediapipe::Timestamp(1). Those packets are added to the +/// designated streams (named by the keys of withInputPixelBuffers). When a packet arrives on the +/// output stream, the graph run is done and the output frame is returned. +- (CVPixelBufferRef)runGraph:(MPPGraph*)graph + withInputPixelBuffers: + (const std::unordered_map>&)inputBuffers + outputStream:(const std::string&)output + packetType:(MediaPipePacketType)inputPacketType; + +/// Loads a data file from the test bundle. +- (NSData*)testDataNamed:(NSString*)name extension:(NSString*)extension; + +/// Loads an image from the test bundle. +- (UIImage*)testImageNamed:(NSString*)name extension:(NSString*)extension; + +/// Loads an image from the test bundle in subpath. +- (UIImage*)testImageNamed:(NSString*)name + extension:(NSString*)extension + subdirectory:(NSString*)subdirectory; + +/// Compares two pixel buffers for strict equality. +/// Returns true iff the two buffers have the same size, format, and pixel data. +- (BOOL)pixelBuffer:(CVPixelBufferRef)a isEqualTo:(CVPixelBufferRef)b; + +/// Compares two pixel buffers with some leniency. +/// Returns true iff the two buffers have the same size and format, and: +/// - the difference between each pixel of A and the corresponding pixel of B does +/// not exceed maxLocalDiff, and +/// - the average difference between corresponding pixels of A and B does not +/// exceed maxAvgDiff. +- (BOOL)pixelBuffer:(CVPixelBufferRef)a + isCloseTo:(CVPixelBufferRef)b + maxLocalDifference:(int)maxLocalDiff + maxAverageDifference:(float)maxAvgDiff; + +/// Utility function for making a copy of a pixel buffer with a different pixel +/// format. +- (CVPixelBufferRef)convertPixelBuffer:(CVPixelBufferRef)input toPixelFormat:(OSType)pixelFormat; + +/// Makes a scaled copy of a BGRA pixel buffer. +- (CVPixelBufferRef)scaleBGRAPixelBuffer:(CVPixelBufferRef)input toSize:(CGSize)size; + +/// Utility function for transforming a pixel buffer. +/// It creates a new pixel buffer with the same dimensions as the original, in the +/// desired pixel format, and invokes a block with the input and output buffers. +/// The buffers are locked before the block and unlocked after, so the block can read +/// from the input buffer and write to the output buffer without further preparation. +- (CVPixelBufferRef)transformPixelBuffer:(CVPixelBufferRef)input + outputPixelFormat:(OSType)pixelFormat + transformation:(void (^)(CVPixelBufferRef input, + CVPixelBufferRef output))transformation; + +/// Computes a difference image from two input images. Useful for debugging. +- (UIImage*)differenceOfImage:(UIImage*)inputA image:(UIImage*)inputB; + +/// Tests a graph by sending in the provided input pixel buffer and comparing the +/// output with the provided expected output. Uses runGraph:withPixelBuffer:packetType: +/// internally, so the streams are supposed to be named "input_frames" and "output_frames". +/// The actual and expected outputs are compared fuzzily. +- (void)testGraph:(MPPGraph*)graph + input:(CVPixelBufferRef)inputBuffer + expectedOutput:(CVPixelBufferRef)expectedBuffer; + +/// Tests a graph by sending the provided image files as pixelBuffer inputs to the +/// corresponding streams, and comparing the single frame output by the given output stream +/// with the contents of the given output file. +/// @param config Graph config. +/// @param fileInputs Dictionary mapping input stream names to image file paths. +/// @param packetInputs Map of input stream names to additional input packets. +/// @param sidePackets Map of input side packet stream names to packets. +/// @param outputStream Name of the output stream where the output is produced. +/// @param expectedPath Path to an image file containing the expected output. +/// @param maxAverageDifference The maximum allowable average pixel difference +/// between the +/// expected output and computed output. +/// TODO: Use NSDictionary instead of std::map for sidePackets. +- (void)testGraphConfig:(const mediapipe::CalculatorGraphConfig&)config + inputStreamsAndFiles:(NSDictionary*)fileInputs + inputStreamsAndPackets:(const std::map&)packetInputs + sidePackets:(std::map)sidePackets + timestamp:(mediapipe::Timestamp)timestamp + outputStream:(NSString*)outputStream + expectedOutputFile:(NSString*)expectedPath + maxAverageDifference:(float)maxAverageDifference; + +/// Calls the above testGraphConfig: method with a default maxAverageDifference +/// of 1.f and timestamp of 1. +- (void)testGraphConfig:(const mediapipe::CalculatorGraphConfig&)config + inputStreamsAndFiles:(NSDictionary*)inputs + outputStream:(NSString*)outputStream + expectedOutputFile:(NSString*)expectedPath; + +@end diff --git a/mediapipe/objc/MPPGraphTestBase.mm b/mediapipe/objc/MPPGraphTestBase.mm new file mode 100644 index 000000000..8bb65c354 --- /dev/null +++ b/mediapipe/objc/MPPGraphTestBase.mm @@ -0,0 +1,446 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/objc/MPPGraphTestBase.h" +#import "mediapipe/objc/Weakify.h" + +#include "absl/memory/memory.h" + +static UIImage* UIImageWithPixelBuffer(CVPixelBufferRef pixelBuffer) { + CFHolder cgImage; + ::mediapipe::Status status = CreateCGImageFromCVPixelBuffer(pixelBuffer, &cgImage); + if (!status.ok()) { + return nil; + } + UIImage *uiImage = [UIImage imageWithCGImage:*cgImage + scale:1.0 + orientation:UIImageOrientationUp]; + return uiImage; +} + +static void EnsureOutputDirFor(NSString *outputFile) { + NSFileManager *fileManager = [NSFileManager defaultManager]; + NSError *error = nil; + BOOL result = [fileManager createDirectoryAtPath:[outputFile stringByDeletingLastPathComponent] + withIntermediateDirectories:YES + attributes:nil + error:&error]; + // TODO: Log the error for clarity. The file-write will fail later + // but it would be nice to see this error. However, 'error' is still testing + // false and result is true even on an unwritable path-- not sure what's up. +} + +@implementation MPPGraphTestBase + +- (NSData*)testDataNamed:(NSString*)name extension:(NSString*)extension { + NSBundle* testBundle = [NSBundle bundleForClass:[self class]]; + NSURL* resourceURL = [testBundle URLForResource:name withExtension:extension]; + XCTAssertNotNil(resourceURL, + @"Unable to find data with name: %@. Did you add it to your resources?", name); + NSError* error; + NSData* data = [NSData dataWithContentsOfURL:resourceURL options:0 error:&error]; + XCTAssertNotNil(data, @"%@: %@", resourceURL.path, error); + return data; +} + +- (UIImage*)testImageNamed:(NSString*)name extension:(NSString*)extension { + return [self testImageNamed:name extension:extension subdirectory:nil]; +} + +- (UIImage*)testImageNamed:(NSString*)name + extension:(NSString*)extension + subdirectory:(NSString *)subdirectory { + // imageNamed does not work in our test bundle + NSBundle* testBundle = [NSBundle bundleForClass:[self class]]; + NSURL* imageURL = subdirectory ? + [testBundle URLForResource:name withExtension:extension subdirectory:subdirectory] : + [testBundle URLForResource:name withExtension:extension]; + XCTAssertNotNil(imageURL, + @"Unable to find image with name: %@. Did you add it to your resources?", name); + NSError* error; + NSData* imageData = [NSData dataWithContentsOfURL:imageURL options:0 error:&error]; + UIImage* image = [UIImage imageWithData:imageData]; + XCTAssertNotNil(image, @"%@: %@", imageURL.path, error); + return image; +} + +- (CVPixelBufferRef)runGraph:(MPPGraph*)graph + withInputPixelBuffers: + (const std::unordered_map>&)inputBuffers + inputPackets:(const std::map&)inputPackets + timestamp:(mediapipe::Timestamp)timestamp + outputStream:(const std::string&)outputStream + packetType:(MediaPipePacketType)inputPacketType { + __block CVPixelBufferRef output; + graph.delegate = self; + + // The XCTAssert macros contain references to self, which causes a retain cycle, + // since the block retains self and self retains the block. The cycle is broken + // at the end of this method, with _pixelBufferOutputBlock = nil, but Clang does + // not realize that and outputs a warning. WEAKIFY and STRONGIFY, though not + // strictly necessary, are used here to avoid the warning. + WEAKIFY(self); + if (!_pixelBufferOutputBlock) { + XCTestExpectation* outputReceived = [self expectationWithDescription:@"output received"]; + _pixelBufferOutputBlock = ^(MPPGraph* outputGraph, CVPixelBufferRef outputBuffer, + const std::string& outputStreamName) { + STRONGIFY(self); + XCTAssertEqualObjects(outputGraph, graph); + XCTAssertEqual(outputStreamName, outputStream); + CFRetain(outputBuffer); + output = outputBuffer; + [outputReceived fulfill]; + }; + } + + NSError *error; + BOOL success = [graph startWithError:&error]; + // Normally we continue after failures, but there is no sense in waiting for an + // output if the graph didn't even start. + BOOL savedContinue = self.continueAfterFailure; + self.continueAfterFailure = NO; + XCTAssert(success, @"%@", error.localizedDescription); + self.continueAfterFailure = savedContinue; + for (const auto& stream_buffer : inputBuffers) { + [graph sendPixelBuffer:*stream_buffer.second + intoStream:stream_buffer.first + packetType:inputPacketType + timestamp:timestamp]; + success = [graph closeInputStream:stream_buffer.first error:&error]; + XCTAssert(success, @"%@", error.localizedDescription); + } + for (const auto& stream_packet : inputPackets) { + [graph sendPacket:stream_packet.second + intoStream:stream_packet.first + error:&error]; + success = [graph closeInputStream:stream_packet.first error:&error]; + XCTAssert(success, @"%@", error.localizedDescription); + } + + XCTestExpectation* graphDone = [self expectationWithDescription:@"graph done"]; + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ + NSError *error; + BOOL success = [graph waitUntilDoneWithError:&error]; + XCTAssert(success, @"%@", error.localizedDescription); + [graphDone fulfill]; + }); + + [self waitForExpectationsWithTimeout:8.0 handler:NULL]; + _pixelBufferOutputBlock = nil; + return output; +} + +- (CVPixelBufferRef)runGraph:(MPPGraph*)graph + withPixelBuffer:(CVPixelBufferRef)inputBuffer + packetType:(MediaPipePacketType)inputPacketType { + return [self runGraph:graph + withInputPixelBuffers:{{"input_frames", MakeCFHolder(inputBuffer)}} + inputPackets:{} + timestamp:mediapipe::Timestamp(1) + outputStream:"output_frames" + packetType:inputPacketType]; +} + +- (CVPixelBufferRef)runGraph:(MPPGraph*)graph + withInputPixelBuffers: + (const std::unordered_map>&)inputBuffers + outputStream:(const std::string&)output + packetType:(MediaPipePacketType)inputPacketType { + return [self runGraph:graph + withInputPixelBuffers:inputBuffers + inputPackets:{} + timestamp:mediapipe::Timestamp(1) + outputStream:output + packetType:inputPacketType]; +} + +// By using a block to handle the delegate message, we can change the +// implementation for each test. +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPixelBuffer:(CVPixelBufferRef)imageBuffer + fromStream:(const std::string&)streamName { + _pixelBufferOutputBlock(graph, imageBuffer, streamName); +} + +- (void)mediapipeGraph:(MPPGraph*)graph + didOutputPacket:(const mediapipe::Packet&)packet + fromStream:(const std::string&)streamName { + _packetOutputBlock(graph, packet, streamName); +} + +- (BOOL)pixelBuffer:(CVPixelBufferRef)a isEqualTo:(CVPixelBufferRef)b { + return [self pixelBuffer:a isCloseTo:b maxLocalDifference:0 maxAverageDifference:0]; +} + +- (BOOL)pixelBuffer:(CVPixelBufferRef)a isCloseTo:(CVPixelBufferRef)b + maxLocalDifference:(int)maxLocalDiff maxAverageDifference:(float)maxAvgDiff { + size_t aBytesPerRow = CVPixelBufferGetBytesPerRow(a); + size_t aWidth = CVPixelBufferGetWidth(a); + size_t aHeight = CVPixelBufferGetHeight(a); + OSType aPixelFormat = CVPixelBufferGetPixelFormatType(a); + XCTAssertFalse(CVPixelBufferIsPlanar(a), @"planar buffers not supported"); + + size_t bBytesPerRow = CVPixelBufferGetBytesPerRow(b); + size_t bWidth = CVPixelBufferGetWidth(b); + size_t bHeight = CVPixelBufferGetHeight(b); + OSType bPixelFormat = CVPixelBufferGetPixelFormatType(b); + XCTAssertFalse(CVPixelBufferIsPlanar(b), @"planar buffers not supported"); + + if (aPixelFormat != bPixelFormat || + aWidth != bWidth || + aHeight != bHeight) return NO; + + size_t bytesPerPixel; // is there a generic way to get this from a pixel buffer? + switch (aPixelFormat) { + case kCVPixelFormatType_32BGRA: + bytesPerPixel = 4; + break; + case kCVPixelFormatType_OneComponent8: + bytesPerPixel = 1; + break; + default: + XCTFail(@"unsupported pixel format"); + } + + CVReturn err; + err = CVPixelBufferLockBaseAddress(a, kCVPixelBufferLock_ReadOnly); + XCTAssertEqual(err, kCVReturnSuccess); + err = CVPixelBufferLockBaseAddress(b, kCVPixelBufferLock_ReadOnly); + XCTAssertEqual(err, kCVReturnSuccess); + const uint8_t* aData = static_cast(CVPixelBufferGetBaseAddress(a)); + const uint8_t* bData = static_cast(CVPixelBufferGetBaseAddress(b)); + + // Let's not assume identical bytesPerRow. Also, the padding may not be equal + // even if bytesPerRow match. + size_t usedRowWidth = aWidth * bytesPerPixel; + BOOL equal = YES; + float averageDiff = 0; + float count = 0; + for (int i = aHeight; i > 0 && equal; --i) { + if (maxLocalDiff == 0) { + // If we can, use memcmp for speed. + equal = memcmp(aData, bData, usedRowWidth) == 0; + } else { + for (int j = 0; j < usedRowWidth; j++) { + int diff = abs(aData[j] - bData[j]); + if (diff > maxLocalDiff) { + equal = NO; + break; + } + // We use Welford's algorithm for computing a sample mean. This has better + // numerical stability than the naive method, as noted in TAoCP. Not that it + // particularly matters here. + // Welford: http://www.jstor.org/stable/1266577 + // Knuth: The Art of Computer Programming Vol 2, section 4.2.2 + averageDiff += (diff - averageDiff) / ++count; + } + } + aData += aBytesPerRow; + bData += bBytesPerRow; + } + if (averageDiff > maxAvgDiff) equal = NO; + + err = CVPixelBufferUnlockBaseAddress(b, kCVPixelBufferLock_ReadOnly); + XCTAssertEqual(err, kCVReturnSuccess); + err = CVPixelBufferUnlockBaseAddress(a, kCVPixelBufferLock_ReadOnly); + XCTAssertEqual(err, kCVReturnSuccess); + return equal; +} + +- (CVPixelBufferRef)convertPixelBuffer:(CVPixelBufferRef)input + toPixelFormat:(OSType)pixelFormat { + size_t width = CVPixelBufferGetWidth(input); + size_t height = CVPixelBufferGetHeight(input); + CVPixelBufferRef output; + CVReturn status = CVPixelBufferCreate( + kCFAllocatorDefault, width, height, pixelFormat, + GetCVPixelBufferAttributesForGlCompatibility(), &output); + XCTAssertEqual(status, kCVReturnSuccess); + + status = CVPixelBufferLockBaseAddress(input, kCVPixelBufferLock_ReadOnly); + XCTAssertEqual(status, kCVReturnSuccess); + + status = CVPixelBufferLockBaseAddress(output, 0); + XCTAssertEqual(status, kCVReturnSuccess); + + status = vImageConvertCVPixelBuffers(input, output); + XCTAssertEqual(status, kvImageNoError); + + status = CVPixelBufferUnlockBaseAddress(output, 0); + XCTAssertEqual(status, kCVReturnSuccess); + + status = CVPixelBufferUnlockBaseAddress(input, kCVPixelBufferLock_ReadOnly); + XCTAssertEqual(status, kCVReturnSuccess); + + return output; +} + +- (CVPixelBufferRef)scaleBGRAPixelBuffer:(CVPixelBufferRef)input + toSize:(CGSize)size { + CVPixelBufferRef output; + CVReturn status = CVPixelBufferCreate( + kCFAllocatorDefault, size.width, size.height, kCVPixelFormatType_32BGRA, + GetCVPixelBufferAttributesForGlCompatibility(), &output); + XCTAssertEqual(status, kCVReturnSuccess); + + status = CVPixelBufferLockBaseAddress(input, kCVPixelBufferLock_ReadOnly); + XCTAssertEqual(status, kCVReturnSuccess); + + status = CVPixelBufferLockBaseAddress(output, 0); + XCTAssertEqual(status, kCVReturnSuccess); + + vImage_Buffer src = vImageForCVPixelBuffer(input); + vImage_Buffer dst = vImageForCVPixelBuffer(output); + status = vImageScale_ARGB8888(&src, &dst, NULL, kvImageNoFlags); + XCTAssertEqual(status, kvImageNoError); + + status = CVPixelBufferUnlockBaseAddress(output, 0); + XCTAssertEqual(status, kCVReturnSuccess); + + status = CVPixelBufferUnlockBaseAddress(input, kCVPixelBufferLock_ReadOnly); + XCTAssertEqual(status, kCVReturnSuccess); + + return output; +} + +- (CVPixelBufferRef)transformPixelBuffer:(CVPixelBufferRef)input + outputPixelFormat:(OSType)pixelFormat + transformation:(void(^)(CVPixelBufferRef input, + CVPixelBufferRef output))transformation { + size_t width = CVPixelBufferGetWidth(input); + size_t height = CVPixelBufferGetHeight(input); + CVPixelBufferRef output; + CVReturn status = CVPixelBufferCreate( + kCFAllocatorDefault, width, height, + pixelFormat, NULL, &output); + XCTAssertEqual(status, kCVReturnSuccess); + + status = CVPixelBufferLockBaseAddress(input, kCVPixelBufferLock_ReadOnly); + XCTAssertEqual(status, kCVReturnSuccess); + + status = CVPixelBufferLockBaseAddress(output, 0); + XCTAssertEqual(status, kCVReturnSuccess); + + transformation(input, output); + + status = CVPixelBufferUnlockBaseAddress(output, 0); + XCTAssertEqual(status, kCVReturnSuccess); + + status = CVPixelBufferUnlockBaseAddress(input, kCVPixelBufferLock_ReadOnly); + XCTAssertEqual(status, kCVReturnSuccess); + + return output; +} + +- (UIImage*)differenceOfImage:(UIImage*)inputA image:(UIImage*)inputB { + UIGraphicsBeginImageContextWithOptions(inputA.size, YES, 1.0); + CGRect imageRect = CGRectMake(0, 0, inputA.size.width, inputA.size.height); + + [inputA drawInRect:imageRect blendMode:kCGBlendModeNormal alpha:1.0]; + [inputB drawInRect:imageRect blendMode:kCGBlendModeDifference alpha:1.0]; + + UIImage *differenceImage = UIGraphicsGetImageFromCurrentImageContext(); + UIGraphicsEndImageContext(); + + return differenceImage; +} + +- (void)testGraph:(MPPGraph*)graph + input:(CVPixelBufferRef)inputBuffer + expectedOutput:(CVPixelBufferRef)expectedBuffer +{ + CVPixelBufferRef outputBuffer = [self runGraph:graph + withPixelBuffer:inputBuffer + packetType:MediaPipePacketPixelBuffer]; +#if DEBUG + // Xcode can display UIImage objects right in the debugger. It is handy to + // have these variables defined if the test fails. + UIImage* output = UIImageWithPixelBuffer(outputBuffer); + XCTAssertNotNil(output); + UIImage* expected = UIImageWithPixelBuffer(expectedBuffer); + XCTAssertNotNil(expected); + UIImage* diff = [self differenceOfImage:output image:expected]; + (void)diff; // Suppress unused variable warning. +#endif + XCTAssert([self pixelBuffer:outputBuffer isCloseTo:expectedBuffer + maxLocalDifference:INT_MAX maxAverageDifference:1]); + CFRelease(outputBuffer); +} + +- (void)testGraphConfig:(const mediapipe::CalculatorGraphConfig&)config + inputStreamsAndFiles:(NSDictionary*)inputs + outputStream:(NSString*)outputStream + expectedOutputFile:(NSString*)expectedPath { + [self testGraphConfig:config + inputStreamsAndFiles:inputs + inputStreamsAndPackets:{} + sidePackets:{} + timestamp:mediapipe::Timestamp(1) + outputStream:outputStream + expectedOutputFile:expectedPath + maxAverageDifference:1.f]; +} + +- (void)testGraphConfig:(const mediapipe::CalculatorGraphConfig&)config + inputStreamsAndFiles:(NSDictionary*)fileInputs + inputStreamsAndPackets:(const std::map&)packetInputs + sidePackets:(std::map)sidePackets + timestamp:(mediapipe::Timestamp)timestamp + outputStream:(NSString*)outputStream + expectedOutputFile:(NSString*)expectedPath + maxAverageDifference:(float)maxAverageDifference { + NSBundle* testBundle = [NSBundle bundleForClass:[self class]]; + chdir([testBundle.resourcePath fileSystemRepresentation]); + MPPGraph* graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [graph addSidePackets:sidePackets]; + [graph addFrameOutputStream:outputStream.UTF8String + outputPacketType:MediaPipePacketPixelBuffer]; + + std::unordered_map> inputBuffers; + for (NSString* inputStream in fileInputs) { + UIImage* inputImage = [self testImageNamed:fileInputs[inputStream] extension:nil]; + XCTAssertNotNil(inputImage); + ::mediapipe::Status status = + CreateCVPixelBufferFromCGImage(inputImage.CGImage, &inputBuffers[inputStream.UTF8String]); + XCTAssert(status.ok()); + } + + UIImage* expectedImage = [self testImageNamed:expectedPath extension:nil]; + XCTAssertNotNil(expectedImage); + CFHolder expectedBuffer; + ::mediapipe::Status status = + CreateCVPixelBufferFromCGImage(expectedImage.CGImage, &expectedBuffer); + XCTAssert(status.ok()); + + CVPixelBufferRef outputBuffer = [self runGraph:graph + withInputPixelBuffers:inputBuffers + inputPackets:packetInputs + timestamp:timestamp + outputStream:outputStream.UTF8String + packetType:MediaPipePacketPixelBuffer]; + + UIImage* output = UIImageWithPixelBuffer(outputBuffer); + XCTAssertNotNil(output); + + UIImage* expected = UIImageWithPixelBuffer(*expectedBuffer); + XCTAssertNotNil(expected); + UIImage* diff = [self differenceOfImage:output image:expected]; + + XCTAssert([self pixelBuffer:outputBuffer isCloseTo:*expectedBuffer + maxLocalDifference:INT_MAX maxAverageDifference:maxAverageDifference]); + + CFRelease(outputBuffer); +} + +@end diff --git a/mediapipe/objc/MPPGraphTests.mm b/mediapipe/objc/MPPGraphTests.mm new file mode 100644 index 000000000..94c39108c --- /dev/null +++ b/mediapipe/objc/MPPGraphTests.mm @@ -0,0 +1,342 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include "absl/memory/memory.h" +#import "mediapipe/objc/MPPGraph.h" +#import "mediapipe/objc/MPPGraphTestBase.h" +#import "mediapipe/objc/NSError+util_status.h" +#import "mediapipe/objc/Weakify.h" +#import "mediapipe/objc/util.h" + +static const char* kExpectedError = "Expected error."; + +namespace mediapipe { + +class GrayscaleCalculator : public Calculator { + public: + static ::mediapipe::Status FillExpectations(const CalculatorOptions& options, + PacketTypeSet* inputs, PacketTypeSet* outputs, + PacketTypeSet* input_side_packets) { + inputs->Index(0).Set(); + outputs->Index(0).Set(); + return ::util::OkStatus(); + } + + ::mediapipe::Status Process() final { + const auto& input = Input()->Get(); + int w = input.Width(); + int h = input.Height(); + + auto output = absl::make_unique(ImageFormat::GRAY8, w, h); + + vImage_Buffer src = vImageForImageFrame(input); + vImage_Buffer dst = vImageForImageFrame(*output); + vImage_Error vErr = vImageRGBAToGray(&src, &dst); + NSCAssert(vErr == kvImageNoError, @"vImageRGBAToGray failed: %zd", vErr); + + Output()->Add(output.release(), InputTimestamp()); + return ::util::OkStatus(); + } +}; +REGISTER_CALCULATOR(GrayscaleCalculator); + +// For testing that video header is present. Open() will have a failure status +// if the video header is not present in the input stream. +class VideoHeaderCalculator : public Calculator { + public: + static ::mediapipe::Status FillExpectations(const CalculatorOptions& options, + PacketTypeSet* inputs, PacketTypeSet* outputs, + PacketTypeSet* input_side_packets) { + inputs->Index(0).Set(); + outputs->Index(0).Set(); + return ::util::OkStatus(); + } + + ::mediapipe::Status Open() final { + if (Input()->Header().IsEmpty()) { + return ::util::UnknownError("No video header present."); + } + return ::util::OkStatus(); + } + + ::mediapipe::Status Process() final { + Output()->AddPacket(Input()->Value()); + return ::util::OkStatus(); + } +}; +REGISTER_CALCULATOR(VideoHeaderCalculator); + +class ErrorCalculator : public Calculator { + public: + static ::mediapipe::Status FillExpectations(const CalculatorOptions& options, + PacketTypeSet* inputs, PacketTypeSet* outputs, + PacketTypeSet* input_side_packets) { + inputs->Index(0).SetAny(); + outputs->Index(0).SetSameAs(&inputs->Index(0)); + return ::util::OkStatus(); + } + + ::mediapipe::Status Process() final { + return ::mediapipe::Status(absl::StatusCode::kUnknown, kExpectedError); + } +}; +REGISTER_CALCULATOR(ErrorCalculator); + +} // namespace mediapipe + +@interface MPPGraphTests : MPPGraphTestBase{ + UIImage* _sourceImage; + MPPGraph* _graph; +} +@end + +@implementation MPPGraphTests + +- (void)setUp { + [super setUp]; + + _sourceImage = [self testImageNamed:@"googlelogo_color_272x92dp" extension:@"png"]; +} + +- (void)tearDown { + [super tearDown]; +} + +- (void)testPassThrough { + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_frames"); + auto node = config.add_node(); + node->set_calculator("PassThroughCalculator"); + node->add_input_stream("input_frames"); + node->add_output_stream("output_frames"); + + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"output_frames" outputPacketType:MediaPipePacketPixelBuffer]; + CFHolder inputBuffer; + ::mediapipe::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &inputBuffer); + XCTAssert(status.ok()); + CVPixelBufferRef outputBuffer = [self runGraph:_graph + withPixelBuffer:*inputBuffer + packetType:MediaPipePacketPixelBuffer]; + XCTAssert([self pixelBuffer:outputBuffer isEqualTo:*inputBuffer]); +} + +- (UIImage*)grayImage:(UIImage*)inputImage { + UIGraphicsBeginImageContextWithOptions(inputImage.size, YES, 1.0); + CGRect imageRect = CGRectMake(0, 0, inputImage.size.width, inputImage.size.height); + + // Draw the image with the luminosity blend mode. + // On top of a white background, this will give a black and white image. + [inputImage drawInRect:imageRect blendMode:kCGBlendModeLuminosity alpha:1.0]; + + UIImage *filteredImage = UIGraphicsGetImageFromCurrentImageContext(); + UIGraphicsEndImageContext(); + + return filteredImage; +} + +- (void)testMultipleOutputs { + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_frames"); + auto passThroughNode = config.add_node(); + passThroughNode->set_calculator("PassThroughCalculator"); + passThroughNode->add_input_stream("input_frames"); + passThroughNode->add_output_stream("pass_frames"); + auto grayNode = config.add_node(); + grayNode->set_calculator("GrayscaleCalculator"); + grayNode->add_input_stream("input_frames"); + grayNode->add_output_stream("gray_frames"); + + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"pass_frames" outputPacketType:MediaPipePacketImageFrame]; + [_graph addFrameOutputStream:"gray_frames" outputPacketType:MediaPipePacketImageFrame]; + + CFHolder inputBuffer; + ::mediapipe::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &inputBuffer); + XCTAssert(status.ok()); + + WEAKIFY(self); + XCTestExpectation* passFrameReceive = + [self expectationWithDescription:@"pass through output received"]; + XCTestExpectation* grayFrameReceive = + [self expectationWithDescription:@"grayscale output received"]; + _pixelBufferOutputBlock = ^(MPPGraph* outputGraph, CVPixelBufferRef outputBuffer, + const std::string& outputStreamName) { + STRONGIFY(self); + XCTAssertEqualObjects(outputGraph, self->_graph); + if (outputStreamName == "pass_frames") { + [passFrameReceive fulfill]; + } else if (outputStreamName == "gray_frames") { + [grayFrameReceive fulfill]; + } + }; + + [self runGraph:_graph withPixelBuffer:*inputBuffer packetType:MediaPipePacketImageFrame]; +} + +- (void)testGrayscaleOutput { + // When a calculator outputs a grayscale ImageFrame, it is returned to the + // application as a BGRA pixel buffer. To test it, let's feed a grayscale + // image into the graph and make sure it comes out unscathed. + UIImage* grayImage = [self grayImage:_sourceImage]; + + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_frames"); + auto node = config.add_node(); + node->set_calculator("GrayscaleCalculator"); + node->add_input_stream("input_frames"); + node->add_output_stream("output_frames"); + + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"output_frames" outputPacketType:MediaPipePacketImageFrame]; + CFHolder inputBuffer; + ::mediapipe::Status status = CreateCVPixelBufferFromCGImage(grayImage.CGImage, &inputBuffer); + XCTAssert(status.ok()); + CVPixelBufferRef outputBuffer = [self runGraph:_graph + withPixelBuffer:*inputBuffer + packetType:MediaPipePacketImageFrame]; + // We accept a small difference due to gamma correction and whatnot. + XCTAssert([self pixelBuffer:outputBuffer isCloseTo:*inputBuffer + maxLocalDifference:5 maxAverageDifference:FLT_MAX]); +} + +- (void)testGraphError { + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_frames"); + auto node = config.add_node(); + node->set_calculator("ErrorCalculator"); + node->add_input_stream("input_frames"); + node->add_output_stream("output_frames"); + CFHolder srcPixelBuffer; + ::mediapipe::Status status = + CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &srcPixelBuffer); + XCTAssert(status.ok()); + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"output_frames" outputPacketType:MediaPipePacketImageFrame]; + _graph.delegate = self; + + XCTAssert([_graph startWithError:nil]); + [_graph sendPixelBuffer:*srcPixelBuffer + intoStream:"input_frames" + packetType:MediaPipePacketImageFrame]; + XCTAssert([_graph closeInputStream:"input_frames" error:nil]); + + __block NSError* error = nil; + XCTestExpectation* graphDone = [self expectationWithDescription:@"graph done"]; + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ + XCTAssertFalse([_graph waitUntilDoneWithError:&error]); + [graphDone fulfill]; + }); + + [self waitForExpectationsWithTimeout:3.0 handler:NULL]; + XCTAssertNotNil(error); + status = error.gus_status; + XCTAssertNotEqual(status.error_message().find(kExpectedError), std::string::npos, + @"Missing expected std::string '%s' from error messge '%s'", kExpectedError, + status.error_message().c_str()); +} + +- (void)testSetStreamHeader { + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_frames"); + auto node = config.add_node(); + node->set_calculator("VideoHeaderCalculator"); + node->add_input_stream("input_frames"); + node->add_output_stream("output_frames"); + + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"output_frames" outputPacketType:MediaPipePacketImageFrame]; + + // We're no longer using video headers, let's just use an int as the header. + auto header_packet = mediapipe::MakePacket(0xDEADBEEF); + [_graph setHeaderPacket:header_packet forStream:"input_frames"]; + + // Verify that Open() on calculator succeeded. + XCTAssert([_graph startWithError:nil]); + + // Tear down graph. + XCTAssert([_graph closeInputStream:"input_frames" error:nil]); + XCTestExpectation* graphDone = [self expectationWithDescription:@"graph done"]; + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ + XCTAssert([_graph waitUntilDoneWithError:nil]); + [graphDone fulfill]; + }); + + [self waitForExpectationsWithTimeout:3.0 handler:NULL]; +} + +- (void)testGraphIsDeallocated { + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_frames"); + auto node = config.add_node(); + node->set_calculator("PassThroughCalculator"); + node->add_input_stream("input_frames"); + node->add_output_stream("output_frames"); + + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"output_frames" outputPacketType:MediaPipePacketPixelBuffer]; + CFHolder inputBuffer; + ::mediapipe::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &inputBuffer); + XCTAssert(status.ok()); + CVPixelBufferRef outputBuffer = [self runGraph:_graph + withPixelBuffer:*inputBuffer + packetType:MediaPipePacketPixelBuffer]; + __weak MPPGraph* weakGraph = _graph; + _graph = nil; + XCTAssertNil(weakGraph); +} + +- (void)testRawPacketOutput { + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input_ints"); + auto node = config.add_node(); + node->set_calculator("PassThroughCalculator"); + node->add_input_stream("input_ints"); + node->add_output_stream("output_ints"); + + const int kTestValue = 10; + + _graph = [[MPPGraph alloc] initWithGraphConfig:config]; + [_graph addFrameOutputStream:"output_ints" outputPacketType:MediaPipePacketRaw]; + _graph.delegate = self; + + WEAKIFY(self); + XCTestExpectation* outputReceived = [self expectationWithDescription:@"output received"]; + _packetOutputBlock = ^(MPPGraph* outputGraph, const mediapipe::Packet& packet, + const std::string& outputStreamName) { + STRONGIFY(self); + XCTAssertEqualObjects(outputGraph, _graph); + XCTAssertEqual(outputStreamName, "output_ints"); + XCTAssertEqual(packet.Get(), kTestValue); + [outputReceived fulfill]; + }; + + XCTAssert([_graph startWithError:nil]); + XCTAssert([_graph sendPacket:mediapipe::MakePacket(kTestValue).At(mediapipe::Timestamp(1)) + intoStream:"input_ints" + error:nil]); + XCTAssert([_graph closeInputStream:"input_ints" error:nil]); + XCTestExpectation* graphDone = [self expectationWithDescription:@"graph done"]; + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ + XCTAssert([_graph waitUntilDoneWithError:nil]); + [graphDone fulfill]; + }); + + [self waitForExpectationsWithTimeout:3.0 handler:NULL]; +} + +@end diff --git a/mediapipe/objc/MPPInputSource.h b/mediapipe/objc/MPPInputSource.h new file mode 100644 index 000000000..aa86b4ffc --- /dev/null +++ b/mediapipe/objc/MPPInputSource.h @@ -0,0 +1,70 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +@class MPPInputSource; + +/// A delegate that can receive frames from a source. +@protocol MPPInputSourceDelegate + +/// Provides the delegate with a new video frame. +@optional +- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer + fromSource:(MPPInputSource*)source __deprecated; + +/// Provides the delegate with a new video frame. +@optional +- (void)processVideoFrame:(CVPixelBufferRef)imageBuffer + timestamp:(CMTime)timestamp + fromSource:(MPPInputSource*)source; + +// Provides the delegate with a new depth frame data +@optional +- (void)processDepthData:(AVDepthData*)depthData + timestamp:(CMTime)timestamp + fromSource:(MPPInputSource*)source; + +@optional +- (void)videoDidPlayToEnd:(CMTime)timestamp; + +@end + +/// Abstract class for a video source. +@interface MPPInputSource : NSObject + +/// The delegate that receives the frames. +@property(weak, nonatomic, readonly) id delegate; + +/// The dispatch queue on which to schedule the delegate callback. +@property(nonatomic, readonly) dispatch_queue_t delegateQueue; + +/// Whether the source is currently running. +@property(nonatomic, getter=isRunning, readonly) BOOL running; + +/// Sets the delegate and the queue on which its callback should be invoked. +- (void)setDelegate:(id)delegate queue:(dispatch_queue_t)queue; + +/// CoreVideo pixel format for the video frames. Defaults to +/// kCVPixelFormatType_32BGRA. +@property(nonatomic) OSType pixelFormatType; + +/// Starts the source. +- (void)start; + +/// Stops the source. +- (void)stop; + +@end diff --git a/mediapipe/objc/MPPInputSource.m b/mediapipe/objc/MPPInputSource.m new file mode 100644 index 000000000..92fab436a --- /dev/null +++ b/mediapipe/objc/MPPInputSource.m @@ -0,0 +1,39 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "MPPInputSource.h" + +#define ABSTRACT_METHOD \ + @throw [NSException exceptionWithName:NSInternalInconsistencyException \ + reason:[NSString stringWithFormat:@"You must override %@ in a subclass", \ + NSStringFromSelector(_cmd)] \ + userInfo:nil]; + +@implementation MPPInputSource + +- (void)setDelegate:(id)delegate + queue:(dispatch_queue_t)queue { + _delegate = delegate; + _delegateQueue = queue; +} + +- (void)start { + ABSTRACT_METHOD +} + +- (void)stop { + ABSTRACT_METHOD +} + +@end diff --git a/mediapipe/objc/MPPLayerRenderer.h b/mediapipe/objc/MPPLayerRenderer.h new file mode 100644 index 000000000..c7f2a4338 --- /dev/null +++ b/mediapipe/objc/MPPLayerRenderer.h @@ -0,0 +1,39 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#import "mediapipe/objc/MPPGLViewRenderer.h" + +/// Renders frames in a Core Animation layer. +@interface MPPLayerRenderer : NSObject + +@property(nonatomic, readonly) CAEAGLLayer *layer; + +/// Updates the layer with a new pixel buffer. +- (void)renderPixelBuffer:(CVPixelBufferRef)pixelBuffer; + +/// Sets which way to rotate input frames before rendering them. +/// Default value is MediaPipeFrameRotationNone. +@property(nonatomic) MediaPipeFrameRotationMode frameRotationMode; + +/// Sets how to scale the frame within the layer. +/// Default value is MediaPipeFrameScaleScaleToFit. +@property(nonatomic) MediaPipeFrameScaleMode frameScaleMode; + +/// If YES, swap left and right. Useful for the front camera. +@property(nonatomic) BOOL mirrored; + +@end diff --git a/mediapipe/objc/MPPLayerRenderer.m b/mediapipe/objc/MPPLayerRenderer.m new file mode 100644 index 000000000..de482c397 --- /dev/null +++ b/mediapipe/objc/MPPLayerRenderer.m @@ -0,0 +1,100 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "MPPLayerRenderer.h" + +@implementation MPPLayerRenderer { + /// Used for rendering to a CAEAGLLayer. + MPPGLViewRenderer *_glRenderer; + + /// Used for rendering. + GLuint framebuffer_; + GLuint renderbuffer_; +} + +- (instancetype)init { + self = [super init]; + if (self) { + _layer = [[CAEAGLLayer alloc] init]; + // The default drawable properties are ok. + _layer.opaque = YES; + // Synchronizes presentation with CoreAnimation transactions. + // Avoids desync between GL and CA updates. + // TODO: should this be on by default? It's not on in GLKView, + // but if we add support for AVSampleBufferDisplayLayer we may want to + // make sure the behavior is similar. + if ([_layer respondsToSelector:@selector(setPresentsWithTransaction:)]) { + _layer.presentsWithTransaction = YES; + } + _layer.contentsScale = [[UIScreen mainScreen] scale]; + _glRenderer = [[MPPGLViewRenderer alloc] init]; + } + return self; +} + +- (void)dealloc { +} + +- (void)setupFrameBuffer { + [EAGLContext setCurrentContext:_glRenderer.glContext]; + glDisable(GL_DEPTH_TEST); + glGenFramebuffers(1, &framebuffer_); + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + glGenRenderbuffers(1, &renderbuffer_); + glBindRenderbuffer(GL_RENDERBUFFER, renderbuffer_); + glFramebufferRenderbuffer(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_RENDERBUFFER, renderbuffer_); + [_glRenderer.glContext renderbufferStorage:GL_RENDERBUFFER fromDrawable:_layer]; + GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); + NSAssert(status == GL_FRAMEBUFFER_COMPLETE, + @"failed to make complete framebuffer object %x", status); +} + +- (void)renderPixelBuffer:(CVPixelBufferRef)pixelBuffer { + if (!framebuffer_) { + [self setupFrameBuffer]; + } + GLfloat drawWidth = _layer.bounds.size.width * _layer.contentsScale; + GLfloat drawHeight = _layer.bounds.size.height * _layer.contentsScale; + [EAGLContext setCurrentContext:_glRenderer.glContext]; + glViewport(0, 0, drawWidth, drawHeight); + [_glRenderer drawPixelBuffer:pixelBuffer width:drawWidth height:drawHeight]; + BOOL success = [_glRenderer.glContext presentRenderbuffer:GL_RENDERBUFFER]; + if (!success) NSLog(@"presentRenderbuffer failed"); +} + +- (MediaPipeFrameRotationMode)frameRotationMode { + return _glRenderer.frameRotationMode; +} + +- (void)setFrameRotationMode:(MediaPipeFrameRotationMode)frameRotationMode { + _glRenderer.frameRotationMode = frameRotationMode; +} + +- (MediaPipeFrameScaleMode)frameScaleMode { + return _glRenderer.frameScaleMode; +} + +- (void)setFrameScaleMode:(MediaPipeFrameScaleMode)frameScaleMode { + _glRenderer.frameScaleMode = frameScaleMode; +} + +- (BOOL)mirrored { + return _glRenderer.mirrored; +} + +- (void)setMirrored:(BOOL)mirrored { + _glRenderer.mirrored = mirrored; +} + +@end diff --git a/mediapipe/objc/MPPPlayerInputSource.h b/mediapipe/objc/MPPPlayerInputSource.h new file mode 100644 index 000000000..e1516abe9 --- /dev/null +++ b/mediapipe/objc/MPPPlayerInputSource.h @@ -0,0 +1,33 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "MPPInputSource.h" + +/// A source that outputs frames from a video, played in real time. +/// Not meant for batch processing of video. +@interface MPPPlayerInputSource : MPPInputSource + +/// Designated initializer. +- (instancetype)initWithAVAsset:(AVAsset*)video; + +/// Skip into video @c time from beginning (time 0), within error of +/- tolerance to closest time. +- (void)seekToTime:(CMTime)time tolerance:(CMTime)tolerance; + +/// Set time into video at which to end playback. +- (void)setPlaybackEndTime:(CMTime)time; + +/// Returns the current video's timestamp. +- (CMTime)currentPlayerTime; + +@end diff --git a/mediapipe/objc/MPPPlayerInputSource.m b/mediapipe/objc/MPPPlayerInputSource.m new file mode 100644 index 000000000..8aaaa7e29 --- /dev/null +++ b/mediapipe/objc/MPPPlayerInputSource.m @@ -0,0 +1,127 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "MPPPlayerInputSource.h" +#import "mediapipe/objc/MPPDisplayLinkWeakTarget.h" + +@implementation MPPPlayerInputSource { + AVAsset* _video; + AVPlayerItem* _videoItem; + AVPlayer* _videoPlayer; + AVPlayerItemVideoOutput* _videoOutput; + CADisplayLink* _videoDisplayLink; + MPPDisplayLinkWeakTarget* _displayLinkWeakTarget; + id _videoEndObserver; +} + +- (instancetype)initWithAVAsset:(AVAsset*)video { + self = [super init]; + if (self) { + _video = video; + _videoItem = [AVPlayerItem playerItemWithAsset:_video]; + // Necessary to ensure the video's preferred transform is respected. + _videoItem.videoComposition = [AVVideoComposition videoCompositionWithPropertiesOfAsset:_video]; + + _videoOutput = [[AVPlayerItemVideoOutput alloc] initWithPixelBufferAttributes:@{ + (id)kCVPixelBufferPixelFormatTypeKey : @(kCVPixelFormatType_32BGRA), + (id)kCVPixelBufferIOSurfacePropertiesKey : [NSDictionary dictionary] + }]; + _videoOutput.suppressesPlayerRendering = YES; + [_videoItem addOutput:_videoOutput]; + + _displayLinkWeakTarget = + [[MPPDisplayLinkWeakTarget alloc] initWithTarget:self selector:@selector(videoUpdate:)]; + + _videoDisplayLink = [CADisplayLink displayLinkWithTarget:_displayLinkWeakTarget + selector:@selector(displayLinkCallback:)]; + _videoDisplayLink.paused = YES; + [_videoDisplayLink addToRunLoop:[NSRunLoop mainRunLoop] forMode:NSRunLoopCommonModes]; + + _videoPlayer = [AVPlayer playerWithPlayerItem:_videoItem]; + _videoPlayer.actionAtItemEnd = AVPlayerActionAtItemEndNone; + NSNotificationCenter* center = [NSNotificationCenter defaultCenter]; + _videoEndObserver = [center addObserverForName:AVPlayerItemDidPlayToEndTimeNotification + object:_videoItem + queue:nil + usingBlock:^(NSNotification* note) { + [self playerItemDidPlayToEnd:note]; + }]; + } + return self; +} + +- (void)start { + [_videoPlayer play]; + _videoDisplayLink.paused = NO; +} + +- (void)stop { + _videoDisplayLink.paused = YES; + [_videoPlayer pause]; +} + +- (BOOL)isRunning { + return _videoPlayer.rate != 0.0; +} + +- (void)videoUpdate:(CADisplayLink*)sender { + CMTime timestamp = [_videoItem currentTime]; + if ([_videoOutput hasNewPixelBufferForItemTime:timestamp]) { + CVPixelBufferRef pixelBuffer = + [_videoOutput copyPixelBufferForItemTime:timestamp itemTimeForDisplay:nil]; + if (pixelBuffer) + dispatch_async(self.delegateQueue, ^{ + if ([self.delegate respondsToSelector:@selector(processVideoFrame:timestamp:fromSource:)]) { + [self.delegate processVideoFrame:pixelBuffer timestamp:timestamp fromSource:self]; + } else if ([self.delegate respondsToSelector:@selector(processVideoFrame:fromSource:)]) { + [self.delegate processVideoFrame:pixelBuffer fromSource:self]; + } + CFRelease(pixelBuffer); + }); + } +} + +- (void)dealloc { + [[NSNotificationCenter defaultCenter] removeObserver:self]; + [_videoDisplayLink invalidate]; + _videoPlayer = nil; +} + +#pragma mark - NSNotificationCenter / observer + +- (void)playerItemDidPlayToEnd:(NSNotification*)notification { + CMTime timestamp = [_videoItem currentTime]; + dispatch_async(self.delegateQueue, ^{ + if ([self.delegate respondsToSelector:@selector(videoDidPlayToEnd:)]) { + [self.delegate videoDidPlayToEnd:timestamp]; + } else { + // Default to loop if no delegate handler set. + [_videoPlayer seekToTime:kCMTimeZero]; + } + }); +} + +- (void)seekToTime:(CMTime)time tolerance:(CMTime)tolerance { + [_videoPlayer seekToTime:time toleranceBefore:tolerance toleranceAfter:tolerance]; +} + +- (void)setPlaybackEndTime:(CMTime)time { + _videoPlayer.currentItem.forwardPlaybackEndTime = time; +} + +- (CMTime)currentPlayerTime { + return _videoPlayer.currentTime; +} + +@end diff --git a/mediapipe/objc/MPPTimestampConverter.h b/mediapipe/objc/MPPTimestampConverter.h new file mode 100644 index 000000000..44a6fc63f --- /dev/null +++ b/mediapipe/objc/MPPTimestampConverter.h @@ -0,0 +1,43 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/objc/util.h" + +/// Helps convert a CMTime to a MediaPipe timestamp. +@interface MPPTimestampConverter : NSObject + +/// The last timestamp returned by timestampForMediaTime:. +@property(nonatomic, readonly) mediapipe::Timestamp lastTimestamp; + +/// Initializer. +- (instancetype)init NS_DESIGNATED_INITIALIZER; + +/// Resets the object. After this method is called, we can return timestamps +/// that are lower than previously returned timestamps. +- (void)reset; + +/// Converts a CMTime to a MediaPipe timestamp. This ensures that MediaPipe +/// timestamps +/// are always increasing: if the provided CMTime has gone backwards (e.g. if +/// it's from a +/// looping video), we shift all timestamps from that point on to keep the +/// output increasing. +/// This state is erased when reset is called. +- (mediapipe::Timestamp)timestampForMediaTime:(CMTime)mediaTime; + +@end diff --git a/mediapipe/objc/MPPTimestampConverter.mm b/mediapipe/objc/MPPTimestampConverter.mm new file mode 100644 index 000000000..c7b66bd5a --- /dev/null +++ b/mediapipe/objc/MPPTimestampConverter.mm @@ -0,0 +1,50 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "MPPTimestampConverter.h" + +@implementation MPPTimestampConverter { + mediapipe::Timestamp _mediapipeTimestamp; + mediapipe::Timestamp _lastTimestamp; + mediapipe::TimestampDiff _timestampOffset; +} + +- (instancetype)init +{ + self = [super init]; + if (self) { + [self reset]; + } + return self; +} + +- (void)reset { + _mediapipeTimestamp = mediapipe::Timestamp::Min(); + _lastTimestamp = _mediapipeTimestamp; + _timestampOffset = 0; +} + +- (mediapipe::Timestamp)timestampForMediaTime:(CMTime)mediaTime { + float sampleSeconds = CMTIME_IS_VALID(mediaTime) ? CMTimeGetSeconds(mediaTime) : 0; + const int64 sampleUsec = sampleSeconds * mediapipe::Timestamp::kTimestampUnitsPerSecond; + _mediapipeTimestamp = mediapipe::Timestamp(sampleUsec) + _timestampOffset; + if (_mediapipeTimestamp <= _lastTimestamp) { + _timestampOffset = _timestampOffset + _lastTimestamp + 1 - _mediapipeTimestamp; + _mediapipeTimestamp = _lastTimestamp + 1; + } + _lastTimestamp = _mediapipeTimestamp; + return _mediapipeTimestamp; +} + +@end diff --git a/mediapipe/objc/NSError+util_status.h b/mediapipe/objc/NSError+util_status.h new file mode 100644 index 000000000..ebbc6fb6e --- /dev/null +++ b/mediapipe/objc/NSError+util_status.h @@ -0,0 +1,48 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "Foundation/Foundation.h" + +#include "mediapipe/framework/port/status.h" + +/// Error domain for ::mediapipe::Status errors. +extern NSString *const kGUSGoogleUtilStatusErrorDomain; + +/// Key for the ::mediapipe::Status wrapper in an NSError's user info dictionary. +extern NSString *const kGUSGoogleUtilStatusErrorKey; + +/// This just wraps ::mediapipe::Status into an Objective-C object. +@interface GUSUtilStatusWrapper : NSObject + +@property(nonatomic)::mediapipe::Status status; + ++ (instancetype)wrapStatus:(const ::mediapipe::Status &)status; + +@end + +/// This category adds methods for generating NSError objects from ::mediapipe::Status +/// objects, and vice versa. +@interface NSError (GUSGoogleUtilStatus) + +/// Generates an NSError representing a ::mediapipe::Status. Note that NSError always +/// represents an error, so this should not be called with ::mediapipe::Status::OK. ++ (NSError *)gus_errorWithStatus:(const ::mediapipe::Status &)status; + +/// Returns a ::mediapipe::Status object representing an NSError. If the NSError was +/// generated from a ::mediapipe::Status, the ::mediapipe::Status returned is identical to +/// the original. Otherwise, this returns a status with code ::util::error::UNKNOWN +/// and a message extracted from the NSError. +@property(nonatomic, readonly)::mediapipe::Status gus_status; // NOLINT(identifier-naming) + +@end diff --git a/mediapipe/objc/NSError+util_status.mm b/mediapipe/objc/NSError+util_status.mm new file mode 100644 index 000000000..715be2674 --- /dev/null +++ b/mediapipe/objc/NSError+util_status.mm @@ -0,0 +1,69 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/objc/NSError+util_status.h" + +@implementation GUSUtilStatusWrapper + ++ (instancetype)wrapStatus:(const ::mediapipe::Status &)status { + return [[self alloc] initWithStatus:status]; +} + +- (instancetype)initWithStatus:(const ::mediapipe::Status &)status { + self = [super init]; + if (self) { + _status = status; + } + return self; +} + +- (NSString *)description { + return [NSString stringWithFormat:@"<%@: %p; status = %s>", + [self class], self, _status.error_message().c_str()]; +} + +@end + +@implementation NSError (GUSGoogleUtilStatus) + +NSString *const kGUSGoogleUtilStatusErrorDomain = @"GoogleUtilStatusErrorDomain"; +NSString *const kGUSGoogleUtilStatusErrorKey = @"GUSGoogleUtilStatusErrorKey"; + ++ (NSError *)gus_errorWithStatus:(const ::mediapipe::Status &)status { + NSDictionary *userInfo = @{ + NSLocalizedDescriptionKey : @(status.error_message().c_str()), + kGUSGoogleUtilStatusErrorKey : [GUSUtilStatusWrapper wrapStatus:status], + }; + NSError *error = [NSError errorWithDomain:kGUSGoogleUtilStatusErrorDomain + code:static_cast(status.code()) + userInfo:userInfo]; + return error; +} + +- (::mediapipe::Status)gus_status { + NSString *domain = self.domain; + if ([domain isEqual:kGUSGoogleUtilStatusErrorDomain]) { + GUSUtilStatusWrapper *wrapper = self.userInfo[kGUSGoogleUtilStatusErrorKey]; + if (wrapper) return wrapper.status; +#if 0 + // Unfortunately, util/task/posixerrorspace.h is not in portable status yet. + // TODO: fix that. + } else if ([domain isEqual:NSPOSIXErrorDomain]) { + return ::util::PosixErrorToStatus(self.code, self.localizedDescription.UTF8String); +#endif + } + return ::mediapipe::Status(mediapipe::StatusCode::kUnknown, self.localizedDescription.UTF8String); +} + +@end diff --git a/mediapipe/objc/Weakify.h b/mediapipe/objc/Weakify.h new file mode 100644 index 000000000..e507325e9 --- /dev/null +++ b/mediapipe/objc/Weakify.h @@ -0,0 +1,54 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// OBJC_LINTER + +// TODO: check license. This was forked from an internal header. + +/** + * __WEAKNAME_ is a private macro used to generate a local variable name related + * to the argument variable name. This generated local variable name is + * intentionally stable across multiple invocations. + */ +#define __WEAKNAME_(variable) variable##_weak_ + +/** + * WEAKIFY defines a new local variable that is a weak reference to the argument + * variable. + * + * This macro is generally used to capture a weak reference to be captured by an + * Objective-C block to avoid unintentionally extending an object's lifetime or + * avoid causing a retain cycle. + * + * The new local variable's name will be based on the name of the target + * variable and is stable across multiple invocations of WEAKIFY. In general, + * you should not need to invoke WEAKIFY multiple times on the same variable. + */ +#define WEAKIFY(variable) \ + __weak __typeof__(variable) __WEAKNAME_(variable) = (variable) + +/** + * STRONGIFY defines a new shadow local variable with the same name as the + * argument variable and initialize it with a resolved weak reference based on a + * weak reference created previously using the WEAKIFY macro. + * + * @note: + * This macro must be called within each block scope to prevent nested blocks + * from capturing a strong reference from an outer block. + */ +#define STRONGIFY(variable) \ + _Pragma("clang diagnostic push") \ + _Pragma("clang diagnostic ignored \"-Wshadow\"") \ + __strong __typeof__(variable) variable = \ + __WEAKNAME_(variable) _Pragma("clang diagnostic pop") diff --git a/mediapipe/objc/testdata/googlelogo_color_272x92dp.png b/mediapipe/objc/testdata/googlelogo_color_272x92dp.png new file mode 100644 index 000000000..ed8841bd0 Binary files /dev/null and b/mediapipe/objc/testdata/googlelogo_color_272x92dp.png differ diff --git a/mediapipe/objc/testdata/googlelogo_color_272x92dp_luminance.png b/mediapipe/objc/testdata/googlelogo_color_272x92dp_luminance.png new file mode 100644 index 000000000..f34a8eca3 Binary files /dev/null and b/mediapipe/objc/testdata/googlelogo_color_272x92dp_luminance.png differ diff --git a/mediapipe/objc/testdata/sergey.png b/mediapipe/objc/testdata/sergey.png new file mode 100644 index 000000000..44231bd99 Binary files /dev/null and b/mediapipe/objc/testdata/sergey.png differ diff --git a/mediapipe/objc/testdata/sobel_reference.png b/mediapipe/objc/testdata/sobel_reference.png new file mode 100644 index 000000000..8e262817d Binary files /dev/null and b/mediapipe/objc/testdata/sobel_reference.png differ diff --git a/mediapipe/objc/testdata/test_sobel.pbtxt b/mediapipe/objc/testdata/test_sobel.pbtxt new file mode 100644 index 000000000..eed090390 --- /dev/null +++ b/mediapipe/objc/testdata/test_sobel.pbtxt @@ -0,0 +1,12 @@ +input_stream: "input_video" + +node: { + calculator: "GlLuminanceCalculator" + input_stream: "input_video" + output_stream: "luma_video" +} +node: { + calculator: "GlSobelCalculator" + input_stream: "luma_video" + output_stream: "output_video" +} diff --git a/mediapipe/objc/util.cc b/mediapipe/objc/util.cc new file mode 100644 index 000000000..2316043ae --- /dev/null +++ b/mediapipe/objc/util.cc @@ -0,0 +1,573 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/objc/util.h" + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/source_location.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_builder.h" + +namespace { + +// NOTE: you must release the colorspace returned by this function, unless +// it's null. +// Returns an invalid format (all fields 0) if the requested format is +// unsupported. +vImage_CGImageFormat vImageFormatForCVPixelFormat(OSType pixel_format) { + switch (pixel_format) { + case kCVPixelFormatType_OneComponent8: + return { + .bitsPerComponent = 8, + .bitsPerPixel = 8, + .colorSpace = CGColorSpaceCreateDeviceGray(), + .bitmapInfo = kCGImageAlphaNone | kCGBitmapByteOrderDefault, + }; + + case kCVPixelFormatType_32BGRA: + return { + .bitsPerComponent = 8, + .bitsPerPixel = 32, + .colorSpace = NULL, + .bitmapInfo = kCGImageAlphaFirst | kCGBitmapByteOrder32Little, + }; + + case kCVPixelFormatType_32RGBA: + return { + .bitsPerComponent = 8, + .bitsPerPixel = 32, + .colorSpace = NULL, + .bitmapInfo = kCGImageAlphaLast | kCGBitmapByteOrderDefault, + }; + + default: + return {}; + } +} + +CGColorSpaceRef CreateConversionCGColorSpaceForPixelFormat( + OSType pixel_format) { + // According to vImage documentation, YUV formats require the RGB colorspace + // in which the RGB conversion should be interpreted. sRGB is suggested. + // We cannot just pass sRGB all the time, though, since it breaks with + // monochrome. + switch (pixel_format) { + case kCVPixelFormatType_422YpCbCr8: + case kCVPixelFormatType_4444YpCbCrA8: + case kCVPixelFormatType_4444YpCbCrA8R: + case kCVPixelFormatType_4444AYpCbCr8: + case kCVPixelFormatType_4444AYpCbCr16: + case kCVPixelFormatType_444YpCbCr8: + case kCVPixelFormatType_422YpCbCr16: + case kCVPixelFormatType_422YpCbCr10: + case kCVPixelFormatType_444YpCbCr10: + case kCVPixelFormatType_420YpCbCr8Planar: + case kCVPixelFormatType_420YpCbCr8PlanarFullRange: + case kCVPixelFormatType_422YpCbCr_4A_8BiPlanar: + case kCVPixelFormatType_420YpCbCr8BiPlanarVideoRange: + case kCVPixelFormatType_420YpCbCr8BiPlanarFullRange: + case kCVPixelFormatType_422YpCbCr8_yuvs: + case kCVPixelFormatType_422YpCbCr8FullRange: + return CGColorSpaceCreateWithName(kCGColorSpaceSRGB); + + default: + return NULL; + } +} + +vImageConverterRef vImageConverterForCVPixelFormats(OSType src_pixel_format, + OSType dst_pixel_format, + vImage_Error* error) { + static CGFloat default_background[3] = {1.0, 1.0, 1.0}; + vImageConverterRef converter = NULL; + + vImage_CGImageFormat src_cg_format = + vImageFormatForCVPixelFormat(src_pixel_format); + vImage_CGImageFormat dst_cg_format = + vImageFormatForCVPixelFormat(dst_pixel_format); + + // Use CV format functions if available (introduced in iOS 8). + // Weak-linked symbols are NULL when not available. + if (&vImageConverter_CreateForCGToCVImageFormat != NULL) { + // Strangely, there is no function to convert between two + // vImageCVImageFormat, so one side has to use a vImage_CGImageFormat + // that we have to find ourselves. + if (src_cg_format.bitsPerComponent > 0) { + // We can handle source using a CGImageFormat. + // TODO: check the final alpha hint parameter + CGColorSpaceRef cv_color_space = + CreateConversionCGColorSpaceForPixelFormat(dst_pixel_format); + vImageCVImageFormatRef dst_cv_format = vImageCVImageFormat_Create( + dst_pixel_format, kvImage_ARGBToYpCbCrMatrix_ITU_R_709_2, + kCVImageBufferChromaLocation_Center, cv_color_space, 1); + CGColorSpaceRelease(cv_color_space); + + converter = vImageConverter_CreateForCGToCVImageFormat( + &src_cg_format, dst_cv_format, default_background, + kvImagePrintDiagnosticsToConsole, error); + vImageCVImageFormat_Release(dst_cv_format); + } else if (dst_cg_format.bitsPerComponent > 0) { + // We can use a CGImageFormat for the destination. + CGColorSpaceRef cv_color_space = + CreateConversionCGColorSpaceForPixelFormat(src_pixel_format); + vImageCVImageFormatRef src_cv_format = vImageCVImageFormat_Create( + src_pixel_format, kvImage_ARGBToYpCbCrMatrix_ITU_R_709_2, + kCVImageBufferChromaLocation_Center, cv_color_space, 1); + CGColorSpaceRelease(cv_color_space); + + converter = vImageConverter_CreateForCVToCGImageFormat( + src_cv_format, &dst_cg_format, default_background, + kvImagePrintDiagnosticsToConsole, error); + vImageCVImageFormat_Release(src_cv_format); + } + } + + if (!converter) { + // Try a CG to CG conversion. + if (src_cg_format.bitsPerComponent > 0 && + dst_cg_format.bitsPerComponent > 0) { + converter = vImageConverter_CreateWithCGImageFormat( + &src_cg_format, &dst_cg_format, default_background, kvImageNoFlags, + error); + } + } + + CGColorSpaceRelease(src_cg_format.colorSpace); + CGColorSpaceRelease(dst_cg_format.colorSpace); + return converter; +} + +} // unnamed namespace + +vImage_Error vImageGrayToBGRA(const vImage_Buffer* src, vImage_Buffer* dst) { + static vImageConverterRef converter = NULL; + if (!converter) { + converter = vImageConverterForCVPixelFormats( + kCVPixelFormatType_OneComponent8, kCVPixelFormatType_32BGRA, NULL); + } + return vImageConvert_AnyToAny(converter, src, dst, NULL, kvImageNoFlags); +} + +vImage_Error vImageBGRAToGray(const vImage_Buffer* src, vImage_Buffer* dst) { + static vImageConverterRef converter = NULL; + if (!converter) { + converter = vImageConverterForCVPixelFormats( + kCVPixelFormatType_32BGRA, kCVPixelFormatType_OneComponent8, NULL); + } + return vImageConvert_AnyToAny(converter, src, dst, NULL, kvImageNoFlags); +} + +vImage_Error vImageRGBAToGray(const vImage_Buffer* src, vImage_Buffer* dst) { + static vImageConverterRef converter = NULL; + if (!converter) { + converter = vImageConverterForCVPixelFormats( + kCVPixelFormatType_32RGBA, kCVPixelFormatType_OneComponent8, NULL); + } + return vImageConvert_AnyToAny(converter, src, dst, NULL, kvImageNoFlags); +} + +vImage_Error vImageConvertCVPixelBuffers(CVPixelBufferRef src, + CVPixelBufferRef dst) { + // CGColorSpaceRef srgb_color_space = + // CGColorSpaceCreateWithName(kCGColorSpaceSRGB); + vImage_Error error; + vImageConverterRef converter = vImageConverterForCVPixelFormats( + CVPixelBufferGetPixelFormatType(src), + CVPixelBufferGetPixelFormatType(dst), &error); + if (!converter) { + return error; + } + + int src_buffer_count = vImageConverter_GetNumberOfSourceBuffers(converter); + int dst_buffer_count = + vImageConverter_GetNumberOfDestinationBuffers(converter); + vImage_Buffer buffers[8]; + if (src_buffer_count + dst_buffer_count > ABSL_ARRAYSIZE(buffers)) { + vImageConverter_Release(converter); + return kvImageMemoryAllocationError; + } + vImage_Buffer* src_bufs = buffers; + vImage_Buffer* dst_bufs = buffers + src_buffer_count; + + // vImageBuffer_InitForCopyToCVPixelBuffer can be used only if the converter + // was created by vImageConverter_CreateForCGToCVImageFormat. + // vImageBuffer_InitForCopyFromCVPixelBuffer can be used only if the converter + // was created by vImageConverter_CreateForCVToCGImageFormat. + // There does not seem to be a way to ask the converter for its type; however, + // it is documented that all multi-planar formats are CV formats, so we use + // these calls when there are multiple buffers. + + if (src_buffer_count > 1) { + error = vImageBuffer_InitForCopyFromCVPixelBuffer( + src_bufs, converter, src, + kvImageNoAllocate | kvImagePrintDiagnosticsToConsole); + if (error != kvImageNoError) { + vImageConverter_Release(converter); + return error; + } + } else { + *src_bufs = vImageForCVPixelBuffer(src); + } + + if (dst_buffer_count > 1) { + error = vImageBuffer_InitForCopyToCVPixelBuffer( + dst_bufs, converter, dst, + kvImageNoAllocate | kvImagePrintDiagnosticsToConsole); + if (error != kvImageNoError) { + vImageConverter_Release(converter); + return error; + } + } else { + *dst_bufs = vImageForCVPixelBuffer(dst); + } + + error = vImageConvert_AnyToAny(converter, src_bufs, dst_bufs, NULL, + kvImageNoFlags); + vImageConverter_Release(converter); + return error; +} + +void ReleaseMediaPipePacket(void* refcon, const void* base_address) { + auto packet = (mediapipe::Packet*)refcon; + delete packet; +} + +CVPixelBufferRef CreateCVPixelBufferForImageFramePacket( + const mediapipe::Packet& image_frame_packet) { + CFHolder buffer; + ::mediapipe::Status status = + CreateCVPixelBufferForImageFramePacket(image_frame_packet, &buffer); + MEDIAPIPE_CHECK_OK(status) << "Failed to create CVPixelBufferRef"; + return (CVPixelBufferRef)CFRetain(*buffer); +} + +::mediapipe::Status CreateCVPixelBufferForImageFramePacket( + const mediapipe::Packet& image_frame_packet, + CFHolder* out_buffer) { + return CreateCVPixelBufferForImageFramePacket(image_frame_packet, false, + out_buffer); +} + +::mediapipe::Status CreateCVPixelBufferForImageFramePacket( + const mediapipe::Packet& image_frame_packet, bool can_overwrite, + CFHolder* out_buffer) { + if (!out_buffer) { + return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "out_buffer cannot be NULL"; + } + CFHolder pixel_buffer; + + auto packet_copy = absl::make_unique(image_frame_packet); + const auto& frame = packet_copy->Get(); + void* frame_data = + const_cast(reinterpret_cast(frame.PixelData())); + + mediapipe::ImageFormat::Format image_format = frame.Format(); + OSType pixel_format = 0; + CVReturn status; + switch (image_format) { + case mediapipe::ImageFormat::SRGBA: { + pixel_format = kCVPixelFormatType_32BGRA; + // Swap R and B channels. + vImage_Buffer v_image = vImageForImageFrame(frame); + vImage_Buffer v_dest; + if (can_overwrite) { + v_dest = v_image; + } else { + CVPixelBufferRef pixel_buffer_temp; + status = CVPixelBufferCreate( + kCFAllocatorDefault, frame.Width(), frame.Height(), pixel_format, + GetCVPixelBufferAttributesForGlCompatibility(), &pixel_buffer_temp); + RET_CHECK(status == kCVReturnSuccess) + << "CVPixelBufferCreate failed: " << status; + pixel_buffer.adopt(pixel_buffer_temp); + status = CVPixelBufferLockBaseAddress(*pixel_buffer, + kCVPixelBufferLock_ReadOnly); + RET_CHECK(status == kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << status; + v_dest = vImageForCVPixelBuffer(*pixel_buffer); + } + const uint8_t permute_map[4] = {2, 1, 0, 3}; + vImage_Error vError = vImagePermuteChannels_ARGB8888( + &v_image, &v_dest, permute_map, kvImageNoFlags); + RET_CHECK(vError == kvImageNoError) + << "vImagePermuteChannels failed: " << vError; + } break; + + case mediapipe::ImageFormat::GRAY8: + pixel_format = kCVPixelFormatType_OneComponent8; + break; + + default: + return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + << "unsupported ImageFrame format: " << image_format; + } + + if (*pixel_buffer) { + status = CVPixelBufferUnlockBaseAddress(*pixel_buffer, + kCVPixelBufferLock_ReadOnly); + RET_CHECK(status == kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << status; + } else { + CVPixelBufferRef pixel_buffer_temp; + status = CVPixelBufferCreateWithBytes( + NULL, frame.Width(), frame.Height(), pixel_format, frame_data, + frame.WidthStep(), ReleaseMediaPipePacket, packet_copy.release(), + GetCVPixelBufferAttributesForGlCompatibility(), &pixel_buffer_temp); + RET_CHECK(status == kCVReturnSuccess) + << "failed to create pixel buffer: " << status; + pixel_buffer.adopt(pixel_buffer_temp); + } + + *out_buffer = pixel_buffer; + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CreateCGImageFromCVPixelBuffer( + CVPixelBufferRef image_buffer, CFHolder* image) { + CVReturn status = + CVPixelBufferLockBaseAddress(image_buffer, kCVPixelBufferLock_ReadOnly); + RET_CHECK(status == kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << status; + + void* base_address = CVPixelBufferGetBaseAddress(image_buffer); + size_t bytes_per_row = CVPixelBufferGetBytesPerRow(image_buffer); + size_t width = CVPixelBufferGetWidth(image_buffer); + size_t height = CVPixelBufferGetHeight(image_buffer); + OSType pixel_format = CVPixelBufferGetPixelFormatType(image_buffer); + + CGColorSpaceRef color_space = nullptr; + uint32_t bitmap_info = 0; + switch (pixel_format) { + case kCVPixelFormatType_32BGRA: + color_space = CGColorSpaceCreateDeviceRGB(); + bitmap_info = + kCGBitmapByteOrder32Little | kCGImageAlphaPremultipliedFirst; + break; + + case kCVPixelFormatType_OneComponent8: + color_space = CGColorSpaceCreateDeviceGray(); + bitmap_info = kCGImageAlphaNone; + break; + + default: + LOG(FATAL) << "Unsupported pixelFormat " << pixel_format; + break; + } + + CGContextRef src_context = CGBitmapContextCreate( + base_address, width, height, 8, bytes_per_row, color_space, bitmap_info); + + CGImageRef quartz_image = CGBitmapContextCreateImage(src_context); + CGContextRelease(src_context); + CGColorSpaceRelease(color_space); + CFHolder cg_image_holder = MakeCFHolderAdopting(quartz_image); + status = + CVPixelBufferUnlockBaseAddress(image_buffer, kCVPixelBufferLock_ReadOnly); + RET_CHECK(status == kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << status; + + *image = cg_image_holder; + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status CreateCVPixelBufferFromCGImage( + CGImageRef image, CFHolder* out_buffer) { + size_t width = CGImageGetWidth(image); + size_t height = CGImageGetHeight(image); + CFHolder pixel_buffer; + + CVPixelBufferRef pixel_buffer_temp; + CVReturn status = CVPixelBufferCreate( + kCFAllocatorDefault, width, height, kCVPixelFormatType_32BGRA, + GetCVPixelBufferAttributesForGlCompatibility(), &pixel_buffer_temp); + RET_CHECK(status == kCVReturnSuccess) + << "failed to create pixel buffer: " << status; + pixel_buffer.adopt(pixel_buffer_temp); + + status = CVPixelBufferLockBaseAddress(*pixel_buffer, 0); + RET_CHECK(status == kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << status; + + void* base_address = CVPixelBufferGetBaseAddress(*pixel_buffer); + CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); + size_t bytes_per_row = CVPixelBufferGetBytesPerRow(*pixel_buffer); + CGContextRef context = CGBitmapContextCreate( + base_address, width, height, 8, bytes_per_row, color_space, + kCGBitmapByteOrder32Little | kCGImageAlphaPremultipliedFirst); + CGRect rect = CGRectMake(0, 0, width, height); + CGContextClearRect(context, rect); + CGContextDrawImage(context, rect, image); + + CGContextRelease(context); + CGColorSpaceRelease(color_space); + status = CVPixelBufferUnlockBaseAddress(*pixel_buffer, 0); + RET_CHECK(status == kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << status; + + *out_buffer = pixel_buffer; + return ::mediapipe::OkStatus(); +} + +std::unique_ptr CreateImageFrameForCVPixelBuffer( + CVPixelBufferRef image_buffer) { + return CreateImageFrameForCVPixelBuffer(image_buffer, false, false); +} + +std::unique_ptr CreateImageFrameForCVPixelBuffer( + CVPixelBufferRef image_buffer, bool can_overwrite, bool bgr_as_rgb) { + CVReturn status = + CVPixelBufferLockBaseAddress(image_buffer, kCVPixelBufferLock_ReadOnly); + CHECK_EQ(status, kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << status; + + void* base_address = CVPixelBufferGetBaseAddress(image_buffer); + size_t bytes_per_row = CVPixelBufferGetBytesPerRow(image_buffer); + size_t width = CVPixelBufferGetWidth(image_buffer); + size_t height = CVPixelBufferGetHeight(image_buffer); + std::unique_ptr frame; + + CVPixelBufferRetain(image_buffer); + + OSType pixel_format = CVPixelBufferGetPixelFormatType(image_buffer); + mediapipe::ImageFormat::Format image_format = mediapipe::ImageFormat::UNKNOWN; + switch (pixel_format) { + case kCVPixelFormatType_32BGRA: { + image_format = mediapipe::ImageFormat::SRGBA; + if (!bgr_as_rgb) { + // Swap R and B channels. + vImage_Buffer v_image = vImageForCVPixelBuffer(image_buffer); + vImage_Buffer v_dest; + if (can_overwrite) { + v_dest = v_image; + } else { + frame = absl::make_unique(image_format, width, + height); + v_dest = vImageForImageFrame(*frame); + } + const uint8_t permute_map[4] = {2, 1, 0, 3}; + vImage_Error vError = vImagePermuteChannels_ARGB8888( + &v_image, &v_dest, permute_map, kvImageNoFlags); + CHECK(vError == kvImageNoError) + << "vImagePermuteChannels failed: " << vError; + } + } break; + + case kCVPixelFormatType_32RGBA: + image_format = mediapipe::ImageFormat::SRGBA; + break; + + case kCVPixelFormatType_24RGB: + image_format = mediapipe::ImageFormat::SRGB; + break; + + case kCVPixelFormatType_OneComponent8: + image_format = mediapipe::ImageFormat::GRAY8; + break; + + default: { + char format_str[5] = {static_cast(pixel_format >> 24 & 0xFF), + static_cast(pixel_format >> 16 & 0xFF), + static_cast(pixel_format >> 8 & 0xFF), + static_cast(pixel_format & 0xFF), 0}; + LOG(FATAL) << "unsupported pixel format: " << format_str; + } break; + } + + if (frame) { + // We have already created a new frame that does not reference the buffer. + status = CVPixelBufferUnlockBaseAddress(image_buffer, + kCVPixelBufferLock_ReadOnly); + CHECK_EQ(status, kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << status; + CVPixelBufferRelease(image_buffer); + } else { + frame = absl::make_unique( + image_format, width, height, bytes_per_row, + reinterpret_cast(base_address), [image_buffer](uint8* x) { + CVPixelBufferUnlockBaseAddress(image_buffer, + kCVPixelBufferLock_ReadOnly); + CVPixelBufferRelease(image_buffer); + }); + } + return frame; +} + +CFDictionaryRef GetCVPixelBufferAttributesForGlCompatibility() { + static CFDictionaryRef attrs = NULL; + if (!attrs) { + CFDictionaryRef empty_dict = CFDictionaryCreate( + kCFAllocatorDefault, NULL, NULL, 0, &kCFTypeDictionaryKeyCallBacks, + &kCFTypeDictionaryValueCallBacks); + // To ensure compatibility with CVOpenGLESTextureCache, these attributes + // should be present. + const void* keys[] = { + kCVPixelBufferIOSurfacePropertiesKey, +#if TARGET_OS_OSX + kCVPixelFormatOpenGLCompatibility, +#else + kCVPixelFormatOpenGLESCompatibility, +#endif // TARGET_OS_OSX + }; + const void* values[] = {empty_dict, kCFBooleanTrue}; + attrs = CFDictionaryCreate( + kCFAllocatorDefault, keys, values, ABSL_ARRAYSIZE(values), + &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); + CFRelease(empty_dict); + } + return attrs; +} + +void DumpCVPixelFormats() { + CFArrayRef pf_descs = + CVPixelFormatDescriptionArrayCreateWithAllPixelFormatTypes( + kCFAllocatorDefault); + CFIndex count = CFArrayGetCount(pf_descs); + CFIndex i; + + printf("Core Video Supported Pixel Format Types:\n"); + + for (i = 0; i < count; i++) { + CFNumberRef pf_num = (CFNumberRef)CFArrayGetValueAtIndex(pf_descs, i); + if (!pf_num) continue; + + int pf; + CFNumberGetValue(pf_num, kCFNumberSInt32Type, &pf); + + if (pf <= 0x28) { + printf("\nCore Video Pixel Format Type: %d\n", pf); + } else { + printf("\nCore Video Pixel Format Type (FourCC): %c%c%c%c\n", + static_cast(pf >> 24), static_cast(pf >> 16), + static_cast(pf >> 8), static_cast(pf)); + } + + CFDictionaryRef desc = CVPixelFormatDescriptionCreateWithPixelFormatType( + kCFAllocatorDefault, pf); + CFDictionaryApplyFunction( + desc, + [](const void* key, const void* value, void* context) { + CFStringRef s = CFStringCreateWithFormat( + kCFAllocatorDefault, nullptr, CFSTR(" %@: %@"), key, value); + CFShow(s); + CFRelease(s); + }, + nullptr); + CFRelease(desc); + } + CFRelease(pf_descs); +} diff --git a/mediapipe/objc/util.h b/mediapipe/objc/util.h new file mode 100644 index 000000000..2bb07500b --- /dev/null +++ b/mediapipe/objc/util.h @@ -0,0 +1,122 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_OBJC_UTIL_H_ +#define MEDIAPIPE_OBJC_UTIL_H_ + +#import +#import +#import + +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/objc/CFHolder.h" + +// TODO: namespace and/or prefix these. Split up the file. + +/// Returns a vImage_Buffer describing the data inside the pixel_buffer. +/// NOTE: the pixel buffer's base address must have been locked before this +/// call, and it must stay locked as long as the vImage_Buffer is in use. +inline vImage_Buffer vImageForCVPixelBuffer(CVPixelBufferRef pixel_buffer) { + return {.data = CVPixelBufferGetBaseAddress(pixel_buffer), + .width = CVPixelBufferGetWidth(pixel_buffer), + .height = CVPixelBufferGetHeight(pixel_buffer), + .rowBytes = CVPixelBufferGetBytesPerRow(pixel_buffer)}; +} + +/// Returns a vImage_Buffer describing the data inside the ImageFrame. +inline vImage_Buffer vImageForImageFrame(const mediapipe::ImageFrame& frame) { + return {.data = (void*)frame.PixelData(), + .width = static_cast(frame.Width()), + .height = static_cast(frame.Height()), + .rowBytes = static_cast(frame.WidthStep())}; +} + +/// Converts a grayscale image without alpha to BGRA format. +vImage_Error vImageGrayToBGRA(const vImage_Buffer* src, vImage_Buffer* dst); + +/// Converts a BGRA image to grayscale without alpha. +vImage_Error vImageBGRAToGray(const vImage_Buffer* src, vImage_Buffer* dst); + +/// Converts an RGBA image to grayscale without alpha. +vImage_Error vImageRGBAToGray(const vImage_Buffer* src, vImage_Buffer* dst); + +/// Copy from a pixel buffer to another, converting pixel format. +/// Both pixel buffers should be locked before calling this. +vImage_Error vImageConvertCVPixelBuffers(CVPixelBufferRef src, + CVPixelBufferRef dst); + +/// When storing a mediapipe::Packet* in a CVPixelBuffer's refcon, this can be +/// used as a CVPixelBufferReleaseBytesCallback. This keeps the packet's data +/// alive while the CVPixelBuffer is in use. +void ReleaseMediaPipePacket(void* refcon, const void* base_address); + +/// Returns a CVPixelBuffer that references the data inside the packet. The +/// packet must contain an ImageFrame. The CVPixelBuffer manages a copy of +/// the packet, so that the packet's data is kept alive as long as the +/// CVPixelBuffer is in use. +/// +/// For formats which are not supported by both image types, it may be +/// necessary to convert the data. This is done by creating a new buffer. +/// If the optional can_overwrite parameter is true, the old buffer may be +/// modified instead. +::mediapipe::Status CreateCVPixelBufferForImageFramePacket( + const mediapipe::Packet& image_frame_packet, + CFHolder* out_buffer); +::mediapipe::Status CreateCVPixelBufferForImageFramePacket( + const mediapipe::Packet& image_frame_packet, bool can_overwrite, + CFHolder* out_buffer); + +/// Creates a CVPixelBuffer with a copy of the contents of the CGImage. +::mediapipe::Status CreateCVPixelBufferFromCGImage( + CGImageRef image, CFHolder* out_buffer); + +/// Creates a CGImage with a copy of the contents of the CVPixelBuffer. +::mediapipe::Status CreateCGImageFromCVPixelBuffer( + CVPixelBufferRef image_buffer, CFHolder* image); + +/// DEPRECATED: use the version that returns ::mediapipe::Status instead. +CVPixelBufferRef CreateCVPixelBufferForImageFramePacket( + const mediapipe::Packet& image_frame_packet); + +/// Returns an ImageFrame that references the data inside the pixel_buffer. +/// The ImageFrame retains the pixel_buffer and keeps it locked as long as it +/// is in use. +/// +/// For formats which are not supported by both image types, it may be +/// necessary to convert the data. This is done by creating a new buffer. +/// If the optional can_overwrite parameter is true, the old buffer may be +/// modified instead. +/// +/// ImageFrame does not have a format for BGRA data, so we normally swap the +/// channels to produce RGBA. But many graphs do not care about the order of +/// the channels; in those cases, setting the optional bgr_as_rgb parameter +/// to true skips the channel swap. +std::unique_ptr CreateImageFrameForCVPixelBuffer( + CVPixelBufferRef pixel_buffer); +std::unique_ptr CreateImageFrameForCVPixelBuffer( + CVPixelBufferRef pixel_buffer, bool can_overwrite, bool bgr_as_rgb); + +/// Returns a CFDictionaryRef that can be passed to CVPixelBufferCreate to +/// ensure that the pixel buffer is compatible with OpenGL ES and with +/// CVOpenGLESTextureCacheCreateTextureFromImage. +/// The returned object is persistent and should not be released. +CFDictionaryRef GetCVPixelBufferAttributesForGlCompatibility(); + +/// Prints debug information about available CoreVideo pixel formats. +/// This prints to stdout. +void DumpCVPixelFormats(); + +#endif // MEDIAPIPE_OBJC_UTIL_H_ diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index 9a7d0fd76..c0602d4e9 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -19,6 +19,15 @@ package(default_visibility = ["//visibility:private"]) load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +proto_library( + name = "audio_decoder_proto", + srcs = ["audio_decoder.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + ], +) + proto_library( name = "color_proto", srcs = ["color.proto"], @@ -47,6 +56,40 @@ mediapipe_cc_proto_library( deps = [":render_data_proto"], ) +mediapipe_cc_proto_library( + name = "audio_decoder_cc_proto", + srcs = ["audio_decoder.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//mediapipe:__subpackages__"], + deps = [":audio_decoder_proto"], +) + +cc_library( + name = "audio_decoder", + srcs = ["audio_decoder.cc"], + hdrs = ["audio_decoder.h"], + visibility = ["//mediapipe:__subpackages__"], + deps = [ + ":audio_decoder_cc_proto", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/deps:cleanup", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:commandlineflags", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:map_util", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:status_util", + "//third_party:libffmpeg", + "@com_google_absl//absl/base:endian", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@eigen_archive//:eigen", + ], +) + cc_library( name = "cpu_util", srcs = ["cpu_util.cc"], @@ -130,9 +173,19 @@ cc_library( srcs = select({ "//conditions:default": ["resource_util.cc"], "//mediapipe:android": ["resource_util_android.cc"], + "//mediapipe:apple": ["resource_util_apple.cc"], + "//mediapipe:macos": ["resource_util.cc"], }), hdrs = ["resource_util.h"], # We use Objective-C++ on iOS. + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-std=c++11", + "-ObjC++", + ], + "//mediapipe:macos": [], + }), visibility = [ "//mediapipe/framework:mediapipe_internal", ], @@ -149,6 +202,10 @@ cc_library( "//mediapipe/util/android:asset_manager_util", "//mediapipe/util/android/file/base", ], + "//mediapipe:apple": [], + "//mediapipe:macos": [ + "//mediapipe/framework/port:file_helpers", + ], }), ) @@ -179,3 +236,59 @@ cc_library( ], }), ) + +cc_library( + name = "time_series_util", + srcs = ["time_series_util.cc"], + hdrs = ["time_series_util.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "time_series_test_util", + testonly = 1, + hdrs = ["time_series_test_util.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":time_series_util", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + "@eigen_archive//:eigen", + ], +) + +cc_test( + name = "time_series_util_test", + size = "small", + srcs = ["time_series_util_test.cc"], + deps = [ + ":time_series_util", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:gtest_main", + "@eigen_archive//:eigen", + ], +) diff --git a/mediapipe/util/annotation_renderer.cc b/mediapipe/util/annotation_renderer.cc index 51cabf09f..85ca2e6b7 100644 --- a/mediapipe/util/annotation_renderer.cc +++ b/mediapipe/util/annotation_renderer.cc @@ -31,6 +31,7 @@ using FilledRectangle = RenderAnnotation::FilledRectangle; using FilledRoundedRectangle = RenderAnnotation::FilledRoundedRectangle; using Point = RenderAnnotation::Point; using Line = RenderAnnotation::Line; +using GradientLine = RenderAnnotation::GradientLine; using Oval = RenderAnnotation::Oval; using Rectangle = RenderAnnotation::Rectangle; using RoundedRectangle = RenderAnnotation::RoundedRectangle; @@ -56,10 +57,28 @@ bool NormalizedtoPixelCoordinates(double normalized_x, double normalized_y, } cv::Scalar MediapipeColorToOpenCVColor(const Color& color) { - return cv::Scalar(static_cast(color.r() * 255.0f), - static_cast(color.g() * 255.0f), - static_cast(color.b() * 255.0f)); + return cv::Scalar(color.r(), color.g(), color.b()); } + +cv::RotatedRect RectangleToOpenCVRotatedRect(int left, int top, int right, + int bottom, double rotation) { + return cv::RotatedRect( + cv::Point2f((left + right) / 2.f, (top + bottom) / 2.f), + cv::Size2f(right - left, bottom - top), rotation / M_PI * 180.f); +} + +void cv_line2(cv::Mat& img, const cv::Point& start, const cv::Point& end, + const cv::Scalar& color1, const cv::Scalar& color2, + int thickness) { + cv::LineIterator iter(img, start, end, /*cv::LINE_4=*/4); + for (int i = 0; i < iter.count; i++, iter++) { + const double alpha = static_cast(i) / iter.count; + const cv::Scalar new_color(color1 * (1.0 - alpha) + color2 * alpha); + const cv::Rect rect(iter.pos(), cv::Size(thickness, thickness)); + cv::rectangle(img, rect, new_color, /*cv::FILLED=*/-1, /*cv::LINE_4=*/4); + } +} + } // namespace void AnnotationRenderer::RenderDataOnImage(const RenderData& render_data) { @@ -83,6 +102,8 @@ void AnnotationRenderer::RenderDataOnImage(const RenderData& render_data) { DrawPoint(annotation); } else if (annotation.data_case() == RenderAnnotation::kLine) { DrawLine(annotation); + } else if (annotation.data_case() == RenderAnnotation::kGradientLine) { + DrawGradientLine(annotation); } else if (annotation.data_case() == RenderAnnotation::kArrow) { DrawArrow(annotation); } else { @@ -126,10 +147,23 @@ void AnnotationRenderer::DrawRectangle(const RenderAnnotation& annotation) { bottom = static_cast(rectangle.bottom()); } - cv::Rect rect(left, top, right - left, bottom - top); const cv::Scalar color = MediapipeColorToOpenCVColor(annotation.color()); const int thickness = annotation.thickness(); - cv::rectangle(mat_image_, rect, color, thickness); + + if (rectangle.rotation() != 0.0) { + const auto& rect = RectangleToOpenCVRotatedRect(left, top, right, bottom, + rectangle.rotation()); + const int kNumVertices = 4; + cv::Point2f vertices[kNumVertices]; + rect.points(vertices); + for (int i = 0; i < kNumVertices; i++) { + cv::line(mat_image_, vertices[i], vertices[(i + 1) % kNumVertices], color, + thickness); + } + } else { + cv::Rect rect(left, top, right - left, bottom - top); + cv::rectangle(mat_image_, rect, color, thickness); + } } void AnnotationRenderer::DrawFilledRectangle( @@ -154,9 +188,24 @@ void AnnotationRenderer::DrawFilledRectangle( bottom = static_cast(rectangle.bottom()); } - cv::Rect rect(left, top, right - left, bottom - top); const cv::Scalar color = MediapipeColorToOpenCVColor(annotation.color()); - cv::rectangle(mat_image_, rect, color, -1); + + if (rectangle.rotation() != 0.0) { + const auto& rect = RectangleToOpenCVRotatedRect(left, top, right, bottom, + rectangle.rotation()); + const int kNumVertices = 4; + cv::Point2f vertices2f[kNumVertices]; + rect.points(vertices2f); + // Convert cv::Point2f[] to cv::Point[]. + cv::Point vertices[kNumVertices]; + for (int i = 0; i < kNumVertices; ++i) { + vertices[i] = vertices2f[i]; + } + cv::fillConvexPoly(mat_image_, vertices, kNumVertices, color); + } else { + cv::Rect rect(left, top, right - left, bottom - top); + cv::rectangle(mat_image_, rect, color, -1); + } } void AnnotationRenderer::DrawRoundedRectangle( @@ -405,6 +454,33 @@ void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) { cv::line(mat_image_, start, end, color, thickness); } +void AnnotationRenderer::DrawGradientLine(const RenderAnnotation& annotation) { + int x_start = -1; + int y_start = -1; + int x_end = -1; + int y_end = -1; + + const auto& line = annotation.gradient_line(); + if (line.normalized()) { + CHECK(NormalizedtoPixelCoordinates(line.x_start(), line.y_start(), + image_width_, image_height_, &x_start, + &y_start)); + CHECK(NormalizedtoPixelCoordinates(line.x_end(), line.y_end(), image_width_, + image_height_, &x_end, &y_end)); + } else { + x_start = static_cast(line.x_start()); + y_start = static_cast(line.y_start()); + x_end = static_cast(line.x_end()); + y_end = static_cast(line.y_end()); + } + const cv::Point start(x_start, y_start); + const cv::Point end(x_end, y_end); + const int thickness = annotation.thickness(); + const cv::Scalar color1 = MediapipeColorToOpenCVColor(line.color1()); + const cv::Scalar color2 = MediapipeColorToOpenCVColor(line.color2()); + cv_line2(mat_image_, start, end, color1, color2, thickness); +} + void AnnotationRenderer::DrawText(const RenderAnnotation& annotation) { int left = -1; int baseline = -1; diff --git a/mediapipe/util/annotation_renderer.h b/mediapipe/util/annotation_renderer.h index 21cd3cda3..60dfb8594 100644 --- a/mediapipe/util/annotation_renderer.h +++ b/mediapipe/util/annotation_renderer.h @@ -93,6 +93,9 @@ class AnnotationRenderer { // Draws a line segment on the image as described in the annotation. void DrawLine(const RenderAnnotation& annotation); + // Draws a 2-tone line segment on the image as described in the annotation. + void DrawGradientLine(const RenderAnnotation& annotation); + // Draws a text on the image as described in the annotation. void DrawText(const RenderAnnotation& annotation); diff --git a/mediapipe/util/audio_decoder.cc b/mediapipe/util/audio_decoder.cc new file mode 100644 index 000000000..345f0c40d --- /dev/null +++ b/mediapipe/util/audio_decoder.cc @@ -0,0 +1,829 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/audio_decoder.h" + +#include +#include // required by avutil.h +#include +#include +#include + +#include "Eigen/Core" +#include "absl/base/internal/endian.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" +#include "absl/time/time.h" +#include "mediapipe/framework/deps/cleanup.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/map_util.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/status_util.h" + +extern "C" { +#include "libavcodec/avcodec.h" +#include "libavformat/avformat.h" +#include "libavutil/avutil.h" +#include "libavutil/mem.h" +#include "libavutil/samplefmt.h" +} + +DEFINE_int64(media_decoder_allowed_audio_gap_merge, 5, + "The time gap forwards or backwards in the audio to ignore. " + "Timestamps in media files are restricted by the container format " + "and stream codec and are invariably not accurate to exact sample " + "numbers. If the discrepency between time based on counting " + "samples and based on the container timestamps grows beyond this " + "value it will be reset to the value in the audio stream and " + "counting based on samples will resume."); + +namespace mediapipe { + +// MPEG PTS max value + 1, used to correct for PTS rollover. Unit is PTS ticks. +const int64 kMpegPtsEpoch = 1LL << 33; +// Maximum PTS change between frames. Larger changes are considered to indicate +// the MPEG PTS has rolled over. Unit is PTS ticks. +const int64 kMpegPtsMaxDelta = kMpegPtsEpoch / 2; + +// BasePacketProcessor +namespace { + +inline std::string TimestampToString(int64 timestamp) { + if (timestamp == AV_NOPTS_VALUE) { + return "NOPTS"; + } + return absl::StrCat(timestamp); +} + +float Uint32ToFloat(uint32 raw_value) { + float value; + memcpy(&value, &raw_value, 4); + return value; +} + +std::string AvErrorToString(int error) { + if (error >= 0) { + return absl::StrCat("Not an error (", error, ")"); + } + + switch (error) { + case AVERROR(EINVAL): + return "AVERROR(EINVAL) - unknown error or invalid data"; + case AVERROR(EIO): + return "AVERROR(EIO) - I/O error"; + case AVERROR(EDOM): + return "AVERROR(EDOM) - Number syntax expected in filename."; + case AVERROR(ENOMEM): + return "AVERROR(ENOMEM) - not enough memory"; + case AVERROR(EILSEQ): + return "AVERROR(EILSEQ) - unknown format"; + case AVERROR(ENOSYS): + return "AVERROR(ENOSYS) - Operation not supported."; + case AVERROR(ENOENT): + return "AVERROR(ENOENT) - No such file or directory."; + case AVERROR(EPIPE): + return "AVERROR(EPIPE) - End of file."; + case AVERROR_BSF_NOT_FOUND: + return "AVERROR_BSF_NOT_FOUND - Bitstream filter not found."; + case AVERROR_BUG: + return "AVERROR_BUG - Internal bug, should not have happened."; + case AVERROR_BUG2: + return "AVERROR_BUG2 - Internal bug, should not have happened."; + case AVERROR_BUFFER_TOO_SMALL: + return "AVERROR_BUFFER_TOO_SMALL - Buffer too small."; + case AVERROR_DECODER_NOT_FOUND: + return "AVERROR_DECODER_NOT_FOUND - Decoder not found."; + case AVERROR_DEMUXER_NOT_FOUND: + return "AVERROR_DEMUXER_NOT_FOUND - Demuxer not found."; + case AVERROR_ENCODER_NOT_FOUND: + return "AVERROR_ENCODER_NOT_FOUND - Encoder not found."; + case AVERROR_EOF: + return "AVERROR_EOF - End of file."; + case AVERROR_EXIT: + return "AVERROR_EXIT - Immediate exit was requested."; + case AVERROR_EXTERNAL: + return "AVERROR_EXTERNAL - Generic error in an external library."; + case AVERROR_FILTER_NOT_FOUND: + return "AVERROR_FILTER_NOT_FOUND - Filter not found."; + case AVERROR_INVALIDDATA: + return "AVERROR_INVALIDDATA - Invalid data found when processing input."; + case AVERROR_MUXER_NOT_FOUND: + return "AVERROR_MUXER_NOT_FOUND - Muxer not found."; + case AVERROR_OPTION_NOT_FOUND: + return "AVERROR_OPTION_NOT_FOUND - Option not found."; + case AVERROR_PATCHWELCOME: + return "AVERROR_PATCHWELCOME - Not yet implemented in FFmpeg, " + "patches welcome."; + case AVERROR_PROTOCOL_NOT_FOUND: + return "AVERROR_PROTOCOL_NOT_FOUND - Protocol not found."; + case AVERROR_STREAM_NOT_FOUND: + return "AVERROR_STREAM_NOT_FOUND - Stream not found."; + case AVERROR_EXPERIMENTAL: + return "AVERROR_EXPERIMENTAL - Requested feature is flagged " + "experimental."; + case AVERROR_INPUT_CHANGED: + return "AVERROR_INPUT_CHANGED - Input changed between calls."; + case AVERROR_OUTPUT_CHANGED: + return "AVERROR_OUTPUT_CHANGED - Output changed between calls."; + default: + // FALLTHRU + {} + } + + char buf[AV_ERROR_MAX_STRING_SIZE]; + if (av_strerror(error, buf, sizeof(buf)) == 0) { + return absl::StrCat("AVERROR(", error, ") - ", buf); + } + + return absl::StrCat("Unknown AVERROR number ", error); +} + +// Send a packet to the decoder. +mediapipe::Status SendPacket(const AVPacket& packet, + AVCodecContext* avcodec_ctx) { + const int error = avcodec_send_packet(avcodec_ctx, &packet); + if (error != 0 && error != AVERROR_EOF) { + // Not consider AVERROR_EOF as an error because it can happen when more + // than 1 flush packet is sent. + return UnknownError(absl::StrCat("Failed to send packet: error=", error, + " (", AvErrorToString(error), + "). Packet size: ", packet.size)); + } + return mediapipe::OkStatus(); +} + +// Receive a decoded frame from the decoder. +mediapipe::Status ReceiveFrame(AVCodecContext* avcodec_ctx, AVFrame* frame, + bool* received) { + const int error = avcodec_receive_frame(avcodec_ctx, frame); + *received = error == 0; + if (error != 0 && error != AVERROR_EOF && error != AVERROR(EAGAIN)) { + // Not consider AVERROR_EOF as an error because it can happen after a + // flush, and AVERROR(EAGAIN) because it happens when there's no (more) + // frame to be received from this packet. + return UnknownError(absl::StrCat(" Failed to receive frame: error=", error, + " (", AvErrorToString(error), ").")); + } + return mediapipe::OkStatus(); +} + +mediapipe::Status LogStatus(const mediapipe::Status& status, + const AVCodecContext& avcodec_ctx, + const AVPacket& packet, + bool always_return_ok_status) { + if (status.ok()) { + return status; + } + + VLOG(3) << "Failed to process packet:" + << " media_type:" + << (avcodec_ctx.codec_type == AVMEDIA_TYPE_VIDEO ? "video" : "audio") + << " codec_id:" << avcodec_ctx.codec_id + << " frame_number:" << avcodec_ctx.frame_number + << " pts:" << TimestampToString(packet.pts) + << " dts:" << TimestampToString(packet.dts) << " size:" << packet.size + << (packet.flags & AV_PKT_FLAG_KEY ? " Key Frame." : ""); + + if (always_return_ok_status) { + LOG(WARNING) << status.message(); + return mediapipe::OkStatus(); + } else { + return status; + } +} + +class AVPacketDeleter { + public: + void operator()(void* x) const { + AVPacket* packet = static_cast(x); + if (packet) { + av_free_packet(packet); + delete packet; + } + } +}; + +} // namespace + +BasePacketProcessor::BasePacketProcessor() + : decoded_frame_(av_frame_alloc()), + source_time_base_{0, 0}, + output_time_base_{1, 1000000}, + source_frame_rate_{0, 0} {} + +BasePacketProcessor::~BasePacketProcessor() { Close(); } + +bool BasePacketProcessor::HasData() { return !buffer_.empty(); } + +mediapipe::Status BasePacketProcessor::GetData(Packet* packet) { + CHECK(packet); + CHECK(!buffer_.empty()); + *packet = buffer_.front(); + buffer_.pop_front(); + + return mediapipe::OkStatus(); +} + +mediapipe::Status BasePacketProcessor::Flush() { + int64 last_num_frames_processed; + do { + std::unique_ptr av_packet(new AVPacket()); + av_init_packet(av_packet.get()); + av_packet->size = 0; + av_packet->data = nullptr; + av_packet->stream_index = id_; + + last_num_frames_processed = num_frames_processed_; + // ProcessPacket increments num_frames_processed_ if it is able to + // decode a frame. Not being able to decode a frame while being + // flushed signals that the codec is completely done. + RETURN_IF_ERROR(ProcessPacket(av_packet.get())); + } while (last_num_frames_processed != num_frames_processed_); + + flushed_ = true; + return mediapipe::OkStatus(); +} + +void BasePacketProcessor::Close() { + if (avcodec_ctx_) { + if (avcodec_ctx_->codec) { + avcodec_close(avcodec_ctx_); + av_free(avcodec_ctx_); + } + avcodec_ctx_ = nullptr; + } + if (avcodec_opts_) { + av_dict_free(&avcodec_opts_); + } + if (decoded_frame_) { + av_frame_free(&decoded_frame_); + } +} + +mediapipe::Status BasePacketProcessor::Decode(const AVPacket& packet, + bool ignore_decode_failures) { + RETURN_IF_ERROR(LogStatus(SendPacket(packet, avcodec_ctx_), *avcodec_ctx_, + packet, ignore_decode_failures)); + while (true) { + bool received; + RETURN_IF_ERROR( + LogStatus(ReceiveFrame(avcodec_ctx_, decoded_frame_, &received), + *avcodec_ctx_, packet, ignore_decode_failures)); + if (received) { + // Successfully decoded a frame (i.e., received it from the decoder). Now + // further process it. + RETURN_IF_ERROR(ProcessDecodedFrame(packet)); + } else { + break; + } + } + return mediapipe::OkStatus(); +} + +int64 BasePacketProcessor::CorrectPtsForRollover(int64 media_pts) { + const int64 rollover_pts_media_bits = kMpegPtsEpoch - 1; + // Ensure PTS in range 0 ... kMpegPtsEpoch. This avoids errors from post + // decode PTS corrections that overflow the epoch range (while still yielding + // the correct result as long as the corrections do not exceed + // kMpegPtsMaxDelta). + media_pts &= rollover_pts_media_bits; + if (rollover_corrected_last_pts_ == AV_NOPTS_VALUE) { + // First seen PTS. + rollover_corrected_last_pts_ = media_pts; + } else { + int64 prev_media_pts = + rollover_corrected_last_pts_ & rollover_pts_media_bits; + int64 pts_step = media_pts - prev_media_pts; + if (pts_step > kMpegPtsMaxDelta) { + pts_step = pts_step - kMpegPtsEpoch; + } else if (pts_step < -kMpegPtsMaxDelta) { + pts_step = kMpegPtsEpoch + pts_step; + } + rollover_corrected_last_pts_ = + std::max((int64)0, rollover_corrected_last_pts_ + pts_step); + } + return rollover_corrected_last_pts_; +} + +// AudioPacketProcessor +namespace { + +// Converts a PCM_S16LE-encoded input sample to float between -1 and 1. +inline float PcmEncodedSampleToFloat(const char* data) { + static const float kMultiplier = 1.f / (1 << 15); + return absl::little_endian::Load16(data) * kMultiplier; +} + +// Converts a PCM_S32LE-encoded input sample to float between -1 and 1. +inline float PcmEncodedSampleInt32ToFloat(const char* data) { + static constexpr float kMultiplier = 1.f / (1u << 31); + return absl::little_endian::Load32(data) * kMultiplier; +} + +} // namespace + +AudioPacketProcessor::AudioPacketProcessor(const AudioStreamOptions& options) + : sample_time_base_{0, 0}, options_(options) { + DCHECK(absl::little_endian::IsLittleEndian()); +} + +mediapipe::Status AudioPacketProcessor::Open(int id, + + AVStream* stream) { + id_ = id; + avcodec_ = avcodec_find_decoder(stream->codecpar->codec_id); + if (!avcodec_) { + return ::mediapipe::InvalidArgumentError("Failed to find codec"); + } + avcodec_ctx_ = avcodec_alloc_context3(avcodec_); + avcodec_parameters_to_context(avcodec_ctx_, stream->codecpar); + if (avcodec_open2(avcodec_ctx_, avcodec_, &avcodec_opts_) < 0) { + return UnknownError("avcodec_open() failed."); + } + CHECK(avcodec_ctx_->codec); + + source_time_base_ = stream->time_base; + source_frame_rate_ = stream->r_frame_rate; + last_frame_time_regression_detected_ = false; + + RETURN_IF_ERROR(ValidateSampleFormat()); + bytes_per_sample_ = av_get_bytes_per_sample(avcodec_ctx_->sample_fmt); + num_channels_ = avcodec_ctx_->channels; + sample_rate_ = avcodec_ctx_->sample_rate; + + if (num_channels_ <= 0) { + return UnknownError("num_channels must be strictly positive."); + } + if (sample_rate_ <= 0) { + return UnknownError("sample_rate must be strictly positive."); + } + + sample_time_base_ = {1, static_cast(sample_rate_)}; + + VLOG(0) << absl::Substitute( + "Opened audio stream (id: $0, channels: $1, sample rate: $2, time base: " + "$3/$4).", + id_, num_channels_, sample_rate_, source_time_base_.num, + source_time_base_.den); + + return mediapipe::OkStatus(); +} + +mediapipe::Status AudioPacketProcessor::ValidateSampleFormat() { + switch (avcodec_ctx_->sample_fmt) { + case AV_SAMPLE_FMT_S16: + case AV_SAMPLE_FMT_S16P: + case AV_SAMPLE_FMT_S32: + case AV_SAMPLE_FMT_FLT: + case AV_SAMPLE_FMT_FLTP: + return mediapipe::OkStatus(); + default: + return mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) + << "sample_fmt = " << avcodec_ctx_->sample_fmt; + } +} + +int64 AudioPacketProcessor::SampleNumberToTimestamp(const int64 sample_number) { + return av_rescale_q(sample_number, sample_time_base_, source_time_base_); +} + +int64 AudioPacketProcessor::TimestampToSampleNumber(const int64 timestamp) { + return av_rescale_q(timestamp, source_time_base_, sample_time_base_); +} + +int64 AudioPacketProcessor::TimestampToMicroseconds(const int64 timestamp) { + return av_rescale_q(timestamp, source_time_base_, {1, 1000000}); +} + +int64 AudioPacketProcessor::SampleNumberToMicroseconds( + const int64 sample_number) { + return av_rescale_q(sample_number, sample_time_base_, {1, 1000000}); +} + +mediapipe::Status AudioPacketProcessor::ProcessPacket(AVPacket* packet) { + CHECK(packet); + if (flushed_) { + return UnknownError( + "ProcessPacket was called, but AudioPacketProcessor is already " + "finished."); + } + RET_CHECK_EQ(packet->stream_index, id_); + + decoded_frame_->nb_samples = 0; + return Decode(*packet, options_.ignore_decode_failures()); +} + +mediapipe::Status AudioPacketProcessor::ProcessDecodedFrame( + const AVPacket& packet) { + RET_CHECK_EQ(decoded_frame_->channels, num_channels_); + int buf_size_bytes = av_samples_get_buffer_size(nullptr, num_channels_, + decoded_frame_->nb_samples, + avcodec_ctx_->sample_fmt, 1); + VLOG(3) << "Audio packet " << avcodec_ctx_->frame_number + << " pts: " << TimestampToString(packet.pts) + << " frame.pts:" << TimestampToString(decoded_frame_->pts) + << " pkt_dts:" << TimestampToString(decoded_frame_->pkt_dts) + << " dts:" << TimestampToString(packet.dts) << " size:" << packet.size + << " decoded:" << buf_size_bytes; + uint8* const* data_ptr = decoded_frame_->data; + if (!data_ptr[0]) { + return UnknownError("No data in audio frame."); + } + if (decoded_frame_->pts != AV_NOPTS_VALUE) { + int64 pts = MaybeCorrectPtsForRollover(decoded_frame_->pts); + if (num_frames_processed_ == 0) { + expected_sample_number_ = TimestampToSampleNumber(pts); + } + + const int64 expected_us = + SampleNumberToMicroseconds(expected_sample_number_); + const int64 actual_us = TimestampToMicroseconds(pts); + if (absl::Microseconds(std::abs(expected_us - actual_us)) > + absl::Seconds(FLAGS_media_decoder_allowed_audio_gap_merge)) { + LOG(ERROR) << "The expected time based on how many samples we have seen (" + << expected_us + << " microseconds) no longer matches the time based " + "on what the audio stream is telling us (" + << actual_us + << " microseconds). The difference is more than " + "--media_decoder_allowed_audio_gap_merge (" + << absl::FormatDuration(absl::Seconds( + FLAGS_media_decoder_allowed_audio_gap_merge)) + << " microseconds). Resetting the timestamps to track what " + "the audio stream is telling us."; + expected_sample_number_ = TimestampToSampleNumber(pts); + } + } + + RETURN_IF_ERROR(AddAudioDataToBuffer( + Timestamp(av_rescale_q(expected_sample_number_, sample_time_base_, + output_time_base_)), + data_ptr, buf_size_bytes)); + + ++num_frames_processed_; + return mediapipe::OkStatus(); +} + +mediapipe::Status AudioPacketProcessor::AddAudioDataToBuffer( + const Timestamp output_timestamp, uint8* const* raw_audio, + int buf_size_bytes) { + if (buf_size_bytes == 0) { + return mediapipe::OkStatus(); + } + + if (buf_size_bytes % (num_channels_ * bytes_per_sample_) != 0) { + return UnknownError("Buffer is not an integral number of samples."); + } + + const int64 num_samples = buf_size_bytes / bytes_per_sample_ / num_channels_; + VLOG(3) << "Adding " << num_samples << " audio samples in " << num_channels_ + << " channels to output."; + auto current_frame = absl::make_unique(num_channels_, num_samples); + + const char* sample_ptr = nullptr; + switch (avcodec_ctx_->sample_fmt) { + case AV_SAMPLE_FMT_S16: + sample_ptr = reinterpret_cast(raw_audio[0]); + for (int64 sample_index = 0; sample_index < num_samples; ++sample_index) { + for (int channel = 0; channel < num_channels_; ++channel) { + (*current_frame)(channel, sample_index) = + PcmEncodedSampleToFloat(sample_ptr); + sample_ptr += bytes_per_sample_; + } + } + break; + case AV_SAMPLE_FMT_S32: + sample_ptr = reinterpret_cast(raw_audio[0]); + for (int64 sample_index = 0; sample_index < num_samples; ++sample_index) { + for (int channel = 0; channel < num_channels_; ++channel) { + (*current_frame)(channel, sample_index) = + PcmEncodedSampleInt32ToFloat(sample_ptr); + sample_ptr += bytes_per_sample_; + } + } + break; + case AV_SAMPLE_FMT_FLT: + sample_ptr = reinterpret_cast(raw_audio[0]); + for (int64 sample_index = 0; sample_index < num_samples; ++sample_index) { + for (int channel = 0; channel < num_channels_; ++channel) { + (*current_frame)(channel, sample_index) = + Uint32ToFloat(absl::little_endian::Load32(sample_ptr)); + sample_ptr += bytes_per_sample_; + } + } + break; + case AV_SAMPLE_FMT_S16P: + for (int channel = 0; channel < num_channels_; ++channel) { + sample_ptr = reinterpret_cast(raw_audio[channel]); + for (int64 sample_index = 0; sample_index < num_samples; + ++sample_index) { + (*current_frame)(channel, sample_index) = + PcmEncodedSampleToFloat(sample_ptr); + sample_ptr += bytes_per_sample_; + } + } + break; + case AV_SAMPLE_FMT_FLTP: + for (int channel = 0; channel < num_channels_; ++channel) { + sample_ptr = reinterpret_cast(raw_audio[channel]); + for (int64 sample_index = 0; sample_index < num_samples; + ++sample_index) { + (*current_frame)(channel, sample_index) = + Uint32ToFloat(absl::little_endian::Load32(sample_ptr)); + sample_ptr += bytes_per_sample_; + } + } + break; + default: + return mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) + << "sample_fmt = " << avcodec_ctx_->sample_fmt; + } + + if (options_.output_regressing_timestamps() || + last_timestamp_ == Timestamp::Unset() || + output_timestamp > last_timestamp_) { + buffer_.push_back(Adopt(current_frame.release()).At(output_timestamp)); + last_timestamp_ = output_timestamp; + if (last_frame_time_regression_detected_) { + last_frame_time_regression_detected_ = false; + LOG(INFO) << "Processor " << this << " resumed audio packet processing."; + } + } else if (!last_frame_time_regression_detected_) { + last_frame_time_regression_detected_ = true; + LOG(ERROR) << "Processor " << this + << " is dropping an audio packet because the timestamps " + "regressed. Was " + << last_timestamp_ << " but got " << output_timestamp; + } + expected_sample_number_ += num_samples; + + return mediapipe::OkStatus(); +} + +mediapipe::Status AudioPacketProcessor::FillHeader( + TimeSeriesHeader* header) const { + CHECK(header); + header->set_sample_rate(sample_rate_); + header->set_num_channels(num_channels_); + return mediapipe::OkStatus(); +} + +int64 AudioPacketProcessor::MaybeCorrectPtsForRollover(int64 media_pts) { + return options_.correct_pts_for_rollover() ? CorrectPtsForRollover(media_pts) + : media_pts; +} + +// AudioDecoder +AudioDecoder::AudioDecoder() { av_register_all(); } + +AudioDecoder::~AudioDecoder() { + ::mediapipe::Status status = Close(); + if (!status.ok()) { + LOG(ERROR) << "Encountered error while closing media file: " + << status.message(); + } +} + +::mediapipe::Status AudioDecoder::Initialize( + const std::string& input_file, + const mediapipe::AudioDecoderOptions options) { + if (options.audio_stream().empty()) { + return ::mediapipe::InvalidArgumentError( + "At least one audio_stream must be defined in AudioDecoderOptions"); + } + std::map stream_index_to_audio_options_index; + int options_index = 0; + for (const auto& audio_stream : options.audio_stream()) { + InsertIfNotPresent(&stream_index_to_audio_options_index, + audio_stream.stream_index(), options_index); + ++options_index; + } + + Cleanup> decoder_closer([this]() { + ::mediapipe::Status status = Close(); + if (!status.ok()) { + LOG(ERROR) << "Encountered error while closing media file: " + << status.message(); + } + }); + + avformat_ctx_ = avformat_alloc_context(); + if (avformat_open_input(&avformat_ctx_, input_file.c_str(), NULL, NULL) < 0) { + return ::mediapipe::InvalidArgumentError( + absl::StrCat("Could not open file: ", input_file)); + } + + if (avformat_find_stream_info(avformat_ctx_, NULL) < 0) { + return ::mediapipe::InvalidArgumentError(absl::StrCat( + "Could not find stream information of file: ", input_file)); + } + + std::map audio_options_index_to_stream_id; + for (int current_audio_index = 0, stream_id = 0; + stream_id < avformat_ctx_->nb_streams; ++stream_id) { + AVStream* stream = avformat_ctx_->streams[stream_id]; + AVCodecParameters* dec_param = stream->codecpar; + switch (dec_param->codec_type) { + case AVMEDIA_TYPE_AUDIO: { + const int* options_index_ptr = FindOrNull( + stream_index_to_audio_options_index, current_audio_index); + if (options_index_ptr) { + std::unique_ptr processor = + absl::make_unique( + options.audio_stream(*options_index_ptr)); + if (!ContainsKey(audio_processor_, stream_id)) { + LOG(INFO) << "Created audio processor " << processor.get() + << " for file \"" << input_file << "\""; + } else { + LOG(ERROR) << "Stream " << stream_id + << " already mapped to audio processor " + << audio_processor_[stream_id].get(); + } + + RETURN_IF_ERROR(processor->Open(stream_id, stream)); + audio_processor_.emplace(stream_id, std::move(processor)); + CHECK(InsertIfNotPresent(&stream_id_to_audio_options_index_, + stream_id, *options_index_ptr)); + CHECK(InsertIfNotPresent(&audio_options_index_to_stream_id, + *options_index_ptr, stream_id)); + } + ++current_audio_index; + break; + } + default: { + // Ignore other stream types. + } + } + } + for (int i = 0; i < options.audio_stream_size(); ++i) { + RET_CHECK(ContainsKey(audio_options_index_to_stream_id, i) || + options.audio_stream(i).allow_missing()) + << absl::StrCat("Could not find audio stream with index ", i, + " in file ", input_file); + } + + if (options.has_start_time()) { + start_time_ = Timestamp::FromSeconds(options.start_time()); + } + if (options.has_end_time()) { + end_time_ = Timestamp::FromSeconds(options.end_time()); + } + is_first_packet_.resize(avformat_ctx_->nb_streams, true); + + decoder_closer.release(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AudioDecoder::GetData(int* options_index, Packet* data) { + while (true) { + for (auto& item : audio_processor_) { + while (item.second && item.second->HasData()) { + bool is_first_packet = is_first_packet_[item.first]; + is_first_packet_[item.first] = false; + *options_index = + FindOrDie(stream_id_to_audio_options_index_, item.first); + ::mediapipe::Status status = item.second->GetData(data); + // Ignore packets which are out of the requested timestamp range. + if (start_time_ != Timestamp::Unset()) { + if (is_first_packet && data->Timestamp() > start_time_) { + LOG(ERROR) << "First packet in audio stream " << *options_index + << " has timestamp " << data->Timestamp() + << " which is after start time of " << start_time_ + << "."; + } + if (data->Timestamp() < start_time_) { + VLOG(1) << "Skipping audio frame with timestamp " + << data->Timestamp() << " before start time " + << start_time_; + *data = Packet(); + continue; + } + } + if (end_time_ != Timestamp::Unset() && data->Timestamp() > end_time_) { + VLOG(1) << "Skipping audio frame with timestamp " << data->Timestamp() + << " after end time " << end_time_; + // We are past the last timestamp we care about, close the + // packet processor. We cannot remove the element from + // audio_processor_ right now, because we need to continue + // iterating through it. + item.second->Close(); + item.second.reset(nullptr); + *data = Packet(); + continue; + } + return status; + } + } + if (flushed_) { + RETURN_IF_ERROR(Close()); + return tool::StatusStop(); + } + RETURN_IF_ERROR(ProcessPacket()); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AudioDecoder::Close() { + for (auto& item : audio_processor_) { + if (item.second) { + item.second->Close(); + item.second.reset(nullptr); + } + } + // Free the context. + if (avformat_ctx_) { + avformat_close_input(&avformat_ctx_); + } + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AudioDecoder::FillAudioHeader( + const AudioStreamOptions& stream_option, TimeSeriesHeader* header) const { + const std::unique_ptr* processor_ptr_ = + FindOrNull(audio_processor_, stream_option.stream_index()); + + RET_CHECK(processor_ptr_ && *processor_ptr_) << "audio stream is not open."; + RETURN_IF_ERROR((*processor_ptr_)->FillHeader(header)); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status AudioDecoder::ProcessPacket() { + std::unique_ptr av_packet(new AVPacket()); + av_init_packet(av_packet.get()); + av_packet->size = 0; + av_packet->data = nullptr; + int ret = av_read_frame(avformat_ctx_, av_packet.get()); + if (ret >= 0) { + CHECK(av_packet->data) << "AVPacket does not include any data but " + "av_read_frame was successful."; + const int stream_id = av_packet->stream_index; + auto audio_iterator = audio_processor_.find(stream_id); + if (audio_iterator != audio_processor_.end()) { + // This stream_id is belongs to an audio stream we care about. + if (audio_iterator->second) { + RETURN_IF_ERROR(audio_iterator->second->ProcessPacket(av_packet.get())); + } else { + VLOG(3) << "processor for stream " << stream_id << " is nullptr."; + } + } else { + VLOG(3) << "Ignoring packet for stream " << stream_id; + } + return ::mediapipe::OkStatus(); + } + VLOG(1) << "Demuxing returned error (or EOF): " << AvErrorToString(ret); + if (ret == AVERROR(EAGAIN)) { + // EAGAIN is used to signify that the av_packet should be skipped + // (maybe the demuxer is trying to re-sync). This definitely + // occurs in the FLV and MpegT demuxers. + return ::mediapipe::OkStatus(); + } + + // Unrecoverable demuxing error with details in avformat_ctx_->pb->error. + int demuxing_error = + avformat_ctx_->pb ? avformat_ctx_->pb->error : 0 /* no error */; + if (ret == AVERROR_EOF && !demuxing_error) { + VLOG(1) << "Reached EOF."; + return Flush(); + } + + RET_CHECK(!demuxing_error) << absl::Substitute( + "Failed to read a frame: retval = $0 ($1), avformat_ctx_->pb->error = " + "$2 ($3)", + ret, AvErrorToString(ret), demuxing_error, + AvErrorToString(demuxing_error)); + + if (is_first_packet_[av_packet->stream_index]) { + RET_CHECK_FAIL() << "Couldn't even read the first frame; maybe a partial " + "file with only metadata?"; + } + + // Unrecoverable demuxing error without details. + RET_CHECK_FAIL() << absl::Substitute( + "Failed to read a frame: retval = $0 ($1)", ret, AvErrorToString(ret)); +} + +::mediapipe::Status AudioDecoder::Flush() { + std::vector<::mediapipe::Status> statuses; + for (auto& item : audio_processor_) { + if (item.second) { + statuses.push_back(item.second->Flush()); + } + } + flushed_ = true; + return tool::CombinedStatus("Error while flushing codecs: ", statuses); +} + +} // namespace mediapipe diff --git a/mediapipe/util/audio_decoder.h b/mediapipe/util/audio_decoder.h new file mode 100644 index 000000000..c2e7d77e6 --- /dev/null +++ b/mediapipe/util/audio_decoder.h @@ -0,0 +1,227 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_AUDIO_DECODER_H_ +#define MEDIAPIPE_UTIL_AUDIO_DECODER_H_ + +#include // required by avutil.h +#include +#include +#include + +#include "absl/time/time.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/commandlineflags.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/util/audio_decoder.pb.h" + +extern "C" { +#include "libavcodec/avcodec.h" +#include "libavformat/avformat.h" +#include "libavutil/avutil.h" +#include "libavutil/dict.h" +#include "mediapipe/util/audio_decoder.pb.h" +} + +namespace mediapipe { + +using mediapipe::AudioStreamOptions; +using mediapipe::TimeSeriesHeader; + +// The base helper class for a processor which handles decoding of a single +// stream. +class BasePacketProcessor { + public: + BasePacketProcessor(); + virtual ~BasePacketProcessor(); + + // Opens the codec. + virtual mediapipe::Status Open(int id, AVStream* stream) = 0; + + // Processes a packet of data. Caller retains ownership of packet. + virtual mediapipe::Status ProcessPacket(AVPacket* packet) = 0; + + // Returns true if the processor has data immediately available + // (without providing more data with ProcessPacket()). + bool HasData(); + + // Fills packet with the next frame of data. Returns an empty packet + // if there is nothing to return. + mediapipe::Status GetData(Packet* packet); + + // Once no more AVPackets are available in the file, each stream must + // be flushed to get any remaining frames which the codec is buffering. + mediapipe::Status Flush(); + + // Closes the Processor, this does not close the file. You may not + // call ProcessPacket() after calling Close(). Close() may be called + // repeatedly. + void Close(); + + protected: + // Decodes frames in a packet. + virtual mediapipe::Status Decode(const AVPacket& packet, + bool ignore_decode_failures); + + // Processes a decoded frame. + virtual mediapipe::Status ProcessDecodedFrame(const AVPacket& packet) = 0; + + // Corrects the given PTS for MPEG PTS rollover. Assumed to be called with + // the PTS of each frame in decode order. We detect a rollover whenever the + // PTS timestamp changes by more than 2^33/2 (half the timstamp space). For + // video this means every 26.5h with 1 PTS tick = 1/90000 of a second. + // Example timeline: + // CorrectPtsForRollover(0) -> 0 + // CorrectPtsForRollover(42) -> 42 + // CorrectPtsForRollover(2^33 - 1) -> 2^33 - 1 + // CorrectPtsForRollover(0) -> 2^33 // PTS in media rolls over, corrected. + // CorrectPtsForRollover(1) -> 2^33 + 1 + int64 CorrectPtsForRollover(int64 media_pts); + + AVCodecContext* avcodec_ctx_ = nullptr; + const AVCodec* avcodec_ = nullptr; + AVDictionary* avcodec_opts_ = nullptr; + AVFrame* decoded_frame_ = nullptr; + + // Stream ID this object processes. + int id_ = -1; + + // Set to true if the stream has been flushed and no more AVPackets + // will be processed with it. + bool flushed_ = false; + + // The source time base. + AVRational source_time_base_; + // The output time base. + const AVRational output_time_base_; + + // The source frame rate (estimated from header information). + AVRational source_frame_rate_; + + // The number of frames that were successfully processed. + int64 num_frames_processed_ = 0; + + int bytes_per_sample_ = 0; + + // boolean flag to show if time regression has been detected for last frame; + bool last_frame_time_regression_detected_ = false; + + // The last rollover corrected PTS returned by CorrectPtsForRollover. + int64 rollover_corrected_last_pts_ = AV_NOPTS_VALUE; + + // The buffer of current frames. + std::deque buffer_; +}; + +// Class which decodes packets from a single audio stream. +class AudioPacketProcessor : public BasePacketProcessor { + public: + explicit AudioPacketProcessor(const AudioStreamOptions& options); + + mediapipe::Status Open(int id, AVStream* stream) override; + + mediapipe::Status ProcessPacket(AVPacket* packet) override; + + mediapipe::Status FillHeader(TimeSeriesHeader* header) const; + + private: + // Appends audio in buffer(s) to the output buffer (buffer_). + mediapipe::Status AddAudioDataToBuffer(const Timestamp output_timestamp, + uint8* const* raw_audio, + int buf_size_bytes); + + // Converts a number of samples into an approximate stream timestamp value. + int64 SampleNumberToTimestamp(const int64 sample_number); + int64 TimestampToSampleNumber(const int64 timestamp); + + // Converts a timestamp/sample number to microseconds. + int64 TimestampToMicroseconds(const int64 timestamp); + int64 SampleNumberToMicroseconds(const int64 sample_number); + + // Returns an error if the sample format in avformat_ctx_.sample_format + // is not supported. + mediapipe::Status ValidateSampleFormat(); + + // Processes a decoded audio frame. audio_frame_ must have been filled + // with the frame before calling this function. + mediapipe::Status ProcessDecodedFrame(const AVPacket& packet) override; + + // Corrects PTS for rollover if correction is enabled. + int64 MaybeCorrectPtsForRollover(int64 media_pts); + + // Number of channels to output. This value might be different from + // the actual number of channels for the current AVPacket, found in + // avcodec_ctx_->channels. + int num_channels_ = -1; + + // Sample rate of the data to output. This value might be different + // from the actual sample rate for the current AVPacket, found in + // avcodec_ctx_->sample_rate. + int64 sample_rate_ = -1; + + // The time base of audio samples (i.e. the reciprocal of the sample rate). + AVRational sample_time_base_; + + // The timestamp of the last packet added to the buffer. + Timestamp last_timestamp_; + + // The expected sample number based on counting samples. + int64 expected_sample_number_ = 0; + + // Options for the processor. + AudioStreamOptions options_; +}; + +// Decode the audio streams of a media file. The AudioDecoder is responsible +// for demuxing the audio streams in the container format, whereas decoding of +// the content is delegated to AudioPacketProcessor. +class AudioDecoder { + public: + AudioDecoder(); + ~AudioDecoder(); + + ::mediapipe::Status Initialize(const std::string& input_file, + const mediapipe::AudioDecoderOptions options); + + ::mediapipe::Status GetData(int* options_index, Packet* data); + + ::mediapipe::Status Close(); + + ::mediapipe::Status FillAudioHeader(const AudioStreamOptions& stream_option, + TimeSeriesHeader* header) const; + + private: + ::mediapipe::Status ProcessPacket(); + ::mediapipe::Status Flush(); + + std::map stream_id_to_audio_options_index_; + std::map> audio_processor_; + + // Indexed by container stream index, true if the stream has not seen + // a packet (whether returned or not), and false otherwise. + std::vector is_first_packet_; + bool flushed_ = false; + + Timestamp start_time_ = Timestamp::Unset(); + Timestamp end_time_ = Timestamp::Unset(); + + AVFormatContext* avformat_ctx_ = nullptr; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_AUDIO_DECODER_H_ diff --git a/mediapipe/util/audio_decoder.proto b/mediapipe/util/audio_decoder.proto new file mode 100644 index 000000000..67c54a209 --- /dev/null +++ b/mediapipe/util/audio_decoder.proto @@ -0,0 +1,53 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message AudioStreamOptions { + // The stream to decode. Stream indexes start from 0 (audio and video + // are handled separately). + optional int64 stream_index = 1 [default = 0]; + + // Process the file despite this stream not being present. + optional bool allow_missing = 2 [default = false]; + + // If true, failures to decode a frame of data will be ignored. + optional bool ignore_decode_failures = 3 [default = false]; + + // Output packets with regressing timestamps. By default those packets are + // dropped. + optional bool output_regressing_timestamps = 4 [default = false]; + + // MPEG PTS timestamps roll over back to 0 after 26.5h. If this flag is set + // we detect any rollover and continue incrementing timestamps past this + // point. Set this flag if you want non-regressing timestamps for MPEG + // content where the PTS may roll over. + optional bool correct_pts_for_rollover = 5; +} + +message AudioDecoderOptions { + extend CalculatorOptions { + optional AudioDecoderOptions ext = 263370674; + } + repeated AudioStreamOptions audio_stream = 1; + + // The start time in seconds to decode. + optional double start_time = 2; + // The end time in seconds to decode (inclusive). + optional double end_time = 3; +} diff --git a/mediapipe/util/render_data.proto b/mediapipe/util/render_data.proto index 8b95afcde..e98df974c 100644 --- a/mediapipe/util/render_data.proto +++ b/mediapipe/util/render_data.proto @@ -45,6 +45,7 @@ message RenderAnnotation { optional double right = 3; optional double bottom = 4; optional bool normalized = 5 [default = false]; + optional double rotation = 6; // Rotation in radians. } message FilledRectangle { @@ -114,6 +115,17 @@ message RenderAnnotation { optional LineType line_type = 6 [default = SOLID]; } + message GradientLine { + optional double x_start = 1; + optional double y_start = 2; + optional double x_end = 3; + optional double y_end = 4; + optional bool normalized = 5 [default = false]; + // Linearly interpolate between color1 and color2 along the line. + optional Color color1 = 6; + optional Color color2 = 7; + } + message Arrow { // The arrow head will be drawn at (x_end, y_end). optional double x_start = 1; @@ -162,6 +174,7 @@ message RenderAnnotation { Text text = 8; RoundedRectangle rounded_rectangle = 9; FilledRoundedRectangle filled_rounded_rectangle = 10; + GradientLine gradient_line = 14; } // Thickness for drawing the annotation. diff --git a/mediapipe/util/resource_util_apple.cc b/mediapipe/util/resource_util_apple.cc index d196faadd..bcd90fe1a 100644 --- a/mediapipe/util/resource_util_apple.cc +++ b/mediapipe/util/resource_util_apple.cc @@ -30,7 +30,7 @@ namespace mediapipe { } NSString* ns_path = [NSString stringWithUTF8String:path.c_str()]; - Class mediapipeGraphClass = NSClassFromString(@"MediaPipeGraph"); + Class mediapipeGraphClass = NSClassFromString(@"MPPGraph"); NSString* resource_dir = [[NSBundle bundleForClass:mediapipeGraphClass] resourcePath]; NSString* resolved_ns_path = diff --git a/mediapipe/util/sequence/media_sequence.cc b/mediapipe/util/sequence/media_sequence.cc index f54ec388e..d50055da6 100644 --- a/mediapipe/util/sequence/media_sequence.cc +++ b/mediapipe/util/sequence/media_sequence.cc @@ -191,29 +191,30 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { // the closest annotation is saved. This matches the behavior of downsampling // images streams in time. ::mediapipe::Status ReconcileMetadataBoxAnnotations( - tensorflow::SequenceExample* sequence) { - int num_bboxes = GetBBoxTimestampSize(*sequence); + const std::string& prefix, tensorflow::SequenceExample* sequence) { + int num_bboxes = GetBBoxTimestampSize(prefix, *sequence); int num_frames = GetImageTimestampSize(*sequence); if (num_bboxes && num_frames) { // If no one has indicated which frames are annotated, assume annotations // are dense. - if (GetBBoxIsAnnotatedSize(*sequence) == 0) { + if (GetBBoxIsAnnotatedSize(prefix, *sequence) == 0) { for (int i = 0; i < num_bboxes; ++i) { - AddBBoxIsAnnotated(true, sequence); + AddBBoxIsAnnotated(prefix, true, sequence); } } - RET_CHECK_EQ(num_bboxes, GetBBoxIsAnnotatedSize(*sequence)) + RET_CHECK_EQ(num_bboxes, GetBBoxIsAnnotatedSize(prefix, *sequence)) << "Expected number of BBox timestamps and annotation marks to match."; // Update num_bboxes. - if (GetBBoxSize(*sequence) > 0) { - auto* bbox_feature_list = - MutableFeatureList(kRegionBBoxXMinKey, sequence); + if (GetBBoxSize(prefix, *sequence) > 0) { + std::string xmin_key = merge_prefix(prefix, kRegionBBoxXMinKey); + auto* bbox_feature_list = MutableFeatureList(xmin_key, sequence); RET_CHECK_EQ(num_bboxes, bbox_feature_list->feature_size()) << "Expected number of BBox timestamps and boxes to match."; - ClearBBoxNumRegions(sequence); + ClearBBoxNumRegions(prefix, sequence); for (int i = 0; i < num_bboxes; ++i) { AddBBoxNumRegions( - bbox_feature_list->feature(i).float_list().value_size(), sequence); + prefix, bbox_feature_list->feature(i).float_list().value_size(), + sequence); } } // Collect which timestamps currently match to which indices in timestamps. @@ -221,15 +222,16 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { // Requires sorted indices. ::std::vector box_timestamps(num_bboxes); int bbox_index = 0; - for (auto& feature : - GetFeatureList(*sequence, kRegionTimestampKey).feature()) { + std::string timestamp_key = merge_prefix(prefix, kRegionTimestampKey); + for (auto& feature : GetFeatureList(*sequence, timestamp_key).feature()) { box_timestamps[bbox_index] = feature.int64_list().value(0); ++bbox_index; } ::std::vector box_is_annotated(num_bboxes); bbox_index = 0; + std::string is_annotated_key = merge_prefix(prefix, kRegionIsAnnotatedKey); for (auto& feature : - GetFeatureList(*sequence, kRegionIsAnnotatedKey).feature()) { + GetFeatureList(*sequence, is_annotated_key).feature()) { box_is_annotated[bbox_index] = feature.int64_list().value(0); ++bbox_index; } @@ -270,62 +272,87 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { } // Only update unmodified bbox timestamp if it doesn't exist to prevent // overwriting with modified values. - if (!GetUnmodifiedBBoxTimestampSize(*sequence)) { - for (int i = 0; i < num_bboxes; ++i) { - if (GetBBoxIsAnnotatedAt(*sequence, i)) { - AddUnmodifiedBBoxTimestamp(box_timestamps[i], sequence); + if (!GetUnmodifiedBBoxTimestampSize(prefix, *sequence)) { + for (int i = 0; i < num_frames; ++i) { + if (bbox_index_if_annotated[i] >= 0 && + GetBBoxIsAnnotatedAt(prefix, *sequence, i)) { + AddUnmodifiedBBoxTimestamp( + prefix, box_timestamps[bbox_index_if_annotated[i]], sequence); } } } // store some new feature_lists in a temporary sequence + std::string expected_prefix = merge_prefix(prefix, "region/"); ::tensorflow::SequenceExample tmp_seq; for (const auto& key_value : sequence->feature_lists().feature_list()) { const std::string& key = key_value.first; - if (::absl::StartsWith(key, "region/")) { + if (::absl::StartsWith(key, expected_prefix)) { // create a new set of values and swap them in. tmp_seq.Clear(); auto* old_feature_list = MutableFeatureList(key, sequence); - if (key != kUnmodifiedRegionTimestampKey) { + auto* new_feature_list = MutableFeatureList(key, &tmp_seq); + if (key != merge_prefix(prefix, kUnmodifiedRegionTimestampKey)) { RET_CHECK_EQ(num_bboxes, old_feature_list->feature().size()) << "Expected number of BBox timestamps to match number of " "entries " << "in " << key; - } - auto* new_feature_list = MutableFeatureList(key, &tmp_seq); - for (int i = 0; i < num_frames; ++i) { - if (bbox_index_if_annotated[i] >= 0) { - if (key == kRegionTimestampKey) { - new_feature_list->add_feature()->mutable_int64_list()->add_value( - image_timestamps[i]); + for (int i = 0; i < num_frames; ++i) { + if (bbox_index_if_annotated[i] >= 0) { + if (key == merge_prefix(prefix, kRegionTimestampKey)) { + new_feature_list->add_feature() + ->mutable_int64_list() + ->add_value(image_timestamps[i]); + } else { + *new_feature_list->add_feature() = + old_feature_list->feature(bbox_index_if_annotated[i]); + } } else { - *new_feature_list->add_feature() = - old_feature_list->feature(bbox_index_if_annotated[i]); - } - } else { - // Add either a default value or an empty. - if (key == kRegionIsAnnotatedKey) { - new_feature_list->add_feature()->mutable_int64_list()->add_value( - 0); - } else if (key == kRegionNumRegionsKey) { - new_feature_list->add_feature()->mutable_int64_list()->add_value( - 0); - } else if (key == kRegionTimestampKey) { - new_feature_list->add_feature()->mutable_int64_list()->add_value( - image_timestamps[i]); - } else if (key == kUnmodifiedRegionTimestampKey) { - // Do not add an unmodified timestamp when - // is_annotated == false. - } else { - new_feature_list->add_feature(); // Adds an empty. + // Add either a default value or an empty. + if (key == merge_prefix(prefix, kRegionIsAnnotatedKey)) { + new_feature_list->add_feature() + ->mutable_int64_list() + ->add_value(0); + } else if (key == merge_prefix(prefix, kRegionNumRegionsKey)) { + new_feature_list->add_feature() + ->mutable_int64_list() + ->add_value(0); + } else if (key == merge_prefix(prefix, kRegionTimestampKey)) { + new_feature_list->add_feature() + ->mutable_int64_list() + ->add_value(image_timestamps[i]); + } else { + new_feature_list->add_feature(); // Adds an empty. + } } } + *old_feature_list = *new_feature_list; } - *old_feature_list = *new_feature_list; } } } return ::mediapipe::OkStatus(); } + +::mediapipe::Status ReconcileMetadataRegionAnnotations( + tensorflow::SequenceExample* sequence) { + // Copy keys for fixed iteration order while updating feature_lists. + std::vector key_ptrs; + for (const auto& key_value : sequence->feature_lists().feature_list()) { + key_ptrs.push_back(&key_value.first); + } + for (const std::string* key_ptr : key_ptrs) { + const std::string& key = *key_ptr; + if (::absl::StrContains(key, kRegionTimestampKey)) { + std::string prefix = + key.substr(0, key.size() - sizeof(kRegionTimestampKey)); + if (key == kRegionTimestampKey) { + prefix = ""; + } + RET_CHECK_OK(ReconcileMetadataBoxAnnotations(prefix, sequence)); + } + } + return ::mediapipe::OkStatus(); +} } // namespace int GetBBoxSize(const std::string& prefix, @@ -368,6 +395,14 @@ void AddBBox(const std::string& prefix, AddBBoxYMax(prefix, ymaxs, sequence); } +void ClearBBox(const std::string& prefix, + tensorflow::SequenceExample* sequence) { + ClearBBoxXMin(prefix, sequence); + ClearBBoxYMin(prefix, sequence); + ClearBBoxXMax(prefix, sequence); + ClearBBoxYMax(prefix, sequence); +} + int GetPointSize(const std::string& prefix, const tensorflow::SequenceExample& sequence) { return GetBBoxPointXSize(prefix, sequence); @@ -399,6 +434,12 @@ void AddPoint(const std::string& prefix, AddBBoxPointX(prefix, xs, sequence); } +void ClearPoint(const std::string& prefix, + tensorflow::SequenceExample* sequence) { + ClearBBoxPointY(prefix, sequence); + ClearBBoxPointX(prefix, sequence); +} + std::unique_ptr GetAudioFromFeatureAt( const std::string& prefix, const tensorflow::SequenceExample& sequence, int index) { @@ -431,6 +472,7 @@ void AddAudioAsFeature(const std::string& prefix, } ::mediapipe::Status ReconcileMetadata(bool reconcile_bbox_annotations, + bool reconcile_region_annotations, tensorflow::SequenceExample* sequence) { RET_CHECK_OK(ReconcileAnnotationIndicesByImageTimestamps(sequence)); RET_CHECK_OK(ReconcileMetadataImages("", sequence)); @@ -439,7 +481,10 @@ void AddAudioAsFeature(const std::string& prefix, RET_CHECK_OK(ReconcileMetadataImages(kInstanceSegmentationPrefix, sequence)); RET_CHECK_OK(ReconcileMetadataFeatureFloats(sequence)); if (reconcile_bbox_annotations) { - RET_CHECK_OK(ReconcileMetadataBoxAnnotations(sequence)); + RET_CHECK_OK(ReconcileMetadataBoxAnnotations("", sequence)); + } + if (reconcile_region_annotations) { + RET_CHECK_OK(ReconcileMetadataRegionAnnotations(sequence)); } // audio is always reconciled in the framework. return ::mediapipe::OkStatus(); diff --git a/mediapipe/util/sequence/media_sequence.h b/mediapipe/util/sequence/media_sequence.h index c95aaf700..9d93da60d 100644 --- a/mediapipe/util/sequence/media_sequence.h +++ b/mediapipe/util/sequence/media_sequence.h @@ -112,7 +112,7 @@ // tensorflow::SequenceExample example; // SetClipLabelString({"run", "jump"}, &example); // if (HasClipLabelString(example)) { -// ::std::vector values = GetClipLabelString(example); +// std::vector values = GetClipLabelString(example); // ClearClipLabelString(&example); // } // @@ -140,7 +140,7 @@ // AddBBoxLabelString({"run", "fall"}, &example); // if (HasBBoxLabelString(example)) { // for (int i = 0; i < GetBBoxLabelStringSize(); ++i) { -// ::std::vector labels = GetBBoxLabelStringAt(example, i); +// std::vector labels = GetBBoxLabelStringAt(example, i); // } // ClearBBoxLabelString(&example); // } @@ -315,17 +315,21 @@ std::vector<::mediapipe::Location> GetBBoxAt( void AddBBox(const std::string& prefix, const std::vector<::mediapipe::Location>& bboxes, tensorflow::SequenceExample* sequence); +void ClearBBox(const std::string& prefix, + tensorflow::SequenceExample* sequence); // The input and output format is a pair of coordinates to match the // order of bounding box coordinates. int GetPointSize(const std::string& prefix, const tensorflow::SequenceExample& sequence); -::std::vector<::std::pair> GetPointAt( +std::vector> GetPointAt( const std::string& prefix, const tensorflow::SequenceExample& sequence, int index); void AddPoint(const std::string& prefix, - const ::std::vector<::std::pair>& points, + const std::vector>& points, tensorflow::SequenceExample* sequence); +void ClearPoint(const std::string& prefix, + tensorflow::SequenceExample* sequence); #define FIXED_PREFIX_BBOX_ACCESSORS(identifier, prefix) \ inline int CONCAT_STR3(Get, identifier, \ @@ -341,19 +345,47 @@ void AddPoint(const std::string& prefix, tensorflow::SequenceExample* sequence) { \ return AddBBox(prefix, bboxes, sequence); \ } \ + inline void CONCAT_STR2( \ + Clear, identifier)(tensorflow::SequenceExample * sequence) { \ + return ClearBBox(prefix, sequence); \ + } \ inline int CONCAT_STR3(Get, identifier, PointSize)( \ const tensorflow::SequenceExample& sequence) { \ return GetPointSize(prefix, sequence); \ } \ - inline ::std::vector<::std::pair> CONCAT_STR3( \ + inline int CONCAT_STR3(Get, identifier, PointSize)( \ + const std::string& name, const tensorflow::SequenceExample& sequence) { \ + return GetPointSize(name, sequence); \ + } \ + inline std::vector> CONCAT_STR3( \ Get, identifier, PointAt)(const tensorflow::SequenceExample& sequence, \ int index) { \ return GetPointAt(prefix, sequence, index); \ } \ + inline std::vector> CONCAT_STR3( \ + Get, identifier, PointAt)(const std::string& name, \ + const tensorflow::SequenceExample& sequence, \ + int index) { \ + return GetPointAt(name, sequence, index); \ + } \ inline void CONCAT_STR3(Add, identifier, Point)( \ - const ::std::vector<::std::pair>& points, \ + const std::vector>& points, \ tensorflow::SequenceExample* sequence) { \ return AddPoint(prefix, points, sequence); \ + } \ + inline void CONCAT_STR3(Add, identifier, Point)( \ + const std::string& name, \ + const std::vector>& points, \ + tensorflow::SequenceExample* sequence) { \ + return AddPoint(name, points, sequence); \ + } \ + inline void CONCAT_STR3(Clear, identifier, \ + Point)(tensorflow::SequenceExample * sequence) { \ + return ClearPoint(prefix, sequence); \ + } \ + inline void CONCAT_STR3(Clear, identifier, Point)( \ + std::string name, tensorflow::SequenceExample * sequence) { \ + return ClearPoint(name, sequence); \ } #define PREFIXED_BBOX(identifier, prefix) \ @@ -579,6 +611,7 @@ PREFIXED_FLOAT_CONTEXT_FEATURE(FeatureAudioSampleRate, // Reconciling bounding box annotations is optional because will remove // annotations if the sequence rate is lower than the annotation rate. ::mediapipe::Status ReconcileMetadata(bool reconcile_bbox_annotations, + bool reconcile_region_annotations, tensorflow::SequenceExample* sequence); } // namespace mediasequence } // namespace mediapipe diff --git a/mediapipe/util/sequence/media_sequence.py b/mediapipe/util/sequence/media_sequence.py index 321eee703..1e2572dcf 100644 --- a/mediapipe/util/sequence/media_sequence.py +++ b/mediapipe/util/sequence/media_sequence.py @@ -191,7 +191,8 @@ msu.create_bytes_context_feature( msu.create_bytes_context_feature( "clip_media_id", CLIP_MEDIA_ID_KEY, module_dict=globals()) msu.create_bytes_context_feature( - "clip_alternative_media_id", CLIP_MEDIA_ID_KEY, module_dict=globals()) + "clip_alternative_media_id", ALTERNATIVE_CLIP_MEDIA_ID_KEY, + module_dict=globals()) msu.create_bytes_context_feature( "clip_encoded_media_bytes", CLIP_ENCODED_MEDIA_BYTES_KEY, module_dict=globals()) diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index e815f52cf..bcd29c4c8 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -584,7 +584,7 @@ TEST(MediaSequenceTest, RoundTripOpticalFlowTimestamp) { TEST(MediaSequenceTest, ReconcileMetadataOnEmptySequence) { tensorflow::SequenceExample sequence; - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); } TEST(MediaSequenceTest, ReconcileMetadataImagestoLabels) { @@ -600,7 +600,7 @@ TEST(MediaSequenceTest, ReconcileMetadataImagestoLabels) { AddImageTimestamp(4, &sequence); AddImageTimestamp(5, &sequence); - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); ASSERT_THAT(GetSegmentStartIndex(sequence), testing::ElementsAreArray({2, 3})); ASSERT_THAT(GetSegmentEndIndex(sequence), testing::ElementsAreArray({3, 4})); @@ -617,7 +617,7 @@ TEST(MediaSequenceTest, ReconcileMetadataImages) { AddImageTimestamp(1000000, &sequence); AddImageTimestamp(2000000, &sequence); - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); ASSERT_EQ(GetContext(sequence, kImageFormatKey).bytes_list().value(0), "JPEG"); ASSERT_EQ(GetContext(sequence, kImageChannelsKey).int64_list().value(0), 3); @@ -638,7 +638,7 @@ TEST(MediaSequenceTest, ReconcileMetadataImagesPNG) { AddImageTimestamp(1000000, &sequence); AddImageTimestamp(2000000, &sequence); - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); ASSERT_EQ(GetContext(sequence, kImageFormatKey).bytes_list().value(0), "PNG"); ASSERT_EQ(GetContext(sequence, kImageChannelsKey).int64_list().value(0), 3); ASSERT_EQ(GetContext(sequence, kImageWidthKey).int64_list().value(0), 3); @@ -659,7 +659,7 @@ TEST(MediaSequenceTest, ReconcileMetadataFlowEncoded) { AddForwardFlowTimestamp(1000000, &sequence); AddForwardFlowTimestamp(2000000, &sequence); - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); ASSERT_EQ(GetForwardFlowFormat(sequence), "JPEG"); ASSERT_EQ(GetForwardFlowChannels(sequence), 3); ASSERT_EQ(GetForwardFlowWidth(sequence), 3); @@ -676,7 +676,7 @@ TEST(MediaSequenceTest, ReconcileMetadataFloats) { AddFeatureTimestamp(feature_name, 1000000, &sequence); AddFeatureTimestamp(feature_name, 2000000, &sequence); - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); ASSERT_EQ(GetFeatureDimensions(feature_name, sequence).size(), 1); ASSERT_EQ(GetFeatureDimensions(feature_name, sequence)[0], 3); ASSERT_EQ(GetFeatureRate(feature_name, sequence), 1.0); @@ -692,7 +692,7 @@ TEST(MediaSequenceTest, ReconcileMetadataFloatsDoesntOverwrite) { AddFeatureTimestamp(feature_name, 1000000, &sequence); AddFeatureTimestamp(feature_name, 2000000, &sequence); - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); ASSERT_EQ(GetFeatureDimensions(feature_name, sequence).size(), 3); ASSERT_EQ(GetFeatureDimensions(feature_name, sequence)[0], 1); ASSERT_EQ(GetFeatureDimensions(feature_name, sequence)[1], 3); @@ -710,7 +710,7 @@ TEST(MediaSequenceTest, ReconcileMetadataFloatsFindsMismatch) { AddFeatureTimestamp(feature_name, 1000000, &sequence); AddFeatureTimestamp(feature_name, 2000000, &sequence); - ASSERT_FALSE(ReconcileMetadata(true, &sequence).ok()); + ASSERT_FALSE(ReconcileMetadata(true, false, &sequence).ok()); } TEST(MediaSequenceTest, @@ -735,7 +735,7 @@ TEST(MediaSequenceTest, AddBBox(bboxes[1], &sequence); AddBBox(bboxes[2], &sequence); - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); ASSERT_EQ(GetBBoxTimestampSize(sequence), 3); ASSERT_EQ(GetBBoxTimestampAt(sequence, 0), 10); @@ -753,7 +753,7 @@ TEST(MediaSequenceTest, ASSERT_EQ(GetUnmodifiedBBoxTimestampAt(sequence, 1), 21); // A second reconciliation should not corrupt unmodified bbox timestamps. - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); ASSERT_EQ(GetBBoxTimestampSize(sequence), 3); ASSERT_EQ(GetBBoxTimestampAt(sequence, 0), 10); @@ -788,7 +788,7 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsFillsMissing) { AddBBox(bboxes[1], &sequence); AddBBox(bboxes[2], &sequence); - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); ASSERT_EQ(GetBBoxTimestampSize(sequence), 5); ASSERT_EQ(GetBBoxIsAnnotatedSize(sequence), 5); @@ -857,7 +857,7 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsUpdatesAllFeatures) { AddBBox(bboxes[0], &sequence); AddBBox(bboxes[1], &sequence); - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); ASSERT_EQ(GetBBoxTimestampSize(sequence), 5); ASSERT_EQ(GetBBoxIsAnnotatedSize(sequence), 5); @@ -987,7 +987,7 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsDoesNotAddFields) { AddBBox(bboxes[1], &sequence); AddBBox(bboxes[2], &sequence); - MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, &sequence)); + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); ASSERT_EQ(GetBBoxTimestampSize(sequence), 5); ASSERT_EQ(GetBBoxIsAnnotatedSize(sequence), 5); ASSERT_FALSE(HasBBoxClassIndex(sequence)); @@ -996,6 +996,45 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsDoesNotAddFields) { ASSERT_FALSE(HasBBoxClassString(sequence)); ASSERT_FALSE(HasBBoxTrackString(sequence)); } + +TEST(MediaSequenceTest, ReconcileMetadataRegionAnnotations) { + // Need image timestamps and label timestamps. + tensorflow::SequenceExample sequence; + + // Skip 0, so the indices are the (timestamp - 10) / 10 + AddImageTimestamp(10, &sequence); + AddImageTimestamp(20, &sequence); + AddImageTimestamp(30, &sequence); + + AddBBoxTimestamp(9, &sequence); + AddBBoxTimestamp(21, &sequence); + AddBBoxTimestamp(22, &sequence); // Will be dropped in the output. + + AddBBoxTimestamp("PREFIX", 8, &sequence); // Will be dropped in the output. + AddBBoxTimestamp("PREFIX", 9, &sequence); + AddBBoxTimestamp("PREFIX", 22, &sequence); + + // Expect both the default and "PREFIX"-ed keys to be reconciled. + MEDIAPIPE_ASSERT_OK(ReconcileMetadata(false, true, &sequence)); + ASSERT_EQ(GetBBoxTimestampSize(sequence), 3); + ASSERT_EQ(GetBBoxIsAnnotatedSize(sequence), 3); + ASSERT_EQ(GetBBoxTimestampSize("PREFIX", sequence), 3); + ASSERT_EQ(GetBBoxIsAnnotatedSize("PREFIX", sequence), 3); + + ASSERT_EQ(GetBBoxTimestampAt(sequence, 0), 10); + ASSERT_EQ(GetBBoxTimestampAt(sequence, 1), 20); + ASSERT_EQ(GetBBoxTimestampAt(sequence, 2), 30); + ASSERT_EQ(GetUnmodifiedBBoxTimestampSize(sequence), 2); + ASSERT_EQ(GetUnmodifiedBBoxTimestampAt(sequence, 0), 9); + ASSERT_EQ(GetUnmodifiedBBoxTimestampAt(sequence, 1), 21); + + ASSERT_EQ(GetBBoxTimestampAt("PREFIX", sequence, 0), 10); + ASSERT_EQ(GetBBoxTimestampAt("PREFIX", sequence, 1), 20); + ASSERT_EQ(GetBBoxTimestampAt("PREFIX", sequence, 2), 30); + ASSERT_EQ(GetUnmodifiedBBoxTimestampSize("PREFIX", sequence), 2); + ASSERT_EQ(GetUnmodifiedBBoxTimestampAt("PREFIX", sequence, 0), 9); + ASSERT_EQ(GetUnmodifiedBBoxTimestampAt("PREFIX", sequence, 1), 22); +} } // namespace } // namespace mediasequence } // namespace mediapipe diff --git a/mediapipe/util/sequence/media_sequence_util.h b/mediapipe/util/sequence/media_sequence_util.h index 0d0cd3e72..531605aba 100644 --- a/mediapipe/util/sequence/media_sequence_util.h +++ b/mediapipe/util/sequence/media_sequence_util.h @@ -125,7 +125,8 @@ inline const tensorflow::Feature& GetContext( // print the missing key when it check-fails. const auto it = sequence.context().feature().find(key); CHECK(it != sequence.context().feature().end()) - << "Could not find context key " << key; + << "Could not find context key " << key << ". Sequence: \n" + << sequence.DebugString(); return it->second; } @@ -219,7 +220,8 @@ inline const proto_ns::RepeatedField& GetFloatsAt( const tensorflow::SequenceExample& sequence, const std::string& key, const int index) { const tensorflow::FeatureList& fl = GetFeatureList(sequence, key); - CHECK_LT(index, fl.feature_size()); + CHECK_LT(index, fl.feature_size()) + << "Sequence: \n " << sequence.DebugString(); return fl.feature().Get(index).float_list().value(); } @@ -229,7 +231,8 @@ inline const proto_ns::RepeatedField& GetInt64sAt( const tensorflow::SequenceExample& sequence, const std::string& key, const int index) { const tensorflow::FeatureList& fl = GetFeatureList(sequence, key); - CHECK_LT(index, fl.feature_size()); + CHECK_LT(index, fl.feature_size()) + << "Sequence: \n " << sequence.DebugString(); return fl.feature().Get(index).int64_list().value(); } @@ -239,7 +242,8 @@ inline const proto_ns::RepeatedPtrField& GetBytesAt( const tensorflow::SequenceExample& sequence, const std::string& key, const int index) { const tensorflow::FeatureList& fl = GetFeatureList(sequence, key); - CHECK_LT(index, fl.feature_size()); + CHECK_LT(index, fl.feature_size()) + << "Sequence: \n " << sequence.DebugString(); return fl.feature().Get(index).bytes_list().value(); } diff --git a/mediapipe/util/time_series_test_util.h b/mediapipe/util/time_series_test_util.h new file mode 100644 index 000000000..0371ec42b --- /dev/null +++ b/mediapipe/util/time_series_test_util.h @@ -0,0 +1,522 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_TIME_SERIES_TEST_UTIL_H_ +#define MEDIAPIPE_UTIL_TIME_SERIES_TEST_UTIL_H_ + +#include +#include +#include + +#include "Eigen/Core" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/time_series_util.h" + +namespace mediapipe { + +// Base class for testing Calculators that operate on TimeSeries inputs. +// Subclasses that do not need a special options proto should inherit from +// the subclass BasicTimeSeriesCalculatorTestBase. +// +// This class handles calculators that accept one or more input streams +// specified either by indices or by tags and produce one or more output +// streams, again either specified by indices or tags. +// The default is to use one input stream and one output stream, specified by +// index. To use more streams by index, set num_input_streams_ or +// num_output_streams for the number of input or output streams, respectively. +// These have to be set before calling InitializeGraph(). To use one or more +// streams by tag, set input_stream_tags_ or output_stream_tags_ before calling +// InitializeGraph(), for example: +// input_stream_tags_ = {"MATRIX", "FRAMES"}; +// output_stream_tags_ = {"MATRIX"}; +// InitializeGraph(); +// These options are exclusive since mediapipe requires calculators to use +// either indices or tags, but not both. +template +class TimeSeriesCalculatorTest : public ::testing::Test { + protected: + // Sentinal value which can be used to tell methods like + // PopulateHeader to ignore certain fields. + static constexpr int kUnset = 0; + + TimeSeriesCalculatorTest() + : num_side_packets_(0), + num_input_streams_(1), + num_output_streams_(1), + input_packet_rate_(kUnset), + num_input_samples_(kUnset), + audio_sample_rate_(kUnset) {} + + // Makes the input stream names used in CalculatorRunner runner_. + // If tags are used, that is, input_stream_tags_ is not empty, it returns + // names of the form: + // :_, + // :_, etc. + // For the index format, returns names of the form _0, + // _1, etc. + std::vector MakeInputStreamNames(const std::string& base_name) { + if (!input_stream_tags_.empty()) { + return MakeNames(base_name, input_stream_tags_); + } else { + return MakeNames(base_name, num_input_streams_); + } + return std::vector(); + } + + // Same as MakeInputStreamNames, but for output streams. + std::vector MakeOutputStreamNames(const std::string& base_name) { + if (!output_stream_tags_.empty()) { + return MakeNames(base_name, output_stream_tags_); + } else { + return MakeNames(base_name, num_output_streams_); + } + return std::vector(); + } + + // Makes names used in CalculatorRunner runner_ that use the tag format. Tags + // must be capitalized. Returns names of the form + // :_, + // :_, etc. + std::vector MakeNames(const std::string& base_name, + const std::vector& tags) { + std::vector base_names; + std::vector ids; + for (const std::string& tag : tags) { + const std::string tagged_base_name = absl::StrCat(tag, ":", base_name); + base_names.push_back(tagged_base_name); + + std::string id; + id.reserve(tag.size()); + for (std::size_t i = 0; i < tag.size(); ++i) { + id += std::tolower(tag[i]); + } + ids.push_back(id); + } + const std::vector names = MakeNames(base_names, ids); + return names; + } + + // Makes names used in CalculatorRunner runner_ that use the index format. + // Total is the number of names to create. Returns names of the form + // _0, _1, ..., _. + std::vector MakeNames(const std::string& base_name, + const int total) { + std::vector base_names; + std::vector ids; + for (int i = 0; i < total; ++i) { + const std::string id = absl::StrCat(i); + ids.push_back(id); + base_names.push_back(base_name); + } + const std::vector names = MakeNames(base_names, ids); + return names; + } + + // Makes names used in CalculatorRunner runner_. Returns names of the form + // _, _, etc. + std::vector MakeNames(const std::vector& base_names, + const std::vector& ids) { + CHECK_EQ(base_names.size(), ids.size()); + std::vector names; + for (int i = 0; i < base_names.size(); ++i) { + const std::string name_template = R"($0_$1)"; + const std::string& base_name = base_names[i]; + const std::string& id = ids[i]; + const std::string name = absl::Substitute(name_template, base_name, id); + names.push_back(name); + } + return names; + } + + // Makes the CalculatorGraphConfig used to initialize CalculatorRunner + // runner_. If no options are needed, pass the empty std::string for options. + CalculatorGraphConfig::Node MakeNodeConfig(const std::string& calculator_name, + const int num_side_packets, + const CalculatorOptions& options) { + CalculatorGraphConfig::Node node_config; + node_config.set_calculator(calculator_name); + CalculatorOptions* node_config_options = node_config.mutable_options(); + *node_config_options = options; + + const std::string input_stream_base_name = "input_stream"; + const std::vector input_stream_names = + MakeInputStreamNames(input_stream_base_name); + for (const std::string& input_stream_name : input_stream_names) { + node_config.add_input_stream(input_stream_name); + } + + const std::string input_side_packet_base_name = "input_side_packet"; + const std::vector input_side_packet_names = + MakeNames(input_side_packet_base_name, num_side_packets); + for (const std::string& input_side_packet_name : input_side_packet_names) { + node_config.add_input_side_packet(input_side_packet_name); + } + + const std::string output_stream_base_name = "output_stream"; + const std::vector output_stream_names = + MakeOutputStreamNames(output_stream_base_name); + for (const std::string& output_stream_name : output_stream_names) { + node_config.add_output_stream(output_stream_name); + } + return node_config; + } + + void InitializeGraph(const CalculatorOptions& options) { + if (num_external_inputs_ != -1) { + LOG(WARNING) << "Use num_side_packets_ instead of num_external_inputs_."; + num_side_packets_ = num_external_inputs_; + } + + if (!input_stream_tags_.empty()) { + num_input_streams_ = input_stream_tags_.size(); + } + + if (!output_stream_tags_.empty()) { + num_output_streams_ = output_stream_tags_.size(); + } + + const CalculatorGraphConfig::Node node_config = + MakeNodeConfig(calculator_name_, num_side_packets_, options); + runner_.reset(new CalculatorRunner(node_config)); + } + + void InitializeGraph() { + CalculatorOptions options; + FillOptionsExtension(&options); + InitializeGraph(options); + } + + // Provide an alternative to InitializeGraph for tests that want the + // options not to be set. + void InitializeGraphWithoutOptions() { + CalculatorOptions options; // Left empty. + InitializeGraph(options); + } + + // This is broken out into a separate function to facilitate the + // NoOptions specialization defined below. + void FillOptionsExtension(CalculatorOptions* options) { + options->MutableExtension(OptionsClass::ext)->MergeFrom(options_); + } + + void PopulateHeader(TimeSeriesHeader* header) { + header->set_num_channels(num_input_channels_); + header->set_sample_rate(input_sample_rate_); + if (num_input_samples_ != kUnset) { + header->set_num_samples(num_input_samples_); + } + if (input_packet_rate_ != kUnset) { + header->set_packet_rate(input_packet_rate_); + } + if (audio_sample_rate_ != kUnset) { + header->set_audio_sample_rate(audio_sample_rate_); + } + } + + std::unique_ptr CreateInputHeader() { + std::unique_ptr header(new TimeSeriesHeader); + PopulateHeader(header.get()); + return header; + } + + void FillInputHeader(const size_t input_index = 0) { + runner_->MutableInputs()->Index(input_index).header = + Adopt(CreateInputHeader().release()); + } + + void FillInputHeader(const std::string& input_tag) { + runner_->MutableInputs()->Tag(input_tag).header = + Adopt(CreateInputHeader().release()); + } + + template + std::unique_ptr + CreateInputHeaderWithExtension( + const TimeSeriesHeaderExtensionClass& extension) { + auto header = CreateInputHeader(); + time_series_util::SetExtensionInHeader(extension, header.get()); + return header; + } + + template + void FillInputHeaderWithExtension( + const TimeSeriesHeaderExtensionClass& extension, + const size_t input_index = 0) { + auto header = CreateInputHeaderWithExtension(extension); + runner_->MutableInputs()->Index(input_index).header = + Adopt(header.release()); + } + + template + void FillInputHeaderWithExtension( + const TimeSeriesHeaderExtensionClass& extension, + const std::string& input_tag) { + auto header = CreateInputHeaderWithExtension(extension); + runner_->MutableInputs()->Tag(input_tag).header = Adopt(header.release()); + } + + // Takes ownership of payload. + template + void AppendInputPacket(const T* payload, const Timestamp timestamp, + const size_t input_index = 0) { + runner_->MutableInputs() + ->Index(input_index) + .packets.push_back(Adopt(payload).At(timestamp)); + } + + // Overload to allow explicit conversion from int64 to Timestamp + template + void AppendInputPacket(const T* payload, const int64 timestamp, + const size_t input_index = 0) { + AppendInputPacket(payload, Timestamp(timestamp), input_index); + } + + template + void AppendInputPacket(const T* payload, const Timestamp timestamp, + const std::string& input_tag) { + runner_->MutableInputs()->Tag(input_tag).packets.push_back( + Adopt(payload).At(timestamp)); + } + + template + void AppendInputPacket(const T* payload, const int64 timestamp, + const std::string& input_tag) { + AppendInputPacket(payload, Timestamp(timestamp), input_tag); + } + + ::mediapipe::Status RunGraph() { return runner_->Run(); } + + bool HasInputHeader(const size_t input_index = 0) const { + return input(input_index) + .header.template ValidateAsType() + .ok(); + } + + bool HasOutputHeader() const { + return output().header.template ValidateAsType().ok(); + } + + template + void ExpectOutputHeaderEquals(const Proto& expected, + const size_t output_index = 0) const { + EXPECT_THAT(output(output_index).header.template Get(), + mediapipe::EqualsProto(expected)); + } + + void ExpectOutputHeaderEqualsInputHeader( + const size_t input_index = 0, const size_t output_index = 0) const { + EXPECT_THAT( + output(output_index).header.template Get(), + mediapipe::EqualsProto( + input(input_index).header.template Get())); + } + + void ExpectOutputHeaderEqualsInputHeader( + const std::string& input_tag, const size_t output_index = 0) const { + EXPECT_THAT(output(output_index).header.template Get(), + mediapipe::EqualsProto( + input(input_tag).header.template Get())); + } + + void ExpectOutputHeaderEqualsInputHeader( + const size_t input_index, const std::string& output_tag) const { + EXPECT_THAT( + output(output_tag).header.template Get(), + mediapipe::EqualsProto( + input(input_index).header.template Get())); + } + + void ExpectOutputHeaderEqualsInputHeader( + const std::string& input_tag, const std::string& output_tag) const { + EXPECT_THAT(output(output_tag).header.template Get(), + mediapipe::EqualsProto( + input(input_tag).header.template Get())); + } + + void ExpectApproximatelyEqual(const Matrix& expected, + const Matrix& actual) const { + const float kPrecision = 1e-6; + EXPECT_TRUE(actual.isApprox(expected, kPrecision)) + << "Expected: " << expected << ", but got: " << actual; + } + + const CalculatorRunner::StreamContents& input( + const size_t input_index = 0) const { + return runner_->MutableInputs()->Index(input_index); + } + + const CalculatorRunner::StreamContents& input( + const std::string& input_tag) const { + return runner_->MutableInputs()->Tag(input_tag); + } + + const CalculatorRunner::StreamContents& output( + const size_t output_index = 0) const { + return runner_->Outputs().Index(output_index); + } + + const CalculatorRunner::StreamContents& output( + const std::string& output_tag) const { + return runner_->Outputs().Tag(output_tag); + } + + // Caller takes ownership of the return value. + static Matrix* NewRandomMatrix(int num_channels, int num_samples) { + // TODO: Fix a consistent lack of random seed setting in tests. + auto matrix = new Matrix; + matrix->setRandom(num_channels, num_samples); + return matrix; + } + + std::string calculator_name_; + OptionsClass options_; + int num_side_packets_; + int num_input_streams_; + std::vector input_stream_tags_; + int num_output_streams_; + std::vector output_stream_tags_; + // TODO For backwards compatibility, remove after all clients + // are updated. + int num_external_inputs_ = -1; + int num_input_channels_; + double input_sample_rate_; + // If this is non-zero, it will be used to set the packet_rate field + // of the header proto. + double input_packet_rate_; + // If this is non-zero, it will be used to set the num_samples field + // of the header proto. + int num_input_samples_; + // If this is non-zero, it will be used to set the audio_sample_rate field + // of the header proto. + double audio_sample_rate_; + + std::unique_ptr runner_; +}; + +template +class MultiStreamTimeSeriesCalculatorTest + : public TimeSeriesCalculatorTest { + protected: + void FillInputHeader() { + std::unique_ptr header( + new MultiStreamTimeSeriesHeader); + PopulateHeader(header.get()); + this->runner_->MutableInputs()->Index(0).header = Adopt(header.release()); + } + + template + void FillInputHeaderWithExtension( + const TimeSeriesHeaderExtensionClass& extension) { + std::unique_ptr header( + new MultiStreamTimeSeriesHeader); + PopulateHeader(header.get()); + time_series_util::SetExtensionInHeader( + extension, header->mutable_time_series_header()); + this->runner_->MutableInputs()->Index(0).header = Adopt(header.release()); + } + + // Takes ownership of input_vector. + void AppendInputPacket(const std::vector* input_vector, + const Timestamp timestamp) { + this->runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input_vector).At(timestamp)); + } + + // Overload to allow explicit conversion from int64 to Timestamp + void AppendInputPacket(const std::vector* input_vector, + const int64 timestamp) { + AppendInputPacket(input_vector, Timestamp(timestamp)); + } + + template + void ExpectOutputHeaderEquals(const StringOrProto& expected) const { + EXPECT_THAT( + this->output().header.template Get(), + mediapipe::EqualsProto(expected)); + } + + void ExpectOutputHeaderEqualsInputHeader() const { + ExpectOutputHeaderEquals( + this->input().header.template Get()); + } + + int num_input_streams_; + + private: + void PopulateHeader(MultiStreamTimeSeriesHeader* header) { + TimeSeriesCalculatorTest::PopulateHeader( + header->mutable_time_series_header()); + header->set_num_streams(num_input_streams_); + } +}; + +struct NoOptions {}; + +template <> +void TimeSeriesCalculatorTest::FillOptionsExtension( + CalculatorOptions* options) {} + +// Base class for testing basic time series calculators, which are calculators +// that take no options. +class BasicTimeSeriesCalculatorTestBase + : public TimeSeriesCalculatorTest { + protected: + TimeSeriesHeader ParseTextFormat(const std::string& text_format) { + TimeSeriesHeader header = + ParseTextProtoOrDie(text_format); + return header; + } + + void Test(const TimeSeriesHeader& input_header, + const std::vector& input_packets, + const TimeSeriesHeader& expected_output_header, + const std::vector& expected_output_packets) { + InitializeGraph(); + runner_->MutableInputs()->Index(0).header = + Adopt(new TimeSeriesHeader(input_header)); + for (int i = 0; i < input_packets.size(); ++i) { + const Timestamp timestamp(i * Timestamp::kTimestampUnitsPerSecond); + AppendInputPacket(new Matrix(input_packets[i]), timestamp); + } + + MEDIAPIPE_ASSERT_OK(RunGraph()); + + ExpectOutputHeaderEquals(expected_output_header); + EXPECT_EQ(input().packets.size(), output().packets.size()); + ASSERT_EQ(output().packets.size(), expected_output_packets.size()); + for (int i = 0; i < output().packets.size(); ++i) { + EXPECT_EQ(input().packets[i].Timestamp(), + output().packets[i].Timestamp()); + ExpectApproximatelyEqual(expected_output_packets[i], + output().packets[i].Get()); + } + } +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_TIME_SERIES_TEST_UTIL_H_ diff --git a/mediapipe/util/time_series_util.cc b/mediapipe/util/time_series_util.cc new file mode 100644 index 000000000..69f5d2587 --- /dev/null +++ b/mediapipe/util/time_series_util.cc @@ -0,0 +1,135 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/time_series_util.h" + +#include + +#include +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { +namespace time_series_util { + +bool LogWarningIfTimestampIsInconsistent(const Timestamp& current_timestamp, + const Timestamp& initial_timestamp, + int64 cumulative_samples, + double sample_rate) { + // Ignore the "special" timestamp value Done(). + if (current_timestamp == Timestamp::Done()) return true; + // Don't accept other special timestamp values. We may need to change this + // depending on how they're used in practice. + if (!current_timestamp.IsRangeValue()) { + LOG(WARNING) << "Unexpected special timestamp: " + << current_timestamp.DebugString(); + return false; + } + + // For non-special timestamp values, check whether the number of + // samples that have been processed is consistent with amount of + // time that has elapsed. + double expected_timestamp_seconds = + initial_timestamp.Seconds() + cumulative_samples / sample_rate; + if (fabs(current_timestamp.Seconds() - expected_timestamp_seconds) > + 0.5 / sample_rate) { + LOG_EVERY_N(WARNING, 20) + << std::fixed << "Timestamp " << current_timestamp.Seconds() + << " not consistent with number of samples " << cumulative_samples + << " and initial timestamp " << initial_timestamp + << ". Expected timestamp: " << expected_timestamp_seconds + << " Timestamp difference: " + << current_timestamp.Seconds() - expected_timestamp_seconds + << " sample_rate: " << sample_rate; + return false; + } else { + return true; + } +} + +::mediapipe::Status IsTimeSeriesHeaderValid(const TimeSeriesHeader& header) { + if (header.has_sample_rate() && header.sample_rate() >= 0 && + header.has_num_channels() && header.num_channels() >= 0) { + return ::mediapipe::OkStatus(); + } else { + std::string error_message = + "TimeSeriesHeader is missing necessary fields: " + "sample_rate or num_channels, or one of their values is negative. "; +#ifndef MEDIAPIPE_MOBILE + absl::StrAppend(&error_message, "Got header:\n", header.ShortDebugString()); +#endif + return tool::StatusInvalid(error_message); + } +} + +::mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, + TimeSeriesHeader* header) { + CHECK(header); + if (header_packet.IsEmpty()) { + return tool::StatusFail("No header found."); + } + if (!header_packet.ValidateAsType().ok()) { + return tool::StatusFail("Packet does not contain TimeSeriesHeader."); + } + *header = header_packet.Get(); + return IsTimeSeriesHeaderValid(*header); +} + +::mediapipe::Status FillMultiStreamTimeSeriesHeaderIfValid( + const Packet& header_packet, MultiStreamTimeSeriesHeader* header) { + CHECK(header); + if (header_packet.IsEmpty()) { + return tool::StatusFail("No header found."); + } + if (!header_packet.ValidateAsType().ok()) { + return tool::StatusFail( + "Packet does not contain MultiStreamTimeSeriesHeader."); + } + *header = header_packet.Get(); + if (!header->has_time_series_header()) { + return tool::StatusFail("No time series header found."); + } + return IsTimeSeriesHeaderValid(header->time_series_header()); +} + +::mediapipe::Status IsMatrixShapeConsistentWithHeader( + const Matrix& matrix, const TimeSeriesHeader& header) { + if (header.has_num_samples() && matrix.cols() != header.num_samples()) { + return tool::StatusInvalid(absl::StrCat( + "Matrix size is inconsistent with header. Expected ", + header.num_samples(), " columns, but found ", matrix.cols())); + } + if (header.has_num_channels() && matrix.rows() != header.num_channels()) { + return tool::StatusInvalid(absl::StrCat( + "Matrix size is inconsistent with header. Expected ", + header.num_channels(), " rows, but found ", matrix.rows())); + } + return ::mediapipe::OkStatus(); +} + +int64 SecondsToSamples(double time_in_seconds, double sample_rate) { + return round(time_in_seconds * sample_rate); +} + +double SamplesToSeconds(int64 num_samples, double sample_rate) { + DCHECK_NE(sample_rate, 0.0); + return (num_samples / sample_rate); +} + +} // namespace time_series_util +} // namespace mediapipe diff --git a/mediapipe/util/time_series_util.h b/mediapipe/util/time_series_util.h new file mode 100644 index 000000000..f881dd144 --- /dev/null +++ b/mediapipe/util/time_series_util.h @@ -0,0 +1,122 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Utility functions for MediaPipe time series streams. + +#ifndef MEDIAPIPE_UTIL_TIME_SERIES_UTIL_H_ +#define MEDIAPIPE_UTIL_TIME_SERIES_UTIL_H_ + +#include +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace time_series_util { + +// Logs a warning and returns false if the current_timestamp is +// inconsistent with the cumulative_samples that have been processed +// so far, assuming a constant sample_rate and an offset of +// initial_timestamp. +// +// "Special" timestamps are not considered consistent by this +// function. +bool LogWarningIfTimestampIsInconsistent(const Timestamp& current_timestamp, + const Timestamp& initial_timestamp, + int64 cumulative_samples, + double sample_rate); + +// Returns mediapipe::status::OK if the header is valid. Otherwise, returns a +// Status object with an error message. +::mediapipe::Status IsTimeSeriesHeaderValid(const TimeSeriesHeader& header); + +// Fills header and returns mediapipe::status::OK if the header is non-empty and +// valid. Otherwise, returns a Status object with an error message. +::mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, + TimeSeriesHeader* header); + +// Fills header and returns mediapipe::status::OK if the header contains a +// non-empty and valid TimeSeriesHeader. Otherwise, returns a Status object with +// an error message. +::mediapipe::Status FillMultiStreamTimeSeriesHeaderIfValid( + const Packet& header_packet, MultiStreamTimeSeriesHeader* header); + +// Returns::mediapipe::Status::OK iff options contains an extension of type +// OptionsClass. +template +::mediapipe::Status HasOptionsExtension(const CalculatorOptions& options) { + if (options.HasExtension(OptionsClass::ext)) { + return ::mediapipe::OkStatus(); + } + std::string error_message = "Options proto does not contain extension "; + absl::StrAppend(&error_message, + MediaPipeTypeStringOrDemangled()); +#ifndef MEDIAPIPE_MOBILE + // Avoid lite proto APIs on mobile targets. + absl::StrAppend(&error_message, " : ", options.DebugString()); +#endif + return ::mediapipe::InvalidArgumentError(error_message); +} + +// Returns::mediapipe::Status::OK if the shape of 'matrix' is consistent +// with the num_samples and num_channels fields present in 'header'. +// The corresponding matrix dimensions of unset header fields are +// ignored, so e.g. an empty header (which is not valid according to +// FillTimeSeriesHeaderIfValid) is considered consistent with any matrix. +::mediapipe::Status IsMatrixShapeConsistentWithHeader( + const Matrix& matrix, const TimeSeriesHeader& header); + +template +void FillOptionsExtensionOrDie(const CalculatorOptions& options, + OptionsClass* extension) { + MEDIAPIPE_CHECK_OK(HasOptionsExtension(options)); + extension->CopyFrom(options.GetExtension(OptionsClass::ext)); +} + +template +bool FillExtensionFromHeader(const TimeSeriesHeader& header, + TimeSeriesHeaderExtensionClass* extension) { + if (header.HasExtension(TimeSeriesHeaderExtensionClass::time_series_ext)) { + extension->CopyFrom( + header.GetExtension(TimeSeriesHeaderExtensionClass::time_series_ext)); + return true; + } else { + return false; + } +} + +template +void SetExtensionInHeader(const TimeSeriesHeaderExtensionClass& extension, + TimeSeriesHeader* header) { + header->MutableExtension(TimeSeriesHeaderExtensionClass::time_series_ext) + ->CopyFrom(extension); +} + +// Converts from a time_in_seconds to an integer number of samples. +int64 SecondsToSamples(double time_in_seconds, double sample_rate); + +// Converts from an integer number of samples to a time duration in seconds +// spanned by the samples. +double SamplesToSeconds(int64 num_samples, double sample_rate); + +} // namespace time_series_util +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_TIME_SERIES_UTIL_H_ diff --git a/mediapipe/util/time_series_util_test.cc b/mediapipe/util/time_series_util_test.cc new file mode 100644 index 000000000..5f3e119b9 --- /dev/null +++ b/mediapipe/util/time_series_util_test.cc @@ -0,0 +1,198 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/time_series_util.h" + +#include "Eigen/Core" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace time_series_util { +namespace { + +TEST(TimeSeriesUtilTest, LogWarningIfTimestampIsInconsistent) { + // "Special" timestamps aren't considered consistent. + EXPECT_FALSE(LogWarningIfTimestampIsInconsistent(Timestamp::Unset(), + Timestamp(0), 0, 1)); + + EXPECT_TRUE(LogWarningIfTimestampIsInconsistent( + Timestamp(2 * Timestamp::kTimestampUnitsPerSecond), + Timestamp(1 * Timestamp::kTimestampUnitsPerSecond), 10000, 10000.0)); + EXPECT_FALSE(LogWarningIfTimestampIsInconsistent( + Timestamp(2 * Timestamp::kTimestampUnitsPerSecond), + Timestamp(1 * Timestamp::kTimestampUnitsPerSecond), 10001, 10000.0)); +} + +TEST(TimeSeriesUtilTest, FillTimeSeriesHeaderIfValid) { + { + Packet empty_packet; + TimeSeriesHeader header; + EXPECT_FALSE(FillTimeSeriesHeaderIfValid(empty_packet, &header).ok()); + } + + { + std::unique_ptr valid_header(new TimeSeriesHeader); + valid_header->set_sample_rate(1234.5); + valid_header->set_num_channels(3); + Packet valid_packet = Adopt(valid_header.release()); + TimeSeriesHeader packet_header; + MEDIAPIPE_EXPECT_OK( + FillTimeSeriesHeaderIfValid(valid_packet, &packet_header)); + EXPECT_EQ(packet_header.sample_rate(), 1234.5); + EXPECT_EQ(packet_header.num_channels(), 3); + } + + { + std::unique_ptr missing_num_channels_header( + new TimeSeriesHeader); + missing_num_channels_header->set_sample_rate(1234.5); + Packet packet_with_missing_num_channel = + Adopt(missing_num_channels_header.release()); + TimeSeriesHeader header; + EXPECT_FALSE( + FillTimeSeriesHeaderIfValid(packet_with_missing_num_channel, &header) + .ok()); + } + + { + std::unique_ptr missing_sample_rate_header( + new TimeSeriesHeader); + missing_sample_rate_header->set_num_channels(3); + Packet packet_with_missing_sample_rate = + Adopt(missing_sample_rate_header.release()); + TimeSeriesHeader header; + EXPECT_FALSE( + FillTimeSeriesHeaderIfValid(packet_with_missing_sample_rate, &header) + .ok()); + } +} + +TEST(TimeSeriesUtilTest, FillMultiStreamTimeSeriesHeaderIfValid) { + { + Packet empty_packet; + MultiStreamTimeSeriesHeader header; + EXPECT_FALSE( + FillMultiStreamTimeSeriesHeaderIfValid(empty_packet, &header).ok()); + } + + { + Packet packet_with_empty_header = Adopt(new MultiStreamTimeSeriesHeader()); + MultiStreamTimeSeriesHeader header; + EXPECT_FALSE(FillMultiStreamTimeSeriesHeaderIfValid( + packet_with_empty_header, &header) + .ok()); + } + + { + std::unique_ptr valid_header( + new MultiStreamTimeSeriesHeader); + valid_header->mutable_time_series_header()->set_sample_rate(1234.5); + valid_header->mutable_time_series_header()->set_num_channels(3); + Packet valid_packet = Adopt(valid_header.release()); + MultiStreamTimeSeriesHeader packet_header; + MEDIAPIPE_EXPECT_OK( + FillMultiStreamTimeSeriesHeaderIfValid(valid_packet, &packet_header)); + EXPECT_EQ(packet_header.time_series_header().sample_rate(), 1234.5); + EXPECT_EQ(packet_header.time_series_header().num_channels(), 3); + } + + { + TimeSeriesHeader missing_num_channels_header; + std::unique_ptr + header_with_invalid_time_series_header(new MultiStreamTimeSeriesHeader); + header_with_invalid_time_series_header->mutable_time_series_header() + ->set_sample_rate(1234.5); + Packet packet_with_invalid_time_series_header = + Adopt(header_with_invalid_time_series_header.release()); + MultiStreamTimeSeriesHeader header; + EXPECT_FALSE(FillMultiStreamTimeSeriesHeaderIfValid( + packet_with_invalid_time_series_header, &header) + .ok()); + } +} + +TEST(IsMatrixShapeConsistentWithHeaderTest, BasicOperation) { + TimeSeriesHeader header; + header.set_num_samples(2); + header.set_num_channels(3); + + EXPECT_TRUE( + IsMatrixShapeConsistentWithHeader(Matrix::Zero(3, 2), header).ok()); + EXPECT_FALSE( + IsMatrixShapeConsistentWithHeader(Matrix::Zero(0, 0), header).ok()); + // Transposed Matrix. + EXPECT_FALSE( + IsMatrixShapeConsistentWithHeader(Matrix::Zero(2, 3), header).ok()); +} + +TEST(IsMatrixShapeConsistentWithHeaderTest, + EmptyHeaderConsistentWithAnyMatrix) { + TimeSeriesHeader empty_header; + EXPECT_TRUE( + IsMatrixShapeConsistentWithHeader(Matrix::Zero(0, 0), empty_header).ok()); + EXPECT_TRUE( + IsMatrixShapeConsistentWithHeader(Matrix::Zero(3, 2), empty_header).ok()); +} + +TEST(IsMatrixShapeConsistentWithHeaderTest, + NumChannelsUnsetConsistentWithAnyNumRows) { + TimeSeriesHeader header; + header.set_num_channels(2); + for (int num_cols : {1, 2, 5, 9}) { + EXPECT_TRUE( + IsMatrixShapeConsistentWithHeader(Matrix::Zero(2, num_cols), header) + .ok()); + } +} + +TEST(IsMatrixShapeConsistentWithHeaderTest, + NumSamplesUnsetConsistentWithAnyNumColumns) { + TimeSeriesHeader header; + header.set_num_samples(2); + for (int num_rows : {1, 2, 5, 9}) { + EXPECT_TRUE( + IsMatrixShapeConsistentWithHeader(Matrix::Zero(num_rows, 2), header) + .ok()); + } +} + +TEST(TimeSeriesUtilTest, SecondsToSamples) { + // If the time is an integer multiple of the sampling period, we + // should get an exact result. + double sample_rate = 10.0; + double integer_multiple_time = 5; + EXPECT_EQ(integer_multiple_time * sample_rate, + SecondsToSamples(integer_multiple_time, sample_rate)); + + // Otherwise we should be within one sample. + double arbitrary_time = 5.01; + EXPECT_NEAR(arbitrary_time * sample_rate, + SecondsToSamples(arbitrary_time, sample_rate), 1); +} + +TEST(TimeSeriesUtilTest, SamplesToSeconds) { + double sample_rate = 32.5; + int64 num_samples = 128; + EXPECT_EQ(num_samples / sample_rate, + SamplesToSeconds(num_samples, sample_rate)); +} + +} // namespace +} // namespace time_series_util +} // namespace mediapipe diff --git a/setup_android_sdk_and_ndk.sh b/setup_android_sdk_and_ndk.sh index 319bd797d..0d6efadec 100644 --- a/setup_android_sdk_and_ndk.sh +++ b/setup_android_sdk_and_ndk.sh @@ -47,8 +47,8 @@ fi if [ -z $2 ] then - echo "Warning: android_ndk_path (argument 2) is not specified. Fallback to ~/Android/Ndk/android-ndk-/" - android_ndk_path=$HOME"/Android/Ndk" + echo "Warning: android_ndk_path (argument 2) is not specified. Fallback to ~/Android/Sdk/ndk-bundle/android-ndk-/" + android_ndk_path=$HOME"/Android/Sdk/ndk-bundle" fi if [ -z $3 ] @@ -67,7 +67,7 @@ else unzip /tmp/android_sdk/android_sdk.zip -d /tmp/android_sdk/ mkdir -p $android_sdk_path /tmp/android_sdk/tools/bin/sdkmanager --update - /tmp/android_sdk/tools/bin/sdkmanager "build-tools;28.0.3" "platform-tools" "platforms;android-28" "extras;android;m2repository" --sdk_root=${android_sdk_path} + /tmp/android_sdk/tools/bin/sdkmanager "build-tools;29.0.1" "platform-tools" "platforms;android-29" --sdk_root=${android_sdk_path} rm -rf /tmp/android_sdk/ echo "Android SDK is now installed. Consider setting \$ANDROID_HOME environment variable to be ${android_sdk_path}" fi diff --git a/setup_opencv.sh b/setup_opencv.sh index 63635851d..fcf1be7a5 100644 --- a/setup_opencv.sh +++ b/setup_opencv.sh @@ -33,11 +33,22 @@ sudo apt install cmake ffmpeg libavformat-dev libdc1394-22-dev libgtk2.0-dev \ rm -rf /tmp/build_opencv mkdir /tmp/build_opencv cd /tmp/build_opencv +git clone https://github.com/opencv/opencv_contrib.git git clone https://github.com/opencv/opencv.git mkdir opencv/release cd opencv/release cmake .. -DCMAKE_BUILD_TYPE=RELEASE -DCMAKE_INSTALL_PREFIX=/usr/local \ - -DBUILD_TESTS=OFF -DBUILD_PERF_TESTS=OFF -DBUILD_opencv_ts=OFF + -DBUILD_TESTS=OFF -DBUILD_PERF_TESTS=OFF -DBUILD_opencv_ts=OFF \ + -DOPENCV_EXTRA_MODULES_PATH=/tmp/build_opencv/opencv_contrib/modules \ + -DBUILD_opencv_aruco=OFF -DBUILD_opencv_bgsegm=OFF -DBUILD_opencv_bioinspired=OFF \ + -DBUILD_opencv_ccalib=OFF -DBUILD_opencv_datasets=OFF -DBUILD_opencv_dnn=OFF \ + -DBUILD_opencv_dnn_objdetect=OFF -DBUILD_opencv_dpm=OFF -DBUILD_opencv_face=OFF \ + -DBUILD_opencv_fuzzy=OFF -DBUILD_opencv_hfs=OFF -DBUILD_opencv_img_hash=OFF \ + -DBUILD_opencv_js=OFF -DBUILD_opencv_line_descriptor=OFF -DBUILD_opencv_phase_unwrapping=OFF \ + -DBUILD_opencv_plot=OFF -DBUILD_opencv_quality=OFF -DBUILD_opencv_reg=OFF \ + -DBUILD_opencv_rgbd=OFF -DBUILD_opencv_saliency=OFF -DBUILD_opencv_shape=OFF \ + -DBUILD_opencv_structured_light=OFF -DBUILD_opencv_surface_matching=OFF \ + -DBUILD_opencv_world=OFF -DBUILD_opencv_xobjdetect=OFF -DBUILD_opencv_xphoto=OFF make -j sudo make install rm -rf /tmp/build_opencv @@ -46,6 +57,7 @@ echo "OpenCV has been built. You can find the header files and libraries in /usr # Modify the build file. echo "Modifying MediaPipe opencv config" sed -i "s/lib\/x86_64-linux-gnu/local\/lib/g" $opencv_build_file +sed -i "20i \ \"local/lib/libopencv_optflow.so*\"," $opencv_build_file sed -i "/includes =/d" $opencv_build_file sed -i "/hdrs =/d" $opencv_build_file line_to_insert=$(grep -n 'linkstatic =' $opencv_build_file | awk -F ":" '{print $1}')'i' diff --git a/third_party/BUILD b/third_party/BUILD index 2821a0f95..65fda9d22 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -38,6 +38,9 @@ cc_library( "//mediapipe:android_arm64": [ "@android_opencv//:libopencv_arm64-v8a", ], + "//mediapipe:ios": [ + "@ios_opencv//:opencv", + ], "//mediapipe:macos": [ "@macos_opencv//:opencv", ], @@ -47,10 +50,29 @@ cc_library( }), ) +cc_library( + name = "libffmpeg", + visibility = ["//visibility:public"], + deps = select({ + "//mediapipe:android_x86": [], + "//mediapipe:android_x86_64": [], + "//mediapipe:android_armeabi": [], + "//mediapipe:android_arm": [], + "//mediapipe:android_arm64": [], + "//mediapipe:ios": [], + "//mediapipe:macos": [ + "@macos_ffmpeg//:libffmpeg", + ], + "//conditions:default": [ + "@linux_ffmpeg//:libffmpeg", + ], + }), +) + android_library( - name = "android_constraint_layout", + name = "androidx_annotation", exports = [ - "@maven//:com_android_support_constraint_constraint_layout", + "@maven//:androidx_annotation_annotation", ], ) @@ -61,6 +83,41 @@ android_library( ], ) +android_library( + name = "androidx_constraint_layout", + exports = [ + "@maven//:androidx_constraintlayout_constraintlayout", + ], +) + +android_library( + name = "androidx_core", + exports = [ + "@maven//:androidx_core_core", + ], +) + +android_library( + name = "androidx_legacy_support_v4", + exports = [ + "@maven//:androidx_legacy_legacy_support_v4", + ], +) + +android_library( + name = "androidx_material", + exports = [ + "@maven//:com_google_android_material_material", + ], +) + +android_library( + name = "androidx_recyclerview", + exports = [ + "@maven//:androidx_recyclerview_recyclerview", + ], +) + # TODO: Get the AARs from Google's Maven Repository. aar_import( name = "camerax_core", diff --git a/third_party/ffmpeg_linux.BUILD b/third_party/ffmpeg_linux.BUILD new file mode 100644 index 000000000..66179df6d --- /dev/null +++ b/third_party/ffmpeg_linux.BUILD @@ -0,0 +1,35 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # LGPL + +exports_files(["LICENSE"]) + +cc_library( + name = "libffmpeg", + srcs = glob( + [ + "lib/x86_64-linux-gnu/libav*.so*", + ], + ), + hdrs = glob(["include/x86_64-linux-gnu/libav*/*.h"]), + includes = ["include"], + linkopts = [ + "-lavcodec", + "-lavformat", + "-lavutil", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) diff --git a/third_party/ffmpeg_macos.BUILD b/third_party/ffmpeg_macos.BUILD new file mode 100644 index 000000000..6e6f94aa5 --- /dev/null +++ b/third_party/ffmpeg_macos.BUILD @@ -0,0 +1,35 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # LGPL + +exports_files(["LICENSE"]) + +cc_library( + name = "libffmpeg", + srcs = glob( + [ + "local/opt/ffmpeg/lib/libav*.dylib", + ], + ), + hdrs = glob(["local/opt/ffmpeg/include/libav*/*.h"]), + includes = ["local/opt/ffmpeg/include/"], + linkopts = [ + "-lavcodec", + "-lavformat", + "-lavutil", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) diff --git a/third_party/glog.BUILD b/third_party/glog.BUILD index fad57fbad..ea92ff38d 100644 --- a/third_party/glog.BUILD +++ b/third_party/glog.BUILD @@ -17,9 +17,41 @@ licenses(["notice"]) exports_files(["LICENSE"]) config_setting( - name = "android", + name = "android_arm", values = { - "crosstool_top": "//external:android/crosstool", + "cpu": "armeabi-v7a", + }, + visibility = ["//visibility:private"], +) + +config_setting( + name = "android_arm64", + values = { + "cpu": "arm64-v8a", + }, + visibility = ["//visibility:private"], +) + +config_setting( + name = "ios_armv7", + values = { + "cpu": "ios_armv7", + }, + visibility = ["//visibility:private"], +) + +config_setting( + name = "ios_arm64", + values = { + "cpu": "ios_arm64", + }, + visibility = ["//visibility:private"], +) + +config_setting( + name = "ios_arm64e", + values = { + "cpu": "ios_arm64e", }, visibility = ["//visibility:private"], ) @@ -28,6 +60,7 @@ config_setting( name = "libunwind", values = { "define": "libunwind=true", + "cpu": "k8", }, visibility = ["//visibility:private"], ) @@ -67,12 +100,20 @@ cc_library( "//conditions:default": [], }) + select({ "//conditions:default": ["-lpthread"], - ":android": [], + ":android_arm": [], + ":android_arm64": [], + ":ios_armv7": [], + ":ios_arm64": [], + ":ios_arm64e": [], }), visibility = ["//visibility:public"], deps = select({ "//conditions:default": ["@com_github_gflags_gflags//:gflags"], - ":android": [], + ":android_arm": [], + ":android_arm64": [], + ":ios_armv7": [], + ":ios_arm64": [], + ":ios_arm64e": [], }), ) @@ -115,7 +156,11 @@ genrule( name = "config_h", srcs = select({ "//conditions:default": ["config.h.tmp"], - ":android": ["config.h.android"], + ":android_arm": ["config.h.android_arm"], + ":android_arm64": ["config.h.android_arm"], + ":ios_armv7": ["config.h.ios_arm"], + ":ios_arm64": ["config.h.ios_arm"], + ":ios_arm64e": ["config.h.ios_arm"], }), outs = ["config.h"], cmd = "echo select $< to be the glog config file. && cp $< $@", @@ -125,16 +170,17 @@ genrule( name = "logging_h", srcs = select({ "//conditions:default": ["src/glog/logging.h.tmp"], - ":android": ["src/glog/logging.h.android"], + ":android_arm": ["src/glog/logging.h.android_arm"], + ":android_arm64": ["src/glog/logging.h.android_arm"], }), outs = ["src/glog/logging.h"], cmd = "echo select $< to be the glog logging.h file. && cp $< $@", ) -# Hardcoded android config header for glog library. +# Hardcoded android arm config header for glog library. # TODO: This is a temporary workaround. We should generate the config # header by running the configure script with the right target toolchain. -ANDROID_CONFIG = """ +ANDROID_ARM_CONFIG = """ /* Define if glog does not use RTTI */ /* #undef DISABLE_RTTI */ @@ -319,15 +365,15 @@ your system. */ """ genrule( - name = "gen_android_config", - outs = ["config.h.android"], - cmd = ("echo '%s' > $(location config.h.android)" % ANDROID_CONFIG), + name = "gen_android_arm_config", + outs = ["config.h.android_arm"], + cmd = ("echo '%s' > $(location config.h.android_arm)" % ANDROID_ARM_CONFIG), ) genrule( - name = "generate_android_glog_logging_h", + name = "generate_android_arm_glog_logging_h", srcs = ["src/glog/logging.h.in"], - outs = ["src/glog/logging.h.android"], + outs = ["src/glog/logging.h.android_arm"], cmd = ("sed -e 's/@ac_cv___attribute___noinline@/__attribute__((__noinline__))/g'" + " -e 's/@ac_cv___attribute___noreturn@/__attribute__((__noreturn__))/g'" + " -e 's/@ac_cv_have___builtin_expect@/1/g'" + @@ -344,3 +390,193 @@ genrule( " -e 's/@ac_google_start_namespace@/namespace google {/g'" + " $< > $@"), ) + +# Hardcoded ios arm config header for glog library. +# TODO: This is a temporary workaround. We should generate the config +# header by running the configure script with the right target toolchain. +IOS_ARM_CONFIG = """ +/* define if glog doesnt use RTTI */ +/* #undef DISABLE_RTTI */ + +/* Namespace for Google classes */ +#define GOOGLE_NAMESPACE google + +/* Define if you have the 'dladdr' function */ +#define HAVE_DLADDR 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_DLFCN_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_EXECINFO_H 1 + +/* Define if you have the 'fcntl' function */ +#define HAVE_FCNTL 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_GLOB_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_INTTYPES_H 1 + +/* Define to 1 if you have the 'pthread' library (-lpthread). */ +#define HAVE_LIBPTHREAD 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_LIBUNWIND_H 1 + +/* define if you have google gflags library */ +/* #undef HAVE_LIB_GFLAGS */ + +/* define if you have google gmock library */ +/* #undef HAVE_LIB_GMOCK */ + +/* define if you have google gtest library */ +/* #undef HAVE_LIB_GTEST */ + +/* define if you have libunwind */ +/* #undef HAVE_LIB_UNWIND */ + +/* Define to 1 if you have the header file. */ +#define HAVE_MEMORY_H 1 + +/* define if the compiler implements namespaces */ +#define HAVE_NAMESPACES 1 + +/* Define if you have the 'pread' function */ +#define HAVE_PREAD 1 + +/* Define if you have POSIX threads libraries and header files. */ +#define HAVE_PTHREAD 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_PWD_H 1 + +/* Define if you have the 'pwrite' function */ +#define HAVE_PWRITE 1 + +/* define if the compiler implements pthread_rwlock_* */ +#define HAVE_RWLOCK 1 + +/* Define if you have the 'sigaction' function */ +#define HAVE_SIGACTION 1 + +/* Define if you have the 'sigaltstack' function */ +#define HAVE_SIGALTSTACK 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_STDINT_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_STDLIB_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_STRINGS_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_STRING_H 1 + +/* Define to 1 if you have the header file. */ +/* #undef HAVE_SYSCALL_H */ + +/* Define to 1 if you have the header file. */ +#define HAVE_SYSLOG_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_STAT_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_SYSCALL_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_TIME_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_TYPES_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_UCONTEXT_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_UTSNAME_H 1 + +/* Define to 1 if you have the header file. */ +/* #undef HAVE_UCONTEXT_H */ + +/* Define to 1 if you have the header file. */ +#define HAVE_UNISTD_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_UNWIND_H 1 + +/* define if the compiler supports using expression for operator */ +#define HAVE_USING_OPERATOR 1 + +/* define if your compiler has __attribute__ */ +#define HAVE___ATTRIBUTE__ 1 + +/* define if your compiler has __builtin_expect */ +#define HAVE___BUILTIN_EXPECT 1 + +/* define if your compiler has __sync_val_compare_and_swap */ +#define HAVE___SYNC_VAL_COMPARE_AND_SWAP 1 + +/* Define to the sub-directory in which libtool stores uninstalled libraries. + */ +#define LT_OBJDIR ".libs/" + +/* Name of package */ +#define PACKAGE "glog" + +/* Define to the address where bug reports for this package should be sent. */ +#define PACKAGE_BUGREPORT "opensource@google.com" + +/* Define to the full name of this package. */ +#define PACKAGE_NAME "glog" + +/* Define to the full name and version of this package. */ +#define PACKAGE_STRING "glog 0.3.5" + +/* Define to the one symbol short name of this package. */ +#define PACKAGE_TARNAME "glog" + +/* Define to the home page for this package. */ +#define PACKAGE_URL "" + +/* Define to the version of this package. */ +#define PACKAGE_VERSION "0.3.5" + +/* How to access the PC from a struct ucontext */ +/* #undef PC_FROM_UCONTEXT */ + +/* Define to necessary symbol if this constant uses a non-standard name on + your system. */ +/* #undef PTHREAD_CREATE_JOINABLE */ + +/* The size of 'void *', as computed by sizeof. */ +#define SIZEOF_VOID_P 8 + +/* Define to 1 if you have the ANSI C header files. */ +/* #undef STDC_HEADERS */ + +/* the namespace where STL code like vector<> is defined */ +#define STL_NAMESPACE std + +/* location of source code */ +#define TEST_SRC_DIR "external/com_google_glog" + +/* Version number of package */ +#define VERSION "0.3.5" + +/* Stops putting the code inside the Google namespace */ +#define _END_GOOGLE_NAMESPACE_ } + +/* Puts following code inside the Google namespace */ +#define _START_GOOGLE_NAMESPACE_ namespace google { +""" + +genrule( + name = "gen_ios_arm_config", + outs = ["config.h.ios_arm"], + cmd = ("echo '%s' > $(location config.h.ios_arm)" % IOS_ARM_CONFIG), +) diff --git a/third_party/google_toolbox_for_mac.BUILD b/third_party/google_toolbox_for_mac.BUILD new file mode 100644 index 000000000..06497d29c --- /dev/null +++ b/third_party/google_toolbox_for_mac.BUILD @@ -0,0 +1,509 @@ +# Description: +# A collection of source from different Google projects that may be of use to +# developers working other Mac projects. + +# To run all the test from the command line: +# bazel test \ +# --build_tests_only \ +# @google_toolbox_for_mac///... + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +exports_files( + ["UnitTest-Info.plist"], + visibility = ["//visibility:public"], +) + +objc_library( + name = "GTM_Defines", + hdrs = ["GTMDefines.h"], + includes = ["."], + visibility = ["//visibility:public"], +) + +objc_library( + name = "GTM_TypeCasting", + hdrs = [ + "DebugUtils/GTMTypeCasting.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_LocalizedString", + hdrs = [ + "Foundation/GTMLocalizedString.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_NSStringURLArguments", + srcs = [ + "Foundation/GTMNSString+URLArguments.m", + ], + hdrs = [ + "Foundation/GTMNSString+URLArguments.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_GeometryUtils", + srcs = [ + "Foundation/GTMGeometryUtils.m", + ], + hdrs = [ + "Foundation/GTMGeometryUtils.h", + ], + sdk_frameworks = ["CoreGraphics"], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +# Since this is just .h files, it is ok to not divide this into sub targets as it +# doesn't cause any extra code to be linked in when some just wants a subset of +# it. +objc_library( + name = "GTM_DebugUtils", + hdrs = [ + "DebugUtils/GTMDebugSelectorValidation.h", + "DebugUtils/GTMDebugThreadValidation.h", + "DebugUtils/GTMMethodCheck.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_SynchronizationAsserts", + srcs = [ + "DebugUtils/GTMSynchronizationAsserts.m", + ], + hdrs = [ + "DebugUtils/GTMSynchronizationAsserts.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_KVO", + hdrs = [ + "Foundation/GTMNSObject+KeyValueObserving.h", + ], + non_arc_srcs = [ + "Foundation/GTMNSObject+KeyValueObserving.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_DebugUtils", + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_Regex", + hdrs = [ + "Foundation/GTMRegex.h", + ], + non_arc_srcs = [ + "Foundation/GTMRegex.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_StringEncoding", + hdrs = [ + "Foundation/GTMStringEncoding.h", + ], + non_arc_srcs = [ + "Foundation/GTMStringEncoding.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_NSScannerJSON", + srcs = [ + "Foundation/GTMNSScanner+JSON.m", + ], + hdrs = [ + "Foundation/GTMNSScanner+JSON.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_NSStringHTML", + hdrs = [ + "Foundation/GTMNSString+HTML.h", + ], + non_arc_srcs = [ + "Foundation/GTMNSString+HTML.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_NSStringXML", + srcs = [ + "Foundation/GTMNSString+XML.m", + ], + hdrs = [ + "Foundation/GTMNSString+XML.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_NSThreadBlocks", + hdrs = [ + "Foundation/GTMNSThread+Blocks.h", + ], + non_arc_srcs = [ + "Foundation/GTMNSThread+Blocks.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_TimeUtils", + hdrs = [ + "Foundation/GTMTimeUtils.h", + ], + non_arc_srcs = [ + "Foundation/GTMTimeUtils.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_SQLite", + hdrs = [ + "Foundation/GTMSQLite.h", + ], + non_arc_srcs = [ + "Foundation/GTMSQLite.m", + ], + sdk_dylibs = ["libsqlite3"], + visibility = ["//visibility:public"], + deps = [ + ":GTM_DebugUtils", + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_SystemVersion", + hdrs = [ + "Foundation/GTMSystemVersion.h", + ], + non_arc_srcs = [ + "Foundation/GTMSystemVersion.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_GTMURLBuilder", + hdrs = [ + "Foundation/GTMURLBuilder.h", + ], + non_arc_srcs = [ + "Foundation/GTMURLBuilder.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Logger", + ":GTM_NSDictionaryURLArguments", + ], +) + +objc_library( + name = "GTM_NSDictionaryURLArguments", + srcs = [ + "Foundation/GTMNSDictionary+URLArguments.m", + ], + hdrs = [ + "Foundation/GTMNSDictionary+URLArguments.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_DebugUtils", + ":GTM_NSStringURLArguments", + ], +) + +objc_library( + name = "GTM_StackTrace", + hdrs = [ + "Foundation/GTMStackTrace.h", + ], + non_arc_srcs = [ + "Foundation/GTMStackTrace.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_NSDataZlib", + srcs = [ + "Foundation/GTMNSData+zlib.m", + ], + hdrs = [ + "Foundation/GTMNSData+zlib.h", + ], + sdk_dylibs = [ + "libz", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_NSFileHandleUniqueName", + hdrs = [ + "Foundation/GTMNSFileHandle+UniqueName.h", + ], + non_arc_srcs = [ + "Foundation/GTMNSFileHandle+UniqueName.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_UIFontLineHeight", + srcs = [ + "iPhone/GTMUIFont+LineHeight.m", + ], + hdrs = [ + "iPhone/GTMUIFont+LineHeight.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_RoundedRectPath", + srcs = [ + "iPhone/GTMRoundedRectPath.m", + ], + hdrs = [ + "iPhone/GTMRoundedRectPath.h", + ], + sdk_frameworks = ["CoreGraphics"], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_UIImageResize", + srcs = [ + "iPhone/GTMUIImage+Resize.m", + ], + hdrs = [ + "iPhone/GTMUIImage+Resize.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_FadeTruncatingLabel", + hdrs = [ + "iPhone/GTMFadeTruncatingLabel.h", + ], + non_arc_srcs = [ + "iPhone/GTMFadeTruncatingLabel.m", + ], + visibility = ["//visibility:public"], +) + +objc_library( + name = "GTM_UILocalizer", + hdrs = select({ + "//mediapipe:macos": ["AppKit/GTMUILocalizer.h"], + "//conditions:default": ["iPhone/GTMUILocalizer.h"], + }), + non_arc_srcs = select({ + "//mediapipe:macos": ["AppKit/GTMUILocalizer.m"], + "//conditions:default": ["iPhone/GTMUILocalizer.m"], + }), + sdk_frameworks = select({ + "//mediapipe:macos": ["AppKit"], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], + # On MacOS, mark alwayslink in case this is referenced only from a XIB and + # would otherwise be stripped. + alwayslink = select({ + "//mediapipe:ios": 0, + "//conditions:default": 1, + }), +) + +# NOTE: This target is only available for MacOS, not iPhone. +objc_library( + name = "GTM_UILocalizerAndLayoutTweaker", + hdrs = ["AppKit/GTMUILocalizerAndLayoutTweaker.h"], + non_arc_srcs = ["AppKit/GTMUILocalizerAndLayoutTweaker.m"], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ":GTM_UILocalizer", + ], + # Mark alwayslink in case this is referenced only from a XIB and would + # otherwise be stripped. + alwayslink = 1, +) + +GTM_UNITTESTING_HDRS = [ + "UnitTesting/GTMFoundationUnitTestingUtilities.h", + "UnitTesting/GTMSenTestCase.h", + "UnitTesting/GTMTestTimer.h", +] + +GTM_UNITTESTING_NON_ARC_SRCS = [ + "UnitTesting/GTMFoundationUnitTestingUtilities.m", + "UnitTesting/GTMSenTestCase.m", +] + +GTM_UNITTESTING_SDK_FRAMEWORKS = [ + "CoreGraphics", + "QuartzCore", +] + +GTM_UNITTESTING_DEPS = [ + ":GTM_Regex", + ":GTM_SystemVersion", +] + +objc_library( + name = "GTM_UnitTesting", + testonly = 1, + hdrs = GTM_UNITTESTING_HDRS, + non_arc_srcs = GTM_UNITTESTING_NON_ARC_SRCS, + sdk_frameworks = GTM_UNITTESTING_SDK_FRAMEWORKS, + visibility = ["//visibility:public"], + deps = GTM_UNITTESTING_DEPS, +) + +objc_library( + name = "GTM_UnitTesting_GTM_USING_XCTEST", + testonly = 1, + hdrs = GTM_UNITTESTING_HDRS, + defines = ["GTM_USING_XCTEST=1"], + non_arc_srcs = GTM_UNITTESTING_NON_ARC_SRCS, + sdk_frameworks = GTM_UNITTESTING_SDK_FRAMEWORKS, + visibility = ["//visibility:public"], + deps = GTM_UNITTESTING_DEPS, +) + +objc_library( + name = "GTM_UnitTestingAppLib", + testonly = 1, + hdrs = [ + "UnitTesting/GTMCodeCoverageApp.h", + "UnitTesting/GTMIPhoneUnitTestDelegate.h", + ], + non_arc_srcs = [ + "UnitTesting/GTMIPhoneUnitTestDelegate.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_UnitTesting", + ], +) + +# No Test for GTM_UnitTestingAppLib, use a build test. +objc_library( + name = "GTM_Logger", + hdrs = [ + "Foundation/GTMLogger.h", + ], + non_arc_srcs = [ + "Foundation/GTMLogger.m", + ], + visibility = ["//visibility:public"], + deps = [ + ":GTM_Defines", + ], +) + +objc_library( + name = "GTM_Logger_ASL", + hdrs = ["Foundation/GTMLogger+ASL.h"], + non_arc_srcs = ["Foundation/GTMLogger+ASL.m"], + visibility = ["//visibility:public"], + deps = [":GTM_Logger"], +) + +objc_library( + name = "GTM_LoggerRingBufferWriter", + hdrs = ["Foundation/GTMLoggerRingBufferWriter.h"], + non_arc_srcs = ["Foundation/GTMLoggerRingBufferWriter.m"], + visibility = ["//visibility:public"], + deps = [":GTM_Logger"], +) diff --git a/third_party/opencv_ios.BUILD b/third_party/opencv_ios.BUILD new file mode 100644 index 000000000..c9f112075 --- /dev/null +++ b/third_party/opencv_ios.BUILD @@ -0,0 +1,48 @@ +# Description: +# OpenCV libraries for video/image processing on iOS + +licenses(["notice"]) # BSD license + +exports_files(["LICENSE"]) + +load( + "@build_bazel_rules_apple//apple:apple.bzl", + "apple_static_framework_import", +) + +apple_static_framework_import( + name = "OpencvFramework", + framework_imports = glob(["opencv2.framework/**"]), + visibility = ["//visibility:public"], +) + +objc_library( + name = "opencv_objc_lib", + deps = [":OpencvFramework"], +) + +cc_library( + name = "opencv", + hdrs = glob([ + "opencv2.framework/Versions/A/Headers/**/*.h*", + ]), + copts = [ + "-std=c++11", + "-x objective-c++", + ], + include_prefix = "opencv2", + linkopts = [ + "-framework AssetsLibrary", + "-framework CoreFoundation", + "-framework CoreGraphics", + "-framework CoreMedia", + "-framework Accelerate", + "-framework CoreImage", + "-framework AVFoundation", + "-framework CoreVideo", + "-framework QuartzCore", + ], + strip_include_prefix = "opencv2.framework/Versions/A/Headers", + visibility = ["//visibility:public"], + deps = [":opencv_objc_lib"], +) diff --git a/third_party/opencv_linux.BUILD b/third_party/opencv_linux.BUILD index 0267a9987..48e8d9af9 100644 --- a/third_party/opencv_linux.BUILD +++ b/third_party/opencv_linux.BUILD @@ -21,7 +21,7 @@ cc_library( "lib/x86_64-linux-gnu/libopencv_videoio.so*", ], ), - hdrs = glob(["include/opencv/*.h*"]), + hdrs = glob(["include/opencv2/**/*.h*"]), includes = ["include"], linkstatic = 1, visibility = ["//visibility:public"], diff --git a/third_party/opencv_macos.BUILD b/third_party/opencv_macos.BUILD index b76c146c6..362e61c60 100644 --- a/third_party/opencv_macos.BUILD +++ b/third_party/opencv_macos.BUILD @@ -16,6 +16,7 @@ cc_library( "local/opt/opencv/lib/libopencv_highgui.dylib", "local/opt/opencv/lib/libopencv_imgcodecs.dylib", "local/opt/opencv/lib/libopencv_imgproc.dylib", + "local/opt/opencv/lib/libopencv_optflow.dylib", "local/opt/opencv/lib/libopencv_video.dylib", "local/opt/opencv/lib/libopencv_videoio.dylib", ], diff --git a/third_party/rules_apple_c0863d0596ae6b769a29fa3fb72ff036444fd249.diff b/third_party/rules_apple_c0863d0596ae6b769a29fa3fb72ff036444fd249.diff new file mode 100644 index 000000000..736292cfa --- /dev/null +++ b/third_party/rules_apple_c0863d0596ae6b769a29fa3fb72ff036444fd249.diff @@ -0,0 +1,25 @@ +commit c0863d0596ae6b769a29fa3fb72ff036444fd249 (HEAD -> py3) +Author: Camillo Lugaresi +Date: Fri Aug 16 00:13:16 2019 -0700 + + Fix codesigningtool.py py3 compatibility. + + In recent versions of plistlib, binary data entries are returned as instances of the built-in bytes class, and plistlib.Data is deprecated. + Since this script was expecting a plistlib.Data, it would fail with the error "AttributeError: 'bytes' object has no attribute 'data'". + This change makes it compatible with both new and old versions of plistlib. + +diff --git a/tools/codesigningtool/codesigningtool.py b/tools/codesigningtool/codesigningtool.py +index 59f3841..40cdcf3 100644 +--- a/tools/codesigningtool/codesigningtool.py ++++ b/tools/codesigningtool/codesigningtool.py +@@ -102,7 +102,9 @@ def _certificate_fingerprint(identity): + def _get_identities_from_provisioning_profile(mpf): + """Iterates through all the identities in a provisioning profile, lazily.""" + for identity in mpf["DeveloperCertificates"]: +- yield _certificate_fingerprint(identity.data) ++ if not _PY3: ++ identity = identity.data ++ yield _certificate_fingerprint(identity) + + + def _find_codesign_identities(identity=None): diff --git a/third_party/tensorflow_065c20bf79253257c87bd4614bb9a7fdef015cbb.diff b/third_party/tensorflow_065c20bf79253257c87bd4614bb9a7fdef015cbb.diff new file mode 100644 index 000000000..16f9c3265 --- /dev/null +++ b/third_party/tensorflow_065c20bf79253257c87bd4614bb9a7fdef015cbb.diff @@ -0,0 +1,22 @@ +commit 065c20bf79253257c87bd4614bb9a7fdef015cbb +Author: Camillo Lugaresi +Date: Thu Aug 15 18:34:41 2019 -0700 + + Use python3 if available to run gen_git_source.py. + + gen_git_source.py fails with an "ImportError: No module named builtins" on a default installation of Python 2 (at least, the one that comes with macOS). This can be worked around by installing the "future" package from pip. However, instead of requiring users to go through this extra step, we can simply run the script using Python 3 if it's installed. The script works on a default installation of Python 3, without requiring extra packages. + +diff --git a/third_party/git/git_configure.bzl b/third_party/git/git_configure.bzl +index fc18fdb988..3ce64242af 100644 +--- a/third_party/git/git_configure.bzl ++++ b/third_party/git/git_configure.bzl +@@ -18,6 +18,9 @@ def _get_python_bin(repository_ctx): + python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH) + if python_bin != None: + return python_bin ++ python_bin_path = repository_ctx.which("python3") ++ if python_bin_path != None: ++ return str(python_bin_path) + python_bin_path = repository_ctx.which("python") + if python_bin_path != None: + return str(python_bin_path) diff --git a/third_party/tensorflow_f67fcbefce906cd419e4657f0d41e21019b71abd.diff b/third_party/tensorflow_f67fcbefce906cd419e4657f0d41e21019b71abd.diff new file mode 100644 index 000000000..080e4dc12 --- /dev/null +++ b/third_party/tensorflow_f67fcbefce906cd419e4657f0d41e21019b71abd.diff @@ -0,0 +1,24 @@ +commit f67fcbefce906cd419e4657f0d41e21019b71abd (HEAD -> formediapipe) +Author: Camillo Lugaresi +Date: Fri Aug 16 12:24:58 2019 -0700 + + elementwise requires C++14 + + This file fails to compile when using C++11, which is the default. This can be worked around by passing --cxxopt='-std=c++14' as a global build option to Bazel, but it is more convenient for users if we just configure this cc_library to be built with C++14 by default. + + The authors may also want to change it to be compatible with C++11, but that's out of scope for this change. + +diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +index 17e59e70eb..4302a1f644 100644 +--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD ++++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +@@ -197,6 +197,9 @@ cc_library( + name = "elementwise", + srcs = ["elementwise.cc"], + hdrs = ["elementwise.h"], ++ copts = [ ++ "-std=c++14", ++ ], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations",