Project import generated by Copybara.
PiperOrigin-RevId: 263889205
This commit is contained in:
parent
dc40414468
commit
294687295d
27
.bazelrc
27
.bazelrc
|
@ -34,3 +34,30 @@ build:android_arm --fat_apk_cpu=armeabi-v7a
|
|||
build:android_arm64 --config=android
|
||||
build:android_arm64 --cpu=arm64-v8a
|
||||
build:android_arm64 --fat_apk_cpu=arm64-v8a
|
||||
|
||||
# iOS configs.
|
||||
build:ios --apple_platform_type=ios
|
||||
|
||||
build:ios_i386 --config=ios
|
||||
build:ios_i386 --cpu=ios_i386
|
||||
build:ios_i386 --watchos_cpus=i386
|
||||
|
||||
build:ios_x86_64 --config=ios
|
||||
build:ios_x86_64 --cpu=ios_x86_64
|
||||
build:ios_x86_64 --watchos_cpus=i386
|
||||
|
||||
build:ios_armv7 --config=ios
|
||||
build:ios_armv7 --cpu=ios_armv7
|
||||
build:ios_armv7 --watchos_cpus=armv7k
|
||||
|
||||
build:ios_arm64 --config=ios
|
||||
build:ios_arm64 --cpu=ios_arm64
|
||||
build:ios_arm64 --watchos_cpus=armv7k
|
||||
|
||||
build:ios_arm64e --config=ios
|
||||
build:ios_arm64e --cpu=ios_arm64e
|
||||
build:ios_arm64e --watchos_cpus=armv7k
|
||||
|
||||
build:ios_fat --config=ios
|
||||
build:ios_fat --ios_multi_cpus=armv7,arm64
|
||||
build:ios_fat --watchos_cpus=armv7k
|
||||
|
|
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
mediapipe/provisioning_profile.mobileprovision
|
2
BUILD
2
BUILD
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 The MediaPipeOSS Authors.
|
||||
# Copyright 2019 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
10
Dockerfile
10
Dockerfile
|
@ -28,14 +28,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
wget \
|
||||
unzip \
|
||||
python \
|
||||
python-pip \
|
||||
libopencv-core-dev \
|
||||
libopencv-highgui-dev \
|
||||
libopencv-imgproc-dev \
|
||||
libopencv-video-dev \
|
||||
&& \
|
||||
software-properties-common && \
|
||||
add-apt-repository -y ppa:openjdk-r/ppa && \
|
||||
apt-get update && apt-get install -y openjdk-11-jdk && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --upgrade setuptools
|
||||
RUN pip install future
|
||||
|
||||
# Install bazel
|
||||
ARG BAZEL_VERSION=0.26.1
|
||||
RUN mkdir /bazel && \
|
||||
|
@ -49,4 +55,4 @@ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
|
|||
COPY . /mediapipe/
|
||||
|
||||
# If we want the docker image to contain the pre-built object_detection_offline_demo binary, do the following
|
||||
# RUN bazel build -c opt --define 'MEDIAPIPE_DISABLE_GPU=1' mediapipe/examples/desktop/demo:object_detection_tensorflow_demo
|
||||
# RUN bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/demo:object_detection_tensorflow_demo
|
||||
|
|
119
WORKSPACE
119
WORKSPACE
|
@ -2,11 +2,12 @@ workspace(name = "mediapipe")
|
|||
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
|
||||
skylib_version = "0.8.0"
|
||||
http_archive(
|
||||
name = "bazel_skylib",
|
||||
sha256 = "bbccf674aa441c266df9894182d80de104cabd19be98be002f6d478aaa31574d",
|
||||
strip_prefix = "bazel-skylib-2169ae1c374aab4a09aa90e65efe1a3aad4e279b",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/2169ae1c374aab4a09aa90e65efe1a3aad4e279b.tar.gz"],
|
||||
type = "tar.gz",
|
||||
url = "https://github.com/bazelbuild/bazel-skylib/releases/download/{}/bazel-skylib.{}.tar.gz".format (skylib_version, skylib_version),
|
||||
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
|
||||
)
|
||||
load("@bazel_skylib//lib:versions.bzl", "versions")
|
||||
versions.check(minimum_bazel_version = "0.23.0")
|
||||
|
@ -52,7 +53,7 @@ http_archive(
|
|||
|
||||
# glog
|
||||
http_archive(
|
||||
name = "com_google_glog",
|
||||
name = "com_github_glog_glog",
|
||||
url = "https://github.com/google/glog/archive/v0.3.5.zip",
|
||||
sha256 = "267103f8a1e9578978aa1dc256001e6529ef593e5aea38193d31c2872ee025e8",
|
||||
strip_prefix = "glog-0.3.5",
|
||||
|
@ -73,6 +74,12 @@ http_archive(
|
|||
urls = ["https://github.com/google/protobuf/archive/384989534b2246d413dbcd750744faab2607b516.zip"],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "com_google_audio_tools",
|
||||
strip_prefix = "multichannel-audio-tools-master",
|
||||
urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"],
|
||||
)
|
||||
|
||||
# Needed by TensorFlow
|
||||
http_archive(
|
||||
name = "io_bazel_rules_closure",
|
||||
|
@ -84,12 +91,24 @@ http_archive(
|
|||
],
|
||||
)
|
||||
|
||||
# TensorFlow r1.14-rc0
|
||||
# 2019-08-15
|
||||
_TENSORFLOW_GIT_COMMIT = "67def62936e28f97c16182dfcc467d8d1cae02b4"
|
||||
_TENSORFLOW_SHA256= "ddd4e3c056e7c0ff2ef29133b30fa62781dfbf8a903e99efb91a02d292fa9562"
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
strip_prefix = "tensorflow-1.14.0-rc0",
|
||||
sha256 = "76404a6157a45e8d7a07e4f5690275256260130145924c2a7c73f6eda2a3de10",
|
||||
urls = ["https://github.com/tensorflow/tensorflow/archive/v1.14.0-rc0.zip"],
|
||||
urls = [
|
||||
"https://mirror.bazel.build/github.com/tensorflow/tensorflow/archive/%s.tar.gz" % _TENSORFLOW_GIT_COMMIT,
|
||||
"https://github.com/tensorflow/tensorflow/archive/%s.tar.gz" % _TENSORFLOW_GIT_COMMIT,
|
||||
],
|
||||
strip_prefix = "tensorflow-%s" % _TENSORFLOW_GIT_COMMIT,
|
||||
sha256 = _TENSORFLOW_SHA256,
|
||||
patches = [
|
||||
"@//third_party:tensorflow_065c20bf79253257c87bd4614bb9a7fdef015cbb.diff",
|
||||
"@//third_party:tensorflow_f67fcbefce906cd419e4657f0d41e21019b71abd.diff",
|
||||
],
|
||||
patch_args = [
|
||||
"-p1",
|
||||
],
|
||||
)
|
||||
|
||||
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
|
||||
|
@ -102,6 +121,12 @@ new_local_repository(
|
|||
path = "/usr",
|
||||
)
|
||||
|
||||
new_local_repository(
|
||||
name = "linux_ffmpeg",
|
||||
build_file = "@//third_party:ffmpeg_linux.BUILD",
|
||||
path = "/usr"
|
||||
)
|
||||
|
||||
# Please run $ brew install opencv
|
||||
new_local_repository(
|
||||
name = "macos_opencv",
|
||||
|
@ -109,6 +134,12 @@ new_local_repository(
|
|||
path = "/usr",
|
||||
)
|
||||
|
||||
new_local_repository(
|
||||
name = "macos_ffmpeg",
|
||||
build_file = "@//third_party:ffmpeg_macos.BUILD",
|
||||
path = "/usr",
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "android_opencv",
|
||||
sha256 = "056b849842e4fa8751d09edbb64530cfa7a63c84ccd232d0ace330e27ba55d0b",
|
||||
|
@ -118,6 +149,18 @@ http_archive(
|
|||
url = "https://github.com/opencv/opencv/releases/download/4.1.0/opencv-4.1.0-android-sdk.zip",
|
||||
)
|
||||
|
||||
# After OpenCV 3.2.0, the pre-compiled opencv2.framework has google protobuf symbols, which will
|
||||
# trigger duplicate symbol errors in the linking stage of building a mediapipe ios app.
|
||||
# To get a higher version of OpenCV for iOS, opencv2.framework needs to be built from source with
|
||||
# '-DBUILD_PROTOBUF=OFF -DBUILD_opencv_dnn=OFF'.
|
||||
http_archive(
|
||||
name = "ios_opencv",
|
||||
sha256 = "7dd536d06f59e6e1156b546bd581523d8df92ce83440002885ec5abc06558de2",
|
||||
build_file = "@//third_party:opencv_ios.BUILD",
|
||||
type = "zip",
|
||||
url = "https://github.com/opencv/opencv/releases/download/3.2.0/opencv-3.2.0-ios-framework.zip",
|
||||
)
|
||||
|
||||
RULES_JVM_EXTERNAL_TAG = "2.2"
|
||||
RULES_JVM_EXTERNAL_SHA = "f1203ce04e232ab6fdd81897cf0ff76f2c04c0741424d192f28e65ae752ce2d6"
|
||||
|
||||
|
@ -132,12 +175,15 @@ load("@rules_jvm_external//:defs.bzl", "maven_install")
|
|||
|
||||
maven_install(
|
||||
artifacts = [
|
||||
"com.android.support.constraint:constraint-layout:aar:1.0.2",
|
||||
"androidx.appcompat:appcompat:aar:1.0.2",
|
||||
],
|
||||
repositories = [
|
||||
"https://dl.google.com/dl/android/maven2",
|
||||
"androidx.annotation:annotation:aar:1.1.0",
|
||||
"androidx.appcompat:appcompat:aar:1.1.0-rc01",
|
||||
"androidx.constraintlayout:constraintlayout:aar:1.1.3",
|
||||
"androidx.core:core:aar:1.1.0-rc03",
|
||||
"androidx.legacy:legacy-support-v4:aar:1.0.0",
|
||||
"androidx.recyclerview:recyclerview:aar:1.1.0-beta02",
|
||||
"com.google.android.material:material:aar:1.0.0-rc01",
|
||||
],
|
||||
repositories = ["https://dl.google.com/dl/android/maven2"],
|
||||
)
|
||||
|
||||
maven_server(
|
||||
|
@ -191,3 +237,50 @@ android_ndk_repository(
|
|||
android_sdk_repository(
|
||||
name = "androidsdk",
|
||||
)
|
||||
|
||||
# iOS basic build deps.
|
||||
|
||||
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
|
||||
|
||||
git_repository(
|
||||
name = "build_bazel_rules_apple",
|
||||
remote = "https://github.com/bazelbuild/rules_apple.git",
|
||||
tag = "0.18.0",
|
||||
patches = [
|
||||
"@//third_party:rules_apple_c0863d0596ae6b769a29fa3fb72ff036444fd249.diff",
|
||||
],
|
||||
patch_args = [
|
||||
"-p1",
|
||||
],
|
||||
)
|
||||
|
||||
load(
|
||||
"@build_bazel_rules_apple//apple:repositories.bzl",
|
||||
"apple_rules_dependencies",
|
||||
)
|
||||
|
||||
apple_rules_dependencies()
|
||||
|
||||
load(
|
||||
"@build_bazel_rules_swift//swift:repositories.bzl",
|
||||
"swift_rules_dependencies",
|
||||
)
|
||||
|
||||
swift_rules_dependencies()
|
||||
|
||||
load(
|
||||
"@build_bazel_apple_support//lib:repositories.bzl",
|
||||
"apple_support_dependencies",
|
||||
)
|
||||
|
||||
apple_support_dependencies()
|
||||
|
||||
# More iOS deps.
|
||||
|
||||
http_archive(
|
||||
name = "google_toolbox_for_mac",
|
||||
url = "https://github.com/google/google-toolbox-for-mac/archive/v2.2.1.zip",
|
||||
sha256 = "e3ac053813c989a88703556df4dc4466e424e30d32108433ed6beaec76ba4fdc",
|
||||
strip_prefix = "google-toolbox-for-mac-2.2.1",
|
||||
build_file = "@//third_party:google_toolbox_for_mac.BUILD",
|
||||
)
|
||||
|
|
|
@ -65,11 +65,73 @@ config_setting(
|
|||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
# Note: this cannot just match "apple_platform_type": "macos" because that option
|
||||
# defaults to "macos" even when building on Linux!
|
||||
alias(
|
||||
name = "macos",
|
||||
actual = select({
|
||||
":macos_i386": ":macos_i386",
|
||||
":macos_x86_64": ":macos_x86_64",
|
||||
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Note: this also matches on crosstool_top so that it does not produce ambiguous
|
||||
# selectors when used together with "android".
|
||||
config_setting(
|
||||
name = "ios",
|
||||
values = {
|
||||
"crosstool_top": "@bazel_tools//tools/cpp:toolchain",
|
||||
"apple_platform_type": "ios",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "apple",
|
||||
actual = select({
|
||||
":macos": ":macos",
|
||||
":ios": ":ios",
|
||||
"//conditions:default": ":ios", # Arbitrarily chosen from above.
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "macos_i386",
|
||||
values = {
|
||||
"apple_platform_type": "macos",
|
||||
"cpu": "darwin",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "macos_x86_64",
|
||||
values = {
|
||||
"apple_platform_type": "macos",
|
||||
"cpu": "darwin_x86_64",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
[
|
||||
config_setting(
|
||||
name = arch,
|
||||
values = {"cpu": arch},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
for arch in [
|
||||
"ios_i386",
|
||||
"ios_x86_64",
|
||||
"ios_armv7",
|
||||
"ios_arm64",
|
||||
"ios_arm64e",
|
||||
]
|
||||
]
|
||||
|
||||
exports_files(
|
||||
["provisioning_profile.mobileprovision"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
|
305
mediapipe/calculators/audio/BUILD
Normal file
305
mediapipe/calculators/audio/BUILD
Normal file
|
@ -0,0 +1,305 @@
|
|||
# Copyright 2019 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "mfcc_mel_calculators_proto",
|
||||
srcs = ["mfcc_mel_calculators.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "mfcc_mel_calculators_cc_proto",
|
||||
srcs = ["mfcc_mel_calculators.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":mfcc_mel_calculators_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "rational_factor_resample_calculator_proto",
|
||||
srcs = ["rational_factor_resample_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "rational_factor_resample_calculator_cc_proto",
|
||||
srcs = ["rational_factor_resample_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":rational_factor_resample_calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "spectrogram_calculator_proto",
|
||||
srcs = ["spectrogram_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "spectrogram_calculator_cc_proto",
|
||||
srcs = ["spectrogram_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":spectrogram_calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "time_series_framer_calculator_proto",
|
||||
srcs = ["time_series_framer_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "time_series_framer_calculator_cc_proto",
|
||||
srcs = ["time_series_framer_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":time_series_framer_calculator_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "audio_decoder_calculator",
|
||||
srcs = ["audio_decoder_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:audio_decoder",
|
||||
"//mediapipe/util:audio_decoder_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "basic_time_series_calculators",
|
||||
srcs = ["basic_time_series_calculators.cc"],
|
||||
hdrs = ["basic_time_series_calculators.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mfcc_mel_calculators",
|
||||
srcs = ["mfcc_mel_calculators.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":mfcc_mel_calculators_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_audio_tools//audio/dsp/mfcc",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rational_factor_resample_calculator",
|
||||
srcs = ["rational_factor_resample_calculator.cc"],
|
||||
hdrs = ["rational_factor_resample_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":rational_factor_resample_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_audio_tools//audio/dsp:resampler",
|
||||
"@com_google_audio_tools//audio/dsp:resampler_rational_factor",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "spectrogram_calculator",
|
||||
srcs = ["spectrogram_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":spectrogram_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:source_location",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@com_google_audio_tools//audio/dsp/spectrogram",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "time_series_framer_calculator",
|
||||
srcs = ["time_series_framer_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":time_series_framer_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "audio_decoder_calculator_test",
|
||||
srcs = ["audio_decoder_calculator_test.cc"],
|
||||
data = ["//mediapipe/calculators/audio/testdata:test_audios"],
|
||||
deps = [
|
||||
":audio_decoder_calculator",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "basic_time_series_calculators_test",
|
||||
srcs = ["basic_time_series_calculators_test.cc"],
|
||||
deps = [
|
||||
":basic_time_series_calculators",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "mfcc_mel_calculators_test",
|
||||
srcs = ["mfcc_mel_calculators_test.cc"],
|
||||
deps = [
|
||||
":mfcc_mel_calculators",
|
||||
":mfcc_mel_calculators_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "spectrogram_calculator_test",
|
||||
srcs = ["spectrogram_calculator_test.cc"],
|
||||
deps = [
|
||||
":spectrogram_calculator",
|
||||
":spectrogram_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:benchmark",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@com_google_audio_tools//audio/dsp:number_util",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "time_series_framer_calculator_test",
|
||||
srcs = ["time_series_framer_calculator_test.cc"],
|
||||
deps = [
|
||||
":time_series_framer_calculator",
|
||||
":time_series_framer_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "rational_factor_resample_calculator_test",
|
||||
srcs = ["rational_factor_resample_calculator_test.cc"],
|
||||
deps = [
|
||||
":rational_factor_resample_calculator",
|
||||
":rational_factor_resample_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@com_google_audio_tools//audio/dsp:signal_vector_util",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
)
|
106
mediapipe/calculators/audio/audio_decoder_calculator.cc
Normal file
106
mediapipe/calculators/audio/audio_decoder_calculator.cc
Normal file
|
@ -0,0 +1,106 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/audio_decoder.h"
|
||||
#include "mediapipe/util/audio_decoder.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// The AudioDecoderCalculator decodes an audio stream of the media file. It
|
||||
// produces two output streams contain audio packets and the header infomation.
|
||||
//
|
||||
// Output Streams:
|
||||
// AUDIO: Output audio frames (Matrix).
|
||||
// AUDIO_HEADER:
|
||||
// Optional audio header information output
|
||||
// Input Side Packets:
|
||||
// INPUT_FILE_PATH: The input file path.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "AudioDecoderCalculator"
|
||||
// input_side_packet: "INPUT_FILE_PATH:input_file_path"
|
||||
// output_stream: "AUDIO:audio"
|
||||
// output_stream: "AUDIO_HEADER:audio_header"
|
||||
// node_options {
|
||||
// [type.googleapis.com/mediapipe.AudioDecoderOptions]: {
|
||||
// audio_stream { stream_index: 0 }
|
||||
// start_time: 0
|
||||
// end_time: 1
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// TODO: support decoding multiple streams.
|
||||
class AudioDecoderCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<AudioDecoder> decoder_;
|
||||
};
|
||||
|
||||
::mediapipe::Status AudioDecoderCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set<std::string>();
|
||||
|
||||
cc->Outputs().Tag("AUDIO").Set<Matrix>();
|
||||
if (cc->Outputs().HasTag("AUDIO_HEADER")) {
|
||||
cc->Outputs().Tag("AUDIO_HEADER").Set<mediapipe::TimeSeriesHeader>();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status AudioDecoderCalculator::Open(CalculatorContext* cc) {
|
||||
const std::string& input_file_path =
|
||||
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get<std::string>();
|
||||
const auto& decoder_options = cc->Options<mediapipe::AudioDecoderOptions>();
|
||||
decoder_ = absl::make_unique<AudioDecoder>();
|
||||
RETURN_IF_ERROR(decoder_->Initialize(input_file_path, decoder_options));
|
||||
std::unique_ptr<mediapipe::TimeSeriesHeader> header =
|
||||
absl::make_unique<mediapipe::TimeSeriesHeader>();
|
||||
if (decoder_->FillAudioHeader(decoder_options.audio_stream(0), header.get())
|
||||
.ok()) {
|
||||
// Only pass on a header if the decoder could actually produce one.
|
||||
// otherwise, the header will be empty.
|
||||
cc->Outputs().Tag("AUDIO_HEADER").SetHeader(Adopt(header.release()));
|
||||
}
|
||||
cc->Outputs().Tag("AUDIO_HEADER").Close();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status AudioDecoderCalculator::Process(CalculatorContext* cc) {
|
||||
Packet data;
|
||||
int options_index = -1;
|
||||
auto status = decoder_->GetData(&options_index, &data);
|
||||
if (status.ok()) {
|
||||
cc->Outputs().Tag("AUDIO").AddPacket(data);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
::mediapipe::Status AudioDecoderCalculator::Close(CalculatorContext* cc) {
|
||||
return decoder_->Close();
|
||||
}
|
||||
|
||||
REGISTER_CALCULATOR(AudioDecoderCalculator);
|
||||
|
||||
} // namespace mediapipe
|
153
mediapipe/calculators/audio/audio_decoder_calculator_test.cc
Normal file
153
mediapipe/calculators/audio/audio_decoder_calculator_test.cc
Normal file
|
@ -0,0 +1,153 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
TEST(AudioDecoderCalculatorTest, TestWAV) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "AudioDecoderCalculator"
|
||||
input_side_packet: "INPUT_FILE_PATH:input_file_path"
|
||||
output_stream: "AUDIO:audio"
|
||||
output_stream: "AUDIO_HEADER:audio_header"
|
||||
node_options {
|
||||
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
|
||||
audio_stream { stream_index: 0 }
|
||||
}
|
||||
})");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_44100_mono_2_sec_wav.audio"));
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
MEDIAPIPE_EXPECT_OK(
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
|
||||
const mediapipe::TimeSeriesHeader& header =
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.Get<mediapipe::TimeSeriesHeader>();
|
||||
EXPECT_EQ(44100, header.sample_rate());
|
||||
EXPECT_EQ(1, header.num_channels());
|
||||
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
|
||||
std::ceil(44100.0 * 2 / 2048));
|
||||
}
|
||||
|
||||
TEST(AudioDecoderCalculatorTest, Test48KWAV) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "AudioDecoderCalculator"
|
||||
input_side_packet: "INPUT_FILE_PATH:input_file_path"
|
||||
output_stream: "AUDIO:audio"
|
||||
output_stream: "AUDIO_HEADER:audio_header"
|
||||
node_options {
|
||||
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
|
||||
audio_stream { stream_index: 0 }
|
||||
}
|
||||
})");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio"));
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
MEDIAPIPE_EXPECT_OK(
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
|
||||
const mediapipe::TimeSeriesHeader& header =
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.Get<mediapipe::TimeSeriesHeader>();
|
||||
EXPECT_EQ(48000, header.sample_rate());
|
||||
EXPECT_EQ(2, header.num_channels());
|
||||
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
|
||||
std::ceil(48000.0 * 2 / 1024));
|
||||
}
|
||||
|
||||
TEST(AudioDecoderCalculatorTest, TestMP3) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "AudioDecoderCalculator"
|
||||
input_side_packet: "INPUT_FILE_PATH:input_file_path"
|
||||
output_stream: "AUDIO:audio"
|
||||
output_stream: "AUDIO_HEADER:audio_header"
|
||||
node_options {
|
||||
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
|
||||
audio_stream { stream_index: 0 }
|
||||
}
|
||||
})");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio"));
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
MEDIAPIPE_EXPECT_OK(
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
|
||||
const mediapipe::TimeSeriesHeader& header =
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.Get<mediapipe::TimeSeriesHeader>();
|
||||
EXPECT_EQ(44100, header.sample_rate());
|
||||
EXPECT_EQ(2, header.num_channels());
|
||||
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
|
||||
std::ceil(44100.0 * 2 / 1152));
|
||||
}
|
||||
|
||||
TEST(AudioDecoderCalculatorTest, TestAAC) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "AudioDecoderCalculator"
|
||||
input_side_packet: "INPUT_FILE_PATH:input_file_path"
|
||||
output_stream: "AUDIO:audio"
|
||||
output_stream: "AUDIO_HEADER:audio_header"
|
||||
node_options {
|
||||
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
|
||||
audio_stream { stream_index: 0 }
|
||||
}
|
||||
})");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio"));
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
MEDIAPIPE_EXPECT_OK(
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
|
||||
const mediapipe::TimeSeriesHeader& header =
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.Get<mediapipe::TimeSeriesHeader>();
|
||||
EXPECT_EQ(44100, header.sample_rate());
|
||||
EXPECT_EQ(2, header.num_channels());
|
||||
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
|
||||
std::ceil(44100.0 * 2 / 1024));
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
403
mediapipe/calculators/audio/basic_time_series_calculators.cc
Normal file
403
mediapipe/calculators/audio/basic_time_series_calculators.cc
Normal file
|
@ -0,0 +1,403 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Basic Calculators that operate on TimeSeries streams.
|
||||
#include "mediapipe/calculators/audio/basic_time_series_calculators.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
static bool SafeMultiply(int x, int y, int* result) {
|
||||
static_assert(sizeof(int64) >= 2 * sizeof(int),
|
||||
"Unable to detect overflow after multiplication");
|
||||
const int64 big = static_cast<int64>(x) * static_cast<int64>(y);
|
||||
if (big > static_cast<int64>(INT_MIN) && big < static_cast<int64>(INT_MAX)) {
|
||||
if (result != nullptr) *result = static_cast<int>(big);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
::mediapipe::Status BasicTimeSeriesCalculatorBase::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Input stream with TimeSeriesHeader.
|
||||
);
|
||||
cc->Outputs().Index(0).Set<Matrix>(
|
||||
// Output stream with TimeSeriesHeader.
|
||||
);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status BasicTimeSeriesCalculatorBase::Open(CalculatorContext* cc) {
|
||||
TimeSeriesHeader input_header;
|
||||
RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
|
||||
auto output_header = new TimeSeriesHeader(input_header);
|
||||
RETURN_IF_ERROR(MutateHeader(output_header));
|
||||
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status BasicTimeSeriesCalculatorBase::Process(
|
||||
CalculatorContext* cc) {
|
||||
const Matrix& input = cc->Inputs().Index(0).Get<Matrix>();
|
||||
RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader(
|
||||
input, cc->Inputs().Index(0).Header().Get<TimeSeriesHeader>()));
|
||||
|
||||
std::unique_ptr<Matrix> output(new Matrix(ProcessMatrix(input)));
|
||||
RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader(
|
||||
*output, cc->Outputs().Index(0).Header().Get<TimeSeriesHeader>()));
|
||||
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status BasicTimeSeriesCalculatorBase::MutateHeader(
|
||||
TimeSeriesHeader* output_header) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Calculator to sum an input time series across channels. This is
|
||||
// useful for e.g. computing 'summary SAI' pitchogram features.
|
||||
//
|
||||
// Options proto: None.
|
||||
class SumTimeSeriesAcrossChannelsCalculator
|
||||
: public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_channels(1);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.colwise().sum();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(SumTimeSeriesAcrossChannelsCalculator);
|
||||
|
||||
// Calculator to average an input time series across channels. This is
|
||||
// useful for e.g. converting stereo or multi-channel files to mono.
|
||||
//
|
||||
// Options proto: None.
|
||||
class AverageTimeSeriesAcrossChannelsCalculator
|
||||
: public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_channels(1);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.colwise().mean();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(AverageTimeSeriesAcrossChannelsCalculator);
|
||||
|
||||
// Calculator to convert a (temporal) summary SAI stream (a single-channel
|
||||
// stream output by SumTimeSeriesAcrossChannelsCalculator) into pitchogram
|
||||
// frames by transposing the input packets, swapping the time and channel axes.
|
||||
//
|
||||
// Options proto: None.
|
||||
class SummarySaiToPitchogramCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
if (output_header->num_channels() != 1) {
|
||||
return tool::StatusInvalid(
|
||||
absl::StrCat("Expected single-channel input, got ",
|
||||
output_header->num_channels()));
|
||||
}
|
||||
output_header->set_num_channels(output_header->num_samples());
|
||||
output_header->set_num_samples(1);
|
||||
output_header->set_sample_rate(output_header->packet_rate());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.transpose();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(SummarySaiToPitchogramCalculator);
|
||||
|
||||
// Calculator to reverse the order of channels in TimeSeries packets.
|
||||
// This is useful for e.g. interfacing with the speech pipeline which uses the
|
||||
// opposite convention to the hearing filterbanks.
|
||||
//
|
||||
// Options proto: None.
|
||||
class ReverseChannelOrderCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.colwise().reverse();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(ReverseChannelOrderCalculator);
|
||||
|
||||
// Calculator to flatten all samples in a TimeSeries packet down into
|
||||
// a single 'sample' vector. This is useful for e.g. stacking several
|
||||
// frames of features into a single feature vector.
|
||||
//
|
||||
// Options proto: None.
|
||||
class FlattenPacketCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
const int num_input_channels = output_header->num_channels();
|
||||
const int num_input_samples = output_header->num_samples();
|
||||
RET_CHECK(num_input_channels >= 0)
|
||||
<< "FlattenPacketCalculator: num_input_channels < 0";
|
||||
RET_CHECK(num_input_samples >= 0)
|
||||
<< "FlattenPacketCalculator: num_input_samples < 0";
|
||||
int output_num_channels;
|
||||
RET_CHECK(SafeMultiply(num_input_channels, num_input_samples,
|
||||
&output_num_channels))
|
||||
<< "FlattenPacketCalculator: Multiplication failed.";
|
||||
output_header->set_num_channels(output_num_channels);
|
||||
output_header->set_num_samples(1);
|
||||
output_header->set_sample_rate(output_header->packet_rate());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
// Flatten by interleaving channels so that full samples are
|
||||
// stacked on top of each other instead of interleaving samples
|
||||
// from the same channel.
|
||||
Matrix output(input_matrix.size(), 1);
|
||||
for (int sample = 0; sample < input_matrix.cols(); ++sample) {
|
||||
output.middleRows(sample * input_matrix.rows(), input_matrix.rows()) =
|
||||
input_matrix.col(sample);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(FlattenPacketCalculator);
|
||||
|
||||
// Calculator to subtract the within-packet mean for each channel from each
|
||||
// corresponding channel.
|
||||
//
|
||||
// Options proto: None.
|
||||
class SubtractMeanCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
Matrix mean = input_matrix.rowwise().mean();
|
||||
return input_matrix - mean.replicate(1, input_matrix.cols());
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(SubtractMeanCalculator);
|
||||
|
||||
// Calculator to subtract the mean over all values (across all times and
|
||||
// channels) in a Packet from the values in that Packet.
|
||||
//
|
||||
// Options proto: None.
|
||||
class SubtractMeanAcrossChannelsCalculator
|
||||
: public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
auto mean = input_matrix.mean();
|
||||
return (input_matrix.array() - mean).matrix();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(SubtractMeanAcrossChannelsCalculator);
|
||||
|
||||
// Calculator to divide all values in a Packet by the average value across all
|
||||
// times and channels in the packet. This is useful for normalizing
|
||||
// nonnegative quantities like power, but might cause unexpected results if used
|
||||
// with Packets that can contain negative numbers.
|
||||
//
|
||||
// If mean is exactly zero, the output will be a matrix of all ones, because
|
||||
// that's what happens in other cases where all values are equal.
|
||||
//
|
||||
// Options proto: None.
|
||||
class DivideByMeanAcrossChannelsCalculator
|
||||
: public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
auto mean = input_matrix.mean();
|
||||
|
||||
if (mean != 0) {
|
||||
return input_matrix / mean;
|
||||
|
||||
// When used with nonnegative matrices, the mean will only be zero if the
|
||||
// entire matrix is exactly zero. If mean is exactly zero, the output will
|
||||
// be a matrix of all ones, because that's what happens in other cases
|
||||
// where
|
||||
// all values are equal.
|
||||
} else {
|
||||
return Matrix::Ones(input_matrix.rows(), input_matrix.cols());
|
||||
}
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(DivideByMeanAcrossChannelsCalculator);
|
||||
|
||||
// Calculator to calculate the mean for each channel.
|
||||
//
|
||||
// Options proto: None.
|
||||
class MeanCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_samples(1);
|
||||
output_header->set_sample_rate(output_header->packet_rate());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.rowwise().mean();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(MeanCalculator);
|
||||
|
||||
// Calculator to calculate the uncorrected sample standard deviation in each
|
||||
// channel, independently for each Packet. I.e. divide by the number of samples
|
||||
// in the Packet, not (<number of samples> - 1).
|
||||
//
|
||||
// Options proto: None.
|
||||
class StandardDeviationCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_samples(1);
|
||||
output_header->set_sample_rate(output_header->packet_rate());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
Eigen::VectorXf mean = input_matrix.rowwise().mean();
|
||||
return (input_matrix.colwise() - mean).rowwise().norm() /
|
||||
sqrt(input_matrix.cols());
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(StandardDeviationCalculator);
|
||||
|
||||
// Calculator to calculate the covariance matrix. If the input matrix
|
||||
// has N channels, the output matrix will be an N by N symmetric
|
||||
// matrix.
|
||||
//
|
||||
// Options proto: None.
|
||||
class CovarianceCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_samples(output_header->num_channels());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
auto mean = input_matrix.rowwise().mean();
|
||||
auto zero_mean_input =
|
||||
input_matrix - mean.replicate(1, input_matrix.cols());
|
||||
return (zero_mean_input * zero_mean_input.transpose()) /
|
||||
input_matrix.cols();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(CovarianceCalculator);
|
||||
|
||||
// Calculator to get the per column L2 norm of an input time series.
|
||||
//
|
||||
// Options proto: None.
|
||||
class L2NormCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_channels(1);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.colwise().norm();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(L2NormCalculator);
|
||||
|
||||
// Calculator to convert each column of a matrix to a unit vector.
|
||||
//
|
||||
// Options proto: None.
|
||||
class L2NormalizeColumnCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.colwise().normalized();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(L2NormalizeColumnCalculator);
|
||||
|
||||
// Calculator to apply L2 normalization to the input matrix.
|
||||
//
|
||||
// Returns the matrix as is if the RMS is <= 1E-8.
|
||||
// Options proto: None.
|
||||
class L2NormalizeCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
constexpr double kEpsilon = 1e-8;
|
||||
double rms = std::sqrt(input_matrix.array().square().mean());
|
||||
if (rms <= kEpsilon) {
|
||||
return input_matrix;
|
||||
}
|
||||
return input_matrix / rms;
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(L2NormalizeCalculator);
|
||||
|
||||
// Calculator to apply Peak normalization to the input matrix.
|
||||
//
|
||||
// Returns the matrix as is if the peak is <= 1E-8.
|
||||
// Options proto: None.
|
||||
class PeakNormalizeCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
constexpr double kEpsilon = 1e-8;
|
||||
double max_pcm = input_matrix.cwiseAbs().maxCoeff();
|
||||
if (max_pcm <= kEpsilon) {
|
||||
return input_matrix;
|
||||
}
|
||||
return input_matrix / max_pcm;
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(PeakNormalizeCalculator);
|
||||
|
||||
// Calculator to compute the elementwise square of an input time series.
|
||||
//
|
||||
// Options proto: None.
|
||||
class ElementwiseSquareCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.array().square();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(ElementwiseSquareCalculator);
|
||||
|
||||
// Calculator that outputs first floor(num_samples / 2) of the samples.
|
||||
//
|
||||
// Options proto: None.
|
||||
class FirstHalfSlicerCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
const int num_input_samples = output_header->num_samples();
|
||||
RET_CHECK(num_input_samples >= 0)
|
||||
<< "FirstHalfSlicerCalculator: num_input_samples < 0";
|
||||
output_header->set_num_samples(num_input_samples / 2);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.block(0, 0, input_matrix.rows(),
|
||||
input_matrix.cols() / 2);
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(FirstHalfSlicerCalculator);
|
||||
|
||||
} // namespace mediapipe
|
48
mediapipe/calculators/audio/basic_time_series_calculators.h
Normal file
48
mediapipe/calculators/audio/basic_time_series_calculators.h
Normal file
|
@ -0,0 +1,48 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Abstract base class for basic MediaPipe calculators that operate on
|
||||
// TimeSeries streams and don't require any Options protos.
|
||||
// Subclasses must override ProcessMatrix, and optionally
|
||||
// MutateHeader.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_
|
||||
#define MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class BasicTimeSeriesCalculatorBase : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
protected:
|
||||
// Open() calls this method to mutate the output stream header. The input
|
||||
// to this function will contain a copy of the input stream header, so
|
||||
// subclasses that do not need to mutate the header do not need to override
|
||||
// it.
|
||||
virtual ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header);
|
||||
|
||||
// Process() calls this method on each packet to compute the output matrix.
|
||||
virtual Matrix ProcessMatrix(const Matrix& input_matrix) = 0;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_
|
|
@ -0,0 +1,515 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class SumTimeSeriesAcrossChannelsCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "SumTimeSeriesAcrossChannelsCalculator";
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, IsNoOpOnSingleChannelInputs) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
|
||||
const Matrix input =
|
||||
Matrix::Random(header.num_channels(), header.num_samples());
|
||||
|
||||
Test(header, {input}, header, {input});
|
||||
}
|
||||
|
||||
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, ConstantPacket) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
|
||||
TimeSeriesHeader output_header(header);
|
||||
output_header.set_num_channels(1);
|
||||
|
||||
Test(header,
|
||||
{Matrix::Constant(header.num_channels(), header.num_samples(), 1)},
|
||||
output_header,
|
||||
{Matrix::Constant(1, header.num_samples(), header.num_channels())});
|
||||
}
|
||||
|
||||
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, MultiplePackets) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
|
||||
Matrix in(header.num_channels(), header.num_samples());
|
||||
in << 10, -1, -1, 0, 0, 20, -2, 0, 1, 0, 30, -3, 1, 0, 12;
|
||||
|
||||
TimeSeriesHeader output_header(header);
|
||||
output_header.set_num_channels(1);
|
||||
Matrix out(1, header.num_samples());
|
||||
out << 60, -6, 0, 1, 12;
|
||||
|
||||
Test(header, {in, 2 * in, in + Matrix::Constant(in.rows(), in.cols(), 3.5f)},
|
||||
output_header,
|
||||
{out, 2 * out,
|
||||
out + Matrix::Constant(out.rows(), out.cols(),
|
||||
3.5 * header.num_channels())});
|
||||
}
|
||||
|
||||
class AverageTimeSeriesAcrossChannelsCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "AverageTimeSeriesAcrossChannelsCalculator";
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(AverageTimeSeriesAcrossChannelsCalculatorTest,
|
||||
IsNoOpOnSingleChannelInputs) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
|
||||
const Matrix input =
|
||||
Matrix::Random(header.num_channels(), header.num_samples());
|
||||
|
||||
Test(header, {input}, header, {input});
|
||||
}
|
||||
|
||||
TEST_F(AverageTimeSeriesAcrossChannelsCalculatorTest, ConstantPacket) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
|
||||
TimeSeriesHeader output_header(header);
|
||||
output_header.set_num_channels(1);
|
||||
|
||||
Matrix input =
|
||||
Matrix::Constant(header.num_channels(), header.num_samples(), 0.0);
|
||||
input.row(0) = Matrix::Constant(1, header.num_samples(), 1.0);
|
||||
|
||||
Test(
|
||||
header, {input}, output_header,
|
||||
{Matrix::Constant(1, header.num_samples(), 1.0 / header.num_channels())});
|
||||
}
|
||||
|
||||
class SummarySaiToPitchogramCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "SummarySaiToPitchogramCalculator";
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(SummarySaiToPitchogramCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 1 num_samples: 3");
|
||||
Matrix input(1, input_header.num_samples());
|
||||
input << 3, -9, 4;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 5.0 packet_rate: 5.0 num_channels: 3 num_samples: 1");
|
||||
Matrix output(input_header.num_samples(), 1);
|
||||
output << 3, -9, 4;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class ReverseChannelOrderCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "ReverseChannelOrderCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(ReverseChannelOrderCalculatorTest, IsNoOpOnSingleChannelInputs) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
|
||||
const Matrix input =
|
||||
Matrix::Random(header.num_channels(), header.num_samples());
|
||||
|
||||
Test(header, {input}, header, {input});
|
||||
}
|
||||
|
||||
TEST_F(ReverseChannelOrderCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 5 num_samples: 2");
|
||||
Matrix input(header.num_channels(), header.num_samples());
|
||||
input.transpose() << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
|
||||
Matrix output(header.num_channels(), header.num_samples());
|
||||
output.transpose() << 5, 4, 3, 2, 1, -5, -4, -3, -2, -1;
|
||||
|
||||
Test(header, {input}, header, {output});
|
||||
}
|
||||
|
||||
class FlattenPacketCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "FlattenPacketCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(FlattenPacketCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input.transpose() << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
|
||||
Matrix output(10, 1);
|
||||
output << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 10.0 packet_rate: 10.0 num_channels: 10 num_samples: 1");
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class SubtractMeanCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "SubtractMeanCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(SubtractMeanCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
Matrix output(input_header.num_channels(), input_header.num_samples());
|
||||
|
||||
// clang-format off
|
||||
input.transpose() << 1, 0, 3, 0, 1,
|
||||
-1, -2, -3, 4, 7;
|
||||
output.transpose() << 1, 1, 3, -2, -3,
|
||||
-1, -1, -3, 2, 3;
|
||||
// clang-format on
|
||||
|
||||
const TimeSeriesHeader output_header = input_header;
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class SubtractMeanAcrossChannelsCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "SubtractMeanAcrossChannelsCalculator";
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(SubtractMeanAcrossChannelsCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_num_samples(2);
|
||||
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
|
||||
// clang-format off
|
||||
input.transpose() << 1.0, 2.0, 3.0,
|
||||
4.0, 5.0, 6.0;
|
||||
output.transpose() << 1.0 - 3.5, 2.0 - 3.5, 3.0 - 3.5,
|
||||
4.0 - 3.5, 5.0 - 3.5, 6.0 - 3.5;
|
||||
// clang-format on
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class DivideByMeanAcrossChannelsCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "DivideByMeanAcrossChannelsCalculator";
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(DivideByMeanAcrossChannelsCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input.transpose() << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0;
|
||||
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_num_samples(2);
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output.transpose() << 1.0 / 3.5, 2.0 / 3.5, 3.0 / 3.5, 4.0 / 3.5, 5.0 / 3.5,
|
||||
6.0 / 3.5;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
TEST_F(DivideByMeanAcrossChannelsCalculatorTest, ReturnsOneForZeroMean) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input.transpose() << -3.0, -2.0, -1.0, 1.0, 2.0, 3.0;
|
||||
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_num_samples(2);
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output.transpose() << 1.0, 1.0, 1.0, 1.0, 1.0, 1.0;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class MeanCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "MeanCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(MeanCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input.transpose() << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0;
|
||||
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_num_samples(1);
|
||||
output_header.set_sample_rate(10.0);
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << (1.0 + 4.0) / 2, (2.0 + 5.0) / 2, (3.0 + 6.0) / 2;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class StandardDeviationCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "StandardDeviationCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(StandardDeviationCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input.transpose() << 0.0, 2.0, 3.0, 4.0, 5.0, 8.0;
|
||||
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_sample_rate(10.0);
|
||||
output_header.set_num_samples(1);
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << sqrt((pow(0.0 - 2.0, 2) + pow(4.0 - 2.0, 2)) / 2),
|
||||
sqrt((pow(2.0 - 3.5, 2) + pow(5.0 - 3.5, 2)) / 2),
|
||||
sqrt((pow(3.0 - 5.5, 2) + pow(8.0 - 5.5, 2)) / 2);
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class CovarianceCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "CovarianceCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(CovarianceCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
|
||||
// We'll specify in transposed form so we can write one channel at a time.
|
||||
input << 1.0, 3.0, 5.0, 9.0, -1.0, -3.0;
|
||||
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_num_samples(output_header.num_channels());
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << 1, 2, -1, 2, 4, -2, -1, -2, 1;
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class L2NormCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "L2NormCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(L2NormCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 3, 5, 8, 4, 12, -15;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 1 num_samples: 3");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << 5, 13, 17;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class L2NormalizeColumnCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "L2NormalizeColumnCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(L2NormalizeColumnCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
|
||||
// The values in output are column-wise L2 normalized
|
||||
// e.g.
|
||||
// |a| -> |a/sqrt(a^2 + b^2)|
|
||||
// |b| |b/sqrt(a^2 + b^2)|
|
||||
output << 0.51449579000473022, 0.40613847970962524, 0.70710676908493042,
|
||||
0.85749292373657227, 0.91381156444549561, 0.70710676908493042;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class L2NormalizeCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "L2NormalizeCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(L2NormalizeCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
|
||||
// The values in output are L2 normalized
|
||||
// a -> a/sqrt(a^2 + b^2 + c^2 + ...) * sqrt(matrix.cols()*matrix.rows())
|
||||
output << 0.45661166, 0.60881555, 1.21763109, 0.76101943, 1.36983498,
|
||||
1.21763109;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
TEST_F(L2NormalizeCalculatorTest, UnitMatrixStaysUnchanged) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 3 num_samples: 5");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0;
|
||||
|
||||
Test(input_header, {input}, input_header, {input});
|
||||
}
|
||||
|
||||
class PeakNormalizeCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "PeakNormalizeCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(PeakNormalizeCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << 0.33333333, 0.44444444, 0.88888889, 0.55555556, 1.0, 0.88888889;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
TEST_F(PeakNormalizeCalculatorTest, UnitMatrixStaysUnchanged) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 3 num_samples: 5");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0;
|
||||
|
||||
Test(input_header, {input}, input_header, {input});
|
||||
}
|
||||
|
||||
class ElementwiseSquareCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "ElementwiseSquareCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(ElementwiseSquareCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 3, 5, 8, 4, 12, -15;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << 9, 25, 64, 16, 144, 225;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class FirstHalfSlicerCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "FirstHalfSlicerCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(FirstHalfSlicerCalculatorTest, SinglePacketEvenNumSamples) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
// clang-format off
|
||||
input.transpose() << 0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9;
|
||||
// clang-format on
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output.transpose() << 0, 1, 2, 3, 4;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
TEST_F(FirstHalfSlicerCalculatorTest, SinglePacketOddNumSamples) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
// clang-format off
|
||||
input.transpose() << 0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9,
|
||||
0, 0, 0, 0, 0;
|
||||
// clang-format on
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output.transpose() << 0, 1, 2, 3, 4;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
TEST_F(FirstHalfSlicerCalculatorTest, MultiplePackets) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
// clang-format off
|
||||
input.transpose() << 0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9;
|
||||
// clang-format on
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output.transpose() << 0, 1, 2, 3, 4;
|
||||
|
||||
Test(input_header,
|
||||
{input, 2 * input,
|
||||
input + Matrix::Constant(input.rows(), input.cols(), 3.5f)},
|
||||
output_header,
|
||||
{output, 2 * output,
|
||||
output + Matrix::Constant(output.rows(), output.cols(), 3.5f)});
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
278
mediapipe/calculators/audio/mfcc_mel_calculators.cc
Normal file
278
mediapipe/calculators/audio/mfcc_mel_calculators.cc
Normal file
|
@ -0,0 +1,278 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// MediaPipe Calculator wrapper around audio/dsp/mfcc/
|
||||
// classes MelFilterbank (magnitude spectrograms warped to the Mel
|
||||
// approximation of the auditory frequency scale) and Mfcc (Mel Frequency
|
||||
// Cepstral Coefficients, the decorrelated transform of log-Mel-spectrum
|
||||
// commonly used as acoustic features in speech and other audio tasks.
|
||||
// Both calculators expect as input the SQUARED_MAGNITUDE-domain outputs
|
||||
// from the MediaPipe SpectrogramCalculator object.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "audio/dsp/mfcc/mel_filterbank.h"
|
||||
#include "audio/dsp/mfcc/mfcc.h"
|
||||
#include "mediapipe/calculators/audio/mfcc_mel_calculators.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
// Portable version of TimeSeriesHeader's DebugString.
|
||||
std::string PortableDebugString(const TimeSeriesHeader& header) {
|
||||
std::string unsubstituted_header_debug_str = R"(
|
||||
sample_rate: $0
|
||||
num_channels: $1
|
||||
num_samples: $2
|
||||
packet_rate: $3
|
||||
audio_sample_rate: $4
|
||||
)";
|
||||
return absl::Substitute(unsubstituted_header_debug_str, header.sample_rate(),
|
||||
header.num_channels(), header.num_samples(),
|
||||
header.packet_rate(), header.audio_sample_rate());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Abstract base class for Calculators that transform feature vectors on a
|
||||
// frame-by-frame basis.
|
||||
// Subclasses must override pure virtual methods ConfigureTransform and
|
||||
// TransformFrame.
|
||||
// Input and output MediaPipe packets are matrices with one column per frame,
|
||||
// and one row per feature dimension. Each input packet results in an
|
||||
// output packet with the same number of columns (but differing numbers of
|
||||
// rows corresponding to the new feature space).
|
||||
class FramewiseTransformCalculatorBase : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Sequence of Matrices, each column describing a particular time frame,
|
||||
// each row a feature dimension, with TimeSeriesHeader.
|
||||
);
|
||||
cc->Outputs().Index(0).Set<Matrix>(
|
||||
// Sequence of Matrices, each column describing a particular time frame,
|
||||
// each row a feature dimension, with TimeSeriesHeader.
|
||||
);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
int num_output_channels(void) { return num_output_channels_; }
|
||||
|
||||
void set_num_output_channels(int num_output_channels) {
|
||||
num_output_channels_ = num_output_channels;
|
||||
}
|
||||
|
||||
private:
|
||||
// Takes header and options, and sets up state including calling
|
||||
// set_num_output_channels() on the base object.
|
||||
virtual ::mediapipe::Status ConfigureTransform(
|
||||
const TimeSeriesHeader& header, const CalculatorOptions& options) = 0;
|
||||
|
||||
// Takes a vector<double> corresponding to an input frame, and
|
||||
// perform the specific transformation to produce an output frame.
|
||||
virtual void TransformFrame(const std::vector<double>& input,
|
||||
std::vector<double>* output) const = 0;
|
||||
|
||||
private:
|
||||
int num_output_channels_;
|
||||
};
|
||||
|
||||
::mediapipe::Status FramewiseTransformCalculatorBase::Open(
|
||||
CalculatorContext* cc) {
|
||||
TimeSeriesHeader input_header;
|
||||
RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
|
||||
::mediapipe::Status status = ConfigureTransform(input_header, cc->Options());
|
||||
|
||||
auto output_header = new TimeSeriesHeader(input_header);
|
||||
output_header->set_num_channels(num_output_channels_);
|
||||
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
::mediapipe::Status FramewiseTransformCalculatorBase::Process(
|
||||
CalculatorContext* cc) {
|
||||
const Matrix& input = cc->Inputs().Index(0).Get<Matrix>();
|
||||
const int num_frames = input.cols();
|
||||
std::unique_ptr<Matrix> output(new Matrix(num_output_channels_, num_frames));
|
||||
// The main work here is converting each column of the float Matrix
|
||||
// into a vector of doubles, which is what our target functions from
|
||||
// dsp_core consume, and doing the reverse with their output.
|
||||
std::vector<double> input_frame(input.rows());
|
||||
std::vector<double> output_frame(num_output_channels_);
|
||||
|
||||
for (int frame = 0; frame < num_frames; ++frame) {
|
||||
// Copy input from Eigen::Matrix column to vector<float>.
|
||||
Eigen::Map<Eigen::MatrixXd> input_frame_map(&input_frame[0],
|
||||
input_frame.size(), 1);
|
||||
input_frame_map = input.col(frame).cast<double>();
|
||||
|
||||
// Perform the actual transformation.
|
||||
TransformFrame(input_frame, &output_frame);
|
||||
|
||||
// Copy output from vector<float> to Eigen::Vector.
|
||||
CHECK_EQ(output_frame.size(), num_output_channels_);
|
||||
Eigen::Map<const Eigen::MatrixXd> output_frame_map(&output_frame[0],
|
||||
output_frame.size(), 1);
|
||||
output->col(frame) = output_frame_map.cast<float>();
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Calculator wrapper around the dsp/mfcc/mfcc.cc routine.
|
||||
// Take frames of squared-magnitude spectra from the SpectrogramCalculator
|
||||
// and convert them into Mel Frequency Cepstral Coefficients.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "MfccCalculator"
|
||||
// input_stream: "spectrogram_frames_stream"
|
||||
// output_stream: "mfcc_frames_stream"
|
||||
// options {
|
||||
// [mediapipe.MfccCalculatorOptions.ext] {
|
||||
// mel_spectrum_params {
|
||||
// channel_count: 20
|
||||
// min_frequency_hertz: 125.0
|
||||
// max_frequency_hertz: 3800.0
|
||||
// }
|
||||
// mfcc_count: 13
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class MfccCalculator : public FramewiseTransformCalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
return FramewiseTransformCalculatorBase::GetContract(cc);
|
||||
}
|
||||
|
||||
private:
|
||||
::mediapipe::Status ConfigureTransform(
|
||||
const TimeSeriesHeader& header,
|
||||
const CalculatorOptions& options) override {
|
||||
MfccCalculatorOptions mfcc_options;
|
||||
time_series_util::FillOptionsExtensionOrDie(options, &mfcc_options);
|
||||
mfcc_.reset(new audio_dsp::Mfcc());
|
||||
int input_length = header.num_channels();
|
||||
// Set up the parameters to the Mfcc object.
|
||||
set_num_output_channels(mfcc_options.mfcc_count());
|
||||
mfcc_->set_dct_coefficient_count(num_output_channels());
|
||||
mfcc_->set_upper_frequency_limit(
|
||||
mfcc_options.mel_spectrum_params().max_frequency_hertz());
|
||||
mfcc_->set_lower_frequency_limit(
|
||||
mfcc_options.mel_spectrum_params().min_frequency_hertz());
|
||||
mfcc_->set_filterbank_channel_count(
|
||||
mfcc_options.mel_spectrum_params().channel_count());
|
||||
// An upstream calculator (such as SpectrogramCalculator) must store
|
||||
// the sample rate of its input audio waveform in the TimeSeries Header.
|
||||
// audio_dsp::MelFilterBank needs to know this to
|
||||
// correctly interpret the spectrogram bins.
|
||||
if (!header.has_audio_sample_rate()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ",
|
||||
PortableDebugString(header)));
|
||||
}
|
||||
// Now we can initialize the Mfcc object.
|
||||
bool initialized =
|
||||
mfcc_->Initialize(input_length, header.audio_sample_rate());
|
||||
|
||||
if (initialized) {
|
||||
return ::mediapipe::OkStatus();
|
||||
} else {
|
||||
return ::mediapipe::Status(mediapipe::StatusCode::kInternal,
|
||||
"Mfcc::Initialize returned uninitialized");
|
||||
}
|
||||
}
|
||||
|
||||
void TransformFrame(const std::vector<double>& input,
|
||||
std::vector<double>* output) const override {
|
||||
mfcc_->Compute(input, output);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<audio_dsp::Mfcc> mfcc_;
|
||||
};
|
||||
REGISTER_CALCULATOR(MfccCalculator);
|
||||
|
||||
// Calculator wrapper around the dsp/mfcc/mel_filterbank.cc routine.
|
||||
// Take frames of squared-magnitude spectra from the SpectrogramCalculator
|
||||
// and convert them into Mel-warped (linear-magnitude) spectra.
|
||||
// Note: This code computes a mel-frequency filterbank, using a simple
|
||||
// algorithm that gives bad results (some mel channels that are always zero)
|
||||
// if you ask for too many channels.
|
||||
class MelSpectrumCalculator : public FramewiseTransformCalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
return FramewiseTransformCalculatorBase::GetContract(cc);
|
||||
}
|
||||
|
||||
private:
|
||||
::mediapipe::Status ConfigureTransform(
|
||||
const TimeSeriesHeader& header,
|
||||
const CalculatorOptions& options) override {
|
||||
MelSpectrumCalculatorOptions mel_spectrum_options;
|
||||
time_series_util::FillOptionsExtensionOrDie(options, &mel_spectrum_options);
|
||||
mel_filterbank_.reset(new audio_dsp::MelFilterbank());
|
||||
int input_length = header.num_channels();
|
||||
set_num_output_channels(mel_spectrum_options.channel_count());
|
||||
// An upstream calculator (such as SpectrogramCalculator) must store
|
||||
// the sample rate of its input audio waveform in the TimeSeries Header.
|
||||
// audio_dsp::MelFilterBank needs to know this to
|
||||
// correctly interpret the spectrogram bins.
|
||||
if (!header.has_audio_sample_rate()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ",
|
||||
PortableDebugString(header)));
|
||||
}
|
||||
bool initialized = mel_filterbank_->Initialize(
|
||||
input_length, header.audio_sample_rate(), num_output_channels(),
|
||||
mel_spectrum_options.min_frequency_hertz(),
|
||||
mel_spectrum_options.max_frequency_hertz());
|
||||
|
||||
if (initialized) {
|
||||
return ::mediapipe::OkStatus();
|
||||
} else {
|
||||
return ::mediapipe::Status(mediapipe::StatusCode::kInternal,
|
||||
"mfcc::Initialize returned uninitialized");
|
||||
}
|
||||
}
|
||||
|
||||
void TransformFrame(const std::vector<double>& input,
|
||||
std::vector<double>* output) const override {
|
||||
mel_filterbank_->Compute(input, output);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<audio_dsp::MelFilterbank> mel_filterbank_;
|
||||
};
|
||||
REGISTER_CALCULATOR(MelSpectrumCalculator);
|
||||
|
||||
} // namespace mediapipe
|
50
mediapipe/calculators/audio/mfcc_mel_calculators.proto
Normal file
50
mediapipe/calculators/audio/mfcc_mel_calculators.proto
Normal file
|
@ -0,0 +1,50 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message MelSpectrumCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional MelSpectrumCalculatorOptions ext = 78581812;
|
||||
}
|
||||
// The fields are to populate the config parameters in
|
||||
// audio/dsp/mfcc/mel_filterbank.h
|
||||
// but the names are chose to mirror
|
||||
// audio/hearing/filterbanks/cochlea_gammatone_filterbank.proto
|
||||
// and the default values match those in
|
||||
// speech/greco3/frontend/filter_bank.proto .
|
||||
|
||||
// Total number of frequency bands to use.
|
||||
optional int32 channel_count = 1 [default = 20];
|
||||
// Lower edge of lowest triangular Mel band.
|
||||
optional float min_frequency_hertz = 2 [default = 125.0];
|
||||
// Upper edge of highest triangular Mel band.
|
||||
optional float max_frequency_hertz = 3 [default = 3800.0];
|
||||
}
|
||||
|
||||
message MfccCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional MfccCalculatorOptions ext = 78450441;
|
||||
}
|
||||
|
||||
// Specification of the underlying mel filterbank.
|
||||
optional MelSpectrumCalculatorOptions mel_spectrum_params = 1;
|
||||
|
||||
// How many MFCC coefficients to emit.
|
||||
optional uint32 mfcc_count = 2 [default = 13];
|
||||
}
|
149
mediapipe/calculators/audio/mfcc_mel_calculators_test.cc
Normal file
149
mediapipe/calculators/audio/mfcc_mel_calculators_test.cc
Normal file
|
@ -0,0 +1,149 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/calculators/audio/mfcc_mel_calculators.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Use a sample rate that is unlikely to be a default somewhere.
|
||||
const float kAudioSampleRate = 8800.0;
|
||||
|
||||
template <typename OptionsType, const char* CalculatorName>
|
||||
class FramewiseTransformCalculatorTest
|
||||
: public TimeSeriesCalculatorTest<OptionsType> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
this->calculator_name_ = CalculatorName;
|
||||
this->num_input_channels_ = 129;
|
||||
// This is the frame rate coming out of the SpectrogramCalculator.
|
||||
this->input_sample_rate_ = 100.0;
|
||||
}
|
||||
|
||||
// Returns the number of samples per packet.
|
||||
int GenerateRandomNonnegInputStream(int num_packets) {
|
||||
const double kSecondsPerPacket = 0.2;
|
||||
const int num_samples_per_packet =
|
||||
kSecondsPerPacket * this->input_sample_rate_;
|
||||
for (int i = 0; i < num_packets; ++i) {
|
||||
const int timestamp =
|
||||
i * kSecondsPerPacket * Timestamp::kTimestampUnitsPerSecond;
|
||||
// Mfcc, MelSpectrum expect squared-magnitude inputs, so make
|
||||
// sure the input data has no negative values.
|
||||
Matrix* sqdata = this->NewRandomMatrix(this->num_input_channels_,
|
||||
num_samples_per_packet);
|
||||
*sqdata = sqdata->array().square();
|
||||
this->AppendInputPacket(sqdata, timestamp);
|
||||
}
|
||||
return num_samples_per_packet;
|
||||
}
|
||||
|
||||
void CheckOutputPacketMetadata(int expected_num_channels,
|
||||
int expected_num_samples_per_packet) {
|
||||
int expected_timestamp = 0;
|
||||
for (const auto& packet : this->output().packets) {
|
||||
EXPECT_EQ(expected_timestamp, packet.Timestamp().Value());
|
||||
expected_timestamp += expected_num_samples_per_packet /
|
||||
this->input_sample_rate_ *
|
||||
Timestamp::kTimestampUnitsPerSecond;
|
||||
|
||||
const Matrix& output_matrix = packet.template Get<Matrix>();
|
||||
|
||||
EXPECT_EQ(output_matrix.rows(), expected_num_channels);
|
||||
EXPECT_EQ(output_matrix.cols(), expected_num_samples_per_packet);
|
||||
}
|
||||
}
|
||||
|
||||
void SetupGraphAndHeader() {
|
||||
this->InitializeGraph();
|
||||
this->FillInputHeader();
|
||||
}
|
||||
|
||||
// Argument is the expected number of dimensions (channels, columns) in
|
||||
// the output data from the Calculator under test, which the test should
|
||||
// know.
|
||||
void SetupRandomInputPackets() {
|
||||
constexpr int kNumPackets = 5;
|
||||
num_samples_per_packet_ = GenerateRandomNonnegInputStream(kNumPackets);
|
||||
}
|
||||
|
||||
::mediapipe::Status Run() { return this->RunGraph(); }
|
||||
|
||||
void CheckResults(int expected_num_channels) {
|
||||
const auto& output_header =
|
||||
this->output().header.template Get<TimeSeriesHeader>();
|
||||
EXPECT_EQ(this->input_sample_rate_, output_header.sample_rate());
|
||||
CheckOutputPacketMetadata(expected_num_channels, num_samples_per_packet_);
|
||||
|
||||
// Sanity check that output packets have non-zero energy.
|
||||
for (const auto& packet : this->output().packets) {
|
||||
const Matrix& data = packet.template Get<Matrix>();
|
||||
EXPECT_GT(data.squaredNorm(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Allows SetupRandomInputPackets() to inform CheckResults() about how
|
||||
// big the packets are supposed to be.
|
||||
int num_samples_per_packet_;
|
||||
};
|
||||
|
||||
constexpr char kMfccCalculator[] = "MfccCalculator";
|
||||
typedef FramewiseTransformCalculatorTest<MfccCalculatorOptions, kMfccCalculator>
|
||||
MfccCalculatorTest;
|
||||
TEST_F(MfccCalculatorTest, AudioSampleRateFromInputHeader) {
|
||||
audio_sample_rate_ = kAudioSampleRate;
|
||||
SetupGraphAndHeader();
|
||||
SetupRandomInputPackets();
|
||||
|
||||
MEDIAPIPE_EXPECT_OK(Run());
|
||||
|
||||
CheckResults(options_.mfcc_count());
|
||||
}
|
||||
TEST_F(MfccCalculatorTest, NoAudioSampleRate) {
|
||||
// Leave audio_sample_rate_ == kUnset, so it is not present in the
|
||||
// input TimeSeriesHeader; expect failure.
|
||||
SetupGraphAndHeader();
|
||||
SetupRandomInputPackets();
|
||||
|
||||
EXPECT_FALSE(Run().ok());
|
||||
}
|
||||
|
||||
constexpr char kMelSpectrumCalculator[] = "MelSpectrumCalculator";
|
||||
typedef FramewiseTransformCalculatorTest<MelSpectrumCalculatorOptions,
|
||||
kMelSpectrumCalculator>
|
||||
MelSpectrumCalculatorTest;
|
||||
TEST_F(MelSpectrumCalculatorTest, AudioSampleRateFromInputHeader) {
|
||||
audio_sample_rate_ = kAudioSampleRate;
|
||||
SetupGraphAndHeader();
|
||||
SetupRandomInputPackets();
|
||||
|
||||
MEDIAPIPE_EXPECT_OK(Run());
|
||||
|
||||
CheckResults(options_.channel_count());
|
||||
}
|
||||
TEST_F(MelSpectrumCalculatorTest, NoAudioSampleRate) {
|
||||
// Leave audio_sample_rate_ == kUnset, so it is not present in the
|
||||
// input TimeSeriesHeader; expect failure.
|
||||
SetupGraphAndHeader();
|
||||
SetupRandomInputPackets();
|
||||
|
||||
EXPECT_FALSE(Run().ok());
|
||||
}
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,197 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Defines RationalFactorResampleCalculator.
|
||||
|
||||
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.h"
|
||||
|
||||
#include "audio/dsp/resampler_rational_factor.h"
|
||||
|
||||
using audio_dsp::DefaultResamplingKernel;
|
||||
using audio_dsp::RationalFactorResampler;
|
||||
using audio_dsp::Resampler;
|
||||
|
||||
namespace mediapipe {
|
||||
::mediapipe::Status RationalFactorResampleCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
return ProcessInternal(cc->Inputs().Index(0).Get<Matrix>(), false, cc);
|
||||
}
|
||||
|
||||
::mediapipe::Status RationalFactorResampleCalculator::Close(
|
||||
CalculatorContext* cc) {
|
||||
if (initial_timestamp_ == Timestamp::Unstarted()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
Matrix empty_input_frame(num_channels_, 0);
|
||||
return ProcessInternal(empty_input_frame, true, cc);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void CopyChannelToVector(const Matrix& matrix, int channel,
|
||||
std::vector<float>* vec) {
|
||||
vec->clear();
|
||||
vec->reserve(matrix.cols());
|
||||
for (int sample = 0; sample < matrix.cols(); ++sample) {
|
||||
vec->push_back(matrix(channel, sample));
|
||||
}
|
||||
}
|
||||
|
||||
void CopyVectorToChannel(const std::vector<float>& vec, Matrix* matrix,
|
||||
int channel) {
|
||||
if (matrix->cols() == 0) {
|
||||
matrix->resize(matrix->rows(), vec.size());
|
||||
} else {
|
||||
CHECK_EQ(vec.size(), matrix->cols());
|
||||
CHECK_LT(channel, matrix->rows());
|
||||
}
|
||||
for (int sample = 0; sample < matrix->cols(); ++sample) {
|
||||
(*matrix)(channel, sample) = vec[sample];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
::mediapipe::Status RationalFactorResampleCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
RationalFactorResampleCalculatorOptions resample_options;
|
||||
time_series_util::FillOptionsExtensionOrDie(cc->Options(), &resample_options);
|
||||
|
||||
if (!resample_options.has_target_sample_rate()) {
|
||||
return tool::StatusInvalid(
|
||||
"resample_options doesn't have target_sample_rate.");
|
||||
}
|
||||
target_sample_rate_ = resample_options.target_sample_rate();
|
||||
|
||||
TimeSeriesHeader input_header;
|
||||
RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
|
||||
source_sample_rate_ = input_header.sample_rate();
|
||||
num_channels_ = input_header.num_channels();
|
||||
|
||||
// Don't create resamplers for pass-thru (sample rates are equal).
|
||||
if (source_sample_rate_ != target_sample_rate_) {
|
||||
resampler_.resize(num_channels_);
|
||||
for (auto& r : resampler_) {
|
||||
r = ResamplerFromOptions(source_sample_rate_, target_sample_rate_,
|
||||
resample_options);
|
||||
if (!r) {
|
||||
LOG(ERROR) << "Failed to initialize resampler.";
|
||||
return ::mediapipe::UnknownError("Failed to initialize resampler.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TimeSeriesHeader* output_header = new TimeSeriesHeader(input_header);
|
||||
output_header->set_sample_rate(target_sample_rate_);
|
||||
// The resampler doesn't make guarantees about how many samples will
|
||||
// be in each packet.
|
||||
output_header->clear_packet_rate();
|
||||
output_header->clear_num_samples();
|
||||
|
||||
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
|
||||
cumulative_output_samples_ = 0;
|
||||
cumulative_input_samples_ = 0;
|
||||
initial_timestamp_ = Timestamp::Unstarted();
|
||||
check_inconsistent_timestamps_ =
|
||||
resample_options.check_inconsistent_timestamps();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status RationalFactorResampleCalculator::ProcessInternal(
|
||||
const Matrix& input_frame, bool should_flush, CalculatorContext* cc) {
|
||||
if (initial_timestamp_ == Timestamp::Unstarted()) {
|
||||
initial_timestamp_ = cc->InputTimestamp();
|
||||
}
|
||||
|
||||
if (check_inconsistent_timestamps_) {
|
||||
time_series_util::LogWarningIfTimestampIsInconsistent(
|
||||
cc->InputTimestamp(), initial_timestamp_, cumulative_input_samples_,
|
||||
source_sample_rate_);
|
||||
}
|
||||
Timestamp output_timestamp =
|
||||
initial_timestamp_ + ((cumulative_output_samples_ / target_sample_rate_) *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
|
||||
cumulative_input_samples_ += input_frame.cols();
|
||||
std::unique_ptr<Matrix> output_frame(new Matrix(num_channels_, 0));
|
||||
if (resampler_.empty()) {
|
||||
// Sample rates were same for input and output; pass-thru.
|
||||
*output_frame = input_frame;
|
||||
} else {
|
||||
if (!Resample(input_frame, output_frame.get(), should_flush)) {
|
||||
return ::mediapipe::UnknownError("Resample() failed.");
|
||||
}
|
||||
}
|
||||
cumulative_output_samples_ += output_frame->cols();
|
||||
|
||||
if (output_frame->cols() > 0) {
|
||||
cc->Outputs().Index(0).Add(output_frame.release(), output_timestamp);
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
bool RationalFactorResampleCalculator::Resample(const Matrix& input_frame,
|
||||
Matrix* output_frame,
|
||||
bool should_flush) {
|
||||
std::vector<float> input_vector;
|
||||
std::vector<float> output_vector;
|
||||
for (int i = 0; i < input_frame.rows(); ++i) {
|
||||
CopyChannelToVector(input_frame, i, &input_vector);
|
||||
if (should_flush) {
|
||||
resampler_[i]->Flush(&output_vector);
|
||||
} else {
|
||||
resampler_[i]->ProcessSamples(input_vector, &output_vector);
|
||||
}
|
||||
CopyVectorToChannel(output_vector, output_frame, i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// static
|
||||
std::unique_ptr<Resampler<float>>
|
||||
RationalFactorResampleCalculator::ResamplerFromOptions(
|
||||
const double source_sample_rate, const double target_sample_rate,
|
||||
const RationalFactorResampleCalculatorOptions& options) {
|
||||
std::unique_ptr<Resampler<float>> resampler;
|
||||
const auto& rational_factor_options =
|
||||
options.resampler_rational_factor_options();
|
||||
std::unique_ptr<DefaultResamplingKernel> kernel;
|
||||
if (rational_factor_options.has_radius() &&
|
||||
rational_factor_options.has_cutoff() &&
|
||||
rational_factor_options.has_kaiser_beta()) {
|
||||
kernel = absl::make_unique<DefaultResamplingKernel>(
|
||||
source_sample_rate, target_sample_rate,
|
||||
rational_factor_options.radius(), rational_factor_options.cutoff(),
|
||||
rational_factor_options.kaiser_beta());
|
||||
} else {
|
||||
kernel = absl::make_unique<DefaultResamplingKernel>(source_sample_rate,
|
||||
target_sample_rate);
|
||||
}
|
||||
|
||||
// Set large enough so that the resampling factor between common sample
|
||||
// rates (e.g. 8kHz, 16kHz, 22.05kHz, 32kHz, 44.1kHz, 48kHz) is exact, and
|
||||
// that any factor is represented with error less than 0.025%.
|
||||
const int kMaxDenominator = 2000;
|
||||
resampler = absl::make_unique<RationalFactorResampler<float>>(
|
||||
*kernel, kMaxDenominator);
|
||||
if (resampler != nullptr && !resampler->Valid()) {
|
||||
resampler = std::unique_ptr<Resampler<float>>();
|
||||
}
|
||||
return resampler;
|
||||
}
|
||||
|
||||
REGISTER_CALCULATOR(RationalFactorResampleCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,106 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "audio/dsp/resampler.h"
|
||||
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
// MediaPipe Calculator for resampling a (vector-valued)
|
||||
// input time series with a uniform sample rate. The output
|
||||
// stream's sampling rate is specified by target_sample_rate in the
|
||||
// RationalFactorResampleCalculatorOptions. The output time series may have
|
||||
// a varying number of samples per frame.
|
||||
class RationalFactorResampleCalculator : public CalculatorBase {
|
||||
public:
|
||||
struct TestAccess;
|
||||
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Single input stream with TimeSeriesHeader.
|
||||
);
|
||||
cc->Outputs().Index(0).Set<Matrix>(
|
||||
// Resampled stream with TimeSeriesHeader.
|
||||
);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
// Returns FAIL if the input stream header is invalid or if the
|
||||
// resampler cannot be initialized.
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
// Resamples a packet of TimeSeries data. Returns FAIL if the
|
||||
// resampler state becomes inconsistent.
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
// Flushes any remaining state. Returns FAIL if the resampler state
|
||||
// becomes inconsistent.
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
protected:
|
||||
typedef audio_dsp::Resampler<float> ResamplerType;
|
||||
|
||||
// Returns a Resampler<float> implementation specified by the
|
||||
// RationalFactorResampleCalculatorOptions proto. Returns null if the options
|
||||
// specify an invalid resampler.
|
||||
static std::unique_ptr<ResamplerType> ResamplerFromOptions(
|
||||
const double source_sample_rate, const double target_sample_rate,
|
||||
const RationalFactorResampleCalculatorOptions& options);
|
||||
|
||||
// Does Timestamp bookkeeping and resampling common to Process() and
|
||||
// Close(). Returns FAIL if the resampler state becomes
|
||||
// inconsistent.
|
||||
::mediapipe::Status ProcessInternal(const Matrix& input_frame,
|
||||
bool should_flush, CalculatorContext* cc);
|
||||
|
||||
// Uses the internal resampler_ objects to actually resample each
|
||||
// row of the input TimeSeries. Returns false if the resampler
|
||||
// state becomes inconsistent.
|
||||
bool Resample(const Matrix& input_frame, Matrix* output_frame,
|
||||
bool should_flush);
|
||||
|
||||
double source_sample_rate_;
|
||||
double target_sample_rate_;
|
||||
int64 cumulative_input_samples_;
|
||||
int64 cumulative_output_samples_;
|
||||
Timestamp initial_timestamp_;
|
||||
bool check_inconsistent_timestamps_;
|
||||
int num_channels_;
|
||||
std::vector<std::unique_ptr<ResamplerType>> resampler_;
|
||||
};
|
||||
|
||||
// Test-only access to RationalFactorResampleCalculator methods.
|
||||
struct RationalFactorResampleCalculator::TestAccess {
|
||||
static std::unique_ptr<ResamplerType> ResamplerFromOptions(
|
||||
const double source_sample_rate, const double target_sample_rate,
|
||||
const RationalFactorResampleCalculatorOptions& options) {
|
||||
return RationalFactorResampleCalculator::ResamplerFromOptions(
|
||||
source_sample_rate, target_sample_rate, options);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message RationalFactorResampleCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional RationalFactorResampleCalculatorOptions ext = 259760074;
|
||||
}
|
||||
|
||||
// target_sample_rate is the sample rate, in Hertz, of the output
|
||||
// stream. Required. Must be greater than 0.
|
||||
optional double target_sample_rate = 1;
|
||||
|
||||
// Parameters for initializing the RationalFactorResampler. See
|
||||
// RationalFactorResampler for more details.
|
||||
message ResamplerRationalFactorOptions {
|
||||
// Kernel radius in units of input samples.
|
||||
optional double radius = 1;
|
||||
// Anti-aliasing cutoff frequency in Hertz. A reasonable setting is
|
||||
// 0.45 * min(input_sample_rate, output_sample_rate).
|
||||
optional double cutoff = 2;
|
||||
// The Kaiser beta parameter for the kernel window.
|
||||
optional double kaiser_beta = 3 [default = 6.0];
|
||||
}
|
||||
optional ResamplerRationalFactorOptions resampler_rational_factor_options = 2;
|
||||
|
||||
// Set to false to disable checks for jitter in timestamp values. Useful with
|
||||
// live audio input.
|
||||
optional bool check_inconsistent_timestamps = 3 [default = true];
|
||||
}
|
|
@ -0,0 +1,247 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "audio/dsp/signal_vector_util.h"
|
||||
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.pb.h"
|
||||
#include "mediapipe/framework//tool/validate_type.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
const int kInitialTimestampOffsetMilliseconds = 4;
|
||||
|
||||
class RationalFactorResampleCalculatorTest
|
||||
: public TimeSeriesCalculatorTest<RationalFactorResampleCalculatorOptions> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "RationalFactorResampleCalculator";
|
||||
input_sample_rate_ = 4000.0;
|
||||
num_input_channels_ = 3;
|
||||
}
|
||||
|
||||
// Expects two vectors whose lengths are almost the same and whose
|
||||
// elements are equal (for indices that are present in both).
|
||||
//
|
||||
// This is useful because the resampler doesn't make precise
|
||||
// guarantees about its output size.
|
||||
void ExpectVectorMostlyFloatEq(const std::vector<float>& expected,
|
||||
const std::vector<float>& actual) {
|
||||
// Lengths should be close, but don't have to be equal.
|
||||
ASSERT_NEAR(expected.size(), actual.size(), 1);
|
||||
for (int i = 0; i < std::min(expected.size(), actual.size()); ++i) {
|
||||
EXPECT_FLOAT_EQ(expected[i], actual[i]) << " where i=" << i << ".";
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a float value with the sample, channel, and timestamp
|
||||
// separated by a few orders of magnitude, for easy parsing by
|
||||
// humans.
|
||||
double TestValue(int sample, int channel, int timestamp_in_microseconds) {
|
||||
return timestamp_in_microseconds * 100.0 + sample + channel / 10.0;
|
||||
}
|
||||
|
||||
// Caller takes ownership of the returned value.
|
||||
Matrix* NewTestFrame(int num_channels, int num_samples, int timestamp) {
|
||||
auto matrix = new Matrix(num_channels, num_samples);
|
||||
for (int c = 0; c < num_channels; ++c) {
|
||||
for (int i = 0; i < num_samples; ++i) {
|
||||
(*matrix)(c, i) = TestValue(i, c, timestamp);
|
||||
}
|
||||
}
|
||||
return matrix;
|
||||
}
|
||||
|
||||
// Initializes and runs the test graph.
|
||||
::mediapipe::Status Run(double output_sample_rate) {
|
||||
options_.set_target_sample_rate(output_sample_rate);
|
||||
InitializeGraph();
|
||||
|
||||
FillInputHeader();
|
||||
concatenated_input_samples_.resize(num_input_channels_, 0);
|
||||
num_input_samples_ = 0;
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
int packet_size = (i + 1) * 10;
|
||||
int timestamp = kInitialTimestampOffsetMilliseconds +
|
||||
num_input_samples_ / input_sample_rate_ *
|
||||
Timestamp::kTimestampUnitsPerSecond;
|
||||
Matrix* data_frame =
|
||||
NewTestFrame(num_input_channels_, packet_size, timestamp);
|
||||
|
||||
// Keep a reference copy of the input.
|
||||
//
|
||||
// conservativeResize() is needed here to preserve the existing
|
||||
// data. Eigen's resize() resizes without preserving data.
|
||||
concatenated_input_samples_.conservativeResize(
|
||||
num_input_channels_, num_input_samples_ + packet_size);
|
||||
concatenated_input_samples_.rightCols(packet_size) = *data_frame;
|
||||
num_input_samples_ += packet_size;
|
||||
|
||||
AppendInputPacket(data_frame, timestamp);
|
||||
}
|
||||
|
||||
return RunGraph();
|
||||
}
|
||||
|
||||
void CheckOutputLength(double output_sample_rate) {
|
||||
double factor = output_sample_rate / input_sample_rate_;
|
||||
|
||||
int num_output_samples = 0;
|
||||
for (const Packet& packet : output().packets) {
|
||||
num_output_samples += packet.Get<Matrix>().cols();
|
||||
}
|
||||
|
||||
// The exact number of expected samples may vary based on the implementation
|
||||
// of the resampler since the exact value is not an integer.
|
||||
// TODO: Reduce this offset to + 1 once cl/185829520 is submitted.
|
||||
const double expected_num_output_samples = num_input_samples_ * factor;
|
||||
EXPECT_LE(ceil(expected_num_output_samples), num_output_samples);
|
||||
EXPECT_GE(ceil(expected_num_output_samples) + 11, num_output_samples);
|
||||
}
|
||||
|
||||
// Checks that output timestamps are consistent with the
|
||||
// output_sample_rate and output packet sizes.
|
||||
void CheckOutputPacketTimestamps(double output_sample_rate) {
|
||||
int num_output_samples = 0;
|
||||
for (const Packet& packet : output().packets) {
|
||||
const int expected_timestamp = kInitialTimestampOffsetMilliseconds +
|
||||
num_output_samples / output_sample_rate *
|
||||
Timestamp::kTimestampUnitsPerSecond;
|
||||
EXPECT_NEAR(expected_timestamp, packet.Timestamp().Value(), 1);
|
||||
num_output_samples += packet.Get<Matrix>().cols();
|
||||
}
|
||||
}
|
||||
|
||||
// Checks that output values from the calculator (which resamples
|
||||
// packet-by-packet) are consistent with resampling the entire
|
||||
// signal at once.
|
||||
void CheckOutputValues(double output_sample_rate) {
|
||||
for (int i = 0; i < num_input_channels_; ++i) {
|
||||
auto verification_resampler =
|
||||
RationalFactorResampleCalculator::TestAccess::ResamplerFromOptions(
|
||||
input_sample_rate_, output_sample_rate, options_);
|
||||
|
||||
std::vector<float> input_data;
|
||||
for (int j = 0; j < num_input_samples_; ++j) {
|
||||
input_data.push_back(concatenated_input_samples_(i, j));
|
||||
}
|
||||
std::vector<float> expected_resampled_data;
|
||||
std::vector<float> temp;
|
||||
verification_resampler->ProcessSamples(input_data, &temp);
|
||||
audio_dsp::VectorAppend(&expected_resampled_data, temp);
|
||||
verification_resampler->Flush(&temp);
|
||||
audio_dsp::VectorAppend(&expected_resampled_data, temp);
|
||||
std::vector<float> actual_resampled_data;
|
||||
for (const Packet& packet : output().packets) {
|
||||
Matrix output_frame_row = packet.Get<Matrix>().row(i);
|
||||
actual_resampled_data.insert(
|
||||
actual_resampled_data.end(), &output_frame_row(0),
|
||||
&output_frame_row(0) + output_frame_row.cols());
|
||||
}
|
||||
|
||||
ExpectVectorMostlyFloatEq(expected_resampled_data, actual_resampled_data);
|
||||
}
|
||||
}
|
||||
|
||||
void CheckOutputHeaders(double output_sample_rate) {
|
||||
const TimeSeriesHeader& output_header =
|
||||
output().header.Get<TimeSeriesHeader>();
|
||||
TimeSeriesHeader expected_header;
|
||||
expected_header.set_sample_rate(output_sample_rate);
|
||||
expected_header.set_num_channels(num_input_channels_);
|
||||
EXPECT_THAT(output_header, mediapipe::EqualsProto(expected_header));
|
||||
}
|
||||
|
||||
void CheckOutput(double output_sample_rate) {
|
||||
CheckOutputLength(output_sample_rate);
|
||||
CheckOutputPacketTimestamps(output_sample_rate);
|
||||
CheckOutputValues(output_sample_rate);
|
||||
CheckOutputHeaders(output_sample_rate);
|
||||
}
|
||||
|
||||
void CheckOutputUnchanged() {
|
||||
for (int i = 0; i < num_input_channels_; ++i) {
|
||||
std::vector<float> expected_resampled_data;
|
||||
for (int j = 0; j < num_input_samples_; ++j) {
|
||||
expected_resampled_data.push_back(concatenated_input_samples_(i, j));
|
||||
}
|
||||
std::vector<float> actual_resampled_data;
|
||||
for (const Packet& packet : output().packets) {
|
||||
Matrix output_frame_row = packet.Get<Matrix>().row(i);
|
||||
actual_resampled_data.insert(
|
||||
actual_resampled_data.end(), &output_frame_row(0),
|
||||
&output_frame_row(0) + output_frame_row.cols());
|
||||
}
|
||||
ExpectVectorMostlyFloatEq(expected_resampled_data, actual_resampled_data);
|
||||
}
|
||||
}
|
||||
|
||||
int num_input_samples_;
|
||||
Matrix concatenated_input_samples_;
|
||||
};
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, Upsample) {
|
||||
const double kUpsampleRate = input_sample_rate_ * 1.9;
|
||||
MEDIAPIPE_ASSERT_OK(Run(kUpsampleRate));
|
||||
CheckOutput(kUpsampleRate);
|
||||
}
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, Downsample) {
|
||||
const double kDownsampleRate = input_sample_rate_ / 1.9;
|
||||
MEDIAPIPE_ASSERT_OK(Run(kDownsampleRate));
|
||||
CheckOutput(kDownsampleRate);
|
||||
}
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, UsesRationalFactorResampler) {
|
||||
const double kUpsampleRate = input_sample_rate_ * 2;
|
||||
MEDIAPIPE_ASSERT_OK(Run(kUpsampleRate));
|
||||
CheckOutput(kUpsampleRate);
|
||||
}
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, PassthroughIfSampleRateUnchanged) {
|
||||
const double kUpsampleRate = input_sample_rate_;
|
||||
MEDIAPIPE_ASSERT_OK(Run(kUpsampleRate));
|
||||
CheckOutputUnchanged();
|
||||
}
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, FailsOnBadTargetRate) {
|
||||
ASSERT_FALSE(Run(-999.9).ok()); // Invalid output sample rate.
|
||||
}
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, DoesNotDieOnEmptyInput) {
|
||||
options_.set_target_sample_rate(input_sample_rate_);
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
MEDIAPIPE_ASSERT_OK(RunGraph());
|
||||
EXPECT_TRUE(output().packets.empty());
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace mediapipe
|
425
mediapipe/calculators/audio/spectrogram_calculator.cc
Normal file
425
mediapipe/calculators/audio/spectrogram_calculator.cc
Normal file
|
@ -0,0 +1,425 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Defines SpectrogramCalculator.
|
||||
#include <math.h>
|
||||
|
||||
#include <complex>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "audio/dsp/spectrogram/spectrogram.h"
|
||||
#include "audio/dsp/window_functions.h"
|
||||
#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/core_proto_inc.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/source_location.h"
|
||||
#include "mediapipe/framework/port/status_builder.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// MediaPipe Calculator for computing the "spectrogram" (short-time Fourier
|
||||
// transform squared-magnitude, by default) of a multichannel input
|
||||
// time series, including optionally overlapping frames. Options are
|
||||
// specified in SpectrogramCalculatorOptions proto (where names are chosen
|
||||
// to mirror TimeSeriesFramerCalculator):
|
||||
//
|
||||
// Result is a MatrixData record (for single channel input and when the
|
||||
// allow_multichannel_input flag is false), or a vector of MatrixData records,
|
||||
// one for each channel (when the allow_multichannel_input flag is set). The
|
||||
// rows of each spectrogram matrix correspond to the n_fft/2+1 unique complex
|
||||
// values, or squared/linear/dB magnitudes, depending on the output_type option.
|
||||
// Each input packet will result in zero or one output packets, each containing
|
||||
// one Matrix for each channel of the input, where each Matrix has one or more
|
||||
// columns of spectral values, one for each complete frame of input samples. If
|
||||
// the input packet contains too few samples to trigger a new output frame, no
|
||||
// output packet is generated (since zero-length packets are not legal since
|
||||
// they would result in timestamps that were equal, not strictly increasing).
|
||||
//
|
||||
// Output packet Timestamps are set to the beginning of each frame. This is to
|
||||
// allow calculators downstream from SpectrogramCalculator to have aligned
|
||||
// Timestamps regardless of a packet's signal length.
|
||||
//
|
||||
// Both frame_duration_seconds and frame_overlap_seconds will be
|
||||
// rounded to the nearest integer number of samples. Conseqently, all output
|
||||
// frames will be based on the same number of input samples, and each
|
||||
// analysis frame will advance from its predecessor by the same time step.
|
||||
class SpectrogramCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Input stream with TimeSeriesHeader.
|
||||
);
|
||||
|
||||
SpectrogramCalculatorOptions spectrogram_options;
|
||||
time_series_util::FillOptionsExtensionOrDie(cc->Options(),
|
||||
&spectrogram_options);
|
||||
|
||||
if (!spectrogram_options.allow_multichannel_input()) {
|
||||
if (spectrogram_options.output_type() ==
|
||||
SpectrogramCalculatorOptions::COMPLEX) {
|
||||
cc->Outputs().Index(0).Set<Eigen::MatrixXcf>(
|
||||
// Complex spectrogram frames with TimeSeriesHeader.
|
||||
);
|
||||
} else {
|
||||
cc->Outputs().Index(0).Set<Matrix>(
|
||||
// Spectrogram frames with TimeSeriesHeader.
|
||||
);
|
||||
}
|
||||
} else {
|
||||
if (spectrogram_options.output_type() ==
|
||||
SpectrogramCalculatorOptions::COMPLEX) {
|
||||
cc->Outputs().Index(0).Set<std::vector<Eigen::MatrixXcf>>(
|
||||
// Complex spectrogram frames with MultiStreamTimeSeriesHeader.
|
||||
);
|
||||
} else {
|
||||
cc->Outputs().Index(0).Set<std::vector<Matrix>>(
|
||||
// Spectrogram frames with MultiStreamTimeSeriesHeader.
|
||||
);
|
||||
}
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Returns FAIL if the input stream header is invalid.
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
|
||||
// Outputs at most one packet consisting of a single Matrix with one or
|
||||
// more columns containing the spectral values from as many input frames
|
||||
// as are completed by the input samples. Always returns OK.
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
// Performs zero-padding and processing of any remaining samples
|
||||
// if pad_final_packet is set.
|
||||
// Returns OK.
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
Timestamp CurrentOutputTimestamp() {
|
||||
// Current output timestamp is the *center* of the next frame to be
|
||||
// emitted, hence delayed by half a window duration compared to relevant
|
||||
// input timestamp.
|
||||
return initial_input_timestamp_ +
|
||||
round(cumulative_completed_frames_ * frame_step_samples() *
|
||||
Timestamp::kTimestampUnitsPerSecond / input_sample_rate_);
|
||||
}
|
||||
|
||||
int frame_step_samples() const {
|
||||
return frame_duration_samples_ - frame_overlap_samples_;
|
||||
}
|
||||
|
||||
// Take the next set of input samples, already translated into a
|
||||
// vector<float> and pass them to the spectrogram object.
|
||||
// Convert the output of the spectrogram object into a Matrix (or an
|
||||
// Eigen::MatrixXcf if complex-valued output is requested) and pass to
|
||||
// MediaPipe output.
|
||||
::mediapipe::Status ProcessVector(const Matrix& input_stream,
|
||||
CalculatorContext* cc);
|
||||
|
||||
// Templated function to process either real- or complex-output spectrogram.
|
||||
template <class OutputMatrixType>
|
||||
::mediapipe::Status ProcessVectorToOutput(
|
||||
const Matrix& input_stream,
|
||||
const OutputMatrixType postprocess_output_fn(const OutputMatrixType&),
|
||||
CalculatorContext* cc);
|
||||
|
||||
double input_sample_rate_;
|
||||
bool pad_final_packet_;
|
||||
int frame_duration_samples_;
|
||||
int frame_overlap_samples_;
|
||||
// How many samples we've been passed, used for checking input time stamps.
|
||||
int64 cumulative_input_samples_;
|
||||
// How many frames we've emitted, used for calculating output time stamps.
|
||||
int64 cumulative_completed_frames_;
|
||||
Timestamp initial_input_timestamp_;
|
||||
int num_input_channels_;
|
||||
// How many frequency bins we emit (=N_FFT/2 + 1).
|
||||
int num_output_channels_;
|
||||
// Which output type?
|
||||
int output_type_;
|
||||
// Output type: mono or multichannel.
|
||||
bool allow_multichannel_input_;
|
||||
// Vector of Spectrogram objects, one for each channel.
|
||||
std::vector<std::unique_ptr<audio_dsp::Spectrogram>> spectrogram_generators_;
|
||||
// Fixed scale factor applied to output values (regardless of type).
|
||||
double output_scale_;
|
||||
|
||||
static const float kLnPowerToDb;
|
||||
};
|
||||
REGISTER_CALCULATOR(SpectrogramCalculator);
|
||||
|
||||
// Factor to convert ln(magnitude_squared) to deciBels = 10.0/ln(10.0).
|
||||
const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518;
|
||||
|
||||
::mediapipe::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
|
||||
SpectrogramCalculatorOptions spectrogram_options;
|
||||
time_series_util::FillOptionsExtensionOrDie(cc->Options(),
|
||||
&spectrogram_options);
|
||||
|
||||
if (spectrogram_options.frame_duration_seconds() <= 0.0) {
|
||||
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Invalid or missing frame_duration_seconds.\n"
|
||||
"frame_duration_seconds: "
|
||||
<< spectrogram_options.frame_overlap_seconds();
|
||||
}
|
||||
if (spectrogram_options.frame_overlap_seconds() >=
|
||||
spectrogram_options.frame_duration_seconds()) {
|
||||
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Invalid frame_overlap_seconds.\nframe_overlap_seconds: "
|
||||
<< spectrogram_options.frame_overlap_seconds()
|
||||
<< "\nframe_duration_seconds: "
|
||||
<< spectrogram_options.frame_duration_seconds();
|
||||
}
|
||||
if (spectrogram_options.frame_overlap_seconds() < 0.0) {
|
||||
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Frame_overlap_seconds is < 0.0.\nframe_overlap_seconds: "
|
||||
<< spectrogram_options.frame_overlap_seconds();
|
||||
}
|
||||
|
||||
TimeSeriesHeader input_header;
|
||||
RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
|
||||
input_sample_rate_ = input_header.sample_rate();
|
||||
num_input_channels_ = input_header.num_channels();
|
||||
|
||||
if (!spectrogram_options.allow_multichannel_input() &&
|
||||
num_input_channels_ != 1) {
|
||||
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "The current setting only supports single-channel input. Please set "
|
||||
"allow_multichannel_input.\n";
|
||||
}
|
||||
|
||||
frame_duration_samples_ =
|
||||
round(spectrogram_options.frame_duration_seconds() * input_sample_rate_);
|
||||
frame_overlap_samples_ =
|
||||
round(spectrogram_options.frame_overlap_seconds() * input_sample_rate_);
|
||||
|
||||
pad_final_packet_ = spectrogram_options.pad_final_packet();
|
||||
output_type_ = spectrogram_options.output_type();
|
||||
allow_multichannel_input_ = spectrogram_options.allow_multichannel_input();
|
||||
|
||||
output_scale_ = spectrogram_options.output_scale();
|
||||
|
||||
std::vector<double> window;
|
||||
switch (spectrogram_options.window_type()) {
|
||||
case SpectrogramCalculatorOptions::HANN:
|
||||
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window);
|
||||
break;
|
||||
case SpectrogramCalculatorOptions::HAMMING:
|
||||
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window);
|
||||
break;
|
||||
}
|
||||
|
||||
// Propagate settings down to the actual Spectrogram object.
|
||||
spectrogram_generators_.clear();
|
||||
for (int i = 0; i < num_input_channels_; i++) {
|
||||
spectrogram_generators_.push_back(
|
||||
std::unique_ptr<audio_dsp::Spectrogram>(new audio_dsp::Spectrogram()));
|
||||
spectrogram_generators_[i]->Initialize(window, frame_step_samples());
|
||||
}
|
||||
|
||||
num_output_channels_ =
|
||||
spectrogram_generators_[0]->output_frequency_channels();
|
||||
std::unique_ptr<TimeSeriesHeader> output_header(
|
||||
new TimeSeriesHeader(input_header));
|
||||
// Store the actual sample rate of the input audio in the TimeSeriesHeader
|
||||
// so that subsequent calculators can figure out the frequency scale of
|
||||
// our output.
|
||||
output_header->set_audio_sample_rate(input_sample_rate_);
|
||||
// Setup rest of output header.
|
||||
output_header->set_num_channels(num_output_channels_);
|
||||
output_header->set_sample_rate(input_sample_rate_ / frame_step_samples());
|
||||
// Although we usually generate one output packet for each input
|
||||
// packet, this might not be true for input packets whose size is smaller
|
||||
// than the analysis window length. So we clear output_header.packet_rate
|
||||
// because we can't guarantee a constant packet rate. Similarly, the number
|
||||
// of output frames per packet depends on the input packet, so we also clear
|
||||
// output_header.num_samples.
|
||||
output_header->clear_packet_rate();
|
||||
output_header->clear_num_samples();
|
||||
if (!spectrogram_options.allow_multichannel_input()) {
|
||||
cc->Outputs().Index(0).SetHeader(Adopt(output_header.release()));
|
||||
} else {
|
||||
std::unique_ptr<MultiStreamTimeSeriesHeader> multichannel_output_header(
|
||||
new MultiStreamTimeSeriesHeader());
|
||||
*multichannel_output_header->mutable_time_series_header() = *output_header;
|
||||
multichannel_output_header->set_num_streams(num_input_channels_);
|
||||
cc->Outputs().Index(0).SetHeader(
|
||||
Adopt(multichannel_output_header.release()));
|
||||
}
|
||||
cumulative_completed_frames_ = 0;
|
||||
initial_input_timestamp_ = Timestamp::Unstarted();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status SpectrogramCalculator::Process(CalculatorContext* cc) {
|
||||
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
|
||||
initial_input_timestamp_ = cc->InputTimestamp();
|
||||
}
|
||||
|
||||
const Matrix& input_stream = cc->Inputs().Index(0).Get<Matrix>();
|
||||
if (input_stream.rows() != num_input_channels_) {
|
||||
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Number of input channels do not correspond to the number of rows "
|
||||
<< "in the input matrix: " << num_input_channels_ << "channels vs "
|
||||
<< input_stream.rows() << " rows";
|
||||
}
|
||||
|
||||
cumulative_input_samples_ += input_stream.cols();
|
||||
|
||||
return ProcessVector(input_stream, cc);
|
||||
}
|
||||
|
||||
template <class OutputMatrixType>
|
||||
::mediapipe::Status SpectrogramCalculator::ProcessVectorToOutput(
|
||||
const Matrix& input_stream,
|
||||
const OutputMatrixType postprocess_output_fn(const OutputMatrixType&),
|
||||
CalculatorContext* cc) {
|
||||
std::unique_ptr<std::vector<OutputMatrixType>> spectrogram_matrices(
|
||||
new std::vector<OutputMatrixType>());
|
||||
std::vector<std::vector<typename OutputMatrixType::Scalar>> output_vectors;
|
||||
|
||||
// Compute a spectrogram for each channel.
|
||||
int num_output_time_frames;
|
||||
for (int channel = 0; channel < input_stream.rows(); ++channel) {
|
||||
output_vectors.clear();
|
||||
|
||||
// Copy one row (channel) of the input matrix into the std::vector.
|
||||
std::vector<float> input_vector(input_stream.cols());
|
||||
Eigen::Map<Matrix>(&input_vector[0], 1, input_vector.size()) =
|
||||
input_stream.row(channel);
|
||||
|
||||
if (!spectrogram_generators_[channel]->ComputeSpectrogram(
|
||||
input_vector, &output_vectors)) {
|
||||
return ::mediapipe::Status(mediapipe::StatusCode::kInternal,
|
||||
"Spectrogram returned failure");
|
||||
}
|
||||
if (channel == 0) {
|
||||
// Record the number of time frames we expect from each channel.
|
||||
num_output_time_frames = output_vectors.size();
|
||||
} else {
|
||||
RET_CHECK_EQ(output_vectors.size(), num_output_time_frames)
|
||||
<< "Inconsistent spectrogram time frames for channel " << channel;
|
||||
}
|
||||
// Skip remaining processing if there are too few input samples to trigger
|
||||
// any output frames.
|
||||
if (!output_vectors.empty()) {
|
||||
// Translate the returned values into a matrix of output frames.
|
||||
OutputMatrixType output_frames(num_output_channels_,
|
||||
output_vectors.size());
|
||||
for (int frame = 0; frame < output_vectors.size(); ++frame) {
|
||||
Eigen::Map<const OutputMatrixType> frame_map(
|
||||
&output_vectors[frame][0], output_vectors[frame].size(), 1);
|
||||
// The underlying dsp object returns squared magnitudes; here
|
||||
// we optionally translate to linear magnitude or dB.
|
||||
output_frames.col(frame) =
|
||||
output_scale_ * postprocess_output_fn(frame_map);
|
||||
}
|
||||
spectrogram_matrices->push_back(output_frames);
|
||||
}
|
||||
}
|
||||
// If the input is very short, there may not be enough accumulated,
|
||||
// unprocessed samples to cause any new frames to be generated by
|
||||
// the spectrogram object. If so, we don't want to emit
|
||||
// a packet at all.
|
||||
if (!spectrogram_matrices->empty()) {
|
||||
RET_CHECK_EQ(spectrogram_matrices->size(), input_stream.rows())
|
||||
<< "Inconsistent number of spectrogram channels.";
|
||||
if (allow_multichannel_input_) {
|
||||
cc->Outputs().Index(0).Add(spectrogram_matrices.release(),
|
||||
CurrentOutputTimestamp());
|
||||
} else {
|
||||
cc->Outputs().Index(0).Add(
|
||||
new OutputMatrixType(spectrogram_matrices->at(0)),
|
||||
CurrentOutputTimestamp());
|
||||
}
|
||||
cumulative_completed_frames_ += output_vectors.size();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status SpectrogramCalculator::ProcessVector(
|
||||
const Matrix& input_stream, CalculatorContext* cc) {
|
||||
switch (output_type_) {
|
||||
// These blocks deliberately ignore clang-format to preserve the
|
||||
// "silhouette" of the different cases.
|
||||
// clang-format off
|
||||
case SpectrogramCalculatorOptions::COMPLEX: {
|
||||
return ProcessVectorToOutput(
|
||||
input_stream,
|
||||
+[](const Eigen::MatrixXcf& col) -> const Eigen::MatrixXcf {
|
||||
return col;
|
||||
}, cc);
|
||||
}
|
||||
case SpectrogramCalculatorOptions::SQUARED_MAGNITUDE: {
|
||||
return ProcessVectorToOutput(
|
||||
input_stream,
|
||||
+[](const Matrix& col) -> const Matrix {
|
||||
return col;
|
||||
}, cc);
|
||||
}
|
||||
case SpectrogramCalculatorOptions::LINEAR_MAGNITUDE: {
|
||||
return ProcessVectorToOutput(
|
||||
input_stream,
|
||||
+[](const Matrix& col) -> const Matrix {
|
||||
return col.array().sqrt().matrix();
|
||||
}, cc);
|
||||
}
|
||||
case SpectrogramCalculatorOptions::DECIBELS: {
|
||||
return ProcessVectorToOutput(
|
||||
input_stream,
|
||||
+[](const Matrix& col) -> const Matrix {
|
||||
return kLnPowerToDb * col.array().log().matrix();
|
||||
}, cc);
|
||||
}
|
||||
// clang-format on
|
||||
default: {
|
||||
return ::mediapipe::Status(mediapipe::StatusCode::kInvalidArgument,
|
||||
"Unrecognized spectrogram output type.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
::mediapipe::Status SpectrogramCalculator::Close(CalculatorContext* cc) {
|
||||
if (cumulative_input_samples_ > 0 && pad_final_packet_) {
|
||||
// We can flush any remaining samples by sending frame_step_samples - 1
|
||||
// zeros to the Process method, and letting it do its thing,
|
||||
// UNLESS we have fewer than one window's worth of samples, in which case
|
||||
// we pad to exactly one frame_duration_samples.
|
||||
// Release the memory for the Spectrogram objects.
|
||||
int required_padding_samples = frame_step_samples() - 1;
|
||||
if (cumulative_input_samples_ < frame_duration_samples_) {
|
||||
required_padding_samples =
|
||||
frame_duration_samples_ - cumulative_input_samples_;
|
||||
}
|
||||
return ProcessVector(
|
||||
Matrix::Zero(num_input_channels_, required_padding_samples), cc);
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
68
mediapipe/calculators/audio/spectrogram_calculator.proto
Normal file
68
mediapipe/calculators/audio/spectrogram_calculator.proto
Normal file
|
@ -0,0 +1,68 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message SpectrogramCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional SpectrogramCalculatorOptions ext = 76186688;
|
||||
}
|
||||
|
||||
// Options mirror those of TimeSeriesFramerCalculator.
|
||||
|
||||
// Analysis window duration in seconds. Required. Must be greater than 0.
|
||||
// (Note: the spectrogram DFT length will be the smallest power-of-2
|
||||
// sample count that can hold this duration.)
|
||||
optional double frame_duration_seconds = 1;
|
||||
|
||||
// Duration of overlap between adjacent windows.
|
||||
// Hence, frame_rate = 1/(frame_duration_seconds - frame_overlap_seconds).
|
||||
// Required that 0 <= frame_overlap_seconds < frame_duration_seconds.
|
||||
optional double frame_overlap_seconds = 2 [default = 0.0];
|
||||
|
||||
// Whether to pad the final packet with zeros. If true, guarantees that
|
||||
// all input samples will output. If set to false, any partial packet
|
||||
// at the end of the stream will be dropped.
|
||||
optional bool pad_final_packet = 3 [default = true];
|
||||
|
||||
// Output value type can be squared-magnitude, linear-magnitude,
|
||||
// deciBels (dB, = 20*log10(linear_magnitude)), or std::complex.
|
||||
enum OutputType {
|
||||
SQUARED_MAGNITUDE = 0;
|
||||
LINEAR_MAGNITUDE = 1;
|
||||
DECIBELS = 2;
|
||||
COMPLEX = 3;
|
||||
}
|
||||
optional OutputType output_type = 4 [default = SQUARED_MAGNITUDE];
|
||||
|
||||
// If set to true then the output will be a vector of spectrograms, one for
|
||||
// each channel and the stream will have a MultiStreamTimeSeriesHeader.
|
||||
optional bool allow_multichannel_input = 5 [default = false];
|
||||
|
||||
// Which window to use when computing the FFT.
|
||||
enum WindowType {
|
||||
HANN = 0;
|
||||
HAMMING = 1;
|
||||
}
|
||||
optional WindowType window_type = 6 [default = HANN];
|
||||
|
||||
// Support a fixed multiplicative scaling of the output. This is applied
|
||||
// uniformly regardless of output type (i.e., even dBs are multiplied, not
|
||||
// offset).
|
||||
optional double output_scale = 7 [default = 1.0];
|
||||
}
|
895
mediapipe/calculators/audio/spectrogram_calculator_test.cc
Normal file
895
mediapipe/calculators/audio/spectrogram_calculator_test.cc
Normal file
|
@ -0,0 +1,895 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "audio/dsp/number_util.h"
|
||||
#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/benchmark.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
const int kInitialTimestampOffsetMicroseconds = 4;
|
||||
|
||||
class SpectrogramCalculatorTest
|
||||
: public TimeSeriesCalculatorTest<SpectrogramCalculatorOptions> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "SpectrogramCalculator";
|
||||
input_sample_rate_ = 4000.0;
|
||||
num_input_channels_ = 1;
|
||||
}
|
||||
|
||||
// Initializes and runs the test graph.
|
||||
::mediapipe::Status Run() {
|
||||
// Now that options are set, we can set up some internal constants.
|
||||
frame_duration_samples_ =
|
||||
round(options_.frame_duration_seconds() * input_sample_rate_);
|
||||
frame_step_samples_ =
|
||||
frame_duration_samples_ -
|
||||
round(options_.frame_overlap_seconds() * input_sample_rate_);
|
||||
// The magnitude of the 0th FFT bin (DC) should be sum(input.*window);
|
||||
// for an input identically 1.0, this is just sum(window). The average
|
||||
// value of our Hann window is 0.5, hence this is the expected squared-
|
||||
// magnitude output value in the DC bin for constant input of 1.0.
|
||||
expected_dc_squared_magnitude_ =
|
||||
pow((static_cast<float>(frame_duration_samples_) * 0.5), 2.0);
|
||||
|
||||
return RunGraph();
|
||||
}
|
||||
|
||||
// Creates test multichannel input with specified packet sizes and containing
|
||||
// a constant-frequency sinusoid that maintains phase between adjacent
|
||||
// packets.
|
||||
void SetupCosineInputPackets(const std::vector<int>& packet_sizes_samples,
|
||||
float cosine_frequency_hz) {
|
||||
int total_num_input_samples = 0;
|
||||
for (int packet_size_samples : packet_sizes_samples) {
|
||||
double packet_start_time_seconds =
|
||||
kInitialTimestampOffsetMicroseconds * 1e-6 +
|
||||
total_num_input_samples / input_sample_rate_;
|
||||
double packet_end_time_seconds =
|
||||
packet_start_time_seconds + packet_size_samples / input_sample_rate_;
|
||||
double angular_freq = 2 * M_PI * cosine_frequency_hz;
|
||||
Matrix* packet_data =
|
||||
new Matrix(num_input_channels_, packet_size_samples);
|
||||
// Use Eigen's vectorized cos() function to fill the vector with a
|
||||
// sinusoid of appropriate frequency & phase.
|
||||
for (int i = 0; i < num_input_channels_; i++) {
|
||||
packet_data->row(i) =
|
||||
Eigen::ArrayXf::LinSpaced(packet_size_samples,
|
||||
packet_start_time_seconds * angular_freq,
|
||||
packet_end_time_seconds * angular_freq)
|
||||
.cos()
|
||||
.transpose();
|
||||
}
|
||||
int64 input_timestamp = round(packet_start_time_seconds *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
AppendInputPacket(packet_data, input_timestamp);
|
||||
total_num_input_samples += packet_size_samples;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup a sequence of input packets of specified sizes, each filled
|
||||
// with samples of 1.0.
|
||||
void SetupConstantInputPackets(const std::vector<int>& packet_sizes_samples) {
|
||||
// A 0 Hz cosine is identically 1.0 for all samples.
|
||||
SetupCosineInputPackets(packet_sizes_samples, 0.0);
|
||||
}
|
||||
|
||||
// Setup a sequence of input packets of specified sizes, each containing a
|
||||
// single sample of 1.0 at a specified offset.
|
||||
void SetupImpulseInputPackets(
|
||||
const std::vector<int>& packet_sizes_samples,
|
||||
const std::vector<int>& impulse_offsets_samples) {
|
||||
int total_num_input_samples = 0;
|
||||
for (int i = 0; i < packet_sizes_samples.size(); ++i) {
|
||||
double packet_start_time_seconds =
|
||||
kInitialTimestampOffsetMicroseconds * 1e-6 +
|
||||
total_num_input_samples / input_sample_rate_;
|
||||
int64 input_timestamp = round(packet_start_time_seconds *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
std::unique_ptr<Matrix> impulse(
|
||||
new Matrix(Matrix::Zero(1, packet_sizes_samples[i])));
|
||||
(*impulse)(0, impulse_offsets_samples[i]) = 1.0;
|
||||
AppendInputPacket(impulse.release(), input_timestamp);
|
||||
total_num_input_samples += packet_sizes_samples[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Creates test multichannel input with specified packet sizes and containing
|
||||
// constant input packets for the even channels and constant-frequency
|
||||
// sinusoid that maintains phase between adjacent packets for the odd
|
||||
// channels.
|
||||
void SetupMultichannelInputPackets(
|
||||
const std::vector<int>& packet_sizes_samples, float cosine_frequency_hz) {
|
||||
int total_num_input_samples = 0;
|
||||
for (int packet_size_samples : packet_sizes_samples) {
|
||||
double packet_start_time_seconds =
|
||||
kInitialTimestampOffsetMicroseconds * 1e-6 +
|
||||
total_num_input_samples / input_sample_rate_;
|
||||
double packet_end_time_seconds =
|
||||
packet_start_time_seconds + packet_size_samples / input_sample_rate_;
|
||||
double angular_freq;
|
||||
Matrix* packet_data =
|
||||
new Matrix(num_input_channels_, packet_size_samples);
|
||||
// Use Eigen's vectorized cos() function to fill the vector with a
|
||||
// sinusoid of appropriate frequency & phase.
|
||||
for (int i = 0; i < num_input_channels_; i++) {
|
||||
if (i % 2 == 0) {
|
||||
angular_freq = 0;
|
||||
} else {
|
||||
angular_freq = 2 * M_PI * cosine_frequency_hz;
|
||||
}
|
||||
packet_data->row(i) =
|
||||
Eigen::ArrayXf::LinSpaced(packet_size_samples,
|
||||
packet_start_time_seconds * angular_freq,
|
||||
packet_end_time_seconds * angular_freq)
|
||||
.cos()
|
||||
.transpose();
|
||||
}
|
||||
int64 input_timestamp = round(packet_start_time_seconds *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
AppendInputPacket(packet_data, input_timestamp);
|
||||
total_num_input_samples += packet_size_samples;
|
||||
}
|
||||
}
|
||||
|
||||
// Return vector of the numbers of frames in each output packet.
|
||||
std::vector<int> OutputFramesPerPacket() {
|
||||
std::vector<int> frame_counts;
|
||||
for (const Packet& packet : output().packets) {
|
||||
const Matrix& matrix = packet.Get<Matrix>();
|
||||
frame_counts.push_back(matrix.cols());
|
||||
}
|
||||
return frame_counts;
|
||||
}
|
||||
|
||||
// Checks output headers and Timestamps.
|
||||
void CheckOutputHeadersAndTimestamps() {
|
||||
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
|
||||
|
||||
TimeSeriesHeader expected_header = input().header.Get<TimeSeriesHeader>();
|
||||
expected_header.set_num_channels(fft_size / 2 + 1);
|
||||
// The output header sample rate should depend on the output frame step.
|
||||
expected_header.set_sample_rate(input_sample_rate_ / frame_step_samples_);
|
||||
// SpectrogramCalculator stores the sample rate of the input in
|
||||
// the TimeSeriesHeader.
|
||||
expected_header.set_audio_sample_rate(input_sample_rate_);
|
||||
// We expect the output header to have num_samples and packet_rate unset.
|
||||
expected_header.clear_num_samples();
|
||||
expected_header.clear_packet_rate();
|
||||
if (!options_.allow_multichannel_input()) {
|
||||
ExpectOutputHeaderEquals(expected_header);
|
||||
} else {
|
||||
EXPECT_THAT(output()
|
||||
.header.template Get<MultiStreamTimeSeriesHeader>()
|
||||
.time_series_header(),
|
||||
mediapipe::EqualsProto(expected_header));
|
||||
EXPECT_THAT(output()
|
||||
.header.template Get<MultiStreamTimeSeriesHeader>()
|
||||
.num_streams(),
|
||||
num_input_channels_);
|
||||
}
|
||||
|
||||
int cumulative_output_frames = 0;
|
||||
// The timestamps coming out of the spectrogram correspond to the
|
||||
// middle of the first frame's window, hence frame_duration_samples_/2
|
||||
// term. We use frame_duration_samples_ because that is how it is
|
||||
// actually quantized inside spectrogram.
|
||||
const double packet_timestamp_offset_seconds =
|
||||
kInitialTimestampOffsetMicroseconds * 1e-6;
|
||||
const double frame_step_seconds = frame_step_samples_ / input_sample_rate_;
|
||||
|
||||
Timestamp initial_timestamp = Timestamp::Unstarted();
|
||||
|
||||
for (const Packet& packet : output().packets) {
|
||||
// This is the timestamp we expect based on how the spectrogram should
|
||||
// behave (advancing by one step's worth of input samples each frame).
|
||||
const double expected_timestamp_seconds =
|
||||
packet_timestamp_offset_seconds +
|
||||
cumulative_output_frames * frame_step_seconds;
|
||||
const int64 expected_timestamp_ticks =
|
||||
expected_timestamp_seconds * Timestamp::kTimestampUnitsPerSecond;
|
||||
EXPECT_EQ(expected_timestamp_ticks, packet.Timestamp().Value());
|
||||
// Accept the timestamp of the first packet as the baseline for checking
|
||||
// the remainder.
|
||||
if (initial_timestamp == Timestamp::Unstarted()) {
|
||||
initial_timestamp = packet.Timestamp();
|
||||
}
|
||||
// Also check that the timestamp is consistent with the sample_rate
|
||||
// in the output stream's TimeSeriesHeader.
|
||||
EXPECT_TRUE(time_series_util::LogWarningIfTimestampIsInconsistent(
|
||||
packet.Timestamp(), initial_timestamp, cumulative_output_frames,
|
||||
expected_header.sample_rate()));
|
||||
if (!options_.allow_multichannel_input()) {
|
||||
if (options_.output_type() == SpectrogramCalculatorOptions::COMPLEX) {
|
||||
const Eigen::MatrixXcf& matrix = packet.Get<Eigen::MatrixXcf>();
|
||||
cumulative_output_frames += matrix.cols();
|
||||
} else {
|
||||
const Matrix& matrix = packet.Get<Matrix>();
|
||||
cumulative_output_frames += matrix.cols();
|
||||
}
|
||||
} else {
|
||||
if (options_.output_type() == SpectrogramCalculatorOptions::COMPLEX) {
|
||||
const Eigen::MatrixXcf& matrix =
|
||||
packet.Get<std::vector<Eigen::MatrixXcf>>().at(0);
|
||||
cumulative_output_frames += matrix.cols();
|
||||
} else {
|
||||
const Matrix& matrix = packet.Get<std::vector<Matrix>>().at(0);
|
||||
cumulative_output_frames += matrix.cols();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that the bin corresponding to the specified frequency
|
||||
// is the largest one in one particular frame of a single packet.
|
||||
void CheckPeakFrequencyInPacketFrame(const Packet& packet, int frame,
|
||||
float frequency) {
|
||||
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
|
||||
const int target_bin =
|
||||
round((frequency / input_sample_rate_) * static_cast<float>(fft_size));
|
||||
|
||||
const Matrix& matrix = packet.Get<Matrix>();
|
||||
// Stop here if the requested frame is not in this packet.
|
||||
ASSERT_GT(matrix.cols(), frame);
|
||||
|
||||
int actual_largest_bin;
|
||||
matrix.col(frame).maxCoeff(&actual_largest_bin);
|
||||
EXPECT_EQ(actual_largest_bin, target_bin);
|
||||
}
|
||||
|
||||
// Verify that the bin corresponding to the specified frequency
|
||||
// is the largest one in one particular frame of a single spectrogram Matrix.
|
||||
void CheckPeakFrequencyInMatrix(const Matrix& matrix, int frame,
|
||||
float frequency) {
|
||||
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
|
||||
const int target_bin =
|
||||
round((frequency / input_sample_rate_) * static_cast<float>(fft_size));
|
||||
|
||||
// Stop here if the requested frame is not in this packet.
|
||||
ASSERT_GT(matrix.cols(), frame);
|
||||
|
||||
int actual_largest_bin;
|
||||
matrix.col(frame).maxCoeff(&actual_largest_bin);
|
||||
EXPECT_EQ(actual_largest_bin, target_bin);
|
||||
}
|
||||
|
||||
int frame_duration_samples_;
|
||||
int frame_step_samples_;
|
||||
// Expected DC output for a window of pure 1.0, set when window length
|
||||
// is set.
|
||||
float expected_dc_squared_magnitude_;
|
||||
};
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, IntegerFrameDurationNoOverlap) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {500, 200};
|
||||
const std::vector<int> expected_output_packet_sizes = {5, 2};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, IntegerFrameDurationSomeOverlap) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {500, 200};
|
||||
// complete_output_frames = 1 + floor((input_samples - window_length)/step)
|
||||
// = 1 + floor((500 - 100)/40) = 1 + 10 = 11 for the first packet
|
||||
// = 1 + floor((700 - 100)/40) = 1 + 15 = 16 for the whole stream
|
||||
// so expect 16 - 11 = 5 in the second packet.
|
||||
const std::vector<int> expected_output_packet_sizes = {11, 5};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, NonintegerFrameDurationAndOverlap) {
|
||||
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(58.4 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {500, 200};
|
||||
// now frame_duration_samples will be 99 (rounded), and frame_step_samples
|
||||
// will be (99-58) = 41, so the first packet of 500 samples will generate
|
||||
// 1 + floor(500-99)/41 = 10 samples.
|
||||
const std::vector<int> expected_output_packet_sizes = {10, 5};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, ShortInitialPacketNoOverlap) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {90, 100, 110};
|
||||
// The first input packet is too small to generate any frames,
|
||||
// but zero-length packets would result in a timestamp monotonicity
|
||||
// violation, so they are suppressed. Thus, only the second and third
|
||||
// input packets generate output packets.
|
||||
const std::vector<int> expected_output_packet_sizes = {1, 2};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, TrailingSamplesNoPad) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {140, 90};
|
||||
const std::vector<int> expected_output_packet_sizes = {2, 2};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, NoTrailingSamplesWithPad) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(true);
|
||||
const std::vector<int> input_packet_sizes = {140, 80};
|
||||
const std::vector<int> expected_output_packet_sizes = {2, 2};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, TrailingSamplesWithPad) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(true);
|
||||
const std::vector<int> input_packet_sizes = {140, 90};
|
||||
// In contrast to NoTrailingSamplesWithPad and TrailingSamplesNoPad,
|
||||
// this time we get an extra frame in an extra final packet.
|
||||
const std::vector<int> expected_output_packet_sizes = {2, 2, 1};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, VeryShortInputWillPad) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(true);
|
||||
const std::vector<int> input_packet_sizes = {30};
|
||||
const std::vector<int> expected_output_packet_sizes = {1};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, VeryShortInputZeroOutputFramesIfNoPad) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {90};
|
||||
const std::vector<int> expected_output_packet_sizes = {};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, DCSignalIsPeakBin) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
const std::vector<int> input_packet_sizes = {140}; // Gives 2 output frames.
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
const float dc_frequency_hz = 0.0;
|
||||
CheckPeakFrequencyInPacketFrame(output().packets[0], 0, dc_frequency_hz);
|
||||
CheckPeakFrequencyInPacketFrame(output().packets[0], 1, dc_frequency_hz);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, A440ToneIsPeakBin) {
|
||||
const std::vector<int> input_packet_sizes = {
|
||||
460}; // 100 + 9*40 for 10 frames.
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
int num_output_frames = output().packets[0].Get<Matrix>().cols();
|
||||
for (int frame = 0; frame < num_output_frames; ++frame) {
|
||||
CheckPeakFrequencyInPacketFrame(output().packets[0], frame,
|
||||
tone_frequency_hz);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, SquaredMagnitudeOutputLooksRight) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::SQUARED_MAGNITUDE);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
|
||||
expected_dc_squared_magnitude_);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, DefaultOutputIsSquaredMagnitude) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
// Let the output_type be its default
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
|
||||
expected_dc_squared_magnitude_);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, LinearMagnitudeOutputLooksRight) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::LINEAR_MAGNITUDE);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
|
||||
std::sqrt(expected_dc_squared_magnitude_));
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, DbMagnitudeOutputLooksRight) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::DECIBELS);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
|
||||
10.0 * std::log10(expected_dc_squared_magnitude_));
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, OutputScalingLooksRight) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::DECIBELS);
|
||||
double output_scale = 2.5;
|
||||
options_.set_output_scale(output_scale);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(
|
||||
output().packets[0].Get<Matrix>()(0, 0),
|
||||
output_scale * 10.0 * std::log10(expected_dc_squared_magnitude_));
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, ComplexOutputLooksRight) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(std::norm(output().packets[0].Get<Eigen::MatrixXcf>()(0, 0)),
|
||||
expected_dc_squared_magnitude_);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, ComplexOutputLooksRightForImpulses) {
|
||||
const int frame_size_samples = 100;
|
||||
options_.set_frame_duration_seconds(frame_size_samples / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
|
||||
const std::vector<int> input_packet_sizes = {frame_size_samples,
|
||||
frame_size_samples};
|
||||
const std::vector<int> input_packet_impulse_offsets = {49, 50};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
|
||||
// Make two impulse packets offset one sample from each other
|
||||
SetupImpulseInputPackets(input_packet_sizes, input_packet_impulse_offsets);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
const int num_buckets =
|
||||
(audio_dsp::NextPowerOfTwo(frame_size_samples) / 2) + 1;
|
||||
const float precision = 0.01f;
|
||||
auto norm_fn = [](const std::complex<float>& cf) { return std::norm(cf); };
|
||||
|
||||
// Both impulses should have (approximately) constant power across all
|
||||
// frequency bins
|
||||
EXPECT_TRUE(output()
|
||||
.packets[0]
|
||||
.Get<Eigen::MatrixXcf>()
|
||||
.unaryExpr(norm_fn)
|
||||
.isApproxToConstant(1.0f, precision));
|
||||
EXPECT_TRUE(output()
|
||||
.packets[1]
|
||||
.Get<Eigen::MatrixXcf>()
|
||||
.unaryExpr(norm_fn)
|
||||
.isApproxToConstant(1.0f, precision));
|
||||
|
||||
// Because the second Packet's impulse is delayed by exactly one sample with
|
||||
// respect to the first Packet's impulse, the second impulse should have
|
||||
// greater phase, and in the highest frequency bin, the real part should
|
||||
// (approximately) flip sign from the first Packet to the second
|
||||
EXPECT_LT(std::arg(output().packets[0].Get<Eigen::MatrixXcf>()(1, 0)),
|
||||
std::arg(output().packets[1].Get<Eigen::MatrixXcf>()(1, 0)));
|
||||
const float highest_bucket_real_ratio =
|
||||
output().packets[0].Get<Eigen::MatrixXcf>()(num_buckets - 1, 0).real() /
|
||||
output().packets[1].Get<Eigen::MatrixXcf>()(num_buckets - 1, 0).real();
|
||||
EXPECT_NEAR(highest_bucket_real_ratio, -1.0f, precision);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, SquaredMagnitudeOutputLooksRightForNonDC) {
|
||||
const int frame_size_samples = 100;
|
||||
options_.set_frame_duration_seconds(frame_size_samples / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::SQUARED_MAGNITUDE);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Make the tone have an integral number of cycles within the window
|
||||
const int target_bin = 16;
|
||||
const int fft_size = audio_dsp::NextPowerOfTwo(frame_size_samples);
|
||||
const float tone_frequency_hz = target_bin * (input_sample_rate_ / fft_size);
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
// For a non-DC bin, the magnitude will be split between positive and
|
||||
// negative frequency bins, so it should about be half-magnitude
|
||||
// = quarter-power.
|
||||
// It's not quite exact because of the interference from the hann(100)
|
||||
// spread from the negative-frequency half.
|
||||
EXPECT_GT(output().packets[0].Get<Matrix>()(target_bin, 0),
|
||||
0.98 * expected_dc_squared_magnitude_ / 4.0);
|
||||
EXPECT_LT(output().packets[0].Get<Matrix>()(target_bin, 0),
|
||||
1.02 * expected_dc_squared_magnitude_ / 4.0);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, ZeroOutputsForZeroInputsWithPaddingEnabled) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(true);
|
||||
const std::vector<int> input_packet_sizes = {};
|
||||
const std::vector<int> expected_output_packet_sizes = {};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, NumChannelsIsRight) {
|
||||
const std::vector<int> input_packet_sizes = {460};
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
options_.set_allow_multichannel_input(true);
|
||||
num_input_channels_ = 3;
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(output().packets[0].Get<std::vector<Matrix>>().size(),
|
||||
num_input_channels_);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, NumSamplesAndPacketRateAreCleared) {
|
||||
num_input_samples_ = 500;
|
||||
input_packet_rate_ = 1.0;
|
||||
const std::vector<int> input_packet_sizes = {num_input_samples_};
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(0.0);
|
||||
options_.set_pad_final_packet(false);
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
const TimeSeriesHeader& output_header =
|
||||
output().header.Get<TimeSeriesHeader>();
|
||||
EXPECT_FALSE(output_header.has_num_samples());
|
||||
EXPECT_FALSE(output_header.has_packet_rate());
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, MultichannelSpectrogramSizesAreRight) {
|
||||
const std::vector<int> input_packet_sizes = {420}; // less than 10 frames
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
options_.set_allow_multichannel_input(true);
|
||||
num_input_channels_ = 10;
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
|
||||
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
|
||||
int spectrogram_num_rows = spectrograms[0].rows();
|
||||
int spectrogram_num_cols = spectrograms[0].cols();
|
||||
for (int i = 1; i < num_input_channels_; i++) {
|
||||
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
|
||||
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, MultichannelSpectrogramValuesAreRight) {
|
||||
const std::vector<int> input_packet_sizes = {
|
||||
460}; // 100 + 9*40 for 10 frames.
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_allow_multichannel_input(true);
|
||||
num_input_channels_ = 10;
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupMultichannelInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
|
||||
int num_output_frames = spectrograms[0].cols();
|
||||
for (int i = 0; i < num_input_channels_; i++) {
|
||||
for (int frame = 0; frame < num_output_frames; ++frame) {
|
||||
if (i % 2 == 0) {
|
||||
CheckPeakFrequencyInMatrix(spectrograms[i], frame, 0);
|
||||
} else {
|
||||
CheckPeakFrequencyInMatrix(spectrograms[i], frame, tone_frequency_hz);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, MultichannelHandlesShortInitialPacket) {
|
||||
// First packet is less than one frame, but second packet should trigger a
|
||||
// complete frame from all channels.
|
||||
const std::vector<int> input_packet_sizes = {50, 50};
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
options_.set_allow_multichannel_input(true);
|
||||
num_input_channels_ = 2;
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
|
||||
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
|
||||
int spectrogram_num_rows = spectrograms[0].rows();
|
||||
int spectrogram_num_cols = spectrograms[0].cols();
|
||||
for (int i = 1; i < num_input_channels_; i++) {
|
||||
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
|
||||
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest,
|
||||
MultichannelComplexHandlesShortInitialPacket) {
|
||||
// First packet is less than one frame, but second packet should trigger a
|
||||
// complete frame from all channels, even for complex output.
|
||||
const std::vector<int> input_packet_sizes = {50, 50};
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
options_.set_allow_multichannel_input(true);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
|
||||
num_input_channels_ = 2;
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
auto spectrograms = output().packets[0].Get<std::vector<Eigen::MatrixXcf>>();
|
||||
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
|
||||
int spectrogram_num_rows = spectrograms[0].rows();
|
||||
int spectrogram_num_cols = spectrograms[0].cols();
|
||||
for (int i = 1; i < num_input_channels_; i++) {
|
||||
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
|
||||
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
|
||||
}
|
||||
}
|
||||
|
||||
void BM_ProcessDC(benchmark::State& state) {
|
||||
CalculatorGraphConfig::Node node_config;
|
||||
node_config.set_calculator("SpectrogramCalculator");
|
||||
node_config.add_input_stream("input_audio");
|
||||
node_config.add_output_stream("output_spectrogram");
|
||||
|
||||
SpectrogramCalculatorOptions* options =
|
||||
node_config.mutable_options()->MutableExtension(
|
||||
SpectrogramCalculatorOptions::ext);
|
||||
options->set_frame_duration_seconds(0.010);
|
||||
options->set_frame_overlap_seconds(0.0);
|
||||
options->set_pad_final_packet(false);
|
||||
*node_config.mutable_options()->MutableExtension(
|
||||
SpectrogramCalculatorOptions::ext) = *options;
|
||||
|
||||
int num_input_channels = 1;
|
||||
int packet_size_samples = 1600000;
|
||||
TimeSeriesHeader* header = new TimeSeriesHeader();
|
||||
header->set_sample_rate(16000.0);
|
||||
header->set_num_channels(num_input_channels);
|
||||
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableInputs()->Index(0).header = Adopt(header);
|
||||
|
||||
Matrix* payload = new Matrix(
|
||||
Matrix::Constant(num_input_channels, packet_size_samples, 1.0));
|
||||
Timestamp timestamp = Timestamp(0);
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(payload).At(timestamp));
|
||||
|
||||
for (auto _ : state) {
|
||||
ASSERT_TRUE(runner.Run().ok());
|
||||
}
|
||||
|
||||
const CalculatorRunner::StreamContents& output = runner.Outputs().Index(0);
|
||||
const Matrix& output_matrix = output.packets[0].Get<Matrix>();
|
||||
LOG(INFO) << "Output matrix=" << output_matrix.rows() << "x"
|
||||
<< output_matrix.cols();
|
||||
LOG(INFO) << "First values=" << output_matrix(0, 0) << ", "
|
||||
<< output_matrix(1, 0) << ", " << output_matrix(2, 0) << ", "
|
||||
<< output_matrix(3, 0);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ProcessDC);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace mediapipe
|
27
mediapipe/calculators/audio/testdata/BUILD
vendored
Normal file
27
mediapipe/calculators/audio/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2019 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
filegroup(
|
||||
name = "test_audios",
|
||||
srcs = [
|
||||
"sine_wave_1k_44100_mono_2_sec_wav.audio",
|
||||
"sine_wave_1k_44100_stereo_2_sec_aac.audio",
|
||||
"sine_wave_1k_44100_stereo_2_sec_mp3.audio",
|
||||
"sine_wave_1k_48000_stereo_2_sec_wav.audio",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
BIN
mediapipe/calculators/audio/testdata/sine_wave_1k_44100_mono_2_sec_wav.audio
vendored
Normal file
BIN
mediapipe/calculators/audio/testdata/sine_wave_1k_44100_mono_2_sec_wav.audio
vendored
Normal file
Binary file not shown.
BIN
mediapipe/calculators/audio/testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio
vendored
Normal file
BIN
mediapipe/calculators/audio/testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio
vendored
Normal file
Binary file not shown.
BIN
mediapipe/calculators/audio/testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio
vendored
Normal file
BIN
mediapipe/calculators/audio/testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio
vendored
Normal file
Binary file not shown.
BIN
mediapipe/calculators/audio/testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio
vendored
Normal file
BIN
mediapipe/calculators/audio/testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio
vendored
Normal file
Binary file not shown.
289
mediapipe/calculators/audio/time_series_framer_calculator.cc
Normal file
289
mediapipe/calculators/audio/time_series_framer_calculator.cc
Normal file
|
@ -0,0 +1,289 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Defines TimeSeriesFramerCalculator.
|
||||
#include <math.h>
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "audio/dsp/window_functions.h"
|
||||
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// MediaPipe Calculator for framing a (vector-valued) input time series,
|
||||
// i.e. for breaking an input time series into fixed-size, possibly
|
||||
// overlapping, frames. The output stream's frame duration is
|
||||
// specified by frame_duration_seconds in the
|
||||
// TimeSeriesFramerCalculatorOptions, and the output's overlap is
|
||||
// specified by frame_overlap_seconds.
|
||||
//
|
||||
// This calculator assumes that the input timestamps refer to the
|
||||
// first sample in each Matrix. The output timestamps follow this
|
||||
// same convention.
|
||||
//
|
||||
// All output frames will have exactly the same number of samples: the number of
|
||||
// samples that approximates frame_duration_seconds most closely.
|
||||
//
|
||||
// Similarly, frame overlap is by default the (fixed) number of samples
|
||||
// approximating frame_overlap_seconds most closely. But if
|
||||
// emulate_fractional_frame_overlap is set to true, frame overlap is a variable
|
||||
// number of samples instead, such that the long-term average step between
|
||||
// frames is the difference between the (nominal) frame_duration_seconds and
|
||||
// frame_overlap_seconds.
|
||||
//
|
||||
// If pad_final_packet is true, all input samples will be emitted and the final
|
||||
// packet will be zero padded as necessary. If pad_final_packet is false, some
|
||||
// samples may be dropped at the end of the stream.
|
||||
class TimeSeriesFramerCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Input stream with TimeSeriesHeader.
|
||||
);
|
||||
cc->Outputs().Index(0).Set<Matrix>(
|
||||
// Fixed length time series Packets with TimeSeriesHeader.
|
||||
);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Returns FAIL if the input stream header is invalid.
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
|
||||
// Outputs as many framed packets as possible given the accumulated
|
||||
// input. Always returns OK.
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
// Flushes any remaining samples in a zero-padded packet. Always
|
||||
// returns OK.
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Adds input data to the internal buffer.
|
||||
void EnqueueInput(CalculatorContext* cc);
|
||||
// Constructs and emits framed output packets.
|
||||
void FrameOutput(CalculatorContext* cc);
|
||||
|
||||
Timestamp CurrentOutputTimestamp() {
|
||||
return initial_input_timestamp_ +
|
||||
round(cumulative_completed_samples_ / sample_rate_ *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
}
|
||||
|
||||
// The number of input samples to advance after the current output frame is
|
||||
// emitted.
|
||||
int next_frame_step_samples() const {
|
||||
// All numbers are in input samples.
|
||||
const int64 current_output_frame_start = static_cast<int64>(
|
||||
round(cumulative_output_frames_ * average_frame_step_samples_));
|
||||
CHECK_EQ(current_output_frame_start, cumulative_completed_samples_);
|
||||
const int64 next_output_frame_start = static_cast<int64>(
|
||||
round((cumulative_output_frames_ + 1) * average_frame_step_samples_));
|
||||
return next_output_frame_start - current_output_frame_start;
|
||||
}
|
||||
|
||||
double sample_rate_;
|
||||
bool pad_final_packet_;
|
||||
int frame_duration_samples_;
|
||||
// The advance, in input samples, between the start of successive output
|
||||
// frames. This may be a non-integer average value if
|
||||
// emulate_fractional_frame_overlap is true.
|
||||
double average_frame_step_samples_;
|
||||
int samples_still_to_drop_;
|
||||
int64 cumulative_input_samples_;
|
||||
int64 cumulative_output_frames_;
|
||||
// "Completed" samples are samples that are no longer needed because
|
||||
// the framer has completely stepped past them (taking into account
|
||||
// any overlap).
|
||||
int64 cumulative_completed_samples_;
|
||||
Timestamp initial_input_timestamp_;
|
||||
int num_channels_;
|
||||
|
||||
// Each entry in this deque consists of a single sample, i.e. a
|
||||
// single column vector.
|
||||
std::deque<Matrix> sample_buffer_;
|
||||
|
||||
bool use_window_;
|
||||
Matrix window_;
|
||||
};
|
||||
REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
|
||||
|
||||
void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) {
|
||||
const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>();
|
||||
|
||||
for (int i = 0; i < input_frame.cols(); ++i) {
|
||||
sample_buffer_.emplace_back(input_frame.col(i));
|
||||
}
|
||||
|
||||
cumulative_input_samples_ += input_frame.cols();
|
||||
}
|
||||
|
||||
void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
|
||||
while (sample_buffer_.size() >=
|
||||
frame_duration_samples_ + samples_still_to_drop_) {
|
||||
while (samples_still_to_drop_ > 0) {
|
||||
sample_buffer_.pop_front();
|
||||
--samples_still_to_drop_;
|
||||
}
|
||||
const int frame_step_samples = next_frame_step_samples();
|
||||
std::unique_ptr<Matrix> output_frame(
|
||||
new Matrix(num_channels_, frame_duration_samples_));
|
||||
for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_);
|
||||
++i) {
|
||||
output_frame->col(i) = sample_buffer_.front();
|
||||
sample_buffer_.pop_front();
|
||||
}
|
||||
const int frame_overlap_samples =
|
||||
frame_duration_samples_ - frame_step_samples;
|
||||
if (frame_overlap_samples > 0) {
|
||||
for (int i = 0; i < frame_overlap_samples; ++i) {
|
||||
output_frame->col(i + frame_step_samples) = sample_buffer_[i];
|
||||
}
|
||||
} else {
|
||||
samples_still_to_drop_ = -frame_overlap_samples;
|
||||
}
|
||||
|
||||
if (use_window_) {
|
||||
*output_frame = (output_frame->array() * window_.array()).matrix();
|
||||
}
|
||||
|
||||
cc->Outputs().Index(0).Add(output_frame.release(),
|
||||
CurrentOutputTimestamp());
|
||||
++cumulative_output_frames_;
|
||||
cumulative_completed_samples_ += frame_step_samples;
|
||||
}
|
||||
}
|
||||
|
||||
::mediapipe::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
|
||||
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
|
||||
initial_input_timestamp_ = cc->InputTimestamp();
|
||||
}
|
||||
|
||||
EnqueueInput(cc);
|
||||
FrameOutput(cc);
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
|
||||
while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) {
|
||||
sample_buffer_.pop_front();
|
||||
--samples_still_to_drop_;
|
||||
}
|
||||
if (!sample_buffer_.empty() && pad_final_packet_) {
|
||||
std::unique_ptr<Matrix> output_frame(new Matrix);
|
||||
output_frame->setZero(num_channels_, frame_duration_samples_);
|
||||
for (int i = 0; i < sample_buffer_.size(); ++i) {
|
||||
output_frame->col(i) = sample_buffer_[i];
|
||||
}
|
||||
|
||||
cc->Outputs().Index(0).Add(output_frame.release(),
|
||||
CurrentOutputTimestamp());
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
|
||||
TimeSeriesFramerCalculatorOptions framer_options;
|
||||
time_series_util::FillOptionsExtensionOrDie(cc->Options(), &framer_options);
|
||||
|
||||
RET_CHECK_GT(framer_options.frame_duration_seconds(), 0.0)
|
||||
<< "Invalid or missing frame_duration_seconds. "
|
||||
<< "framer_duration_seconds: \n"
|
||||
<< framer_options.frame_duration_seconds();
|
||||
RET_CHECK_LT(framer_options.frame_overlap_seconds(),
|
||||
framer_options.frame_duration_seconds())
|
||||
<< "Invalid frame_overlap_seconds. framer_overlap_seconds: \n"
|
||||
<< framer_options.frame_overlap_seconds();
|
||||
|
||||
TimeSeriesHeader input_header;
|
||||
RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
|
||||
sample_rate_ = input_header.sample_rate();
|
||||
num_channels_ = input_header.num_channels();
|
||||
frame_duration_samples_ = time_series_util::SecondsToSamples(
|
||||
framer_options.frame_duration_seconds(), sample_rate_);
|
||||
RET_CHECK_GT(frame_duration_samples_, 0)
|
||||
<< "Frame duration of " << framer_options.frame_duration_seconds()
|
||||
<< "s too small to cover a single sample at " << sample_rate_ << " Hz ";
|
||||
if (framer_options.emulate_fractional_frame_overlap()) {
|
||||
// Frame step may be fractional.
|
||||
average_frame_step_samples_ = (framer_options.frame_duration_seconds() -
|
||||
framer_options.frame_overlap_seconds()) *
|
||||
sample_rate_;
|
||||
} else {
|
||||
// Frame step is an integer (stored in a double).
|
||||
average_frame_step_samples_ =
|
||||
frame_duration_samples_ -
|
||||
time_series_util::SecondsToSamples(
|
||||
framer_options.frame_overlap_seconds(), sample_rate_);
|
||||
}
|
||||
RET_CHECK_GE(average_frame_step_samples_, 1)
|
||||
<< "Frame step too small to cover a single sample at " << sample_rate_
|
||||
<< " Hz.";
|
||||
pad_final_packet_ = framer_options.pad_final_packet();
|
||||
|
||||
auto output_header = new TimeSeriesHeader(input_header);
|
||||
output_header->set_num_samples(frame_duration_samples_);
|
||||
if (round(average_frame_step_samples_) == average_frame_step_samples_) {
|
||||
// Only set output packet rate if it is fixed.
|
||||
output_header->set_packet_rate(sample_rate_ / average_frame_step_samples_);
|
||||
}
|
||||
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
|
||||
cumulative_completed_samples_ = 0;
|
||||
cumulative_input_samples_ = 0;
|
||||
cumulative_output_frames_ = 0;
|
||||
samples_still_to_drop_ = 0;
|
||||
initial_input_timestamp_ = Timestamp::Unstarted();
|
||||
|
||||
std::vector<double> window_vector;
|
||||
use_window_ = false;
|
||||
switch (framer_options.window_function()) {
|
||||
case TimeSeriesFramerCalculatorOptions::HAMMING:
|
||||
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window_vector);
|
||||
use_window_ = true;
|
||||
break;
|
||||
case TimeSeriesFramerCalculatorOptions::HANN:
|
||||
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window_vector);
|
||||
use_window_ = true;
|
||||
break;
|
||||
case TimeSeriesFramerCalculatorOptions::NONE:
|
||||
break;
|
||||
}
|
||||
|
||||
if (use_window_) {
|
||||
window_ = Matrix::Ones(num_channels_, 1) *
|
||||
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
|
||||
frame_duration_samples_)
|
||||
.cast<float>();
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message TimeSeriesFramerCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional TimeSeriesFramerCalculatorOptions ext = 50631621;
|
||||
}
|
||||
|
||||
// Frame duration in seconds. Required. Must be greater than 0. This is
|
||||
// rounded to the nearest integer number of samples.
|
||||
optional double frame_duration_seconds = 1;
|
||||
|
||||
// Frame overlap in seconds.
|
||||
//
|
||||
// If emulate_fractional_frame_overlap is false (the default), then the frame
|
||||
// overlap is rounded to the nearest integer number of samples, and the step
|
||||
// from one frame to the next will be the difference between the number of
|
||||
// samples in a frame and the number of samples in the overlap.
|
||||
//
|
||||
// If emulate_fractional_frame_overlap is true, then frame overlap will be a
|
||||
// variable number of samples, such that the long-time average time step from
|
||||
// one frame to the next will be the difference between the (nominal, not
|
||||
// rounded) frame_duration_seconds and frame_overlap_seconds. This is useful
|
||||
// where the desired time step is not an integral number of input samples.
|
||||
//
|
||||
// A negative frame_overlap_seconds corresponds to skipping some input samples
|
||||
// between each frame of emitted samples.
|
||||
//
|
||||
// Required that frame_overlap_seconds < frame_duration_seconds.
|
||||
optional double frame_overlap_seconds = 2 [default = 0.0];
|
||||
|
||||
// See frame_overlap_seconds for semantics.
|
||||
optional bool emulate_fractional_frame_overlap = 5 [default = false];
|
||||
|
||||
// Whether to pad the final packet with zeros. If true, guarantees that all
|
||||
// input samples (other than those that fall in gaps implied by negative
|
||||
// frame_overlap_seconds) will be emitted. If set to false, any partial
|
||||
// packet at the end of the stream will be dropped.
|
||||
optional bool pad_final_packet = 3 [default = true];
|
||||
|
||||
// Optional windowing function. The default is NONE (no windowing function).
|
||||
enum WindowFunction {
|
||||
NONE = 0;
|
||||
HAMMING = 1;
|
||||
HANN = 2;
|
||||
}
|
||||
optional WindowFunction window_function = 4 [default = NONE];
|
||||
}
|
|
@ -0,0 +1,395 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "audio/dsp/window_functions.h"
|
||||
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
const int kInitialTimestampOffsetMicroseconds = 4;
|
||||
|
||||
class TimeSeriesFramerCalculatorTest
|
||||
: public TimeSeriesCalculatorTest<TimeSeriesFramerCalculatorOptions> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "TimeSeriesFramerCalculator";
|
||||
input_sample_rate_ = 4000.0;
|
||||
num_input_channels_ = 3;
|
||||
}
|
||||
|
||||
// Returns a float value with the channel and timestamp separated by
|
||||
// an order of magnitude, for easy parsing by humans.
|
||||
float TestValue(int64 timestamp_in_microseconds, int channel) {
|
||||
return timestamp_in_microseconds + channel / 10.0;
|
||||
}
|
||||
|
||||
// Caller takes ownership of the returned value.
|
||||
Matrix* NewTestFrame(int num_channels, int num_samples,
|
||||
double starting_timestamp_seconds) {
|
||||
auto matrix = new Matrix(num_channels, num_samples);
|
||||
for (int c = 0; c < num_channels; ++c) {
|
||||
for (int i = 0; i < num_samples; ++i) {
|
||||
int64 timestamp = time_series_util::SecondsToSamples(
|
||||
starting_timestamp_seconds + i / input_sample_rate_,
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
(*matrix)(c, i) = TestValue(timestamp, c);
|
||||
}
|
||||
}
|
||||
return matrix;
|
||||
}
|
||||
|
||||
// Initializes and runs the test graph.
|
||||
::mediapipe::Status Run() {
|
||||
InitializeGraph();
|
||||
|
||||
FillInputHeader();
|
||||
InitializeInput();
|
||||
|
||||
return RunGraph();
|
||||
}
|
||||
|
||||
// Creates test input and saves a reference copy.
|
||||
void InitializeInput() {
|
||||
concatenated_input_samples_.resize(0, num_input_channels_);
|
||||
num_input_samples_ = 0;
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
// This range of packet sizes was chosen such that some input
|
||||
// packets will be smaller than the output packet size and other
|
||||
// input packets will be larger.
|
||||
int packet_size = (i + 1) * 20;
|
||||
double timestamp_seconds = kInitialTimestampOffsetMicroseconds * 1.0e-6 +
|
||||
num_input_samples_ / input_sample_rate_;
|
||||
|
||||
Matrix* data_frame =
|
||||
NewTestFrame(num_input_channels_, packet_size, timestamp_seconds);
|
||||
|
||||
// Keep a reference copy of the input.
|
||||
//
|
||||
// conservativeResize() is needed here to preserve the existing
|
||||
// data. Eigen's resize() resizes without preserving data.
|
||||
concatenated_input_samples_.conservativeResize(
|
||||
num_input_channels_, num_input_samples_ + packet_size);
|
||||
concatenated_input_samples_.rightCols(packet_size) = *data_frame;
|
||||
num_input_samples_ += packet_size;
|
||||
|
||||
AppendInputPacket(data_frame, round(timestamp_seconds *
|
||||
Timestamp::kTimestampUnitsPerSecond));
|
||||
}
|
||||
|
||||
const int frame_duration_samples = FrameDurationSamples();
|
||||
std::vector<double> window_vector;
|
||||
switch (options_.window_function()) {
|
||||
case TimeSeriesFramerCalculatorOptions::HAMMING:
|
||||
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples,
|
||||
&window_vector);
|
||||
break;
|
||||
case TimeSeriesFramerCalculatorOptions::HANN:
|
||||
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples,
|
||||
&window_vector);
|
||||
break;
|
||||
case TimeSeriesFramerCalculatorOptions::NONE:
|
||||
window_vector.assign(frame_duration_samples, 1.0f);
|
||||
break;
|
||||
}
|
||||
|
||||
window_ = Matrix::Ones(num_input_channels_, 1) *
|
||||
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
|
||||
frame_duration_samples)
|
||||
.cast<float>();
|
||||
}
|
||||
|
||||
int FrameDurationSamples() {
|
||||
return time_series_util::SecondsToSamples(options_.frame_duration_seconds(),
|
||||
input_sample_rate_);
|
||||
}
|
||||
|
||||
// Checks that the values in the framed output packets matches the
|
||||
// appropriate values from the input.
|
||||
void CheckOutputPacketValues(const Matrix& actual, int packet_num,
|
||||
int frame_duration_samples,
|
||||
double frame_step_samples,
|
||||
int num_columns_to_check) {
|
||||
ASSERT_EQ(frame_duration_samples, actual.cols());
|
||||
Matrix expected = (concatenated_input_samples_
|
||||
.block(0, round(frame_step_samples * packet_num),
|
||||
num_input_channels_, num_columns_to_check)
|
||||
.array() *
|
||||
window_.leftCols(num_columns_to_check).array())
|
||||
.matrix();
|
||||
ExpectApproximatelyEqual(expected, actual.leftCols(num_columns_to_check));
|
||||
}
|
||||
|
||||
// Checks output headers, Timestamps, and values.
|
||||
void CheckOutput() {
|
||||
const int frame_duration_samples = FrameDurationSamples();
|
||||
const double frame_step_samples =
|
||||
options_.emulate_fractional_frame_overlap()
|
||||
? (options_.frame_duration_seconds() -
|
||||
options_.frame_overlap_seconds()) *
|
||||
input_sample_rate_
|
||||
: frame_duration_samples -
|
||||
time_series_util::SecondsToSamples(
|
||||
options_.frame_overlap_seconds(), input_sample_rate_);
|
||||
|
||||
TimeSeriesHeader expected_header = input().header.Get<TimeSeriesHeader>();
|
||||
expected_header.set_num_samples(frame_duration_samples);
|
||||
if (!options_.emulate_fractional_frame_overlap() ||
|
||||
frame_step_samples == round(frame_step_samples)) {
|
||||
expected_header.set_packet_rate(input_sample_rate_ / frame_step_samples);
|
||||
}
|
||||
ExpectOutputHeaderEquals(expected_header);
|
||||
|
||||
int num_full_packets = output().packets.size();
|
||||
if (options_.pad_final_packet()) {
|
||||
num_full_packets -= 1;
|
||||
}
|
||||
|
||||
for (int packet_num = 0; packet_num < num_full_packets; ++packet_num) {
|
||||
const Packet& packet = output().packets[packet_num];
|
||||
CheckOutputPacketValues(packet.Get<Matrix>(), packet_num,
|
||||
frame_duration_samples, frame_step_samples,
|
||||
frame_duration_samples);
|
||||
}
|
||||
|
||||
// What is the effective time index of the final sample emitted?
|
||||
// This includes accounting for the gaps when overlap is negative.
|
||||
const int num_unique_output_samples =
|
||||
round((output().packets.size() - 1) * frame_step_samples) +
|
||||
frame_duration_samples;
|
||||
LOG(INFO) << "packets.size()=" << output().packets.size()
|
||||
<< " frame_duration_samples=" << frame_duration_samples
|
||||
<< " frame_step_samples=" << frame_step_samples
|
||||
<< " num_input_samples_=" << num_input_samples_
|
||||
<< " num_unique_output_samples=" << num_unique_output_samples;
|
||||
const int num_padding_samples =
|
||||
num_unique_output_samples - num_input_samples_;
|
||||
if (options_.pad_final_packet()) {
|
||||
EXPECT_LT(num_padding_samples, frame_duration_samples);
|
||||
// If the input ended during the dropped samples between the end of
|
||||
// the last emitted frame and where the next one would begin, there
|
||||
// can be fewer unique output points than input points, even with
|
||||
// padding.
|
||||
const int max_dropped_samples =
|
||||
static_cast<int>(ceil(frame_step_samples - frame_duration_samples));
|
||||
EXPECT_GE(num_padding_samples, std::min(0, -max_dropped_samples));
|
||||
|
||||
if (num_padding_samples > 0) {
|
||||
// Check the non-padded part of the final packet.
|
||||
const Matrix& final_matrix = output().packets.back().Get<Matrix>();
|
||||
CheckOutputPacketValues(final_matrix, num_full_packets,
|
||||
frame_duration_samples, frame_step_samples,
|
||||
frame_duration_samples - num_padding_samples);
|
||||
// Check the padded part of the final packet.
|
||||
EXPECT_EQ(
|
||||
Matrix::Zero(num_input_channels_, num_padding_samples),
|
||||
final_matrix.block(0, frame_duration_samples - num_padding_samples,
|
||||
num_input_channels_, num_padding_samples));
|
||||
}
|
||||
} else {
|
||||
EXPECT_GT(num_padding_samples, -frame_duration_samples);
|
||||
EXPECT_LE(num_padding_samples, 0);
|
||||
}
|
||||
}
|
||||
|
||||
int num_input_samples_;
|
||||
Matrix concatenated_input_samples_;
|
||||
Matrix window_;
|
||||
};
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, IntegerSampleDurationNoOverlap) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest,
|
||||
IntegerSampleDurationNoOverlapHammingWindow) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HAMMING);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest,
|
||||
IntegerSampleDurationNoOverlapHannWindow) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HANN);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, IntegerSampleDurationAndOverlap) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(40.0 / input_sample_rate_);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, NonintegerSampleDurationAndOverlap) {
|
||||
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(38.4 / input_sample_rate_);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapExactFrames) {
|
||||
// Negative overlap means to drop samples between frames.
|
||||
// 100 samples per frame plus a skip of 10 samples will be 10 full frames in
|
||||
// the 1100 input samples.
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(-10.0 / input_sample_rate_);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 10);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapExactFramesLessSkip) {
|
||||
// 100 samples per frame plus a skip of 100 samples will be 6 full frames in
|
||||
// the 1100 input samples.
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(-100.0 / input_sample_rate_);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 6);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapWithPadding) {
|
||||
// 150 samples per frame plus a skip of 50 samples will require some padding
|
||||
// on the sixth and last frame given 1100 sample input.
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(-100.0 / input_sample_rate_);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 6);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, FixedFrameOverlap) {
|
||||
// Frame of 30 samples with step of 11.4 samples (rounded to 11 samples)
|
||||
// results in ceil((1100 - 30) / 11) + 1 = 99 packets.
|
||||
options_.set_frame_duration_seconds(30 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds((30.0 - 11.4) / input_sample_rate_);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 99);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, VariableFrameOverlap) {
|
||||
// Frame of 30 samples with step of 11.4 samples (not rounded)
|
||||
// results in ceil((1100 - 30) / 11.4) + 1 = 95 packets.
|
||||
options_.set_frame_duration_seconds(30 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds((30 - 11.4) / input_sample_rate_);
|
||||
options_.set_emulate_fractional_frame_overlap(true);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 95);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, VariableFrameSkip) {
|
||||
// Frame of 30 samples with step of 41.4 samples (not rounded)
|
||||
// results in ceil((1100 - 30) / 41.4) + 1 = 27 packets.
|
||||
options_.set_frame_duration_seconds(30 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds((30 - 41.4) / input_sample_rate_);
|
||||
options_.set_emulate_fractional_frame_overlap(true);
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 27);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, NoFinalPacketPadding) {
|
||||
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest,
|
||||
FrameRateHigherThanSampleRate_FrameDurationTooLow) {
|
||||
// Try to produce a frame rate 10 times the input sample rate by using a
|
||||
// a frame duration that is too small and covers only 0.1 samples.
|
||||
options_.set_frame_duration_seconds(1 / (10 * input_sample_rate_));
|
||||
options_.set_frame_overlap_seconds(0.0);
|
||||
EXPECT_FALSE(Run().ok());
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest,
|
||||
FrameRateHigherThanSampleRate_FrameStepTooLow) {
|
||||
// Try to produce a frame rate 10 times the input sample rate by using
|
||||
// a frame overlap that is too high and produces frame steps (difference
|
||||
// between duration and overlap) of 0.1 samples.
|
||||
options_.set_frame_duration_seconds(10.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(9.9 / input_sample_rate_);
|
||||
EXPECT_FALSE(Run().ok());
|
||||
}
|
||||
|
||||
// A simple test class to do windowing sanity checks. Tests from this
|
||||
// class input a single packet of all ones, and check the average
|
||||
// value of the single output packet. This is useful as a sanity check
|
||||
// that the correct windows are applied.
|
||||
class TimeSeriesFramerCalculatorWindowingSanityTest
|
||||
: public TimeSeriesFramerCalculatorTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TimeSeriesFramerCalculatorTest::SetUp();
|
||||
num_input_channels_ = 1;
|
||||
}
|
||||
|
||||
void RunAndTestSinglePacketAverage(float expected_average) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
AppendInputPacket(new Matrix(Matrix::Ones(1, FrameDurationSamples())),
|
||||
kInitialTimestampOffsetMicroseconds);
|
||||
MEDIAPIPE_ASSERT_OK(RunGraph());
|
||||
ASSERT_EQ(1, output().packets.size());
|
||||
ASSERT_NEAR(expected_average * FrameDurationSamples(),
|
||||
output().packets[0].Get<Matrix>().sum(), 1e-5);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, NoWindowSanityCheck) {
|
||||
RunAndTestSinglePacketAverage(1.0f);
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest,
|
||||
HammingWindowSanityCheck) {
|
||||
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HAMMING);
|
||||
RunAndTestSinglePacketAverage(0.54f);
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, HannWindowSanityCheck) {
|
||||
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HANN);
|
||||
RunAndTestSinglePacketAverage(0.5f);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace mediapipe
|
|
@ -19,6 +19,20 @@ package(default_visibility = ["//visibility:private"])
|
|||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "concatenate_vector_calculator_proto",
|
||||
srcs = ["concatenate_vector_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "packet_cloner_calculator_proto",
|
||||
srcs = ["packet_cloner_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "packet_resampler_calculator_proto",
|
||||
srcs = ["packet_resampler_calculator.proto"],
|
||||
|
@ -26,6 +40,46 @@ proto_library(
|
|||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "split_vector_calculator_proto",
|
||||
srcs = ["split_vector_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "quantize_float_vector_calculator_proto",
|
||||
srcs = ["quantize_float_vector_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "sequence_shift_calculator_proto",
|
||||
srcs = ["sequence_shift_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "gate_calculator_proto",
|
||||
srcs = ["gate_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "packet_cloner_calculator_cc_proto",
|
||||
srcs = ["packet_cloner_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
deps = [":packet_cloner_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "packet_resampler_calculator_cc_proto",
|
||||
srcs = ["packet_resampler_calculator.proto"],
|
||||
|
@ -34,6 +88,115 @@ mediapipe_cc_proto_library(
|
|||
deps = [":packet_resampler_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "split_vector_calculator_cc_proto",
|
||||
srcs = ["split_vector_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
deps = [":split_vector_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "concatenate_vector_calculator_cc_proto",
|
||||
srcs = ["concatenate_vector_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":concatenate_vector_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "quantize_float_vector_calculator_cc_proto",
|
||||
srcs = ["quantize_float_vector_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
deps = [":quantize_float_vector_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "sequence_shift_calculator_cc_proto",
|
||||
srcs = ["sequence_shift_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
deps = [":sequence_shift_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "gate_calculator_cc_proto",
|
||||
srcs = ["gate_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
deps = [":gate_calculator_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "add_header_calculator",
|
||||
srcs = ["add_header_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:logging",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "add_header_calculator_test",
|
||||
size = "small",
|
||||
srcs = ["add_header_calculator_test.cc"],
|
||||
deps = [
|
||||
":add_header_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concatenate_vector_calculator",
|
||||
srcs = ["concatenate_vector_calculator.cc"],
|
||||
hdrs = ["concatenate_vector_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":concatenate_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concatenate_detection_vector_calculator",
|
||||
srcs = ["concatenate_detection_vector_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":concatenate_vector_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "concatenate_vector_calculator_test",
|
||||
srcs = ["concatenate_vector_calculator_test.cc"],
|
||||
deps = [
|
||||
":concatenate_vector_calculator",
|
||||
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "counting_source_calculator",
|
||||
srcs = ["counting_source_calculator.cc"],
|
||||
|
@ -62,6 +225,38 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "matrix_multiply_calculator",
|
||||
srcs = ["matrix_multiply_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "matrix_subtract_calculator",
|
||||
srcs = ["matrix_subtract_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mux_calculator",
|
||||
srcs = ["mux_calculator.cc"],
|
||||
|
@ -83,12 +278,37 @@ cc_library(
|
|||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:packet_cloner_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "packet_inner_join_calculator",
|
||||
srcs = ["packet_inner_join_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "packet_inner_join_calculator_test",
|
||||
srcs = ["packet_inner_join_calculator_test.cc"],
|
||||
deps = [
|
||||
":packet_inner_join_calculator",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pass_through_calculator",
|
||||
srcs = ["pass_through_calculator.cc"],
|
||||
|
@ -145,8 +365,8 @@ cc_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "real_time_flow_limiter_calculator",
|
||||
srcs = ["real_time_flow_limiter_calculator.cc"],
|
||||
name = "flow_limiter_calculator",
|
||||
srcs = ["flow_limiter_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -191,19 +411,16 @@ cc_library(
|
|||
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework/deps:mathutil",
|
||||
"//mediapipe/framework/deps:random",
|
||||
"//mediapipe/framework/formats:video_stream_header",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:options_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//mediapipe/framework/deps:mathutil",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"//mediapipe/framework/deps:random",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
|
@ -245,10 +462,43 @@ cc_test(
|
|||
)
|
||||
|
||||
cc_test(
|
||||
name = "real_time_flow_limiter_calculator_test",
|
||||
srcs = ["real_time_flow_limiter_calculator_test.cc"],
|
||||
name = "matrix_multiply_calculator_test",
|
||||
srcs = ["matrix_multiply_calculator_test.cc"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":real_time_flow_limiter_calculator",
|
||||
":matrix_multiply_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "matrix_subtract_calculator_test",
|
||||
srcs = ["matrix_subtract_calculator_test.cc"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":matrix_subtract_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "flow_limiter_calculator_test",
|
||||
srcs = ["flow_limiter_calculator_test.cc"],
|
||||
deps = [
|
||||
":flow_limiter_calculator",
|
||||
"//mediapipe/calculators/core:counting_source_calculator",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -264,3 +514,140 @@ cc_test(
|
|||
"@com_google_absl//absl/time",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "split_vector_calculator",
|
||||
srcs = ["split_vector_calculator.cc"],
|
||||
hdrs = ["split_vector_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":split_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:resource_util",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "split_vector_calculator_test",
|
||||
srcs = ["split_vector_calculator_test.cc"],
|
||||
deps = [
|
||||
":split_vector_calculator",
|
||||
":split_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantize_float_vector_calculator",
|
||||
srcs = ["quantize_float_vector_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":quantize_float_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "quantize_float_vector_calculator_test",
|
||||
srcs = ["quantize_float_vector_calculator_test.cc"],
|
||||
deps = [
|
||||
":quantize_float_vector_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sequence_shift_calculator",
|
||||
srcs = ["sequence_shift_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":sequence_shift_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "sequence_shift_calculator_test",
|
||||
srcs = ["sequence_shift_calculator_test.cc"],
|
||||
deps = [
|
||||
":sequence_shift_calculator",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gate_calculator",
|
||||
srcs = ["gate_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gate_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:header_util",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "gate_calculator_test",
|
||||
srcs = ["gate_calculator_test.cc"],
|
||||
deps = [
|
||||
":gate_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "merge_calculator",
|
||||
srcs = ["merge_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "merge_calculator_test",
|
||||
srcs = ["merge_calculator_test.cc"],
|
||||
deps = [
|
||||
":merge_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
|
53
mediapipe/calculators/core/add_header_calculator.cc
Normal file
53
mediapipe/calculators/core/add_header_calculator.cc
Normal file
|
@ -0,0 +1,53 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Attach the header from one stream to another stream.
|
||||
//
|
||||
// The header stream (tag HEADER) must not have any packets in it.
|
||||
//
|
||||
// Before using this calculator, please think about changing your
|
||||
// calculator to not need a header or to accept a separate stream with
|
||||
// a header, that would be more future proof.
|
||||
//
|
||||
class AddHeaderCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("HEADER").SetNone();
|
||||
cc->Inputs().Tag("DATA").SetAny();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Tag("DATA"));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
const Packet& header = cc->Inputs().Tag("HEADER").Header();
|
||||
if (!header.IsEmpty()) {
|
||||
cc->Outputs().Index(0).SetHeader(header);
|
||||
}
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Tag("DATA").Value());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(AddHeaderCalculator);
|
||||
} // namespace mediapipe
|
99
mediapipe/calculators/core/add_header_calculator_test.cc
Normal file
99
mediapipe/calculators/core/add_header_calculator_test.cc
Normal file
|
@ -0,0 +1,99 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class AddHeaderCalculatorTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(AddHeaderCalculatorTest, Works) {
|
||||
CalculatorGraphConfig::Node node;
|
||||
node.set_calculator("AddHeaderCalculator");
|
||||
node.add_input_stream("HEADER:header_stream");
|
||||
node.add_input_stream("DATA:data_stream");
|
||||
node.add_output_stream("merged_stream");
|
||||
|
||||
CalculatorRunner runner(node);
|
||||
|
||||
// Set header and add 5 packets.
|
||||
runner.MutableInputs()->Tag("HEADER").header =
|
||||
Adopt(new std::string("my_header"));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
||||
}
|
||||
|
||||
// Run calculator.
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
ASSERT_EQ(1, runner.Outputs().NumEntries());
|
||||
|
||||
// Test output.
|
||||
EXPECT_EQ(std::string("my_header"),
|
||||
runner.Outputs().Index(0).header.Get<std::string>());
|
||||
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
|
||||
ASSERT_EQ(5, output_packets.size());
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
const int val = output_packets[i].Get<int>();
|
||||
EXPECT_EQ(i, val);
|
||||
EXPECT_EQ(Timestamp(i * 1000), output_packets[i].Timestamp());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(AddHeaderCalculatorTest, HandlesEmptyHeaderStream) {
|
||||
CalculatorGraphConfig::Node node;
|
||||
node.set_calculator("AddHeaderCalculator");
|
||||
node.add_input_stream("HEADER:header_stream");
|
||||
node.add_input_stream("DATA:data_stream");
|
||||
node.add_output_stream("merged_stream");
|
||||
|
||||
CalculatorRunner runner(node);
|
||||
|
||||
// No header and no packets.
|
||||
// Run calculator.
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
EXPECT_TRUE(runner.Outputs().Index(0).header.IsEmpty());
|
||||
}
|
||||
|
||||
TEST_F(AddHeaderCalculatorTest, NoPacketsOnHeaderStream) {
|
||||
CalculatorGraphConfig::Node node;
|
||||
node.set_calculator("AddHeaderCalculator");
|
||||
node.add_input_stream("HEADER:header_stream");
|
||||
node.add_input_stream("DATA:data_stream");
|
||||
node.add_output_stream("merged_stream");
|
||||
|
||||
CalculatorRunner runner(node);
|
||||
|
||||
// Set header and add 5 packets.
|
||||
runner.MutableInputs()->Tag("HEADER").header =
|
||||
Adopt(new std::string("my_header"));
|
||||
runner.MutableInputs()->Tag("HEADER").packets.push_back(
|
||||
Adopt(new std::string("not allowed")));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
||||
}
|
||||
|
||||
// Run calculator.
|
||||
ASSERT_FALSE(runner.Run().ok());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "ConcatenateDetectionVectorCalculator"
|
||||
// input_stream: "detection_vector_1"
|
||||
// input_stream: "detection_vector_2"
|
||||
// output_stream: "concatenated_detection_vector"
|
||||
// }
|
||||
typedef ConcatenateVectorCalculator<::mediapipe::Detection>
|
||||
ConcatenateDetectionVectorCalculator;
|
||||
REGISTER_CALCULATOR(ConcatenateDetectionVectorCalculator);
|
||||
|
||||
} // namespace mediapipe
|
44
mediapipe/calculators/core/concatenate_vector_calculator.cc
Normal file
44
mediapipe/calculators/core/concatenate_vector_calculator.cc
Normal file
|
@ -0,0 +1,44 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "ConcatenateFloatVectorCalculator"
|
||||
// input_stream: "float_vector_1"
|
||||
// input_stream: "float_vector_2"
|
||||
// output_stream: "concatenated_float_vector"
|
||||
// }
|
||||
typedef ConcatenateVectorCalculator<float> ConcatenateFloatVectorCalculator;
|
||||
REGISTER_CALCULATOR(ConcatenateFloatVectorCalculator);
|
||||
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "ConcatenateTfLiteTensorVectorCalculator"
|
||||
// input_stream: "tflitetensor_vector_1"
|
||||
// input_stream: "tflitetensor_vector_2"
|
||||
// output_stream: "concatenated_tflitetensor_vector"
|
||||
// }
|
||||
typedef ConcatenateVectorCalculator<TfLiteTensor>
|
||||
ConcatenateTfLiteTensorVectorCalculator;
|
||||
REGISTER_CALCULATOR(ConcatenateTfLiteTensorVectorCalculator);
|
||||
|
||||
} // namespace mediapipe
|
78
mediapipe/calculators/core/concatenate_vector_calculator.h
Normal file
78
mediapipe/calculators/core/concatenate_vector_calculator.h
Normal file
|
@ -0,0 +1,78 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Concatenates several std::vector<T> following stream index order. This class
|
||||
// assumes that every input stream contains the vector<T> type. To use this
|
||||
// class for a particular type T, regisiter a calculator using
|
||||
// ConcatenateVectorCalculator<T>.
|
||||
template <typename T>
|
||||
class ConcatenateVectorCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().NumEntries() != 0);
|
||||
RET_CHECK(cc->Outputs().NumEntries() == 1);
|
||||
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
cc->Inputs().Index(i).Set<std::vector<T>>();
|
||||
}
|
||||
|
||||
cc->Outputs().Index(0).Set<std::vector<T>>();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
only_emit_if_all_present_ =
|
||||
cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>()
|
||||
.only_emit_if_all_present();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
if (only_emit_if_all_present_) {
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
if (cc->Inputs().Index(i).IsEmpty()) return ::mediapipe::OkStatus();
|
||||
}
|
||||
}
|
||||
auto output = absl::make_unique<std::vector<T>>();
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
if (cc->Inputs().Index(i).IsEmpty()) continue;
|
||||
const std::vector<T>& input = cc->Inputs().Index(i).Get<std::vector<T>>();
|
||||
output->insert(output->end(), input.begin(), input.end());
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
bool only_emit_if_all_present_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message ConcatenateVectorCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional ConcatenateVectorCalculatorOptions ext = 259397839;
|
||||
}
|
||||
|
||||
// If true, the calculator will only emit a packet at the given timestamp if
|
||||
// all input streams have a non-empty packet (AND operation on streams).
|
||||
optional bool only_emit_if_all_present = 1 [default = false];
|
||||
}
|
238
mediapipe/calculators/core/concatenate_vector_calculator_test.cc
Normal file
238
mediapipe/calculators/core/concatenate_vector_calculator_test.cc
Normal file
|
@ -0,0 +1,238 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
|
||||
REGISTER_CALCULATOR(TestConcatenateIntVectorCalculator);
|
||||
|
||||
void AddInputVectors(const std::vector<std::vector<int>>& inputs,
|
||||
int64 timestamp, CalculatorRunner* runner) {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
runner->MutableInputs()->Index(i).packets.push_back(
|
||||
MakePacket<std::vector<int>>(inputs[i]).At(Timestamp(timestamp)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, EmptyVectorInputs) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<int>> inputs = {{}, {}, {}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_TRUE(outputs[0].Get<std::vector<int>>().empty());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, OneTimestamp) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<int>> inputs = {{1, 2, 3}, {4}, {5, 6}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, TwoInputsAtTwoTimestamps) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
{
|
||||
std::vector<std::vector<int>> inputs = {{1, 2, 3}, {4}, {5, 6}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
}
|
||||
{
|
||||
std::vector<std::vector<int>> inputs = {{0, 2}, {1}, {3, 5}};
|
||||
AddInputVectors(inputs, /*timestamp=*/2, &runner);
|
||||
}
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(2, outputs.size());
|
||||
{
|
||||
EXPECT_EQ(6, outputs[0].Get<std::vector<int>>().size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
|
||||
}
|
||||
{
|
||||
EXPECT_EQ(5, outputs[1].Get<std::vector<int>>().size());
|
||||
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||
std::vector<int> expected_vector = {0, 2, 1, 3, 5};
|
||||
EXPECT_EQ(expected_vector, outputs[1].Get<std::vector<int>>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, OneEmptyStreamStillOutput) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<int>> inputs = {{1, 2, 3}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<int> expected_vector = {1, 2, 3};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, OneEmptyStreamNoOutput) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/
|
||||
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
|
||||
"{only_emit_if_all_present: true}",
|
||||
/*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<int>> inputs = {{1, 2, 3}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
}
|
||||
|
||||
void AddInputVectors(const std::vector<std::vector<float>>& inputs,
|
||||
int64 timestamp, CalculatorRunner* runner) {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
runner->MutableInputs()->Index(i).packets.push_back(
|
||||
MakePacket<std::vector<float>>(inputs[i]).At(Timestamp(timestamp)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ConcatenateFloatVectorCalculatorTest, EmptyVectorInputs) {
|
||||
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<float>> inputs = {{}, {}, {}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_TRUE(outputs[0].Get<std::vector<float>>().empty());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
TEST(ConcatenateFloatVectorCalculatorTest, OneTimestamp) {
|
||||
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<float>> inputs = {
|
||||
{1.0f, 2.0f, 3.0f}, {4.0f}, {5.0f, 6.0f}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
|
||||
}
|
||||
|
||||
TEST(ConcatenateFloatVectorCalculatorTest, TwoInputsAtTwoTimestamps) {
|
||||
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
{
|
||||
std::vector<std::vector<float>> inputs = {
|
||||
{1.0f, 2.0f, 3.0f}, {4.0f}, {5.0f, 6.0f}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
}
|
||||
{
|
||||
std::vector<std::vector<float>> inputs = {
|
||||
{0.0f, 2.0f}, {1.0f}, {3.0f, 5.0f}};
|
||||
AddInputVectors(inputs, /*timestamp=*/2, &runner);
|
||||
}
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(2, outputs.size());
|
||||
{
|
||||
EXPECT_EQ(6, outputs[0].Get<std::vector<float>>().size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
|
||||
}
|
||||
{
|
||||
EXPECT_EQ(5, outputs[1].Get<std::vector<float>>().size());
|
||||
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||
std::vector<float> expected_vector = {0.0f, 2.0f, 1.0f, 3.0f, 5.0f};
|
||||
EXPECT_EQ(expected_vector, outputs[1].Get<std::vector<float>>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamStillOutput) {
|
||||
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<float>> inputs = {{1.0f, 2.0f, 3.0f}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
|
||||
}
|
||||
|
||||
TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) {
|
||||
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
|
||||
/*options_string=*/
|
||||
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
|
||||
"{only_emit_if_all_present: true}",
|
||||
/*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<float>> inputs = {{1.0f, 2.0f, 3.0f}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -23,34 +23,34 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
// RealTimeFlowLimiterCalculator is used to limit the number of pipelined
|
||||
// processing operations in a section of the graph.
|
||||
// FlowLimiterCalculator is used to limit the number of pipelined processing
|
||||
// operations in a section of the graph.
|
||||
//
|
||||
// Typical topology:
|
||||
//
|
||||
// in ->-[RTFLC]-[foo]-...-[bar]-+->- out
|
||||
// ^____________________|
|
||||
// in ->-[FLC]-[foo]-...-[bar]-+->- out
|
||||
// ^_____________________|
|
||||
// FINISHED
|
||||
//
|
||||
// By connecting the output of the graph section to this calculator's FINISHED
|
||||
// input with a backwards edge, this allows RTFLC to keep track of how many
|
||||
// input with a backwards edge, this allows FLC to keep track of how many
|
||||
// timestamps are currently being processed.
|
||||
//
|
||||
// The limit defaults to 1, and can be overridden with the MAX_IN_FLIGHT side
|
||||
// packet.
|
||||
//
|
||||
// As long as the number of timestamps being processed ("in flight") is below
|
||||
// the limit, RTFLC allows input to pass through. When the limit is reached,
|
||||
// RTFLC starts dropping input packets, keeping only the most recent. When the
|
||||
// the limit, FLC allows input to pass through. When the limit is reached,
|
||||
// FLC starts dropping input packets, keeping only the most recent. When the
|
||||
// processing count decreases again, as signaled by the receipt of a packet on
|
||||
// FINISHED, RTFLC allows packets to flow again, releasing the most recently
|
||||
// FINISHED, FLC allows packets to flow again, releasing the most recently
|
||||
// queued packet, if any.
|
||||
//
|
||||
// If there are multiple input streams, packet dropping is synchronized.
|
||||
//
|
||||
// IMPORTANT: for each timestamp where RTFLC forwards a packet (or a set of
|
||||
// IMPORTANT: for each timestamp where FLC forwards a packet (or a set of
|
||||
// packets, if using multiple data streams), a packet must eventually arrive on
|
||||
// the FINISHED stream. Dropping packets in the section between RTFLC and
|
||||
// the FINISHED stream. Dropping packets in the section between FLC and
|
||||
// FINISHED will make the in-flight count incorrect.
|
||||
//
|
||||
// TODO: Remove this comment when graph-level ISH has been removed.
|
||||
|
@ -61,7 +61,7 @@ namespace mediapipe {
|
|||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "RealTimeFlowLimiterCalculator"
|
||||
// calculator: "FlowLimiterCalculator"
|
||||
// input_stream: "raw_frames"
|
||||
// input_stream: "FINISHED:finished"
|
||||
// input_stream_info: {
|
||||
|
@ -73,7 +73,7 @@ namespace mediapipe {
|
|||
// }
|
||||
// output_stream: "gated_frames"
|
||||
// }
|
||||
class RealTimeFlowLimiterCalculator : public CalculatorBase {
|
||||
class FlowLimiterCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
int num_data_streams = cc->Inputs().NumEntries("");
|
||||
|
@ -194,6 +194,6 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase {
|
|||
Timestamp allow_ctr_ts_;
|
||||
std::vector<Timestamp> data_stream_bound_ts_;
|
||||
};
|
||||
REGISTER_CALCULATOR(RealTimeFlowLimiterCalculator);
|
||||
REGISTER_CALCULATOR(FlowLimiterCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -71,7 +71,7 @@ constexpr int kNumImageFrames = 5;
|
|||
constexpr int kNumFinished = 3;
|
||||
CalculatorGraphConfig::Node GetDefaultNode() {
|
||||
return ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "RealTimeFlowLimiterCalculator"
|
||||
calculator: "FlowLimiterCalculator"
|
||||
input_stream: "raw_frames"
|
||||
input_stream: "FINISHED:finished"
|
||||
input_stream_info: { tag_index: "FINISHED" back_edge: true }
|
||||
|
@ -79,9 +79,9 @@ CalculatorGraphConfig::Node GetDefaultNode() {
|
|||
)");
|
||||
}
|
||||
|
||||
// Simple test to make sure that the RealTimeFlowLimiterCalculator outputs
|
||||
// just one packet when MAX_IN_FLIGHT is 1.
|
||||
TEST(RealTimeFlowLimiterCalculator, OneOutputTest) {
|
||||
// Simple test to make sure that the FlowLimiterCalculator outputs just one
|
||||
// packet when MAX_IN_FLIGHT is 1.
|
||||
TEST(FlowLimiterCalculator, OneOutputTest) {
|
||||
// Setup the calculator runner and add only ImageFrame packets.
|
||||
CalculatorRunner runner(GetDefaultNode());
|
||||
for (int i = 0; i < kNumImageFrames; ++i) {
|
||||
|
@ -98,9 +98,9 @@ TEST(RealTimeFlowLimiterCalculator, OneOutputTest) {
|
|||
EXPECT_EQ(frame_output_packets.size(), 1);
|
||||
}
|
||||
|
||||
// Simple test to make sure that the RealTimeFlowLimiterCalculator waits for all
|
||||
// Simple test to make sure that the FlowLimiterCalculator waits for all
|
||||
// input streams to have at least one packet available before publishing.
|
||||
TEST(RealTimeFlowLimiterCalculator, BasicTest) {
|
||||
TEST(FlowLimiterCalculator, BasicTest) {
|
||||
// Setup the calculator runner and add both ImageFrame and finish packets.
|
||||
CalculatorRunner runner(GetDefaultNode());
|
||||
for (int i = 0; i < kNumImageFrames; ++i) {
|
||||
|
@ -171,13 +171,11 @@ class CloseCallbackCalculator : public CalculatorBase {
|
|||
};
|
||||
REGISTER_CALCULATOR(CloseCallbackCalculator);
|
||||
|
||||
// Tests demostrating an RealTimeFlowLimiterCalculator operating in a cyclic
|
||||
// graph.
|
||||
// Tests demostrating an FlowLimiterCalculator operating in a cyclic graph.
|
||||
// TODO: clean up these tests.
|
||||
class RealTimeFlowLimiterCalculatorTest : public testing::Test {
|
||||
class FlowLimiterCalculatorTest : public testing::Test {
|
||||
public:
|
||||
RealTimeFlowLimiterCalculatorTest()
|
||||
: enter_semaphore_(0), exit_semaphore_(0) {}
|
||||
FlowLimiterCalculatorTest() : enter_semaphore_(0), exit_semaphore_(0) {}
|
||||
|
||||
void SetUp() override {
|
||||
graph_config_ = InflightGraphConfig();
|
||||
|
@ -215,7 +213,7 @@ class RealTimeFlowLimiterCalculatorTest : public testing::Test {
|
|||
input_name, MakePacket<int>(value).At(Timestamp(value))));
|
||||
}
|
||||
|
||||
// A calculator graph starting with an RealTimeFlowLimiterCalculator and
|
||||
// A calculator graph starting with an FlowLimiterCalculator and
|
||||
// ending with a InFlightFinishCalculator.
|
||||
// Back-edge "finished" limits processing to one frame in-flight.
|
||||
// The two LambdaCalculators are used to keep certain packet sets in flight.
|
||||
|
@ -224,7 +222,7 @@ class RealTimeFlowLimiterCalculatorTest : public testing::Test {
|
|||
input_stream: 'in_1'
|
||||
input_stream: 'in_2'
|
||||
node {
|
||||
calculator: 'RealTimeFlowLimiterCalculator'
|
||||
calculator: 'FlowLimiterCalculator'
|
||||
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
|
||||
input_stream: 'in_1'
|
||||
input_stream: 'in_2'
|
||||
|
@ -270,14 +268,14 @@ class RealTimeFlowLimiterCalculatorTest : public testing::Test {
|
|||
int close_count_ = 0;
|
||||
};
|
||||
|
||||
// A test demonstrating an RealTimeFlowLimiterCalculator operating in a cyclic
|
||||
// A test demonstrating an FlowLimiterCalculator operating in a cyclic
|
||||
// graph. This test shows that:
|
||||
//
|
||||
// (1) Timestamps are passed through unaltered.
|
||||
// (2) All output streams including the back_edge stream are closed when
|
||||
// the first input stream is closed.
|
||||
//
|
||||
TEST_F(RealTimeFlowLimiterCalculatorTest, BackEdgeCloses) {
|
||||
TEST_F(FlowLimiterCalculatorTest, BackEdgeCloses) {
|
||||
InitializeGraph(1);
|
||||
MEDIAPIPE_ASSERT_OK(graph_.StartRun({}));
|
||||
|
||||
|
@ -321,7 +319,7 @@ TEST_F(RealTimeFlowLimiterCalculatorTest, BackEdgeCloses) {
|
|||
|
||||
// A test demonstrating that all output streams are closed when all
|
||||
// input streams are closed after the last input packet has been processed.
|
||||
TEST_F(RealTimeFlowLimiterCalculatorTest, AllStreamsClose) {
|
||||
TEST_F(FlowLimiterCalculatorTest, AllStreamsClose) {
|
||||
InitializeGraph(1);
|
||||
MEDIAPIPE_ASSERT_OK(graph_.StartRun({}));
|
||||
|
||||
|
@ -341,7 +339,7 @@ TEST_F(RealTimeFlowLimiterCalculatorTest, AllStreamsClose) {
|
|||
EXPECT_EQ(1, close_count_);
|
||||
}
|
||||
|
||||
TEST(RealTimeFlowLimiterCalculator, TwoStreams) {
|
||||
TEST(FlowLimiterCalculator, TwoStreams) {
|
||||
std::vector<Packet> a_passed;
|
||||
std::vector<Packet> b_passed;
|
||||
CalculatorGraphConfig graph_config_ =
|
||||
|
@ -351,7 +349,7 @@ TEST(RealTimeFlowLimiterCalculator, TwoStreams) {
|
|||
input_stream: 'finished'
|
||||
node {
|
||||
name: 'input_dropper'
|
||||
calculator: 'RealTimeFlowLimiterCalculator'
|
||||
calculator: 'FlowLimiterCalculator'
|
||||
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
|
||||
input_stream: 'in_a'
|
||||
input_stream: 'in_b'
|
||||
|
@ -440,7 +438,7 @@ TEST(RealTimeFlowLimiterCalculator, TwoStreams) {
|
|||
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(RealTimeFlowLimiterCalculator, CanConsume) {
|
||||
TEST(FlowLimiterCalculator, CanConsume) {
|
||||
std::vector<Packet> in_sampled_packets_;
|
||||
CalculatorGraphConfig graph_config_ =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
|
@ -448,7 +446,7 @@ TEST(RealTimeFlowLimiterCalculator, CanConsume) {
|
|||
input_stream: 'finished'
|
||||
node {
|
||||
name: 'input_dropper'
|
||||
calculator: 'RealTimeFlowLimiterCalculator'
|
||||
calculator: 'FlowLimiterCalculator'
|
||||
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
|
||||
input_stream: 'in'
|
||||
input_stream: 'FINISHED:finished'
|
163
mediapipe/calculators/core/gate_calculator.cc
Normal file
163
mediapipe/calculators/core/gate_calculator.cc
Normal file
|
@ -0,0 +1,163 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/gate_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/header_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
enum GateState {
|
||||
GATE_UNINITIALIZED,
|
||||
GATE_ALLOW,
|
||||
GATE_DISALLOW,
|
||||
};
|
||||
|
||||
std::string ToString(GateState state) {
|
||||
switch (state) {
|
||||
case GATE_UNINITIALIZED:
|
||||
return "UNINITIALIZED";
|
||||
case GATE_ALLOW:
|
||||
return "ALLOW";
|
||||
case GATE_DISALLOW:
|
||||
return "DISALLOW";
|
||||
}
|
||||
DLOG(FATAL) << "Unknown GateState";
|
||||
return "UNKNOWN";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Controls whether or not the input packets are passed further along the graph.
|
||||
// Takes multiple data input streams and either an ALLOW or a DISALLOW control
|
||||
// input stream. It outputs an output stream for each input stream that is not
|
||||
// ALLOW or DISALLOW as well as an optional STATE_CHANGE stream which downstream
|
||||
// calculators can use to respond to state-change events.
|
||||
//
|
||||
// If the current ALLOW packet is set to true, the input packets are passed to
|
||||
// their corresponding output stream unchanged. If the ALLOW packet is set to
|
||||
// false, the current input packet is NOT passed to the output stream. If using
|
||||
// DISALLOW, the behavior is opposite of ALLOW.
|
||||
//
|
||||
// By default, an empty packet in the ALLOW or DISALLOW input stream indicates
|
||||
// disallowing the corresponding packets in other input streams. The behavior
|
||||
// can be inverted with a calculator option.
|
||||
//
|
||||
// Intended to be used with the default input stream handler, which synchronizes
|
||||
// all data input streams with the ALLOW/DISALLOW control input stream.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "GateCalculator"
|
||||
// input_stream: "input_stream0"
|
||||
// input_stream: "input_stream1"
|
||||
// input_stream: "input_streamN"
|
||||
// input_stream: "ALLOW:allow" or "DISALLOW:disallow"
|
||||
// output_stream: "STATE_CHANGE:state_change"
|
||||
// output_stream: "output_stream0"
|
||||
// output_stream: "output_stream1"
|
||||
// output_stream: "output_streamN"
|
||||
// }
|
||||
class GateCalculator : public CalculatorBase {
|
||||
public:
|
||||
GateCalculator() {}
|
||||
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
// Assume that input streams do not have a tag and that gating signal is
|
||||
// tagged either ALLOW or DISALLOW.
|
||||
RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW"));
|
||||
const int num_data_streams = cc->Inputs().NumEntries("");
|
||||
RET_CHECK_GE(num_data_streams, 1);
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams)
|
||||
<< "Number of data output streams must match with data input streams.";
|
||||
|
||||
for (int i = 0; i < num_data_streams; ++i) {
|
||||
cc->Inputs().Get("", i).SetAny();
|
||||
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
|
||||
}
|
||||
if (cc->Inputs().HasTag("ALLOW")) {
|
||||
cc->Inputs().Tag("ALLOW").Set<bool>();
|
||||
} else {
|
||||
cc->Inputs().Tag("DISALLOW").Set<bool>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("STATE_CHANGE")) {
|
||||
cc->Outputs().Tag("STATE_CHANGE").Set<bool>();
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
num_data_streams_ = cc->Inputs().NumEntries("");
|
||||
last_gate_state_ = GATE_UNINITIALIZED;
|
||||
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
|
||||
|
||||
const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>();
|
||||
empty_packets_as_allow_ = options.empty_packets_as_allow();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
bool allow = empty_packets_as_allow_;
|
||||
if (cc->Inputs().HasTag("ALLOW") && !cc->Inputs().Tag("ALLOW").IsEmpty()) {
|
||||
allow = cc->Inputs().Tag("ALLOW").Get<bool>();
|
||||
}
|
||||
if (cc->Inputs().HasTag("DISALLOW") &&
|
||||
!cc->Inputs().Tag("DISALLOW").IsEmpty()) {
|
||||
allow = !cc->Inputs().Tag("DISALLOW").Get<bool>();
|
||||
}
|
||||
|
||||
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
|
||||
|
||||
if (cc->Outputs().HasTag("STATE_CHANGE")) {
|
||||
if (last_gate_state_ != GATE_UNINITIALIZED &&
|
||||
last_gate_state_ != new_gate_state) {
|
||||
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
|
||||
<< cc->InputTimestamp().Value() << " from "
|
||||
<< ToString(last_gate_state_) << " to "
|
||||
<< ToString(new_gate_state);
|
||||
cc->Outputs()
|
||||
.Tag("STATE_CHANGE")
|
||||
.AddPacket(MakePacket<bool>(allow).At(cc->InputTimestamp()));
|
||||
}
|
||||
}
|
||||
last_gate_state_ = new_gate_state;
|
||||
|
||||
if (!allow) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Process data streams.
|
||||
for (int i = 0; i < num_data_streams_; ++i) {
|
||||
if (!cc->Inputs().Get("", i).IsEmpty()) {
|
||||
cc->Outputs().Get("", i).AddPacket(cc->Inputs().Get("", i).Value());
|
||||
}
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
GateState last_gate_state_ = GATE_UNINITIALIZED;
|
||||
int num_data_streams_;
|
||||
bool empty_packets_as_allow_;
|
||||
};
|
||||
REGISTER_CALCULATOR(GateCalculator);
|
||||
|
||||
} // namespace mediapipe
|
30
mediapipe/calculators/core/gate_calculator.proto
Normal file
30
mediapipe/calculators/core/gate_calculator.proto
Normal file
|
@ -0,0 +1,30 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message GateCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional GateCalculatorOptions ext = 261754847;
|
||||
}
|
||||
|
||||
// By default an empty packet in the ALLOW or DISALLOW input stream indicates
|
||||
// disallowing the corresponding packets in the data input streams. Setting
|
||||
// this option to true inverts that, allowing the data packets to go through.
|
||||
optional bool empty_packets_as_allow = 1;
|
||||
}
|
190
mediapipe/calculators/core/gate_calculator_test.cc
Normal file
190
mediapipe/calculators/core/gate_calculator_test.cc
Normal file
|
@ -0,0 +1,190 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
class GateCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void RunTimeStep(int64 timestamp, const std::string& control_tag,
|
||||
bool control) {
|
||||
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
||||
MakePacket<bool>(true).At(Timestamp(timestamp)));
|
||||
runner_->MutableInputs()
|
||||
->Tag(control_tag)
|
||||
.packets.push_back(MakePacket<bool>(control).At(Timestamp(timestamp)));
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
|
||||
}
|
||||
|
||||
void SetRunner(const std::string& proto) {
|
||||
runner_ = absl::make_unique<CalculatorRunner>(
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(proto));
|
||||
}
|
||||
|
||||
CalculatorRunner* runner() { return runner_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<CalculatorRunner> runner_;
|
||||
};
|
||||
|
||||
TEST_F(GateCalculatorTest, Allow) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "ALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "ALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
ASSERT_EQ(2, output.size());
|
||||
EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value());
|
||||
EXPECT_EQ(kTimestampValue2, output[1].Timestamp().Value());
|
||||
EXPECT_EQ(true, output[0].Get<bool>());
|
||||
EXPECT_EQ(true, output[1].Get<bool>());
|
||||
}
|
||||
|
||||
TEST_F(GateCalculatorTest, Disallow) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "DISALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "DISALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
ASSERT_EQ(2, output.size());
|
||||
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
|
||||
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
|
||||
EXPECT_EQ(true, output[0].Get<bool>());
|
||||
EXPECT_EQ(true, output[1].Get<bool>());
|
||||
}
|
||||
|
||||
TEST_F(GateCalculatorTest, AllowWithStateChange) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "ALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", false);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
|
||||
ASSERT_EQ(2, output.size());
|
||||
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
|
||||
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
|
||||
EXPECT_EQ(true, output[0].Get<bool>()); // Allow.
|
||||
EXPECT_EQ(false, output[1].Get<bool>()); // Disallow.
|
||||
}
|
||||
|
||||
TEST_F(GateCalculatorTest, DisallowWithStateChange) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "DISALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "DISALLOW", true);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
|
||||
ASSERT_EQ(2, output.size());
|
||||
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
|
||||
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
|
||||
EXPECT_EQ(true, output[0].Get<bool>()); // Allow.
|
||||
EXPECT_EQ(false, output[1].Get<bool>()); // Disallow.
|
||||
}
|
||||
|
||||
// Must not detect disallow value for first timestamp as a state change.
|
||||
TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "DISALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
|
||||
ASSERT_EQ(0, output.size());
|
||||
}
|
||||
|
||||
// Must not detect allow value for first timestamp as a state change.
|
||||
TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "ALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
|
||||
ASSERT_EQ(0, output.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -146,7 +146,7 @@ class ImmediateMuxCalculatorTest : public ::testing::Test {
|
|||
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
|
||||
input_stream: "input_packets_0"
|
||||
node {
|
||||
calculator: 'RealTimeFlowLimiterCalculator'
|
||||
calculator: 'FlowLimiterCalculator'
|
||||
input_stream_handler {
|
||||
input_stream_handler: 'ImmediateInputStreamHandler'
|
||||
}
|
||||
|
|
66
mediapipe/calculators/core/matrix_multiply_calculator.cc
Normal file
66
mediapipe/calculators/core/matrix_multiply_calculator.cc
Normal file
|
@ -0,0 +1,66 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
// Perform a (left) matrix multiply. Meaning (output = A * input)
|
||||
// where A is the matrix which is provided as an input side packet.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "MatrixMultiplyCalculator"
|
||||
// input_stream: "samples"
|
||||
// output_stream: "multiplied_samples"
|
||||
// input_side_packet: "multiplication_matrix"
|
||||
// }
|
||||
class MatrixMultiplyCalculator : public CalculatorBase {
|
||||
public:
|
||||
MatrixMultiplyCalculator() {}
|
||||
~MatrixMultiplyCalculator() override {}
|
||||
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
REGISTER_CALCULATOR(MatrixMultiplyCalculator);
|
||||
|
||||
// static
|
||||
::mediapipe::Status MatrixMultiplyCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>();
|
||||
cc->Outputs().Index(0).Set<Matrix>();
|
||||
cc->InputSidePackets().Index(0).Set<Matrix>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status MatrixMultiplyCalculator::Open(CalculatorContext* cc) {
|
||||
// The output is at the same timestamp as the input.
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) {
|
||||
Matrix* multiplied = new Matrix();
|
||||
*multiplied = cc->InputSidePackets().Index(0).Get<Matrix>() *
|
||||
cc->Inputs().Index(0).Get<Matrix>();
|
||||
cc->Outputs().Index(0).Add(multiplied, cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
239
mediapipe/calculators/core/matrix_multiply_calculator_test.cc
Normal file
239
mediapipe/calculators/core/matrix_multiply_calculator_test.cc
Normal file
|
@ -0,0 +1,239 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
// A 3x4 Matrix of random integers in [0,1000).
|
||||
const char kMatrixText[] =
|
||||
"rows: 3\n"
|
||||
"cols: 4\n"
|
||||
"packed_data: 387\n"
|
||||
"packed_data: 940\n"
|
||||
"packed_data: 815\n"
|
||||
"packed_data: 825\n"
|
||||
"packed_data: 997\n"
|
||||
"packed_data: 884\n"
|
||||
"packed_data: 419\n"
|
||||
"packed_data: 763\n"
|
||||
"packed_data: 123\n"
|
||||
"packed_data: 30\n"
|
||||
"packed_data: 825\n"
|
||||
"packed_data: 299\n";
|
||||
|
||||
// A 4x20 Matrix of random integers in [0,10).
|
||||
// Each column of this matrix is a sample.
|
||||
const char kSamplesText[] =
|
||||
"rows: 4\n"
|
||||
"cols: 20\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 9\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 9\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 9\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 9\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 9\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 7\n";
|
||||
|
||||
// A 3x20 Matrix of expected values for the result of the matrix multiply
|
||||
// computed using R.
|
||||
// Each column of this matrix is an expected output.
|
||||
const char kExpectedText[] =
|
||||
"rows: 3\n"
|
||||
"cols: 20\n"
|
||||
"packed_data: 12499\n"
|
||||
"packed_data: 26793\n"
|
||||
"packed_data: 16967\n"
|
||||
"packed_data: 5007\n"
|
||||
"packed_data: 14406\n"
|
||||
"packed_data: 9635\n"
|
||||
"packed_data: 4179\n"
|
||||
"packed_data: 7870\n"
|
||||
"packed_data: 4434\n"
|
||||
"packed_data: 5793\n"
|
||||
"packed_data: 12045\n"
|
||||
"packed_data: 8876\n"
|
||||
"packed_data: 2801\n"
|
||||
"packed_data: 8053\n"
|
||||
"packed_data: 5611\n"
|
||||
"packed_data: 1740\n"
|
||||
"packed_data: 4469\n"
|
||||
"packed_data: 2665\n"
|
||||
"packed_data: 8108\n"
|
||||
"packed_data: 18396\n"
|
||||
"packed_data: 10186\n"
|
||||
"packed_data: 12330\n"
|
||||
"packed_data: 23374\n"
|
||||
"packed_data: 15526\n"
|
||||
"packed_data: 9611\n"
|
||||
"packed_data: 21804\n"
|
||||
"packed_data: 14776\n"
|
||||
"packed_data: 8241\n"
|
||||
"packed_data: 17979\n"
|
||||
"packed_data: 11989\n"
|
||||
"packed_data: 8429\n"
|
||||
"packed_data: 18921\n"
|
||||
"packed_data: 9819\n"
|
||||
"packed_data: 6270\n"
|
||||
"packed_data: 13689\n"
|
||||
"packed_data: 7031\n"
|
||||
"packed_data: 9472\n"
|
||||
"packed_data: 19210\n"
|
||||
"packed_data: 13634\n"
|
||||
"packed_data: 8567\n"
|
||||
"packed_data: 12499\n"
|
||||
"packed_data: 10455\n"
|
||||
"packed_data: 2151\n"
|
||||
"packed_data: 7469\n"
|
||||
"packed_data: 3195\n"
|
||||
"packed_data: 10774\n"
|
||||
"packed_data: 21851\n"
|
||||
"packed_data: 12673\n"
|
||||
"packed_data: 12516\n"
|
||||
"packed_data: 25318\n"
|
||||
"packed_data: 14347\n"
|
||||
"packed_data: 7984\n"
|
||||
"packed_data: 17100\n"
|
||||
"packed_data: 10972\n"
|
||||
"packed_data: 5195\n"
|
||||
"packed_data: 11102\n"
|
||||
"packed_data: 8710\n"
|
||||
"packed_data: 3002\n"
|
||||
"packed_data: 11295\n"
|
||||
"packed_data: 6360\n";
|
||||
|
||||
// Send a number of samples through the MatrixMultiplyCalculator.
|
||||
TEST(MatrixMultiplyCalculatorTest, Multiply) {
|
||||
CalculatorRunner runner("MatrixMultiplyCalculator", "", 1, 1, 1);
|
||||
Matrix* matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText, matrix);
|
||||
runner.MutableSidePackets()->Index(0) = Adopt(matrix);
|
||||
|
||||
Matrix samples;
|
||||
MatrixFromTextProto(kSamplesText, &samples);
|
||||
Matrix expected;
|
||||
MatrixFromTextProto(kExpectedText, &expected);
|
||||
CHECK_EQ(samples.cols(), expected.cols());
|
||||
|
||||
for (int i = 0; i < samples.cols(); ++i) {
|
||||
// Take a column from samples and produce a packet with just that
|
||||
// column in it as an input sample for the calculator.
|
||||
Eigen::MatrixXf* sample = new Eigen::MatrixXf(samples.block(0, i, 4, 1));
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(sample).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
EXPECT_EQ(runner.MutableInputs()->Index(0).packets.size(),
|
||||
runner.Outputs().Index(0).packets.size());
|
||||
|
||||
int i = 0;
|
||||
for (const Packet& output : runner.Outputs().Index(0).packets) {
|
||||
EXPECT_EQ(Timestamp(i), output.Timestamp());
|
||||
const Eigen::MatrixXf& result = output.Get<Matrix>();
|
||||
ASSERT_EQ(3, result.rows());
|
||||
EXPECT_NEAR((expected.block(0, i, 3, 1) - result).cwiseAbs().sum(), 0.0,
|
||||
1e-5);
|
||||
++i;
|
||||
}
|
||||
EXPECT_EQ(samples.cols(), i);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
123
mediapipe/calculators/core/matrix_subtract_calculator.cc
Normal file
123
mediapipe/calculators/core/matrix_subtract_calculator.cc
Normal file
|
@ -0,0 +1,123 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Subtract input matrix from the side input matrix and vice versa. The matrices
|
||||
// must have the same dimension.
|
||||
// Based on the tag (MINUEND vs SUBTRAHEND), the matrices in the input stream
|
||||
// and input side packet can be either subtrahend or minuend. The output matrix
|
||||
// is generated by performing minuend matrix - subtrahend matrix.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "MatrixSubtractCalculator"
|
||||
// input_stream: "MINUEND:input_matrix"
|
||||
// input_side_packet: "SUBTRAHEND:side_matrix"
|
||||
// output_stream: "output_matrix"
|
||||
// }
|
||||
//
|
||||
// or
|
||||
//
|
||||
// node {
|
||||
// calculator: "MatrixSubtractCalculator"
|
||||
// input_stream: "SUBTRAHEND:input_matrix"
|
||||
// input_side_packet: "MINUEND:side_matrix"
|
||||
// output_stream: "output_matrix"
|
||||
// }
|
||||
class MatrixSubtractCalculator : public CalculatorBase {
|
||||
public:
|
||||
MatrixSubtractCalculator() {}
|
||||
~MatrixSubtractCalculator() override {}
|
||||
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
bool subtract_from_input_ = false;
|
||||
};
|
||||
REGISTER_CALCULATOR(MatrixSubtractCalculator);
|
||||
|
||||
// static
|
||||
::mediapipe::Status MatrixSubtractCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
if (cc->Inputs().NumEntries() != 1 ||
|
||||
cc->InputSidePackets().NumEntries() != 1) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"MatrixSubtractCalculator only accepts exactly one input stream and "
|
||||
"one "
|
||||
"input side packet");
|
||||
}
|
||||
if (cc->Inputs().HasTag("MINUEND") &&
|
||||
cc->InputSidePackets().HasTag("SUBTRAHEND")) {
|
||||
cc->Inputs().Tag("MINUEND").Set<Matrix>();
|
||||
cc->InputSidePackets().Tag("SUBTRAHEND").Set<Matrix>();
|
||||
} else if (cc->Inputs().HasTag("SUBTRAHEND") &&
|
||||
cc->InputSidePackets().HasTag("MINUEND")) {
|
||||
cc->Inputs().Tag("SUBTRAHEND").Set<Matrix>();
|
||||
cc->InputSidePackets().Tag("MINUEND").Set<Matrix>();
|
||||
} else {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Must specify exactly one minuend and one subtrahend.");
|
||||
}
|
||||
cc->Outputs().Index(0).Set<Matrix>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status MatrixSubtractCalculator::Open(CalculatorContext* cc) {
|
||||
// The output is at the same timestamp as the input.
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
if (cc->Inputs().HasTag("MINUEND")) {
|
||||
subtract_from_input_ = true;
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) {
|
||||
Matrix* subtracted = new Matrix();
|
||||
if (subtract_from_input_) {
|
||||
const Matrix& input_matrix = cc->Inputs().Tag("MINUEND").Get<Matrix>();
|
||||
const Matrix& side_input_matrix =
|
||||
cc->InputSidePackets().Tag("SUBTRAHEND").Get<Matrix>();
|
||||
if (input_matrix.rows() != side_input_matrix.rows() ||
|
||||
input_matrix.cols() != side_input_matrix.cols()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Input matrix and the input side matrix must have the same "
|
||||
"dimension.");
|
||||
}
|
||||
*subtracted = input_matrix - side_input_matrix;
|
||||
} else {
|
||||
const Matrix& input_matrix = cc->Inputs().Tag("SUBTRAHEND").Get<Matrix>();
|
||||
const Matrix& side_input_matrix =
|
||||
cc->InputSidePackets().Tag("MINUEND").Get<Matrix>();
|
||||
if (input_matrix.rows() != side_input_matrix.rows() ||
|
||||
input_matrix.cols() != side_input_matrix.cols()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Input matrix and the input side matrix must have the same "
|
||||
"dimension.");
|
||||
}
|
||||
*subtracted = side_input_matrix - input_matrix;
|
||||
}
|
||||
cc->Outputs().Index(0).Add(subtracted, cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
157
mediapipe/calculators/core/matrix_subtract_calculator_test.cc
Normal file
157
mediapipe/calculators/core/matrix_subtract_calculator_test.cc
Normal file
|
@ -0,0 +1,157 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
// A 3x4 Matrix of random integers in [0,1000).
|
||||
const char kMatrixText[] =
|
||||
"rows: 3\n"
|
||||
"cols: 4\n"
|
||||
"packed_data: 387\n"
|
||||
"packed_data: 940\n"
|
||||
"packed_data: 815\n"
|
||||
"packed_data: 825\n"
|
||||
"packed_data: 997\n"
|
||||
"packed_data: 884\n"
|
||||
"packed_data: 419\n"
|
||||
"packed_data: 763\n"
|
||||
"packed_data: 123\n"
|
||||
"packed_data: 30\n"
|
||||
"packed_data: 825\n"
|
||||
"packed_data: 299\n";
|
||||
|
||||
const char kMatrixText2[] =
|
||||
"rows: 3\n"
|
||||
"cols: 4\n"
|
||||
"packed_data: 388\n"
|
||||
"packed_data: 941\n"
|
||||
"packed_data: 816\n"
|
||||
"packed_data: 826\n"
|
||||
"packed_data: 998\n"
|
||||
"packed_data: 885\n"
|
||||
"packed_data: 420\n"
|
||||
"packed_data: 764\n"
|
||||
"packed_data: 124\n"
|
||||
"packed_data: 31\n"
|
||||
"packed_data: 826\n"
|
||||
"packed_data: 300\n";
|
||||
|
||||
TEST(MatrixSubtractCalculatorTest, WrongConfig) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "MatrixSubtractCalculator"
|
||||
input_stream: "input_matrix"
|
||||
input_side_packet: "SUBTRAHEND:side_matrix"
|
||||
input_side_packet: "MINUEND:side_matrix2"
|
||||
output_stream: "output_matrix"
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
auto status = runner.Run();
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"only accepts exactly one input stream and one input side packet"));
|
||||
}
|
||||
|
||||
TEST(MatrixSubtractCalculatorTest, WrongConfig2) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "MatrixSubtractCalculator"
|
||||
input_side_packet: "SUBTRAHEND:side_matrix"
|
||||
input_stream: "SUBTRAHEND:side_matrix2"
|
||||
output_stream: "output_matrix"
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
auto status = runner.Run();
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr("specify exactly one minuend and one subtrahend."));
|
||||
}
|
||||
|
||||
TEST(MatrixSubtractCalculatorTest, SubtractFromInput) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "MatrixSubtractCalculator"
|
||||
input_stream: "MINUEND:input_matrix"
|
||||
input_side_packet: "SUBTRAHEND:side_matrix"
|
||||
output_stream: "output_matrix"
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
Matrix* side_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText, side_matrix);
|
||||
runner.MutableSidePackets()->Tag("SUBTRAHEND") = Adopt(side_matrix);
|
||||
|
||||
Matrix* input_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText2, input_matrix);
|
||||
runner.MutableInputs()->Tag("MINUEND").packets.push_back(
|
||||
Adopt(input_matrix).At(Timestamp(0)));
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
|
||||
|
||||
EXPECT_EQ(Timestamp(0), runner.Outputs().Index(0).packets[0].Timestamp());
|
||||
const Eigen::MatrixXf& result =
|
||||
runner.Outputs().Index(0).packets[0].Get<Matrix>();
|
||||
ASSERT_EQ(3, result.rows());
|
||||
ASSERT_EQ(4, result.cols());
|
||||
EXPECT_NEAR(result.sum(), 12, 1e-5);
|
||||
}
|
||||
|
||||
TEST(MatrixSubtractCalculatorTest, SubtractFromSideMatrix) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "MatrixSubtractCalculator"
|
||||
input_stream: "SUBTRAHEND:input_matrix"
|
||||
input_side_packet: "MINUEND:side_matrix"
|
||||
output_stream: "output_matrix"
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
Matrix* side_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText, side_matrix);
|
||||
runner.MutableSidePackets()->Tag("MINUEND") = Adopt(side_matrix);
|
||||
|
||||
Matrix* input_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText2, input_matrix);
|
||||
runner.MutableInputs()
|
||||
->Tag("SUBTRAHEND")
|
||||
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
|
||||
|
||||
EXPECT_EQ(Timestamp(0), runner.Outputs().Index(0).packets[0].Timestamp());
|
||||
const Eigen::MatrixXf& result =
|
||||
runner.Outputs().Index(0).packets[0].Get<Matrix>();
|
||||
ASSERT_EQ(3, result.rows());
|
||||
ASSERT_EQ(4, result.cols());
|
||||
EXPECT_NEAR(result.sum(), -12, 1e-5);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
91
mediapipe/calculators/core/merge_calculator.cc
Normal file
91
mediapipe/calculators/core/merge_calculator.cc
Normal file
|
@ -0,0 +1,91 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// This calculator takes a set of input streams and combines them into a single
|
||||
// output stream. The packets from different streams do not need to contain the
|
||||
// same type. If there are packets arriving at the same time from two or more
|
||||
// input streams, the packet corresponding to the input stream with the smallest
|
||||
// index is passed to the output and the rest are ignored.
|
||||
//
|
||||
// Example use-case:
|
||||
// Suppose we have two (or more) different algorithms for detecting shot
|
||||
// boundaries and we need to merge their packets into a single stream. The
|
||||
// algorithms may emit shot boundaries at the same time and their output types
|
||||
// may not be compatible. Subsequent calculators that process the merged stream
|
||||
// may be interested only in the timestamps of the shot boundary packets and so
|
||||
// it may not even need to inspect the values stored inside the packets.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "MergeCalculator"
|
||||
// input_stream: "shot_info1"
|
||||
// input_stream: "shot_info2"
|
||||
// input_stream: "shot_info3"
|
||||
// output_stream: "merged_shot_infos"
|
||||
// }
|
||||
//
|
||||
class MergeCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK_GT(cc->Inputs().NumEntries(), 0)
|
||||
<< "Needs at least one input stream";
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
|
||||
if (cc->Inputs().NumEntries() == 1) {
|
||||
LOG(WARNING)
|
||||
<< "MergeCalculator expects multiple input streams to merge but is "
|
||||
"receiving only one. Make sure the calculator is configured "
|
||||
"correctly or consider removing this calculator to reduce "
|
||||
"unnecessary overhead.";
|
||||
}
|
||||
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
cc->Inputs().Index(i).SetAny();
|
||||
}
|
||||
cc->Outputs().Index(0).SetAny();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
// Output the packet from the first input stream with a packet ready at this
|
||||
// timestamp.
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
if (!cc->Inputs().Index(i).IsEmpty()) {
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(i).Value());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
}
|
||||
|
||||
LOG(WARNING) << "Empty input packets at timestamp "
|
||||
<< cc->InputTimestamp().Value();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(MergeCalculator);
|
||||
|
||||
} // namespace mediapipe
|
139
mediapipe/calculators/core/merge_calculator_test.cc
Normal file
139
mediapipe/calculators/core/merge_calculator_test.cc
Normal file
|
@ -0,0 +1,139 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
// Checks that the calculator fails if no input streams are provided.
|
||||
TEST(InvariantMergeInputStreamsCalculator, NoInputStreamsMustFail) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "MergeCalculator"
|
||||
output_stream: "merged_output"
|
||||
)"));
|
||||
// Expect calculator to fail.
|
||||
ASSERT_FALSE(runner.Run().ok());
|
||||
}
|
||||
|
||||
// Checks that the calculator fails with an incorrect number of output streams.
|
||||
TEST(InvariantMergeInputStreamsCalculator, ExpectExactlyOneOutputStream) {
|
||||
CalculatorRunner runner1(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "MergeCalculator"
|
||||
input_stream: "input1"
|
||||
input_stream: "input2"
|
||||
)"));
|
||||
// Expect calculator to fail.
|
||||
EXPECT_FALSE(runner1.Run().ok());
|
||||
|
||||
CalculatorRunner runner2(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "MergeCalculator"
|
||||
input_stream: "input1"
|
||||
input_stream: "input2"
|
||||
output_stream: "output1"
|
||||
output_stream: "output2"
|
||||
)"));
|
||||
// Expect calculator to fail.
|
||||
ASSERT_FALSE(runner2.Run().ok());
|
||||
}
|
||||
|
||||
// Ensures two streams with differing types can be merged correctly.
|
||||
TEST(MediaPipeDetectionToSoapboxDetectionCalculatorTest,
|
||||
TestMergingTwoStreams) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "MergeCalculator"
|
||||
input_stream: "input1"
|
||||
input_stream: "input2"
|
||||
output_stream: "combined_output"
|
||||
)"));
|
||||
|
||||
// input1: integers 10, 20, 30, occurring at times 10, 20, 30.
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new int(10)).At(Timestamp(10)));
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new int(20)).At(Timestamp(20)));
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new int(30)).At(Timestamp(30)));
|
||||
// input2: floats 5.5, 35.5 at times 5, 35.
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(new float(5.5)).At(Timestamp(5)));
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(new float(35.5)).At(Timestamp(35)));
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
// Expected combined_output: 5.5, 10, 20, 30, 35.5 at times 5, 10, 20, 30, 35.
|
||||
const std::vector<Packet>& actual_output = runner.Outputs().Index(0).packets;
|
||||
ASSERT_EQ(actual_output.size(), 5);
|
||||
EXPECT_EQ(actual_output[0].Timestamp(), Timestamp(5));
|
||||
EXPECT_EQ(actual_output[0].Get<float>(), 5.5);
|
||||
|
||||
EXPECT_EQ(actual_output[1].Timestamp(), Timestamp(10));
|
||||
EXPECT_EQ(actual_output[1].Get<int>(), 10);
|
||||
|
||||
EXPECT_EQ(actual_output[2].Timestamp(), Timestamp(20));
|
||||
EXPECT_EQ(actual_output[2].Get<int>(), 20);
|
||||
|
||||
EXPECT_EQ(actual_output[3].Timestamp(), Timestamp(30));
|
||||
EXPECT_EQ(actual_output[3].Get<int>(), 30);
|
||||
|
||||
EXPECT_EQ(actual_output[4].Timestamp(), Timestamp(35));
|
||||
EXPECT_EQ(actual_output[4].Get<float>(), 35.5);
|
||||
}
|
||||
|
||||
// Ensures three streams with differing types can be merged correctly.
|
||||
TEST(MediaPipeDetectionToSoapboxDetectionCalculatorTest,
|
||||
TestMergingThreeStreams) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "MergeCalculator"
|
||||
input_stream: "input1"
|
||||
input_stream: "input2"
|
||||
input_stream: "input3"
|
||||
output_stream: "combined_output"
|
||||
)"));
|
||||
|
||||
// input1: integer 30 occurring at time 30.
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new int(30)).At(Timestamp(30)));
|
||||
// input2: float 20.5 occurring at time 20.
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(new float(20.5)).At(Timestamp(20)));
|
||||
// input3: char 'c' occurring at time 10.
|
||||
runner.MutableInputs()->Index(2).packets.push_back(
|
||||
Adopt(new char('c')).At(Timestamp(10)));
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
|
||||
// Expected combined_output: 'c', 20.5, 30 at times 10, 20, 30.
|
||||
const std::vector<Packet>& actual_output = runner.Outputs().Index(0).packets;
|
||||
ASSERT_EQ(actual_output.size(), 3);
|
||||
EXPECT_EQ(actual_output[0].Timestamp(), Timestamp(10));
|
||||
EXPECT_EQ(actual_output[0].Get<char>(), 'c');
|
||||
|
||||
EXPECT_EQ(actual_output[1].Timestamp(), Timestamp(20));
|
||||
EXPECT_EQ(actual_output[1].Get<float>(), 20.5);
|
||||
|
||||
EXPECT_EQ(actual_output[2].Timestamp(), Timestamp(30));
|
||||
EXPECT_EQ(actual_output[2].Get<int>(), 30);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -51,7 +51,7 @@ class MuxCalculator : public CalculatorBase {
|
|||
data_input_base_ = cc->Inputs().GetId("INPUT", 0);
|
||||
num_data_inputs_ = cc->Inputs().NumEntries("INPUT");
|
||||
output_ = cc->Outputs().GetId("OUTPUT", 0);
|
||||
cc->SetOffset(mediapipe::TimestampDiff(0));
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/calculators/core/packet_cloner_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -39,6 +40,7 @@ namespace mediapipe {
|
|||
// }
|
||||
//
|
||||
// Related:
|
||||
// packet_cloner_calculator.proto: Options for this calculator.
|
||||
// merge_input_streams_calculator.cc: One output stream.
|
||||
// packet_inner_join_calculator.cc: Don't output unless all inputs are new.
|
||||
class PacketClonerCalculator : public CalculatorBase {
|
||||
|
@ -54,6 +56,13 @@ class PacketClonerCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
// Load options.
|
||||
const auto calculator_options =
|
||||
cc->Options<mediapipe::PacketClonerCalculatorOptions>();
|
||||
output_only_when_all_inputs_received_ =
|
||||
calculator_options.output_only_when_all_inputs_received();
|
||||
|
||||
// Parse input streams.
|
||||
tick_signal_index_ = cc->Inputs().NumEntries() - 1;
|
||||
current_.resize(tick_signal_index_);
|
||||
// Pass along the header for each stream if present.
|
||||
|
@ -73,8 +82,17 @@ class PacketClonerCalculator : public CalculatorBase {
|
|||
}
|
||||
}
|
||||
|
||||
// Output if the tick signal is non-empty.
|
||||
// Output according to the TICK signal.
|
||||
if (!cc->Inputs().Index(tick_signal_index_).Value().IsEmpty()) {
|
||||
if (output_only_when_all_inputs_received_) {
|
||||
// Return if one of the input is null.
|
||||
for (int i = 0; i < tick_signal_index_; ++i) {
|
||||
if (current_[i].IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
}
|
||||
}
|
||||
// Output each stream.
|
||||
for (int i = 0; i < tick_signal_index_; ++i) {
|
||||
if (!current_[i].IsEmpty()) {
|
||||
cc->Outputs().Index(i).AddPacket(
|
||||
|
@ -91,6 +109,7 @@ class PacketClonerCalculator : public CalculatorBase {
|
|||
private:
|
||||
std::vector<Packet> current_;
|
||||
int tick_signal_index_;
|
||||
bool output_only_when_all_inputs_received_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(PacketClonerCalculator);
|
||||
|
|
29
mediapipe/calculators/core/packet_cloner_calculator.proto
Normal file
29
mediapipe/calculators/core/packet_cloner_calculator.proto
Normal file
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message PacketClonerCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional PacketClonerCalculatorOptions ext = 258872085;
|
||||
}
|
||||
|
||||
// When true, this calculator will drop received TICK packets if any input
|
||||
// stream hasn't received a packet yet.
|
||||
optional bool output_only_when_all_inputs_received = 1 [default = false];
|
||||
}
|
78
mediapipe/calculators/core/packet_inner_join_calculator.cc
Normal file
78
mediapipe/calculators/core/packet_inner_join_calculator.cc
Normal file
|
@ -0,0 +1,78 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Calculator that acts like the SQL query:
|
||||
// SELECT *
|
||||
// FROM packets_on_stream1 AS packet1
|
||||
// INNER JOIN packets_on_stream2 AS packet2
|
||||
// ON packet1.timestamp = packet2.timestamp
|
||||
//
|
||||
// In other words, it only emits and forwards packets if all input streams are
|
||||
// not empty.
|
||||
//
|
||||
// Intended for use with FixedSizeInputStreamHandler.
|
||||
//
|
||||
// Related:
|
||||
// packet_cloner_calculator.cc: Repeats last-seen packets from empty inputs.
|
||||
class PacketInnerJoinCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
int num_streams_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(PacketInnerJoinCalculator);
|
||||
|
||||
::mediapipe::Status PacketInnerJoinCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().NumEntries() == cc->Outputs().NumEntries())
|
||||
<< "The number of input and output streams must match.";
|
||||
const int num_streams = cc->Inputs().NumEntries();
|
||||
for (int i = 0; i < num_streams; ++i) {
|
||||
cc->Inputs().Index(i).SetAny();
|
||||
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i));
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status PacketInnerJoinCalculator::Open(CalculatorContext* cc) {
|
||||
num_streams_ = cc->Inputs().NumEntries();
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status PacketInnerJoinCalculator::Process(CalculatorContext* cc) {
|
||||
for (int i = 0; i < num_streams_; ++i) {
|
||||
if (cc->Inputs().Index(i).Value().IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < num_streams_; ++i) {
|
||||
cc->Outputs().Index(i).AddPacket(cc->Inputs().Index(i).Value());
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
101
mediapipe/calculators/core/packet_inner_join_calculator_test.cc
Normal file
101
mediapipe/calculators/core/packet_inner_join_calculator_test.cc
Normal file
|
@ -0,0 +1,101 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
inline Packet PacketFrom(int i) { return Adopt(new int(i)).At(Timestamp(i)); }
|
||||
|
||||
TEST(PacketInnerJoinCalculatorTest, AllMatching) {
|
||||
// Test case.
|
||||
const std::vector<int> packets_on_stream1 = {0, 1, 2, 3};
|
||||
const std::vector<int> packets_on_stream2 = {0, 1, 2, 3};
|
||||
// Run.
|
||||
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
|
||||
for (int packet_load : packets_on_stream1) {
|
||||
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
for (int packet_load : packets_on_stream2) {
|
||||
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
// Check.
|
||||
const std::vector<int> expected = {0, 1, 2, 3};
|
||||
ASSERT_EQ(expected.size(), runner.Outputs().Index(0).packets.size());
|
||||
ASSERT_EQ(expected.size(), runner.Outputs().Index(1).packets.size());
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
const Packet packet1 = runner.Outputs().Index(0).packets[i];
|
||||
EXPECT_EQ(expected[i], packet1.Get<int>());
|
||||
EXPECT_EQ(expected[i], packet1.Timestamp().Value());
|
||||
const Packet packet2 = runner.Outputs().Index(1).packets[i];
|
||||
EXPECT_EQ(expected[i], packet2.Get<int>());
|
||||
EXPECT_EQ(expected[i], packet2.Timestamp().Value());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketInnerJoinCalculatorTest, NoneMatching) {
|
||||
// Test case.
|
||||
const std::vector<int> packets_on_stream1 = {0, 2};
|
||||
const std::vector<int> packets_on_stream2 = {1, 3};
|
||||
// Run.
|
||||
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
|
||||
for (int packet_load : packets_on_stream1) {
|
||||
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
for (int packet_load : packets_on_stream2) {
|
||||
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
// Check.
|
||||
EXPECT_TRUE(runner.Outputs().Index(0).packets.empty());
|
||||
EXPECT_TRUE(runner.Outputs().Index(1).packets.empty());
|
||||
}
|
||||
|
||||
TEST(PacketInnerJoinCalculatorTest, SomeMatching) {
|
||||
// Test case.
|
||||
const std::vector<int> packets_on_stream1 = {0, 1, 2, 3, 4, 6};
|
||||
const std::vector<int> packets_on_stream2 = {0, 2, 4, 5, 6};
|
||||
// Run.
|
||||
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
|
||||
for (int packet_load : packets_on_stream1) {
|
||||
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
for (int packet_load : packets_on_stream2) {
|
||||
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
// Check.
|
||||
const std::vector<int> expected = {0, 2, 4, 6};
|
||||
ASSERT_EQ(expected.size(), runner.Outputs().Index(0).packets.size());
|
||||
ASSERT_EQ(expected.size(), runner.Outputs().Index(1).packets.size());
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
const Packet packet1 = runner.Outputs().Index(0).packets[i];
|
||||
EXPECT_EQ(expected[i], packet1.Get<int>());
|
||||
EXPECT_EQ(expected[i], packet1.Timestamp().Value());
|
||||
const Packet packet2 = runner.Outputs().Index(1).packets[i];
|
||||
EXPECT_EQ(expected[i], packet2.Get<int>());
|
||||
EXPECT_EQ(expected[i], packet2.Timestamp().Value());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
102
mediapipe/calculators/core/quantize_float_vector_calculator.cc
Normal file
102
mediapipe/calculators/core/quantize_float_vector_calculator.cc
Normal file
|
@ -0,0 +1,102 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cfloat>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/quantize_float_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
// Quantizes a vector of floats to a std::string so that each float becomes a
|
||||
// byte in the [0, 255] range. Any value above max_quantized_value or below
|
||||
// min_quantized_value will be saturated to '/xFF' or '/0'.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "QuantizeFloatVectorCalculator"
|
||||
// input_stream: "FLOAT_VECTOR:float_vector"
|
||||
// output_stream: "ENCODED:encoded"
|
||||
// options {
|
||||
// [mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
// max_quantized_value: 64
|
||||
// min_quantized_value: -64
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
namespace mediapipe {
|
||||
|
||||
class QuantizeFloatVectorCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
|
||||
cc->Outputs().Tag("ENCODED").Set<std::string>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
const auto options =
|
||||
cc->Options<::mediapipe::QuantizeFloatVectorCalculatorOptions>();
|
||||
if (!options.has_max_quantized_value() ||
|
||||
!options.has_min_quantized_value()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Both max_quantized_value and min_quantized_value must be provided "
|
||||
"in QuantizeFloatVectorCalculatorOptions.");
|
||||
}
|
||||
max_quantized_value_ = options.max_quantized_value();
|
||||
min_quantized_value_ = options.min_quantized_value();
|
||||
if (max_quantized_value_ < min_quantized_value_ + FLT_EPSILON) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"max_quantized_value must be greater than min_quantized_value.");
|
||||
}
|
||||
range_ = max_quantized_value_ - min_quantized_value_;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
const std::vector<float>& float_vector =
|
||||
cc->Inputs().Tag("FLOAT_VECTOR").Value().Get<std::vector<float>>();
|
||||
int feature_size = float_vector.size();
|
||||
std::string encoded_features;
|
||||
encoded_features.reserve(feature_size);
|
||||
for (int i = 0; i < feature_size; i++) {
|
||||
float old_value = float_vector[i];
|
||||
if (old_value < min_quantized_value_) {
|
||||
old_value = min_quantized_value_;
|
||||
}
|
||||
if (old_value > max_quantized_value_) {
|
||||
old_value = max_quantized_value_;
|
||||
}
|
||||
unsigned char encoded = static_cast<unsigned char>(
|
||||
(old_value - min_quantized_value_) * (255.0 / range_));
|
||||
encoded_features += encoded;
|
||||
}
|
||||
cc->Outputs().Tag("ENCODED").AddPacket(
|
||||
MakePacket<std::string>(encoded_features).At(cc->InputTimestamp()));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
float max_quantized_value_;
|
||||
float min_quantized_value_;
|
||||
float range_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(QuantizeFloatVectorCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message QuantizeFloatVectorCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional QuantizeFloatVectorCalculatorOptions ext = 259848061;
|
||||
}
|
||||
|
||||
optional float max_quantized_value = 1;
|
||||
optional float min_quantized_value = 2;
|
||||
}
|
|
@ -0,0 +1,204 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
min_quantized_value: 1
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"Both max_quantized_value and min_quantized_value must be provided"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
max_quantized_value: -1
|
||||
min_quantized_value: 1
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"max_quantized_value must be greater than min_quantized_value"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
max_quantized_value: 1
|
||||
min_quantized_value: 1
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"max_quantized_value must be greater than min_quantized_value"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, TestEmptyVector) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
max_quantized_value: 1
|
||||
min_quantized_value: -1
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_TRUE(outputs[0].Get<std::string>().empty());
|
||||
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, TestNonEmptyVector) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
max_quantized_value: 64
|
||||
min_quantized_value: -64
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f};
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
const std::string& result = outputs[0].Get<std::string>();
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_EQ(5, result.size());
|
||||
// 127
|
||||
EXPECT_EQ('\x7F', result.c_str()[0]);
|
||||
// 0
|
||||
EXPECT_EQ('\0', result.c_str()[1]);
|
||||
// 255
|
||||
EXPECT_EQ('\xFF', result.c_str()[2]);
|
||||
// 63
|
||||
EXPECT_EQ('\x3F', result.c_str()[3]);
|
||||
// 191
|
||||
EXPECT_EQ('\xBF', result.c_str()[4]);
|
||||
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, TestSaturation) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
max_quantized_value: 64
|
||||
min_quantized_value: -64
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> vector = {-65.0f, 65.0f};
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
const std::string& result = outputs[0].Get<std::string>();
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_EQ(2, result.size());
|
||||
// 0
|
||||
EXPECT_EQ('\0', result.c_str()[0]);
|
||||
// 255
|
||||
EXPECT_EQ('\xFF', result.c_str()[1]);
|
||||
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
114
mediapipe/calculators/core/sequence_shift_calculator.cc
Normal file
114
mediapipe/calculators/core/sequence_shift_calculator.cc
Normal file
|
@ -0,0 +1,114 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "mediapipe/calculators/core/sequence_shift_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A Calculator that shifts the timestamps of packets along a stream. Packets on
|
||||
// the input stream are output with a timestamp of the packet given by packet
|
||||
// offset. That is, packet[i] is output with the timestamp of
|
||||
// packet[i + packet_offset]. Packet offset can be either positive or negative.
|
||||
// If packet_offset is -n, the first n packets will be dropped. If packet offset
|
||||
// is n, the final n packets will be dropped. For example, with a packet_offset
|
||||
// of -1, the first packet on the stream will be dropped, the second will be
|
||||
// output with the timestamp of the first, the third with the timestamp of the
|
||||
// second, and so on.
|
||||
class SequenceShiftCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Reads from options to set cache_size_ and packet_offset_.
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// A positive offset means we want a packet to be output with the timestamp of
|
||||
// a later packet. Stores packets waiting for their output timestamps and
|
||||
// outputs a single packet when the cache fills.
|
||||
void ProcessPositiveOffset(CalculatorContext* cc);
|
||||
|
||||
// A negative offset means we want a packet to be output with the timestamp of
|
||||
// an earlier packet. Stores timestamps waiting for the corresponding input
|
||||
// packet and outputs a single packet when the cache fills.
|
||||
void ProcessNegativeOffset(CalculatorContext* cc);
|
||||
|
||||
// Storage for packets waiting to be output when packet_offset > 0. When cache
|
||||
// is full, oldest packet is output with current timestamp.
|
||||
std::deque<Packet> packet_cache_;
|
||||
|
||||
// Storage for previous timestamps used when packet_offset < 0. When cache is
|
||||
// full, oldest timestamp is used for current packet.
|
||||
std::deque<Timestamp> timestamp_cache_;
|
||||
|
||||
// Copied from corresponding field in options.
|
||||
int packet_offset_;
|
||||
// The number of packets or timestamps we need to store to output packet[i] at
|
||||
// the timestamp of packet[i + packet_offset]; equal to abs(packet_offset).
|
||||
int cache_size_;
|
||||
};
|
||||
REGISTER_CALCULATOR(SequenceShiftCalculator);
|
||||
|
||||
::mediapipe::Status SequenceShiftCalculator::Open(CalculatorContext* cc) {
|
||||
packet_offset_ =
|
||||
cc->Options<mediapipe::SequenceShiftCalculatorOptions>().packet_offset();
|
||||
cache_size_ = abs(packet_offset_);
|
||||
// An offset of zero is a no-op, but someone might still request it.
|
||||
if (packet_offset_ == 0) {
|
||||
cc->Outputs().Index(0).SetOffset(0);
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) {
|
||||
if (packet_offset_ > 0) {
|
||||
ProcessPositiveOffset(cc);
|
||||
} else if (packet_offset_ < 0) {
|
||||
ProcessNegativeOffset(cc);
|
||||
} else {
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) {
|
||||
if (packet_cache_.size() >= cache_size_) {
|
||||
// Ready to output oldest packet with current timestamp.
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
packet_cache_.front().At(cc->InputTimestamp()));
|
||||
packet_cache_.pop_front();
|
||||
}
|
||||
// Store current packet for later output.
|
||||
packet_cache_.push_back(cc->Inputs().Index(0).Value());
|
||||
}
|
||||
|
||||
void SequenceShiftCalculator::ProcessNegativeOffset(CalculatorContext* cc) {
|
||||
if (timestamp_cache_.size() >= cache_size_) {
|
||||
// Ready to output current packet with oldest timestamp.
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
cc->Inputs().Index(0).Value().At(timestamp_cache_.front()));
|
||||
timestamp_cache_.pop_front();
|
||||
}
|
||||
// Store current timestamp for use by a future packet.
|
||||
timestamp_cache_.push_back(cc->InputTimestamp());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
26
mediapipe/calculators/core/sequence_shift_calculator.proto
Normal file
26
mediapipe/calculators/core/sequence_shift_calculator.proto
Normal file
|
@ -0,0 +1,26 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message SequenceShiftCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional SequenceShiftCalculatorOptions ext = 107633927;
|
||||
}
|
||||
optional int32 packet_offset = 1 [default = -1];
|
||||
}
|
104
mediapipe/calculators/core/sequence_shift_calculator_test.cc
Normal file
104
mediapipe/calculators/core/sequence_shift_calculator_test.cc
Normal file
|
@ -0,0 +1,104 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
// Adds packets containing integers equal to their original timestamp.
|
||||
void AddPackets(CalculatorRunner* runner) {
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
runner->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new int(i)).At(Timestamp(i)));
|
||||
}
|
||||
}
|
||||
|
||||
// Zero shift is a no-op (output input[i] at timestamp[i]). Input and output
|
||||
// streams should be identical.
|
||||
TEST(SequenceShiftCalculatorTest, ZeroShift) {
|
||||
CalculatorRunner runner(
|
||||
"SequenceShiftCalculator",
|
||||
"[mediapipe.SequenceShiftCalculatorOptions.ext]: { packet_offset: 0 }", 1,
|
||||
1, 0);
|
||||
AddPackets(&runner);
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& input_packets =
|
||||
runner.MutableInputs()->Index(0).packets;
|
||||
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
|
||||
ASSERT_EQ(10, input_packets.size());
|
||||
ASSERT_EQ(input_packets.size(), output_packets.size());
|
||||
for (int i = 0; i < output_packets.size(); ++i) {
|
||||
// Make sure the contents are as expected.
|
||||
EXPECT_EQ(input_packets[i].Get<int>(), output_packets[i].Get<int>());
|
||||
EXPECT_EQ(input_packets[i].Timestamp(), output_packets[i].Timestamp());
|
||||
}
|
||||
}
|
||||
|
||||
// Tests shifting by three packets, i.e., output input[i] with the timestamp of
|
||||
// input[i + 3].
|
||||
TEST(SequenceShiftCalculatorTest, PositiveShift) {
|
||||
CalculatorRunner runner(
|
||||
"SequenceShiftCalculator",
|
||||
"[mediapipe.SequenceShiftCalculatorOptions.ext]: { packet_offset: 3 }", 1,
|
||||
1, 0);
|
||||
AddPackets(&runner);
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& input_packets =
|
||||
runner.MutableInputs()->Index(0).packets;
|
||||
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
|
||||
ASSERT_EQ(10, input_packets.size());
|
||||
// input_packet[i] should be output with the timestamp of input_packet[i + 3].
|
||||
// The last 3 packets are dropped.
|
||||
ASSERT_EQ(7, output_packets.size());
|
||||
for (int i = 0; i < output_packets.size(); ++i) {
|
||||
// Make sure the contents are as expected.
|
||||
EXPECT_EQ(input_packets[i].Get<int>(), output_packets[i].Get<int>());
|
||||
// Make sure the timestamps are shifted as expected.
|
||||
EXPECT_EQ(input_packets[i + 3].Timestamp(), output_packets[i].Timestamp());
|
||||
}
|
||||
}
|
||||
|
||||
// Tests shifting by -2, i.e., output input[i] with timestamp[i - 2]. The first
|
||||
// two packets should be dropped.
|
||||
TEST(SequenceShiftCalculatorTest, NegativeShift) {
|
||||
CalculatorRunner runner(
|
||||
"SequenceShiftCalculator",
|
||||
"[mediapipe.SequenceShiftCalculatorOptions.ext]: { packet_offset: -2 }",
|
||||
1, 1, 0);
|
||||
AddPackets(&runner);
|
||||
MEDIAPIPE_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& input_packets =
|
||||
runner.MutableInputs()->Index(0).packets;
|
||||
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
|
||||
ASSERT_EQ(10, input_packets.size());
|
||||
// Input packet[i] should be output with the timestamp of input packet[i - 2].
|
||||
// The first two packets are dropped. This means timestamps match between
|
||||
// input and output packets, but the data in the output packets come from
|
||||
// input_packets[i + 2].
|
||||
ASSERT_EQ(8, output_packets.size());
|
||||
for (int i = 0; i < output_packets.size(); ++i) {
|
||||
EXPECT_EQ(input_packets[i].Timestamp(), output_packets[i].Timestamp());
|
||||
EXPECT_EQ(input_packets[i + 2].Get<int>(), output_packets[i].Get<int>());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mediapipe
|
40
mediapipe/calculators/core/split_vector_calculator.cc
Normal file
40
mediapipe/calculators/core/split_vector_calculator.cc
Normal file
|
@ -0,0 +1,40 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
// input_stream: "tflitetensor_vector"
|
||||
// output_stream: "tflitetensor_vector_range_0"
|
||||
// output_stream: "tflitetensor_vector_range_1"
|
||||
// options {
|
||||
// [mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
// ranges: { begin: 0 end: 1 }
|
||||
// ranges: { begin: 1 end: 2 }
|
||||
// element_only: false
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
typedef SplitVectorCalculator<TfLiteTensor> SplitTfLiteTensorVectorCalculator;
|
||||
REGISTER_CALCULATOR(SplitTfLiteTensorVectorCalculator);
|
||||
|
||||
} // namespace mediapipe
|
125
mediapipe/calculators/core/split_vector_calculator.h
Normal file
125
mediapipe/calculators/core/split_vector_calculator.h
Normal file
|
@ -0,0 +1,125 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_SPLIT_VECTOR_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_SPLIT_VECTOR_CALCULATOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
#include "tensorflow/lite/error_reporter.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Splits an input packet with std::vector<T> into multiple std::vector<T>
|
||||
// output packets using the [begin, end) ranges specified in
|
||||
// SplitVectorCalculatorOptions. If the option "element_only" is set to true,
|
||||
// all ranges should be of size 1 and all outputs will be elements of type T. If
|
||||
// "element_only" is false, ranges can be non-zero in size and all outputs will
|
||||
// be of type std::vector<T>.
|
||||
// To use this class for a particular type T, register a calculator using
|
||||
// SplitVectorCalculator<T>.
|
||||
template <typename T>
|
||||
class SplitVectorCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().NumEntries() == 1);
|
||||
RET_CHECK(cc->Outputs().NumEntries() != 0);
|
||||
|
||||
cc->Inputs().Index(0).Set<std::vector<T>>();
|
||||
|
||||
const auto& options =
|
||||
cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
|
||||
|
||||
if (cc->Outputs().NumEntries() != options.ranges_size()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"The number of output streams should match the number of ranges "
|
||||
"specified in the CalculatorOptions.");
|
||||
}
|
||||
|
||||
// Set the output types for each output stream.
|
||||
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
|
||||
if (options.ranges(i).begin() < 0 || options.ranges(i).end() < 0 ||
|
||||
options.ranges(i).begin() >= options.ranges(i).end()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Indices should be non-negative and begin index should be less "
|
||||
"than the end index.");
|
||||
}
|
||||
if (options.element_only()) {
|
||||
if (options.ranges(i).end() - options.ranges(i).begin() != 1) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Since element_only is true, all ranges should be of size 1.");
|
||||
}
|
||||
cc->Outputs().Index(i).Set<T>();
|
||||
} else {
|
||||
cc->Outputs().Index(i).Set<std::vector<T>>();
|
||||
}
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
const auto& options =
|
||||
cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
|
||||
|
||||
for (const auto& range : options.ranges()) {
|
||||
ranges_.push_back({range.begin(), range.end()});
|
||||
max_range_end_ = std::max(max_range_end_, range.end());
|
||||
}
|
||||
|
||||
element_only_ = options.element_only();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
const auto& input = cc->Inputs().Index(0).Get<std::vector<T>>();
|
||||
RET_CHECK_GE(input.size(), max_range_end_);
|
||||
|
||||
if (element_only_) {
|
||||
for (int i = 0; i < ranges_.size(); ++i) {
|
||||
cc->Outputs().Index(i).AddPacket(
|
||||
MakePacket<T>(input[ranges_[i].first]).At(cc->InputTimestamp()));
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ranges_.size(); ++i) {
|
||||
auto output = absl::make_unique<std::vector<T>>(
|
||||
input.begin() + ranges_[i].first,
|
||||
input.begin() + ranges_[i].second);
|
||||
cc->Outputs().Index(i).Add(output.release(), cc->InputTimestamp());
|
||||
}
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::pair<int32, int32>> ranges_;
|
||||
int32 max_range_end_ = -1;
|
||||
bool element_only_ = false;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_SPLIT_VECTOR_CALCULATOR_H_
|
40
mediapipe/calculators/core/split_vector_calculator.proto
Normal file
40
mediapipe/calculators/core/split_vector_calculator.proto
Normal file
|
@ -0,0 +1,40 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
// A Range {begin, end} specifies beginning ane ending indices to splice a
|
||||
// vector. A vector v is spliced to have elements v[begin:(end-1)], i.e., with
|
||||
// begin index inclusive and end index exclusive.
|
||||
message Range {
|
||||
optional int32 begin = 1;
|
||||
optional int32 end = 2;
|
||||
}
|
||||
|
||||
message SplitVectorCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional SplitVectorCalculatorOptions ext = 259438222;
|
||||
}
|
||||
|
||||
repeated Range ranges = 1;
|
||||
|
||||
// Specify if single element ranges should be outputted as std::vector<T> or
|
||||
// just element of type T. By default, if a range specifies only one element,
|
||||
// it is outputted as an std::vector<T>.
|
||||
optional bool element_only = 2 [default = false];
|
||||
}
|
321
mediapipe/calculators/core/split_vector_calculator_test.cc
Normal file
321
mediapipe/calculators/core/split_vector_calculator_test.cc
Normal file
|
@ -0,0 +1,321 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
#include "tensorflow/lite/error_reporter.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::tflite::Interpreter;
|
||||
|
||||
const int width = 1;
|
||||
const int height = 1;
|
||||
const int channels = 1;
|
||||
|
||||
class SplitTfLiteTensorVectorCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void TearDown() {
|
||||
// Note: Since the pointers contained in this vector will be cleaned up by
|
||||
// the interpreter, only ensure that the vector is cleaned up for the next
|
||||
// test.
|
||||
input_buffers_.clear();
|
||||
}
|
||||
|
||||
void PrepareTfLiteTensorVector(int vector_size) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
// Prepare input tensors.
|
||||
std::vector<int> indices(vector_size);
|
||||
for (int i = 0; i < vector_size; ++i) {
|
||||
indices[i] = i;
|
||||
}
|
||||
interpreter_->AddTensors(vector_size);
|
||||
interpreter_->SetInputs(indices);
|
||||
|
||||
input_vec_ = absl::make_unique<std::vector<TfLiteTensor>>();
|
||||
for (int i = 0; i < vector_size; ++i) {
|
||||
interpreter_->SetTensorParametersReadWrite(i, kTfLiteFloat32, "", {3},
|
||||
TfLiteQuantization());
|
||||
const int tensor_index = interpreter_->inputs()[i];
|
||||
interpreter_->ResizeInputTensor(tensor_index, {width, height, channels});
|
||||
}
|
||||
|
||||
interpreter_->AllocateTensors();
|
||||
|
||||
// Save the tensor buffer pointers for comparison after the graph runs.
|
||||
input_buffers_ = std::vector<float*>(vector_size);
|
||||
for (int i = 0; i < vector_size; ++i) {
|
||||
const int tensor_index = interpreter_->inputs()[i];
|
||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
|
||||
float* tensor_buffer = tensor->data.f;
|
||||
ASSERT_NE(tensor_buffer, nullptr);
|
||||
for (int j = 0; j < width * height * channels; ++j) {
|
||||
tensor_buffer[j] = i;
|
||||
}
|
||||
input_vec_->push_back(*tensor);
|
||||
input_buffers_[i] = tensor_buffer;
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateVectorOutput(std::vector<Packet>& output_packets,
|
||||
int expected_elements, int input_begin_index) {
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const std::vector<TfLiteTensor>& output_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensor>>();
|
||||
ASSERT_EQ(expected_elements, output_vec.size());
|
||||
|
||||
for (int i = 0; i < expected_elements; ++i) {
|
||||
const int expected_value = input_begin_index + i;
|
||||
const TfLiteTensor* result = &output_vec[i];
|
||||
float* result_buffer = result->data.f;
|
||||
ASSERT_NE(result_buffer, nullptr);
|
||||
ASSERT_EQ(result_buffer, input_buffers_[input_begin_index + i]);
|
||||
for (int j = 0; j < width * height * channels; ++j) {
|
||||
ASSERT_EQ(expected_value, result_buffer[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateElementOutput(std::vector<Packet>& output_packets,
|
||||
int input_begin_index) {
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
|
||||
const TfLiteTensor& result = output_packets[0].Get<TfLiteTensor>();
|
||||
float* result_buffer = result.data.f;
|
||||
ASSERT_NE(result_buffer, nullptr);
|
||||
ASSERT_EQ(result_buffer, input_buffers_[input_begin_index]);
|
||||
|
||||
const int expected_value = input_begin_index;
|
||||
for (int j = 0; j < width * height * channels; ++j) {
|
||||
ASSERT_EQ(expected_value, result_buffer[j]);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Interpreter> interpreter_ = absl::make_unique<Interpreter>();
|
||||
std::unique_ptr<std::vector<TfLiteTensor>> input_vec_ = nullptr;
|
||||
std::vector<float*> input_buffers_;
|
||||
std::unique_ptr<CalculatorRunner> runner_ = nullptr;
|
||||
};
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTest) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
PrepareTfLiteTensorVector(/*vector_size=*/5);
|
||||
ASSERT_NE(input_vec_, nullptr);
|
||||
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
input_stream: "tensor_in"
|
||||
output_stream: "range_0"
|
||||
output_stream: "range_1"
|
||||
output_stream: "range_2"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 1 }
|
||||
ranges: { begin: 1 end: 4 }
|
||||
ranges: { begin: 4 end: 5 }
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
std::vector<Packet> range_0_packets;
|
||||
tool::AddVectorSink("range_0", &graph_config, &range_0_packets);
|
||||
std::vector<Packet> range_1_packets;
|
||||
tool::AddVectorSink("range_1", &graph_config, &range_1_packets);
|
||||
std::vector<Packet> range_2_packets;
|
||||
tool::AddVectorSink("range_2", &graph_config, &range_2_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MEDIAPIPE_ASSERT_OK(graph.StartRun({}));
|
||||
MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tensor_in", Adopt(input_vec_.release()).At(Timestamp(0))));
|
||||
// Wait until the calculator finishes processing.
|
||||
MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
ValidateVectorOutput(range_0_packets, /*expected_elements=*/1,
|
||||
/*input_begin_index=*/0);
|
||||
ValidateVectorOutput(range_1_packets, /*expected_elements=*/3,
|
||||
/*input_begin_index=*/1);
|
||||
ValidateVectorOutput(range_2_packets, /*expected_elements=*/1,
|
||||
/*input_begin_index=*/4);
|
||||
|
||||
// Fully close the graph at the end.
|
||||
MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("tensor_in"));
|
||||
MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidRangeTest) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
input_stream: "tensor_in"
|
||||
output_stream: "range_0"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 0 }
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
// The graph should fail running because of an invalid range (begin == end).
|
||||
ASSERT_FALSE(graph.Initialize(graph_config).ok());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOutputStreamCountTest) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
input_stream: "tensor_in"
|
||||
output_stream: "range_0"
|
||||
output_stream: "range_1"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 1 }
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
// The graph should fail running because the number of output streams does not
|
||||
// match the number of range elements in the options.
|
||||
ASSERT_FALSE(graph.Initialize(graph_config).ok());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
PrepareTfLiteTensorVector(/*vector_size=*/5);
|
||||
ASSERT_NE(input_vec_, nullptr);
|
||||
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
input_stream: "tensor_in"
|
||||
output_stream: "range_0"
|
||||
output_stream: "range_1"
|
||||
output_stream: "range_2"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 1 }
|
||||
ranges: { begin: 2 end: 3 }
|
||||
ranges: { begin: 4 end: 5 }
|
||||
element_only: true
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
std::vector<Packet> range_0_packets;
|
||||
tool::AddVectorSink("range_0", &graph_config, &range_0_packets);
|
||||
std::vector<Packet> range_1_packets;
|
||||
tool::AddVectorSink("range_1", &graph_config, &range_1_packets);
|
||||
std::vector<Packet> range_2_packets;
|
||||
tool::AddVectorSink("range_2", &graph_config, &range_2_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MEDIAPIPE_ASSERT_OK(graph.StartRun({}));
|
||||
MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tensor_in", Adopt(input_vec_.release()).At(Timestamp(0))));
|
||||
// Wait until the calculator finishes processing.
|
||||
MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
ValidateElementOutput(range_0_packets,
|
||||
/*input_begin_index=*/0);
|
||||
ValidateElementOutput(range_1_packets,
|
||||
/*input_begin_index=*/2);
|
||||
ValidateElementOutput(range_2_packets,
|
||||
/*input_begin_index=*/4);
|
||||
|
||||
// Fully close the graph at the end.
|
||||
MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("tensor_in"));
|
||||
MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest,
|
||||
ElementOnlyDisablesVectorOutputs) {
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
input_stream: "tensor_in"
|
||||
output_stream: "range_0"
|
||||
output_stream: "range_1"
|
||||
output_stream: "range_2"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 1 }
|
||||
ranges: { begin: 1 end: 4 }
|
||||
ranges: { begin: 4 end: 5 }
|
||||
element_only: true
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
ASSERT_FALSE(graph.Initialize(graph_config).ok());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -19,6 +19,7 @@ package(default_visibility = ["//visibility:private"])
|
|||
exports_files(["LICENSE"])
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||
|
||||
proto_library(
|
||||
name = "opencv_image_encoder_calculator_proto",
|
||||
|
@ -46,6 +47,26 @@ proto_library(
|
|||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "image_cropping_calculator_proto",
|
||||
srcs = ["image_cropping_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "bilateral_filter_calculator_proto",
|
||||
srcs = ["bilateral_filter_calculator.proto"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "recolor_calculator_proto",
|
||||
srcs = ["recolor_calculator.proto"],
|
||||
|
@ -93,6 +114,26 @@ mediapipe_cc_proto_library(
|
|||
deps = [":set_alpha_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "image_cropping_calculator_cc_proto",
|
||||
srcs = ["image_cropping_calculator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
deps = [":image_cropping_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "bilateral_filter_calculator_cc_proto",
|
||||
srcs = ["bilateral_filter_calculator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
deps = [":bilateral_filter_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "recolor_calculator_cc_proto",
|
||||
srcs = ["recolor_calculator.proto"],
|
||||
|
@ -185,6 +226,42 @@ cc_library(
|
|||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:vector",
|
||||
] + select({
|
||||
"//mediapipe:android": [
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//mediapipe/gpu:gl_simple_shaders",
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"//mediapipe/gpu:shader_util",
|
||||
],
|
||||
"//mediapipe:ios": [
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//mediapipe/gpu:gl_simple_shaders",
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"//mediapipe/gpu:shader_util",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bilateral_filter_calculator",
|
||||
srcs = ["bilateral_filter_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
":bilateral_filter_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:vector",
|
||||
] + select({
|
||||
"//mediapipe:android": [
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
|
@ -221,6 +298,19 @@ mediapipe_cc_proto_library(
|
|||
cc_library(
|
||||
name = "image_transformation_calculator",
|
||||
srcs = ["image_transformation_calculator.cc"],
|
||||
copts = select({
|
||||
"//mediapipe:apple": [
|
||||
"-x objective-c++",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkopts = select({
|
||||
"//mediapipe:apple": [
|
||||
"-framework CoreVideo",
|
||||
"-framework MetalKit",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":image_transformation_calculator_cc_proto",
|
||||
|
@ -232,8 +322,8 @@ cc_library(
|
|||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
] + select({
|
||||
"//mediapipe:android": [
|
||||
] + selects.with_or({
|
||||
("//mediapipe:android", "//mediapipe:ios"): [
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//mediapipe/gpu:gl_simple_shaders",
|
||||
"//mediapipe/gpu:gl_quad_renderer",
|
||||
|
@ -247,8 +337,24 @@ cc_library(
|
|||
cc_library(
|
||||
name = "image_cropping_calculator",
|
||||
srcs = ["image_cropping_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
copts = select({
|
||||
"//mediapipe:apple": [
|
||||
"-x objective-c++",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkopts = select({
|
||||
"//mediapipe:apple": [
|
||||
"-framework CoreVideo",
|
||||
"-framework MetalKit",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
":image_cropping_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
|
@ -257,7 +363,15 @@ cc_library(
|
|||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
] + selects.with_or({
|
||||
("//mediapipe:android", "//mediapipe:ios"): [
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//mediapipe/gpu:gl_simple_shaders",
|
||||
"//mediapipe/gpu:gl_quad_renderer",
|
||||
"//mediapipe/gpu:shader_util",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
|
@ -307,6 +421,12 @@ cc_library(
|
|||
"//mediapipe/gpu:gpu_buffer",
|
||||
"//mediapipe/gpu:shader_util",
|
||||
],
|
||||
"//mediapipe:ios": [
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//mediapipe/gpu:gl_simple_shaders",
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"//mediapipe/gpu:shader_util",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
|
@ -357,6 +477,24 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image_properties_calculator",
|
||||
srcs = ["image_properties_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
] + selects.with_or({
|
||||
("//mediapipe:android", "//mediapipe:ios"): [
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "opencv_encoded_image_to_image_frame_calculator_test",
|
||||
srcs = ["opencv_encoded_image_to_image_frame_calculator_test.cc"],
|
||||
|
|
553
mediapipe/calculators/image/bilateral_filter_calculator.cc
Normal file
553
mediapipe/calculators/image/bilateral_filter_calculator.cc
Normal file
|
@ -0,0 +1,553 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/calculators/image/bilateral_filter_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_options.pb.h"
|
||||
#include "mediapipe/framework/formats/image_format.pb.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/vector.h"
|
||||
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gl_simple_shaders.h"
|
||||
#include "mediapipe/gpu/shader_util.h"
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
constexpr char kInputFrameTag[] = "IMAGE";
|
||||
constexpr char kInputGuideTag[] = "GUIDE";
|
||||
constexpr char kOutputFrameTag[] = "IMAGE";
|
||||
|
||||
constexpr char kInputFrameTagGpu[] = "IMAGE_GPU";
|
||||
constexpr char kInputGuideTagGpu[] = "GUIDE_GPU";
|
||||
constexpr char kOutputFrameTagGpu[] = "IMAGE_GPU";
|
||||
|
||||
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
|
||||
} // namespace
|
||||
|
||||
// A calculator for applying a bilateral filter to an image,
|
||||
// with an optional guide image (joint blateral).
|
||||
//
|
||||
// Inputs:
|
||||
// One of the following two IMAGE tags:
|
||||
// IMAGE: ImageFrame containing input image - Grayscale or RGB only.
|
||||
// IMAGE_GPU: GpuBuffer containing input image - Grayscale, RGB or RGBA.
|
||||
//
|
||||
// GUIDE (optional): ImageFrame guide image used to filter IMAGE. (N/A).
|
||||
// GUIDE_GPU (optional): GpuBuffer guide image used to filter IMAGE_GPU.
|
||||
//
|
||||
// Output:
|
||||
// One of the following two tags:
|
||||
// IMAGE: A filtered ImageFrame - Same as input.
|
||||
// IMAGE_GPU: A filtered GpuBuffer - RGBA
|
||||
//
|
||||
// Options:
|
||||
// sigma_space: Pixel radius: use (sigma_space*2+1)x(sigma_space*2+1) window.
|
||||
// This should be set based on output image pixel space.
|
||||
// sigma_color: Color variance: normalized [0-1] color difference allowed.
|
||||
//
|
||||
// Notes:
|
||||
// * When GUIDE is present, the output image is same size as GUIDE image;
|
||||
// otherwise, the output image is same size as input image.
|
||||
// * On GPU the kernel window is subsampled by approximately sqrt(sigma_space)
|
||||
// i.e. the step size is ~sqrt(sigma_space),
|
||||
// prioritizing performance > quality.
|
||||
// * TODO: Add CPU path for joint filter.
|
||||
//
|
||||
class BilateralFilterCalculator : public CalculatorBase {
|
||||
public:
|
||||
BilateralFilterCalculator() = default;
|
||||
~BilateralFilterCalculator() override = default;
|
||||
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
// From Calculator.
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
::mediapipe::Status RenderGpu(CalculatorContext* cc);
|
||||
::mediapipe::Status RenderCpu(CalculatorContext* cc);
|
||||
|
||||
::mediapipe::Status GlSetup(CalculatorContext* cc);
|
||||
void GlRender(CalculatorContext* cc);
|
||||
|
||||
mediapipe::BilateralFilterCalculatorOptions options_;
|
||||
float sigma_color_ = -1.f;
|
||||
float sigma_space_ = -1.f;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
bool gpu_initialized_ = false;
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
GLuint program_ = 0;
|
||||
GLuint program_joint_ = 0;
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
};
|
||||
REGISTER_CALCULATOR(BilateralFilterCalculator);
|
||||
|
||||
::mediapipe::Status BilateralFilterCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
CHECK_GE(cc->Inputs().NumEntries(), 1);
|
||||
|
||||
if (cc->Inputs().HasTag(kInputFrameTag) &&
|
||||
cc->Inputs().HasTag(kInputFrameTagGpu)) {
|
||||
return ::mediapipe::InternalError("Cannot have multiple input images.");
|
||||
}
|
||||
if (cc->Inputs().HasTag(kInputFrameTagGpu) !=
|
||||
cc->Outputs().HasTag(kOutputFrameTagGpu)) {
|
||||
return ::mediapipe::InternalError("GPU output must have GPU input.");
|
||||
}
|
||||
|
||||
// Input image to filter.
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
if (cc->Inputs().HasTag(kInputFrameTagGpu)) {
|
||||
cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>();
|
||||
}
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
if (cc->Inputs().HasTag(kInputFrameTag)) {
|
||||
cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
// Input guide image mask (optional)
|
||||
if (cc->Inputs().HasTag(kInputGuideTagGpu)) {
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
cc->Inputs().Tag(kInputGuideTagGpu).Set<mediapipe::GpuBuffer>();
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
}
|
||||
if (cc->Inputs().HasTag(kInputGuideTag)) {
|
||||
cc->Inputs().Tag(kInputGuideTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
// Output image.
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
if (cc->Outputs().HasTag(kOutputFrameTagGpu)) {
|
||||
cc->Outputs().Tag(kOutputFrameTagGpu).Set<mediapipe::GpuBuffer>();
|
||||
}
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
if (cc->Outputs().HasTag(kOutputFrameTag)) {
|
||||
cc->Outputs().Tag(kOutputFrameTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status BilateralFilterCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
options_ = cc->Options<mediapipe::BilateralFilterCalculatorOptions>();
|
||||
|
||||
if (cc->Inputs().HasTag(kInputFrameTagGpu) &&
|
||||
cc->Outputs().HasTag(kOutputFrameTagGpu)) {
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
use_gpu_ = true;
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet.";
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
}
|
||||
|
||||
sigma_color_ = options_.sigma_color();
|
||||
sigma_space_ = options_.sigma_space();
|
||||
CHECK_GE(sigma_color_, 0.0);
|
||||
CHECK_GE(sigma_space_, 0.0);
|
||||
if (!use_gpu_) sigma_color_ *= 255.0;
|
||||
|
||||
if (use_gpu_) {
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#endif
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status BilateralFilterCalculator::Process(CalculatorContext* cc) {
|
||||
if (use_gpu_) {
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
|
||||
if (!gpu_initialized_) {
|
||||
RETURN_IF_ERROR(GlSetup(cc));
|
||||
gpu_initialized_ = true;
|
||||
}
|
||||
RETURN_IF_ERROR(RenderGpu(cc));
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
} else {
|
||||
RETURN_IF_ERROR(RenderCpu(cc));
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status BilateralFilterCalculator::Close(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
gpu_helper_.RunInGlContext([this] {
|
||||
if (program_) glDeleteProgram(program_);
|
||||
program_ = 0;
|
||||
if (program_joint_) glDeleteProgram(program_joint_);
|
||||
program_joint_ = 0;
|
||||
});
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status BilateralFilterCalculator::RenderCpu(
|
||||
CalculatorContext* cc) {
|
||||
if (cc->Inputs().Tag(kInputFrameTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
|
||||
auto input_mat = mediapipe::formats::MatView(&input_frame);
|
||||
|
||||
// Only 1 or 3 channel images supported by OpenCV.
|
||||
if ((input_mat.channels() == 1 || input_mat.channels() == 3)) {
|
||||
return ::mediapipe::InternalError(
|
||||
"CPU filtering supports only 1 or 3 channel input images.");
|
||||
}
|
||||
|
||||
auto output_frame = absl::make_unique<ImageFrame>(
|
||||
input_frame.Format(), input_mat.cols, input_mat.rows);
|
||||
const bool has_guide_image = cc->Inputs().HasTag(kInputGuideTag) &&
|
||||
!cc->Inputs().Tag(kInputGuideTag).IsEmpty();
|
||||
|
||||
if (has_guide_image) {
|
||||
// cv::jointBilateralFilter() is in contrib module 'ximgproc'.
|
||||
return ::mediapipe::UnimplementedError(
|
||||
"CPU joint filtering support is not implemented yet.");
|
||||
} else {
|
||||
auto output_mat = mediapipe::formats::MatView(output_frame.get());
|
||||
// Prefer setting 'd = sigma_space * 2' to match GPU definition of radius.
|
||||
cv::bilateralFilter(input_mat, output_mat, /*d=*/sigma_space_ * 2.0,
|
||||
sigma_color_, sigma_space_);
|
||||
}
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kOutputFrameTag)
|
||||
.Add(output_frame.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status BilateralFilterCalculator::RenderGpu(
|
||||
CalculatorContext* cc) {
|
||||
if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
const auto& input_frame =
|
||||
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
|
||||
auto input_texture = gpu_helper_.CreateSourceTexture(input_frame);
|
||||
|
||||
mediapipe::GlTexture output_texture;
|
||||
const bool has_guide_image = cc->Inputs().HasTag(kInputGuideTagGpu) &&
|
||||
!cc->Inputs().Tag(kInputGuideTagGpu).IsEmpty();
|
||||
|
||||
// Setup textures and Update image in GPU shader.
|
||||
if (has_guide_image) {
|
||||
// joint bilateral filter
|
||||
glUseProgram(program_joint_);
|
||||
const auto& guide_image =
|
||||
cc->Inputs().Tag(kInputGuideTagGpu).Get<mediapipe::GpuBuffer>();
|
||||
auto guide_texture = gpu_helper_.CreateSourceTexture(guide_image);
|
||||
glUniform2f(glGetUniformLocation(program_joint_, "texel_size_guide"),
|
||||
1.0 / guide_image.width(), 1.0 / guide_image.height());
|
||||
output_texture = gpu_helper_.CreateDestinationTexture(
|
||||
guide_image.width(), guide_image.height(),
|
||||
mediapipe::GpuBufferFormat::kBGRA32);
|
||||
gpu_helper_.BindFramebuffer(output_texture);
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(GL_TEXTURE_2D, input_texture.name());
|
||||
glActiveTexture(GL_TEXTURE2);
|
||||
glBindTexture(GL_TEXTURE_2D, guide_texture.name());
|
||||
GlRender(cc);
|
||||
glActiveTexture(GL_TEXTURE2);
|
||||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
guide_texture.Release();
|
||||
} else {
|
||||
// regular bilateral filter
|
||||
glUseProgram(program_);
|
||||
glUniform2f(glGetUniformLocation(program_, "texel_size"),
|
||||
1.0 / input_frame.width(), 1.0 / input_frame.height());
|
||||
output_texture = gpu_helper_.CreateDestinationTexture(
|
||||
input_frame.width(), input_frame.height(),
|
||||
mediapipe::GpuBufferFormat::kBGRA32);
|
||||
gpu_helper_.BindFramebuffer(output_texture);
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(GL_TEXTURE_2D, input_texture.name());
|
||||
GlRender(cc);
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
}
|
||||
glFlush();
|
||||
|
||||
// Send out image as GPU packet.
|
||||
auto output_frame = output_texture.GetFrame<mediapipe::GpuBuffer>();
|
||||
cc->Outputs()
|
||||
.Tag(kOutputFrameTagGpu)
|
||||
.Add(output_frame.release(), cc->InputTimestamp());
|
||||
|
||||
// Cleanup
|
||||
input_texture.Release();
|
||||
output_texture.Release();
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
void BilateralFilterCalculator::GlRender(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
static const GLfloat square_vertices[] = {
|
||||
-1.0f, -1.0f, // bottom left
|
||||
1.0f, -1.0f, // bottom right
|
||||
-1.0f, 1.0f, // top left
|
||||
1.0f, 1.0f, // top right
|
||||
};
|
||||
static const GLfloat texture_vertices[] = {
|
||||
0.0f, 0.0f, // bottom left
|
||||
1.0f, 0.0f, // bottom right
|
||||
0.0f, 1.0f, // top left
|
||||
1.0f, 1.0f, // top right
|
||||
};
|
||||
|
||||
// vertex storage
|
||||
GLuint vbo[2];
|
||||
glGenBuffers(2, vbo);
|
||||
GLuint vao;
|
||||
glGenVertexArrays(1, &vao);
|
||||
glBindVertexArray(vao);
|
||||
|
||||
// vbo 0
|
||||
glBindBuffer(GL_ARRAY_BUFFER, vbo[0]);
|
||||
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices,
|
||||
GL_STATIC_DRAW);
|
||||
glEnableVertexAttribArray(ATTRIB_VERTEX);
|
||||
glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr);
|
||||
|
||||
// vbo 1
|
||||
glBindBuffer(GL_ARRAY_BUFFER, vbo[1]);
|
||||
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices,
|
||||
GL_STATIC_DRAW);
|
||||
glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
|
||||
glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr);
|
||||
|
||||
// draw
|
||||
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
|
||||
|
||||
// cleanup
|
||||
glDisableVertexAttribArray(ATTRIB_VERTEX);
|
||||
glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
|
||||
glBindBuffer(GL_ARRAY_BUFFER, 0);
|
||||
glBindVertexArray(0);
|
||||
glDeleteVertexArrays(1, &vao);
|
||||
glDeleteBuffers(2, vbo);
|
||||
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
}
|
||||
|
||||
::mediapipe::Status BilateralFilterCalculator::GlSetup(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
const GLint attr_location[NUM_ATTRIBUTES] = {
|
||||
ATTRIB_VERTEX,
|
||||
ATTRIB_TEXTURE_POSITION,
|
||||
};
|
||||
const GLchar* attr_name[NUM_ATTRIBUTES] = {
|
||||
"position",
|
||||
"texture_coordinate",
|
||||
};
|
||||
|
||||
// We bake our sigma values directly into the shader, so the GLSL compiler can
|
||||
// optimize appropriately.
|
||||
std::string sigma_options_string =
|
||||
"const float sigma_space = " + std::to_string(sigma_space_) +
|
||||
"; const float sigma_color = " + std::to_string(sigma_color_) + ";\n";
|
||||
|
||||
// Shader to do bilateral filtering on input image based on sigma space/color.
|
||||
// Large kernel sizes are subsampled based on sqrt(sigma_space) window size,
|
||||
// denoted as 'sparsity' below.
|
||||
const std::string frag_src = GLES_VERSION_COMPAT
|
||||
R"(
|
||||
#if __VERSION__ < 130
|
||||
#define in varying
|
||||
#endif // __VERSION__ < 130
|
||||
|
||||
#ifdef GL_ES
|
||||
#define fragColor gl_FragColor
|
||||
precision highp float;
|
||||
#else
|
||||
#define lowp
|
||||
#define mediump
|
||||
#define highp
|
||||
#define texture2D texture
|
||||
out vec4 fragColor;
|
||||
#endif // defined(GL_ES)
|
||||
|
||||
in vec2 sample_coordinate;
|
||||
uniform sampler2D input_frame;
|
||||
)" + sigma_options_string + R"(
|
||||
uniform vec2 texel_size;
|
||||
|
||||
const float kSparsityFactor = 0.66; // Higher is more sparse.
|
||||
const float sparsity = max(1.0, sqrt(sigma_space) * kSparsityFactor);
|
||||
const float step = sparsity;
|
||||
const float radius = sigma_space;
|
||||
const float offset = (step > 1.0) ? (step * 0.5) : (0.0);
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
float coeff = -0.5 / (sigma * sigma * 4.0 + 1.0e-6);
|
||||
return exp((x * x) * coeff);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 center_uv = sample_coordinate;
|
||||
vec3 center_val = texture2D(input_frame, center_uv).rgb;
|
||||
vec3 new_val = vec3(0.0);
|
||||
|
||||
float space_weight = 0.0;
|
||||
float color_weight = 0.0;
|
||||
float total_weight = 0.0;
|
||||
|
||||
float sigma_texel = max(texel_size.x, texel_size.y) * sigma_space;
|
||||
// Subsample kernel space.
|
||||
for (float i = -radius+offset; i <= radius; i+=step) {
|
||||
for (float j = -radius+offset; j <= radius; j+=step) {
|
||||
vec2 shift = vec2(j, i) * texel_size;
|
||||
vec2 uv = vec2(center_uv + shift);
|
||||
vec3 val = texture2D(input_frame, uv).rgb;
|
||||
|
||||
space_weight = gaussian(distance(center_uv, uv), sigma_texel);
|
||||
color_weight = gaussian(distance(center_val, val), sigma_color);
|
||||
total_weight += space_weight * color_weight;
|
||||
|
||||
new_val += vec3(space_weight * color_weight) * val;
|
||||
}
|
||||
}
|
||||
new_val /= vec3(total_weight);
|
||||
|
||||
fragColor = vec4(new_val, 1.0);
|
||||
}
|
||||
)";
|
||||
|
||||
// Create shader program and set parameters.
|
||||
mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src.c_str(),
|
||||
NUM_ATTRIBUTES, (const GLchar**)&attr_name[0],
|
||||
attr_location, &program_);
|
||||
RET_CHECK(program_) << "Problem initializing the program.";
|
||||
glUseProgram(program_);
|
||||
glUniform1i(glGetUniformLocation(program_, "input_frame"), 1);
|
||||
|
||||
// Shader to do joint bilateral filtering on input image based on
|
||||
// sigma space/color, and a Guide image.
|
||||
// Large kernel sizes are subsampled based on sqrt(sigma_space) window size,
|
||||
// denoted as 'sparsity' below.
|
||||
const std::string joint_frag_src = GLES_VERSION_COMPAT
|
||||
R"(
|
||||
#if __VERSION__ < 130
|
||||
#define in varying
|
||||
#endif // __VERSION__ < 130
|
||||
|
||||
#ifdef GL_ES
|
||||
#define fragColor gl_FragColor
|
||||
precision highp float;
|
||||
#else
|
||||
#define lowp
|
||||
#define mediump
|
||||
#define highp
|
||||
#define texture2D texture
|
||||
out vec4 fragColor;
|
||||
#endif // defined(GL_ES)
|
||||
|
||||
in vec2 sample_coordinate;
|
||||
uniform sampler2D input_frame;
|
||||
uniform sampler2D guide_frame;
|
||||
)" + sigma_options_string + R"(
|
||||
uniform vec2 texel_size_guide; // size of guide and resulting filtered image
|
||||
|
||||
const float kSparsityFactor = 0.66; // Higher is more sparse.
|
||||
const float sparsity = max(1.0, sqrt(sigma_space) * kSparsityFactor);
|
||||
const float step = sparsity;
|
||||
const float radius = sigma_space;
|
||||
const float offset = (step > 1.0) ? (step * 0.5) : (0.0);
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
float coeff = -0.5 / (sigma * sigma * 4.0 + 1.0e-6);
|
||||
return exp((x * x) * coeff);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 center_uv = sample_coordinate;
|
||||
vec3 center_val = texture2D(guide_frame, center_uv).rgb;
|
||||
vec3 new_val = vec3(0.0);
|
||||
|
||||
float space_weight = 0.0;
|
||||
float color_weight = 0.0;
|
||||
float total_weight = 0.0;
|
||||
|
||||
float sigma_texel = max(texel_size_guide.x, texel_size_guide.y) * sigma_space;
|
||||
// Subsample kernel space.
|
||||
for (float i = -radius+offset; i <= radius; i+=step) {
|
||||
for (float j = -radius+offset; j <= radius; j+=step) {
|
||||
vec2 shift = vec2(j, i) * texel_size_guide;
|
||||
vec2 uv = vec2(center_uv + shift);
|
||||
vec3 guide_val = texture2D(guide_frame, uv).rgb;
|
||||
vec3 out_val = texture2D(input_frame, uv).rgb;
|
||||
|
||||
space_weight = gaussian(distance(center_uv, uv), sigma_texel);
|
||||
color_weight = gaussian(distance(center_val, guide_val), sigma_color);
|
||||
total_weight += space_weight * color_weight;
|
||||
|
||||
new_val += vec3(space_weight * color_weight) * out_val;
|
||||
}
|
||||
}
|
||||
new_val /= vec3(total_weight);
|
||||
|
||||
fragColor = vec4(new_val, 1.0);
|
||||
}
|
||||
)";
|
||||
|
||||
// Create shader program and set parameters.
|
||||
mediapipe::GlhCreateProgram(
|
||||
mediapipe::kBasicVertexShader, joint_frag_src.c_str(), NUM_ATTRIBUTES,
|
||||
(const GLchar**)&attr_name[0], attr_location, &program_joint_);
|
||||
RET_CHECK(program_joint_) << "Problem initializing the program.";
|
||||
glUseProgram(program_joint_);
|
||||
glUniform1i(glGetUniformLocation(program_joint_, "input_frame"), 1);
|
||||
glUniform1i(glGetUniformLocation(program_joint_, "guide_frame"), 2);
|
||||
|
||||
#endif // __ANDROID__ || __EMSCRIPTEN__
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,20 @@
|
|||
// Options for BilateralFilterCalculator
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message BilateralFilterCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional BilateralFilterCalculatorOptions ext = 255670209;
|
||||
}
|
||||
|
||||
// Max variance in color allowed, based on normalized color values.
|
||||
optional float sigma_color = 1;
|
||||
|
||||
// Window radius.
|
||||
// Results in a '(sigma_space*2+1) x (sigma_space*2+1)' size kernel.
|
||||
// This should be set based on output image pixel space.
|
||||
optional float sigma_space = 2;
|
||||
}
|
|
@ -1,3 +1,17 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
|
|
|
@ -12,6 +12,9 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "mediapipe/calculators/image/image_cropping_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
|
@ -21,47 +24,98 @@
|
|||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gl_simple_shaders.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "mediapipe/gpu/shader_util.h"
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
namespace {
|
||||
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
|
||||
} // namespace
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
// Crops the input texture to the given rectangle region. The rectangle can
|
||||
// be at arbitrary location on the image with rotation. If there's rotation, the
|
||||
// output texture will have the size of the input rectangle. The rotation should
|
||||
// be in radian, see rect.proto for detail.
|
||||
// Currently it only works for CPU.
|
||||
//
|
||||
// Input:
|
||||
// IMAGE: ImageFrame representing the input image.
|
||||
// One of the following two tags:
|
||||
// IMAGE - ImageFrame representing the input image.
|
||||
// IMAGE_GPU - GpuBuffer representing the input image.
|
||||
// One of the following two tags (optional if WIDTH/HEIGHT is specified):
|
||||
// RECT - A Rect proto specifying the width/height and location of the
|
||||
// cropping rectangle.
|
||||
// NORM_RECT - A NormalizedRect proto specifying the width/height and location
|
||||
// of the cropping rectangle in normalized coordinates.
|
||||
// Alternative tags to RECT (optional if RECT/NORM_RECT is specified):
|
||||
// WIDTH - The desired width of the output cropped image,
|
||||
// based on image center
|
||||
// HEIGHT - The desired height of the output cropped image,
|
||||
// based on image center
|
||||
//
|
||||
// Output:
|
||||
// IMAGE - Cropped frames.
|
||||
// One of the following two tags:
|
||||
// IMAGE - Cropped ImageFrame
|
||||
// IMAGE_GPU - Cropped GpuBuffer.
|
||||
//
|
||||
// Note: input_stream values take precedence over options defined in the graph.
|
||||
//
|
||||
class ImageCroppingCalculator : public CalculatorBase {
|
||||
public:
|
||||
ImageCroppingCalculator() = default;
|
||||
~ImageCroppingCalculator() override = default;
|
||||
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
::mediapipe::Status RenderCpu(CalculatorContext* cc);
|
||||
::mediapipe::Status RenderGpu(CalculatorContext* cc);
|
||||
::mediapipe::Status InitGpu(CalculatorContext* cc);
|
||||
void GlRender();
|
||||
void GetOutputDimensions(CalculatorContext* cc, int src_width, int src_height,
|
||||
int* dst_width, int* dst_height);
|
||||
|
||||
// TODO: Merge with GlCroppingCalculator to have GPU support.
|
||||
bool use_gpu_{};
|
||||
mediapipe::ImageCroppingCalculatorOptions options_;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
// Output texture corners (4) after transoformation in normalized coordinates.
|
||||
float transformed_points_[8];
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
bool gpu_initialized_ = false;
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
GLuint program_ = 0;
|
||||
#endif // __ANDROID__ or iOS
|
||||
};
|
||||
REGISTER_CALCULATOR(ImageCroppingCalculator);
|
||||
|
||||
::mediapipe::Status ImageCroppingCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("IMAGE"));
|
||||
RET_CHECK(cc->Outputs().HasTag("IMAGE"));
|
||||
RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU"));
|
||||
RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU"));
|
||||
|
||||
if (cc->Inputs().HasTag("IMAGE")) {
|
||||
RET_CHECK(cc->Outputs().HasTag("IMAGE"));
|
||||
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
|
||||
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
|
||||
}
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU")) {
|
||||
RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU"));
|
||||
cc->Inputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
|
||||
cc->Outputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
|
||||
}
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
if (cc->Inputs().HasTag("RECT")) {
|
||||
cc->Inputs().Tag("RECT").Set<Rect>();
|
||||
|
@ -69,21 +123,71 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
|
|||
if (cc->Inputs().HasTag("NORM_RECT")) {
|
||||
cc->Inputs().Tag("NORM_RECT").Set<NormalizedRect>();
|
||||
}
|
||||
if (cc->Inputs().HasTag("WIDTH")) {
|
||||
cc->Inputs().Tag("WIDTH").Set<int>();
|
||||
}
|
||||
if (cc->Inputs().HasTag("HEIGHT")) {
|
||||
cc->Inputs().Tag("HEIGHT").Set<int>();
|
||||
}
|
||||
|
||||
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status ImageCroppingCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU")) {
|
||||
use_gpu_ = true;
|
||||
}
|
||||
|
||||
options_ = cc->Options<mediapipe::ImageCroppingCalculatorOptions>();
|
||||
|
||||
if (use_gpu_) {
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
|
||||
#endif // __ANDROID__ or iOS
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status ImageCroppingCalculator::Process(CalculatorContext* cc) {
|
||||
if (use_gpu_) {
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
|
||||
if (!gpu_initialized_) {
|
||||
RETURN_IF_ERROR(InitGpu(cc));
|
||||
gpu_initialized_ = true;
|
||||
}
|
||||
RETURN_IF_ERROR(RenderGpu(cc));
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
#endif // __ANDROID__ or iOS
|
||||
} else {
|
||||
RETURN_IF_ERROR(RenderCpu(cc));
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status ImageCroppingCalculator::Close(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
gpu_helper_.RunInGlContext([this] {
|
||||
if (program_) glDeleteProgram(program_);
|
||||
program_ = 0;
|
||||
});
|
||||
gpu_initialized_ = false;
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status ImageCroppingCalculator::RenderCpu(CalculatorContext* cc) {
|
||||
const auto& input_img = cc->Inputs().Tag("IMAGE").Get<ImageFrame>();
|
||||
cv::Mat input_mat = formats::MatView(&input_img);
|
||||
|
@ -97,43 +201,53 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
|
|||
const auto& rect = cc->Inputs().Tag("RECT").Get<Rect>();
|
||||
if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 &&
|
||||
rect.y_center() >= 0) {
|
||||
rotation = rect.rotation();
|
||||
rect_center_x = rect.x_center();
|
||||
rect_center_y = rect.y_center();
|
||||
target_width = rect.width();
|
||||
target_height = rect.height();
|
||||
rotation = rect.rotation();
|
||||
}
|
||||
} else if (cc->Inputs().HasTag("NORM_RECT")) {
|
||||
const auto& rect = cc->Inputs().Tag("NORM_RECT").Get<NormalizedRect>();
|
||||
if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 &&
|
||||
rect.y_center() >= 0.0) {
|
||||
rotation = rect.rotation();
|
||||
rect_center_x = std::round(rect.x_center() * input_img.Width());
|
||||
rect_center_y = std::round(rect.y_center() * input_img.Height());
|
||||
target_width = std::round(rect.width() * input_img.Width());
|
||||
target_height = std::round(rect.height() * input_img.Height());
|
||||
rotation = rect.rotation();
|
||||
}
|
||||
}
|
||||
|
||||
cv::Mat rotated_mat;
|
||||
if (std::abs(rotation) > 1e-5) {
|
||||
// TODO: Use open source common math library.
|
||||
const float pi = 3.1415926f;
|
||||
rotation = rotation * 180.0 / pi;
|
||||
|
||||
// First rotation the image.
|
||||
cv::Point2f src_center(rect_center_x, rect_center_y);
|
||||
cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, rotation, 1.0);
|
||||
cv::warpAffine(input_mat, rotated_mat, rotation_mat, input_mat.size());
|
||||
} else {
|
||||
input_mat.copyTo(rotated_mat);
|
||||
if (cc->Inputs().HasTag("WIDTH") && cc->Inputs().HasTag("HEIGHT")) {
|
||||
target_width = cc->Inputs().Tag("WIDTH").Get<int>();
|
||||
target_height = cc->Inputs().Tag("HEIGHT").Get<int>();
|
||||
} else if (options_.has_width() && options_.has_height()) {
|
||||
target_width = options_.width();
|
||||
target_height = options_.height();
|
||||
}
|
||||
rotation = options_.rotation();
|
||||
}
|
||||
|
||||
// Then crop the requested area.
|
||||
const cv::Rect cropping_rect(rect_center_x - target_width / 2,
|
||||
rect_center_y - target_height / 2, target_width,
|
||||
target_height);
|
||||
cv::Mat cropped_image = cv::Mat(rotated_mat, cropping_rect);
|
||||
const cv::RotatedRect min_rect(cv::Point2f(rect_center_x, rect_center_y),
|
||||
cv::Size2f(target_width, target_height),
|
||||
rotation * 180.f / M_PI);
|
||||
cv::Mat src_points;
|
||||
cv::boxPoints(min_rect, src_points);
|
||||
|
||||
float dst_corners[8] = {0,
|
||||
min_rect.size.height - 1,
|
||||
0,
|
||||
0,
|
||||
min_rect.size.width - 1,
|
||||
0,
|
||||
min_rect.size.width - 1,
|
||||
min_rect.size.height - 1};
|
||||
cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners);
|
||||
cv::Mat projection_matrix =
|
||||
cv::getPerspectiveTransform(src_points, dst_points);
|
||||
cv::Mat cropped_image;
|
||||
cv::warpPerspective(input_mat, cropped_image, projection_matrix,
|
||||
cv::Size(min_rect.size.width, min_rect.size.height));
|
||||
|
||||
std::unique_ptr<ImageFrame> output_frame(new ImageFrame(
|
||||
input_img.Format(), cropped_image.cols, cropped_image.rows));
|
||||
|
@ -144,7 +258,220 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) {
|
||||
return ::mediapipe::UnimplementedError("GPU support is not implemented yet.");
|
||||
if (cc->Inputs().Tag("IMAGE_GPU").IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value();
|
||||
const auto& input_buffer = input_packet.Get<mediapipe::GpuBuffer>();
|
||||
auto src_tex = gpu_helper_.CreateSourceTexture(input_buffer);
|
||||
|
||||
int out_width, out_height;
|
||||
GetOutputDimensions(cc, src_tex.width(), src_tex.height(), &out_width,
|
||||
&out_height);
|
||||
auto dst_tex = gpu_helper_.CreateDestinationTexture(out_width, out_height);
|
||||
|
||||
// Run cropping shader on GPU.
|
||||
{
|
||||
gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0
|
||||
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
glBindTexture(src_tex.target(), src_tex.name());
|
||||
|
||||
GlRender();
|
||||
|
||||
glActiveTexture(GL_TEXTURE2);
|
||||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
glFlush();
|
||||
}
|
||||
|
||||
// Send result image in GPU packet.
|
||||
auto output = dst_tex.GetFrame<mediapipe::GpuBuffer>();
|
||||
cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp());
|
||||
|
||||
// Cleanup
|
||||
src_tex.Release();
|
||||
dst_tex.Release();
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
void ImageCroppingCalculator::GlRender() {
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
static const GLfloat square_vertices[] = {
|
||||
-1.0f, -1.0f, // bottom left
|
||||
1.0f, -1.0f, // bottom right
|
||||
-1.0f, 1.0f, // top left
|
||||
1.0f, 1.0f, // top right
|
||||
};
|
||||
const GLfloat* texture_vertices = &transformed_points_[0];
|
||||
|
||||
// program
|
||||
glUseProgram(program_);
|
||||
|
||||
// vertex storage
|
||||
GLuint vbo[2];
|
||||
glGenBuffers(2, vbo);
|
||||
GLuint vao;
|
||||
glGenVertexArrays(1, &vao);
|
||||
glBindVertexArray(vao);
|
||||
|
||||
// vbo 0
|
||||
glBindBuffer(GL_ARRAY_BUFFER, vbo[0]);
|
||||
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices,
|
||||
GL_STATIC_DRAW);
|
||||
glEnableVertexAttribArray(ATTRIB_VERTEX);
|
||||
glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr);
|
||||
|
||||
// vbo 1
|
||||
glBindBuffer(GL_ARRAY_BUFFER, vbo[1]);
|
||||
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices,
|
||||
GL_STATIC_DRAW);
|
||||
glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
|
||||
glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr);
|
||||
|
||||
// draw
|
||||
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
|
||||
|
||||
// cleanup
|
||||
glDisableVertexAttribArray(ATTRIB_VERTEX);
|
||||
glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
|
||||
glBindBuffer(GL_ARRAY_BUFFER, 0);
|
||||
glBindVertexArray(0);
|
||||
glDeleteVertexArrays(1, &vao);
|
||||
glDeleteBuffers(2, vbo);
|
||||
|
||||
#endif // __ANDROID__ or iOS
|
||||
}
|
||||
|
||||
::mediapipe::Status ImageCroppingCalculator::InitGpu(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
const GLint attr_location[NUM_ATTRIBUTES] = {
|
||||
ATTRIB_VERTEX,
|
||||
ATTRIB_TEXTURE_POSITION,
|
||||
};
|
||||
const GLchar* attr_name[NUM_ATTRIBUTES] = {
|
||||
"position",
|
||||
"texture_coordinate",
|
||||
};
|
||||
|
||||
// Simple pass-through shader.
|
||||
const GLchar* frag_src = GLES_VERSION_COMPAT
|
||||
R"(
|
||||
#if __VERSION__ < 130
|
||||
#define in varying
|
||||
#endif // __VERSION__ < 130
|
||||
|
||||
#ifdef GL_ES
|
||||
#define fragColor gl_FragColor
|
||||
precision highp float;
|
||||
#else
|
||||
#define lowp
|
||||
#define mediump
|
||||
#define highp
|
||||
#define texture2D texture
|
||||
out vec4 fragColor;
|
||||
#endif // defined(GL_ES)
|
||||
|
||||
in vec2 sample_coordinate;
|
||||
uniform sampler2D input_frame;
|
||||
|
||||
void main() {
|
||||
vec4 pix = texture2D(input_frame, sample_coordinate);
|
||||
fragColor = pix;
|
||||
}
|
||||
)";
|
||||
|
||||
// Program
|
||||
mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src,
|
||||
NUM_ATTRIBUTES, &attr_name[0], attr_location,
|
||||
&program_);
|
||||
RET_CHECK(program_) << "Problem initializing the program.";
|
||||
|
||||
// Parameters
|
||||
glUseProgram(program_);
|
||||
glUniform1i(glGetUniformLocation(program_, "input_frame"), 1);
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// For GPU only.
|
||||
void ImageCroppingCalculator::GetOutputDimensions(CalculatorContext* cc,
|
||||
int src_width, int src_height,
|
||||
int* dst_width,
|
||||
int* dst_height) {
|
||||
// Get the size of the cropping box.
|
||||
int crop_width = src_width;
|
||||
int crop_height = src_height;
|
||||
// Get the center of cropping box. Default is the at the center.
|
||||
int x_center = src_width / 2;
|
||||
int y_center = src_height / 2;
|
||||
// Get the rotation of the cropping box.
|
||||
float rotation = 0.0f;
|
||||
if (cc->Inputs().HasTag("RECT")) {
|
||||
const auto& rect = cc->Inputs().Tag("RECT").Get<Rect>();
|
||||
// Only use the rect if it is valid.
|
||||
if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 &&
|
||||
rect.y_center() >= 0) {
|
||||
x_center = rect.x_center();
|
||||
y_center = rect.y_center();
|
||||
crop_width = rect.width();
|
||||
crop_height = rect.height();
|
||||
rotation = rect.rotation();
|
||||
}
|
||||
} else if (cc->Inputs().HasTag("NORM_RECT")) {
|
||||
const auto& rect = cc->Inputs().Tag("NORM_RECT").Get<NormalizedRect>();
|
||||
// Only use the rect if it is valid.
|
||||
if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 &&
|
||||
rect.y_center() >= 0.0) {
|
||||
x_center = std::round(rect.x_center() * src_width);
|
||||
y_center = std::round(rect.y_center() * src_height);
|
||||
crop_width = std::round(rect.width() * src_width);
|
||||
crop_height = std::round(rect.height() * src_height);
|
||||
rotation = rect.rotation();
|
||||
}
|
||||
} else {
|
||||
if (cc->Inputs().HasTag("WIDTH") && cc->Inputs().HasTag("HEIGHT")) {
|
||||
crop_width = cc->Inputs().Tag("WIDTH").Get<int>();
|
||||
crop_height = cc->Inputs().Tag("HEIGHT").Get<int>();
|
||||
} else if (options_.has_width() && options_.has_height()) {
|
||||
crop_width = options_.width();
|
||||
crop_height = options_.height();
|
||||
}
|
||||
rotation = options_.rotation();
|
||||
}
|
||||
|
||||
const float half_width = crop_width / 2.0f;
|
||||
const float half_height = crop_height / 2.0f;
|
||||
const float corners[] = {-half_width, -half_height, half_width, -half_height,
|
||||
-half_width, half_height, half_width, half_height};
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
const float rotated_x = std::cos(rotation) * corners[i * 2] -
|
||||
std::sin(rotation) * corners[i * 2 + 1];
|
||||
const float rotated_y = std::sin(rotation) * corners[i * 2] +
|
||||
std::cos(rotation) * corners[i * 2 + 1];
|
||||
|
||||
transformed_points_[i * 2] = ((rotated_x + x_center) / src_width);
|
||||
transformed_points_[i * 2 + 1] = ((rotated_y + y_center) / src_height);
|
||||
}
|
||||
|
||||
// Find the boundaries of the transformed rectangle.
|
||||
float col_min = transformed_points_[0];
|
||||
float col_max = transformed_points_[0];
|
||||
float row_min = transformed_points_[1];
|
||||
float row_max = transformed_points_[1];
|
||||
for (int i = 1; i < 4; ++i) {
|
||||
col_min = std::min(col_min, transformed_points_[i * 2]);
|
||||
col_max = std::max(col_max, transformed_points_[i * 2]);
|
||||
row_min = std::min(row_min, transformed_points_[i * 2 + 1]);
|
||||
row_max = std::max(row_max, transformed_points_[i * 2 + 1]);
|
||||
}
|
||||
|
||||
*dst_width = std::round((col_max - col_min) * src_width);
|
||||
*dst_height = std::round((row_max - row_min) * src_height);
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
33
mediapipe/calculators/image/image_cropping_calculator.proto
Normal file
33
mediapipe/calculators/image/image_cropping_calculator.proto
Normal file
|
@ -0,0 +1,33 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message ImageCroppingCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional ImageCroppingCalculatorOptions ext = 262466399;
|
||||
}
|
||||
|
||||
// Output texture buffer dimensions. The values defined in the options will be
|
||||
// overriden by the WIDTH and HEIGHT input streams if they exist.
|
||||
optional int32 width = 1;
|
||||
optional int32 height = 2;
|
||||
|
||||
// Rotation angle is counter-clockwise in radian.
|
||||
optional float rotation = 3 [default = 0.0];
|
||||
}
|
93
mediapipe/calculators/image/image_properties_calculator.cc
Normal file
93
mediapipe/calculators/image/image_properties_calculator.cc
Normal file
|
@ -0,0 +1,93 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Extracts image properties from the input image and outputs the properties.
|
||||
// Currently only supports image size.
|
||||
// Input:
|
||||
// One of the following:
|
||||
// IMAGE: An ImageFrame
|
||||
// IMAGE_GPU: A GpuBuffer
|
||||
//
|
||||
// Output:
|
||||
// SIZE: Size (as a std::pair<int, int>) of the input image.
|
||||
//
|
||||
// Example usage:
|
||||
// node {
|
||||
// calculator: "ImagePropertiesCalculator"
|
||||
// input_stream: "IMAGE:image"
|
||||
// output_stream: "SIZE:size"
|
||||
// }
|
||||
class ImagePropertiesCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU"));
|
||||
if (cc->Inputs().HasTag("IMAGE")) {
|
||||
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
|
||||
}
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU")) {
|
||||
cc->Inputs().Tag("IMAGE_GPU").Set<::mediapipe::GpuBuffer>();
|
||||
}
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
if (cc->Outputs().HasTag("SIZE")) {
|
||||
cc->Outputs().Tag("SIZE").Set<std::pair<int, int>>();
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
int width;
|
||||
int height;
|
||||
|
||||
if (cc->Inputs().HasTag("IMAGE") && !cc->Inputs().Tag("IMAGE").IsEmpty()) {
|
||||
const auto& image = cc->Inputs().Tag("IMAGE").Get<ImageFrame>();
|
||||
width = image.Width();
|
||||
height = image.Height();
|
||||
}
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU") &&
|
||||
!cc->Inputs().Tag("IMAGE_GPU").IsEmpty()) {
|
||||
const auto& image =
|
||||
cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
|
||||
width = image.width();
|
||||
height = image.height();
|
||||
}
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
cc->Outputs().Tag("SIZE").AddPacket(
|
||||
MakePacket<std::pair<int, int>>(width, height)
|
||||
.At(cc->InputTimestamp()));
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(ImagePropertiesCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -22,12 +22,12 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/gpu/scale_mode.pb.h"
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gl_quad_renderer.h"
|
||||
#include "mediapipe/gpu/gl_simple_shaders.h"
|
||||
#include "mediapipe/gpu/shader_util.h"
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ || iOS
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
// The size of Java arrays is dynamic, which makes it difficult to
|
||||
|
@ -42,9 +42,9 @@ typedef int DimensionsPacketType[2];
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
|
||||
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ || iOS
|
||||
|
||||
namespace {
|
||||
int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) {
|
||||
|
@ -170,11 +170,12 @@ class ImageTransformationCalculator : public CalculatorBase {
|
|||
mediapipe::ScaleMode_Mode scale_mode_;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
|
||||
GlCalculatorHelper helper_;
|
||||
std::unique_ptr<QuadRenderer> rgb_renderer_;
|
||||
std::unique_ptr<QuadRenderer> yuv_renderer_;
|
||||
std::unique_ptr<QuadRenderer> ext_rgb_renderer_;
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ || iOS
|
||||
};
|
||||
REGISTER_CALCULATOR(ImageTransformationCalculator);
|
||||
|
||||
|
@ -189,13 +190,13 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
|
||||
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
|
||||
}
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU")) {
|
||||
RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU"));
|
||||
cc->Inputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
|
||||
cc->Outputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
|
||||
}
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ || iOS
|
||||
if (cc->Inputs().HasTag("ROTATION_DEGREES")) {
|
||||
cc->Inputs().Tag("ROTATION_DEGREES").Set<int>();
|
||||
}
|
||||
|
@ -211,9 +212,9 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
cc->Outputs().Tag("LETTERBOX_PADDING").Set<std::array<float, 4>>();
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
|
||||
RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ || iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -221,7 +222,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
::mediapipe::Status ImageTransformationCalculator::Open(CalculatorContext* cc) {
|
||||
// Inform the framework that we always output at the same timestamp
|
||||
// as we receive a packet at.
|
||||
cc->SetOffset(mediapipe::TimestampDiff(0));
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
options_ = cc->Options<ImageTransformationCalculatorOptions>();
|
||||
|
||||
|
@ -249,12 +250,12 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
scale_mode_ = ParseScaleMode(options_.scale_mode(), DEFAULT_SCALE_MODE);
|
||||
|
||||
if (use_gpu_) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
|
||||
// Let the helper access the GL context information.
|
||||
RETURN_IF_ERROR(helper_.Open(cc));
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing for non-Android not supported yet.";
|
||||
#endif // __ANDROID__
|
||||
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
|
||||
#endif // __ANDROID__ || iOS
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
|
@ -263,10 +264,10 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
::mediapipe::Status ImageTransformationCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
if (use_gpu_) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
|
||||
return helper_.RunInGlContext(
|
||||
[this, cc]() -> ::mediapipe::Status { return RenderGpu(cc); });
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ || iOS
|
||||
} else {
|
||||
return RenderCpu(cc);
|
||||
}
|
||||
|
@ -276,10 +277,11 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
::mediapipe::Status ImageTransformationCalculator::Close(
|
||||
CalculatorContext* cc) {
|
||||
if (use_gpu_) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
|
||||
QuadRenderer* rgb_renderer = rgb_renderer_.release();
|
||||
QuadRenderer* yuv_renderer = yuv_renderer_.release();
|
||||
QuadRenderer* ext_rgb_renderer = ext_rgb_renderer_.release();
|
||||
helper_.RunInGlContext([rgb_renderer, ext_rgb_renderer] {
|
||||
helper_.RunInGlContext([rgb_renderer, yuv_renderer, ext_rgb_renderer] {
|
||||
if (rgb_renderer) {
|
||||
rgb_renderer->GlTeardown();
|
||||
delete rgb_renderer;
|
||||
|
@ -288,10 +290,13 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
ext_rgb_renderer->GlTeardown();
|
||||
delete ext_rgb_renderer;
|
||||
}
|
||||
});
|
||||
#endif // __ANDROID__
|
||||
if (yuv_renderer) {
|
||||
yuv_renderer->GlTeardown();
|
||||
delete yuv_renderer;
|
||||
}
|
||||
});
|
||||
#endif // __ANDROID__ || iOS
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -366,7 +371,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
|
||||
::mediapipe::Status ImageTransformationCalculator::RenderGpu(
|
||||
CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
|
||||
int input_width = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>().width();
|
||||
int input_height = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>().height();
|
||||
|
||||
|
@ -387,8 +392,23 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>();
|
||||
QuadRenderer* renderer = nullptr;
|
||||
GlTexture src1;
|
||||
|
||||
#if defined(__APPLE__) && !TARGET_OS_OSX
|
||||
if (input.format() == GpuBufferFormat::kBiPlanar420YpCbCr8VideoRange ||
|
||||
input.format() == GpuBufferFormat::kBiPlanar420YpCbCr8FullRange) {
|
||||
if (!yuv_renderer_) {
|
||||
yuv_renderer_ = absl::make_unique<QuadRenderer>();
|
||||
RETURN_IF_ERROR(
|
||||
yuv_renderer_->GlSetup(::mediapipe::kYUV2TexToRGBFragmentShader,
|
||||
{"video_frame_y", "video_frame_uv"}));
|
||||
}
|
||||
renderer = yuv_renderer_.get();
|
||||
src1 = helper_.CreateSourceTexture(input, 0);
|
||||
} else // NOLINT(readability/braces)
|
||||
#endif // iOS
|
||||
{
|
||||
src1 = helper_.CreateSourceTexture(input);
|
||||
#if defined(__ANDROID__)
|
||||
if (src1.target() == GL_TEXTURE_EXTERNAL_OES) {
|
||||
if (!ext_rgb_renderer_) {
|
||||
ext_rgb_renderer_ = absl::make_unique<QuadRenderer>();
|
||||
|
@ -396,7 +416,9 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
::mediapipe::kBasicTexturedFragmentShaderOES, {"video_frame"}));
|
||||
}
|
||||
renderer = ext_rgb_renderer_.get();
|
||||
} else {
|
||||
} else // NOLINT(readability/braces)
|
||||
#endif // __ANDROID__
|
||||
{
|
||||
if (!rgb_renderer_) {
|
||||
rgb_renderer_ = absl::make_unique<QuadRenderer>();
|
||||
RETURN_IF_ERROR(rgb_renderer_->GlSetup());
|
||||
|
@ -438,7 +460,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
|
|||
auto output = dst.GetFrame<GpuBuffer>();
|
||||
cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp());
|
||||
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ || iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
|
|
@ -45,11 +45,11 @@ class OpenCvPutTextCalculator : public CalculatorBase {
|
|||
|
||||
::mediapipe::Status OpenCvPutTextCalculator::Process(CalculatorContext* cc) {
|
||||
const std::string& text_content = cc->Inputs().Index(0).Get<std::string>();
|
||||
cv::Mat mat = cv::Mat::zeros(640, 640, CV_8UC3);
|
||||
cv::Mat mat = cv::Mat::zeros(640, 640, CV_8UC4);
|
||||
cv::putText(mat, text_content, cv::Point(15, 70), cv::FONT_HERSHEY_PLAIN, 3,
|
||||
cv::Scalar(255, 255, 0), 4);
|
||||
cv::Scalar(255, 255, 0, 255), 4);
|
||||
std::unique_ptr<ImageFrame> output_frame = absl::make_unique<ImageFrame>(
|
||||
ImageFormat::SRGB, mat.size().width, mat.size().height);
|
||||
ImageFormat::SRGBA, mat.size().width, mat.size().height);
|
||||
mat.copyTo(formats::MatView(output_frame.get()));
|
||||
cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
|
|
|
@ -21,12 +21,12 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gl_simple_shaders.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "mediapipe/gpu/shader_util.h"
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
namespace {
|
||||
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
|
||||
|
@ -95,10 +95,10 @@ class RecolorCalculator : public CalculatorBase {
|
|||
mediapipe::RecolorCalculatorOptions::MaskChannel mask_channel_;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
GLuint program_ = 0;
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
};
|
||||
REGISTER_CALCULATOR(RecolorCalculator);
|
||||
|
||||
|
@ -107,46 +107,48 @@ REGISTER_CALCULATOR(RecolorCalculator);
|
|||
RET_CHECK(!cc->Inputs().GetTags().empty());
|
||||
RET_CHECK(!cc->Outputs().GetTags().empty());
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU")) {
|
||||
cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
|
||||
}
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
if (cc->Inputs().HasTag("IMAGE")) {
|
||||
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Inputs().HasTag("MASK_GPU")) {
|
||||
cc->Inputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>();
|
||||
}
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
if (cc->Inputs().HasTag("MASK")) {
|
||||
cc->Inputs().Tag("MASK").Set<ImageFrame>();
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Outputs().HasTag("IMAGE_GPU")) {
|
||||
cc->Outputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
|
||||
}
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
if (cc->Outputs().HasTag("IMAGE")) {
|
||||
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status RecolorCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU")) {
|
||||
use_gpu_ = true;
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
}
|
||||
|
||||
RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
@ -156,7 +158,7 @@ REGISTER_CALCULATOR(RecolorCalculator);
|
|||
|
||||
::mediapipe::Status RecolorCalculator::Process(CalculatorContext* cc) {
|
||||
if (use_gpu_) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
|
||||
if (!initialized_) {
|
||||
|
@ -166,7 +168,7 @@ REGISTER_CALCULATOR(RecolorCalculator);
|
|||
RETURN_IF_ERROR(RenderGpu(cc));
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
} else {
|
||||
RETURN_IF_ERROR(RenderCpu(cc));
|
||||
}
|
||||
|
@ -174,12 +176,12 @@ REGISTER_CALCULATOR(RecolorCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status RecolorCalculator::Close(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
gpu_helper_.RunInGlContext([this] {
|
||||
if (program_) glDeleteProgram(program_);
|
||||
program_ = 0;
|
||||
});
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -192,7 +194,7 @@ REGISTER_CALCULATOR(RecolorCalculator);
|
|||
if (cc->Inputs().Tag("MASK_GPU").IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
// Get inputs and setup output.
|
||||
const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value();
|
||||
const Packet& mask_packet = cc->Inputs().Tag("MASK_GPU").Value();
|
||||
|
@ -231,13 +233,13 @@ REGISTER_CALCULATOR(RecolorCalculator);
|
|||
img_tex.Release();
|
||||
mask_tex.Release();
|
||||
dst_tex.Release();
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
void RecolorCalculator::GlRender() {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
static const GLfloat square_vertices[] = {
|
||||
-1.0f, -1.0f, // bottom left
|
||||
1.0f, -1.0f, // bottom right
|
||||
|
@ -285,7 +287,7 @@ void RecolorCalculator::GlRender() {
|
|||
glBindVertexArray(0);
|
||||
glDeleteVertexArrays(1, &vao);
|
||||
glDeleteBuffers(2, vbo);
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
}
|
||||
|
||||
::mediapipe::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) {
|
||||
|
@ -303,7 +305,7 @@ void RecolorCalculator::GlRender() {
|
|||
}
|
||||
|
||||
::mediapipe::Status RecolorCalculator::InitGpu(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
const GLint attr_location[NUM_ATTRIBUTES] = {
|
||||
ATTRIB_VERTEX,
|
||||
ATTRIB_TEXTURE_POSITION,
|
||||
|
@ -372,7 +374,7 @@ void RecolorCalculator::GlRender() {
|
|||
glUniform1i(glGetUniformLocation(program_, "mask"), 2);
|
||||
glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0], color_[1],
|
||||
color_[2]);
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
|
|
@ -25,11 +25,12 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/vector.h"
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gl_simple_shaders.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "mediapipe/gpu/shader_util.h"
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -98,18 +99,18 @@ class SetAlphaCalculator : public CalculatorBase {
|
|||
::mediapipe::Status RenderGpu(CalculatorContext* cc);
|
||||
::mediapipe::Status RenderCpu(CalculatorContext* cc);
|
||||
|
||||
::mediapipe::Status GlRender(CalculatorContext* cc);
|
||||
::mediapipe::Status GlSetup(CalculatorContext* cc);
|
||||
void GlRender(CalculatorContext* cc);
|
||||
|
||||
mediapipe::SetAlphaCalculatorOptions options_;
|
||||
float alpha_value_ = -1.f;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
bool gpu_initialized_ = false;
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
GLuint program_ = 0;
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
};
|
||||
REGISTER_CALCULATOR(SetAlphaCalculator);
|
||||
|
||||
|
@ -126,52 +127,54 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
|
|||
}
|
||||
|
||||
// Input image to add/edit alpha channel.
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Inputs().HasTag(kInputFrameTagGpu)) {
|
||||
cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>();
|
||||
}
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
if (cc->Inputs().HasTag(kInputFrameTag)) {
|
||||
cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
// Input alpha image mask (optional)
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Inputs().HasTag(kInputAlphaTagGpu)) {
|
||||
cc->Inputs().Tag(kInputAlphaTagGpu).Set<mediapipe::GpuBuffer>();
|
||||
}
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
if (cc->Inputs().HasTag(kInputAlphaTag)) {
|
||||
cc->Inputs().Tag(kInputAlphaTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
// RGBA output image.
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Outputs().HasTag(kOutputFrameTagGpu)) {
|
||||
cc->Outputs().Tag(kOutputFrameTagGpu).Set<mediapipe::GpuBuffer>();
|
||||
}
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
if (cc->Outputs().HasTag(kOutputFrameTag)) {
|
||||
cc->Outputs().Tag(kOutputFrameTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status SetAlphaCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
options_ = cc->Options<mediapipe::SetAlphaCalculatorOptions>();
|
||||
|
||||
if (cc->Inputs().HasTag(kInputFrameTagGpu) &&
|
||||
cc->Outputs().HasTag(kOutputFrameTagGpu)) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
use_gpu_ = true;
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet.";
|
||||
#endif // __ANDROID__
|
||||
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
|
||||
#endif // __ANDROID__ or iOS
|
||||
}
|
||||
|
||||
// Get global value from options (-1 if not set).
|
||||
|
@ -184,7 +187,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
|
|||
RET_CHECK_FAIL() << "Must use either image mask or options alpha value.";
|
||||
|
||||
if (use_gpu_) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#endif
|
||||
}
|
||||
|
@ -194,7 +197,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
|
|||
|
||||
::mediapipe::Status SetAlphaCalculator::Process(CalculatorContext* cc) {
|
||||
if (use_gpu_) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
|
||||
if (!gpu_initialized_) {
|
||||
|
@ -204,7 +207,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
|
|||
RETURN_IF_ERROR(RenderGpu(cc));
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
} else {
|
||||
RETURN_IF_ERROR(RenderCpu(cc));
|
||||
}
|
||||
|
@ -213,12 +216,12 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status SetAlphaCalculator::Close(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
gpu_helper_.RunInGlContext([this] {
|
||||
if (program_) glDeleteProgram(program_);
|
||||
program_ = 0;
|
||||
});
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -292,7 +295,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
|
|||
if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
// Setup source texture.
|
||||
const auto& input_frame =
|
||||
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
|
||||
|
@ -345,13 +348,13 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
|
|||
// Cleanup
|
||||
input_texture.Release();
|
||||
output_texture.Release();
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status SetAlphaCalculator::GlRender(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__)
|
||||
void SetAlphaCalculator::GlRender(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
static const GLfloat square_vertices[] = {
|
||||
-1.0f, -1.0f, // bottom left
|
||||
1.0f, -1.0f, // bottom right
|
||||
|
@ -400,13 +403,11 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
|
|||
glDeleteVertexArrays(1, &vao);
|
||||
glDeleteBuffers(2, vbo);
|
||||
|
||||
#endif // __ANDROID__
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
#endif // __ANDROID__ or iOS
|
||||
}
|
||||
|
||||
::mediapipe::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
const GLint attr_location[NUM_ATTRIBUTES] = {
|
||||
ATTRIB_VERTEX,
|
||||
ATTRIB_TEXTURE_POSITION,
|
||||
|
@ -417,6 +418,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
|
|||
};
|
||||
|
||||
// Shader to overlay a texture onto another when overlay is non-zero.
|
||||
// TODO split into two shaders to handle alpha_value<0 separately
|
||||
const GLchar* frag_src = GLES_VERSION_COMPAT
|
||||
R"(
|
||||
#if __VERSION__ < 130
|
||||
|
@ -458,7 +460,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
|
|||
glUniform1i(glGetUniformLocation(program_, "alpha_mask"), 2);
|
||||
glUniform1f(glGetUniformLocation(program_, "alpha_value"), alpha_value_);
|
||||
|
||||
#endif // __ANDROID__
|
||||
#endif // __ANDROID__ or iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
|
|
@ -1,4 +1,17 @@
|
|||
// Options for SetAlphaCalculator
|
||||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
|
|
@ -417,6 +417,9 @@ cc_library(
|
|||
"//mediapipe:android": [
|
||||
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos",
|
||||
],
|
||||
"//mediapipe:ios": [
|
||||
"@org_tensorflow//tensorflow/core:ios_tensorflow_lib",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -435,6 +438,9 @@ cc_library(
|
|||
"//mediapipe:android": [
|
||||
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos",
|
||||
],
|
||||
"//mediapipe:ios": [
|
||||
"@org_tensorflow//tensorflow/core:ios_tensorflow_lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
|
@ -459,6 +465,10 @@ cc_library(
|
|||
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos",
|
||||
"//mediapipe/android/file/base",
|
||||
],
|
||||
"//mediapipe:ios": [
|
||||
"@org_tensorflow//tensorflow/core:ios_tensorflow_lib",
|
||||
"//mediapipe/android/file/base",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -783,6 +793,7 @@ cc_test(
|
|||
"@org_tensorflow//tensorflow/core/kernels:io",
|
||||
"@org_tensorflow//tensorflow/core/kernels:state",
|
||||
"@org_tensorflow//tensorflow/core/kernels:string",
|
||||
"@org_tensorflow//tensorflow/core/kernels/data:tensor_dataset_op",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -813,6 +824,7 @@ cc_test(
|
|||
"@org_tensorflow//tensorflow/core/kernels:io",
|
||||
"@org_tensorflow//tensorflow/core/kernels:state",
|
||||
"@org_tensorflow//tensorflow/core/kernels:string",
|
||||
"@org_tensorflow//tensorflow/core/kernels/data:tensor_dataset_op",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -949,6 +961,9 @@ cc_test(
|
|||
"//mediapipe:android": [
|
||||
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_with_ops_lite_proto_no_rtti_lib",
|
||||
],
|
||||
"//mediapipe:ios": [
|
||||
"@org_tensorflow//tensorflow/core:ios_tensorflow_test_lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
|
|
|
@ -118,7 +118,7 @@ REGISTER_CALCULATOR(MatrixToTensorCalculator);
|
|||
|
||||
// Inform the framework that we always output at the same timestamp
|
||||
// as we receive a packet at.
|
||||
cc->SetOffset(mediapipe::TimestampDiff(0));
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ const char kImageTag[] = "IMAGE";
|
|||
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
|
||||
const char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
|
||||
const char kBBoxTag[] = "BBOX";
|
||||
const char kKeypointsTag[] = "KEYPOINTS";
|
||||
const char kSegmentationMaskTag[] = "CLASS_SEGMENTATION";
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
@ -54,16 +55,11 @@ namespace mpms = ::mediapipe::mediasequence;
|
|||
// stores the encoded optical flow from the same calculator, "BBOX" which stores
|
||||
// bounding boxes from vector<Detections>, and streams with the
|
||||
// "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector<float>'s
|
||||
// associated with the name ${NAME}. Audio streams (i.e. Matrix with a
|
||||
// TimeSeriesHeader) are given extra packing and unpacking support and are named
|
||||
// similar to floats with the pattern "AUDIO_${NAME}". "IMAGE_${NAME}" and
|
||||
// "BBOX_${NAME}" will also store prefixed versions of each stream, which allows
|
||||
// for multiple image streams to be included. However, the default names are
|
||||
// suppored by more tools. "ENCODED_MEDIA" stores a video encoding for the clip
|
||||
// directly. The last packet on this stream is stored, and can be unpacked with
|
||||
// the metadata generator. Because the media decoder always starts with
|
||||
// timestamp zero, the "ENCODED_MEDIA_START_TIMESTAMP" should be recorded as
|
||||
// well. Use the FirstTimestampCalculator to determine this value.
|
||||
// associated with the name ${NAME}. "KEYPOINTS" stores a map of 2D keypoints
|
||||
// from unordered_map<std::string, vector<pair<float, float>>>. "IMAGE_${NAME}",
|
||||
// "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store prefixed versions of
|
||||
// each stream, which allows for multiple image streams to be included. However,
|
||||
// the default names are suppored by more tools.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
|
@ -122,6 +118,21 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
}
|
||||
cc->Inputs().Tag(tag).Set<OpenCvImageEncoderCalculatorResults>();
|
||||
}
|
||||
if (absl::StartsWith(tag, kKeypointsTag)) {
|
||||
std::string key = "";
|
||||
if (tag != kKeypointsTag) {
|
||||
int tag_length = sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1;
|
||||
if (tag[tag_length] == '_') {
|
||||
key = tag.substr(tag_length + 1);
|
||||
} else {
|
||||
continue; // Skip keys that don't match "(kKeypointsTag)_?"
|
||||
}
|
||||
}
|
||||
cc->Inputs()
|
||||
.Tag(tag)
|
||||
.Set<std::unordered_map<std::string,
|
||||
std::vector<std::pair<float, float>>>>();
|
||||
}
|
||||
if (absl::StartsWith(tag, kBBoxTag)) {
|
||||
std::string key = "";
|
||||
if (tag != kBBoxTag) {
|
||||
|
@ -169,8 +180,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
features_present_[tag] = false;
|
||||
}
|
||||
|
||||
if (cc->Options()
|
||||
.GetExtension(PackMediaSequenceCalculatorOptions::ext)
|
||||
if (cc->Options<PackMediaSequenceCalculatorOptions>()
|
||||
.replace_data_instead_of_append()) {
|
||||
for (const auto& tag : cc->Inputs().GetTags()) {
|
||||
if (absl::StartsWith(tag, kImageTag)) {
|
||||
|
@ -186,13 +196,19 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
mpms::ClearImageEncoded(key, sequence_.get());
|
||||
mpms::ClearImageTimestamp(key, sequence_.get());
|
||||
}
|
||||
if (absl::StartsWith(tag, kBBoxTag)) {
|
||||
std::string key = "";
|
||||
if (tag != kBBoxTag) {
|
||||
int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1;
|
||||
if (tag[tag_length] == '_') {
|
||||
key = tag.substr(tag_length + 1);
|
||||
} else {
|
||||
continue; // Skip keys that don't match "(kBBoxTag)_?"
|
||||
}
|
||||
if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) {
|
||||
mpms::ClearForwardFlowEncoded(sequence_.get());
|
||||
mpms::ClearForwardFlowTimestamp(sequence_.get());
|
||||
}
|
||||
|
||||
for (const auto& tag : cc->Inputs().GetTags()) {
|
||||
mpms::ClearBBox(key, sequence_.get());
|
||||
mpms::ClearBBoxTimestamp(key, sequence_.get());
|
||||
}
|
||||
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
|
||||
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
|
||||
sizeof(*kFloatFeaturePrefixTag) -
|
||||
|
@ -200,6 +216,16 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
mpms::ClearFeatureFloats(key, sequence_.get());
|
||||
mpms::ClearFeatureTimestamp(key, sequence_.get());
|
||||
}
|
||||
if (absl::StartsWith(tag, kKeypointsTag)) {
|
||||
std::string key =
|
||||
tag.substr(sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1);
|
||||
mpms::ClearBBoxPoint(key, sequence_.get());
|
||||
mpms::ClearBBoxTimestamp(key, sequence_.get());
|
||||
}
|
||||
}
|
||||
if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) {
|
||||
mpms::ClearForwardFlowEncoded(sequence_.get());
|
||||
mpms::ClearForwardFlowTimestamp(sequence_.get());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -228,11 +254,11 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override {
|
||||
auto& options =
|
||||
cc->Options().GetExtension(PackMediaSequenceCalculatorOptions::ext);
|
||||
auto& options = cc->Options<PackMediaSequenceCalculatorOptions>();
|
||||
if (options.reconcile_metadata()) {
|
||||
RET_CHECK_OK(mpms::ReconcileMetadata(options.reconcile_bbox_annotations(),
|
||||
sequence_.get()));
|
||||
RET_CHECK_OK(mpms::ReconcileMetadata(
|
||||
options.reconcile_bbox_annotations(),
|
||||
options.reconcile_region_annotations(), sequence_.get()));
|
||||
}
|
||||
|
||||
if (options.output_only_if_all_present()) {
|
||||
|
@ -260,6 +286,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
for (const auto& tag : cc->Inputs().GetTags()) {
|
||||
if (!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||
features_present_[tag] = true;
|
||||
}
|
||||
if (absl::StartsWith(tag, kImageTag) &&
|
||||
!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||
std::string key = "";
|
||||
|
@ -281,23 +310,29 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
sequence_.get());
|
||||
mpms::AddImageEncoded(key, image.encoded_image(), sequence_.get());
|
||||
}
|
||||
if (absl::StartsWith(tag, kKeypointsTag) &&
|
||||
!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||
std::string key = "";
|
||||
if (tag != kImageTag) {
|
||||
int tag_length = sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1;
|
||||
if (tag[tag_length] == '_') {
|
||||
key = tag.substr(tag_length + 1);
|
||||
} else {
|
||||
continue; // Skip keys that don't match "(kKeypointsTag)_?"
|
||||
}
|
||||
if (cc->Inputs().HasTag(kForwardFlowEncodedTag) &&
|
||||
!cc->Inputs().Tag(kForwardFlowEncodedTag).IsEmpty()) {
|
||||
const OpenCvImageEncoderCalculatorResults& forward_flow =
|
||||
}
|
||||
const auto& keypoints =
|
||||
cc->Inputs()
|
||||
.Tag(kForwardFlowEncodedTag)
|
||||
.Get<OpenCvImageEncoderCalculatorResults>();
|
||||
if (!forward_flow.has_encoded_image()) {
|
||||
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "No encoded forward flow";
|
||||
}
|
||||
mpms::AddForwardFlowTimestamp(cc->InputTimestamp().Value(),
|
||||
sequence_.get());
|
||||
mpms::AddForwardFlowEncoded(forward_flow.encoded_image(),
|
||||
.Tag(tag)
|
||||
.Get<std::unordered_map<
|
||||
std::string, std::vector<std::pair<float, float>>>>();
|
||||
for (const auto& pair : keypoints) {
|
||||
mpms::AddBBoxTimestamp(mpms::merge_prefix(key, pair.first),
|
||||
cc->InputTimestamp().Value(), sequence_.get());
|
||||
mpms::AddBBoxPoint(mpms::merge_prefix(key, pair.first), pair.second,
|
||||
sequence_.get());
|
||||
}
|
||||
for (const auto& tag : cc->Inputs().GetTags()) {
|
||||
}
|
||||
if (absl::StartsWith(tag, kFloatFeaturePrefixTag) &&
|
||||
!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
|
||||
|
@ -309,8 +344,6 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
cc->Inputs().Tag(tag).Get<std::vector<float>>(),
|
||||
sequence_.get());
|
||||
}
|
||||
}
|
||||
for (const auto& tag : cc->Inputs().GetTags()) {
|
||||
if (absl::StartsWith(tag, kBBoxTag) && !cc->Inputs().Tag(tag).IsEmpty()) {
|
||||
std::string key = "";
|
||||
if (tag != kBBoxTag) {
|
||||
|
@ -349,15 +382,30 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
mpms::AddBBoxTimestamp(key, cc->InputTimestamp().Value(),
|
||||
sequence_.get());
|
||||
if (!predicted_class_strings.empty()) {
|
||||
mpms::AddBBoxClassString(key, predicted_class_strings,
|
||||
mpms::AddBBoxLabelString(key, predicted_class_strings,
|
||||
sequence_.get());
|
||||
}
|
||||
if (!predicted_label_ids.empty()) {
|
||||
mpms::AddBBoxClassIndex(key, predicted_label_ids, sequence_.get());
|
||||
mpms::AddBBoxLabelIndex(key, predicted_label_ids, sequence_.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cc->Inputs().HasTag(kForwardFlowEncodedTag) &&
|
||||
!cc->Inputs().Tag(kForwardFlowEncodedTag).IsEmpty()) {
|
||||
const OpenCvImageEncoderCalculatorResults& forward_flow =
|
||||
cc->Inputs()
|
||||
.Tag(kForwardFlowEncodedTag)
|
||||
.Get<OpenCvImageEncoderCalculatorResults>();
|
||||
if (!forward_flow.has_encoded_image()) {
|
||||
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "No encoded forward flow";
|
||||
}
|
||||
mpms::AddForwardFlowTimestamp(cc->InputTimestamp().Value(),
|
||||
sequence_.get());
|
||||
mpms::AddForwardFlowEncoded(forward_flow.encoded_image(),
|
||||
sequence_.get());
|
||||
}
|
||||
if (cc->Inputs().HasTag(kSegmentationMaskTag) &&
|
||||
!cc->Inputs().Tag(kSegmentationMaskTag).IsEmpty()) {
|
||||
bool already_has_mask = false;
|
||||
|
@ -387,11 +435,6 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
}
|
||||
}
|
||||
}
|
||||
for (const auto& tag : cc->Inputs().GetTags()) {
|
||||
if (!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||
features_present_[tag] = true;
|
||||
}
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -32,18 +32,29 @@ message PackMediaSequenceCalculatorOptions {
|
|||
// (e.g. fills in the image height, width, and number of frames.)
|
||||
optional bool reconcile_metadata = 2 [default = true];
|
||||
|
||||
// If true, updates the metadata for sequences with bounding boxes. This will
|
||||
// align each bounding box annotation with the nearest frame and insert empty
|
||||
// annotations as needed to satisfy the frame rate.
|
||||
// If true, updates the metadata for bounding box portions of sequences. This
|
||||
// will align each bounding box annotation with the nearest frame and insert
|
||||
// empty annotations as needed to satisfy the frame rate.
|
||||
// NOTE: IF YOU DOWNSAMPLE IN TIME YOU WILL LOSE ANNOTATIONS.
|
||||
// If two or more annotations are closest to the same frame, then only
|
||||
// the closest annotation is saved. This matches the behavior of
|
||||
// downsampling images in time.
|
||||
optional bool reconcile_bbox_annotations = 5 [default = true];
|
||||
optional bool reconcile_bbox_annotations = 5 [default = false];
|
||||
|
||||
// If true, updates the metadata for all regions annotations, regardless of
|
||||
// prefix, in the sequence. This will align each region annotation with the
|
||||
// nearest frame and insert empty annotations as needed to satisfy the frame
|
||||
// rate. This is particularly useful for key point annotations that are
|
||||
// represented as region points. This does not exclude bboxes.
|
||||
// NOTE: IF YOU DOWNSAMPLE IN TIME YOU WILL LOSE ANNOTATIONS.
|
||||
// If two or more annotations are closest to the same frame, then only
|
||||
// the closest annotation is saved. This matches the behavior of
|
||||
// downsampling images in time.
|
||||
optional bool reconcile_region_annotations = 6 [default = true];
|
||||
|
||||
// If true, the SequenceExample is output only if all input streams are
|
||||
// present.
|
||||
optional bool output_only_if_all_present = 3 [default = false];
|
||||
optional bool output_only_if_all_present = 3 [default = true];
|
||||
|
||||
// If true, will remove all data from a sequence example for a corresponding
|
||||
// input stream. E.g. if images are already present and an IMAGE stream is
|
||||
|
|
|
@ -348,15 +348,51 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) {
|
|||
ASSERT_NEAR(1.0, rect.ymax(), 0.001);
|
||||
}
|
||||
auto class_strings =
|
||||
mpms::GetPredictedBBoxClassStringAt(output_sequence, i);
|
||||
mpms::GetPredictedBBoxLabelStringAt(output_sequence, i);
|
||||
ASSERT_EQ("absolute bbox", class_strings[0]);
|
||||
ASSERT_EQ("relative bbox", class_strings[1]);
|
||||
auto class_indices = mpms::GetPredictedBBoxClassIndexAt(output_sequence, i);
|
||||
auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i);
|
||||
ASSERT_EQ(0, class_indices[0]);
|
||||
ASSERT_EQ(1, class_indices[1]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) {
|
||||
SetUpCalculator({"KEYPOINTS_TEST:keypoints"}, {}, false, true);
|
||||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||
std::string test_video_id = "test_video_id";
|
||||
mpms::SetClipMediaId(test_video_id, input_sequence.get());
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::pair<float, float>>> points =
|
||||
{{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}};
|
||||
runner_->MutableInputs()
|
||||
->Tag("KEYPOINTS_TEST")
|
||||
.packets.push_back(PointToForeign(&points).At(Timestamp(0)));
|
||||
runner_->MutableInputs()
|
||||
->Tag("KEYPOINTS_TEST")
|
||||
.packets.push_back(PointToForeign(&points).At(Timestamp(1)));
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MEDIAPIPE_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
||||
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
|
||||
ASSERT_EQ(2, mpms::GetBBoxPointSize("TEST/HEAD", output_sequence));
|
||||
ASSERT_EQ(2, mpms::GetBBoxPointSize("TEST/TAIL", output_sequence));
|
||||
ASSERT_NEAR(0.2,
|
||||
mpms::GetBBoxPointAt("TEST/HEAD", output_sequence, 0)[0].second,
|
||||
0.001);
|
||||
ASSERT_NEAR(0.5,
|
||||
mpms::GetBBoxPointAt("TEST/TAIL", output_sequence, 1)[0].first,
|
||||
0.001);
|
||||
}
|
||||
|
||||
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
|
||||
SetUpCalculator({"CLASS_SEGMENTATION:detections"}, {}, false, true);
|
||||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||
|
@ -395,7 +431,6 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
|
|||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
LOG(INFO) << "output_sequence: \n" << output_sequence.DebugString();
|
||||
|
||||
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
|
||||
ASSERT_EQ(height, mpms::GetImageHeight(output_sequence));
|
||||
|
@ -602,6 +637,10 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
|
|||
mpms::AddBBoxTimestamp(21, input_sequence.get());
|
||||
mpms::AddBBoxTimestamp(22, input_sequence.get());
|
||||
|
||||
mpms::AddBBoxTimestamp("PREFIX", 8, input_sequence.get());
|
||||
mpms::AddBBoxTimestamp("PREFIX", 9, input_sequence.get());
|
||||
mpms::AddBBoxTimestamp("PREFIX", 22, input_sequence.get());
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
Adopt(input_sequence.release());
|
||||
MEDIAPIPE_ASSERT_OK(runner_->Run());
|
||||
|
@ -617,6 +656,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
|
|||
ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 2), 30);
|
||||
ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 3), 40);
|
||||
ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 4), 50);
|
||||
|
||||
ASSERT_EQ(mpms::GetBBoxTimestampSize("PREFIX", output_sequence), 5);
|
||||
ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 0), 10);
|
||||
ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 1), 20);
|
||||
ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 2), 30);
|
||||
ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 3), 40);
|
||||
ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 4), 50);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -183,7 +183,7 @@ REGISTER_CALCULATOR(TensorToMatrixCalculator);
|
|||
header_ = *input_header;
|
||||
cc->Outputs().Tag(kMatrix).SetHeader(Adopt(input_header.release()));
|
||||
}
|
||||
cc->SetOffset(mediapipe::TimestampDiff(0));
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -232,8 +232,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
|||
current_timestamp_index_ = 0;
|
||||
|
||||
// Determine the data path and output it.
|
||||
const auto& options =
|
||||
cc->Options().GetExtension(UnpackMediaSequenceCalculatorOptions::ext);
|
||||
const auto& options = cc->Options<UnpackMediaSequenceCalculatorOptions>();
|
||||
const auto& sequence = cc->InputSidePackets()
|
||||
.Tag(kSequenceExampleTag)
|
||||
.Get<tensorflow::SequenceExample>();
|
||||
|
@ -338,7 +337,6 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
|||
? Timestamp::PostStream()
|
||||
: Timestamp(map_kv.second[i]);
|
||||
|
||||
LOG(INFO) << "key: " << map_kv.first;
|
||||
if (absl::StrContains(map_kv.first, mpms::GetImageTimestampKey())) {
|
||||
std::vector<std::string> pieces = absl::StrSplit(map_kv.first, '/');
|
||||
std::string feature_key = "";
|
||||
|
|
|
@ -40,7 +40,7 @@ message UnpackMediaSequenceCalculatorOptions {
|
|||
optional float extra_padding_from_media_decoder = 5 [default = 0.0];
|
||||
|
||||
// Stores the packet resampler settings for the graph. The most accurate
|
||||
// proceedure for sampling a range of frames is to request a padded time range
|
||||
// procedure for sampling a range of frames is to request a padded time range
|
||||
// from the MediaDecoderCalculator and then trim it down to the proper time
|
||||
// range with the PacketResamplerCalculator.
|
||||
optional PacketResamplerCalculatorOptions base_packet_resampler_options = 6;
|
||||
|
|
|
@ -460,6 +460,8 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) {
|
|||
}
|
||||
|
||||
TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) {
|
||||
// TODO: Suport proto3 proto.Any in CalculatorOptions.
|
||||
// TODO: Avoid proto2 extensions in "RESAMPLER_OPTIONS".
|
||||
CalculatorOptions options;
|
||||
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
|
||||
->set_padding_before_label(1);
|
||||
|
|
|
@ -61,6 +61,13 @@ proto_library(
|
|||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "tflite_tensors_to_landmarks_calculator_proto",
|
||||
srcs = ["tflite_tensors_to_landmarks_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "ssd_anchors_calculator_cc_proto",
|
||||
srcs = ["ssd_anchors_calculator.proto"],
|
||||
|
@ -109,6 +116,14 @@ mediapipe_cc_proto_library(
|
|||
deps = [":tflite_tensors_to_detections_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "tflite_tensors_to_landmarks_calculator_cc_proto",
|
||||
srcs = ["tflite_tensors_to_landmarks_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
deps = [":tflite_tensors_to_landmarks_calculator_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ssd_anchors_calculator",
|
||||
srcs = ["ssd_anchors_calculator.cc"],
|
||||
|
@ -168,6 +183,21 @@ cc_test(
|
|||
cc_library(
|
||||
name = "tflite_inference_calculator",
|
||||
srcs = ["tflite_inference_calculator.cc"],
|
||||
copts = select({
|
||||
"//mediapipe:ios": [
|
||||
"-std=c++11",
|
||||
"-x objective-c++",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkopts = select({
|
||||
"//mediapipe:ios": [
|
||||
"-framework CoreVideo",
|
||||
"-framework MetalKit",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tflite_inference_calculator_cc_proto",
|
||||
|
@ -179,12 +209,17 @@ cc_library(
|
|||
"//mediapipe/framework/port:ret_check",
|
||||
] + select({
|
||||
"//mediapipe:android": [
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
|
||||
],
|
||||
"//mediapipe:ios": [
|
||||
"//mediapipe/gpu:MPPMetalHelper",
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
|
@ -194,12 +229,28 @@ cc_library(
|
|||
cc_library(
|
||||
name = "tflite_converter_calculator",
|
||||
srcs = ["tflite_converter_calculator.cc"],
|
||||
copts = select({
|
||||
"//mediapipe:ios": [
|
||||
"-std=c++11",
|
||||
"-x objective-c++",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkopts = select({
|
||||
"//mediapipe:ios": [
|
||||
"-framework CoreVideo",
|
||||
"-framework MetalKit",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tflite_converter_calculator_cc_proto",
|
||||
"//mediapipe/util:resource_util",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||
"//mediapipe/framework/tool:status_util",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -208,13 +259,19 @@ cc_library(
|
|||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
] + select({
|
||||
"//mediapipe:android": [
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
|
||||
],
|
||||
"//mediapipe:ios": [
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"//mediapipe/gpu:MPPMetalHelper",
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
|
@ -229,6 +286,9 @@ cc_library(
|
|||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:opencv_imgcodecs",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -279,6 +339,32 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tflite_tensors_to_landmarks_calculator",
|
||||
srcs = ["tflite_tensors_to_landmarks_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tflite_tensors_to_landmarks_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tflite_tensors_to_floats_calculator",
|
||||
srcs = ["tflite_tensors_to_floats_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tflite_inference_calculator_test",
|
||||
srcs = ["tflite_inference_calculator_test.cc"],
|
||||
|
@ -298,3 +384,24 @@ cc_test(
|
|||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tflite_converter_calculator_test",
|
||||
srcs = ["tflite_converter_calculator_test.cc"],
|
||||
deps = [
|
||||
":tflite_converter_calculator",
|
||||
":tflite_converter_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -73,6 +73,8 @@ class SsdAnchorsCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
const SsdAnchorsCalculatorOptions& options =
|
||||
cc->Options<SsdAnchorsCalculatorOptions>();
|
||||
|
||||
|
|
|
@ -12,9 +12,14 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
#include "tensorflow/lite/error_reporter.h"
|
||||
|
@ -22,18 +27,41 @@
|
|||
|
||||
#if defined(__ANDROID__)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||
#endif // ANDROID
|
||||
#endif // __ANDROID__
|
||||
|
||||
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
#import <CoreVideo/CoreVideo.h>
|
||||
#import <Metal/Metal.h>
|
||||
#import <MetalKit/MetalKit.h>
|
||||
|
||||
#import "mediapipe/gpu/MPPMetalHelper.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||
#endif // iOS
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
typedef id<MTLBuffer> GpuTensor;
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
|
||||
// Commonly used to compute the number of blocks to launch in a kernel.
|
||||
int RoundUp(const int size, const int multiple) {
|
||||
return (size + multiple - 1) / multiple;
|
||||
int NumGroups(const int size, const int group_size) { // NOLINT
|
||||
return (size + group_size - 1) / group_size;
|
||||
}
|
||||
|
||||
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
|
||||
RowMajorMatrixXf;
|
||||
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>
|
||||
ColMajorMatrixXf;
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -43,30 +71,38 @@ using ::tflite::gpu::gl::GlBuffer;
|
|||
using ::tflite::gpu::gl::GlProgram;
|
||||
using ::tflite::gpu::gl::GlShader;
|
||||
struct GPUData {
|
||||
int width;
|
||||
int height;
|
||||
int channels;
|
||||
GlBuffer ssbo;
|
||||
int elements = 1;
|
||||
GlBuffer buffer;
|
||||
GlShader shader;
|
||||
GlProgram program;
|
||||
};
|
||||
#endif // ANDROID
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
struct GPUData {
|
||||
int elements = 1;
|
||||
id<MTLBuffer> buffer;
|
||||
id<MTLComputePipelineState> pipeline_state;
|
||||
};
|
||||
#endif
|
||||
|
||||
// Calculator for normalizing and converting an ImageFrame or GpuBuffer
|
||||
// into a TfLiteTensor (float 32) or tflite::gpu::GlBuffer, respetively.
|
||||
// Calculator for normalizing and converting an ImageFrame or Matrix
|
||||
// into a TfLiteTensor (float 32) or a GpuBuffer to a tflite::gpu::GlBuffer.
|
||||
//
|
||||
// This calculator is designed to be used with the TfLiteInferenceCalcualtor,
|
||||
// as a pre-processing step for calculator inputs.
|
||||
//
|
||||
// Input data is normalized to [-1,1] (default) or [0,1], specified by options.
|
||||
// IMAGE and IMAGE_GPU inputs are normalized to [-1,1] (default) or [0,1],
|
||||
// specified by options (unless outputting a quantized tensor).
|
||||
//
|
||||
// Input:
|
||||
// One of the following tags:
|
||||
// IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data).
|
||||
// IMAGE_GPU - GpuBuffer (assumed to be RGBA or RGB GL texture)
|
||||
// IMAGE_GPU - GpuBuffer (assumed to be RGBA or RGB GL texture).
|
||||
// MATRIX - Matrix.
|
||||
//
|
||||
// Output:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32
|
||||
// TENSORS_GPU - vector of GlBuffer
|
||||
// One of the following tags:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8.
|
||||
// TENSORS_GPU - vector of GlBuffer.
|
||||
//
|
||||
// Example use:
|
||||
// node {
|
||||
|
@ -83,7 +119,7 @@ struct GPUData {
|
|||
// IMPORTANT Notes:
|
||||
// No conversion between CPU/GPU is done.
|
||||
// Inputs/outputs must match type: CPU->CPU or GPU->GPU.
|
||||
// GPU tensors are currently only supported on Android.
|
||||
// GPU tensors are currently only supported on mobile platforms.
|
||||
// This calculator uses FixedSizeInputStreamHandler by default.
|
||||
//
|
||||
class TfLiteConverterCalculator : public CalculatorBase {
|
||||
|
@ -101,43 +137,62 @@ class TfLiteConverterCalculator : public CalculatorBase {
|
|||
::mediapipe::Status NormalizeImage(const ImageFrame& image_frame,
|
||||
bool zero_center, bool flip_vertically,
|
||||
float* tensor_buffer);
|
||||
::mediapipe::Status CopyMatrixToTensor(const Matrix& matrix,
|
||||
float* tensor_buffer);
|
||||
::mediapipe::Status ProcessCPU(CalculatorContext* cc);
|
||||
::mediapipe::Status ProcessGPU(CalculatorContext* cc);
|
||||
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GPUData> gpu_data_out_;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
MPPMetalHelper* gpu_helper_ = nullptr;
|
||||
std::unique_ptr<GPUData> gpu_data_out_;
|
||||
#endif
|
||||
|
||||
bool initialized_ = false;
|
||||
bool use_gpu_ = false;
|
||||
bool zero_center_ = true; // normalize range to [-1,1] | otherwise [0,1]
|
||||
bool flip_vertically_ = false;
|
||||
bool row_major_matrix_ = false;
|
||||
bool use_quantized_tensors_ = false;
|
||||
int max_num_channels_ = 3;
|
||||
};
|
||||
REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("IMAGE") || cc->Inputs().HasTag("IMAGE_GPU"));
|
||||
RET_CHECK(cc->Outputs().HasTag("TENSORS") ||
|
||||
const bool has_image_tag = cc->Inputs().HasTag("IMAGE");
|
||||
const bool has_image_gpu_tag = cc->Inputs().HasTag("IMAGE_GPU");
|
||||
const bool has_matrix_tag = cc->Inputs().HasTag("MATRIX");
|
||||
// Confirm only one of the input streams is present.
|
||||
RET_CHECK(has_image_tag ^ has_image_gpu_tag ^ has_matrix_tag &&
|
||||
!(has_image_tag && has_image_gpu_tag && has_matrix_tag));
|
||||
|
||||
// Confirm only one of the output streams is present.
|
||||
RET_CHECK(cc->Outputs().HasTag("TENSORS") ^
|
||||
cc->Outputs().HasTag("TENSORS_GPU"));
|
||||
|
||||
if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
|
||||
#if defined(__ANDROID__)
|
||||
if (cc->Inputs().HasTag("MATRIX")) cc->Inputs().Tag("MATRIX").Set<Matrix>();
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU"))
|
||||
cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
|
||||
#endif
|
||||
|
||||
if (cc->Outputs().HasTag("TENSORS"))
|
||||
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Outputs().HasTag("TENSORS_GPU"))
|
||||
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>();
|
||||
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
|
||||
#endif
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
#endif
|
||||
|
||||
// Assign this calculator's default InputStreamHandler.
|
||||
|
@ -147,14 +202,16 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU") ||
|
||||
cc->Outputs().HasTag("IMAGE_OUT_GPU")) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
use_gpu_ = true;
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet.";
|
||||
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -162,8 +219,13 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
// Cannot mix CPU/GPU streams.
|
||||
RET_CHECK(cc->Inputs().HasTag("IMAGE_GPU") &&
|
||||
cc->Outputs().HasTag("TENSORS_GPU"));
|
||||
// Cannot use quantization.
|
||||
use_quantized_tensors_ = false;
|
||||
#if defined(__ANDROID__)
|
||||
RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
RET_CHECK(gpu_helper_);
|
||||
#endif
|
||||
} else {
|
||||
interpreter_ = absl::make_unique<tflite::Interpreter>();
|
||||
|
@ -176,77 +238,59 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) {
|
||||
if (use_gpu_) {
|
||||
// GpuBuffer to tflite::gpu::GlBuffer conversion.
|
||||
#if defined(__ANDROID__)
|
||||
if (!initialized_) {
|
||||
RETURN_IF_ERROR(InitGpu(cc));
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
const auto& input =
|
||||
cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
|
||||
RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, &input]() -> ::mediapipe::Status {
|
||||
// Convert GL texture into TfLite GlBuffer (SSBO).
|
||||
auto src = gpu_helper_.CreateSourceTexture(input);
|
||||
glActiveTexture(GL_TEXTURE0 + 0);
|
||||
glBindTexture(GL_TEXTURE_2D, src.name());
|
||||
auto status = gpu_data_out_->ssbo.BindToIndex(1);
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
const tflite::gpu::uint3 workgroups = {
|
||||
RoundUp(gpu_data_out_->width, kWorkgroupSize),
|
||||
RoundUp(gpu_data_out_->height, kWorkgroupSize), 1};
|
||||
status = gpu_data_out_->program.Dispatch(workgroups);
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
|
||||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
src.Release();
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
|
||||
auto output_tensors = absl::make_unique<std::vector<GlBuffer>>();
|
||||
output_tensors->resize(1);
|
||||
for (int i = 0; i < 1; ++i) {
|
||||
GlBuffer& tensor = output_tensors->at(i);
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
auto status = CreateReadWriteShaderStorageBuffer<float>(
|
||||
gpu_data_out_->width * gpu_data_out_->height *
|
||||
gpu_data_out_->channels,
|
||||
&tensor);
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
tflite::gpu::gl::CopyBuffer(gpu_data_out_->ssbo, tensor);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("TENSORS_GPU")
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
#else
|
||||
RET_CHECK_FAIL()
|
||||
<< "GPU input on non-Android devices is not supported yet.";
|
||||
#endif
|
||||
// Convert to GPU tensors type.
|
||||
RETURN_IF_ERROR(ProcessGPU(cc));
|
||||
} else {
|
||||
// Convert to CPU tensors or Matrix type.
|
||||
RETURN_IF_ERROR(ProcessCPU(cc));
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__)
|
||||
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
|
||||
#endif
|
||||
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
gpu_data_out_.reset();
|
||||
#endif
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::ProcessCPU(
|
||||
CalculatorContext* cc) {
|
||||
if (cc->Inputs().HasTag("IMAGE")) {
|
||||
// CPU ImageFrame to TfLiteTensor conversion.
|
||||
|
||||
const auto& image_frame = cc->Inputs().Tag("IMAGE").Get<ImageFrame>();
|
||||
const int height = image_frame.Height();
|
||||
const int width = image_frame.Width();
|
||||
const int channels_preserved =
|
||||
std::min(image_frame.NumberOfChannels(), max_num_channels_);
|
||||
const int channels = image_frame.NumberOfChannels();
|
||||
const int channels_preserved = std::min(channels, max_num_channels_);
|
||||
|
||||
if (!initialized_) {
|
||||
if (!(image_frame.Format() == mediapipe::ImageFormat::SRGBA ||
|
||||
image_frame.Format() == mediapipe::ImageFormat::SRGB ||
|
||||
image_frame.Format() == mediapipe::ImageFormat::GRAY8 ||
|
||||
image_frame.Format() == mediapipe::ImageFormat::VEC32F1))
|
||||
RET_CHECK_FAIL() << "Unsupported CPU input format.";
|
||||
|
||||
if (!initialized_) {
|
||||
interpreter_->SetTensorParametersReadWrite(
|
||||
0, kTfLiteFloat32, "", {channels_preserved}, TfLiteQuantization());
|
||||
TfLiteQuantization quant;
|
||||
if (use_quantized_tensors_) {
|
||||
RET_CHECK(image_frame.Format() != mediapipe::ImageFormat::VEC32F1)
|
||||
<< "Only 8-bit input images are supported for quantization.";
|
||||
// Optional: Set 'quant' quantization params here if needed.
|
||||
interpreter_->SetTensorParametersReadWrite(0, kTfLiteUInt8, "",
|
||||
{channels_preserved}, quant);
|
||||
} else {
|
||||
// Default TfLiteQuantization used for no quantization.
|
||||
interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "",
|
||||
{channels_preserved}, quant);
|
||||
}
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
|
@ -256,9 +300,26 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
{height, width, channels_preserved});
|
||||
interpreter_->AllocateTensors();
|
||||
|
||||
// Copy image data into tensor.
|
||||
if (use_quantized_tensors_) {
|
||||
const int width_padding =
|
||||
image_frame.WidthStep() / image_frame.ByteDepth() - width * channels;
|
||||
const uint8* image_buffer =
|
||||
reinterpret_cast<const uint8*>(image_frame.PixelData());
|
||||
uint8* tensor_buffer = tensor->data.uint8;
|
||||
RET_CHECK(tensor_buffer);
|
||||
for (int row = 0; row < height; ++row) {
|
||||
for (int col = 0; col < width; ++col) {
|
||||
for (int channel = 0; channel < channels_preserved; ++channel) {
|
||||
*tensor_buffer++ = image_buffer[channel];
|
||||
}
|
||||
image_buffer += channels;
|
||||
}
|
||||
image_buffer += width_padding;
|
||||
}
|
||||
} else {
|
||||
float* tensor_buffer = tensor->data.f;
|
||||
RET_CHECK(tensor_buffer);
|
||||
|
||||
if (image_frame.ByteDepth() == 1) {
|
||||
RETURN_IF_ERROR(NormalizeImage<uint8>(image_frame, zero_center_,
|
||||
flip_vertically_, tensor_buffer));
|
||||
|
@ -269,6 +330,36 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
return ::mediapipe::InternalError(
|
||||
"Only byte-based (8 bit) and float (32 bit) images supported.");
|
||||
}
|
||||
}
|
||||
|
||||
auto output_tensors = absl::make_unique<std::vector<TfLiteTensor>>();
|
||||
output_tensors->emplace_back(*tensor);
|
||||
cc->Outputs().Tag("TENSORS").Add(output_tensors.release(),
|
||||
cc->InputTimestamp());
|
||||
} else if (cc->Inputs().HasTag("MATRIX")) {
|
||||
// CPU Matrix to TfLiteTensor conversion.
|
||||
|
||||
const auto& matrix = cc->Inputs().Tag("MATRIX").Get<Matrix>();
|
||||
const int height = matrix.rows();
|
||||
const int width = matrix.cols();
|
||||
const int channels = 1;
|
||||
|
||||
if (!initialized_) {
|
||||
interpreter_->SetTensorParametersReadWrite(
|
||||
/*tensor_index=*/0, /*type=*/kTfLiteFloat32, /*name=*/"",
|
||||
/*dims=*/{channels}, /*quantization=*/TfLiteQuantization());
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
const int tensor_idx = interpreter_->inputs()[0];
|
||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_idx);
|
||||
interpreter_->ResizeInputTensor(tensor_idx, {height, width, channels});
|
||||
interpreter_->AllocateTensors();
|
||||
|
||||
float* tensor_buffer = tensor->data.f;
|
||||
RET_CHECK(tensor_buffer);
|
||||
|
||||
RETURN_IF_ERROR(CopyMatrixToTensor(matrix, tensor_buffer));
|
||||
|
||||
auto output_tensors = absl::make_unique<std::vector<TfLiteTensor>>();
|
||||
output_tensors->emplace_back(*tensor);
|
||||
|
@ -279,43 +370,132 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
|
||||
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
|
||||
CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__)
|
||||
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
|
||||
#endif // __ANDROID__
|
||||
// GpuBuffer to tflite::gpu::GlBuffer conversion.
|
||||
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
|
||||
RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, &input]() -> ::mediapipe::Status {
|
||||
// Convert GL texture into TfLite GlBuffer (SSBO).
|
||||
auto src = gpu_helper_.CreateSourceTexture(input);
|
||||
glActiveTexture(GL_TEXTURE0 + 0);
|
||||
glBindTexture(GL_TEXTURE_2D, src.name());
|
||||
auto status = gpu_data_out_->buffer.BindToIndex(1);
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
const tflite::gpu::uint3 workgroups = {
|
||||
NumGroups(input.width(), kWorkgroupSize),
|
||||
NumGroups(input.height(), kWorkgroupSize), 1};
|
||||
status = gpu_data_out_->program.Dispatch(workgroups);
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
|
||||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
src.Release();
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
|
||||
// Copy into outputs.
|
||||
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
|
||||
output_tensors->resize(1);
|
||||
{
|
||||
GlBuffer& tensor = output_tensors->at(0);
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
auto status = CreateReadWriteShaderStorageBuffer<float>(
|
||||
gpu_data_out_->elements, &tensor);
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
tflite::gpu::gl::CopyBuffer(gpu_data_out_->buffer, tensor);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("TENSORS_GPU")
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
// GpuBuffer to id<MTLBuffer> conversion.
|
||||
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
|
||||
{
|
||||
id<MTLTexture> src_texture = [gpu_helper_ metalTextureWithGpuBuffer:input];
|
||||
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
|
||||
command_buffer.label = @"TfLiteConverterCalculatorConvert";
|
||||
id<MTLComputeCommandEncoder> compute_encoder =
|
||||
[command_buffer computeCommandEncoder];
|
||||
[compute_encoder setComputePipelineState:gpu_data_out_->pipeline_state];
|
||||
[compute_encoder setTexture:src_texture atIndex:0];
|
||||
[compute_encoder setBuffer:gpu_data_out_->buffer offset:0 atIndex:1];
|
||||
MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1);
|
||||
MTLSize threadgroups =
|
||||
MTLSizeMake(NumGroups(input.width(), kWorkgroupSize),
|
||||
NumGroups(input.height(), kWorkgroupSize), 1);
|
||||
[compute_encoder dispatchThreadgroups:threadgroups
|
||||
threadsPerThreadgroup:threads_per_group];
|
||||
[compute_encoder endEncoding];
|
||||
[command_buffer commit];
|
||||
[command_buffer waitUntilCompleted];
|
||||
}
|
||||
|
||||
// Copy into outputs.
|
||||
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
|
||||
{
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
|
||||
command_buffer.label = @"TfLiteConverterCalculatorCopy";
|
||||
id<MTLBuffer> tensor =
|
||||
[device newBufferWithLength:gpu_data_out_->elements * sizeof(float)
|
||||
options:MTLResourceStorageModeShared];
|
||||
id<MTLBlitCommandEncoder> blit_command =
|
||||
[command_buffer blitCommandEncoder];
|
||||
[blit_command copyFromBuffer:gpu_data_out_->buffer
|
||||
sourceOffset:0
|
||||
toBuffer:tensor
|
||||
destinationOffset:0
|
||||
size:gpu_data_out_->elements * sizeof(float)];
|
||||
[blit_command endEncoding];
|
||||
[command_buffer commit];
|
||||
[command_buffer waitUntilCompleted];
|
||||
|
||||
output_tensors->push_back(tensor);
|
||||
}
|
||||
|
||||
cc->Outputs()
|
||||
.Tag("TENSORS_GPU")
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
|
||||
#endif
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__)
|
||||
// Get input image sizes.
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
// Configure inputs.
|
||||
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
|
||||
|
||||
mediapipe::ImageFormat::Format format =
|
||||
mediapipe::ImageFormatForGpuBufferFormat(input.format());
|
||||
|
||||
gpu_data_out_ = absl::make_unique<GPUData>();
|
||||
gpu_data_out_->height = input.height();
|
||||
gpu_data_out_->width = input.width();
|
||||
gpu_data_out_->channels = max_num_channels_; // desired output channels
|
||||
|
||||
gpu_data_out_->elements = input.height() * input.width() * max_num_channels_;
|
||||
const bool include_alpha = (max_num_channels_ == 4);
|
||||
|
||||
if (!(format == mediapipe::ImageFormat::SRGB ||
|
||||
format == mediapipe::ImageFormat::SRGBA))
|
||||
RET_CHECK_FAIL() << "Unsupported GPU input format.";
|
||||
|
||||
if (include_alpha && (format != mediapipe::ImageFormat::SRGBA))
|
||||
RET_CHECK_FAIL() << "Num input channels is less than desired output.";
|
||||
#endif
|
||||
|
||||
// Shader to convert GL Texture to Shader Storage Buffer Object (SSBO),
|
||||
// with normalization to either: [0,1] or [-1,1].
|
||||
#if defined(__ANDROID__)
|
||||
// Device memory.
|
||||
auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
|
||||
gpu_data_out_->width * gpu_data_out_->height * gpu_data_out_->channels,
|
||||
&gpu_data_out_->ssbo);
|
||||
gpu_data_out_->elements, &gpu_data_out_->buffer);
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
|
||||
// Shader to convert GL Texture to Shader Storage Buffer Object (SSBO),
|
||||
// with normalization to either: [0,1] or [-1,1].
|
||||
const std::string shader_source = absl::Substitute(
|
||||
R"( #version 310 es
|
||||
layout(local_size_x = $0, local_size_y = $0) in;
|
||||
|
@ -333,8 +513,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
output_data.elements[linear_index + 2] = pixel.z;
|
||||
$6 // alpha channel
|
||||
})",
|
||||
/*$0=*/kWorkgroupSize, /*$1=*/gpu_data_out_->width,
|
||||
/*$2=*/gpu_data_out_->height,
|
||||
/*$0=*/kWorkgroupSize, /*$1=*/input.width(), /*$2=*/input.height(),
|
||||
/*$3=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "",
|
||||
/*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y",
|
||||
/*$5=*/
|
||||
|
@ -353,7 +532,65 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
#endif // ANDROID
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
RET_CHECK(include_alpha)
|
||||
<< "iOS GPU inference currently accepts only RGBA input.";
|
||||
|
||||
// Device memory.
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
gpu_data_out_->buffer =
|
||||
[device newBufferWithLength:gpu_data_out_->elements * sizeof(float)
|
||||
options:MTLResourceStorageModeShared];
|
||||
|
||||
// Shader to convert GL Texture to Metal Buffer,
|
||||
// with normalization to either: [0,1] or [-1,1].
|
||||
const std::string shader_source = absl::Substitute(
|
||||
R"(
|
||||
#include <simd/simd.h>
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
kernel void convertKernel(
|
||||
texture2d<half, access::sample> in_tex [[ texture(0) ]],
|
||||
device float* out_buf [[ buffer(1) ]],
|
||||
uint2 gid [[ thread_position_in_grid ]]) {
|
||||
if (gid.x >= in_tex.get_width() || gid.y >= in_tex.get_height()) return;
|
||||
constexpr sampler texture_sampler(coord::pixel, address::clamp_to_edge);
|
||||
const float2 coord = float2(gid.x, gid.y);
|
||||
$0 pixel = $0(in_tex.sample(texture_sampler, coord).$1);
|
||||
$2 // normalize [-1,1]
|
||||
const int linear_index = $4 * ($3 * in_tex.get_width() + gid.x);
|
||||
out_buf[linear_index + 0] = pixel.x;
|
||||
out_buf[linear_index + 1] = pixel.y;
|
||||
out_buf[linear_index + 2] = pixel.z;
|
||||
$5 // alpha channel
|
||||
}
|
||||
)",
|
||||
/*$0=*/include_alpha ? "float4" : "float3",
|
||||
/*$1=*/include_alpha ? "rgba" : "rgb",
|
||||
/*$2=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "",
|
||||
/*$3=*/flip_vertically_ ? "(in_tex.get_height() - 1 - gid.y)" : "gid.y",
|
||||
/*$4=*/include_alpha ? 4 : 3,
|
||||
/*$5=*/include_alpha ? "out_buf[linear_index + 3] = pixel.w;" : "");
|
||||
|
||||
NSString* library_source =
|
||||
[NSString stringWithUTF8String:shader_source.c_str()];
|
||||
NSError* error = nil;
|
||||
id<MTLLibrary> library =
|
||||
[device newLibraryWithSource:library_source options:nullptr error:&error];
|
||||
RET_CHECK(library != nil) << "Couldn't create shader library "
|
||||
<< [[error localizedDescription] UTF8String];
|
||||
id<MTLFunction> kernel_func = nil;
|
||||
kernel_func = [library newFunctionWithName:@"convertKernel"];
|
||||
RET_CHECK(kernel_func != nil) << "Couldn't create kernel function.";
|
||||
gpu_data_out_->pipeline_state =
|
||||
[device newComputePipelineStateWithFunction:kernel_func error:&error];
|
||||
RET_CHECK(gpu_data_out_->pipeline_state != nil)
|
||||
<< "Couldn't create pipeline state "
|
||||
<< [[error localizedDescription] UTF8String];
|
||||
#endif
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -370,11 +607,23 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
// Get y-flip mode.
|
||||
flip_vertically_ = options.flip_vertically();
|
||||
|
||||
// Get row_major_matrix mode.
|
||||
row_major_matrix_ = options.row_major_matrix();
|
||||
|
||||
// Get desired way to handle input channels.
|
||||
max_num_channels_ = options.max_num_channels();
|
||||
// Currently only alpha channel toggling is suppored.
|
||||
CHECK_GE(max_num_channels_, 3);
|
||||
CHECK_LE(max_num_channels_, 4);
|
||||
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU"))
|
||||
// Currently on iOS, tflite gpu input tensor must be 4 channels,
|
||||
// so input image must be 4 channels also (checked in InitGpu).
|
||||
max_num_channels_ = 4;
|
||||
#endif
|
||||
|
||||
// Get tensor type, float or quantized.
|
||||
use_quantized_tensors_ = options.use_quantized_tensors();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -415,4 +664,19 @@ template <class T>
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::CopyMatrixToTensor(
|
||||
const Matrix& matrix, float* tensor_buffer) {
|
||||
if (row_major_matrix_) {
|
||||
auto matrix_map = Eigen::Map<RowMajorMatrixXf>(tensor_buffer, matrix.rows(),
|
||||
matrix.cols());
|
||||
matrix_map = matrix;
|
||||
} else {
|
||||
auto matrix_map = Eigen::Map<ColMajorMatrixXf>(tensor_buffer, matrix.rows(),
|
||||
matrix.cols());
|
||||
matrix_map = matrix;
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -22,9 +22,10 @@ message TfLiteConverterCalculatorOptions {
|
|||
optional TfLiteConverterCalculatorOptions ext = 245817797;
|
||||
}
|
||||
|
||||
// Choose normalization mode for output:
|
||||
// Choose normalization mode for output (not applied for Matrix inputs).
|
||||
// true = [-1,1]
|
||||
// false = [0,1]
|
||||
// Ignored if using quantization.
|
||||
optional bool zero_center = 1 [default = true];
|
||||
|
||||
// Whether the input image should be flipped vertically (along the
|
||||
|
@ -38,4 +39,12 @@ message TfLiteConverterCalculatorOptions {
|
|||
// tensor. Currently this only controls whether or not to ignore alpha
|
||||
// channel, so it must be 3 or 4.
|
||||
optional int32 max_num_channels = 3 [default = 3];
|
||||
|
||||
// The calculator expects Matrix inputs to be in column-major order. Set
|
||||
// row_major_matrix to true if the inputs are in row-major order.
|
||||
optional bool row_major_matrix = 4 [default = false];
|
||||
|
||||
// Quantization option (CPU only).
|
||||
// When true, output kTfLiteUInt8 tensor instead of kTfLiteFloat32.
|
||||
optional bool use_quantized_tensors = 5 [default = false];
|
||||
}
|
||||
|
|
199
mediapipe/calculators/tflite/tflite_converter_calculator_test.cc
Normal file
199
mediapipe/calculators/tflite/tflite_converter_calculator_test.cc
Normal file
|
@ -0,0 +1,199 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kTransposeOptionsString[] =
|
||||
"[mediapipe.TfLiteConverterCalculatorOptions.ext]: {"
|
||||
"row_major_matrix: True}";
|
||||
|
||||
} // namespace
|
||||
|
||||
using RandomEngine = std::mt19937_64;
|
||||
const uint32 kSeed = 1234;
|
||||
const int kNumSizes = 8;
|
||||
const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2},
|
||||
{5, 3}, {7, 13}, {16, 32}, {101, 2}};
|
||||
|
||||
class TfLiteConverterCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
// Adds a packet with a matrix filled with random values in [0,1].
|
||||
void AddRandomMatrix(int num_rows, int num_columns, uint32 seed,
|
||||
bool row_major_matrix = false) {
|
||||
RandomEngine random(kSeed);
|
||||
std::uniform_real_distribution<> uniform_dist(0, 1.0);
|
||||
auto matrix = ::absl::make_unique<Matrix>();
|
||||
matrix->resize(num_rows, num_columns);
|
||||
if (row_major_matrix) {
|
||||
for (int y = 0; y < num_rows; ++y) {
|
||||
for (int x = 0; x < num_columns; ++x) {
|
||||
float value = uniform_dist(random);
|
||||
(*matrix)(y, x) = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int x = 0; x < num_columns; ++x) {
|
||||
for (int y = 0; y < num_rows; ++y) {
|
||||
float value = uniform_dist(random);
|
||||
(*matrix)(y, x) = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
MEDIAPIPE_ASSERT_OK(graph_->AddPacketToInputStream(
|
||||
"matrix", Adopt(matrix.release()).At(Timestamp(0))));
|
||||
}
|
||||
|
||||
std::unique_ptr<CalculatorGraph> graph_;
|
||||
};
|
||||
|
||||
TEST_F(TfLiteConverterCalculatorTest, RandomMatrixColMajor) {
|
||||
for (int size_index = 0; size_index < kNumSizes; ++size_index) {
|
||||
const int num_rows = sizes[size_index][0];
|
||||
const int num_columns = sizes[size_index][1];
|
||||
|
||||
// Run the calculator and verify that one output is generated.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "matrix"
|
||||
node {
|
||||
calculator: "TfLiteConverterCalculator"
|
||||
input_stream: "MATRIX:matrix"
|
||||
output_stream: "TENSORS:tensor"
|
||||
options {
|
||||
[mediapipe.TfLiteConverterCalculatorOptions.ext] {
|
||||
row_major_matrix: false
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensor", &graph_config, &output_packets);
|
||||
|
||||
// Run the graph.
|
||||
graph_ = absl::make_unique<CalculatorGraph>();
|
||||
MEDIAPIPE_ASSERT_OK(graph_->Initialize(graph_config));
|
||||
MEDIAPIPE_ASSERT_OK(graph_->StartRun({}));
|
||||
|
||||
// Push the tensor into the graph.
|
||||
AddRandomMatrix(num_rows, num_columns, kSeed, /*row_major_matrix=*/false);
|
||||
|
||||
// Wait until the calculator done processing.
|
||||
MEDIAPIPE_ASSERT_OK(graph_->WaitUntilIdle());
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
|
||||
// Get and process results.
|
||||
const std::vector<TfLiteTensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensor>>();
|
||||
EXPECT_EQ(1, tensor_vec.size());
|
||||
|
||||
const TfLiteTensor* tensor = &tensor_vec[0];
|
||||
EXPECT_EQ(kTfLiteFloat32, tensor->type);
|
||||
|
||||
// Verify that the data is correct.
|
||||
RandomEngine random(kSeed);
|
||||
std::uniform_real_distribution<> uniform_dist(0, 1.0);
|
||||
const float* tensor_buffer = tensor->data.f;
|
||||
for (int i = 0; i < num_rows * num_columns; ++i) {
|
||||
const float expected = uniform_dist(random);
|
||||
EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i;
|
||||
}
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
MEDIAPIPE_ASSERT_OK(graph_->CloseInputStream("matrix"));
|
||||
MEDIAPIPE_ASSERT_OK(graph_->WaitUntilDone());
|
||||
|
||||
graph_.reset();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TfLiteConverterCalculatorTest, RandomMatrixRowMajor) {
|
||||
for (int size_index = 0; size_index < kNumSizes; ++size_index) {
|
||||
const int num_rows = sizes[size_index][0];
|
||||
const int num_columns = sizes[size_index][1];
|
||||
|
||||
// Run the calculator and verify that one output is generated.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "matrix"
|
||||
node {
|
||||
calculator: "TfLiteConverterCalculator"
|
||||
input_stream: "MATRIX:matrix"
|
||||
output_stream: "TENSORS:tensor"
|
||||
options {
|
||||
[mediapipe.TfLiteConverterCalculatorOptions.ext] {
|
||||
row_major_matrix: true
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensor", &graph_config, &output_packets);
|
||||
|
||||
// Run the graph.
|
||||
graph_ = absl::make_unique<CalculatorGraph>();
|
||||
MEDIAPIPE_ASSERT_OK(graph_->Initialize(graph_config));
|
||||
MEDIAPIPE_ASSERT_OK(graph_->StartRun({}));
|
||||
|
||||
// Push the tensor into the graph.
|
||||
AddRandomMatrix(num_rows, num_columns, kSeed, /*row_major_matrix=*/true);
|
||||
|
||||
// Wait until the calculator done processing.
|
||||
MEDIAPIPE_ASSERT_OK(graph_->WaitUntilIdle());
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
|
||||
// Get and process results.
|
||||
const std::vector<TfLiteTensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensor>>();
|
||||
EXPECT_EQ(1, tensor_vec.size());
|
||||
|
||||
const TfLiteTensor* tensor = &tensor_vec[0];
|
||||
EXPECT_EQ(kTfLiteFloat32, tensor->type);
|
||||
|
||||
// Verify that the data is correct.
|
||||
RandomEngine random(kSeed);
|
||||
std::uniform_real_distribution<> uniform_dist(0, 1.0);
|
||||
const float* tensor_buffer = tensor->data.f;
|
||||
for (int i = 0; i < num_rows * num_columns; ++i) {
|
||||
const float expected = uniform_dist(random);
|
||||
EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i;
|
||||
}
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
MEDIAPIPE_ASSERT_OK(graph_->CloseInputStream("matrix"));
|
||||
MEDIAPIPE_ASSERT_OK(graph_->WaitUntilDone());
|
||||
|
||||
graph_.reset();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -47,6 +47,8 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
const TfLiteCustomOpResolverCalculatorOptions& options =
|
||||
cc->Options<TfLiteCustomOpResolverCalculatorOptions>();
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -32,17 +31,22 @@
|
|||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||
#endif // ANDROID
|
||||
#endif // __ANDROID__
|
||||
|
||||
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
#if defined(__OBJC__)
|
||||
#import <CoreVideo/CoreVideo.h>
|
||||
#import <Metal/Metal.h>
|
||||
#import <MetalKit/MetalKit.h>
|
||||
#endif // OBJC
|
||||
#import "mediapipe/framework/ios/NSError+util_status.h"
|
||||
#import "mediapipe/gpu/MediaPipeMetalHelper.h"
|
||||
|
||||
#import "mediapipe/gpu/MPPMetalHelper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||
#endif // APPLE && !TARGET_OS_OSX
|
||||
#endif // iOS
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
typedef id<MTLBuffer> GpuTensor;
|
||||
#endif
|
||||
|
||||
// TfLiteInferenceCalculator File Layout:
|
||||
// * Header
|
||||
|
@ -56,11 +60,14 @@ using ::tflite::gpu::gl::GlProgram;
|
|||
using ::tflite::gpu::gl::GlShader;
|
||||
struct GPUData {
|
||||
int elements = 1;
|
||||
GlBuffer ssbo;
|
||||
GlShader shader;
|
||||
GlProgram program;
|
||||
GlBuffer buffer;
|
||||
};
|
||||
#endif // ANDROID
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
struct GPUData {
|
||||
int elements = 1;
|
||||
id<MTLBuffer> buffer;
|
||||
};
|
||||
#endif
|
||||
|
||||
// Calculator Header Section
|
||||
|
||||
|
@ -78,12 +85,12 @@ struct GPUData {
|
|||
// GPU.
|
||||
//
|
||||
// Input:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32
|
||||
// TENSORS_GPU - Vector of GlBuffer (assumed to be RGB image)
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8
|
||||
// TENSORS_GPU - Vector of GlBuffer or MTLBuffer
|
||||
//
|
||||
// Output:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32
|
||||
// TENSORS_GPU - Vector of GlBuffer
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8
|
||||
// TENSORS_GPU - Vector of GlBuffer or MTLBuffer
|
||||
//
|
||||
// Input side packet:
|
||||
// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver,
|
||||
|
@ -107,7 +114,7 @@ struct GPUData {
|
|||
// Input tensors are assumed to be of the correct size and already normalized.
|
||||
// All output TfLiteTensors will be destroyed when the graph closes,
|
||||
// (i.e. after calling graph.WaitUntilDone()).
|
||||
// GPU tensors are currently only supported on Android.
|
||||
// GPU tensors are currently only supported on Android and iOS.
|
||||
// This calculator uses FixedSizeInputStreamHandler by default.
|
||||
//
|
||||
class TfLiteInferenceCalculator : public CalculatorBase {
|
||||
|
@ -131,40 +138,41 @@ class TfLiteInferenceCalculator : public CalculatorBase {
|
|||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GPUData> gpu_data_in_;
|
||||
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
|
||||
#endif
|
||||
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
MediaPipeMetalHelper* gpu_helper_ = nullptr;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
MPPMetalHelper* gpu_helper_ = nullptr;
|
||||
std::unique_ptr<GPUData> gpu_data_in_;
|
||||
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
|
||||
#endif
|
||||
|
||||
std::string model_path_ = "";
|
||||
bool gpu_inference_ = false;
|
||||
bool gpu_input_ = false;
|
||||
bool gpu_output_ = false;
|
||||
}; // TfLiteInferenceCalculator
|
||||
|
||||
bool use_quantized_tensors_ = false;
|
||||
};
|
||||
REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||
|
||||
// Calculator Core Section
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("TENSORS") ||
|
||||
RET_CHECK(cc->Inputs().HasTag("TENSORS") ^
|
||||
cc->Inputs().HasTag("TENSORS_GPU"));
|
||||
RET_CHECK(cc->Outputs().HasTag("TENSORS") ||
|
||||
RET_CHECK(cc->Outputs().HasTag("TENSORS") ^
|
||||
cc->Outputs().HasTag("TENSORS_GPU"));
|
||||
|
||||
if (cc->Inputs().HasTag("TENSORS"))
|
||||
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU"))
|
||||
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>();
|
||||
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
|
||||
#endif
|
||||
|
||||
if (cc->Outputs().HasTag("TENSORS"))
|
||||
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
if (cc->Outputs().HasTag("TENSORS_GPU"))
|
||||
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>();
|
||||
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
|
||||
#endif
|
||||
|
||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||
|
@ -176,7 +184,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
#if defined(__ANDROID__)
|
||||
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
RETURN_IF_ERROR([MediaPipeMetalHelper updateContract:cc]);
|
||||
RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
#endif
|
||||
|
||||
// Assign this calculator's default InputStreamHandler.
|
||||
|
@ -186,26 +194,26 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
gpu_input_ = true;
|
||||
gpu_inference_ = true; // Inference must be on GPU also.
|
||||
#else
|
||||
RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU"))
|
||||
<< "GPU input for non-Android not supported yet.";
|
||||
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
|
||||
#endif
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("TENSORS_GPU")) {
|
||||
#if defined(__ANDROID__)
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
gpu_output_ = true;
|
||||
RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU"))
|
||||
<< "GPU output must also have GPU Input.";
|
||||
#else
|
||||
RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU"))
|
||||
<< "GPU output for non-Android not supported yet.";
|
||||
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -215,7 +223,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
#if defined(__ANDROID__)
|
||||
RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
gpu_helper_ = [[MediaPipeMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
RET_CHECK(gpu_helper_);
|
||||
#endif
|
||||
|
||||
|
@ -226,24 +234,38 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) {
|
||||
// Receive pre-processed tensor inputs.
|
||||
// 1. Receive pre-processed tensor inputs.
|
||||
if (gpu_input_) {
|
||||
// Read GPU input into SSBO.
|
||||
#if defined(__ANDROID__)
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>();
|
||||
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors]() -> ::mediapipe::Status {
|
||||
// Explicit copy input.
|
||||
tflite::gpu::gl::CopyBuffer(input_tensors[0], gpu_data_in_->ssbo);
|
||||
// Run inference.
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
tflite::gpu::gl::CopyBuffer(input_tensors[0], gpu_data_in_->buffer);
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
|
||||
command_buffer.label = @"TfLiteInferenceCalculatorInput";
|
||||
id<MTLBlitCommandEncoder> blit_command =
|
||||
[command_buffer blitCommandEncoder];
|
||||
// Explicit copy input.
|
||||
[blit_command copyFromBuffer:input_tensors[0]
|
||||
sourceOffset:0
|
||||
toBuffer:gpu_data_in_->buffer
|
||||
destinationOffset:0
|
||||
size:gpu_data_in_->elements * sizeof(float)];
|
||||
[blit_command endEncoding];
|
||||
[command_buffer commit];
|
||||
[command_buffer waitUntilCompleted];
|
||||
#else
|
||||
RET_CHECK_FAIL()
|
||||
<< "GPU input on non-Android devices is not supported yet.";
|
||||
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
|
||||
#endif
|
||||
} else {
|
||||
// Read CPU input into tensors.
|
||||
|
@ -252,20 +274,23 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
RET_CHECK_GT(input_tensors.size(), 0);
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
const TfLiteTensor* input_tensor = &input_tensors[i];
|
||||
RET_CHECK(input_tensor->data.raw);
|
||||
if (use_quantized_tensors_) {
|
||||
const uint8* input_tensor_buffer = input_tensor->data.uint8;
|
||||
uint8* local_tensor_buffer = interpreter_->typed_input_tensor<uint8>(i);
|
||||
memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes);
|
||||
} else {
|
||||
const float* input_tensor_buffer = input_tensor->data.f;
|
||||
RET_CHECK(input_tensor_buffer);
|
||||
|
||||
float* local_tensor_buffer = interpreter_->typed_input_tensor<float>(i);
|
||||
RET_CHECK(local_tensor_buffer);
|
||||
|
||||
memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run inference.
|
||||
// 2. Run inference.
|
||||
if (gpu_inference_) {
|
||||
#if defined(__ANDROID__)
|
||||
RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
|
||||
RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
|
@ -275,12 +300,12 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
} else {
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Output processed tensors.
|
||||
if (gpu_output_) {
|
||||
#if defined(__ANDROID__)
|
||||
// Output result tensors (GPU).
|
||||
auto output_tensors = absl::make_unique<std::vector<GlBuffer>>();
|
||||
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
|
||||
output_tensors->resize(gpu_data_out_.size());
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
GlBuffer& tensor = output_tensors->at(i);
|
||||
|
@ -290,13 +315,39 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
tflite::gpu::gl::CopyBuffer(gpu_data_out_[i]->ssbo, tensor);
|
||||
tflite::gpu::gl::CopyBuffer(gpu_data_out_[i]->buffer, tensor);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("TENSORS_GPU")
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
// Output result tensors (GPU).
|
||||
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
|
||||
command_buffer.label = @"TfLiteInferenceCalculatorOutput";
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
id<MTLBuffer> tensor =
|
||||
[device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float)
|
||||
options:MTLResourceStorageModeShared];
|
||||
id<MTLBlitCommandEncoder> blit_command =
|
||||
[command_buffer blitCommandEncoder];
|
||||
// Explicit copy input.
|
||||
[blit_command copyFromBuffer:gpu_data_out_[i]->buffer
|
||||
sourceOffset:0
|
||||
toBuffer:tensor
|
||||
destinationOffset:0
|
||||
size:gpu_data_out_[i]->elements * sizeof(float)];
|
||||
[blit_command endEncoding];
|
||||
[command_buffer commit];
|
||||
[command_buffer waitUntilCompleted];
|
||||
output_tensors->push_back(tensor);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("TENSORS_GPU")
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
#else
|
||||
LOG(ERROR) << "GPU output on non-Android not supported yet.";
|
||||
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
|
||||
#endif
|
||||
} else {
|
||||
// Output result tensors (CPU).
|
||||
|
@ -325,8 +376,13 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
DeleteGpuDelegate(delegate_);
|
||||
TFLGpuDelegateDelete(delegate_);
|
||||
gpu_data_in_.reset();
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
gpu_data_out_[i].reset();
|
||||
}
|
||||
#endif
|
||||
delegate_ = nullptr;
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -341,8 +397,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
// Get model name.
|
||||
if (!options.model_path().empty()) {
|
||||
ASSIGN_OR_RETURN(model_path_,
|
||||
mediapipe::PathToResourceAsFile(options.model_path()));
|
||||
auto model_path = options.model_path();
|
||||
|
||||
ASSIGN_OR_RETURN(model_path_, mediapipe::PathToResourceAsFile(model_path));
|
||||
} else {
|
||||
LOG(ERROR) << "Must specify path to TFLite model.";
|
||||
return ::mediapipe::Status(::mediapipe::StatusCode::kNotFound,
|
||||
|
@ -360,11 +417,6 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
model_ = tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
|
||||
RET_CHECK(model_);
|
||||
|
||||
#if !defined(__ANDROID__) && !(defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
LOG(WARNING) << "GPU only supported on mobile platforms. Using CPU fallback.";
|
||||
gpu_inference_ = false;
|
||||
#endif
|
||||
|
||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||
const auto& op_resolver =
|
||||
cc->InputSidePackets()
|
||||
|
@ -378,8 +430,13 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
RET_CHECK(interpreter_);
|
||||
|
||||
if (!gpu_output_) {
|
||||
if (gpu_output_) {
|
||||
use_quantized_tensors_ = false;
|
||||
} else {
|
||||
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
use_quantized_tensors_ = (interpreter_->tensor(0)->quantization.type ==
|
||||
kTfLiteAffineQuantization);
|
||||
if (use_quantized_tensors_) gpu_inference_ = false;
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
|
@ -388,12 +445,20 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate(
|
||||
CalculatorContext* cc) {
|
||||
#if defined(__ANDROID__)
|
||||
// Get input image sizes.
|
||||
// Configure and create the delegate.
|
||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||
options.compile_options.precision_loss_allowed = 1;
|
||||
options.compile_options.preferred_gl_object_type =
|
||||
TFLITE_GL_OBJECT_TYPE_FASTEST;
|
||||
options.compile_options.dynamic_batch_enabled = 0;
|
||||
options.compile_options.inline_parameters = 1;
|
||||
if (!delegate_) delegate_ = TfLiteGpuDelegateCreate(&options);
|
||||
|
||||
if (gpu_input_) {
|
||||
// Get input image sizes.
|
||||
gpu_data_in_ = absl::make_unique<GPUData>();
|
||||
const auto& input_indices = interpreter_->inputs();
|
||||
// TODO accept > 1.
|
||||
RET_CHECK_EQ(input_indices.size(), 1);
|
||||
RET_CHECK_EQ(input_indices.size(), 1); // TODO accept > 1.
|
||||
const TfLiteTensor* tensor = interpreter_->tensor(input_indices[0]);
|
||||
gpu_data_in_->elements = 1;
|
||||
for (int d = 0; d < tensor->dims->size; ++d) {
|
||||
|
@ -402,9 +467,19 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
// Input to model can be either RGB/RGBA only.
|
||||
RET_CHECK_GE(tensor->dims->data[3], 3);
|
||||
RET_CHECK_LE(tensor->dims->data[3], 4);
|
||||
// Create and bind input buffer.
|
||||
auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
|
||||
gpu_data_in_->elements, &gpu_data_in_->buffer);
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor(
|
||||
delegate_, gpu_data_in_->buffer.id(),
|
||||
interpreter_->inputs()[0]), // First tensor only
|
||||
kTfLiteOk);
|
||||
}
|
||||
// Get output image sizes.
|
||||
if (gpu_output_) {
|
||||
// Get output image sizes.
|
||||
const auto& output_indices = interpreter_->outputs();
|
||||
gpu_data_out_.resize(output_indices.size());
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
|
@ -416,55 +491,90 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
gpu_data_out_[i]->elements *= tensor->dims->data[d];
|
||||
}
|
||||
}
|
||||
}
|
||||
// Configure and create the delegate.
|
||||
TfLiteGpuDelegateOptions options;
|
||||
options.metadata = nullptr;
|
||||
options.compile_options.precision_loss_allowed = 1;
|
||||
options.compile_options.preferred_gl_object_type =
|
||||
TFLITE_GL_OBJECT_TYPE_FASTEST;
|
||||
options.compile_options.dynamic_batch_enabled = 0;
|
||||
if (!delegate_) delegate_ = TfLiteGpuDelegateCreate(&options);
|
||||
// Shader to convert GL texture to SSBO.
|
||||
if (gpu_input_) {
|
||||
auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
|
||||
gpu_data_in_->elements, &gpu_data_in_->ssbo);
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor(
|
||||
delegate_, gpu_data_in_->ssbo.id(),
|
||||
interpreter_->inputs()[0]), // First tensor only
|
||||
kTfLiteOk);
|
||||
}
|
||||
// Create output SSBO buffers.
|
||||
if (gpu_output_) {
|
||||
// Create and bind output buffers.
|
||||
interpreter_->SetAllowBufferHandleOutput(true);
|
||||
const auto& output_indices = interpreter_->outputs();
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
auto status = CreateReadWriteShaderStorageBuffer<float>(
|
||||
gpu_data_out_[i]->elements, &gpu_data_out_[i]->ssbo);
|
||||
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer);
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
RET_CHECK_EQ(
|
||||
TfLiteGpuDelegateBindBufferToTensor(
|
||||
delegate_, gpu_data_out_[i]->ssbo.id(), output_indices[i]),
|
||||
delegate_, gpu_data_out_[i]->buffer.id(), output_indices[i]),
|
||||
kTfLiteOk);
|
||||
}
|
||||
}
|
||||
|
||||
// Must call this last.
|
||||
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk);
|
||||
return ::mediapipe::OkStatus();
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
#endif // __ANDROID__
|
||||
|
||||
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
// Configure and create the delegate.
|
||||
GpuDelegateOptions options;
|
||||
options.allow_precision_loss = 1;
|
||||
options.wait_type = GpuDelegateOptions::WaitType::kPassive;
|
||||
if (!delegate_) delegate_ = NewGpuDelegate(&options);
|
||||
options.allow_precision_loss = false; // Must match converter, F=float/T=half
|
||||
options.wait_type = GpuDelegateOptions::WaitType::kActive;
|
||||
if (!delegate_) delegate_ = TFLGpuDelegateCreate(&options);
|
||||
|
||||
if (gpu_input_) {
|
||||
// Get input image sizes.
|
||||
gpu_data_in_ = absl::make_unique<GPUData>();
|
||||
const auto& input_indices = interpreter_->inputs();
|
||||
RET_CHECK_EQ(input_indices.size(), 1);
|
||||
const TfLiteTensor* tensor = interpreter_->tensor(input_indices[0]);
|
||||
gpu_data_in_->elements = 1;
|
||||
// On iOS GPU, input must be 4 channels, regardless of what model expects.
|
||||
{
|
||||
gpu_data_in_->elements *= tensor->dims->data[0]; // batch
|
||||
gpu_data_in_->elements *= tensor->dims->data[1]; // height
|
||||
gpu_data_in_->elements *= tensor->dims->data[2]; // width
|
||||
gpu_data_in_->elements *= 4; // channels
|
||||
}
|
||||
// Input to model can be RGBA only.
|
||||
if (tensor->dims->data[3] != 4) {
|
||||
LOG(WARNING) << "Please ensure input GPU tensor is 4 channels.";
|
||||
}
|
||||
// Create and bind input buffer.
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
gpu_data_in_->buffer =
|
||||
[device newBufferWithLength:gpu_data_in_->elements * sizeof(float)
|
||||
options:MTLResourceStorageModeShared];
|
||||
// Must call this before TFLGpuDelegateBindMetalBufferToTensor.
|
||||
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk);
|
||||
return ::mediapipe::OkStatus();
|
||||
#endif // ANDROID or iOS
|
||||
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
||||
delegate_,
|
||||
input_indices[0], // First tensor only
|
||||
gpu_data_in_->buffer),
|
||||
true);
|
||||
}
|
||||
if (gpu_output_) {
|
||||
// Get output image sizes.
|
||||
const auto& output_indices = interpreter_->outputs();
|
||||
gpu_data_out_.resize(output_indices.size());
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]);
|
||||
gpu_data_out_[i] = absl::make_unique<GPUData>();
|
||||
gpu_data_out_[i]->elements = 1;
|
||||
// TODO handle *2 properly on some dialated models
|
||||
for (int d = 0; d < tensor->dims->size; ++d) {
|
||||
gpu_data_out_[i]->elements *= tensor->dims->data[d];
|
||||
}
|
||||
}
|
||||
// Create and bind output buffers.
|
||||
interpreter_->SetAllowBufferHandleOutput(true);
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
for (int i = 0; i < gpu_data_out_.size(); ++i) {
|
||||
gpu_data_out_[i]->buffer =
|
||||
[device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float)
|
||||
options:MTLResourceStorageModeShared];
|
||||
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
||||
delegate_, output_indices[i], gpu_data_out_[i]->buffer),
|
||||
true);
|
||||
}
|
||||
}
|
||||
#endif // iOS
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
|
|
@ -120,11 +120,19 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
|
|||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
::mediapipe::Status ProcessCPU(CalculatorContext* cc,
|
||||
std::vector<Detection>* output_detections);
|
||||
::mediapipe::Status ProcessGPU(CalculatorContext* cc,
|
||||
std::vector<Detection>* output_detections);
|
||||
|
||||
::mediapipe::Status LoadOptions(CalculatorContext* cc);
|
||||
::mediapipe::Status GlSetup(CalculatorContext* cc);
|
||||
::mediapipe::Status DecodeBoxes(const float* raw_boxes,
|
||||
const std::vector<Anchor>& anchors,
|
||||
std::vector<float>* boxes);
|
||||
::mediapipe::Status ConvertToDetections(
|
||||
const float* detection_boxes, const float* detection_scores,
|
||||
const int* detection_classes, std::vector<Detection>* output_detections);
|
||||
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
|
||||
float box_xmax, float score, int class_id,
|
||||
bool flip_vertically);
|
||||
|
@ -136,6 +144,7 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
|
|||
|
||||
::mediapipe::TfLiteTensorsToDetectionsCalculatorOptions options_;
|
||||
std::vector<Anchor> anchors_;
|
||||
bool side_packet_anchors_{};
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
|
@ -187,6 +196,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
gpu_input_ = true;
|
||||
#if defined(__ANDROID__)
|
||||
|
@ -195,6 +206,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
}
|
||||
|
||||
RETURN_IF_ERROR(LoadOptions(cc));
|
||||
side_packet_anchors_ = cc->InputSidePackets().HasTag("ANCHORS");
|
||||
|
||||
if (gpu_input_) {
|
||||
RETURN_IF_ERROR(GlSetup(cc));
|
||||
|
@ -210,72 +222,32 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
const bool side_packet_anchors =
|
||||
cc->InputSidePackets().HasTag("ANCHORS") &&
|
||||
!cc->InputSidePackets().Tag("ANCHORS").IsEmpty();
|
||||
auto output_detections = absl::make_unique<std::vector<Detection>>();
|
||||
|
||||
std::vector<float> boxes(num_boxes_ * num_coords_);
|
||||
std::vector<float> score_class_id_pairs(num_boxes_ * 2);
|
||||
|
||||
if (gpu_input_) {
|
||||
#if defined(__ANDROID__)
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>();
|
||||
|
||||
// Copy inputs.
|
||||
tflite::gpu::gl::CopyBuffer(input_tensors[0], *raw_boxes_buffer_.get());
|
||||
tflite::gpu::gl::CopyBuffer(input_tensors[1], *raw_scores_buffer_.get());
|
||||
if (!anchors_init_) {
|
||||
if (side_packet_anchors) {
|
||||
const auto& anchors =
|
||||
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
|
||||
std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox);
|
||||
ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data());
|
||||
raw_anchors_buffer_->Write<float>(absl::MakeSpan(raw_anchors));
|
||||
RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
|
||||
} else {
|
||||
CHECK_EQ(input_tensors.size(), 3);
|
||||
tflite::gpu::gl::CopyBuffer(input_tensors[2],
|
||||
*raw_anchors_buffer_.get());
|
||||
}
|
||||
anchors_init_ = true;
|
||||
}
|
||||
RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
|
||||
} // if gpu_input_
|
||||
|
||||
// Run shaders.
|
||||
RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors]() -> ::mediapipe::Status {
|
||||
// Decode boxes.
|
||||
decoded_boxes_buffer_->BindToIndex(0);
|
||||
raw_boxes_buffer_->BindToIndex(1);
|
||||
raw_anchors_buffer_->BindToIndex(2);
|
||||
const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1};
|
||||
decode_program_->Dispatch(decode_workgroups);
|
||||
|
||||
// Score boxes.
|
||||
scored_boxes_buffer_->BindToIndex(0);
|
||||
raw_scores_buffer_->BindToIndex(1);
|
||||
const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1};
|
||||
score_program_->Dispatch(score_workgroups);
|
||||
// Output
|
||||
if (cc->Outputs().HasTag("DETECTIONS")) {
|
||||
cc->Outputs()
|
||||
.Tag("DETECTIONS")
|
||||
.Add(output_detections.release(), cc->InputTimestamp());
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
}
|
||||
|
||||
// Copy decoded boxes from GPU to CPU.
|
||||
auto status = decoded_boxes_buffer_->Read(absl::MakeSpan(boxes));
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
status = scored_boxes_buffer_->Read(absl::MakeSpan(score_class_id_pairs));
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
#else
|
||||
LOG(ERROR) << "GPU input on non-Android not supported yet.";
|
||||
#endif // defined(__ANDROID__)
|
||||
} else {
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU(
|
||||
CalculatorContext* cc, std::vector<Detection>* output_detections) {
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
|
||||
|
||||
if (input_tensors.size() == 2) {
|
||||
// Postprocessing on CPU for model without postprocessing op. E.g. output
|
||||
// raw score tensor and box tensor. Anchor decoding will be handled below.
|
||||
const TfLiteTensor* raw_box_tensor = &input_tensors[0];
|
||||
const TfLiteTensor* raw_score_tensor = &input_tensors[1];
|
||||
|
||||
|
@ -300,7 +272,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox);
|
||||
const float* raw_anchors = anchor_tensor->data.f;
|
||||
ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_);
|
||||
} else if (side_packet_anchors) {
|
||||
} else if (side_packet_anchors_) {
|
||||
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
|
||||
anchors_ =
|
||||
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
|
||||
} else {
|
||||
|
@ -308,8 +281,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
}
|
||||
anchors_init_ = true;
|
||||
}
|
||||
std::vector<float> boxes(num_boxes_ * num_coords_);
|
||||
RETURN_IF_ERROR(DecodeBoxes(raw_boxes, anchors_, &boxes));
|
||||
|
||||
std::vector<float> detection_scores(num_boxes_);
|
||||
std::vector<int> detection_classes(num_boxes_);
|
||||
|
||||
// Filter classes by scores.
|
||||
for (int i = 0; i < num_boxes_; ++i) {
|
||||
int class_id = -1;
|
||||
|
@ -335,44 +312,119 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
}
|
||||
}
|
||||
}
|
||||
score_class_id_pairs[i * 2 + 0] = max_score;
|
||||
score_class_id_pairs[i * 2 + 1] = class_id;
|
||||
detection_scores[i] = max_score;
|
||||
detection_classes[i] = class_id;
|
||||
}
|
||||
} // if gpu_input_
|
||||
|
||||
// Convert to Detection.
|
||||
RETURN_IF_ERROR(ConvertToDetections(boxes.data(), detection_scores.data(),
|
||||
detection_classes.data(),
|
||||
output_detections));
|
||||
} else {
|
||||
// Postprocessing on CPU with postprocessing op (e.g. anchor decoding and
|
||||
// non-maximum suppression) within the model.
|
||||
RET_CHECK_EQ(input_tensors.size(), 4);
|
||||
|
||||
const TfLiteTensor* detection_boxes_tensor = &input_tensors[0];
|
||||
const TfLiteTensor* detection_classes_tensor = &input_tensors[1];
|
||||
const TfLiteTensor* detection_scores_tensor = &input_tensors[2];
|
||||
const TfLiteTensor* num_boxes_tensor = &input_tensors[3];
|
||||
RET_CHECK_EQ(num_boxes_tensor->dims->size, 1);
|
||||
RET_CHECK_EQ(num_boxes_tensor->dims->data[0], 1);
|
||||
const float* num_boxes = num_boxes_tensor->data.f;
|
||||
num_boxes_ = num_boxes[0];
|
||||
RET_CHECK_EQ(detection_boxes_tensor->dims->size, 3);
|
||||
RET_CHECK_EQ(detection_boxes_tensor->dims->data[0], 1);
|
||||
const int max_detections = detection_boxes_tensor->dims->data[1];
|
||||
RET_CHECK_EQ(detection_boxes_tensor->dims->data[2], num_coords_);
|
||||
RET_CHECK_EQ(detection_classes_tensor->dims->size, 2);
|
||||
RET_CHECK_EQ(detection_classes_tensor->dims->data[0], 1);
|
||||
RET_CHECK_EQ(detection_classes_tensor->dims->data[1], max_detections);
|
||||
RET_CHECK_EQ(detection_scores_tensor->dims->size, 2);
|
||||
RET_CHECK_EQ(detection_scores_tensor->dims->data[0], 1);
|
||||
RET_CHECK_EQ(detection_scores_tensor->dims->data[1], max_detections);
|
||||
|
||||
const float* detection_boxes = detection_boxes_tensor->data.f;
|
||||
const float* detection_scores = detection_scores_tensor->data.f;
|
||||
std::vector<int> detection_classes(num_boxes_);
|
||||
for (int i = 0; i < num_boxes_; ++i) {
|
||||
const float score = score_class_id_pairs[i * 2 + 0];
|
||||
const int class_id = score_class_id_pairs[i * 2 + 1];
|
||||
const int box_offset = i * num_coords_;
|
||||
Detection detection = ConvertToDetection(
|
||||
boxes[box_offset + 0], boxes[box_offset + 1], boxes[box_offset + 2],
|
||||
boxes[box_offset + 3], score, class_id, options_.flip_vertically());
|
||||
// Add keypoints.
|
||||
if (options_.num_keypoints() > 0) {
|
||||
auto* location_data = detection.mutable_location_data();
|
||||
for (int kp_id = 0; kp_id < options_.num_keypoints() *
|
||||
options_.num_values_per_keypoint();
|
||||
kp_id += options_.num_values_per_keypoint()) {
|
||||
auto keypoint = location_data->add_relative_keypoints();
|
||||
const int keypoint_index =
|
||||
box_offset + options_.keypoint_coord_offset() + kp_id;
|
||||
keypoint->set_x(boxes[keypoint_index + 0]);
|
||||
keypoint->set_y(options_.flip_vertically()
|
||||
? 1.f - boxes[keypoint_index + 1]
|
||||
: boxes[keypoint_index + 1]);
|
||||
detection_classes[i] =
|
||||
static_cast<int>(detection_classes_tensor->data.f[i]);
|
||||
}
|
||||
RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores,
|
||||
detection_classes.data(),
|
||||
output_detections));
|
||||
}
|
||||
output_detections->emplace_back(detection);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
|
||||
CalculatorContext* cc, std::vector<Detection>* output_detections) {
|
||||
#if defined(__ANDROID__)
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>();
|
||||
|
||||
// Copy inputs.
|
||||
tflite::gpu::gl::CopyBuffer(input_tensors[0], *raw_boxes_buffer_.get());
|
||||
tflite::gpu::gl::CopyBuffer(input_tensors[1], *raw_scores_buffer_.get());
|
||||
if (!anchors_init_) {
|
||||
if (side_packet_anchors_) {
|
||||
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
|
||||
const auto& anchors =
|
||||
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
|
||||
std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox);
|
||||
ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data());
|
||||
raw_anchors_buffer_->Write<float>(absl::MakeSpan(raw_anchors));
|
||||
} else {
|
||||
CHECK_EQ(input_tensors.size(), 3);
|
||||
tflite::gpu::gl::CopyBuffer(input_tensors[2], *raw_anchors_buffer_.get());
|
||||
}
|
||||
anchors_init_ = true;
|
||||
}
|
||||
|
||||
// Output
|
||||
if (cc->Outputs().HasTag("DETECTIONS")) {
|
||||
cc->Outputs()
|
||||
.Tag("DETECTIONS")
|
||||
.Add(output_detections.release(), cc->InputTimestamp());
|
||||
// Run shaders.
|
||||
RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors]() -> ::mediapipe::Status {
|
||||
// Decode boxes.
|
||||
decoded_boxes_buffer_->BindToIndex(0);
|
||||
raw_boxes_buffer_->BindToIndex(1);
|
||||
raw_anchors_buffer_->BindToIndex(2);
|
||||
const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1};
|
||||
decode_program_->Dispatch(decode_workgroups);
|
||||
|
||||
// Score boxes.
|
||||
scored_boxes_buffer_->BindToIndex(0);
|
||||
raw_scores_buffer_->BindToIndex(1);
|
||||
const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1};
|
||||
score_program_->Dispatch(score_workgroups);
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
|
||||
// Copy decoded boxes from GPU to CPU.
|
||||
std::vector<float> boxes(num_boxes_ * num_coords_);
|
||||
auto status = decoded_boxes_buffer_->Read(absl::MakeSpan(boxes));
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
std::vector<float> score_class_id_pairs(num_boxes_ * 2);
|
||||
status = scored_boxes_buffer_->Read(absl::MakeSpan(score_class_id_pairs));
|
||||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
|
||||
// TODO: b/138851969. Is it possible to output a float vector
|
||||
// for score and an int vector for class so that we can avoid copying twice?
|
||||
std::vector<float> detection_scores(num_boxes_);
|
||||
std::vector<int> detection_classes(num_boxes_);
|
||||
for (int i = 0; i < num_boxes_; ++i) {
|
||||
detection_scores[i] = score_class_id_pairs[i * 2];
|
||||
detection_classes[i] = static_cast<int>(score_class_id_pairs[i * 2 + 1]);
|
||||
}
|
||||
RETURN_IF_ERROR(ConvertToDetections(boxes.data(), detection_scores.data(),
|
||||
detection_classes.data(),
|
||||
output_detections));
|
||||
#else
|
||||
LOG(ERROR) << "GPU input on non-Android not supported yet.";
|
||||
#endif // defined(__ANDROID__)
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -481,6 +533,39 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ConvertToDetections(
|
||||
const float* detection_boxes, const float* detection_scores,
|
||||
const int* detection_classes, std::vector<Detection>* output_detections) {
|
||||
for (int i = 0; i < num_boxes_; ++i) {
|
||||
if (options_.has_min_score_thresh() &&
|
||||
detection_scores[i] < options_.min_score_thresh()) {
|
||||
continue;
|
||||
}
|
||||
const int box_offset = i * num_coords_;
|
||||
Detection detection = ConvertToDetection(
|
||||
detection_boxes[box_offset + 0], detection_boxes[box_offset + 1],
|
||||
detection_boxes[box_offset + 2], detection_boxes[box_offset + 3],
|
||||
detection_scores[i], detection_classes[i], options_.flip_vertically());
|
||||
// Add keypoints.
|
||||
if (options_.num_keypoints() > 0) {
|
||||
auto* location_data = detection.mutable_location_data();
|
||||
for (int kp_id = 0; kp_id < options_.num_keypoints() *
|
||||
options_.num_values_per_keypoint();
|
||||
kp_id += options_.num_values_per_keypoint()) {
|
||||
auto keypoint = location_data->add_relative_keypoints();
|
||||
const int keypoint_index =
|
||||
box_offset + options_.keypoint_coord_offset() + kp_id;
|
||||
keypoint->set_x(detection_boxes[keypoint_index + 0]);
|
||||
keypoint->set_y(options_.flip_vertically()
|
||||
? 1.f - detection_boxes[keypoint_index + 1]
|
||||
: detection_boxes[keypoint_index + 1]);
|
||||
}
|
||||
}
|
||||
output_detections->emplace_back(detection);
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
|
||||
float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score,
|
||||
int class_id, bool flip_vertically) {
|
||||
|
@ -628,7 +713,7 @@ void main() {
|
|||
if (!status.ok()) {
|
||||
return ::mediapipe::InternalError(status.error_message());
|
||||
}
|
||||
size_t raw_anchors_length = num_boxes_ * num_coords_;
|
||||
size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox;
|
||||
raw_anchors_buffer_ = absl::make_unique<GlBuffer>();
|
||||
status = CreateReadWriteShaderStorageBuffer<float>(raw_anchors_length,
|
||||
raw_anchors_buffer_.get());
|
||||
|
|
|
@ -68,4 +68,7 @@ message TfLiteTensorsToDetectionsCalculatorOptions {
|
|||
// the origin is at the top-left corner, whereas the desired detection
|
||||
// representation has a bottom-left origin (e.g., in OpenGL).
|
||||
optional bool flip_vertically = 18 [default = false];
|
||||
|
||||
// Score threshold for perserving decoded detections.
|
||||
optional float min_score_thresh = 19;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A calculator for converting TFLite tensors to to a float or a float vector.
|
||||
//
|
||||
// Input:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first
|
||||
// tensor will be used.
|
||||
// Output:
|
||||
// FLOAT(optional) - Converted single float number.
|
||||
// FLOATS(optional) - Converted float vector.
|
||||
//
|
||||
// Notes: To output FLOAT stream, the input TFLite tensor must have size 1, e.g.
|
||||
// only 1 float number in the tensor.
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "TfLiteTensorsToFloatsCalculator"
|
||||
// input_stream: "TENSORS:tensors"
|
||||
// output_stream: "FLOATS:floats"
|
||||
// }
|
||||
class TfLiteTensorsToFloatsCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator);
|
||||
|
||||
::mediapipe::Status TfLiteTensorsToFloatsCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("TENSORS"));
|
||||
RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT"));
|
||||
|
||||
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
if (cc->Outputs().HasTag("FLOATS")) {
|
||||
cc->Outputs().Tag("FLOATS").Set<std::vector<float>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("FLOAT")) {
|
||||
cc->Outputs().Tag("FLOAT").Set<float>();
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteTensorsToFloatsCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteTensorsToFloatsCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty());
|
||||
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
|
||||
// TODO: Add option to specify which tensor to take from.
|
||||
const TfLiteTensor* raw_tensor = &input_tensors[0];
|
||||
const float* raw_floats = raw_tensor->data.f;
|
||||
int num_values = 1;
|
||||
for (int i = 0; i < raw_tensor->dims->size; ++i) {
|
||||
RET_CHECK_GT(raw_tensor->dims->data[i], 0);
|
||||
num_values *= raw_tensor->dims->data[i];
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("FLOAT")) {
|
||||
// TODO: Could add an index in the option to specifiy returning one
|
||||
// value of a float array.
|
||||
RET_CHECK_EQ(num_values, 1);
|
||||
cc->Outputs().Tag("FLOAT").AddPacket(
|
||||
MakePacket<float>(raw_floats[0]).At(cc->InputTimestamp()));
|
||||
}
|
||||
if (cc->Outputs().HasTag("FLOATS")) {
|
||||
auto output_floats = absl::make_unique<std::vector<float>>(
|
||||
raw_floats, raw_floats + num_values);
|
||||
cc->Outputs().Tag("FLOATS").Add(output_floats.release(),
|
||||
cc->InputTimestamp());
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,188 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A calculator for converting TFLite tensors from regression models into
|
||||
// landmarks.
|
||||
//
|
||||
// Input:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first
|
||||
// tensor will be used. The size of the values must be
|
||||
// (num_dimension x num_landmarks).
|
||||
// Output:
|
||||
// LANDMARKS(optional) - Result MediaPipe landmarks.
|
||||
// NORM_LANDMARKS(optional) - Result MediaPipe normalized landmarks.
|
||||
//
|
||||
// Notes:
|
||||
// To output normalized landmarks, user must provide the original input image
|
||||
// size to the model using calculator option input_image_width and
|
||||
// input_image_height.
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "TfLiteTensorsToLandmarksCalculator"
|
||||
// input_stream: "TENSORS:landmark_tensors"
|
||||
// output_stream: "LANDMARKS:landmarks"
|
||||
// output_stream: "NORM_LANDMARKS:landmarks"
|
||||
// options: {
|
||||
// [mediapipe.TfLiteTensorsToLandmarksCalculatorOptions.ext] {
|
||||
// num_landmarks: 21
|
||||
//
|
||||
// input_image_width: 256
|
||||
// input_image_height: 256
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class TfLiteTensorsToLandmarksCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
::mediapipe::Status LoadOptions(CalculatorContext* cc);
|
||||
int num_landmarks_ = 0;
|
||||
|
||||
::mediapipe::TfLiteTensorsToLandmarksCalculatorOptions options_;
|
||||
};
|
||||
REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
|
||||
|
||||
::mediapipe::Status TfLiteTensorsToLandmarksCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(!cc->Inputs().GetTags().empty());
|
||||
RET_CHECK(!cc->Outputs().GetTags().empty());
|
||||
|
||||
if (cc->Inputs().HasTag("TENSORS")) {
|
||||
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("LANDMARKS")) {
|
||||
cc->Outputs().Tag("LANDMARKS").Set<std::vector<Landmark>>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
|
||||
cc->Outputs().Tag("NORM_LANDMARKS").Set<std::vector<NormalizedLandmark>>();
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteTensorsToLandmarksCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
||||
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
|
||||
RET_CHECK(options_.has_input_image_height() &&
|
||||
options_.has_input_image_width())
|
||||
<< "Must provide input with/height for getting normalized landmarks.";
|
||||
}
|
||||
if (cc->Outputs().HasTag("LANDMARKS") && options_.flip_vertically()) {
|
||||
RET_CHECK(options_.has_input_image_height() &&
|
||||
options_.has_input_image_width())
|
||||
<< "Must provide input with/height for using flip_vertically option "
|
||||
"when outputing landmarks in absolute coordinates.";
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteTensorsToLandmarksCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
if (cc->Inputs().Tag("TENSORS").IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
|
||||
|
||||
const TfLiteTensor* raw_tensor = &input_tensors[0];
|
||||
|
||||
int num_values = 1;
|
||||
for (int i = 0; i < raw_tensor->dims->size; ++i) {
|
||||
num_values *= raw_tensor->dims->data[i];
|
||||
}
|
||||
const int num_dimensions = num_values / num_landmarks_;
|
||||
// Landmarks must have less than 3 dimensions. Otherwise please consider
|
||||
// using matrix.
|
||||
CHECK_LE(num_dimensions, 3);
|
||||
CHECK_GT(num_dimensions, 0);
|
||||
|
||||
const float* raw_landmarks = raw_tensor->data.f;
|
||||
|
||||
auto output_landmarks = absl::make_unique<std::vector<Landmark>>();
|
||||
|
||||
for (int ld = 0; ld < num_landmarks_; ++ld) {
|
||||
const int offset = ld * num_dimensions;
|
||||
Landmark landmark;
|
||||
landmark.set_x(raw_landmarks[offset]);
|
||||
if (num_dimensions > 1) {
|
||||
if (options_.flip_vertically()) {
|
||||
landmark.set_y(options_.input_image_height() -
|
||||
raw_landmarks[offset + 1]);
|
||||
} else {
|
||||
landmark.set_y(raw_landmarks[offset + 1]);
|
||||
}
|
||||
}
|
||||
if (num_dimensions > 2) {
|
||||
landmark.set_z(raw_landmarks[offset + 2]);
|
||||
}
|
||||
output_landmarks->push_back(landmark);
|
||||
}
|
||||
|
||||
// Output normalized landmarks if required.
|
||||
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
|
||||
auto output_norm_landmarks =
|
||||
absl::make_unique<std::vector<NormalizedLandmark>>();
|
||||
for (const auto& landmark : *output_landmarks) {
|
||||
NormalizedLandmark norm_landmark;
|
||||
norm_landmark.set_x(static_cast<float>(landmark.x()) /
|
||||
options_.input_image_width());
|
||||
norm_landmark.set_y(static_cast<float>(landmark.y()) /
|
||||
options_.input_image_height());
|
||||
norm_landmark.set_z(landmark.z() / options_.normalize_z());
|
||||
|
||||
output_norm_landmarks->push_back(norm_landmark);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("NORM_LANDMARKS")
|
||||
.Add(output_norm_landmarks.release(), cc->InputTimestamp());
|
||||
}
|
||||
// Output absolute landmarks.
|
||||
if (cc->Outputs().HasTag("LANDMARKS")) {
|
||||
cc->Outputs()
|
||||
.Tag("LANDMARKS")
|
||||
.Add(output_landmarks.release(), cc->InputTimestamp());
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TfLiteTensorsToLandmarksCalculator::LoadOptions(
|
||||
CalculatorContext* cc) {
|
||||
// Get calculator options specified in the graph.
|
||||
options_ =
|
||||
cc->Options<::mediapipe::TfLiteTensorsToLandmarksCalculatorOptions>();
|
||||
num_landmarks_ = options_.num_landmarks();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// The option proto for the TfLiteTensorsToLandmarksCalculator.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message TfLiteTensorsToLandmarksCalculatorOptions {
|
||||
extend .mediapipe.CalculatorOptions {
|
||||
optional TfLiteTensorsToLandmarksCalculatorOptions ext = 257405002;
|
||||
}
|
||||
|
||||
// Number of landmarks from the output of the model.
|
||||
required int32 num_landmarks = 1;
|
||||
|
||||
// Size of the input image for the model. These options are used only when
|
||||
// normalized landmarks is needed.
|
||||
optional int32 input_image_width = 2;
|
||||
optional int32 input_image_height = 3;
|
||||
|
||||
// Whether the detection coordinates from the input tensors should be flipped
|
||||
// vertically (along the y-direction). This is useful, for example, when the
|
||||
// input tensors represent detections defined with a coordinate system where
|
||||
// the origin is at the top-left corner, whereas the desired detection
|
||||
// representation has a bottom-left origin (e.g., in OpenGL).
|
||||
optional bool flip_vertically = 4 [default = false];
|
||||
|
||||
// A value that z values should be divided by.
|
||||
optional float normalize_z = 5 [default = 1.0];
|
||||
}
|
|
@ -20,6 +20,9 @@
|
|||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
@ -39,8 +42,11 @@ namespace {
|
|||
constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
|
||||
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
|
||||
// Commonly used to compute the number of blocks to launch in a kernel.
|
||||
int RoundUp(const int size, const int multiple) {
|
||||
return (size + multiple - 1) / multiple;
|
||||
int NumGroups(const int size, const int group_size) { // NOLINT
|
||||
return (size + group_size - 1) / group_size;
|
||||
}
|
||||
float Clamp(float val, float min, float max) {
|
||||
return std::min(std::max(val, min), max);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -58,7 +64,7 @@ using ::tflite::gpu::gl::GlShader;
|
|||
// Performs optional upscale to REFERENCE_IMAGE dimensions if provided,
|
||||
// otherwise the mask is the same size as input tensor.
|
||||
//
|
||||
// Note: This calculator is currently GPU only, so only *_GPU tags can be used.
|
||||
// Produces result as an RGBA image, with the mask in both R & A channels.
|
||||
//
|
||||
// Inputs:
|
||||
// One of the following TENSORS tags:
|
||||
|
@ -71,11 +77,11 @@ using ::tflite::gpu::gl::GlShader;
|
|||
// REFERENCE_IMAGE_GPU (optional): A GpuBuffer input image,
|
||||
// used only for output dimensions.
|
||||
// One of the following PREV_MASK tags:
|
||||
// PREV_MASK (optional): An ImageFrame input mask, Gray, RGB or RGBA.
|
||||
// PREV_MASK_GPU (optional): A GpuBuffer input mask, RGBA.
|
||||
// PREV_MASK (optional): An ImageFrame input mask, Gray, RGB or RGBA, [0-255].
|
||||
// PREV_MASK_GPU (optional): A GpuBuffer input mask, RGBA, [0-1].
|
||||
// Output:
|
||||
// One of the following MASK tags:
|
||||
// MASK: An ImageFrame output mask, Gray, RGB or RGBA.
|
||||
// MASK: An ImageFrame output mask, RGBA.
|
||||
// MASK_GPU: A GpuBuffer output mask, RGBA.
|
||||
//
|
||||
// Options:
|
||||
|
@ -141,10 +147,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
}
|
||||
if (cc->Inputs().HasTag("PREV_MASK")) {
|
||||
cc->Inputs().Tag("PREV_MASK").Set<mediapipe::ImageFrame>();
|
||||
cc->Inputs().Tag("PREV_MASK").Set<ImageFrame>();
|
||||
}
|
||||
if (cc->Inputs().HasTag("REFERENCE_IMAGE")) {
|
||||
cc->Inputs().Tag("REFERENCE_IMAGE").Set<mediapipe::ImageFrame>();
|
||||
cc->Inputs().Tag("REFERENCE_IMAGE").Set<ImageFrame>();
|
||||
}
|
||||
|
||||
// Inputs GPU.
|
||||
|
@ -162,7 +168,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
|
||||
// Outputs.
|
||||
if (cc->Outputs().HasTag("MASK")) {
|
||||
cc->Outputs().Tag("MASK").Set<mediapipe::ImageFrame>();
|
||||
cc->Outputs().Tag("MASK").Set<ImageFrame>();
|
||||
}
|
||||
#if defined(__ANDROID__)
|
||||
if (cc->Outputs().HasTag("MASK_GPU")) {
|
||||
|
@ -179,6 +185,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
use_gpu_ = true;
|
||||
#if defined(__ANDROID__)
|
||||
|
@ -238,7 +246,107 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::ProcessCpu(
|
||||
CalculatorContext* cc) {
|
||||
return ::mediapipe::UnimplementedError("CPU support is not implemented yet.");
|
||||
if (cc->Inputs().Tag("TENSORS").IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Get input streams.
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
|
||||
const bool has_prev_mask = cc->Inputs().HasTag("PREV_MASK") &&
|
||||
!cc->Inputs().Tag("PREV_MASK").IsEmpty();
|
||||
const ImageFrame placeholder;
|
||||
const auto& input_mask = has_prev_mask
|
||||
? cc->Inputs().Tag("PREV_MASK").Get<ImageFrame>()
|
||||
: placeholder;
|
||||
int output_width = tensor_width_, output_height = tensor_height_;
|
||||
if (cc->Inputs().HasTag("REFERENCE_IMAGE")) {
|
||||
const auto& input_image =
|
||||
cc->Inputs().Tag("REFERENCE_IMAGE").Get<ImageFrame>();
|
||||
output_width = input_image.Width();
|
||||
output_height = input_image.Height();
|
||||
}
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
|
||||
// Create initial working mask.
|
||||
cv::Mat small_mask_mat(cv::Size(tensor_width_, tensor_height_), CV_8UC4);
|
||||
|
||||
// Get input previous mask.
|
||||
cv::Mat input_mask_mat;
|
||||
if (has_prev_mask) {
|
||||
cv::Mat temp_mask_mat = formats::MatView(&input_mask);
|
||||
if (temp_mask_mat.channels() != 4) {
|
||||
cv::Mat converted_mat;
|
||||
cv::cvtColor(temp_mask_mat, converted_mat,
|
||||
temp_mask_mat.channels() == 1 ? cv::COLOR_GRAY2RGBA
|
||||
: cv::COLOR_RGB2RGBA);
|
||||
temp_mask_mat = converted_mat.clone();
|
||||
}
|
||||
cv::resize(temp_mask_mat, input_mask_mat, small_mask_mat.size());
|
||||
}
|
||||
|
||||
// Copy input tensor.
|
||||
const TfLiteTensor* raw_input_tensor = &input_tensors[0];
|
||||
const float* raw_input_data = raw_input_tensor->data.f;
|
||||
cv::Mat tensor_mat(cv::Size(tensor_width_, tensor_height_),
|
||||
CV_MAKETYPE(CV_32F, tensor_channels_));
|
||||
float* tensor_mat_ptr = tensor_mat.ptr<float>();
|
||||
memcpy(tensor_mat_ptr, raw_input_data, raw_input_tensor->bytes);
|
||||
|
||||
// Process mask tensor.
|
||||
// Run softmax over tensor output and blend with previous mask.
|
||||
const int output_layer_index = options_.output_layer_index();
|
||||
const float combine_with_prev_ratio = options_.combine_with_previous_ratio();
|
||||
for (int i = 0; i < tensor_height_; ++i) {
|
||||
for (int j = 0; j < tensor_width_; ++j) {
|
||||
// Only two channel input tensor is supported.
|
||||
const cv::Vec2f input_pix = tensor_mat.at<cv::Vec2f>(i, j);
|
||||
const float shift = std::max(input_pix[0], input_pix[1]);
|
||||
const float softmax_denom =
|
||||
std::exp(input_pix[0] - shift) + std::exp(input_pix[1] - shift);
|
||||
float new_mask_value =
|
||||
std::exp(input_pix[output_layer_index] - shift) / softmax_denom;
|
||||
// Combine previous value with current using uncertainty^2 as mixing coeff
|
||||
if (has_prev_mask) {
|
||||
const float prev_mask_value =
|
||||
input_mask_mat.at<cv::Vec4b>(i, j)[0] / 255.0f;
|
||||
const float eps = 0.001;
|
||||
float uncertainty_alpha =
|
||||
1.0 +
|
||||
(new_mask_value * std::log(new_mask_value + eps) +
|
||||
(1.0 - new_mask_value) * std::log(1.0 - new_mask_value + eps)) /
|
||||
std::log(2.0f);
|
||||
uncertainty_alpha = Clamp(uncertainty_alpha, 0.0f, 1.0f);
|
||||
// Equivalent to: a = 1 - (1 - a) * (1 - a); (squaring the uncertainty)
|
||||
uncertainty_alpha *= 2.0 - uncertainty_alpha;
|
||||
const float mixed_mask_value =
|
||||
new_mask_value * uncertainty_alpha +
|
||||
prev_mask_value * (1.0f - uncertainty_alpha);
|
||||
new_mask_value = mixed_mask_value * combine_with_prev_ratio +
|
||||
(1.0f - combine_with_prev_ratio) * new_mask_value;
|
||||
}
|
||||
const uchar mask_value = static_cast<uchar>(new_mask_value * 255);
|
||||
// Set both R and A channels for convenience.
|
||||
const cv::Vec4b out_value = {mask_value, 0, 0, mask_value};
|
||||
small_mask_mat.at<cv::Vec4b>(i, j) = out_value;
|
||||
}
|
||||
}
|
||||
|
||||
if (options_.flip_vertically()) cv::flip(small_mask_mat, small_mask_mat, 0);
|
||||
|
||||
// Upsample small mask into output.
|
||||
cv::Mat large_mask_mat;
|
||||
cv::resize(small_mask_mat, large_mask_mat,
|
||||
cv::Size(output_width, output_height));
|
||||
|
||||
// Send out image as CPU packet.
|
||||
std::unique_ptr<ImageFrame> output_mask = absl::make_unique<ImageFrame>(
|
||||
ImageFormat::SRGBA, output_width, output_height);
|
||||
cv::Mat output_mat = formats::MatView(output_mask.get());
|
||||
large_mask_mat.copyTo(output_mat);
|
||||
cc->Outputs().Tag("MASK").Add(output_mask.release(), cc->InputTimestamp());
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Steps:
|
||||
|
@ -267,10 +375,9 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
output_width = input_image.width();
|
||||
output_height = input_image.height();
|
||||
}
|
||||
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
|
||||
// Create initial output mask texture.
|
||||
// Create initial working mask texture.
|
||||
::tflite::gpu::gl::GlTexture small_mask_texture;
|
||||
::tflite::gpu::gl::CreateReadWriteRgbaImageTexture(
|
||||
tflite::gpu::DataType::UINT8, // GL_RGBA8
|
||||
|
@ -285,6 +392,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
tflite::gpu::gl::CopyBuffer(input_tensors[0], *tensor_buffer_);
|
||||
|
||||
// Run shader, process mask tensor.
|
||||
// Run softmax over tensor output and blend with previous mask.
|
||||
{
|
||||
const int output_index = 0;
|
||||
glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0,
|
||||
|
@ -292,8 +400,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
tensor_buffer_->BindToIndex(2);
|
||||
|
||||
const tflite::gpu::uint3 workgroups = {
|
||||
RoundUp(tensor_width_, kWorkgroupSize),
|
||||
RoundUp(tensor_height_, kWorkgroupSize), 1};
|
||||
NumGroups(tensor_width_, kWorkgroupSize),
|
||||
NumGroups(tensor_height_, kWorkgroupSize), 1};
|
||||
|
||||
if (!has_prev_mask) {
|
||||
mask_program_no_prev_->Dispatch(workgroups);
|
||||
|
@ -330,8 +438,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
// Cleanup
|
||||
input_mask_texture.Release();
|
||||
output_texture.Release();
|
||||
|
||||
#endif // __ANDROID__
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -445,7 +553,7 @@ void main() {
|
|||
int linear_index = gid.y * out_width + gid.x;
|
||||
vec2 input_value = input_data.elements[linear_index];
|
||||
|
||||
// Only two channel output is supported.
|
||||
// Only two channel input tensor is supported.
|
||||
vec2 input_px = input_value.rg;
|
||||
float shift = max(input_px.r, input_px.g);
|
||||
float softmax_denom = exp(input_px.r - shift) + exp(input_px.g - shift);
|
||||
|
@ -474,8 +582,8 @@ void main() {
|
|||
(1.0f - combine_with_previous_ratio) * new_mask_value;
|
||||
#endif // READ_PREVIOUS
|
||||
|
||||
// Texture coordinates are inverted on y axis.
|
||||
ivec2 output_coordinate = ivec2(gid.x, out_height - gid.y - 1);
|
||||
int y_coord = int($4);
|
||||
ivec2 output_coordinate = ivec2(gid.x, y_coord);
|
||||
// Set both R and A channels for convenience.
|
||||
vec4 out_value = vec4(new_mask_value, 0.0, 0.0, new_mask_value);
|
||||
imageStore(output_texture, output_coordinate, out_value);
|
||||
|
@ -483,10 +591,12 @@ void main() {
|
|||
|
||||
const std::string shader_src_no_previous = absl::Substitute(
|
||||
shader_src_template, kWorkgroupSize, options_.output_layer_index(),
|
||||
options_.combine_with_previous_ratio(), "");
|
||||
options_.combine_with_previous_ratio(), "",
|
||||
options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y");
|
||||
const std::string shader_src_with_previous = absl::Substitute(
|
||||
shader_src_template, kWorkgroupSize, options_.output_layer_index(),
|
||||
options_.combine_with_previous_ratio(), "#define READ_PREVIOUS");
|
||||
options_.combine_with_previous_ratio(), "#define READ_PREVIOUS",
|
||||
options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y");
|
||||
|
||||
auto status = ::tflite::gpu::OkStatus();
|
||||
|
||||
|
|
|
@ -34,4 +34,7 @@ message TfLiteTensorsToSegmentationCalculatorOptions {
|
|||
|
||||
// Model specific: Channel to use for processing tensor.
|
||||
optional int32 output_layer_index = 5 [default = 1];
|
||||
|
||||
// Flip result image mask along y-axis.
|
||||
optional bool flip_vertically = 6;
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user