Project import generated by Copybara.

PiperOrigin-RevId: 263889205
This commit is contained in:
MediaPipe Team 2019-08-16 18:49:25 -07:00 committed by jqtang
parent dc40414468
commit 294687295d
443 changed files with 33160 additions and 2011 deletions

View File

@ -34,3 +34,30 @@ build:android_arm --fat_apk_cpu=armeabi-v7a
build:android_arm64 --config=android
build:android_arm64 --cpu=arm64-v8a
build:android_arm64 --fat_apk_cpu=arm64-v8a
# iOS configs.
build:ios --apple_platform_type=ios
build:ios_i386 --config=ios
build:ios_i386 --cpu=ios_i386
build:ios_i386 --watchos_cpus=i386
build:ios_x86_64 --config=ios
build:ios_x86_64 --cpu=ios_x86_64
build:ios_x86_64 --watchos_cpus=i386
build:ios_armv7 --config=ios
build:ios_armv7 --cpu=ios_armv7
build:ios_armv7 --watchos_cpus=armv7k
build:ios_arm64 --config=ios
build:ios_arm64 --cpu=ios_arm64
build:ios_arm64 --watchos_cpus=armv7k
build:ios_arm64e --config=ios
build:ios_arm64e --cpu=ios_arm64e
build:ios_arm64e --watchos_cpus=armv7k
build:ios_fat --config=ios
build:ios_fat --ios_multi_cpus=armv7,arm64
build:ios_fat --watchos_cpus=armv7k

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
mediapipe/provisioning_profile.mobileprovision

2
BUILD
View File

@ -1,4 +1,4 @@
# Copyright 2019 The MediaPipeOSS Authors.
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -28,14 +28,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
wget \
unzip \
python \
python-pip \
libopencv-core-dev \
libopencv-highgui-dev \
libopencv-imgproc-dev \
libopencv-video-dev \
&& \
software-properties-common && \
add-apt-repository -y ppa:openjdk-r/ppa && \
apt-get update && apt-get install -y openjdk-11-jdk && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade setuptools
RUN pip install future
# Install bazel
ARG BAZEL_VERSION=0.26.1
RUN mkdir /bazel && \
@ -49,4 +55,4 @@ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
COPY . /mediapipe/
# If we want the docker image to contain the pre-built object_detection_offline_demo binary, do the following
# RUN bazel build -c opt --define 'MEDIAPIPE_DISABLE_GPU=1' mediapipe/examples/desktop/demo:object_detection_tensorflow_demo
# RUN bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/demo:object_detection_tensorflow_demo

121
WORKSPACE
View File

@ -2,11 +2,12 @@ workspace(name = "mediapipe")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
skylib_version = "0.8.0"
http_archive(
name = "bazel_skylib",
sha256 = "bbccf674aa441c266df9894182d80de104cabd19be98be002f6d478aaa31574d",
strip_prefix = "bazel-skylib-2169ae1c374aab4a09aa90e65efe1a3aad4e279b",
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/2169ae1c374aab4a09aa90e65efe1a3aad4e279b.tar.gz"],
type = "tar.gz",
url = "https://github.com/bazelbuild/bazel-skylib/releases/download/{}/bazel-skylib.{}.tar.gz".format (skylib_version, skylib_version),
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
)
load("@bazel_skylib//lib:versions.bzl", "versions")
versions.check(minimum_bazel_version = "0.23.0")
@ -52,7 +53,7 @@ http_archive(
# glog
http_archive(
name = "com_google_glog",
name = "com_github_glog_glog",
url = "https://github.com/google/glog/archive/v0.3.5.zip",
sha256 = "267103f8a1e9578978aa1dc256001e6529ef593e5aea38193d31c2872ee025e8",
strip_prefix = "glog-0.3.5",
@ -73,6 +74,12 @@ http_archive(
urls = ["https://github.com/google/protobuf/archive/384989534b2246d413dbcd750744faab2607b516.zip"],
)
http_archive(
name = "com_google_audio_tools",
strip_prefix = "multichannel-audio-tools-master",
urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"],
)
# Needed by TensorFlow
http_archive(
name = "io_bazel_rules_closure",
@ -84,12 +91,24 @@ http_archive(
],
)
# TensorFlow r1.14-rc0
# 2019-08-15
_TENSORFLOW_GIT_COMMIT = "67def62936e28f97c16182dfcc467d8d1cae02b4"
_TENSORFLOW_SHA256= "ddd4e3c056e7c0ff2ef29133b30fa62781dfbf8a903e99efb91a02d292fa9562"
http_archive(
name = "org_tensorflow",
strip_prefix = "tensorflow-1.14.0-rc0",
sha256 = "76404a6157a45e8d7a07e4f5690275256260130145924c2a7c73f6eda2a3de10",
urls = ["https://github.com/tensorflow/tensorflow/archive/v1.14.0-rc0.zip"],
urls = [
"https://mirror.bazel.build/github.com/tensorflow/tensorflow/archive/%s.tar.gz" % _TENSORFLOW_GIT_COMMIT,
"https://github.com/tensorflow/tensorflow/archive/%s.tar.gz" % _TENSORFLOW_GIT_COMMIT,
],
strip_prefix = "tensorflow-%s" % _TENSORFLOW_GIT_COMMIT,
sha256 = _TENSORFLOW_SHA256,
patches = [
"@//third_party:tensorflow_065c20bf79253257c87bd4614bb9a7fdef015cbb.diff",
"@//third_party:tensorflow_f67fcbefce906cd419e4657f0d41e21019b71abd.diff",
],
patch_args = [
"-p1",
],
)
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
@ -102,6 +121,12 @@ new_local_repository(
path = "/usr",
)
new_local_repository(
name = "linux_ffmpeg",
build_file = "@//third_party:ffmpeg_linux.BUILD",
path = "/usr"
)
# Please run $ brew install opencv
new_local_repository(
name = "macos_opencv",
@ -109,15 +134,33 @@ new_local_repository(
path = "/usr",
)
new_local_repository(
name = "macos_ffmpeg",
build_file = "@//third_party:ffmpeg_macos.BUILD",
path = "/usr",
)
http_archive(
name = "android_opencv",
sha256="056b849842e4fa8751d09edbb64530cfa7a63c84ccd232d0ace330e27ba55d0b",
sha256 = "056b849842e4fa8751d09edbb64530cfa7a63c84ccd232d0ace330e27ba55d0b",
build_file = "@//third_party:opencv_android.BUILD",
strip_prefix = "OpenCV-android-sdk",
type = "zip",
url = "https://github.com/opencv/opencv/releases/download/4.1.0/opencv-4.1.0-android-sdk.zip",
)
# After OpenCV 3.2.0, the pre-compiled opencv2.framework has google protobuf symbols, which will
# trigger duplicate symbol errors in the linking stage of building a mediapipe ios app.
# To get a higher version of OpenCV for iOS, opencv2.framework needs to be built from source with
# '-DBUILD_PROTOBUF=OFF -DBUILD_opencv_dnn=OFF'.
http_archive(
name = "ios_opencv",
sha256 = "7dd536d06f59e6e1156b546bd581523d8df92ce83440002885ec5abc06558de2",
build_file = "@//third_party:opencv_ios.BUILD",
type = "zip",
url = "https://github.com/opencv/opencv/releases/download/3.2.0/opencv-3.2.0-ios-framework.zip",
)
RULES_JVM_EXTERNAL_TAG = "2.2"
RULES_JVM_EXTERNAL_SHA = "f1203ce04e232ab6fdd81897cf0ff76f2c04c0741424d192f28e65ae752ce2d6"
@ -132,12 +175,15 @@ load("@rules_jvm_external//:defs.bzl", "maven_install")
maven_install(
artifacts = [
"com.android.support.constraint:constraint-layout:aar:1.0.2",
"androidx.appcompat:appcompat:aar:1.0.2",
],
repositories = [
"https://dl.google.com/dl/android/maven2",
"androidx.annotation:annotation:aar:1.1.0",
"androidx.appcompat:appcompat:aar:1.1.0-rc01",
"androidx.constraintlayout:constraintlayout:aar:1.1.3",
"androidx.core:core:aar:1.1.0-rc03",
"androidx.legacy:legacy-support-v4:aar:1.0.0",
"androidx.recyclerview:recyclerview:aar:1.1.0-beta02",
"com.google.android.material:material:aar:1.0.0-rc01",
],
repositories = ["https://dl.google.com/dl/android/maven2"],
)
maven_server(
@ -191,3 +237,50 @@ android_ndk_repository(
android_sdk_repository(
name = "androidsdk",
)
# iOS basic build deps.
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
git_repository(
name = "build_bazel_rules_apple",
remote = "https://github.com/bazelbuild/rules_apple.git",
tag = "0.18.0",
patches = [
"@//third_party:rules_apple_c0863d0596ae6b769a29fa3fb72ff036444fd249.diff",
],
patch_args = [
"-p1",
],
)
load(
"@build_bazel_rules_apple//apple:repositories.bzl",
"apple_rules_dependencies",
)
apple_rules_dependencies()
load(
"@build_bazel_rules_swift//swift:repositories.bzl",
"swift_rules_dependencies",
)
swift_rules_dependencies()
load(
"@build_bazel_apple_support//lib:repositories.bzl",
"apple_support_dependencies",
)
apple_support_dependencies()
# More iOS deps.
http_archive(
name = "google_toolbox_for_mac",
url = "https://github.com/google/google-toolbox-for-mac/archive/v2.2.1.zip",
sha256 = "e3ac053813c989a88703556df4dc4466e424e30d32108433ed6beaec76ba4fdc",
strip_prefix = "google-toolbox-for-mac-2.2.1",
build_file = "@//third_party:google_toolbox_for_mac.BUILD",
)

View File

@ -65,11 +65,73 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
# Note: this cannot just match "apple_platform_type": "macos" because that option
# defaults to "macos" even when building on Linux!
alias(
name = "macos",
actual = select({
":macos_i386": ":macos_i386",
":macos_x86_64": ":macos_x86_64",
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
}),
visibility = ["//visibility:public"],
)
# Note: this also matches on crosstool_top so that it does not produce ambiguous
# selectors when used together with "android".
config_setting(
name = "ios",
values = {
"crosstool_top": "@bazel_tools//tools/cpp:toolchain",
"apple_platform_type": "ios",
},
visibility = ["//visibility:public"],
)
alias(
name = "apple",
actual = select({
":macos": ":macos",
":ios": ":ios",
"//conditions:default": ":ios", # Arbitrarily chosen from above.
}),
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_i386",
values = {
"apple_platform_type": "macos",
"cpu": "darwin",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_x86_64",
values = {
"apple_platform_type": "macos",
"cpu": "darwin_x86_64",
},
visibility = ["//visibility:public"],
)
[
config_setting(
name = arch,
values = {"cpu": arch},
visibility = ["//visibility:public"],
)
for arch in [
"ios_i386",
"ios_x86_64",
"ios_armv7",
"ios_arm64",
"ios_arm64e",
]
]
exports_files(
["provisioning_profile.mobileprovision"],
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,305 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "mfcc_mel_calculators_proto",
srcs = ["mfcc_mel_calculators.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_cc_proto_library(
name = "mfcc_mel_calculators_cc_proto",
srcs = ["mfcc_mel_calculators.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":mfcc_mel_calculators_proto"],
)
proto_library(
name = "rational_factor_resample_calculator_proto",
srcs = ["rational_factor_resample_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_cc_proto_library(
name = "rational_factor_resample_calculator_cc_proto",
srcs = ["rational_factor_resample_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":rational_factor_resample_calculator_proto"],
)
proto_library(
name = "spectrogram_calculator_proto",
srcs = ["spectrogram_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
mediapipe_cc_proto_library(
name = "spectrogram_calculator_cc_proto",
srcs = ["spectrogram_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":spectrogram_calculator_proto"],
)
proto_library(
name = "time_series_framer_calculator_proto",
srcs = ["time_series_framer_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_cc_proto_library(
name = "time_series_framer_calculator_cc_proto",
srcs = ["time_series_framer_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":time_series_framer_calculator_proto"],
)
cc_library(
name = "audio_decoder_calculator",
srcs = ["audio_decoder_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status",
"//mediapipe/util:audio_decoder",
"//mediapipe/util:audio_decoder_cc_proto",
],
alwayslink = 1,
)
cc_library(
name = "basic_time_series_calculators",
srcs = ["basic_time_series_calculators.cc"],
hdrs = ["basic_time_series_calculators.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/util:time_series_util",
"@com_google_absl//absl/strings",
"@eigen_archive//:eigen",
],
alwayslink = 1,
)
cc_library(
name = "mfcc_mel_calculators",
srcs = ["mfcc_mel_calculators.cc"],
visibility = ["//visibility:public"],
deps = [
":mfcc_mel_calculators_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util",
"@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp/mfcc",
"@eigen_archive//:eigen",
],
alwayslink = 1,
)
cc_library(
name = "rational_factor_resample_calculator",
srcs = ["rational_factor_resample_calculator.cc"],
hdrs = ["rational_factor_resample_calculator.h"],
visibility = ["//visibility:public"],
deps = [
":rational_factor_resample_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/util:time_series_util",
"@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp:resampler",
"@com_google_audio_tools//audio/dsp:resampler_rational_factor",
"@eigen_archive//:eigen",
],
alwayslink = 1,
)
cc_library(
name = "spectrogram_calculator",
srcs = ["spectrogram_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":spectrogram_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util",
"@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp:window_functions",
"@com_google_audio_tools//audio/dsp/spectrogram",
"@eigen_archive//:eigen",
],
alwayslink = 1,
)
cc_library(
name = "time_series_framer_calculator",
srcs = ["time_series_framer_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":time_series_framer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util",
"@com_google_audio_tools//audio/dsp:window_functions",
"@eigen_archive//:eigen",
],
alwayslink = 1,
)
cc_test(
name = "audio_decoder_calculator_test",
srcs = ["audio_decoder_calculator_test.cc"],
data = ["//mediapipe/calculators/audio/testdata:test_audios"],
deps = [
":audio_decoder_calculator",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
)
cc_test(
name = "basic_time_series_calculators_test",
srcs = ["basic_time_series_calculators_test.cc"],
deps = [
":basic_time_series_calculators",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/util:time_series_test_util",
"@eigen_archive//:eigen",
],
)
cc_test(
name = "mfcc_mel_calculators_test",
srcs = ["mfcc_mel_calculators_test.cc"],
deps = [
":mfcc_mel_calculators",
":mfcc_mel_calculators_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util",
"@eigen_archive//:eigen",
],
)
cc_test(
name = "spectrogram_calculator_test",
srcs = ["spectrogram_calculator_test.cc"],
deps = [
":spectrogram_calculator",
":spectrogram_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:benchmark",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util",
"@com_google_audio_tools//audio/dsp:number_util",
"@eigen_archive//:eigen",
],
)
cc_test(
name = "time_series_framer_calculator_test",
srcs = ["time_series_framer_calculator_test.cc"],
deps = [
":time_series_framer_calculator",
":time_series_framer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util",
"@com_google_audio_tools//audio/dsp:window_functions",
"@eigen_archive//:eigen",
],
)
cc_test(
name = "rational_factor_resample_calculator_test",
srcs = ["rational_factor_resample_calculator_test.cc"],
deps = [
":rational_factor_resample_calculator",
":rational_factor_resample_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:validate_type",
"//mediapipe/util:time_series_test_util",
"@com_google_audio_tools//audio/dsp:signal_vector_util",
"@eigen_archive//:eigen",
],
)

View File

@ -0,0 +1,106 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/audio_decoder.h"
#include "mediapipe/util/audio_decoder.pb.h"
namespace mediapipe {
// The AudioDecoderCalculator decodes an audio stream of the media file. It
// produces two output streams contain audio packets and the header infomation.
//
// Output Streams:
// AUDIO: Output audio frames (Matrix).
// AUDIO_HEADER:
// Optional audio header information output
// Input Side Packets:
// INPUT_FILE_PATH: The input file path.
//
// Example config:
// node {
// calculator: "AudioDecoderCalculator"
// input_side_packet: "INPUT_FILE_PATH:input_file_path"
// output_stream: "AUDIO:audio"
// output_stream: "AUDIO_HEADER:audio_header"
// node_options {
// [type.googleapis.com/mediapipe.AudioDecoderOptions]: {
// audio_stream { stream_index: 0 }
// start_time: 0
// end_time: 1
// }
// }
//
// TODO: support decoding multiple streams.
class AudioDecoderCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
std::unique_ptr<AudioDecoder> decoder_;
};
::mediapipe::Status AudioDecoderCalculator::GetContract(
CalculatorContract* cc) {
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set<std::string>();
cc->Outputs().Tag("AUDIO").Set<Matrix>();
if (cc->Outputs().HasTag("AUDIO_HEADER")) {
cc->Outputs().Tag("AUDIO_HEADER").Set<mediapipe::TimeSeriesHeader>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status AudioDecoderCalculator::Open(CalculatorContext* cc) {
const std::string& input_file_path =
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get<std::string>();
const auto& decoder_options = cc->Options<mediapipe::AudioDecoderOptions>();
decoder_ = absl::make_unique<AudioDecoder>();
RETURN_IF_ERROR(decoder_->Initialize(input_file_path, decoder_options));
std::unique_ptr<mediapipe::TimeSeriesHeader> header =
absl::make_unique<mediapipe::TimeSeriesHeader>();
if (decoder_->FillAudioHeader(decoder_options.audio_stream(0), header.get())
.ok()) {
// Only pass on a header if the decoder could actually produce one.
// otherwise, the header will be empty.
cc->Outputs().Tag("AUDIO_HEADER").SetHeader(Adopt(header.release()));
}
cc->Outputs().Tag("AUDIO_HEADER").Close();
return ::mediapipe::OkStatus();
}
::mediapipe::Status AudioDecoderCalculator::Process(CalculatorContext* cc) {
Packet data;
int options_index = -1;
auto status = decoder_->GetData(&options_index, &data);
if (status.ok()) {
cc->Outputs().Tag("AUDIO").AddPacket(data);
}
return status;
}
::mediapipe::Status AudioDecoderCalculator::Close(CalculatorContext* cc) {
return decoder_->Close();
}
REGISTER_CALCULATOR(AudioDecoderCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,153 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
TEST(AudioDecoderCalculatorTest, TestWAV) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "AudioDecoderCalculator"
input_side_packet: "INPUT_FILE_PATH:input_file_path"
output_stream: "AUDIO:audio"
output_stream: "AUDIO_HEADER:audio_header"
node_options {
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
audio_stream { stream_index: 0 }
}
})");
CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./",
"/mediapipe/calculators/audio/"
"testdata/sine_wave_1k_44100_mono_2_sec_wav.audio"));
MEDIAPIPE_ASSERT_OK(runner.Run());
MEDIAPIPE_EXPECT_OK(
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
const mediapipe::TimeSeriesHeader& header =
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.Get<mediapipe::TimeSeriesHeader>();
EXPECT_EQ(44100, header.sample_rate());
EXPECT_EQ(1, header.num_channels());
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
std::ceil(44100.0 * 2 / 2048));
}
TEST(AudioDecoderCalculatorTest, Test48KWAV) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "AudioDecoderCalculator"
input_side_packet: "INPUT_FILE_PATH:input_file_path"
output_stream: "AUDIO:audio"
output_stream: "AUDIO_HEADER:audio_header"
node_options {
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
audio_stream { stream_index: 0 }
}
})");
CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./",
"/mediapipe/calculators/audio/"
"testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio"));
MEDIAPIPE_ASSERT_OK(runner.Run());
MEDIAPIPE_EXPECT_OK(
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
const mediapipe::TimeSeriesHeader& header =
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.Get<mediapipe::TimeSeriesHeader>();
EXPECT_EQ(48000, header.sample_rate());
EXPECT_EQ(2, header.num_channels());
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
std::ceil(48000.0 * 2 / 1024));
}
TEST(AudioDecoderCalculatorTest, TestMP3) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "AudioDecoderCalculator"
input_side_packet: "INPUT_FILE_PATH:input_file_path"
output_stream: "AUDIO:audio"
output_stream: "AUDIO_HEADER:audio_header"
node_options {
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
audio_stream { stream_index: 0 }
}
})");
CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./",
"/mediapipe/calculators/audio/"
"testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio"));
MEDIAPIPE_ASSERT_OK(runner.Run());
MEDIAPIPE_EXPECT_OK(
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
const mediapipe::TimeSeriesHeader& header =
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.Get<mediapipe::TimeSeriesHeader>();
EXPECT_EQ(44100, header.sample_rate());
EXPECT_EQ(2, header.num_channels());
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
std::ceil(44100.0 * 2 / 1152));
}
TEST(AudioDecoderCalculatorTest, TestAAC) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "AudioDecoderCalculator"
input_side_packet: "INPUT_FILE_PATH:input_file_path"
output_stream: "AUDIO:audio"
output_stream: "AUDIO_HEADER:audio_header"
node_options {
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
audio_stream { stream_index: 0 }
}
})");
CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./",
"/mediapipe/calculators/audio/"
"testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio"));
MEDIAPIPE_ASSERT_OK(runner.Run());
MEDIAPIPE_EXPECT_OK(
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
const mediapipe::TimeSeriesHeader& header =
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.Get<mediapipe::TimeSeriesHeader>();
EXPECT_EQ(44100, header.sample_rate());
EXPECT_EQ(2, header.num_channels());
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
std::ceil(44100.0 * 2 / 1024));
}
} // namespace mediapipe

View File

@ -0,0 +1,403 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Basic Calculators that operate on TimeSeries streams.
#include "mediapipe/calculators/audio/basic_time_series_calculators.h"
#include <cmath>
#include <memory>
#include "Eigen/Core"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
namespace {
static bool SafeMultiply(int x, int y, int* result) {
static_assert(sizeof(int64) >= 2 * sizeof(int),
"Unable to detect overflow after multiplication");
const int64 big = static_cast<int64>(x) * static_cast<int64>(y);
if (big > static_cast<int64>(INT_MIN) && big < static_cast<int64>(INT_MAX)) {
if (result != nullptr) *result = static_cast<int>(big);
return true;
} else {
return false;
}
}
} // namespace
::mediapipe::Status BasicTimeSeriesCalculatorBase::GetContract(
CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Input stream with TimeSeriesHeader.
);
cc->Outputs().Index(0).Set<Matrix>(
// Output stream with TimeSeriesHeader.
);
return ::mediapipe::OkStatus();
}
::mediapipe::Status BasicTimeSeriesCalculatorBase::Open(CalculatorContext* cc) {
TimeSeriesHeader input_header;
RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), &input_header));
auto output_header = new TimeSeriesHeader(input_header);
RETURN_IF_ERROR(MutateHeader(output_header));
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
return ::mediapipe::OkStatus();
}
::mediapipe::Status BasicTimeSeriesCalculatorBase::Process(
CalculatorContext* cc) {
const Matrix& input = cc->Inputs().Index(0).Get<Matrix>();
RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader(
input, cc->Inputs().Index(0).Header().Get<TimeSeriesHeader>()));
std::unique_ptr<Matrix> output(new Matrix(ProcessMatrix(input)));
RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader(
*output, cc->Outputs().Index(0).Header().Get<TimeSeriesHeader>()));
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status BasicTimeSeriesCalculatorBase::MutateHeader(
TimeSeriesHeader* output_header) {
return ::mediapipe::OkStatus();
}
// Calculator to sum an input time series across channels. This is
// useful for e.g. computing 'summary SAI' pitchogram features.
//
// Options proto: None.
class SumTimeSeriesAcrossChannelsCalculator
: public BasicTimeSeriesCalculatorBase {
protected:
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_channels(1);
return ::mediapipe::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.colwise().sum();
}
};
REGISTER_CALCULATOR(SumTimeSeriesAcrossChannelsCalculator);
// Calculator to average an input time series across channels. This is
// useful for e.g. converting stereo or multi-channel files to mono.
//
// Options proto: None.
class AverageTimeSeriesAcrossChannelsCalculator
: public BasicTimeSeriesCalculatorBase {
protected:
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_channels(1);
return ::mediapipe::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.colwise().mean();
}
};
REGISTER_CALCULATOR(AverageTimeSeriesAcrossChannelsCalculator);
// Calculator to convert a (temporal) summary SAI stream (a single-channel
// stream output by SumTimeSeriesAcrossChannelsCalculator) into pitchogram
// frames by transposing the input packets, swapping the time and channel axes.
//
// Options proto: None.
class SummarySaiToPitchogramCalculator : public BasicTimeSeriesCalculatorBase {
protected:
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
if (output_header->num_channels() != 1) {
return tool::StatusInvalid(
absl::StrCat("Expected single-channel input, got ",
output_header->num_channels()));
}
output_header->set_num_channels(output_header->num_samples());
output_header->set_num_samples(1);
output_header->set_sample_rate(output_header->packet_rate());
return ::mediapipe::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.transpose();
}
};
REGISTER_CALCULATOR(SummarySaiToPitchogramCalculator);
// Calculator to reverse the order of channels in TimeSeries packets.
// This is useful for e.g. interfacing with the speech pipeline which uses the
// opposite convention to the hearing filterbanks.
//
// Options proto: None.
class ReverseChannelOrderCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.colwise().reverse();
}
};
REGISTER_CALCULATOR(ReverseChannelOrderCalculator);
// Calculator to flatten all samples in a TimeSeries packet down into
// a single 'sample' vector. This is useful for e.g. stacking several
// frames of features into a single feature vector.
//
// Options proto: None.
class FlattenPacketCalculator : public BasicTimeSeriesCalculatorBase {
protected:
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
const int num_input_channels = output_header->num_channels();
const int num_input_samples = output_header->num_samples();
RET_CHECK(num_input_channels >= 0)
<< "FlattenPacketCalculator: num_input_channels < 0";
RET_CHECK(num_input_samples >= 0)
<< "FlattenPacketCalculator: num_input_samples < 0";
int output_num_channels;
RET_CHECK(SafeMultiply(num_input_channels, num_input_samples,
&output_num_channels))
<< "FlattenPacketCalculator: Multiplication failed.";
output_header->set_num_channels(output_num_channels);
output_header->set_num_samples(1);
output_header->set_sample_rate(output_header->packet_rate());
return ::mediapipe::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
// Flatten by interleaving channels so that full samples are
// stacked on top of each other instead of interleaving samples
// from the same channel.
Matrix output(input_matrix.size(), 1);
for (int sample = 0; sample < input_matrix.cols(); ++sample) {
output.middleRows(sample * input_matrix.rows(), input_matrix.rows()) =
input_matrix.col(sample);
}
return output;
}
};
REGISTER_CALCULATOR(FlattenPacketCalculator);
// Calculator to subtract the within-packet mean for each channel from each
// corresponding channel.
//
// Options proto: None.
class SubtractMeanCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
Matrix mean = input_matrix.rowwise().mean();
return input_matrix - mean.replicate(1, input_matrix.cols());
}
};
REGISTER_CALCULATOR(SubtractMeanCalculator);
// Calculator to subtract the mean over all values (across all times and
// channels) in a Packet from the values in that Packet.
//
// Options proto: None.
class SubtractMeanAcrossChannelsCalculator
: public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
auto mean = input_matrix.mean();
return (input_matrix.array() - mean).matrix();
}
};
REGISTER_CALCULATOR(SubtractMeanAcrossChannelsCalculator);
// Calculator to divide all values in a Packet by the average value across all
// times and channels in the packet. This is useful for normalizing
// nonnegative quantities like power, but might cause unexpected results if used
// with Packets that can contain negative numbers.
//
// If mean is exactly zero, the output will be a matrix of all ones, because
// that's what happens in other cases where all values are equal.
//
// Options proto: None.
class DivideByMeanAcrossChannelsCalculator
: public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
auto mean = input_matrix.mean();
if (mean != 0) {
return input_matrix / mean;
// When used with nonnegative matrices, the mean will only be zero if the
// entire matrix is exactly zero. If mean is exactly zero, the output will
// be a matrix of all ones, because that's what happens in other cases
// where
// all values are equal.
} else {
return Matrix::Ones(input_matrix.rows(), input_matrix.cols());
}
}
};
REGISTER_CALCULATOR(DivideByMeanAcrossChannelsCalculator);
// Calculator to calculate the mean for each channel.
//
// Options proto: None.
class MeanCalculator : public BasicTimeSeriesCalculatorBase {
protected:
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_samples(1);
output_header->set_sample_rate(output_header->packet_rate());
return ::mediapipe::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.rowwise().mean();
}
};
REGISTER_CALCULATOR(MeanCalculator);
// Calculator to calculate the uncorrected sample standard deviation in each
// channel, independently for each Packet. I.e. divide by the number of samples
// in the Packet, not (<number of samples> - 1).
//
// Options proto: None.
class StandardDeviationCalculator : public BasicTimeSeriesCalculatorBase {
protected:
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_samples(1);
output_header->set_sample_rate(output_header->packet_rate());
return ::mediapipe::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
Eigen::VectorXf mean = input_matrix.rowwise().mean();
return (input_matrix.colwise() - mean).rowwise().norm() /
sqrt(input_matrix.cols());
}
};
REGISTER_CALCULATOR(StandardDeviationCalculator);
// Calculator to calculate the covariance matrix. If the input matrix
// has N channels, the output matrix will be an N by N symmetric
// matrix.
//
// Options proto: None.
class CovarianceCalculator : public BasicTimeSeriesCalculatorBase {
protected:
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_samples(output_header->num_channels());
return ::mediapipe::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
auto mean = input_matrix.rowwise().mean();
auto zero_mean_input =
input_matrix - mean.replicate(1, input_matrix.cols());
return (zero_mean_input * zero_mean_input.transpose()) /
input_matrix.cols();
}
};
REGISTER_CALCULATOR(CovarianceCalculator);
// Calculator to get the per column L2 norm of an input time series.
//
// Options proto: None.
class L2NormCalculator : public BasicTimeSeriesCalculatorBase {
protected:
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_channels(1);
return ::mediapipe::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.colwise().norm();
}
};
REGISTER_CALCULATOR(L2NormCalculator);
// Calculator to convert each column of a matrix to a unit vector.
//
// Options proto: None.
class L2NormalizeColumnCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.colwise().normalized();
}
};
REGISTER_CALCULATOR(L2NormalizeColumnCalculator);
// Calculator to apply L2 normalization to the input matrix.
//
// Returns the matrix as is if the RMS is <= 1E-8.
// Options proto: None.
class L2NormalizeCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
constexpr double kEpsilon = 1e-8;
double rms = std::sqrt(input_matrix.array().square().mean());
if (rms <= kEpsilon) {
return input_matrix;
}
return input_matrix / rms;
}
};
REGISTER_CALCULATOR(L2NormalizeCalculator);
// Calculator to apply Peak normalization to the input matrix.
//
// Returns the matrix as is if the peak is <= 1E-8.
// Options proto: None.
class PeakNormalizeCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
constexpr double kEpsilon = 1e-8;
double max_pcm = input_matrix.cwiseAbs().maxCoeff();
if (max_pcm <= kEpsilon) {
return input_matrix;
}
return input_matrix / max_pcm;
}
};
REGISTER_CALCULATOR(PeakNormalizeCalculator);
// Calculator to compute the elementwise square of an input time series.
//
// Options proto: None.
class ElementwiseSquareCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.array().square();
}
};
REGISTER_CALCULATOR(ElementwiseSquareCalculator);
// Calculator that outputs first floor(num_samples / 2) of the samples.
//
// Options proto: None.
class FirstHalfSlicerCalculator : public BasicTimeSeriesCalculatorBase {
protected:
::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final {
const int num_input_samples = output_header->num_samples();
RET_CHECK(num_input_samples >= 0)
<< "FirstHalfSlicerCalculator: num_input_samples < 0";
output_header->set_num_samples(num_input_samples / 2);
return ::mediapipe::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.block(0, 0, input_matrix.rows(),
input_matrix.cols() / 2);
}
};
REGISTER_CALCULATOR(FirstHalfSlicerCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,48 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Abstract base class for basic MediaPipe calculators that operate on
// TimeSeries streams and don't require any Options protos.
// Subclasses must override ProcessMatrix, and optionally
// MutateHeader.
#ifndef MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_
#define MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
namespace mediapipe {
class BasicTimeSeriesCalculatorBase : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
protected:
// Open() calls this method to mutate the output stream header. The input
// to this function will contain a copy of the input stream header, so
// subclasses that do not need to mutate the header do not need to override
// it.
virtual ::mediapipe::Status MutateHeader(TimeSeriesHeader* output_header);
// Process() calls this method on each packet to compute the output matrix.
virtual Matrix ProcessMatrix(const Matrix& input_matrix) = 0;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_

View File

@ -0,0 +1,515 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "Eigen/Core"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/util/time_series_test_util.h"
namespace mediapipe {
class SumTimeSeriesAcrossChannelsCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override {
calculator_name_ = "SumTimeSeriesAcrossChannelsCalculator";
}
};
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, IsNoOpOnSingleChannelInputs) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
const Matrix input =
Matrix::Random(header.num_channels(), header.num_samples());
Test(header, {input}, header, {input});
}
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, ConstantPacket) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
TimeSeriesHeader output_header(header);
output_header.set_num_channels(1);
Test(header,
{Matrix::Constant(header.num_channels(), header.num_samples(), 1)},
output_header,
{Matrix::Constant(1, header.num_samples(), header.num_channels())});
}
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, MultiplePackets) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
Matrix in(header.num_channels(), header.num_samples());
in << 10, -1, -1, 0, 0, 20, -2, 0, 1, 0, 30, -3, 1, 0, 12;
TimeSeriesHeader output_header(header);
output_header.set_num_channels(1);
Matrix out(1, header.num_samples());
out << 60, -6, 0, 1, 12;
Test(header, {in, 2 * in, in + Matrix::Constant(in.rows(), in.cols(), 3.5f)},
output_header,
{out, 2 * out,
out + Matrix::Constant(out.rows(), out.cols(),
3.5 * header.num_channels())});
}
class AverageTimeSeriesAcrossChannelsCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override {
calculator_name_ = "AverageTimeSeriesAcrossChannelsCalculator";
}
};
TEST_F(AverageTimeSeriesAcrossChannelsCalculatorTest,
IsNoOpOnSingleChannelInputs) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
const Matrix input =
Matrix::Random(header.num_channels(), header.num_samples());
Test(header, {input}, header, {input});
}
TEST_F(AverageTimeSeriesAcrossChannelsCalculatorTest, ConstantPacket) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
TimeSeriesHeader output_header(header);
output_header.set_num_channels(1);
Matrix input =
Matrix::Constant(header.num_channels(), header.num_samples(), 0.0);
input.row(0) = Matrix::Constant(1, header.num_samples(), 1.0);
Test(
header, {input}, output_header,
{Matrix::Constant(1, header.num_samples(), 1.0 / header.num_channels())});
}
class SummarySaiToPitchogramCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override {
calculator_name_ = "SummarySaiToPitchogramCalculator";
}
};
TEST_F(SummarySaiToPitchogramCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 1 num_samples: 3");
Matrix input(1, input_header.num_samples());
input << 3, -9, 4;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 5.0 packet_rate: 5.0 num_channels: 3 num_samples: 1");
Matrix output(input_header.num_samples(), 1);
output << 3, -9, 4;
Test(input_header, {input}, output_header, {output});
}
class ReverseChannelOrderCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "ReverseChannelOrderCalculator"; }
};
TEST_F(ReverseChannelOrderCalculatorTest, IsNoOpOnSingleChannelInputs) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
const Matrix input =
Matrix::Random(header.num_channels(), header.num_samples());
Test(header, {input}, header, {input});
}
TEST_F(ReverseChannelOrderCalculatorTest, SinglePacket) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 5 num_samples: 2");
Matrix input(header.num_channels(), header.num_samples());
input.transpose() << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
Matrix output(header.num_channels(), header.num_samples());
output.transpose() << 5, 4, 3, 2, 1, -5, -4, -3, -2, -1;
Test(header, {input}, header, {output});
}
class FlattenPacketCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "FlattenPacketCalculator"; }
};
TEST_F(FlattenPacketCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
input.transpose() << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
Matrix output(10, 1);
output << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 10.0 packet_rate: 10.0 num_channels: 10 num_samples: 1");
Test(input_header, {input}, output_header, {output});
}
class SubtractMeanCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "SubtractMeanCalculator"; }
};
TEST_F(SubtractMeanCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
Matrix output(input_header.num_channels(), input_header.num_samples());
// clang-format off
input.transpose() << 1, 0, 3, 0, 1,
-1, -2, -3, 4, 7;
output.transpose() << 1, 1, 3, -2, -3,
-1, -1, -3, 2, 3;
// clang-format on
const TimeSeriesHeader output_header = input_header;
Test(input_header, {input}, output_header, {output});
}
class SubtractMeanAcrossChannelsCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override {
calculator_name_ = "SubtractMeanAcrossChannelsCalculator";
}
};
TEST_F(SubtractMeanAcrossChannelsCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
TimeSeriesHeader output_header(input_header);
output_header.set_num_samples(2);
Matrix input(input_header.num_channels(), input_header.num_samples());
Matrix output(output_header.num_channels(), output_header.num_samples());
// clang-format off
input.transpose() << 1.0, 2.0, 3.0,
4.0, 5.0, 6.0;
output.transpose() << 1.0 - 3.5, 2.0 - 3.5, 3.0 - 3.5,
4.0 - 3.5, 5.0 - 3.5, 6.0 - 3.5;
// clang-format on
Test(input_header, {input}, output_header, {output});
}
class DivideByMeanAcrossChannelsCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override {
calculator_name_ = "DivideByMeanAcrossChannelsCalculator";
}
};
TEST_F(DivideByMeanAcrossChannelsCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
input.transpose() << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0;
TimeSeriesHeader output_header(input_header);
output_header.set_num_samples(2);
Matrix output(output_header.num_channels(), output_header.num_samples());
output.transpose() << 1.0 / 3.5, 2.0 / 3.5, 3.0 / 3.5, 4.0 / 3.5, 5.0 / 3.5,
6.0 / 3.5;
Test(input_header, {input}, output_header, {output});
}
TEST_F(DivideByMeanAcrossChannelsCalculatorTest, ReturnsOneForZeroMean) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
input.transpose() << -3.0, -2.0, -1.0, 1.0, 2.0, 3.0;
TimeSeriesHeader output_header(input_header);
output_header.set_num_samples(2);
Matrix output(output_header.num_channels(), output_header.num_samples());
output.transpose() << 1.0, 1.0, 1.0, 1.0, 1.0, 1.0;
Test(input_header, {input}, output_header, {output});
}
class MeanCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "MeanCalculator"; }
};
TEST_F(MeanCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
input.transpose() << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0;
TimeSeriesHeader output_header(input_header);
output_header.set_num_samples(1);
output_header.set_sample_rate(10.0);
Matrix output(output_header.num_channels(), output_header.num_samples());
output << (1.0 + 4.0) / 2, (2.0 + 5.0) / 2, (3.0 + 6.0) / 2;
Test(input_header, {input}, output_header, {output});
}
class StandardDeviationCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "StandardDeviationCalculator"; }
};
TEST_F(StandardDeviationCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
input.transpose() << 0.0, 2.0, 3.0, 4.0, 5.0, 8.0;
TimeSeriesHeader output_header(input_header);
output_header.set_sample_rate(10.0);
output_header.set_num_samples(1);
Matrix output(output_header.num_channels(), output_header.num_samples());
output << sqrt((pow(0.0 - 2.0, 2) + pow(4.0 - 2.0, 2)) / 2),
sqrt((pow(2.0 - 3.5, 2) + pow(5.0 - 3.5, 2)) / 2),
sqrt((pow(3.0 - 5.5, 2) + pow(8.0 - 5.5, 2)) / 2);
Test(input_header, {input}, output_header, {output});
}
class CovarianceCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "CovarianceCalculator"; }
};
TEST_F(CovarianceCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
// We'll specify in transposed form so we can write one channel at a time.
input << 1.0, 3.0, 5.0, 9.0, -1.0, -3.0;
TimeSeriesHeader output_header(input_header);
output_header.set_num_samples(output_header.num_channels());
Matrix output(output_header.num_channels(), output_header.num_samples());
output << 1, 2, -1, 2, 4, -2, -1, -2, 1;
Test(input_header, {input}, output_header, {output});
}
class L2NormCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "L2NormCalculator"; }
};
TEST_F(L2NormCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 3, 5, 8, 4, 12, -15;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 1 num_samples: 3");
Matrix output(output_header.num_channels(), output_header.num_samples());
output << 5, 13, 17;
Test(input_header, {input}, output_header, {output});
}
class L2NormalizeColumnCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "L2NormalizeColumnCalculator"; }
};
TEST_F(L2NormalizeColumnCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix output(output_header.num_channels(), output_header.num_samples());
// The values in output are column-wise L2 normalized
// e.g.
// |a| -> |a/sqrt(a^2 + b^2)|
// |b| |b/sqrt(a^2 + b^2)|
output << 0.51449579000473022, 0.40613847970962524, 0.70710676908493042,
0.85749292373657227, 0.91381156444549561, 0.70710676908493042;
Test(input_header, {input}, output_header, {output});
}
class L2NormalizeCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "L2NormalizeCalculator"; }
};
TEST_F(L2NormalizeCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix output(output_header.num_channels(), output_header.num_samples());
// The values in output are L2 normalized
// a -> a/sqrt(a^2 + b^2 + c^2 + ...) * sqrt(matrix.cols()*matrix.rows())
output << 0.45661166, 0.60881555, 1.21763109, 0.76101943, 1.36983498,
1.21763109;
Test(input_header, {input}, output_header, {output});
}
TEST_F(L2NormalizeCalculatorTest, UnitMatrixStaysUnchanged) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 3 num_samples: 5");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1.0, -1.0, 1.0;
Test(input_header, {input}, input_header, {input});
}
class PeakNormalizeCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "PeakNormalizeCalculator"; }
};
TEST_F(PeakNormalizeCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix output(output_header.num_channels(), output_header.num_samples());
output << 0.33333333, 0.44444444, 0.88888889, 0.55555556, 1.0, 0.88888889;
Test(input_header, {input}, output_header, {output});
}
TEST_F(PeakNormalizeCalculatorTest, UnitMatrixStaysUnchanged) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 3 num_samples: 5");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1.0, -1.0, 1.0;
Test(input_header, {input}, input_header, {input});
}
class ElementwiseSquareCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "ElementwiseSquareCalculator"; }
};
TEST_F(ElementwiseSquareCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 3, 5, 8, 4, 12, -15;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix output(output_header.num_channels(), output_header.num_samples());
output << 9, 25, 64, 16, 144, 225;
Test(input_header, {input}, output_header, {output});
}
class FirstHalfSlicerCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "FirstHalfSlicerCalculator"; }
};
TEST_F(FirstHalfSlicerCalculatorTest, SinglePacketEvenNumSamples) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
// clang-format off
input.transpose() << 0, 1, 2, 3, 4,
5, 6, 7, 8, 9;
// clang-format on
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
Matrix output(output_header.num_channels(), output_header.num_samples());
output.transpose() << 0, 1, 2, 3, 4;
Test(input_header, {input}, output_header, {output});
}
TEST_F(FirstHalfSlicerCalculatorTest, SinglePacketOddNumSamples) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
// clang-format off
input.transpose() << 0, 1, 2, 3, 4,
5, 6, 7, 8, 9,
0, 0, 0, 0, 0;
// clang-format on
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
Matrix output(output_header.num_channels(), output_header.num_samples());
output.transpose() << 0, 1, 2, 3, 4;
Test(input_header, {input}, output_header, {output});
}
TEST_F(FirstHalfSlicerCalculatorTest, MultiplePackets) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
// clang-format off
input.transpose() << 0, 1, 2, 3, 4,
5, 6, 7, 8, 9;
// clang-format on
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
Matrix output(output_header.num_channels(), output_header.num_samples());
output.transpose() << 0, 1, 2, 3, 4;
Test(input_header,
{input, 2 * input,
input + Matrix::Constant(input.rows(), input.cols(), 3.5f)},
output_header,
{output, 2 * output,
output + Matrix::Constant(output.rows(), output.cols(), 3.5f)});
}
} // namespace mediapipe

View File

@ -0,0 +1,278 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// MediaPipe Calculator wrapper around audio/dsp/mfcc/
// classes MelFilterbank (magnitude spectrograms warped to the Mel
// approximation of the auditory frequency scale) and Mfcc (Mel Frequency
// Cepstral Coefficients, the decorrelated transform of log-Mel-spectrum
// commonly used as acoustic features in speech and other audio tasks.
// Both calculators expect as input the SQUARED_MAGNITUDE-domain outputs
// from the MediaPipe SpectrogramCalculator object.
#include <memory>
#include <vector>
#include "Eigen/Core"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "audio/dsp/mfcc/mel_filterbank.h"
#include "audio/dsp/mfcc/mfcc.h"
#include "mediapipe/calculators/audio/mfcc_mel_calculators.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
namespace {
// Portable version of TimeSeriesHeader's DebugString.
std::string PortableDebugString(const TimeSeriesHeader& header) {
std::string unsubstituted_header_debug_str = R"(
sample_rate: $0
num_channels: $1
num_samples: $2
packet_rate: $3
audio_sample_rate: $4
)";
return absl::Substitute(unsubstituted_header_debug_str, header.sample_rate(),
header.num_channels(), header.num_samples(),
header.packet_rate(), header.audio_sample_rate());
}
} // namespace
// Abstract base class for Calculators that transform feature vectors on a
// frame-by-frame basis.
// Subclasses must override pure virtual methods ConfigureTransform and
// TransformFrame.
// Input and output MediaPipe packets are matrices with one column per frame,
// and one row per feature dimension. Each input packet results in an
// output packet with the same number of columns (but differing numbers of
// rows corresponding to the new feature space).
class FramewiseTransformCalculatorBase : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Sequence of Matrices, each column describing a particular time frame,
// each row a feature dimension, with TimeSeriesHeader.
);
cc->Outputs().Index(0).Set<Matrix>(
// Sequence of Matrices, each column describing a particular time frame,
// each row a feature dimension, with TimeSeriesHeader.
);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
int num_output_channels(void) { return num_output_channels_; }
void set_num_output_channels(int num_output_channels) {
num_output_channels_ = num_output_channels;
}
private:
// Takes header and options, and sets up state including calling
// set_num_output_channels() on the base object.
virtual ::mediapipe::Status ConfigureTransform(
const TimeSeriesHeader& header, const CalculatorOptions& options) = 0;
// Takes a vector<double> corresponding to an input frame, and
// perform the specific transformation to produce an output frame.
virtual void TransformFrame(const std::vector<double>& input,
std::vector<double>* output) const = 0;
private:
int num_output_channels_;
};
::mediapipe::Status FramewiseTransformCalculatorBase::Open(
CalculatorContext* cc) {
TimeSeriesHeader input_header;
RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), &input_header));
::mediapipe::Status status = ConfigureTransform(input_header, cc->Options());
auto output_header = new TimeSeriesHeader(input_header);
output_header->set_num_channels(num_output_channels_);
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
return status;
}
::mediapipe::Status FramewiseTransformCalculatorBase::Process(
CalculatorContext* cc) {
const Matrix& input = cc->Inputs().Index(0).Get<Matrix>();
const int num_frames = input.cols();
std::unique_ptr<Matrix> output(new Matrix(num_output_channels_, num_frames));
// The main work here is converting each column of the float Matrix
// into a vector of doubles, which is what our target functions from
// dsp_core consume, and doing the reverse with their output.
std::vector<double> input_frame(input.rows());
std::vector<double> output_frame(num_output_channels_);
for (int frame = 0; frame < num_frames; ++frame) {
// Copy input from Eigen::Matrix column to vector<float>.
Eigen::Map<Eigen::MatrixXd> input_frame_map(&input_frame[0],
input_frame.size(), 1);
input_frame_map = input.col(frame).cast<double>();
// Perform the actual transformation.
TransformFrame(input_frame, &output_frame);
// Copy output from vector<float> to Eigen::Vector.
CHECK_EQ(output_frame.size(), num_output_channels_);
Eigen::Map<const Eigen::MatrixXd> output_frame_map(&output_frame[0],
output_frame.size(), 1);
output->col(frame) = output_frame_map.cast<float>();
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
// Calculator wrapper around the dsp/mfcc/mfcc.cc routine.
// Take frames of squared-magnitude spectra from the SpectrogramCalculator
// and convert them into Mel Frequency Cepstral Coefficients.
//
// Example config:
// node {
// calculator: "MfccCalculator"
// input_stream: "spectrogram_frames_stream"
// output_stream: "mfcc_frames_stream"
// options {
// [mediapipe.MfccCalculatorOptions.ext] {
// mel_spectrum_params {
// channel_count: 20
// min_frequency_hertz: 125.0
// max_frequency_hertz: 3800.0
// }
// mfcc_count: 13
// }
// }
// }
class MfccCalculator : public FramewiseTransformCalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
return FramewiseTransformCalculatorBase::GetContract(cc);
}
private:
::mediapipe::Status ConfigureTransform(
const TimeSeriesHeader& header,
const CalculatorOptions& options) override {
MfccCalculatorOptions mfcc_options;
time_series_util::FillOptionsExtensionOrDie(options, &mfcc_options);
mfcc_.reset(new audio_dsp::Mfcc());
int input_length = header.num_channels();
// Set up the parameters to the Mfcc object.
set_num_output_channels(mfcc_options.mfcc_count());
mfcc_->set_dct_coefficient_count(num_output_channels());
mfcc_->set_upper_frequency_limit(
mfcc_options.mel_spectrum_params().max_frequency_hertz());
mfcc_->set_lower_frequency_limit(
mfcc_options.mel_spectrum_params().min_frequency_hertz());
mfcc_->set_filterbank_channel_count(
mfcc_options.mel_spectrum_params().channel_count());
// An upstream calculator (such as SpectrogramCalculator) must store
// the sample rate of its input audio waveform in the TimeSeries Header.
// audio_dsp::MelFilterBank needs to know this to
// correctly interpret the spectrogram bins.
if (!header.has_audio_sample_rate()) {
return ::mediapipe::InvalidArgumentError(
absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ",
PortableDebugString(header)));
}
// Now we can initialize the Mfcc object.
bool initialized =
mfcc_->Initialize(input_length, header.audio_sample_rate());
if (initialized) {
return ::mediapipe::OkStatus();
} else {
return ::mediapipe::Status(mediapipe::StatusCode::kInternal,
"Mfcc::Initialize returned uninitialized");
}
}
void TransformFrame(const std::vector<double>& input,
std::vector<double>* output) const override {
mfcc_->Compute(input, output);
}
private:
std::unique_ptr<audio_dsp::Mfcc> mfcc_;
};
REGISTER_CALCULATOR(MfccCalculator);
// Calculator wrapper around the dsp/mfcc/mel_filterbank.cc routine.
// Take frames of squared-magnitude spectra from the SpectrogramCalculator
// and convert them into Mel-warped (linear-magnitude) spectra.
// Note: This code computes a mel-frequency filterbank, using a simple
// algorithm that gives bad results (some mel channels that are always zero)
// if you ask for too many channels.
class MelSpectrumCalculator : public FramewiseTransformCalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
return FramewiseTransformCalculatorBase::GetContract(cc);
}
private:
::mediapipe::Status ConfigureTransform(
const TimeSeriesHeader& header,
const CalculatorOptions& options) override {
MelSpectrumCalculatorOptions mel_spectrum_options;
time_series_util::FillOptionsExtensionOrDie(options, &mel_spectrum_options);
mel_filterbank_.reset(new audio_dsp::MelFilterbank());
int input_length = header.num_channels();
set_num_output_channels(mel_spectrum_options.channel_count());
// An upstream calculator (such as SpectrogramCalculator) must store
// the sample rate of its input audio waveform in the TimeSeries Header.
// audio_dsp::MelFilterBank needs to know this to
// correctly interpret the spectrogram bins.
if (!header.has_audio_sample_rate()) {
return ::mediapipe::InvalidArgumentError(
absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ",
PortableDebugString(header)));
}
bool initialized = mel_filterbank_->Initialize(
input_length, header.audio_sample_rate(), num_output_channels(),
mel_spectrum_options.min_frequency_hertz(),
mel_spectrum_options.max_frequency_hertz());
if (initialized) {
return ::mediapipe::OkStatus();
} else {
return ::mediapipe::Status(mediapipe::StatusCode::kInternal,
"mfcc::Initialize returned uninitialized");
}
}
void TransformFrame(const std::vector<double>& input,
std::vector<double>* output) const override {
mel_filterbank_->Compute(input, output);
}
private:
std::unique_ptr<audio_dsp::MelFilterbank> mel_filterbank_;
};
REGISTER_CALCULATOR(MelSpectrumCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,50 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message MelSpectrumCalculatorOptions {
extend CalculatorOptions {
optional MelSpectrumCalculatorOptions ext = 78581812;
}
// The fields are to populate the config parameters in
// audio/dsp/mfcc/mel_filterbank.h
// but the names are chose to mirror
// audio/hearing/filterbanks/cochlea_gammatone_filterbank.proto
// and the default values match those in
// speech/greco3/frontend/filter_bank.proto .
// Total number of frequency bands to use.
optional int32 channel_count = 1 [default = 20];
// Lower edge of lowest triangular Mel band.
optional float min_frequency_hertz = 2 [default = 125.0];
// Upper edge of highest triangular Mel band.
optional float max_frequency_hertz = 3 [default = 3800.0];
}
message MfccCalculatorOptions {
extend CalculatorOptions {
optional MfccCalculatorOptions ext = 78450441;
}
// Specification of the underlying mel filterbank.
optional MelSpectrumCalculatorOptions mel_spectrum_params = 1;
// How many MFCC coefficients to emit.
optional uint32 mfcc_count = 2 [default = 13];
}

View File

@ -0,0 +1,149 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "Eigen/Core"
#include "mediapipe/calculators/audio/mfcc_mel_calculators.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/util/time_series_test_util.h"
namespace mediapipe {
// Use a sample rate that is unlikely to be a default somewhere.
const float kAudioSampleRate = 8800.0;
template <typename OptionsType, const char* CalculatorName>
class FramewiseTransformCalculatorTest
: public TimeSeriesCalculatorTest<OptionsType> {
protected:
void SetUp() override {
this->calculator_name_ = CalculatorName;
this->num_input_channels_ = 129;
// This is the frame rate coming out of the SpectrogramCalculator.
this->input_sample_rate_ = 100.0;
}
// Returns the number of samples per packet.
int GenerateRandomNonnegInputStream(int num_packets) {
const double kSecondsPerPacket = 0.2;
const int num_samples_per_packet =
kSecondsPerPacket * this->input_sample_rate_;
for (int i = 0; i < num_packets; ++i) {
const int timestamp =
i * kSecondsPerPacket * Timestamp::kTimestampUnitsPerSecond;
// Mfcc, MelSpectrum expect squared-magnitude inputs, so make
// sure the input data has no negative values.
Matrix* sqdata = this->NewRandomMatrix(this->num_input_channels_,
num_samples_per_packet);
*sqdata = sqdata->array().square();
this->AppendInputPacket(sqdata, timestamp);
}
return num_samples_per_packet;
}
void CheckOutputPacketMetadata(int expected_num_channels,
int expected_num_samples_per_packet) {
int expected_timestamp = 0;
for (const auto& packet : this->output().packets) {
EXPECT_EQ(expected_timestamp, packet.Timestamp().Value());
expected_timestamp += expected_num_samples_per_packet /
this->input_sample_rate_ *
Timestamp::kTimestampUnitsPerSecond;
const Matrix& output_matrix = packet.template Get<Matrix>();
EXPECT_EQ(output_matrix.rows(), expected_num_channels);
EXPECT_EQ(output_matrix.cols(), expected_num_samples_per_packet);
}
}
void SetupGraphAndHeader() {
this->InitializeGraph();
this->FillInputHeader();
}
// Argument is the expected number of dimensions (channels, columns) in
// the output data from the Calculator under test, which the test should
// know.
void SetupRandomInputPackets() {
constexpr int kNumPackets = 5;
num_samples_per_packet_ = GenerateRandomNonnegInputStream(kNumPackets);
}
::mediapipe::Status Run() { return this->RunGraph(); }
void CheckResults(int expected_num_channels) {
const auto& output_header =
this->output().header.template Get<TimeSeriesHeader>();
EXPECT_EQ(this->input_sample_rate_, output_header.sample_rate());
CheckOutputPacketMetadata(expected_num_channels, num_samples_per_packet_);
// Sanity check that output packets have non-zero energy.
for (const auto& packet : this->output().packets) {
const Matrix& data = packet.template Get<Matrix>();
EXPECT_GT(data.squaredNorm(), 0);
}
}
// Allows SetupRandomInputPackets() to inform CheckResults() about how
// big the packets are supposed to be.
int num_samples_per_packet_;
};
constexpr char kMfccCalculator[] = "MfccCalculator";
typedef FramewiseTransformCalculatorTest<MfccCalculatorOptions, kMfccCalculator>
MfccCalculatorTest;
TEST_F(MfccCalculatorTest, AudioSampleRateFromInputHeader) {
audio_sample_rate_ = kAudioSampleRate;
SetupGraphAndHeader();
SetupRandomInputPackets();
MEDIAPIPE_EXPECT_OK(Run());
CheckResults(options_.mfcc_count());
}
TEST_F(MfccCalculatorTest, NoAudioSampleRate) {
// Leave audio_sample_rate_ == kUnset, so it is not present in the
// input TimeSeriesHeader; expect failure.
SetupGraphAndHeader();
SetupRandomInputPackets();
EXPECT_FALSE(Run().ok());
}
constexpr char kMelSpectrumCalculator[] = "MelSpectrumCalculator";
typedef FramewiseTransformCalculatorTest<MelSpectrumCalculatorOptions,
kMelSpectrumCalculator>
MelSpectrumCalculatorTest;
TEST_F(MelSpectrumCalculatorTest, AudioSampleRateFromInputHeader) {
audio_sample_rate_ = kAudioSampleRate;
SetupGraphAndHeader();
SetupRandomInputPackets();
MEDIAPIPE_EXPECT_OK(Run());
CheckResults(options_.channel_count());
}
TEST_F(MelSpectrumCalculatorTest, NoAudioSampleRate) {
// Leave audio_sample_rate_ == kUnset, so it is not present in the
// input TimeSeriesHeader; expect failure.
SetupGraphAndHeader();
SetupRandomInputPackets();
EXPECT_FALSE(Run().ok());
}
} // namespace mediapipe

View File

@ -0,0 +1,197 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Defines RationalFactorResampleCalculator.
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.h"
#include "audio/dsp/resampler_rational_factor.h"
using audio_dsp::DefaultResamplingKernel;
using audio_dsp::RationalFactorResampler;
using audio_dsp::Resampler;
namespace mediapipe {
::mediapipe::Status RationalFactorResampleCalculator::Process(
CalculatorContext* cc) {
return ProcessInternal(cc->Inputs().Index(0).Get<Matrix>(), false, cc);
}
::mediapipe::Status RationalFactorResampleCalculator::Close(
CalculatorContext* cc) {
if (initial_timestamp_ == Timestamp::Unstarted()) {
return ::mediapipe::OkStatus();
}
Matrix empty_input_frame(num_channels_, 0);
return ProcessInternal(empty_input_frame, true, cc);
}
namespace {
void CopyChannelToVector(const Matrix& matrix, int channel,
std::vector<float>* vec) {
vec->clear();
vec->reserve(matrix.cols());
for (int sample = 0; sample < matrix.cols(); ++sample) {
vec->push_back(matrix(channel, sample));
}
}
void CopyVectorToChannel(const std::vector<float>& vec, Matrix* matrix,
int channel) {
if (matrix->cols() == 0) {
matrix->resize(matrix->rows(), vec.size());
} else {
CHECK_EQ(vec.size(), matrix->cols());
CHECK_LT(channel, matrix->rows());
}
for (int sample = 0; sample < matrix->cols(); ++sample) {
(*matrix)(channel, sample) = vec[sample];
}
}
} // namespace
::mediapipe::Status RationalFactorResampleCalculator::Open(
CalculatorContext* cc) {
RationalFactorResampleCalculatorOptions resample_options;
time_series_util::FillOptionsExtensionOrDie(cc->Options(), &resample_options);
if (!resample_options.has_target_sample_rate()) {
return tool::StatusInvalid(
"resample_options doesn't have target_sample_rate.");
}
target_sample_rate_ = resample_options.target_sample_rate();
TimeSeriesHeader input_header;
RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), &input_header));
source_sample_rate_ = input_header.sample_rate();
num_channels_ = input_header.num_channels();
// Don't create resamplers for pass-thru (sample rates are equal).
if (source_sample_rate_ != target_sample_rate_) {
resampler_.resize(num_channels_);
for (auto& r : resampler_) {
r = ResamplerFromOptions(source_sample_rate_, target_sample_rate_,
resample_options);
if (!r) {
LOG(ERROR) << "Failed to initialize resampler.";
return ::mediapipe::UnknownError("Failed to initialize resampler.");
}
}
}
TimeSeriesHeader* output_header = new TimeSeriesHeader(input_header);
output_header->set_sample_rate(target_sample_rate_);
// The resampler doesn't make guarantees about how many samples will
// be in each packet.
output_header->clear_packet_rate();
output_header->clear_num_samples();
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
cumulative_output_samples_ = 0;
cumulative_input_samples_ = 0;
initial_timestamp_ = Timestamp::Unstarted();
check_inconsistent_timestamps_ =
resample_options.check_inconsistent_timestamps();
return ::mediapipe::OkStatus();
}
::mediapipe::Status RationalFactorResampleCalculator::ProcessInternal(
const Matrix& input_frame, bool should_flush, CalculatorContext* cc) {
if (initial_timestamp_ == Timestamp::Unstarted()) {
initial_timestamp_ = cc->InputTimestamp();
}
if (check_inconsistent_timestamps_) {
time_series_util::LogWarningIfTimestampIsInconsistent(
cc->InputTimestamp(), initial_timestamp_, cumulative_input_samples_,
source_sample_rate_);
}
Timestamp output_timestamp =
initial_timestamp_ + ((cumulative_output_samples_ / target_sample_rate_) *
Timestamp::kTimestampUnitsPerSecond);
cumulative_input_samples_ += input_frame.cols();
std::unique_ptr<Matrix> output_frame(new Matrix(num_channels_, 0));
if (resampler_.empty()) {
// Sample rates were same for input and output; pass-thru.
*output_frame = input_frame;
} else {
if (!Resample(input_frame, output_frame.get(), should_flush)) {
return ::mediapipe::UnknownError("Resample() failed.");
}
}
cumulative_output_samples_ += output_frame->cols();
if (output_frame->cols() > 0) {
cc->Outputs().Index(0).Add(output_frame.release(), output_timestamp);
}
return ::mediapipe::OkStatus();
}
bool RationalFactorResampleCalculator::Resample(const Matrix& input_frame,
Matrix* output_frame,
bool should_flush) {
std::vector<float> input_vector;
std::vector<float> output_vector;
for (int i = 0; i < input_frame.rows(); ++i) {
CopyChannelToVector(input_frame, i, &input_vector);
if (should_flush) {
resampler_[i]->Flush(&output_vector);
} else {
resampler_[i]->ProcessSamples(input_vector, &output_vector);
}
CopyVectorToChannel(output_vector, output_frame, i);
}
return true;
}
// static
std::unique_ptr<Resampler<float>>
RationalFactorResampleCalculator::ResamplerFromOptions(
const double source_sample_rate, const double target_sample_rate,
const RationalFactorResampleCalculatorOptions& options) {
std::unique_ptr<Resampler<float>> resampler;
const auto& rational_factor_options =
options.resampler_rational_factor_options();
std::unique_ptr<DefaultResamplingKernel> kernel;
if (rational_factor_options.has_radius() &&
rational_factor_options.has_cutoff() &&
rational_factor_options.has_kaiser_beta()) {
kernel = absl::make_unique<DefaultResamplingKernel>(
source_sample_rate, target_sample_rate,
rational_factor_options.radius(), rational_factor_options.cutoff(),
rational_factor_options.kaiser_beta());
} else {
kernel = absl::make_unique<DefaultResamplingKernel>(source_sample_rate,
target_sample_rate);
}
// Set large enough so that the resampling factor between common sample
// rates (e.g. 8kHz, 16kHz, 22.05kHz, 32kHz, 44.1kHz, 48kHz) is exact, and
// that any factor is represented with error less than 0.025%.
const int kMaxDenominator = 2000;
resampler = absl::make_unique<RationalFactorResampler<float>>(
*kernel, kMaxDenominator);
if (resampler != nullptr && !resampler->Valid()) {
resampler = std::unique_ptr<Resampler<float>>();
}
return resampler;
}
REGISTER_CALCULATOR(RationalFactorResampleCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,106 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_
#include <algorithm>
#include <memory>
#include <vector>
#include "Eigen/Core"
#include "absl/strings/str_cat.h"
#include "audio/dsp/resampler.h"
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
// MediaPipe Calculator for resampling a (vector-valued)
// input time series with a uniform sample rate. The output
// stream's sampling rate is specified by target_sample_rate in the
// RationalFactorResampleCalculatorOptions. The output time series may have
// a varying number of samples per frame.
class RationalFactorResampleCalculator : public CalculatorBase {
public:
struct TestAccess;
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Single input stream with TimeSeriesHeader.
);
cc->Outputs().Index(0).Set<Matrix>(
// Resampled stream with TimeSeriesHeader.
);
return ::mediapipe::OkStatus();
}
// Returns FAIL if the input stream header is invalid or if the
// resampler cannot be initialized.
::mediapipe::Status Open(CalculatorContext* cc) override;
// Resamples a packet of TimeSeries data. Returns FAIL if the
// resampler state becomes inconsistent.
::mediapipe::Status Process(CalculatorContext* cc) override;
// Flushes any remaining state. Returns FAIL if the resampler state
// becomes inconsistent.
::mediapipe::Status Close(CalculatorContext* cc) override;
protected:
typedef audio_dsp::Resampler<float> ResamplerType;
// Returns a Resampler<float> implementation specified by the
// RationalFactorResampleCalculatorOptions proto. Returns null if the options
// specify an invalid resampler.
static std::unique_ptr<ResamplerType> ResamplerFromOptions(
const double source_sample_rate, const double target_sample_rate,
const RationalFactorResampleCalculatorOptions& options);
// Does Timestamp bookkeeping and resampling common to Process() and
// Close(). Returns FAIL if the resampler state becomes
// inconsistent.
::mediapipe::Status ProcessInternal(const Matrix& input_frame,
bool should_flush, CalculatorContext* cc);
// Uses the internal resampler_ objects to actually resample each
// row of the input TimeSeries. Returns false if the resampler
// state becomes inconsistent.
bool Resample(const Matrix& input_frame, Matrix* output_frame,
bool should_flush);
double source_sample_rate_;
double target_sample_rate_;
int64 cumulative_input_samples_;
int64 cumulative_output_samples_;
Timestamp initial_timestamp_;
bool check_inconsistent_timestamps_;
int num_channels_;
std::vector<std::unique_ptr<ResamplerType>> resampler_;
};
// Test-only access to RationalFactorResampleCalculator methods.
struct RationalFactorResampleCalculator::TestAccess {
static std::unique_ptr<ResamplerType> ResamplerFromOptions(
const double source_sample_rate, const double target_sample_rate,
const RationalFactorResampleCalculatorOptions& options) {
return RationalFactorResampleCalculator::ResamplerFromOptions(
source_sample_rate, target_sample_rate, options);
}
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_

View File

@ -0,0 +1,46 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message RationalFactorResampleCalculatorOptions {
extend CalculatorOptions {
optional RationalFactorResampleCalculatorOptions ext = 259760074;
}
// target_sample_rate is the sample rate, in Hertz, of the output
// stream. Required. Must be greater than 0.
optional double target_sample_rate = 1;
// Parameters for initializing the RationalFactorResampler. See
// RationalFactorResampler for more details.
message ResamplerRationalFactorOptions {
// Kernel radius in units of input samples.
optional double radius = 1;
// Anti-aliasing cutoff frequency in Hertz. A reasonable setting is
// 0.45 * min(input_sample_rate, output_sample_rate).
optional double cutoff = 2;
// The Kaiser beta parameter for the kernel window.
optional double kaiser_beta = 3 [default = 6.0];
}
optional ResamplerRationalFactorOptions resampler_rational_factor_options = 2;
// Set to false to disable checks for jitter in timestamp values. Useful with
// live audio input.
optional bool check_inconsistent_timestamps = 3 [default = true];
}

View File

@ -0,0 +1,247 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.h"
#include <math.h>
#include <algorithm>
#include <string>
#include <vector>
#include "Eigen/Core"
#include "audio/dsp/signal_vector_util.h"
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.pb.h"
#include "mediapipe/framework//tool/validate_type.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/time_series_test_util.h"
namespace mediapipe {
namespace {
const int kInitialTimestampOffsetMilliseconds = 4;
class RationalFactorResampleCalculatorTest
: public TimeSeriesCalculatorTest<RationalFactorResampleCalculatorOptions> {
protected:
void SetUp() override {
calculator_name_ = "RationalFactorResampleCalculator";
input_sample_rate_ = 4000.0;
num_input_channels_ = 3;
}
// Expects two vectors whose lengths are almost the same and whose
// elements are equal (for indices that are present in both).
//
// This is useful because the resampler doesn't make precise
// guarantees about its output size.
void ExpectVectorMostlyFloatEq(const std::vector<float>& expected,
const std::vector<float>& actual) {
// Lengths should be close, but don't have to be equal.
ASSERT_NEAR(expected.size(), actual.size(), 1);
for (int i = 0; i < std::min(expected.size(), actual.size()); ++i) {
EXPECT_FLOAT_EQ(expected[i], actual[i]) << " where i=" << i << ".";
}
}
// Returns a float value with the sample, channel, and timestamp
// separated by a few orders of magnitude, for easy parsing by
// humans.
double TestValue(int sample, int channel, int timestamp_in_microseconds) {
return timestamp_in_microseconds * 100.0 + sample + channel / 10.0;
}
// Caller takes ownership of the returned value.
Matrix* NewTestFrame(int num_channels, int num_samples, int timestamp) {
auto matrix = new Matrix(num_channels, num_samples);
for (int c = 0; c < num_channels; ++c) {
for (int i = 0; i < num_samples; ++i) {
(*matrix)(c, i) = TestValue(i, c, timestamp);
}
}
return matrix;
}
// Initializes and runs the test graph.
::mediapipe::Status Run(double output_sample_rate) {
options_.set_target_sample_rate(output_sample_rate);
InitializeGraph();
FillInputHeader();
concatenated_input_samples_.resize(num_input_channels_, 0);
num_input_samples_ = 0;
for (int i = 0; i < 5; ++i) {
int packet_size = (i + 1) * 10;
int timestamp = kInitialTimestampOffsetMilliseconds +
num_input_samples_ / input_sample_rate_ *
Timestamp::kTimestampUnitsPerSecond;
Matrix* data_frame =
NewTestFrame(num_input_channels_, packet_size, timestamp);
// Keep a reference copy of the input.
//
// conservativeResize() is needed here to preserve the existing
// data. Eigen's resize() resizes without preserving data.
concatenated_input_samples_.conservativeResize(
num_input_channels_, num_input_samples_ + packet_size);
concatenated_input_samples_.rightCols(packet_size) = *data_frame;
num_input_samples_ += packet_size;
AppendInputPacket(data_frame, timestamp);
}
return RunGraph();
}
void CheckOutputLength(double output_sample_rate) {
double factor = output_sample_rate / input_sample_rate_;
int num_output_samples = 0;
for (const Packet& packet : output().packets) {
num_output_samples += packet.Get<Matrix>().cols();
}
// The exact number of expected samples may vary based on the implementation
// of the resampler since the exact value is not an integer.
// TODO: Reduce this offset to + 1 once cl/185829520 is submitted.
const double expected_num_output_samples = num_input_samples_ * factor;
EXPECT_LE(ceil(expected_num_output_samples), num_output_samples);
EXPECT_GE(ceil(expected_num_output_samples) + 11, num_output_samples);
}
// Checks that output timestamps are consistent with the
// output_sample_rate and output packet sizes.
void CheckOutputPacketTimestamps(double output_sample_rate) {
int num_output_samples = 0;
for (const Packet& packet : output().packets) {
const int expected_timestamp = kInitialTimestampOffsetMilliseconds +
num_output_samples / output_sample_rate *
Timestamp::kTimestampUnitsPerSecond;
EXPECT_NEAR(expected_timestamp, packet.Timestamp().Value(), 1);
num_output_samples += packet.Get<Matrix>().cols();
}
}
// Checks that output values from the calculator (which resamples
// packet-by-packet) are consistent with resampling the entire
// signal at once.
void CheckOutputValues(double output_sample_rate) {
for (int i = 0; i < num_input_channels_; ++i) {
auto verification_resampler =
RationalFactorResampleCalculator::TestAccess::ResamplerFromOptions(
input_sample_rate_, output_sample_rate, options_);
std::vector<float> input_data;
for (int j = 0; j < num_input_samples_; ++j) {
input_data.push_back(concatenated_input_samples_(i, j));
}
std::vector<float> expected_resampled_data;
std::vector<float> temp;
verification_resampler->ProcessSamples(input_data, &temp);
audio_dsp::VectorAppend(&expected_resampled_data, temp);
verification_resampler->Flush(&temp);
audio_dsp::VectorAppend(&expected_resampled_data, temp);
std::vector<float> actual_resampled_data;
for (const Packet& packet : output().packets) {
Matrix output_frame_row = packet.Get<Matrix>().row(i);
actual_resampled_data.insert(
actual_resampled_data.end(), &output_frame_row(0),
&output_frame_row(0) + output_frame_row.cols());
}
ExpectVectorMostlyFloatEq(expected_resampled_data, actual_resampled_data);
}
}
void CheckOutputHeaders(double output_sample_rate) {
const TimeSeriesHeader& output_header =
output().header.Get<TimeSeriesHeader>();
TimeSeriesHeader expected_header;
expected_header.set_sample_rate(output_sample_rate);
expected_header.set_num_channels(num_input_channels_);
EXPECT_THAT(output_header, mediapipe::EqualsProto(expected_header));
}
void CheckOutput(double output_sample_rate) {
CheckOutputLength(output_sample_rate);
CheckOutputPacketTimestamps(output_sample_rate);
CheckOutputValues(output_sample_rate);
CheckOutputHeaders(output_sample_rate);
}
void CheckOutputUnchanged() {
for (int i = 0; i < num_input_channels_; ++i) {
std::vector<float> expected_resampled_data;
for (int j = 0; j < num_input_samples_; ++j) {
expected_resampled_data.push_back(concatenated_input_samples_(i, j));
}
std::vector<float> actual_resampled_data;
for (const Packet& packet : output().packets) {
Matrix output_frame_row = packet.Get<Matrix>().row(i);
actual_resampled_data.insert(
actual_resampled_data.end(), &output_frame_row(0),
&output_frame_row(0) + output_frame_row.cols());
}
ExpectVectorMostlyFloatEq(expected_resampled_data, actual_resampled_data);
}
}
int num_input_samples_;
Matrix concatenated_input_samples_;
};
TEST_F(RationalFactorResampleCalculatorTest, Upsample) {
const double kUpsampleRate = input_sample_rate_ * 1.9;
MEDIAPIPE_ASSERT_OK(Run(kUpsampleRate));
CheckOutput(kUpsampleRate);
}
TEST_F(RationalFactorResampleCalculatorTest, Downsample) {
const double kDownsampleRate = input_sample_rate_ / 1.9;
MEDIAPIPE_ASSERT_OK(Run(kDownsampleRate));
CheckOutput(kDownsampleRate);
}
TEST_F(RationalFactorResampleCalculatorTest, UsesRationalFactorResampler) {
const double kUpsampleRate = input_sample_rate_ * 2;
MEDIAPIPE_ASSERT_OK(Run(kUpsampleRate));
CheckOutput(kUpsampleRate);
}
TEST_F(RationalFactorResampleCalculatorTest, PassthroughIfSampleRateUnchanged) {
const double kUpsampleRate = input_sample_rate_;
MEDIAPIPE_ASSERT_OK(Run(kUpsampleRate));
CheckOutputUnchanged();
}
TEST_F(RationalFactorResampleCalculatorTest, FailsOnBadTargetRate) {
ASSERT_FALSE(Run(-999.9).ok()); // Invalid output sample rate.
}
TEST_F(RationalFactorResampleCalculatorTest, DoesNotDieOnEmptyInput) {
options_.set_target_sample_rate(input_sample_rate_);
InitializeGraph();
FillInputHeader();
MEDIAPIPE_ASSERT_OK(RunGraph());
EXPECT_TRUE(output().packets.empty());
}
} // anonymous namespace
} // namespace mediapipe

View File

@ -0,0 +1,425 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Defines SpectrogramCalculator.
#include <math.h>
#include <complex>
#include <deque>
#include <memory>
#include <string>
#include "Eigen/Core"
#include "absl/strings/string_view.h"
#include "audio/dsp/spectrogram/spectrogram.h"
#include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/source_location.h"
#include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
// MediaPipe Calculator for computing the "spectrogram" (short-time Fourier
// transform squared-magnitude, by default) of a multichannel input
// time series, including optionally overlapping frames. Options are
// specified in SpectrogramCalculatorOptions proto (where names are chosen
// to mirror TimeSeriesFramerCalculator):
//
// Result is a MatrixData record (for single channel input and when the
// allow_multichannel_input flag is false), or a vector of MatrixData records,
// one for each channel (when the allow_multichannel_input flag is set). The
// rows of each spectrogram matrix correspond to the n_fft/2+1 unique complex
// values, or squared/linear/dB magnitudes, depending on the output_type option.
// Each input packet will result in zero or one output packets, each containing
// one Matrix for each channel of the input, where each Matrix has one or more
// columns of spectral values, one for each complete frame of input samples. If
// the input packet contains too few samples to trigger a new output frame, no
// output packet is generated (since zero-length packets are not legal since
// they would result in timestamps that were equal, not strictly increasing).
//
// Output packet Timestamps are set to the beginning of each frame. This is to
// allow calculators downstream from SpectrogramCalculator to have aligned
// Timestamps regardless of a packet's signal length.
//
// Both frame_duration_seconds and frame_overlap_seconds will be
// rounded to the nearest integer number of samples. Conseqently, all output
// frames will be based on the same number of input samples, and each
// analysis frame will advance from its predecessor by the same time step.
class SpectrogramCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Input stream with TimeSeriesHeader.
);
SpectrogramCalculatorOptions spectrogram_options;
time_series_util::FillOptionsExtensionOrDie(cc->Options(),
&spectrogram_options);
if (!spectrogram_options.allow_multichannel_input()) {
if (spectrogram_options.output_type() ==
SpectrogramCalculatorOptions::COMPLEX) {
cc->Outputs().Index(0).Set<Eigen::MatrixXcf>(
// Complex spectrogram frames with TimeSeriesHeader.
);
} else {
cc->Outputs().Index(0).Set<Matrix>(
// Spectrogram frames with TimeSeriesHeader.
);
}
} else {
if (spectrogram_options.output_type() ==
SpectrogramCalculatorOptions::COMPLEX) {
cc->Outputs().Index(0).Set<std::vector<Eigen::MatrixXcf>>(
// Complex spectrogram frames with MultiStreamTimeSeriesHeader.
);
} else {
cc->Outputs().Index(0).Set<std::vector<Matrix>>(
// Spectrogram frames with MultiStreamTimeSeriesHeader.
);
}
}
return ::mediapipe::OkStatus();
}
// Returns FAIL if the input stream header is invalid.
::mediapipe::Status Open(CalculatorContext* cc) override;
// Outputs at most one packet consisting of a single Matrix with one or
// more columns containing the spectral values from as many input frames
// as are completed by the input samples. Always returns OK.
::mediapipe::Status Process(CalculatorContext* cc) override;
// Performs zero-padding and processing of any remaining samples
// if pad_final_packet is set.
// Returns OK.
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
Timestamp CurrentOutputTimestamp() {
// Current output timestamp is the *center* of the next frame to be
// emitted, hence delayed by half a window duration compared to relevant
// input timestamp.
return initial_input_timestamp_ +
round(cumulative_completed_frames_ * frame_step_samples() *
Timestamp::kTimestampUnitsPerSecond / input_sample_rate_);
}
int frame_step_samples() const {
return frame_duration_samples_ - frame_overlap_samples_;
}
// Take the next set of input samples, already translated into a
// vector<float> and pass them to the spectrogram object.
// Convert the output of the spectrogram object into a Matrix (or an
// Eigen::MatrixXcf if complex-valued output is requested) and pass to
// MediaPipe output.
::mediapipe::Status ProcessVector(const Matrix& input_stream,
CalculatorContext* cc);
// Templated function to process either real- or complex-output spectrogram.
template <class OutputMatrixType>
::mediapipe::Status ProcessVectorToOutput(
const Matrix& input_stream,
const OutputMatrixType postprocess_output_fn(const OutputMatrixType&),
CalculatorContext* cc);
double input_sample_rate_;
bool pad_final_packet_;
int frame_duration_samples_;
int frame_overlap_samples_;
// How many samples we've been passed, used for checking input time stamps.
int64 cumulative_input_samples_;
// How many frames we've emitted, used for calculating output time stamps.
int64 cumulative_completed_frames_;
Timestamp initial_input_timestamp_;
int num_input_channels_;
// How many frequency bins we emit (=N_FFT/2 + 1).
int num_output_channels_;
// Which output type?
int output_type_;
// Output type: mono or multichannel.
bool allow_multichannel_input_;
// Vector of Spectrogram objects, one for each channel.
std::vector<std::unique_ptr<audio_dsp::Spectrogram>> spectrogram_generators_;
// Fixed scale factor applied to output values (regardless of type).
double output_scale_;
static const float kLnPowerToDb;
};
REGISTER_CALCULATOR(SpectrogramCalculator);
// Factor to convert ln(magnitude_squared) to deciBels = 10.0/ln(10.0).
const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518;
::mediapipe::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
SpectrogramCalculatorOptions spectrogram_options;
time_series_util::FillOptionsExtensionOrDie(cc->Options(),
&spectrogram_options);
if (spectrogram_options.frame_duration_seconds() <= 0.0) {
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Invalid or missing frame_duration_seconds.\n"
"frame_duration_seconds: "
<< spectrogram_options.frame_overlap_seconds();
}
if (spectrogram_options.frame_overlap_seconds() >=
spectrogram_options.frame_duration_seconds()) {
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Invalid frame_overlap_seconds.\nframe_overlap_seconds: "
<< spectrogram_options.frame_overlap_seconds()
<< "\nframe_duration_seconds: "
<< spectrogram_options.frame_duration_seconds();
}
if (spectrogram_options.frame_overlap_seconds() < 0.0) {
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Frame_overlap_seconds is < 0.0.\nframe_overlap_seconds: "
<< spectrogram_options.frame_overlap_seconds();
}
TimeSeriesHeader input_header;
RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), &input_header));
input_sample_rate_ = input_header.sample_rate();
num_input_channels_ = input_header.num_channels();
if (!spectrogram_options.allow_multichannel_input() &&
num_input_channels_ != 1) {
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "The current setting only supports single-channel input. Please set "
"allow_multichannel_input.\n";
}
frame_duration_samples_ =
round(spectrogram_options.frame_duration_seconds() * input_sample_rate_);
frame_overlap_samples_ =
round(spectrogram_options.frame_overlap_seconds() * input_sample_rate_);
pad_final_packet_ = spectrogram_options.pad_final_packet();
output_type_ = spectrogram_options.output_type();
allow_multichannel_input_ = spectrogram_options.allow_multichannel_input();
output_scale_ = spectrogram_options.output_scale();
std::vector<double> window;
switch (spectrogram_options.window_type()) {
case SpectrogramCalculatorOptions::HANN:
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::HAMMING:
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
}
// Propagate settings down to the actual Spectrogram object.
spectrogram_generators_.clear();
for (int i = 0; i < num_input_channels_; i++) {
spectrogram_generators_.push_back(
std::unique_ptr<audio_dsp::Spectrogram>(new audio_dsp::Spectrogram()));
spectrogram_generators_[i]->Initialize(window, frame_step_samples());
}
num_output_channels_ =
spectrogram_generators_[0]->output_frequency_channels();
std::unique_ptr<TimeSeriesHeader> output_header(
new TimeSeriesHeader(input_header));
// Store the actual sample rate of the input audio in the TimeSeriesHeader
// so that subsequent calculators can figure out the frequency scale of
// our output.
output_header->set_audio_sample_rate(input_sample_rate_);
// Setup rest of output header.
output_header->set_num_channels(num_output_channels_);
output_header->set_sample_rate(input_sample_rate_ / frame_step_samples());
// Although we usually generate one output packet for each input
// packet, this might not be true for input packets whose size is smaller
// than the analysis window length. So we clear output_header.packet_rate
// because we can't guarantee a constant packet rate. Similarly, the number
// of output frames per packet depends on the input packet, so we also clear
// output_header.num_samples.
output_header->clear_packet_rate();
output_header->clear_num_samples();
if (!spectrogram_options.allow_multichannel_input()) {
cc->Outputs().Index(0).SetHeader(Adopt(output_header.release()));
} else {
std::unique_ptr<MultiStreamTimeSeriesHeader> multichannel_output_header(
new MultiStreamTimeSeriesHeader());
*multichannel_output_header->mutable_time_series_header() = *output_header;
multichannel_output_header->set_num_streams(num_input_channels_);
cc->Outputs().Index(0).SetHeader(
Adopt(multichannel_output_header.release()));
}
cumulative_completed_frames_ = 0;
initial_input_timestamp_ = Timestamp::Unstarted();
return ::mediapipe::OkStatus();
}
::mediapipe::Status SpectrogramCalculator::Process(CalculatorContext* cc) {
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
initial_input_timestamp_ = cc->InputTimestamp();
}
const Matrix& input_stream = cc->Inputs().Index(0).Get<Matrix>();
if (input_stream.rows() != num_input_channels_) {
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Number of input channels do not correspond to the number of rows "
<< "in the input matrix: " << num_input_channels_ << "channels vs "
<< input_stream.rows() << " rows";
}
cumulative_input_samples_ += input_stream.cols();
return ProcessVector(input_stream, cc);
}
template <class OutputMatrixType>
::mediapipe::Status SpectrogramCalculator::ProcessVectorToOutput(
const Matrix& input_stream,
const OutputMatrixType postprocess_output_fn(const OutputMatrixType&),
CalculatorContext* cc) {
std::unique_ptr<std::vector<OutputMatrixType>> spectrogram_matrices(
new std::vector<OutputMatrixType>());
std::vector<std::vector<typename OutputMatrixType::Scalar>> output_vectors;
// Compute a spectrogram for each channel.
int num_output_time_frames;
for (int channel = 0; channel < input_stream.rows(); ++channel) {
output_vectors.clear();
// Copy one row (channel) of the input matrix into the std::vector.
std::vector<float> input_vector(input_stream.cols());
Eigen::Map<Matrix>(&input_vector[0], 1, input_vector.size()) =
input_stream.row(channel);
if (!spectrogram_generators_[channel]->ComputeSpectrogram(
input_vector, &output_vectors)) {
return ::mediapipe::Status(mediapipe::StatusCode::kInternal,
"Spectrogram returned failure");
}
if (channel == 0) {
// Record the number of time frames we expect from each channel.
num_output_time_frames = output_vectors.size();
} else {
RET_CHECK_EQ(output_vectors.size(), num_output_time_frames)
<< "Inconsistent spectrogram time frames for channel " << channel;
}
// Skip remaining processing if there are too few input samples to trigger
// any output frames.
if (!output_vectors.empty()) {
// Translate the returned values into a matrix of output frames.
OutputMatrixType output_frames(num_output_channels_,
output_vectors.size());
for (int frame = 0; frame < output_vectors.size(); ++frame) {
Eigen::Map<const OutputMatrixType> frame_map(
&output_vectors[frame][0], output_vectors[frame].size(), 1);
// The underlying dsp object returns squared magnitudes; here
// we optionally translate to linear magnitude or dB.
output_frames.col(frame) =
output_scale_ * postprocess_output_fn(frame_map);
}
spectrogram_matrices->push_back(output_frames);
}
}
// If the input is very short, there may not be enough accumulated,
// unprocessed samples to cause any new frames to be generated by
// the spectrogram object. If so, we don't want to emit
// a packet at all.
if (!spectrogram_matrices->empty()) {
RET_CHECK_EQ(spectrogram_matrices->size(), input_stream.rows())
<< "Inconsistent number of spectrogram channels.";
if (allow_multichannel_input_) {
cc->Outputs().Index(0).Add(spectrogram_matrices.release(),
CurrentOutputTimestamp());
} else {
cc->Outputs().Index(0).Add(
new OutputMatrixType(spectrogram_matrices->at(0)),
CurrentOutputTimestamp());
}
cumulative_completed_frames_ += output_vectors.size();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status SpectrogramCalculator::ProcessVector(
const Matrix& input_stream, CalculatorContext* cc) {
switch (output_type_) {
// These blocks deliberately ignore clang-format to preserve the
// "silhouette" of the different cases.
// clang-format off
case SpectrogramCalculatorOptions::COMPLEX: {
return ProcessVectorToOutput(
input_stream,
+[](const Eigen::MatrixXcf& col) -> const Eigen::MatrixXcf {
return col;
}, cc);
}
case SpectrogramCalculatorOptions::SQUARED_MAGNITUDE: {
return ProcessVectorToOutput(
input_stream,
+[](const Matrix& col) -> const Matrix {
return col;
}, cc);
}
case SpectrogramCalculatorOptions::LINEAR_MAGNITUDE: {
return ProcessVectorToOutput(
input_stream,
+[](const Matrix& col) -> const Matrix {
return col.array().sqrt().matrix();
}, cc);
}
case SpectrogramCalculatorOptions::DECIBELS: {
return ProcessVectorToOutput(
input_stream,
+[](const Matrix& col) -> const Matrix {
return kLnPowerToDb * col.array().log().matrix();
}, cc);
}
// clang-format on
default: {
return ::mediapipe::Status(mediapipe::StatusCode::kInvalidArgument,
"Unrecognized spectrogram output type.");
}
}
}
::mediapipe::Status SpectrogramCalculator::Close(CalculatorContext* cc) {
if (cumulative_input_samples_ > 0 && pad_final_packet_) {
// We can flush any remaining samples by sending frame_step_samples - 1
// zeros to the Process method, and letting it do its thing,
// UNLESS we have fewer than one window's worth of samples, in which case
// we pad to exactly one frame_duration_samples.
// Release the memory for the Spectrogram objects.
int required_padding_samples = frame_step_samples() - 1;
if (cumulative_input_samples_ < frame_duration_samples_) {
required_padding_samples =
frame_duration_samples_ - cumulative_input_samples_;
}
return ProcessVector(
Matrix::Zero(num_input_channels_, required_padding_samples), cc);
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,68 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message SpectrogramCalculatorOptions {
extend CalculatorOptions {
optional SpectrogramCalculatorOptions ext = 76186688;
}
// Options mirror those of TimeSeriesFramerCalculator.
// Analysis window duration in seconds. Required. Must be greater than 0.
// (Note: the spectrogram DFT length will be the smallest power-of-2
// sample count that can hold this duration.)
optional double frame_duration_seconds = 1;
// Duration of overlap between adjacent windows.
// Hence, frame_rate = 1/(frame_duration_seconds - frame_overlap_seconds).
// Required that 0 <= frame_overlap_seconds < frame_duration_seconds.
optional double frame_overlap_seconds = 2 [default = 0.0];
// Whether to pad the final packet with zeros. If true, guarantees that
// all input samples will output. If set to false, any partial packet
// at the end of the stream will be dropped.
optional bool pad_final_packet = 3 [default = true];
// Output value type can be squared-magnitude, linear-magnitude,
// deciBels (dB, = 20*log10(linear_magnitude)), or std::complex.
enum OutputType {
SQUARED_MAGNITUDE = 0;
LINEAR_MAGNITUDE = 1;
DECIBELS = 2;
COMPLEX = 3;
}
optional OutputType output_type = 4 [default = SQUARED_MAGNITUDE];
// If set to true then the output will be a vector of spectrograms, one for
// each channel and the stream will have a MultiStreamTimeSeriesHeader.
optional bool allow_multichannel_input = 5 [default = false];
// Which window to use when computing the FFT.
enum WindowType {
HANN = 0;
HAMMING = 1;
}
optional WindowType window_type = 6 [default = HANN];
// Support a fixed multiplicative scaling of the output. This is applied
// uniformly regardless of output type (i.e., even dBs are multiplied, not
// offset).
optional double output_scale = 7 [default = 1.0];
}

View File

@ -0,0 +1,895 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <cmath>
#include <complex>
#include <memory>
#include <numeric>
#include <string>
#include <vector>
#include "Eigen/Core"
#include "audio/dsp/number_util.h"
#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/benchmark.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/time_series_test_util.h"
namespace mediapipe {
namespace {
const int kInitialTimestampOffsetMicroseconds = 4;
class SpectrogramCalculatorTest
: public TimeSeriesCalculatorTest<SpectrogramCalculatorOptions> {
protected:
void SetUp() override {
calculator_name_ = "SpectrogramCalculator";
input_sample_rate_ = 4000.0;
num_input_channels_ = 1;
}
// Initializes and runs the test graph.
::mediapipe::Status Run() {
// Now that options are set, we can set up some internal constants.
frame_duration_samples_ =
round(options_.frame_duration_seconds() * input_sample_rate_);
frame_step_samples_ =
frame_duration_samples_ -
round(options_.frame_overlap_seconds() * input_sample_rate_);
// The magnitude of the 0th FFT bin (DC) should be sum(input.*window);
// for an input identically 1.0, this is just sum(window). The average
// value of our Hann window is 0.5, hence this is the expected squared-
// magnitude output value in the DC bin for constant input of 1.0.
expected_dc_squared_magnitude_ =
pow((static_cast<float>(frame_duration_samples_) * 0.5), 2.0);
return RunGraph();
}
// Creates test multichannel input with specified packet sizes and containing
// a constant-frequency sinusoid that maintains phase between adjacent
// packets.
void SetupCosineInputPackets(const std::vector<int>& packet_sizes_samples,
float cosine_frequency_hz) {
int total_num_input_samples = 0;
for (int packet_size_samples : packet_sizes_samples) {
double packet_start_time_seconds =
kInitialTimestampOffsetMicroseconds * 1e-6 +
total_num_input_samples / input_sample_rate_;
double packet_end_time_seconds =
packet_start_time_seconds + packet_size_samples / input_sample_rate_;
double angular_freq = 2 * M_PI * cosine_frequency_hz;
Matrix* packet_data =
new Matrix(num_input_channels_, packet_size_samples);
// Use Eigen's vectorized cos() function to fill the vector with a
// sinusoid of appropriate frequency & phase.
for (int i = 0; i < num_input_channels_; i++) {
packet_data->row(i) =
Eigen::ArrayXf::LinSpaced(packet_size_samples,
packet_start_time_seconds * angular_freq,
packet_end_time_seconds * angular_freq)
.cos()
.transpose();
}
int64 input_timestamp = round(packet_start_time_seconds *
Timestamp::kTimestampUnitsPerSecond);
AppendInputPacket(packet_data, input_timestamp);
total_num_input_samples += packet_size_samples;
}
}
// Setup a sequence of input packets of specified sizes, each filled
// with samples of 1.0.
void SetupConstantInputPackets(const std::vector<int>& packet_sizes_samples) {
// A 0 Hz cosine is identically 1.0 for all samples.
SetupCosineInputPackets(packet_sizes_samples, 0.0);
}
// Setup a sequence of input packets of specified sizes, each containing a
// single sample of 1.0 at a specified offset.
void SetupImpulseInputPackets(
const std::vector<int>& packet_sizes_samples,
const std::vector<int>& impulse_offsets_samples) {
int total_num_input_samples = 0;
for (int i = 0; i < packet_sizes_samples.size(); ++i) {
double packet_start_time_seconds =
kInitialTimestampOffsetMicroseconds * 1e-6 +
total_num_input_samples / input_sample_rate_;
int64 input_timestamp = round(packet_start_time_seconds *
Timestamp::kTimestampUnitsPerSecond);
std::unique_ptr<Matrix> impulse(
new Matrix(Matrix::Zero(1, packet_sizes_samples[i])));
(*impulse)(0, impulse_offsets_samples[i]) = 1.0;
AppendInputPacket(impulse.release(), input_timestamp);
total_num_input_samples += packet_sizes_samples[i];
}
}
// Creates test multichannel input with specified packet sizes and containing
// constant input packets for the even channels and constant-frequency
// sinusoid that maintains phase between adjacent packets for the odd
// channels.
void SetupMultichannelInputPackets(
const std::vector<int>& packet_sizes_samples, float cosine_frequency_hz) {
int total_num_input_samples = 0;
for (int packet_size_samples : packet_sizes_samples) {
double packet_start_time_seconds =
kInitialTimestampOffsetMicroseconds * 1e-6 +
total_num_input_samples / input_sample_rate_;
double packet_end_time_seconds =
packet_start_time_seconds + packet_size_samples / input_sample_rate_;
double angular_freq;
Matrix* packet_data =
new Matrix(num_input_channels_, packet_size_samples);
// Use Eigen's vectorized cos() function to fill the vector with a
// sinusoid of appropriate frequency & phase.
for (int i = 0; i < num_input_channels_; i++) {
if (i % 2 == 0) {
angular_freq = 0;
} else {
angular_freq = 2 * M_PI * cosine_frequency_hz;
}
packet_data->row(i) =
Eigen::ArrayXf::LinSpaced(packet_size_samples,
packet_start_time_seconds * angular_freq,
packet_end_time_seconds * angular_freq)
.cos()
.transpose();
}
int64 input_timestamp = round(packet_start_time_seconds *
Timestamp::kTimestampUnitsPerSecond);
AppendInputPacket(packet_data, input_timestamp);
total_num_input_samples += packet_size_samples;
}
}
// Return vector of the numbers of frames in each output packet.
std::vector<int> OutputFramesPerPacket() {
std::vector<int> frame_counts;
for (const Packet& packet : output().packets) {
const Matrix& matrix = packet.Get<Matrix>();
frame_counts.push_back(matrix.cols());
}
return frame_counts;
}
// Checks output headers and Timestamps.
void CheckOutputHeadersAndTimestamps() {
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
TimeSeriesHeader expected_header = input().header.Get<TimeSeriesHeader>();
expected_header.set_num_channels(fft_size / 2 + 1);
// The output header sample rate should depend on the output frame step.
expected_header.set_sample_rate(input_sample_rate_ / frame_step_samples_);
// SpectrogramCalculator stores the sample rate of the input in
// the TimeSeriesHeader.
expected_header.set_audio_sample_rate(input_sample_rate_);
// We expect the output header to have num_samples and packet_rate unset.
expected_header.clear_num_samples();
expected_header.clear_packet_rate();
if (!options_.allow_multichannel_input()) {
ExpectOutputHeaderEquals(expected_header);
} else {
EXPECT_THAT(output()
.header.template Get<MultiStreamTimeSeriesHeader>()
.time_series_header(),
mediapipe::EqualsProto(expected_header));
EXPECT_THAT(output()
.header.template Get<MultiStreamTimeSeriesHeader>()
.num_streams(),
num_input_channels_);
}
int cumulative_output_frames = 0;
// The timestamps coming out of the spectrogram correspond to the
// middle of the first frame's window, hence frame_duration_samples_/2
// term. We use frame_duration_samples_ because that is how it is
// actually quantized inside spectrogram.
const double packet_timestamp_offset_seconds =
kInitialTimestampOffsetMicroseconds * 1e-6;
const double frame_step_seconds = frame_step_samples_ / input_sample_rate_;
Timestamp initial_timestamp = Timestamp::Unstarted();
for (const Packet& packet : output().packets) {
// This is the timestamp we expect based on how the spectrogram should
// behave (advancing by one step's worth of input samples each frame).
const double expected_timestamp_seconds =
packet_timestamp_offset_seconds +
cumulative_output_frames * frame_step_seconds;
const int64 expected_timestamp_ticks =
expected_timestamp_seconds * Timestamp::kTimestampUnitsPerSecond;
EXPECT_EQ(expected_timestamp_ticks, packet.Timestamp().Value());
// Accept the timestamp of the first packet as the baseline for checking
// the remainder.
if (initial_timestamp == Timestamp::Unstarted()) {
initial_timestamp = packet.Timestamp();
}
// Also check that the timestamp is consistent with the sample_rate
// in the output stream's TimeSeriesHeader.
EXPECT_TRUE(time_series_util::LogWarningIfTimestampIsInconsistent(
packet.Timestamp(), initial_timestamp, cumulative_output_frames,
expected_header.sample_rate()));
if (!options_.allow_multichannel_input()) {
if (options_.output_type() == SpectrogramCalculatorOptions::COMPLEX) {
const Eigen::MatrixXcf& matrix = packet.Get<Eigen::MatrixXcf>();
cumulative_output_frames += matrix.cols();
} else {
const Matrix& matrix = packet.Get<Matrix>();
cumulative_output_frames += matrix.cols();
}
} else {
if (options_.output_type() == SpectrogramCalculatorOptions::COMPLEX) {
const Eigen::MatrixXcf& matrix =
packet.Get<std::vector<Eigen::MatrixXcf>>().at(0);
cumulative_output_frames += matrix.cols();
} else {
const Matrix& matrix = packet.Get<std::vector<Matrix>>().at(0);
cumulative_output_frames += matrix.cols();
}
}
}
}
// Verify that the bin corresponding to the specified frequency
// is the largest one in one particular frame of a single packet.
void CheckPeakFrequencyInPacketFrame(const Packet& packet, int frame,
float frequency) {
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
const int target_bin =
round((frequency / input_sample_rate_) * static_cast<float>(fft_size));
const Matrix& matrix = packet.Get<Matrix>();
// Stop here if the requested frame is not in this packet.
ASSERT_GT(matrix.cols(), frame);
int actual_largest_bin;
matrix.col(frame).maxCoeff(&actual_largest_bin);
EXPECT_EQ(actual_largest_bin, target_bin);
}
// Verify that the bin corresponding to the specified frequency
// is the largest one in one particular frame of a single spectrogram Matrix.
void CheckPeakFrequencyInMatrix(const Matrix& matrix, int frame,
float frequency) {
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
const int target_bin =
round((frequency / input_sample_rate_) * static_cast<float>(fft_size));
// Stop here if the requested frame is not in this packet.
ASSERT_GT(matrix.cols(), frame);
int actual_largest_bin;
matrix.col(frame).maxCoeff(&actual_largest_bin);
EXPECT_EQ(actual_largest_bin, target_bin);
}
int frame_duration_samples_;
int frame_step_samples_;
// Expected DC output for a window of pure 1.0, set when window length
// is set.
float expected_dc_squared_magnitude_;
};
TEST_F(SpectrogramCalculatorTest, IntegerFrameDurationNoOverlap) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {500, 200};
const std::vector<int> expected_output_packet_sizes = {5, 2};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, IntegerFrameDurationSomeOverlap) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {500, 200};
// complete_output_frames = 1 + floor((input_samples - window_length)/step)
// = 1 + floor((500 - 100)/40) = 1 + 10 = 11 for the first packet
// = 1 + floor((700 - 100)/40) = 1 + 15 = 16 for the whole stream
// so expect 16 - 11 = 5 in the second packet.
const std::vector<int> expected_output_packet_sizes = {11, 5};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, NonintegerFrameDurationAndOverlap) {
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
options_.set_frame_overlap_seconds(58.4 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {500, 200};
// now frame_duration_samples will be 99 (rounded), and frame_step_samples
// will be (99-58) = 41, so the first packet of 500 samples will generate
// 1 + floor(500-99)/41 = 10 samples.
const std::vector<int> expected_output_packet_sizes = {10, 5};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, ShortInitialPacketNoOverlap) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {90, 100, 110};
// The first input packet is too small to generate any frames,
// but zero-length packets would result in a timestamp monotonicity
// violation, so they are suppressed. Thus, only the second and third
// input packets generate output packets.
const std::vector<int> expected_output_packet_sizes = {1, 2};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, TrailingSamplesNoPad) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {140, 90};
const std::vector<int> expected_output_packet_sizes = {2, 2};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, NoTrailingSamplesWithPad) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(true);
const std::vector<int> input_packet_sizes = {140, 80};
const std::vector<int> expected_output_packet_sizes = {2, 2};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, TrailingSamplesWithPad) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(true);
const std::vector<int> input_packet_sizes = {140, 90};
// In contrast to NoTrailingSamplesWithPad and TrailingSamplesNoPad,
// this time we get an extra frame in an extra final packet.
const std::vector<int> expected_output_packet_sizes = {2, 2, 1};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, VeryShortInputWillPad) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(true);
const std::vector<int> input_packet_sizes = {30};
const std::vector<int> expected_output_packet_sizes = {1};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, VeryShortInputZeroOutputFramesIfNoPad) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {90};
const std::vector<int> expected_output_packet_sizes = {};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, DCSignalIsPeakBin) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
const std::vector<int> input_packet_sizes = {140}; // Gives 2 output frames.
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
const float dc_frequency_hz = 0.0;
CheckPeakFrequencyInPacketFrame(output().packets[0], 0, dc_frequency_hz);
CheckPeakFrequencyInPacketFrame(output().packets[0], 1, dc_frequency_hz);
}
TEST_F(SpectrogramCalculatorTest, A440ToneIsPeakBin) {
const std::vector<int> input_packet_sizes = {
460}; // 100 + 9*40 for 10 frames.
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
int num_output_frames = output().packets[0].Get<Matrix>().cols();
for (int frame = 0; frame < num_output_frames; ++frame) {
CheckPeakFrequencyInPacketFrame(output().packets[0], frame,
tone_frequency_hz);
}
}
TEST_F(SpectrogramCalculatorTest, SquaredMagnitudeOutputLooksRight) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::SQUARED_MAGNITUDE);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
expected_dc_squared_magnitude_);
}
TEST_F(SpectrogramCalculatorTest, DefaultOutputIsSquaredMagnitude) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
// Let the output_type be its default
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
expected_dc_squared_magnitude_);
}
TEST_F(SpectrogramCalculatorTest, LinearMagnitudeOutputLooksRight) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::LINEAR_MAGNITUDE);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
std::sqrt(expected_dc_squared_magnitude_));
}
TEST_F(SpectrogramCalculatorTest, DbMagnitudeOutputLooksRight) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::DECIBELS);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
10.0 * std::log10(expected_dc_squared_magnitude_));
}
TEST_F(SpectrogramCalculatorTest, OutputScalingLooksRight) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::DECIBELS);
double output_scale = 2.5;
options_.set_output_scale(output_scale);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(
output().packets[0].Get<Matrix>()(0, 0),
output_scale * 10.0 * std::log10(expected_dc_squared_magnitude_));
}
TEST_F(SpectrogramCalculatorTest, ComplexOutputLooksRight) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(std::norm(output().packets[0].Get<Eigen::MatrixXcf>()(0, 0)),
expected_dc_squared_magnitude_);
}
TEST_F(SpectrogramCalculatorTest, ComplexOutputLooksRightForImpulses) {
const int frame_size_samples = 100;
options_.set_frame_duration_seconds(frame_size_samples / input_sample_rate_);
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
const std::vector<int> input_packet_sizes = {frame_size_samples,
frame_size_samples};
const std::vector<int> input_packet_impulse_offsets = {49, 50};
InitializeGraph();
FillInputHeader();
// Make two impulse packets offset one sample from each other
SetupImpulseInputPackets(input_packet_sizes, input_packet_impulse_offsets);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
const int num_buckets =
(audio_dsp::NextPowerOfTwo(frame_size_samples) / 2) + 1;
const float precision = 0.01f;
auto norm_fn = [](const std::complex<float>& cf) { return std::norm(cf); };
// Both impulses should have (approximately) constant power across all
// frequency bins
EXPECT_TRUE(output()
.packets[0]
.Get<Eigen::MatrixXcf>()
.unaryExpr(norm_fn)
.isApproxToConstant(1.0f, precision));
EXPECT_TRUE(output()
.packets[1]
.Get<Eigen::MatrixXcf>()
.unaryExpr(norm_fn)
.isApproxToConstant(1.0f, precision));
// Because the second Packet's impulse is delayed by exactly one sample with
// respect to the first Packet's impulse, the second impulse should have
// greater phase, and in the highest frequency bin, the real part should
// (approximately) flip sign from the first Packet to the second
EXPECT_LT(std::arg(output().packets[0].Get<Eigen::MatrixXcf>()(1, 0)),
std::arg(output().packets[1].Get<Eigen::MatrixXcf>()(1, 0)));
const float highest_bucket_real_ratio =
output().packets[0].Get<Eigen::MatrixXcf>()(num_buckets - 1, 0).real() /
output().packets[1].Get<Eigen::MatrixXcf>()(num_buckets - 1, 0).real();
EXPECT_NEAR(highest_bucket_real_ratio, -1.0f, precision);
}
TEST_F(SpectrogramCalculatorTest, SquaredMagnitudeOutputLooksRightForNonDC) {
const int frame_size_samples = 100;
options_.set_frame_duration_seconds(frame_size_samples / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::SQUARED_MAGNITUDE);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Make the tone have an integral number of cycles within the window
const int target_bin = 16;
const int fft_size = audio_dsp::NextPowerOfTwo(frame_size_samples);
const float tone_frequency_hz = target_bin * (input_sample_rate_ / fft_size);
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
// For a non-DC bin, the magnitude will be split between positive and
// negative frequency bins, so it should about be half-magnitude
// = quarter-power.
// It's not quite exact because of the interference from the hann(100)
// spread from the negative-frequency half.
EXPECT_GT(output().packets[0].Get<Matrix>()(target_bin, 0),
0.98 * expected_dc_squared_magnitude_ / 4.0);
EXPECT_LT(output().packets[0].Get<Matrix>()(target_bin, 0),
1.02 * expected_dc_squared_magnitude_ / 4.0);
}
TEST_F(SpectrogramCalculatorTest, ZeroOutputsForZeroInputsWithPaddingEnabled) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(true);
const std::vector<int> input_packet_sizes = {};
const std::vector<int> expected_output_packet_sizes = {};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, NumChannelsIsRight) {
const std::vector<int> input_packet_sizes = {460};
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
options_.set_allow_multichannel_input(true);
num_input_channels_ = 3;
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(output().packets[0].Get<std::vector<Matrix>>().size(),
num_input_channels_);
}
TEST_F(SpectrogramCalculatorTest, NumSamplesAndPacketRateAreCleared) {
num_input_samples_ = 500;
input_packet_rate_ = 1.0;
const std::vector<int> input_packet_sizes = {num_input_samples_};
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(0.0);
options_.set_pad_final_packet(false);
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MEDIAPIPE_ASSERT_OK(Run());
const TimeSeriesHeader& output_header =
output().header.Get<TimeSeriesHeader>();
EXPECT_FALSE(output_header.has_num_samples());
EXPECT_FALSE(output_header.has_packet_rate());
}
TEST_F(SpectrogramCalculatorTest, MultichannelSpectrogramSizesAreRight) {
const std::vector<int> input_packet_sizes = {420}; // less than 10 frames
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
options_.set_allow_multichannel_input(true);
num_input_channels_ = 10;
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
int spectrogram_num_rows = spectrograms[0].rows();
int spectrogram_num_cols = spectrograms[0].cols();
for (int i = 1; i < num_input_channels_; i++) {
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
}
}
TEST_F(SpectrogramCalculatorTest, MultichannelSpectrogramValuesAreRight) {
const std::vector<int> input_packet_sizes = {
460}; // 100 + 9*40 for 10 frames.
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_allow_multichannel_input(true);
num_input_channels_ = 10;
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupMultichannelInputPackets(input_packet_sizes, tone_frequency_hz);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
int num_output_frames = spectrograms[0].cols();
for (int i = 0; i < num_input_channels_; i++) {
for (int frame = 0; frame < num_output_frames; ++frame) {
if (i % 2 == 0) {
CheckPeakFrequencyInMatrix(spectrograms[i], frame, 0);
} else {
CheckPeakFrequencyInMatrix(spectrograms[i], frame, tone_frequency_hz);
}
}
}
}
TEST_F(SpectrogramCalculatorTest, MultichannelHandlesShortInitialPacket) {
// First packet is less than one frame, but second packet should trigger a
// complete frame from all channels.
const std::vector<int> input_packet_sizes = {50, 50};
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
options_.set_allow_multichannel_input(true);
num_input_channels_ = 2;
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
int spectrogram_num_rows = spectrograms[0].rows();
int spectrogram_num_cols = spectrograms[0].cols();
for (int i = 1; i < num_input_channels_; i++) {
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
}
}
TEST_F(SpectrogramCalculatorTest,
MultichannelComplexHandlesShortInitialPacket) {
// First packet is less than one frame, but second packet should trigger a
// complete frame from all channels, even for complex output.
const std::vector<int> input_packet_sizes = {50, 50};
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
options_.set_allow_multichannel_input(true);
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
num_input_channels_ = 2;
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
auto spectrograms = output().packets[0].Get<std::vector<Eigen::MatrixXcf>>();
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
int spectrogram_num_rows = spectrograms[0].rows();
int spectrogram_num_cols = spectrograms[0].cols();
for (int i = 1; i < num_input_channels_; i++) {
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
}
}
void BM_ProcessDC(benchmark::State& state) {
CalculatorGraphConfig::Node node_config;
node_config.set_calculator("SpectrogramCalculator");
node_config.add_input_stream("input_audio");
node_config.add_output_stream("output_spectrogram");
SpectrogramCalculatorOptions* options =
node_config.mutable_options()->MutableExtension(
SpectrogramCalculatorOptions::ext);
options->set_frame_duration_seconds(0.010);
options->set_frame_overlap_seconds(0.0);
options->set_pad_final_packet(false);
*node_config.mutable_options()->MutableExtension(
SpectrogramCalculatorOptions::ext) = *options;
int num_input_channels = 1;
int packet_size_samples = 1600000;
TimeSeriesHeader* header = new TimeSeriesHeader();
header->set_sample_rate(16000.0);
header->set_num_channels(num_input_channels);
CalculatorRunner runner(node_config);
runner.MutableInputs()->Index(0).header = Adopt(header);
Matrix* payload = new Matrix(
Matrix::Constant(num_input_channels, packet_size_samples, 1.0));
Timestamp timestamp = Timestamp(0);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(payload).At(timestamp));
for (auto _ : state) {
ASSERT_TRUE(runner.Run().ok());
}
const CalculatorRunner::StreamContents& output = runner.Outputs().Index(0);
const Matrix& output_matrix = output.packets[0].Get<Matrix>();
LOG(INFO) << "Output matrix=" << output_matrix.rows() << "x"
<< output_matrix.cols();
LOG(INFO) << "First values=" << output_matrix(0, 0) << ", "
<< output_matrix(1, 0) << ", " << output_matrix(2, 0) << ", "
<< output_matrix(3, 0);
}
BENCHMARK(BM_ProcessDC);
} // anonymous namespace
} // namespace mediapipe

View File

@ -0,0 +1,27 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
licenses(["notice"]) # Apache 2.0
filegroup(
name = "test_audios",
srcs = [
"sine_wave_1k_44100_mono_2_sec_wav.audio",
"sine_wave_1k_44100_stereo_2_sec_aac.audio",
"sine_wave_1k_44100_stereo_2_sec_mp3.audio",
"sine_wave_1k_48000_stereo_2_sec_wav.audio",
],
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,289 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Defines TimeSeriesFramerCalculator.
#include <math.h>
#include <deque>
#include <memory>
#include <string>
#include "Eigen/Core"
#include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
// MediaPipe Calculator for framing a (vector-valued) input time series,
// i.e. for breaking an input time series into fixed-size, possibly
// overlapping, frames. The output stream's frame duration is
// specified by frame_duration_seconds in the
// TimeSeriesFramerCalculatorOptions, and the output's overlap is
// specified by frame_overlap_seconds.
//
// This calculator assumes that the input timestamps refer to the
// first sample in each Matrix. The output timestamps follow this
// same convention.
//
// All output frames will have exactly the same number of samples: the number of
// samples that approximates frame_duration_seconds most closely.
//
// Similarly, frame overlap is by default the (fixed) number of samples
// approximating frame_overlap_seconds most closely. But if
// emulate_fractional_frame_overlap is set to true, frame overlap is a variable
// number of samples instead, such that the long-term average step between
// frames is the difference between the (nominal) frame_duration_seconds and
// frame_overlap_seconds.
//
// If pad_final_packet is true, all input samples will be emitted and the final
// packet will be zero padded as necessary. If pad_final_packet is false, some
// samples may be dropped at the end of the stream.
class TimeSeriesFramerCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Input stream with TimeSeriesHeader.
);
cc->Outputs().Index(0).Set<Matrix>(
// Fixed length time series Packets with TimeSeriesHeader.
);
return ::mediapipe::OkStatus();
}
// Returns FAIL if the input stream header is invalid.
::mediapipe::Status Open(CalculatorContext* cc) override;
// Outputs as many framed packets as possible given the accumulated
// input. Always returns OK.
::mediapipe::Status Process(CalculatorContext* cc) override;
// Flushes any remaining samples in a zero-padded packet. Always
// returns OK.
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
// Adds input data to the internal buffer.
void EnqueueInput(CalculatorContext* cc);
// Constructs and emits framed output packets.
void FrameOutput(CalculatorContext* cc);
Timestamp CurrentOutputTimestamp() {
return initial_input_timestamp_ +
round(cumulative_completed_samples_ / sample_rate_ *
Timestamp::kTimestampUnitsPerSecond);
}
// The number of input samples to advance after the current output frame is
// emitted.
int next_frame_step_samples() const {
// All numbers are in input samples.
const int64 current_output_frame_start = static_cast<int64>(
round(cumulative_output_frames_ * average_frame_step_samples_));
CHECK_EQ(current_output_frame_start, cumulative_completed_samples_);
const int64 next_output_frame_start = static_cast<int64>(
round((cumulative_output_frames_ + 1) * average_frame_step_samples_));
return next_output_frame_start - current_output_frame_start;
}
double sample_rate_;
bool pad_final_packet_;
int frame_duration_samples_;
// The advance, in input samples, between the start of successive output
// frames. This may be a non-integer average value if
// emulate_fractional_frame_overlap is true.
double average_frame_step_samples_;
int samples_still_to_drop_;
int64 cumulative_input_samples_;
int64 cumulative_output_frames_;
// "Completed" samples are samples that are no longer needed because
// the framer has completely stepped past them (taking into account
// any overlap).
int64 cumulative_completed_samples_;
Timestamp initial_input_timestamp_;
int num_channels_;
// Each entry in this deque consists of a single sample, i.e. a
// single column vector.
std::deque<Matrix> sample_buffer_;
bool use_window_;
Matrix window_;
};
REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) {
const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>();
for (int i = 0; i < input_frame.cols(); ++i) {
sample_buffer_.emplace_back(input_frame.col(i));
}
cumulative_input_samples_ += input_frame.cols();
}
void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
while (sample_buffer_.size() >=
frame_duration_samples_ + samples_still_to_drop_) {
while (samples_still_to_drop_ > 0) {
sample_buffer_.pop_front();
--samples_still_to_drop_;
}
const int frame_step_samples = next_frame_step_samples();
std::unique_ptr<Matrix> output_frame(
new Matrix(num_channels_, frame_duration_samples_));
for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_);
++i) {
output_frame->col(i) = sample_buffer_.front();
sample_buffer_.pop_front();
}
const int frame_overlap_samples =
frame_duration_samples_ - frame_step_samples;
if (frame_overlap_samples > 0) {
for (int i = 0; i < frame_overlap_samples; ++i) {
output_frame->col(i + frame_step_samples) = sample_buffer_[i];
}
} else {
samples_still_to_drop_ = -frame_overlap_samples;
}
if (use_window_) {
*output_frame = (output_frame->array() * window_.array()).matrix();
}
cc->Outputs().Index(0).Add(output_frame.release(),
CurrentOutputTimestamp());
++cumulative_output_frames_;
cumulative_completed_samples_ += frame_step_samples;
}
}
::mediapipe::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
initial_input_timestamp_ = cc->InputTimestamp();
}
EnqueueInput(cc);
FrameOutput(cc);
return ::mediapipe::OkStatus();
}
::mediapipe::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) {
sample_buffer_.pop_front();
--samples_still_to_drop_;
}
if (!sample_buffer_.empty() && pad_final_packet_) {
std::unique_ptr<Matrix> output_frame(new Matrix);
output_frame->setZero(num_channels_, frame_duration_samples_);
for (int i = 0; i < sample_buffer_.size(); ++i) {
output_frame->col(i) = sample_buffer_[i];
}
cc->Outputs().Index(0).Add(output_frame.release(),
CurrentOutputTimestamp());
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
TimeSeriesFramerCalculatorOptions framer_options;
time_series_util::FillOptionsExtensionOrDie(cc->Options(), &framer_options);
RET_CHECK_GT(framer_options.frame_duration_seconds(), 0.0)
<< "Invalid or missing frame_duration_seconds. "
<< "framer_duration_seconds: \n"
<< framer_options.frame_duration_seconds();
RET_CHECK_LT(framer_options.frame_overlap_seconds(),
framer_options.frame_duration_seconds())
<< "Invalid frame_overlap_seconds. framer_overlap_seconds: \n"
<< framer_options.frame_overlap_seconds();
TimeSeriesHeader input_header;
RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), &input_header));
sample_rate_ = input_header.sample_rate();
num_channels_ = input_header.num_channels();
frame_duration_samples_ = time_series_util::SecondsToSamples(
framer_options.frame_duration_seconds(), sample_rate_);
RET_CHECK_GT(frame_duration_samples_, 0)
<< "Frame duration of " << framer_options.frame_duration_seconds()
<< "s too small to cover a single sample at " << sample_rate_ << " Hz ";
if (framer_options.emulate_fractional_frame_overlap()) {
// Frame step may be fractional.
average_frame_step_samples_ = (framer_options.frame_duration_seconds() -
framer_options.frame_overlap_seconds()) *
sample_rate_;
} else {
// Frame step is an integer (stored in a double).
average_frame_step_samples_ =
frame_duration_samples_ -
time_series_util::SecondsToSamples(
framer_options.frame_overlap_seconds(), sample_rate_);
}
RET_CHECK_GE(average_frame_step_samples_, 1)
<< "Frame step too small to cover a single sample at " << sample_rate_
<< " Hz.";
pad_final_packet_ = framer_options.pad_final_packet();
auto output_header = new TimeSeriesHeader(input_header);
output_header->set_num_samples(frame_duration_samples_);
if (round(average_frame_step_samples_) == average_frame_step_samples_) {
// Only set output packet rate if it is fixed.
output_header->set_packet_rate(sample_rate_ / average_frame_step_samples_);
}
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
cumulative_completed_samples_ = 0;
cumulative_input_samples_ = 0;
cumulative_output_frames_ = 0;
samples_still_to_drop_ = 0;
initial_input_timestamp_ = Timestamp::Unstarted();
std::vector<double> window_vector;
use_window_ = false;
switch (framer_options.window_function()) {
case TimeSeriesFramerCalculatorOptions::HAMMING:
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
&window_vector);
use_window_ = true;
break;
case TimeSeriesFramerCalculatorOptions::HANN:
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
&window_vector);
use_window_ = true;
break;
case TimeSeriesFramerCalculatorOptions::NONE:
break;
}
if (use_window_) {
window_ = Matrix::Ones(num_channels_, 1) *
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
frame_duration_samples_)
.cast<float>();
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,65 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TimeSeriesFramerCalculatorOptions {
extend CalculatorOptions {
optional TimeSeriesFramerCalculatorOptions ext = 50631621;
}
// Frame duration in seconds. Required. Must be greater than 0. This is
// rounded to the nearest integer number of samples.
optional double frame_duration_seconds = 1;
// Frame overlap in seconds.
//
// If emulate_fractional_frame_overlap is false (the default), then the frame
// overlap is rounded to the nearest integer number of samples, and the step
// from one frame to the next will be the difference between the number of
// samples in a frame and the number of samples in the overlap.
//
// If emulate_fractional_frame_overlap is true, then frame overlap will be a
// variable number of samples, such that the long-time average time step from
// one frame to the next will be the difference between the (nominal, not
// rounded) frame_duration_seconds and frame_overlap_seconds. This is useful
// where the desired time step is not an integral number of input samples.
//
// A negative frame_overlap_seconds corresponds to skipping some input samples
// between each frame of emitted samples.
//
// Required that frame_overlap_seconds < frame_duration_seconds.
optional double frame_overlap_seconds = 2 [default = 0.0];
// See frame_overlap_seconds for semantics.
optional bool emulate_fractional_frame_overlap = 5 [default = false];
// Whether to pad the final packet with zeros. If true, guarantees that all
// input samples (other than those that fall in gaps implied by negative
// frame_overlap_seconds) will be emitted. If set to false, any partial
// packet at the end of the stream will be dropped.
optional bool pad_final_packet = 3 [default = true];
// Optional windowing function. The default is NONE (no windowing function).
enum WindowFunction {
NONE = 0;
HAMMING = 1;
HANN = 2;
}
optional WindowFunction window_function = 4 [default = NONE];
}

View File

@ -0,0 +1,395 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <memory>
#include <string>
#include <vector>
#include "Eigen/Core"
#include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/time_series_test_util.h"
namespace mediapipe {
namespace {
const int kInitialTimestampOffsetMicroseconds = 4;
class TimeSeriesFramerCalculatorTest
: public TimeSeriesCalculatorTest<TimeSeriesFramerCalculatorOptions> {
protected:
void SetUp() override {
calculator_name_ = "TimeSeriesFramerCalculator";
input_sample_rate_ = 4000.0;
num_input_channels_ = 3;
}
// Returns a float value with the channel and timestamp separated by
// an order of magnitude, for easy parsing by humans.
float TestValue(int64 timestamp_in_microseconds, int channel) {
return timestamp_in_microseconds + channel / 10.0;
}
// Caller takes ownership of the returned value.
Matrix* NewTestFrame(int num_channels, int num_samples,
double starting_timestamp_seconds) {
auto matrix = new Matrix(num_channels, num_samples);
for (int c = 0; c < num_channels; ++c) {
for (int i = 0; i < num_samples; ++i) {
int64 timestamp = time_series_util::SecondsToSamples(
starting_timestamp_seconds + i / input_sample_rate_,
Timestamp::kTimestampUnitsPerSecond);
(*matrix)(c, i) = TestValue(timestamp, c);
}
}
return matrix;
}
// Initializes and runs the test graph.
::mediapipe::Status Run() {
InitializeGraph();
FillInputHeader();
InitializeInput();
return RunGraph();
}
// Creates test input and saves a reference copy.
void InitializeInput() {
concatenated_input_samples_.resize(0, num_input_channels_);
num_input_samples_ = 0;
for (int i = 0; i < 10; ++i) {
// This range of packet sizes was chosen such that some input
// packets will be smaller than the output packet size and other
// input packets will be larger.
int packet_size = (i + 1) * 20;
double timestamp_seconds = kInitialTimestampOffsetMicroseconds * 1.0e-6 +
num_input_samples_ / input_sample_rate_;
Matrix* data_frame =
NewTestFrame(num_input_channels_, packet_size, timestamp_seconds);
// Keep a reference copy of the input.
//
// conservativeResize() is needed here to preserve the existing
// data. Eigen's resize() resizes without preserving data.
concatenated_input_samples_.conservativeResize(
num_input_channels_, num_input_samples_ + packet_size);
concatenated_input_samples_.rightCols(packet_size) = *data_frame;
num_input_samples_ += packet_size;
AppendInputPacket(data_frame, round(timestamp_seconds *
Timestamp::kTimestampUnitsPerSecond));
}
const int frame_duration_samples = FrameDurationSamples();
std::vector<double> window_vector;
switch (options_.window_function()) {
case TimeSeriesFramerCalculatorOptions::HAMMING:
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples,
&window_vector);
break;
case TimeSeriesFramerCalculatorOptions::HANN:
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples,
&window_vector);
break;
case TimeSeriesFramerCalculatorOptions::NONE:
window_vector.assign(frame_duration_samples, 1.0f);
break;
}
window_ = Matrix::Ones(num_input_channels_, 1) *
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
frame_duration_samples)
.cast<float>();
}
int FrameDurationSamples() {
return time_series_util::SecondsToSamples(options_.frame_duration_seconds(),
input_sample_rate_);
}
// Checks that the values in the framed output packets matches the
// appropriate values from the input.
void CheckOutputPacketValues(const Matrix& actual, int packet_num,
int frame_duration_samples,
double frame_step_samples,
int num_columns_to_check) {
ASSERT_EQ(frame_duration_samples, actual.cols());
Matrix expected = (concatenated_input_samples_
.block(0, round(frame_step_samples * packet_num),
num_input_channels_, num_columns_to_check)
.array() *
window_.leftCols(num_columns_to_check).array())
.matrix();
ExpectApproximatelyEqual(expected, actual.leftCols(num_columns_to_check));
}
// Checks output headers, Timestamps, and values.
void CheckOutput() {
const int frame_duration_samples = FrameDurationSamples();
const double frame_step_samples =
options_.emulate_fractional_frame_overlap()
? (options_.frame_duration_seconds() -
options_.frame_overlap_seconds()) *
input_sample_rate_
: frame_duration_samples -
time_series_util::SecondsToSamples(
options_.frame_overlap_seconds(), input_sample_rate_);
TimeSeriesHeader expected_header = input().header.Get<TimeSeriesHeader>();
expected_header.set_num_samples(frame_duration_samples);
if (!options_.emulate_fractional_frame_overlap() ||
frame_step_samples == round(frame_step_samples)) {
expected_header.set_packet_rate(input_sample_rate_ / frame_step_samples);
}
ExpectOutputHeaderEquals(expected_header);
int num_full_packets = output().packets.size();
if (options_.pad_final_packet()) {
num_full_packets -= 1;
}
for (int packet_num = 0; packet_num < num_full_packets; ++packet_num) {
const Packet& packet = output().packets[packet_num];
CheckOutputPacketValues(packet.Get<Matrix>(), packet_num,
frame_duration_samples, frame_step_samples,
frame_duration_samples);
}
// What is the effective time index of the final sample emitted?
// This includes accounting for the gaps when overlap is negative.
const int num_unique_output_samples =
round((output().packets.size() - 1) * frame_step_samples) +
frame_duration_samples;
LOG(INFO) << "packets.size()=" << output().packets.size()
<< " frame_duration_samples=" << frame_duration_samples
<< " frame_step_samples=" << frame_step_samples
<< " num_input_samples_=" << num_input_samples_
<< " num_unique_output_samples=" << num_unique_output_samples;
const int num_padding_samples =
num_unique_output_samples - num_input_samples_;
if (options_.pad_final_packet()) {
EXPECT_LT(num_padding_samples, frame_duration_samples);
// If the input ended during the dropped samples between the end of
// the last emitted frame and where the next one would begin, there
// can be fewer unique output points than input points, even with
// padding.
const int max_dropped_samples =
static_cast<int>(ceil(frame_step_samples - frame_duration_samples));
EXPECT_GE(num_padding_samples, std::min(0, -max_dropped_samples));
if (num_padding_samples > 0) {
// Check the non-padded part of the final packet.
const Matrix& final_matrix = output().packets.back().Get<Matrix>();
CheckOutputPacketValues(final_matrix, num_full_packets,
frame_duration_samples, frame_step_samples,
frame_duration_samples - num_padding_samples);
// Check the padded part of the final packet.
EXPECT_EQ(
Matrix::Zero(num_input_channels_, num_padding_samples),
final_matrix.block(0, frame_duration_samples - num_padding_samples,
num_input_channels_, num_padding_samples));
}
} else {
EXPECT_GT(num_padding_samples, -frame_duration_samples);
EXPECT_LE(num_padding_samples, 0);
}
}
int num_input_samples_;
Matrix concatenated_input_samples_;
Matrix window_;
};
TEST_F(TimeSeriesFramerCalculatorTest, IntegerSampleDurationNoOverlap) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest,
IntegerSampleDurationNoOverlapHammingWindow) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HAMMING);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest,
IntegerSampleDurationNoOverlapHannWindow) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HANN);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, IntegerSampleDurationAndOverlap) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(40.0 / input_sample_rate_);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, NonintegerSampleDurationAndOverlap) {
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
options_.set_frame_overlap_seconds(38.4 / input_sample_rate_);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapExactFrames) {
// Negative overlap means to drop samples between frames.
// 100 samples per frame plus a skip of 10 samples will be 10 full frames in
// the 1100 input samples.
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(-10.0 / input_sample_rate_);
MEDIAPIPE_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 10);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapExactFramesLessSkip) {
// 100 samples per frame plus a skip of 100 samples will be 6 full frames in
// the 1100 input samples.
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(-100.0 / input_sample_rate_);
MEDIAPIPE_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 6);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapWithPadding) {
// 150 samples per frame plus a skip of 50 samples will require some padding
// on the sixth and last frame given 1100 sample input.
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(-100.0 / input_sample_rate_);
MEDIAPIPE_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 6);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, FixedFrameOverlap) {
// Frame of 30 samples with step of 11.4 samples (rounded to 11 samples)
// results in ceil((1100 - 30) / 11) + 1 = 99 packets.
options_.set_frame_duration_seconds(30 / input_sample_rate_);
options_.set_frame_overlap_seconds((30.0 - 11.4) / input_sample_rate_);
MEDIAPIPE_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 99);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, VariableFrameOverlap) {
// Frame of 30 samples with step of 11.4 samples (not rounded)
// results in ceil((1100 - 30) / 11.4) + 1 = 95 packets.
options_.set_frame_duration_seconds(30 / input_sample_rate_);
options_.set_frame_overlap_seconds((30 - 11.4) / input_sample_rate_);
options_.set_emulate_fractional_frame_overlap(true);
MEDIAPIPE_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 95);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, VariableFrameSkip) {
// Frame of 30 samples with step of 41.4 samples (not rounded)
// results in ceil((1100 - 30) / 41.4) + 1 = 27 packets.
options_.set_frame_duration_seconds(30 / input_sample_rate_);
options_.set_frame_overlap_seconds((30 - 41.4) / input_sample_rate_);
options_.set_emulate_fractional_frame_overlap(true);
MEDIAPIPE_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 27);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, NoFinalPacketPadding) {
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
options_.set_pad_final_packet(false);
MEDIAPIPE_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest,
FrameRateHigherThanSampleRate_FrameDurationTooLow) {
// Try to produce a frame rate 10 times the input sample rate by using a
// a frame duration that is too small and covers only 0.1 samples.
options_.set_frame_duration_seconds(1 / (10 * input_sample_rate_));
options_.set_frame_overlap_seconds(0.0);
EXPECT_FALSE(Run().ok());
}
TEST_F(TimeSeriesFramerCalculatorTest,
FrameRateHigherThanSampleRate_FrameStepTooLow) {
// Try to produce a frame rate 10 times the input sample rate by using
// a frame overlap that is too high and produces frame steps (difference
// between duration and overlap) of 0.1 samples.
options_.set_frame_duration_seconds(10.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(9.9 / input_sample_rate_);
EXPECT_FALSE(Run().ok());
}
// A simple test class to do windowing sanity checks. Tests from this
// class input a single packet of all ones, and check the average
// value of the single output packet. This is useful as a sanity check
// that the correct windows are applied.
class TimeSeriesFramerCalculatorWindowingSanityTest
: public TimeSeriesFramerCalculatorTest {
protected:
void SetUp() override {
TimeSeriesFramerCalculatorTest::SetUp();
num_input_channels_ = 1;
}
void RunAndTestSinglePacketAverage(float expected_average) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
InitializeGraph();
FillInputHeader();
AppendInputPacket(new Matrix(Matrix::Ones(1, FrameDurationSamples())),
kInitialTimestampOffsetMicroseconds);
MEDIAPIPE_ASSERT_OK(RunGraph());
ASSERT_EQ(1, output().packets.size());
ASSERT_NEAR(expected_average * FrameDurationSamples(),
output().packets[0].Get<Matrix>().sum(), 1e-5);
}
};
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, NoWindowSanityCheck) {
RunAndTestSinglePacketAverage(1.0f);
}
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest,
HammingWindowSanityCheck) {
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HAMMING);
RunAndTestSinglePacketAverage(0.54f);
}
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, HannWindowSanityCheck) {
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HANN);
RunAndTestSinglePacketAverage(0.5f);
}
} // anonymous namespace
} // namespace mediapipe

View File

@ -19,6 +19,20 @@ package(default_visibility = ["//visibility:private"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "concatenate_vector_calculator_proto",
srcs = ["concatenate_vector_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "packet_cloner_calculator_proto",
srcs = ["packet_cloner_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "packet_resampler_calculator_proto",
srcs = ["packet_resampler_calculator.proto"],
@ -26,6 +40,46 @@ proto_library(
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "split_vector_calculator_proto",
srcs = ["split_vector_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "quantize_float_vector_calculator_proto",
srcs = ["quantize_float_vector_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "sequence_shift_calculator_proto",
srcs = ["sequence_shift_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
proto_library(
name = "gate_calculator_proto",
srcs = ["gate_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_cc_proto_library(
name = "packet_cloner_calculator_cc_proto",
srcs = ["packet_cloner_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":packet_cloner_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "packet_resampler_calculator_cc_proto",
srcs = ["packet_resampler_calculator.proto"],
@ -34,6 +88,115 @@ mediapipe_cc_proto_library(
deps = [":packet_resampler_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "split_vector_calculator_cc_proto",
srcs = ["split_vector_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":split_vector_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "concatenate_vector_calculator_cc_proto",
srcs = ["concatenate_vector_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":concatenate_vector_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "quantize_float_vector_calculator_cc_proto",
srcs = ["quantize_float_vector_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":quantize_float_vector_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "sequence_shift_calculator_cc_proto",
srcs = ["sequence_shift_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":sequence_shift_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "gate_calculator_cc_proto",
srcs = ["gate_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":gate_calculator_proto"],
)
cc_library(
name = "add_header_calculator",
srcs = ["add_header_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:logging",
],
alwayslink = 1,
)
cc_test(
name = "add_header_calculator_test",
size = "small",
srcs = ["add_header_calculator_test.cc"],
deps = [
":add_header_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:validate_type",
],
)
cc_library(
name = "concatenate_vector_calculator",
srcs = ["concatenate_vector_calculator.cc"],
hdrs = ["concatenate_vector_calculator.h"],
visibility = ["//visibility:public"],
deps = [
":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/lite:framework",
],
alwayslink = 1,
)
cc_library(
name = "concatenate_detection_vector_calculator",
srcs = ["concatenate_detection_vector_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":concatenate_vector_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
],
alwayslink = 1,
)
cc_test(
name = "concatenate_vector_calculator_test",
srcs = ["concatenate_vector_calculator_test.cc"],
deps = [
":concatenate_vector_calculator",
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "counting_source_calculator",
srcs = ["counting_source_calculator.cc"],
@ -62,6 +225,38 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "matrix_multiply_calculator",
srcs = ["matrix_multiply_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:status",
"@eigen_archive//:eigen",
],
alwayslink = 1,
)
cc_library(
name = "matrix_subtract_calculator",
srcs = ["matrix_subtract_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:status",
"@eigen_archive//:eigen",
],
alwayslink = 1,
)
cc_library(
name = "mux_calculator",
srcs = ["mux_calculator.cc"],
@ -83,12 +278,37 @@ cc_library(
"//visibility:public",
],
deps = [
"//mediapipe/calculators/core:packet_cloner_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_library(
name = "packet_inner_join_calculator",
srcs = ["packet_inner_join_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "packet_inner_join_calculator_test",
srcs = ["packet_inner_join_calculator_test.cc"],
deps = [
":packet_inner_join_calculator",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:validate_type",
],
)
cc_library(
name = "pass_through_calculator",
srcs = ["pass_through_calculator.cc"],
@ -145,8 +365,8 @@ cc_library(
)
cc_library(
name = "real_time_flow_limiter_calculator",
srcs = ["real_time_flow_limiter_calculator.cc"],
name = "flow_limiter_calculator",
srcs = ["flow_limiter_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
@ -191,19 +411,16 @@ cc_library(
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/deps:mathutil",
"//mediapipe/framework/deps:random",
"//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_util",
"@com_google_absl//absl/strings",
"//mediapipe/framework/deps:mathutil",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:integral_types",
] + select({
"//conditions:default": [
"//mediapipe/framework/deps:random",
],
}),
],
alwayslink = 1,
)
@ -245,10 +462,43 @@ cc_test(
)
cc_test(
name = "real_time_flow_limiter_calculator_test",
srcs = ["real_time_flow_limiter_calculator_test.cc"],
name = "matrix_multiply_calculator_test",
srcs = ["matrix_multiply_calculator_test.cc"],
visibility = ["//visibility:private"],
deps = [
":real_time_flow_limiter_calculator",
":matrix_multiply_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/tool:validate_type",
"@eigen_archive//:eigen",
],
)
cc_test(
name = "matrix_subtract_calculator_test",
srcs = ["matrix_subtract_calculator_test.cc"],
visibility = ["//visibility:private"],
deps = [
":matrix_subtract_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:validate_type",
"@eigen_archive//:eigen",
],
)
cc_test(
name = "flow_limiter_calculator_test",
srcs = ["flow_limiter_calculator_test.cc"],
deps = [
":flow_limiter_calculator",
"//mediapipe/calculators/core:counting_source_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:calculator_framework",
@ -264,3 +514,140 @@ cc_test(
"@com_google_absl//absl/time",
],
)
cc_library(
name = "split_vector_calculator",
srcs = ["split_vector_calculator.cc"],
hdrs = ["split_vector_calculator.h"],
visibility = ["//visibility:public"],
deps = [
":split_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util:resource_util",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
alwayslink = 1,
)
cc_test(
name = "split_vector_calculator_test",
srcs = ["split_vector_calculator_test.cc"],
deps = [
":split_vector_calculator",
":split_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:validate_type",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
cc_library(
name = "quantize_float_vector_calculator",
srcs = ["quantize_float_vector_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":quantize_float_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "quantize_float_vector_calculator_test",
srcs = ["quantize_float_vector_calculator_test.cc"],
deps = [
":quantize_float_vector_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)
cc_library(
name = "sequence_shift_calculator",
srcs = ["sequence_shift_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":sequence_shift_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "sequence_shift_calculator_test",
srcs = ["sequence_shift_calculator_test.cc"],
deps = [
":sequence_shift_calculator",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:validate_type",
],
)
cc_library(
name = "gate_calculator",
srcs = ["gate_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":gate_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util:header_util",
],
alwayslink = 1,
)
cc_test(
name = "gate_calculator_test",
srcs = ["gate_calculator_test.cc"],
deps = [
":gate_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)
cc_library(
name = "merge_calculator",
srcs = ["merge_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "merge_calculator_test",
srcs = ["merge_calculator_test.cc"],
deps = [
":merge_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)

View File

@ -0,0 +1,53 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/logging.h"
namespace mediapipe {
// Attach the header from one stream to another stream.
//
// The header stream (tag HEADER) must not have any packets in it.
//
// Before using this calculator, please think about changing your
// calculator to not need a header or to accept a separate stream with
// a header, that would be more future proof.
//
class AddHeaderCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("HEADER").SetNone();
cc->Inputs().Tag("DATA").SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Tag("DATA"));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
const Packet& header = cc->Inputs().Tag("HEADER").Header();
if (!header.IsEmpty()) {
cc->Outputs().Index(0).SetHeader(header);
}
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Tag("DATA").Value());
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(AddHeaderCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,99 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
class AddHeaderCalculatorTest : public ::testing::Test {};
TEST_F(AddHeaderCalculatorTest, Works) {
CalculatorGraphConfig::Node node;
node.set_calculator("AddHeaderCalculator");
node.add_input_stream("HEADER:header_stream");
node.add_input_stream("DATA:data_stream");
node.add_output_stream("merged_stream");
CalculatorRunner runner(node);
// Set header and add 5 packets.
runner.MutableInputs()->Tag("HEADER").header =
Adopt(new std::string("my_header"));
for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
}
// Run calculator.
MEDIAPIPE_ASSERT_OK(runner.Run());
ASSERT_EQ(1, runner.Outputs().NumEntries());
// Test output.
EXPECT_EQ(std::string("my_header"),
runner.Outputs().Index(0).header.Get<std::string>());
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
ASSERT_EQ(5, output_packets.size());
for (int i = 0; i < 5; ++i) {
const int val = output_packets[i].Get<int>();
EXPECT_EQ(i, val);
EXPECT_EQ(Timestamp(i * 1000), output_packets[i].Timestamp());
}
}
TEST_F(AddHeaderCalculatorTest, HandlesEmptyHeaderStream) {
CalculatorGraphConfig::Node node;
node.set_calculator("AddHeaderCalculator");
node.add_input_stream("HEADER:header_stream");
node.add_input_stream("DATA:data_stream");
node.add_output_stream("merged_stream");
CalculatorRunner runner(node);
// No header and no packets.
// Run calculator.
MEDIAPIPE_ASSERT_OK(runner.Run());
EXPECT_TRUE(runner.Outputs().Index(0).header.IsEmpty());
}
TEST_F(AddHeaderCalculatorTest, NoPacketsOnHeaderStream) {
CalculatorGraphConfig::Node node;
node.set_calculator("AddHeaderCalculator");
node.add_input_stream("HEADER:header_stream");
node.add_input_stream("DATA:data_stream");
node.add_output_stream("merged_stream");
CalculatorRunner runner(node);
// Set header and add 5 packets.
runner.MutableInputs()->Tag("HEADER").header =
Adopt(new std::string("my_header"));
runner.MutableInputs()->Tag("HEADER").packets.push_back(
Adopt(new std::string("not allowed")));
for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
}
// Run calculator.
ASSERT_FALSE(runner.Run().ok());
}
} // namespace mediapipe

View File

@ -0,0 +1,33 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
#include "mediapipe/framework/formats/detection.pb.h"
namespace mediapipe {
// Example config:
// node {
// calculator: "ConcatenateDetectionVectorCalculator"
// input_stream: "detection_vector_1"
// input_stream: "detection_vector_2"
// output_stream: "concatenated_detection_vector"
// }
typedef ConcatenateVectorCalculator<::mediapipe::Detection>
ConcatenateDetectionVectorCalculator;
REGISTER_CALCULATOR(ConcatenateDetectionVectorCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,44 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
#include <vector>
#include "tensorflow/lite/interpreter.h"
namespace mediapipe {
// Example config:
// node {
// calculator: "ConcatenateFloatVectorCalculator"
// input_stream: "float_vector_1"
// input_stream: "float_vector_2"
// output_stream: "concatenated_float_vector"
// }
typedef ConcatenateVectorCalculator<float> ConcatenateFloatVectorCalculator;
REGISTER_CALCULATOR(ConcatenateFloatVectorCalculator);
// Example config:
// node {
// calculator: "ConcatenateTfLiteTensorVectorCalculator"
// input_stream: "tflitetensor_vector_1"
// input_stream: "tflitetensor_vector_2"
// output_stream: "concatenated_tflitetensor_vector"
// }
typedef ConcatenateVectorCalculator<TfLiteTensor>
ConcatenateTfLiteTensorVectorCalculator;
REGISTER_CALCULATOR(ConcatenateTfLiteTensorVectorCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,78 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
#include <vector>
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Concatenates several std::vector<T> following stream index order. This class
// assumes that every input stream contains the vector<T> type. To use this
// class for a particular type T, regisiter a calculator using
// ConcatenateVectorCalculator<T>.
template <typename T>
class ConcatenateVectorCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().NumEntries() != 0);
RET_CHECK(cc->Outputs().NumEntries() == 1);
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).Set<std::vector<T>>();
}
cc->Outputs().Index(0).Set<std::vector<T>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
only_emit_if_all_present_ =
cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>()
.only_emit_if_all_present();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
if (only_emit_if_all_present_) {
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
if (cc->Inputs().Index(i).IsEmpty()) return ::mediapipe::OkStatus();
}
}
auto output = absl::make_unique<std::vector<T>>();
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
if (cc->Inputs().Index(i).IsEmpty()) continue;
const std::vector<T>& input = cc->Inputs().Index(i).Get<std::vector<T>>();
output->insert(output->end(), input.begin(), input.end());
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
private:
bool only_emit_if_all_present_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_

View File

@ -0,0 +1,29 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message ConcatenateVectorCalculatorOptions {
extend CalculatorOptions {
optional ConcatenateVectorCalculatorOptions ext = 259397839;
}
// If true, the calculator will only emit a packet at the given timestamp if
// all input streams have a non-empty packet (AND operation on streams).
optional bool only_emit_if_all_present = 1 [default = false];
}

View File

@ -0,0 +1,238 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
REGISTER_CALCULATOR(TestConcatenateIntVectorCalculator);
void AddInputVectors(const std::vector<std::vector<int>>& inputs,
int64 timestamp, CalculatorRunner* runner) {
for (int i = 0; i < inputs.size(); ++i) {
runner->MutableInputs()->Index(i).packets.push_back(
MakePacket<std::vector<int>>(inputs[i]).At(Timestamp(timestamp)));
}
}
TEST(TestConcatenateIntVectorCalculatorTest, EmptyVectorInputs) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<int>> inputs = {{}, {}, {}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_TRUE(outputs[0].Get<std::vector<int>>().empty());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
}
TEST(TestConcatenateIntVectorCalculatorTest, OneTimestamp) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<int>> inputs = {{1, 2, 3}, {4}, {5, 6}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
TEST(TestConcatenateIntVectorCalculatorTest, TwoInputsAtTwoTimestamps) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
{
std::vector<std::vector<int>> inputs = {{1, 2, 3}, {4}, {5, 6}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
}
{
std::vector<std::vector<int>> inputs = {{0, 2}, {1}, {3, 5}};
AddInputVectors(inputs, /*timestamp=*/2, &runner);
}
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(2, outputs.size());
{
EXPECT_EQ(6, outputs[0].Get<std::vector<int>>().size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
{
EXPECT_EQ(5, outputs[1].Get<std::vector<int>>().size());
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
std::vector<int> expected_vector = {0, 2, 1, 3, 5};
EXPECT_EQ(expected_vector, outputs[1].Get<std::vector<int>>());
}
}
TEST(TestConcatenateIntVectorCalculatorTest, OneEmptyStreamStillOutput) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<int>> inputs = {{1, 2, 3}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
TEST(TestConcatenateIntVectorCalculatorTest, OneEmptyStreamNoOutput) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
"{only_emit_if_all_present: true}",
/*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<int>> inputs = {{1, 2, 3}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(0, outputs.size());
}
void AddInputVectors(const std::vector<std::vector<float>>& inputs,
int64 timestamp, CalculatorRunner* runner) {
for (int i = 0; i < inputs.size(); ++i) {
runner->MutableInputs()->Index(i).packets.push_back(
MakePacket<std::vector<float>>(inputs[i]).At(Timestamp(timestamp)));
}
}
TEST(ConcatenateFloatVectorCalculatorTest, EmptyVectorInputs) {
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<float>> inputs = {{}, {}, {}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_TRUE(outputs[0].Get<std::vector<float>>().empty());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
}
TEST(ConcatenateFloatVectorCalculatorTest, OneTimestamp) {
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<float>> inputs = {
{1.0f, 2.0f, 3.0f}, {4.0f}, {5.0f, 6.0f}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
}
TEST(ConcatenateFloatVectorCalculatorTest, TwoInputsAtTwoTimestamps) {
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
{
std::vector<std::vector<float>> inputs = {
{1.0f, 2.0f, 3.0f}, {4.0f}, {5.0f, 6.0f}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
}
{
std::vector<std::vector<float>> inputs = {
{0.0f, 2.0f}, {1.0f}, {3.0f, 5.0f}};
AddInputVectors(inputs, /*timestamp=*/2, &runner);
}
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(2, outputs.size());
{
EXPECT_EQ(6, outputs[0].Get<std::vector<float>>().size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
}
{
EXPECT_EQ(5, outputs[1].Get<std::vector<float>>().size());
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
std::vector<float> expected_vector = {0.0f, 2.0f, 1.0f, 3.0f, 5.0f};
EXPECT_EQ(expected_vector, outputs[1].Get<std::vector<float>>());
}
}
TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamStillOutput) {
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
/*options_string=*/"", /*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<float>> inputs = {{1.0f, 2.0f, 3.0f}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
}
TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) {
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
/*options_string=*/
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
"{only_emit_if_all_present: true}",
/*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<float>> inputs = {{1.0f, 2.0f, 3.0f}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(0, outputs.size());
}
} // namespace mediapipe

View File

@ -23,34 +23,34 @@
namespace mediapipe {
// RealTimeFlowLimiterCalculator is used to limit the number of pipelined
// processing operations in a section of the graph.
// FlowLimiterCalculator is used to limit the number of pipelined processing
// operations in a section of the graph.
//
// Typical topology:
//
// in ->-[RTFLC]-[foo]-...-[bar]-+->- out
// ^____________________|
// FINISHED
// in ->-[FLC]-[foo]-...-[bar]-+->- out
// ^_____________________|
// FINISHED
//
// By connecting the output of the graph section to this calculator's FINISHED
// input with a backwards edge, this allows RTFLC to keep track of how many
// input with a backwards edge, this allows FLC to keep track of how many
// timestamps are currently being processed.
//
// The limit defaults to 1, and can be overridden with the MAX_IN_FLIGHT side
// packet.
//
// As long as the number of timestamps being processed ("in flight") is below
// the limit, RTFLC allows input to pass through. When the limit is reached,
// RTFLC starts dropping input packets, keeping only the most recent. When the
// the limit, FLC allows input to pass through. When the limit is reached,
// FLC starts dropping input packets, keeping only the most recent. When the
// processing count decreases again, as signaled by the receipt of a packet on
// FINISHED, RTFLC allows packets to flow again, releasing the most recently
// FINISHED, FLC allows packets to flow again, releasing the most recently
// queued packet, if any.
//
// If there are multiple input streams, packet dropping is synchronized.
//
// IMPORTANT: for each timestamp where RTFLC forwards a packet (or a set of
// IMPORTANT: for each timestamp where FLC forwards a packet (or a set of
// packets, if using multiple data streams), a packet must eventually arrive on
// the FINISHED stream. Dropping packets in the section between RTFLC and
// the FINISHED stream. Dropping packets in the section between FLC and
// FINISHED will make the in-flight count incorrect.
//
// TODO: Remove this comment when graph-level ISH has been removed.
@ -61,7 +61,7 @@ namespace mediapipe {
//
// Example config:
// node {
// calculator: "RealTimeFlowLimiterCalculator"
// calculator: "FlowLimiterCalculator"
// input_stream: "raw_frames"
// input_stream: "FINISHED:finished"
// input_stream_info: {
@ -73,7 +73,7 @@ namespace mediapipe {
// }
// output_stream: "gated_frames"
// }
class RealTimeFlowLimiterCalculator : public CalculatorBase {
class FlowLimiterCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
int num_data_streams = cc->Inputs().NumEntries("");
@ -194,6 +194,6 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase {
Timestamp allow_ctr_ts_;
std::vector<Timestamp> data_stream_bound_ts_;
};
REGISTER_CALCULATOR(RealTimeFlowLimiterCalculator);
REGISTER_CALCULATOR(FlowLimiterCalculator);
} // namespace mediapipe

View File

@ -71,7 +71,7 @@ constexpr int kNumImageFrames = 5;
constexpr int kNumFinished = 3;
CalculatorGraphConfig::Node GetDefaultNode() {
return ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "RealTimeFlowLimiterCalculator"
calculator: "FlowLimiterCalculator"
input_stream: "raw_frames"
input_stream: "FINISHED:finished"
input_stream_info: { tag_index: "FINISHED" back_edge: true }
@ -79,9 +79,9 @@ CalculatorGraphConfig::Node GetDefaultNode() {
)");
}
// Simple test to make sure that the RealTimeFlowLimiterCalculator outputs
// just one packet when MAX_IN_FLIGHT is 1.
TEST(RealTimeFlowLimiterCalculator, OneOutputTest) {
// Simple test to make sure that the FlowLimiterCalculator outputs just one
// packet when MAX_IN_FLIGHT is 1.
TEST(FlowLimiterCalculator, OneOutputTest) {
// Setup the calculator runner and add only ImageFrame packets.
CalculatorRunner runner(GetDefaultNode());
for (int i = 0; i < kNumImageFrames; ++i) {
@ -98,9 +98,9 @@ TEST(RealTimeFlowLimiterCalculator, OneOutputTest) {
EXPECT_EQ(frame_output_packets.size(), 1);
}
// Simple test to make sure that the RealTimeFlowLimiterCalculator waits for all
// Simple test to make sure that the FlowLimiterCalculator waits for all
// input streams to have at least one packet available before publishing.
TEST(RealTimeFlowLimiterCalculator, BasicTest) {
TEST(FlowLimiterCalculator, BasicTest) {
// Setup the calculator runner and add both ImageFrame and finish packets.
CalculatorRunner runner(GetDefaultNode());
for (int i = 0; i < kNumImageFrames; ++i) {
@ -171,13 +171,11 @@ class CloseCallbackCalculator : public CalculatorBase {
};
REGISTER_CALCULATOR(CloseCallbackCalculator);
// Tests demostrating an RealTimeFlowLimiterCalculator operating in a cyclic
// graph.
// Tests demostrating an FlowLimiterCalculator operating in a cyclic graph.
// TODO: clean up these tests.
class RealTimeFlowLimiterCalculatorTest : public testing::Test {
class FlowLimiterCalculatorTest : public testing::Test {
public:
RealTimeFlowLimiterCalculatorTest()
: enter_semaphore_(0), exit_semaphore_(0) {}
FlowLimiterCalculatorTest() : enter_semaphore_(0), exit_semaphore_(0) {}
void SetUp() override {
graph_config_ = InflightGraphConfig();
@ -215,7 +213,7 @@ class RealTimeFlowLimiterCalculatorTest : public testing::Test {
input_name, MakePacket<int>(value).At(Timestamp(value))));
}
// A calculator graph starting with an RealTimeFlowLimiterCalculator and
// A calculator graph starting with an FlowLimiterCalculator and
// ending with a InFlightFinishCalculator.
// Back-edge "finished" limits processing to one frame in-flight.
// The two LambdaCalculators are used to keep certain packet sets in flight.
@ -224,7 +222,7 @@ class RealTimeFlowLimiterCalculatorTest : public testing::Test {
input_stream: 'in_1'
input_stream: 'in_2'
node {
calculator: 'RealTimeFlowLimiterCalculator'
calculator: 'FlowLimiterCalculator'
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
input_stream: 'in_1'
input_stream: 'in_2'
@ -270,14 +268,14 @@ class RealTimeFlowLimiterCalculatorTest : public testing::Test {
int close_count_ = 0;
};
// A test demonstrating an RealTimeFlowLimiterCalculator operating in a cyclic
// A test demonstrating an FlowLimiterCalculator operating in a cyclic
// graph. This test shows that:
//
// (1) Timestamps are passed through unaltered.
// (2) All output streams including the back_edge stream are closed when
// the first input stream is closed.
//
TEST_F(RealTimeFlowLimiterCalculatorTest, BackEdgeCloses) {
TEST_F(FlowLimiterCalculatorTest, BackEdgeCloses) {
InitializeGraph(1);
MEDIAPIPE_ASSERT_OK(graph_.StartRun({}));
@ -321,7 +319,7 @@ TEST_F(RealTimeFlowLimiterCalculatorTest, BackEdgeCloses) {
// A test demonstrating that all output streams are closed when all
// input streams are closed after the last input packet has been processed.
TEST_F(RealTimeFlowLimiterCalculatorTest, AllStreamsClose) {
TEST_F(FlowLimiterCalculatorTest, AllStreamsClose) {
InitializeGraph(1);
MEDIAPIPE_ASSERT_OK(graph_.StartRun({}));
@ -341,7 +339,7 @@ TEST_F(RealTimeFlowLimiterCalculatorTest, AllStreamsClose) {
EXPECT_EQ(1, close_count_);
}
TEST(RealTimeFlowLimiterCalculator, TwoStreams) {
TEST(FlowLimiterCalculator, TwoStreams) {
std::vector<Packet> a_passed;
std::vector<Packet> b_passed;
CalculatorGraphConfig graph_config_ =
@ -351,7 +349,7 @@ TEST(RealTimeFlowLimiterCalculator, TwoStreams) {
input_stream: 'finished'
node {
name: 'input_dropper'
calculator: 'RealTimeFlowLimiterCalculator'
calculator: 'FlowLimiterCalculator'
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
input_stream: 'in_a'
input_stream: 'in_b'
@ -440,7 +438,7 @@ TEST(RealTimeFlowLimiterCalculator, TwoStreams) {
MEDIAPIPE_EXPECT_OK(graph_.WaitUntilDone());
}
TEST(RealTimeFlowLimiterCalculator, CanConsume) {
TEST(FlowLimiterCalculator, CanConsume) {
std::vector<Packet> in_sampled_packets_;
CalculatorGraphConfig graph_config_ =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
@ -448,7 +446,7 @@ TEST(RealTimeFlowLimiterCalculator, CanConsume) {
input_stream: 'finished'
node {
name: 'input_dropper'
calculator: 'RealTimeFlowLimiterCalculator'
calculator: 'FlowLimiterCalculator'
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
input_stream: 'in'
input_stream: 'FINISHED:finished'

View File

@ -0,0 +1,163 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/gate_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/header_util.h"
namespace mediapipe {
namespace {
enum GateState {
GATE_UNINITIALIZED,
GATE_ALLOW,
GATE_DISALLOW,
};
std::string ToString(GateState state) {
switch (state) {
case GATE_UNINITIALIZED:
return "UNINITIALIZED";
case GATE_ALLOW:
return "ALLOW";
case GATE_DISALLOW:
return "DISALLOW";
}
DLOG(FATAL) << "Unknown GateState";
return "UNKNOWN";
}
} // namespace
// Controls whether or not the input packets are passed further along the graph.
// Takes multiple data input streams and either an ALLOW or a DISALLOW control
// input stream. It outputs an output stream for each input stream that is not
// ALLOW or DISALLOW as well as an optional STATE_CHANGE stream which downstream
// calculators can use to respond to state-change events.
//
// If the current ALLOW packet is set to true, the input packets are passed to
// their corresponding output stream unchanged. If the ALLOW packet is set to
// false, the current input packet is NOT passed to the output stream. If using
// DISALLOW, the behavior is opposite of ALLOW.
//
// By default, an empty packet in the ALLOW or DISALLOW input stream indicates
// disallowing the corresponding packets in other input streams. The behavior
// can be inverted with a calculator option.
//
// Intended to be used with the default input stream handler, which synchronizes
// all data input streams with the ALLOW/DISALLOW control input stream.
//
// Example config:
// node {
// calculator: "GateCalculator"
// input_stream: "input_stream0"
// input_stream: "input_stream1"
// input_stream: "input_streamN"
// input_stream: "ALLOW:allow" or "DISALLOW:disallow"
// output_stream: "STATE_CHANGE:state_change"
// output_stream: "output_stream0"
// output_stream: "output_stream1"
// output_stream: "output_streamN"
// }
class GateCalculator : public CalculatorBase {
public:
GateCalculator() {}
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
// Assume that input streams do not have a tag and that gating signal is
// tagged either ALLOW or DISALLOW.
RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW"));
const int num_data_streams = cc->Inputs().NumEntries("");
RET_CHECK_GE(num_data_streams, 1);
RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams)
<< "Number of data output streams must match with data input streams.";
for (int i = 0; i < num_data_streams; ++i) {
cc->Inputs().Get("", i).SetAny();
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
}
if (cc->Inputs().HasTag("ALLOW")) {
cc->Inputs().Tag("ALLOW").Set<bool>();
} else {
cc->Inputs().Tag("DISALLOW").Set<bool>();
}
if (cc->Outputs().HasTag("STATE_CHANGE")) {
cc->Outputs().Tag("STATE_CHANGE").Set<bool>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
num_data_streams_ = cc->Inputs().NumEntries("");
last_gate_state_ = GATE_UNINITIALIZED;
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>();
empty_packets_as_allow_ = options.empty_packets_as_allow();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
bool allow = empty_packets_as_allow_;
if (cc->Inputs().HasTag("ALLOW") && !cc->Inputs().Tag("ALLOW").IsEmpty()) {
allow = cc->Inputs().Tag("ALLOW").Get<bool>();
}
if (cc->Inputs().HasTag("DISALLOW") &&
!cc->Inputs().Tag("DISALLOW").IsEmpty()) {
allow = !cc->Inputs().Tag("DISALLOW").Get<bool>();
}
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
if (cc->Outputs().HasTag("STATE_CHANGE")) {
if (last_gate_state_ != GATE_UNINITIALIZED &&
last_gate_state_ != new_gate_state) {
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
<< cc->InputTimestamp().Value() << " from "
<< ToString(last_gate_state_) << " to "
<< ToString(new_gate_state);
cc->Outputs()
.Tag("STATE_CHANGE")
.AddPacket(MakePacket<bool>(allow).At(cc->InputTimestamp()));
}
}
last_gate_state_ = new_gate_state;
if (!allow) {
return ::mediapipe::OkStatus();
}
// Process data streams.
for (int i = 0; i < num_data_streams_; ++i) {
if (!cc->Inputs().Get("", i).IsEmpty()) {
cc->Outputs().Get("", i).AddPacket(cc->Inputs().Get("", i).Value());
}
}
return ::mediapipe::OkStatus();
}
private:
GateState last_gate_state_ = GATE_UNINITIALIZED;
int num_data_streams_;
bool empty_packets_as_allow_;
};
REGISTER_CALCULATOR(GateCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,30 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message GateCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional GateCalculatorOptions ext = 261754847;
}
// By default an empty packet in the ALLOW or DISALLOW input stream indicates
// disallowing the corresponding packets in the data input streams. Setting
// this option to true inverts that, allowing the data packets to go through.
optional bool empty_packets_as_allow = 1;
}

View File

@ -0,0 +1,190 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
class GateCalculatorTest : public ::testing::Test {
protected:
void RunTimeStep(int64 timestamp, const std::string& control_tag,
bool control) {
runner_->MutableInputs()->Get("", 0).packets.push_back(
MakePacket<bool>(true).At(Timestamp(timestamp)));
runner_->MutableInputs()
->Tag(control_tag)
.packets.push_back(MakePacket<bool>(control).At(Timestamp(timestamp)));
MEDIAPIPE_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
}
void SetRunner(const std::string& proto) {
runner_ = absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(proto));
}
CalculatorRunner* runner() { return runner_.get(); }
private:
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(GateCalculatorTest, Allow) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
output_stream: "test_output"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "ALLOW", false);
constexpr int64 kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "ALLOW", true);
constexpr int64 kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "ALLOW", false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue2, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>());
EXPECT_EQ(true, output[1].Get<bool>());
}
TEST_F(GateCalculatorTest, Disallow) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "DISALLOW:gating_stream"
output_stream: "test_output"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "DISALLOW", false);
constexpr int64 kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "DISALLOW", true);
constexpr int64 kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "DISALLOW", false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>());
EXPECT_EQ(true, output[1].Get<bool>());
}
TEST_F(GateCalculatorTest, AllowWithStateChange) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
output_stream: "test_output"
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", false);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "ALLOW", true);
constexpr int64 kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "ALLOW", true);
constexpr int64 kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "ALLOW", false);
const std::vector<Packet>& output =
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>()); // Allow.
EXPECT_EQ(false, output[1].Get<bool>()); // Disallow.
}
TEST_F(GateCalculatorTest, DisallowWithStateChange) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "DISALLOW:gating_stream"
output_stream: "test_output"
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "DISALLOW", false);
constexpr int64 kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "DISALLOW", false);
constexpr int64 kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "DISALLOW", true);
const std::vector<Packet>& output =
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>()); // Allow.
EXPECT_EQ(false, output[1].Get<bool>()); // Disallow.
}
// Must not detect disallow value for first timestamp as a state change.
TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "DISALLOW:gating_stream"
output_stream: "test_output"
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", false);
const std::vector<Packet>& output =
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
ASSERT_EQ(0, output.size());
}
// Must not detect allow value for first timestamp as a state change.
TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
output_stream: "test_output"
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true);
const std::vector<Packet>& output =
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
ASSERT_EQ(0, output.size());
}
} // namespace
} // namespace mediapipe

View File

@ -146,7 +146,7 @@ class ImmediateMuxCalculatorTest : public ::testing::Test {
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
input_stream: "input_packets_0"
node {
calculator: 'RealTimeFlowLimiterCalculator'
calculator: 'FlowLimiterCalculator'
input_stream_handler {
input_stream_handler: 'ImmediateInputStreamHandler'
}

View File

@ -0,0 +1,66 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "Eigen/Core"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Perform a (left) matrix multiply. Meaning (output = A * input)
// where A is the matrix which is provided as an input side packet.
//
// Example config:
// node {
// calculator: "MatrixMultiplyCalculator"
// input_stream: "samples"
// output_stream: "multiplied_samples"
// input_side_packet: "multiplication_matrix"
// }
class MatrixMultiplyCalculator : public CalculatorBase {
public:
MatrixMultiplyCalculator() {}
~MatrixMultiplyCalculator() override {}
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(MatrixMultiplyCalculator);
// static
::mediapipe::Status MatrixMultiplyCalculator::GetContract(
CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>();
cc->Outputs().Index(0).Set<Matrix>();
cc->InputSidePackets().Index(0).Set<Matrix>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status MatrixMultiplyCalculator::Open(CalculatorContext* cc) {
// The output is at the same timestamp as the input.
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) {
Matrix* multiplied = new Matrix();
*multiplied = cc->InputSidePackets().Index(0).Get<Matrix>() *
cc->Inputs().Index(0).Get<Matrix>();
cc->Outputs().Index(0).Add(multiplied, cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,239 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "Eigen/Core"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
namespace {
// A 3x4 Matrix of random integers in [0,1000).
const char kMatrixText[] =
"rows: 3\n"
"cols: 4\n"
"packed_data: 387\n"
"packed_data: 940\n"
"packed_data: 815\n"
"packed_data: 825\n"
"packed_data: 997\n"
"packed_data: 884\n"
"packed_data: 419\n"
"packed_data: 763\n"
"packed_data: 123\n"
"packed_data: 30\n"
"packed_data: 825\n"
"packed_data: 299\n";
// A 4x20 Matrix of random integers in [0,10).
// Each column of this matrix is a sample.
const char kSamplesText[] =
"rows: 4\n"
"cols: 20\n"
"packed_data: 7\n"
"packed_data: 9\n"
"packed_data: 5\n"
"packed_data: 9\n"
"packed_data: 6\n"
"packed_data: 3\n"
"packed_data: 0\n"
"packed_data: 7\n"
"packed_data: 1\n"
"packed_data: 3\n"
"packed_data: 3\n"
"packed_data: 2\n"
"packed_data: 4\n"
"packed_data: 5\n"
"packed_data: 0\n"
"packed_data: 4\n"
"packed_data: 6\n"
"packed_data: 0\n"
"packed_data: 1\n"
"packed_data: 2\n"
"packed_data: 0\n"
"packed_data: 2\n"
"packed_data: 0\n"
"packed_data: 3\n"
"packed_data: 1\n"
"packed_data: 7\n"
"packed_data: 4\n"
"packed_data: 9\n"
"packed_data: 8\n"
"packed_data: 8\n"
"packed_data: 6\n"
"packed_data: 4\n"
"packed_data: 6\n"
"packed_data: 8\n"
"packed_data: 1\n"
"packed_data: 9\n"
"packed_data: 7\n"
"packed_data: 5\n"
"packed_data: 3\n"
"packed_data: 5\n"
"packed_data: 3\n"
"packed_data: 5\n"
"packed_data: 7\n"
"packed_data: 7\n"
"packed_data: 3\n"
"packed_data: 3\n"
"packed_data: 6\n"
"packed_data: 4\n"
"packed_data: 7\n"
"packed_data: 7\n"
"packed_data: 2\n"
"packed_data: 5\n"
"packed_data: 4\n"
"packed_data: 8\n"
"packed_data: 1\n"
"packed_data: 0\n"
"packed_data: 2\n"
"packed_data: 0\n"
"packed_data: 3\n"
"packed_data: 4\n"
"packed_data: 6\n"
"packed_data: 6\n"
"packed_data: 8\n"
"packed_data: 5\n"
"packed_data: 5\n"
"packed_data: 8\n"
"packed_data: 9\n"
"packed_data: 7\n"
"packed_data: 3\n"
"packed_data: 7\n"
"packed_data: 2\n"
"packed_data: 7\n"
"packed_data: 8\n"
"packed_data: 2\n"
"packed_data: 1\n"
"packed_data: 1\n"
"packed_data: 4\n"
"packed_data: 1\n"
"packed_data: 1\n"
"packed_data: 7\n";
// A 3x20 Matrix of expected values for the result of the matrix multiply
// computed using R.
// Each column of this matrix is an expected output.
const char kExpectedText[] =
"rows: 3\n"
"cols: 20\n"
"packed_data: 12499\n"
"packed_data: 26793\n"
"packed_data: 16967\n"
"packed_data: 5007\n"
"packed_data: 14406\n"
"packed_data: 9635\n"
"packed_data: 4179\n"
"packed_data: 7870\n"
"packed_data: 4434\n"
"packed_data: 5793\n"
"packed_data: 12045\n"
"packed_data: 8876\n"
"packed_data: 2801\n"
"packed_data: 8053\n"
"packed_data: 5611\n"
"packed_data: 1740\n"
"packed_data: 4469\n"
"packed_data: 2665\n"
"packed_data: 8108\n"
"packed_data: 18396\n"
"packed_data: 10186\n"
"packed_data: 12330\n"
"packed_data: 23374\n"
"packed_data: 15526\n"
"packed_data: 9611\n"
"packed_data: 21804\n"
"packed_data: 14776\n"
"packed_data: 8241\n"
"packed_data: 17979\n"
"packed_data: 11989\n"
"packed_data: 8429\n"
"packed_data: 18921\n"
"packed_data: 9819\n"
"packed_data: 6270\n"
"packed_data: 13689\n"
"packed_data: 7031\n"
"packed_data: 9472\n"
"packed_data: 19210\n"
"packed_data: 13634\n"
"packed_data: 8567\n"
"packed_data: 12499\n"
"packed_data: 10455\n"
"packed_data: 2151\n"
"packed_data: 7469\n"
"packed_data: 3195\n"
"packed_data: 10774\n"
"packed_data: 21851\n"
"packed_data: 12673\n"
"packed_data: 12516\n"
"packed_data: 25318\n"
"packed_data: 14347\n"
"packed_data: 7984\n"
"packed_data: 17100\n"
"packed_data: 10972\n"
"packed_data: 5195\n"
"packed_data: 11102\n"
"packed_data: 8710\n"
"packed_data: 3002\n"
"packed_data: 11295\n"
"packed_data: 6360\n";
// Send a number of samples through the MatrixMultiplyCalculator.
TEST(MatrixMultiplyCalculatorTest, Multiply) {
CalculatorRunner runner("MatrixMultiplyCalculator", "", 1, 1, 1);
Matrix* matrix = new Matrix();
MatrixFromTextProto(kMatrixText, matrix);
runner.MutableSidePackets()->Index(0) = Adopt(matrix);
Matrix samples;
MatrixFromTextProto(kSamplesText, &samples);
Matrix expected;
MatrixFromTextProto(kExpectedText, &expected);
CHECK_EQ(samples.cols(), expected.cols());
for (int i = 0; i < samples.cols(); ++i) {
// Take a column from samples and produce a packet with just that
// column in it as an input sample for the calculator.
Eigen::MatrixXf* sample = new Eigen::MatrixXf(samples.block(0, i, 4, 1));
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(sample).At(Timestamp(i)));
}
MEDIAPIPE_ASSERT_OK(runner.Run());
EXPECT_EQ(runner.MutableInputs()->Index(0).packets.size(),
runner.Outputs().Index(0).packets.size());
int i = 0;
for (const Packet& output : runner.Outputs().Index(0).packets) {
EXPECT_EQ(Timestamp(i), output.Timestamp());
const Eigen::MatrixXf& result = output.Get<Matrix>();
ASSERT_EQ(3, result.rows());
EXPECT_NEAR((expected.block(0, i, 3, 1) - result).cwiseAbs().sum(), 0.0,
1e-5);
++i;
}
EXPECT_EQ(samples.cols(), i);
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,123 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "Eigen/Core"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Subtract input matrix from the side input matrix and vice versa. The matrices
// must have the same dimension.
// Based on the tag (MINUEND vs SUBTRAHEND), the matrices in the input stream
// and input side packet can be either subtrahend or minuend. The output matrix
// is generated by performing minuend matrix - subtrahend matrix.
//
// Example config:
// node {
// calculator: "MatrixSubtractCalculator"
// input_stream: "MINUEND:input_matrix"
// input_side_packet: "SUBTRAHEND:side_matrix"
// output_stream: "output_matrix"
// }
//
// or
//
// node {
// calculator: "MatrixSubtractCalculator"
// input_stream: "SUBTRAHEND:input_matrix"
// input_side_packet: "MINUEND:side_matrix"
// output_stream: "output_matrix"
// }
class MatrixSubtractCalculator : public CalculatorBase {
public:
MatrixSubtractCalculator() {}
~MatrixSubtractCalculator() override {}
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
bool subtract_from_input_ = false;
};
REGISTER_CALCULATOR(MatrixSubtractCalculator);
// static
::mediapipe::Status MatrixSubtractCalculator::GetContract(
CalculatorContract* cc) {
if (cc->Inputs().NumEntries() != 1 ||
cc->InputSidePackets().NumEntries() != 1) {
return ::mediapipe::InvalidArgumentError(
"MatrixSubtractCalculator only accepts exactly one input stream and "
"one "
"input side packet");
}
if (cc->Inputs().HasTag("MINUEND") &&
cc->InputSidePackets().HasTag("SUBTRAHEND")) {
cc->Inputs().Tag("MINUEND").Set<Matrix>();
cc->InputSidePackets().Tag("SUBTRAHEND").Set<Matrix>();
} else if (cc->Inputs().HasTag("SUBTRAHEND") &&
cc->InputSidePackets().HasTag("MINUEND")) {
cc->Inputs().Tag("SUBTRAHEND").Set<Matrix>();
cc->InputSidePackets().Tag("MINUEND").Set<Matrix>();
} else {
return ::mediapipe::InvalidArgumentError(
"Must specify exactly one minuend and one subtrahend.");
}
cc->Outputs().Index(0).Set<Matrix>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status MatrixSubtractCalculator::Open(CalculatorContext* cc) {
// The output is at the same timestamp as the input.
cc->SetOffset(TimestampDiff(0));
if (cc->Inputs().HasTag("MINUEND")) {
subtract_from_input_ = true;
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) {
Matrix* subtracted = new Matrix();
if (subtract_from_input_) {
const Matrix& input_matrix = cc->Inputs().Tag("MINUEND").Get<Matrix>();
const Matrix& side_input_matrix =
cc->InputSidePackets().Tag("SUBTRAHEND").Get<Matrix>();
if (input_matrix.rows() != side_input_matrix.rows() ||
input_matrix.cols() != side_input_matrix.cols()) {
return ::mediapipe::InvalidArgumentError(
"Input matrix and the input side matrix must have the same "
"dimension.");
}
*subtracted = input_matrix - side_input_matrix;
} else {
const Matrix& input_matrix = cc->Inputs().Tag("SUBTRAHEND").Get<Matrix>();
const Matrix& side_input_matrix =
cc->InputSidePackets().Tag("MINUEND").Get<Matrix>();
if (input_matrix.rows() != side_input_matrix.rows() ||
input_matrix.cols() != side_input_matrix.cols()) {
return ::mediapipe::InvalidArgumentError(
"Input matrix and the input side matrix must have the same "
"dimension.");
}
*subtracted = side_input_matrix - input_matrix;
}
cc->Outputs().Index(0).Add(subtracted, cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,157 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "Eigen/Core"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
namespace {
// A 3x4 Matrix of random integers in [0,1000).
const char kMatrixText[] =
"rows: 3\n"
"cols: 4\n"
"packed_data: 387\n"
"packed_data: 940\n"
"packed_data: 815\n"
"packed_data: 825\n"
"packed_data: 997\n"
"packed_data: 884\n"
"packed_data: 419\n"
"packed_data: 763\n"
"packed_data: 123\n"
"packed_data: 30\n"
"packed_data: 825\n"
"packed_data: 299\n";
const char kMatrixText2[] =
"rows: 3\n"
"cols: 4\n"
"packed_data: 388\n"
"packed_data: 941\n"
"packed_data: 816\n"
"packed_data: 826\n"
"packed_data: 998\n"
"packed_data: 885\n"
"packed_data: 420\n"
"packed_data: 764\n"
"packed_data: 124\n"
"packed_data: 31\n"
"packed_data: 826\n"
"packed_data: 300\n";
TEST(MatrixSubtractCalculatorTest, WrongConfig) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "MatrixSubtractCalculator"
input_stream: "input_matrix"
input_side_packet: "SUBTRAHEND:side_matrix"
input_side_packet: "MINUEND:side_matrix2"
output_stream: "output_matrix"
)");
CalculatorRunner runner(node_config);
auto status = runner.Run();
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"only accepts exactly one input stream and one input side packet"));
}
TEST(MatrixSubtractCalculatorTest, WrongConfig2) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "MatrixSubtractCalculator"
input_side_packet: "SUBTRAHEND:side_matrix"
input_stream: "SUBTRAHEND:side_matrix2"
output_stream: "output_matrix"
)");
CalculatorRunner runner(node_config);
auto status = runner.Run();
EXPECT_THAT(
status.message(),
testing::HasSubstr("specify exactly one minuend and one subtrahend."));
}
TEST(MatrixSubtractCalculatorTest, SubtractFromInput) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "MatrixSubtractCalculator"
input_stream: "MINUEND:input_matrix"
input_side_packet: "SUBTRAHEND:side_matrix"
output_stream: "output_matrix"
)");
CalculatorRunner runner(node_config);
Matrix* side_matrix = new Matrix();
MatrixFromTextProto(kMatrixText, side_matrix);
runner.MutableSidePackets()->Tag("SUBTRAHEND") = Adopt(side_matrix);
Matrix* input_matrix = new Matrix();
MatrixFromTextProto(kMatrixText2, input_matrix);
runner.MutableInputs()->Tag("MINUEND").packets.push_back(
Adopt(input_matrix).At(Timestamp(0)));
MEDIAPIPE_ASSERT_OK(runner.Run());
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
EXPECT_EQ(Timestamp(0), runner.Outputs().Index(0).packets[0].Timestamp());
const Eigen::MatrixXf& result =
runner.Outputs().Index(0).packets[0].Get<Matrix>();
ASSERT_EQ(3, result.rows());
ASSERT_EQ(4, result.cols());
EXPECT_NEAR(result.sum(), 12, 1e-5);
}
TEST(MatrixSubtractCalculatorTest, SubtractFromSideMatrix) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "MatrixSubtractCalculator"
input_stream: "SUBTRAHEND:input_matrix"
input_side_packet: "MINUEND:side_matrix"
output_stream: "output_matrix"
)");
CalculatorRunner runner(node_config);
Matrix* side_matrix = new Matrix();
MatrixFromTextProto(kMatrixText, side_matrix);
runner.MutableSidePackets()->Tag("MINUEND") = Adopt(side_matrix);
Matrix* input_matrix = new Matrix();
MatrixFromTextProto(kMatrixText2, input_matrix);
runner.MutableInputs()
->Tag("SUBTRAHEND")
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
MEDIAPIPE_ASSERT_OK(runner.Run());
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
EXPECT_EQ(Timestamp(0), runner.Outputs().Index(0).packets[0].Timestamp());
const Eigen::MatrixXf& result =
runner.Outputs().Index(0).packets[0].Get<Matrix>();
ASSERT_EQ(3, result.rows());
ASSERT_EQ(4, result.cols());
EXPECT_NEAR(result.sum(), -12, 1e-5);
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,91 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// This calculator takes a set of input streams and combines them into a single
// output stream. The packets from different streams do not need to contain the
// same type. If there are packets arriving at the same time from two or more
// input streams, the packet corresponding to the input stream with the smallest
// index is passed to the output and the rest are ignored.
//
// Example use-case:
// Suppose we have two (or more) different algorithms for detecting shot
// boundaries and we need to merge their packets into a single stream. The
// algorithms may emit shot boundaries at the same time and their output types
// may not be compatible. Subsequent calculators that process the merged stream
// may be interested only in the timestamps of the shot boundary packets and so
// it may not even need to inspect the values stored inside the packets.
//
// Example config:
// node {
// calculator: "MergeCalculator"
// input_stream: "shot_info1"
// input_stream: "shot_info2"
// input_stream: "shot_info3"
// output_stream: "merged_shot_infos"
// }
//
class MergeCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK_GT(cc->Inputs().NumEntries(), 0)
<< "Needs at least one input stream";
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
if (cc->Inputs().NumEntries() == 1) {
LOG(WARNING)
<< "MergeCalculator expects multiple input streams to merge but is "
"receiving only one. Make sure the calculator is configured "
"correctly or consider removing this calculator to reduce "
"unnecessary overhead.";
}
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).SetAny();
}
cc->Outputs().Index(0).SetAny();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
// Output the packet from the first input stream with a packet ready at this
// timestamp.
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
if (!cc->Inputs().Index(i).IsEmpty()) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(i).Value());
return ::mediapipe::OkStatus();
}
}
LOG(WARNING) << "Empty input packets at timestamp "
<< cc->InputTimestamp().Value();
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(MergeCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,139 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
// Checks that the calculator fails if no input streams are provided.
TEST(InvariantMergeInputStreamsCalculator, NoInputStreamsMustFail) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "MergeCalculator"
output_stream: "merged_output"
)"));
// Expect calculator to fail.
ASSERT_FALSE(runner.Run().ok());
}
// Checks that the calculator fails with an incorrect number of output streams.
TEST(InvariantMergeInputStreamsCalculator, ExpectExactlyOneOutputStream) {
CalculatorRunner runner1(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "MergeCalculator"
input_stream: "input1"
input_stream: "input2"
)"));
// Expect calculator to fail.
EXPECT_FALSE(runner1.Run().ok());
CalculatorRunner runner2(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "MergeCalculator"
input_stream: "input1"
input_stream: "input2"
output_stream: "output1"
output_stream: "output2"
)"));
// Expect calculator to fail.
ASSERT_FALSE(runner2.Run().ok());
}
// Ensures two streams with differing types can be merged correctly.
TEST(MediaPipeDetectionToSoapboxDetectionCalculatorTest,
TestMergingTwoStreams) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "MergeCalculator"
input_stream: "input1"
input_stream: "input2"
output_stream: "combined_output"
)"));
// input1: integers 10, 20, 30, occurring at times 10, 20, 30.
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(new int(10)).At(Timestamp(10)));
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(new int(20)).At(Timestamp(20)));
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(new int(30)).At(Timestamp(30)));
// input2: floats 5.5, 35.5 at times 5, 35.
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(new float(5.5)).At(Timestamp(5)));
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(new float(35.5)).At(Timestamp(35)));
MEDIAPIPE_ASSERT_OK(runner.Run());
// Expected combined_output: 5.5, 10, 20, 30, 35.5 at times 5, 10, 20, 30, 35.
const std::vector<Packet>& actual_output = runner.Outputs().Index(0).packets;
ASSERT_EQ(actual_output.size(), 5);
EXPECT_EQ(actual_output[0].Timestamp(), Timestamp(5));
EXPECT_EQ(actual_output[0].Get<float>(), 5.5);
EXPECT_EQ(actual_output[1].Timestamp(), Timestamp(10));
EXPECT_EQ(actual_output[1].Get<int>(), 10);
EXPECT_EQ(actual_output[2].Timestamp(), Timestamp(20));
EXPECT_EQ(actual_output[2].Get<int>(), 20);
EXPECT_EQ(actual_output[3].Timestamp(), Timestamp(30));
EXPECT_EQ(actual_output[3].Get<int>(), 30);
EXPECT_EQ(actual_output[4].Timestamp(), Timestamp(35));
EXPECT_EQ(actual_output[4].Get<float>(), 35.5);
}
// Ensures three streams with differing types can be merged correctly.
TEST(MediaPipeDetectionToSoapboxDetectionCalculatorTest,
TestMergingThreeStreams) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "MergeCalculator"
input_stream: "input1"
input_stream: "input2"
input_stream: "input3"
output_stream: "combined_output"
)"));
// input1: integer 30 occurring at time 30.
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(new int(30)).At(Timestamp(30)));
// input2: float 20.5 occurring at time 20.
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(new float(20.5)).At(Timestamp(20)));
// input3: char 'c' occurring at time 10.
runner.MutableInputs()->Index(2).packets.push_back(
Adopt(new char('c')).At(Timestamp(10)));
MEDIAPIPE_ASSERT_OK(runner.Run());
// Expected combined_output: 'c', 20.5, 30 at times 10, 20, 30.
const std::vector<Packet>& actual_output = runner.Outputs().Index(0).packets;
ASSERT_EQ(actual_output.size(), 3);
EXPECT_EQ(actual_output[0].Timestamp(), Timestamp(10));
EXPECT_EQ(actual_output[0].Get<char>(), 'c');
EXPECT_EQ(actual_output[1].Timestamp(), Timestamp(20));
EXPECT_EQ(actual_output[1].Get<float>(), 20.5);
EXPECT_EQ(actual_output[2].Timestamp(), Timestamp(30));
EXPECT_EQ(actual_output[2].Get<int>(), 30);
}
} // namespace
} // namespace mediapipe

View File

@ -51,7 +51,7 @@ class MuxCalculator : public CalculatorBase {
data_input_base_ = cc->Inputs().GetId("INPUT", 0);
num_data_inputs_ = cc->Inputs().NumEntries("INPUT");
output_ = cc->Outputs().GetId("OUTPUT", 0);
cc->SetOffset(mediapipe::TimestampDiff(0));
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}

View File

@ -19,6 +19,7 @@
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/core/packet_cloner_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
namespace mediapipe {
@ -39,6 +40,7 @@ namespace mediapipe {
// }
//
// Related:
// packet_cloner_calculator.proto: Options for this calculator.
// merge_input_streams_calculator.cc: One output stream.
// packet_inner_join_calculator.cc: Don't output unless all inputs are new.
class PacketClonerCalculator : public CalculatorBase {
@ -54,6 +56,13 @@ class PacketClonerCalculator : public CalculatorBase {
}
::mediapipe::Status Open(CalculatorContext* cc) final {
// Load options.
const auto calculator_options =
cc->Options<mediapipe::PacketClonerCalculatorOptions>();
output_only_when_all_inputs_received_ =
calculator_options.output_only_when_all_inputs_received();
// Parse input streams.
tick_signal_index_ = cc->Inputs().NumEntries() - 1;
current_.resize(tick_signal_index_);
// Pass along the header for each stream if present.
@ -73,8 +82,17 @@ class PacketClonerCalculator : public CalculatorBase {
}
}
// Output if the tick signal is non-empty.
// Output according to the TICK signal.
if (!cc->Inputs().Index(tick_signal_index_).Value().IsEmpty()) {
if (output_only_when_all_inputs_received_) {
// Return if one of the input is null.
for (int i = 0; i < tick_signal_index_; ++i) {
if (current_[i].IsEmpty()) {
return ::mediapipe::OkStatus();
}
}
}
// Output each stream.
for (int i = 0; i < tick_signal_index_; ++i) {
if (!current_[i].IsEmpty()) {
cc->Outputs().Index(i).AddPacket(
@ -91,6 +109,7 @@ class PacketClonerCalculator : public CalculatorBase {
private:
std::vector<Packet> current_;
int tick_signal_index_;
bool output_only_when_all_inputs_received_;
};
REGISTER_CALCULATOR(PacketClonerCalculator);

View File

@ -0,0 +1,29 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message PacketClonerCalculatorOptions {
extend CalculatorOptions {
optional PacketClonerCalculatorOptions ext = 258872085;
}
// When true, this calculator will drop received TICK packets if any input
// stream hasn't received a packet yet.
optional bool output_only_when_all_inputs_received = 1 [default = false];
}

View File

@ -0,0 +1,78 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
// Calculator that acts like the SQL query:
// SELECT *
// FROM packets_on_stream1 AS packet1
// INNER JOIN packets_on_stream2 AS packet2
// ON packet1.timestamp = packet2.timestamp
//
// In other words, it only emits and forwards packets if all input streams are
// not empty.
//
// Intended for use with FixedSizeInputStreamHandler.
//
// Related:
// packet_cloner_calculator.cc: Repeats last-seen packets from empty inputs.
class PacketInnerJoinCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
int num_streams_;
};
REGISTER_CALCULATOR(PacketInnerJoinCalculator);
::mediapipe::Status PacketInnerJoinCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().NumEntries() == cc->Outputs().NumEntries())
<< "The number of input and output streams must match.";
const int num_streams = cc->Inputs().NumEntries();
for (int i = 0; i < num_streams; ++i) {
cc->Inputs().Index(i).SetAny();
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status PacketInnerJoinCalculator::Open(CalculatorContext* cc) {
num_streams_ = cc->Inputs().NumEntries();
cc->SetOffset(TimestampDiff(0));
return mediapipe::OkStatus();
}
::mediapipe::Status PacketInnerJoinCalculator::Process(CalculatorContext* cc) {
for (int i = 0; i < num_streams_; ++i) {
if (cc->Inputs().Index(i).Value().IsEmpty()) {
return ::mediapipe::OkStatus();
}
}
for (int i = 0; i < num_streams_; ++i) {
cc->Outputs().Index(i).AddPacket(cc->Inputs().Index(i).Value());
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,101 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
namespace {
inline Packet PacketFrom(int i) { return Adopt(new int(i)).At(Timestamp(i)); }
TEST(PacketInnerJoinCalculatorTest, AllMatching) {
// Test case.
const std::vector<int> packets_on_stream1 = {0, 1, 2, 3};
const std::vector<int> packets_on_stream2 = {0, 1, 2, 3};
// Run.
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
for (int packet_load : packets_on_stream1) {
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
}
for (int packet_load : packets_on_stream2) {
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
}
MEDIAPIPE_ASSERT_OK(runner.Run());
// Check.
const std::vector<int> expected = {0, 1, 2, 3};
ASSERT_EQ(expected.size(), runner.Outputs().Index(0).packets.size());
ASSERT_EQ(expected.size(), runner.Outputs().Index(1).packets.size());
for (int i = 0; i < expected.size(); ++i) {
const Packet packet1 = runner.Outputs().Index(0).packets[i];
EXPECT_EQ(expected[i], packet1.Get<int>());
EXPECT_EQ(expected[i], packet1.Timestamp().Value());
const Packet packet2 = runner.Outputs().Index(1).packets[i];
EXPECT_EQ(expected[i], packet2.Get<int>());
EXPECT_EQ(expected[i], packet2.Timestamp().Value());
}
}
TEST(PacketInnerJoinCalculatorTest, NoneMatching) {
// Test case.
const std::vector<int> packets_on_stream1 = {0, 2};
const std::vector<int> packets_on_stream2 = {1, 3};
// Run.
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
for (int packet_load : packets_on_stream1) {
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
}
for (int packet_load : packets_on_stream2) {
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
}
MEDIAPIPE_ASSERT_OK(runner.Run());
// Check.
EXPECT_TRUE(runner.Outputs().Index(0).packets.empty());
EXPECT_TRUE(runner.Outputs().Index(1).packets.empty());
}
TEST(PacketInnerJoinCalculatorTest, SomeMatching) {
// Test case.
const std::vector<int> packets_on_stream1 = {0, 1, 2, 3, 4, 6};
const std::vector<int> packets_on_stream2 = {0, 2, 4, 5, 6};
// Run.
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
for (int packet_load : packets_on_stream1) {
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
}
for (int packet_load : packets_on_stream2) {
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
}
MEDIAPIPE_ASSERT_OK(runner.Run());
// Check.
const std::vector<int> expected = {0, 2, 4, 6};
ASSERT_EQ(expected.size(), runner.Outputs().Index(0).packets.size());
ASSERT_EQ(expected.size(), runner.Outputs().Index(1).packets.size());
for (int i = 0; i < expected.size(); ++i) {
const Packet packet1 = runner.Outputs().Index(0).packets[i];
EXPECT_EQ(expected[i], packet1.Get<int>());
EXPECT_EQ(expected[i], packet1.Timestamp().Value());
const Packet packet2 = runner.Outputs().Index(1).packets[i];
EXPECT_EQ(expected[i], packet2.Get<int>());
EXPECT_EQ(expected[i], packet2.Timestamp().Value());
}
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,102 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cfloat>
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/calculators/core/quantize_float_vector_calculator.pb.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/status.h"
// Quantizes a vector of floats to a std::string so that each float becomes a
// byte in the [0, 255] range. Any value above max_quantized_value or below
// min_quantized_value will be saturated to '/xFF' or '/0'.
//
// Example config:
// node {
// calculator: "QuantizeFloatVectorCalculator"
// input_stream: "FLOAT_VECTOR:float_vector"
// output_stream: "ENCODED:encoded"
// options {
// [mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
// max_quantized_value: 64
// min_quantized_value: -64
// }
// }
// }
namespace mediapipe {
class QuantizeFloatVectorCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
cc->Outputs().Tag("ENCODED").Set<std::string>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
const auto options =
cc->Options<::mediapipe::QuantizeFloatVectorCalculatorOptions>();
if (!options.has_max_quantized_value() ||
!options.has_min_quantized_value()) {
return ::mediapipe::InvalidArgumentError(
"Both max_quantized_value and min_quantized_value must be provided "
"in QuantizeFloatVectorCalculatorOptions.");
}
max_quantized_value_ = options.max_quantized_value();
min_quantized_value_ = options.min_quantized_value();
if (max_quantized_value_ < min_quantized_value_ + FLT_EPSILON) {
return ::mediapipe::InvalidArgumentError(
"max_quantized_value must be greater than min_quantized_value.");
}
range_ = max_quantized_value_ - min_quantized_value_;
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
const std::vector<float>& float_vector =
cc->Inputs().Tag("FLOAT_VECTOR").Value().Get<std::vector<float>>();
int feature_size = float_vector.size();
std::string encoded_features;
encoded_features.reserve(feature_size);
for (int i = 0; i < feature_size; i++) {
float old_value = float_vector[i];
if (old_value < min_quantized_value_) {
old_value = min_quantized_value_;
}
if (old_value > max_quantized_value_) {
old_value = max_quantized_value_;
}
unsigned char encoded = static_cast<unsigned char>(
(old_value - min_quantized_value_) * (255.0 / range_));
encoded_features += encoded;
}
cc->Outputs().Tag("ENCODED").AddPacket(
MakePacket<std::string>(encoded_features).At(cc->InputTimestamp()));
return ::mediapipe::OkStatus();
}
private:
float max_quantized_value_;
float min_quantized_value_;
float range_;
};
REGISTER_CALCULATOR(QuantizeFloatVectorCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,28 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message QuantizeFloatVectorCalculatorOptions {
extend CalculatorOptions {
optional QuantizeFloatVectorCalculatorOptions ext = 259848061;
}
optional float max_quantized_value = 1;
optional float min_quantized_value = 2;
}

View File

@ -0,0 +1,204 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
min_quantized_value: 1
}
}
)");
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"Both max_quantized_value and min_quantized_value must be provided"));
}
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
max_quantized_value: -1
min_quantized_value: 1
}
}
)");
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"max_quantized_value must be greater than min_quantized_value"));
}
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
max_quantized_value: 1
min_quantized_value: 1
}
}
)");
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"max_quantized_value must be greater than min_quantized_value"));
}
TEST(QuantizeFloatVectorCalculatorTest, TestEmptyVector) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
max_quantized_value: 1
min_quantized_value: -1
}
}
)");
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
EXPECT_EQ(1, outputs.size());
EXPECT_TRUE(outputs[0].Get<std::string>().empty());
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
}
TEST(QuantizeFloatVectorCalculatorTest, TestNonEmptyVector) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
max_quantized_value: 64
min_quantized_value: -64
}
}
)");
CalculatorRunner runner(node_config);
std::vector<float> vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f};
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
EXPECT_EQ(1, outputs.size());
const std::string& result = outputs[0].Get<std::string>();
ASSERT_FALSE(result.empty());
EXPECT_EQ(5, result.size());
// 127
EXPECT_EQ('\x7F', result.c_str()[0]);
// 0
EXPECT_EQ('\0', result.c_str()[1]);
// 255
EXPECT_EQ('\xFF', result.c_str()[2]);
// 63
EXPECT_EQ('\x3F', result.c_str()[3]);
// 191
EXPECT_EQ('\xBF', result.c_str()[4]);
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
}
TEST(QuantizeFloatVectorCalculatorTest, TestSaturation) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
max_quantized_value: 64
min_quantized_value: -64
}
}
)");
CalculatorRunner runner(node_config);
std::vector<float> vector = {-65.0f, 65.0f};
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
EXPECT_EQ(1, outputs.size());
const std::string& result = outputs[0].Get<std::string>();
ASSERT_FALSE(result.empty());
EXPECT_EQ(2, result.size());
// 0
EXPECT_EQ('\0', result.c_str()[0]);
// 255
EXPECT_EQ('\xFF', result.c_str()[1]);
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
}
} // namespace mediapipe

View File

@ -0,0 +1,114 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <deque>
#include "mediapipe/calculators/core/sequence_shift_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
namespace mediapipe {
// A Calculator that shifts the timestamps of packets along a stream. Packets on
// the input stream are output with a timestamp of the packet given by packet
// offset. That is, packet[i] is output with the timestamp of
// packet[i + packet_offset]. Packet offset can be either positive or negative.
// If packet_offset is -n, the first n packets will be dropped. If packet offset
// is n, the final n packets will be dropped. For example, with a packet_offset
// of -1, the first packet on the stream will be dropped, the second will be
// output with the timestamp of the first, the third with the timestamp of the
// second, and so on.
class SequenceShiftCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return ::mediapipe::OkStatus();
}
// Reads from options to set cache_size_ and packet_offset_.
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
// A positive offset means we want a packet to be output with the timestamp of
// a later packet. Stores packets waiting for their output timestamps and
// outputs a single packet when the cache fills.
void ProcessPositiveOffset(CalculatorContext* cc);
// A negative offset means we want a packet to be output with the timestamp of
// an earlier packet. Stores timestamps waiting for the corresponding input
// packet and outputs a single packet when the cache fills.
void ProcessNegativeOffset(CalculatorContext* cc);
// Storage for packets waiting to be output when packet_offset > 0. When cache
// is full, oldest packet is output with current timestamp.
std::deque<Packet> packet_cache_;
// Storage for previous timestamps used when packet_offset < 0. When cache is
// full, oldest timestamp is used for current packet.
std::deque<Timestamp> timestamp_cache_;
// Copied from corresponding field in options.
int packet_offset_;
// The number of packets or timestamps we need to store to output packet[i] at
// the timestamp of packet[i + packet_offset]; equal to abs(packet_offset).
int cache_size_;
};
REGISTER_CALCULATOR(SequenceShiftCalculator);
::mediapipe::Status SequenceShiftCalculator::Open(CalculatorContext* cc) {
packet_offset_ =
cc->Options<mediapipe::SequenceShiftCalculatorOptions>().packet_offset();
cache_size_ = abs(packet_offset_);
// An offset of zero is a no-op, but someone might still request it.
if (packet_offset_ == 0) {
cc->Outputs().Index(0).SetOffset(0);
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) {
if (packet_offset_ > 0) {
ProcessPositiveOffset(cc);
} else if (packet_offset_ < 0) {
ProcessNegativeOffset(cc);
} else {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
}
return ::mediapipe::OkStatus();
}
void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) {
if (packet_cache_.size() >= cache_size_) {
// Ready to output oldest packet with current timestamp.
cc->Outputs().Index(0).AddPacket(
packet_cache_.front().At(cc->InputTimestamp()));
packet_cache_.pop_front();
}
// Store current packet for later output.
packet_cache_.push_back(cc->Inputs().Index(0).Value());
}
void SequenceShiftCalculator::ProcessNegativeOffset(CalculatorContext* cc) {
if (timestamp_cache_.size() >= cache_size_) {
// Ready to output current packet with oldest timestamp.
cc->Outputs().Index(0).AddPacket(
cc->Inputs().Index(0).Value().At(timestamp_cache_.front()));
timestamp_cache_.pop_front();
}
// Store current timestamp for use by a future packet.
timestamp_cache_.push_back(cc->InputTimestamp());
}
} // namespace mediapipe

View File

@ -0,0 +1,26 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message SequenceShiftCalculatorOptions {
extend CalculatorOptions {
optional SequenceShiftCalculatorOptions ext = 107633927;
}
optional int32 packet_offset = 1 [default = -1];
}

View File

@ -0,0 +1,104 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
namespace {
// Adds packets containing integers equal to their original timestamp.
void AddPackets(CalculatorRunner* runner) {
for (int i = 0; i < 10; ++i) {
runner->MutableInputs()->Index(0).packets.push_back(
Adopt(new int(i)).At(Timestamp(i)));
}
}
// Zero shift is a no-op (output input[i] at timestamp[i]). Input and output
// streams should be identical.
TEST(SequenceShiftCalculatorTest, ZeroShift) {
CalculatorRunner runner(
"SequenceShiftCalculator",
"[mediapipe.SequenceShiftCalculatorOptions.ext]: { packet_offset: 0 }", 1,
1, 0);
AddPackets(&runner);
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& input_packets =
runner.MutableInputs()->Index(0).packets;
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
ASSERT_EQ(10, input_packets.size());
ASSERT_EQ(input_packets.size(), output_packets.size());
for (int i = 0; i < output_packets.size(); ++i) {
// Make sure the contents are as expected.
EXPECT_EQ(input_packets[i].Get<int>(), output_packets[i].Get<int>());
EXPECT_EQ(input_packets[i].Timestamp(), output_packets[i].Timestamp());
}
}
// Tests shifting by three packets, i.e., output input[i] with the timestamp of
// input[i + 3].
TEST(SequenceShiftCalculatorTest, PositiveShift) {
CalculatorRunner runner(
"SequenceShiftCalculator",
"[mediapipe.SequenceShiftCalculatorOptions.ext]: { packet_offset: 3 }", 1,
1, 0);
AddPackets(&runner);
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& input_packets =
runner.MutableInputs()->Index(0).packets;
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
ASSERT_EQ(10, input_packets.size());
// input_packet[i] should be output with the timestamp of input_packet[i + 3].
// The last 3 packets are dropped.
ASSERT_EQ(7, output_packets.size());
for (int i = 0; i < output_packets.size(); ++i) {
// Make sure the contents are as expected.
EXPECT_EQ(input_packets[i].Get<int>(), output_packets[i].Get<int>());
// Make sure the timestamps are shifted as expected.
EXPECT_EQ(input_packets[i + 3].Timestamp(), output_packets[i].Timestamp());
}
}
// Tests shifting by -2, i.e., output input[i] with timestamp[i - 2]. The first
// two packets should be dropped.
TEST(SequenceShiftCalculatorTest, NegativeShift) {
CalculatorRunner runner(
"SequenceShiftCalculator",
"[mediapipe.SequenceShiftCalculatorOptions.ext]: { packet_offset: -2 }",
1, 1, 0);
AddPackets(&runner);
MEDIAPIPE_ASSERT_OK(runner.Run());
const std::vector<Packet>& input_packets =
runner.MutableInputs()->Index(0).packets;
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
ASSERT_EQ(10, input_packets.size());
// Input packet[i] should be output with the timestamp of input packet[i - 2].
// The first two packets are dropped. This means timestamps match between
// input and output packets, but the data in the output packets come from
// input_packets[i + 2].
ASSERT_EQ(8, output_packets.size());
for (int i = 0; i < output_packets.size(); ++i) {
EXPECT_EQ(input_packets[i].Timestamp(), output_packets[i].Timestamp());
EXPECT_EQ(input_packets[i + 2].Get<int>(), output_packets[i].Get<int>());
}
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,40 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/split_vector_calculator.h"
#include <vector>
#include "tensorflow/lite/interpreter.h"
namespace mediapipe {
// Example config:
// node {
// calculator: "SplitTfLiteTensorVectorCalculator"
// input_stream: "tflitetensor_vector"
// output_stream: "tflitetensor_vector_range_0"
// output_stream: "tflitetensor_vector_range_1"
// options {
// [mediapipe.SplitVectorCalculatorOptions.ext] {
// ranges: { begin: 0 end: 1 }
// ranges: { begin: 1 end: 2 }
// element_only: false
// }
// }
// }
typedef SplitVectorCalculator<TfLiteTensor> SplitTfLiteTensorVectorCalculator;
REGISTER_CALCULATOR(SplitTfLiteTensorVectorCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,125 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_SPLIT_VECTOR_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_SPLIT_VECTOR_CALCULATOR_H_
#include <vector>
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/resource_util.h"
#include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
namespace mediapipe {
// Splits an input packet with std::vector<T> into multiple std::vector<T>
// output packets using the [begin, end) ranges specified in
// SplitVectorCalculatorOptions. If the option "element_only" is set to true,
// all ranges should be of size 1 and all outputs will be elements of type T. If
// "element_only" is false, ranges can be non-zero in size and all outputs will
// be of type std::vector<T>.
// To use this class for a particular type T, register a calculator using
// SplitVectorCalculator<T>.
template <typename T>
class SplitVectorCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().NumEntries() == 1);
RET_CHECK(cc->Outputs().NumEntries() != 0);
cc->Inputs().Index(0).Set<std::vector<T>>();
const auto& options =
cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
if (cc->Outputs().NumEntries() != options.ranges_size()) {
return ::mediapipe::InvalidArgumentError(
"The number of output streams should match the number of ranges "
"specified in the CalculatorOptions.");
}
// Set the output types for each output stream.
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
if (options.ranges(i).begin() < 0 || options.ranges(i).end() < 0 ||
options.ranges(i).begin() >= options.ranges(i).end()) {
return ::mediapipe::InvalidArgumentError(
"Indices should be non-negative and begin index should be less "
"than the end index.");
}
if (options.element_only()) {
if (options.ranges(i).end() - options.ranges(i).begin() != 1) {
return ::mediapipe::InvalidArgumentError(
"Since element_only is true, all ranges should be of size 1.");
}
cc->Outputs().Index(i).Set<T>();
} else {
cc->Outputs().Index(i).Set<std::vector<T>>();
}
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
const auto& options =
cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
for (const auto& range : options.ranges()) {
ranges_.push_back({range.begin(), range.end()});
max_range_end_ = std::max(max_range_end_, range.end());
}
element_only_ = options.element_only();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
const auto& input = cc->Inputs().Index(0).Get<std::vector<T>>();
RET_CHECK_GE(input.size(), max_range_end_);
if (element_only_) {
for (int i = 0; i < ranges_.size(); ++i) {
cc->Outputs().Index(i).AddPacket(
MakePacket<T>(input[ranges_[i].first]).At(cc->InputTimestamp()));
}
} else {
for (int i = 0; i < ranges_.size(); ++i) {
auto output = absl::make_unique<std::vector<T>>(
input.begin() + ranges_[i].first,
input.begin() + ranges_[i].second);
cc->Outputs().Index(i).Add(output.release(), cc->InputTimestamp());
}
}
return ::mediapipe::OkStatus();
}
private:
std::vector<std::pair<int32, int32>> ranges_;
int32 max_range_end_ = -1;
bool element_only_ = false;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_SPLIT_VECTOR_CALCULATOR_H_

View File

@ -0,0 +1,40 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
// A Range {begin, end} specifies beginning ane ending indices to splice a
// vector. A vector v is spliced to have elements v[begin:(end-1)], i.e., with
// begin index inclusive and end index exclusive.
message Range {
optional int32 begin = 1;
optional int32 end = 2;
}
message SplitVectorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional SplitVectorCalculatorOptions ext = 259438222;
}
repeated Range ranges = 1;
// Specify if single element ranges should be outputted as std::vector<T> or
// just element of type T. By default, if a range specifies only one element,
// it is outputted as an std::vector<T>.
optional bool element_only = 2 [default = false];
}

View File

@ -0,0 +1,321 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/split_vector_calculator.h"
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
#include "mediapipe/framework/tool/validate_type.h"
#include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
namespace mediapipe {
using ::tflite::Interpreter;
const int width = 1;
const int height = 1;
const int channels = 1;
class SplitTfLiteTensorVectorCalculatorTest : public ::testing::Test {
protected:
void TearDown() {
// Note: Since the pointers contained in this vector will be cleaned up by
// the interpreter, only ensure that the vector is cleaned up for the next
// test.
input_buffers_.clear();
}
void PrepareTfLiteTensorVector(int vector_size) {
ASSERT_NE(interpreter_, nullptr);
// Prepare input tensors.
std::vector<int> indices(vector_size);
for (int i = 0; i < vector_size; ++i) {
indices[i] = i;
}
interpreter_->AddTensors(vector_size);
interpreter_->SetInputs(indices);
input_vec_ = absl::make_unique<std::vector<TfLiteTensor>>();
for (int i = 0; i < vector_size; ++i) {
interpreter_->SetTensorParametersReadWrite(i, kTfLiteFloat32, "", {3},
TfLiteQuantization());
const int tensor_index = interpreter_->inputs()[i];
interpreter_->ResizeInputTensor(tensor_index, {width, height, channels});
}
interpreter_->AllocateTensors();
// Save the tensor buffer pointers for comparison after the graph runs.
input_buffers_ = std::vector<float*>(vector_size);
for (int i = 0; i < vector_size; ++i) {
const int tensor_index = interpreter_->inputs()[i];
TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
float* tensor_buffer = tensor->data.f;
ASSERT_NE(tensor_buffer, nullptr);
for (int j = 0; j < width * height * channels; ++j) {
tensor_buffer[j] = i;
}
input_vec_->push_back(*tensor);
input_buffers_[i] = tensor_buffer;
}
}
void ValidateVectorOutput(std::vector<Packet>& output_packets,
int expected_elements, int input_begin_index) {
ASSERT_EQ(1, output_packets.size());
const std::vector<TfLiteTensor>& output_vec =
output_packets[0].Get<std::vector<TfLiteTensor>>();
ASSERT_EQ(expected_elements, output_vec.size());
for (int i = 0; i < expected_elements; ++i) {
const int expected_value = input_begin_index + i;
const TfLiteTensor* result = &output_vec[i];
float* result_buffer = result->data.f;
ASSERT_NE(result_buffer, nullptr);
ASSERT_EQ(result_buffer, input_buffers_[input_begin_index + i]);
for (int j = 0; j < width * height * channels; ++j) {
ASSERT_EQ(expected_value, result_buffer[j]);
}
}
}
void ValidateElementOutput(std::vector<Packet>& output_packets,
int input_begin_index) {
ASSERT_EQ(1, output_packets.size());
const TfLiteTensor& result = output_packets[0].Get<TfLiteTensor>();
float* result_buffer = result.data.f;
ASSERT_NE(result_buffer, nullptr);
ASSERT_EQ(result_buffer, input_buffers_[input_begin_index]);
const int expected_value = input_begin_index;
for (int j = 0; j < width * height * channels; ++j) {
ASSERT_EQ(expected_value, result_buffer[j]);
}
}
std::unique_ptr<Interpreter> interpreter_ = absl::make_unique<Interpreter>();
std::unique_ptr<std::vector<TfLiteTensor>> input_vec_ = nullptr;
std::vector<float*> input_buffers_;
std::unique_ptr<CalculatorRunner> runner_ = nullptr;
};
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTest) {
ASSERT_NE(interpreter_, nullptr);
PrepareTfLiteTensorVector(/*vector_size=*/5);
ASSERT_NE(input_vec_, nullptr);
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "tensor_in"
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "tensor_in"
output_stream: "range_0"
output_stream: "range_1"
output_stream: "range_2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 1 end: 4 }
ranges: { begin: 4 end: 5 }
}
}
}
)");
std::vector<Packet> range_0_packets;
tool::AddVectorSink("range_0", &graph_config, &range_0_packets);
std::vector<Packet> range_1_packets;
tool::AddVectorSink("range_1", &graph_config, &range_1_packets);
std::vector<Packet> range_2_packets;
tool::AddVectorSink("range_2", &graph_config, &range_2_packets);
// Run the graph.
CalculatorGraph graph;
MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config));
MEDIAPIPE_ASSERT_OK(graph.StartRun({}));
MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream(
"tensor_in", Adopt(input_vec_.release()).At(Timestamp(0))));
// Wait until the calculator finishes processing.
MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle());
ValidateVectorOutput(range_0_packets, /*expected_elements=*/1,
/*input_begin_index=*/0);
ValidateVectorOutput(range_1_packets, /*expected_elements=*/3,
/*input_begin_index=*/1);
ValidateVectorOutput(range_2_packets, /*expected_elements=*/1,
/*input_begin_index=*/4);
// Fully close the graph at the end.
MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("tensor_in"));
MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidRangeTest) {
ASSERT_NE(interpreter_, nullptr);
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "tensor_in"
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "tensor_in"
output_stream: "range_0"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 0 }
}
}
}
)");
// Run the graph.
CalculatorGraph graph;
// The graph should fail running because of an invalid range (begin == end).
ASSERT_FALSE(graph.Initialize(graph_config).ok());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOutputStreamCountTest) {
ASSERT_NE(interpreter_, nullptr);
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "tensor_in"
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "tensor_in"
output_stream: "range_0"
output_stream: "range_1"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
}
}
}
)");
// Run the graph.
CalculatorGraph graph;
// The graph should fail running because the number of output streams does not
// match the number of range elements in the options.
ASSERT_FALSE(graph.Initialize(graph_config).ok());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) {
ASSERT_NE(interpreter_, nullptr);
PrepareTfLiteTensorVector(/*vector_size=*/5);
ASSERT_NE(input_vec_, nullptr);
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "tensor_in"
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "tensor_in"
output_stream: "range_0"
output_stream: "range_1"
output_stream: "range_2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 2 end: 3 }
ranges: { begin: 4 end: 5 }
element_only: true
}
}
}
)");
std::vector<Packet> range_0_packets;
tool::AddVectorSink("range_0", &graph_config, &range_0_packets);
std::vector<Packet> range_1_packets;
tool::AddVectorSink("range_1", &graph_config, &range_1_packets);
std::vector<Packet> range_2_packets;
tool::AddVectorSink("range_2", &graph_config, &range_2_packets);
// Run the graph.
CalculatorGraph graph;
MEDIAPIPE_ASSERT_OK(graph.Initialize(graph_config));
MEDIAPIPE_ASSERT_OK(graph.StartRun({}));
MEDIAPIPE_ASSERT_OK(graph.AddPacketToInputStream(
"tensor_in", Adopt(input_vec_.release()).At(Timestamp(0))));
// Wait until the calculator finishes processing.
MEDIAPIPE_ASSERT_OK(graph.WaitUntilIdle());
ValidateElementOutput(range_0_packets,
/*input_begin_index=*/0);
ValidateElementOutput(range_1_packets,
/*input_begin_index=*/2);
ValidateElementOutput(range_2_packets,
/*input_begin_index=*/4);
// Fully close the graph at the end.
MEDIAPIPE_ASSERT_OK(graph.CloseInputStream("tensor_in"));
MEDIAPIPE_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest,
ElementOnlyDisablesVectorOutputs) {
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "tensor_in"
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "tensor_in"
output_stream: "range_0"
output_stream: "range_1"
output_stream: "range_2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 1 end: 4 }
ranges: { begin: 4 end: 5 }
element_only: true
}
}
}
)");
// Run the graph.
CalculatorGraph graph;
ASSERT_FALSE(graph.Initialize(graph_config).ok());
}
} // namespace mediapipe

View File

@ -19,6 +19,7 @@ package(default_visibility = ["//visibility:private"])
exports_files(["LICENSE"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
load("@bazel_skylib//lib:selects.bzl", "selects")
proto_library(
name = "opencv_image_encoder_calculator_proto",
@ -46,6 +47,26 @@ proto_library(
],
)
proto_library(
name = "image_cropping_calculator_proto",
srcs = ["image_cropping_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
proto_library(
name = "bilateral_filter_calculator_proto",
srcs = ["bilateral_filter_calculator.proto"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
proto_library(
name = "recolor_calculator_proto",
srcs = ["recolor_calculator.proto"],
@ -93,6 +114,26 @@ mediapipe_cc_proto_library(
deps = [":set_alpha_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "image_cropping_calculator_cc_proto",
srcs = ["image_cropping_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":image_cropping_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "bilateral_filter_calculator_cc_proto",
srcs = ["bilateral_filter_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":bilateral_filter_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "recolor_calculator_cc_proto",
srcs = ["recolor_calculator.proto"],
@ -185,6 +226,42 @@ cc_library(
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector",
] + select({
"//mediapipe:android": [
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:shader_util",
],
"//mediapipe:ios": [
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:shader_util",
],
"//conditions:default": [],
}),
alwayslink = 1,
)
cc_library(
name = "bilateral_filter_calculator",
srcs = ["bilateral_filter_calculator.cc"],
visibility = [
"//visibility:public",
],
deps = [
":bilateral_filter_calculator_cc_proto",
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector",
] + select({
"//mediapipe:android": [
"//mediapipe/gpu:gl_calculator_helper",
@ -221,6 +298,19 @@ mediapipe_cc_proto_library(
cc_library(
name = "image_transformation_calculator",
srcs = ["image_transformation_calculator.cc"],
copts = select({
"//mediapipe:apple": [
"-x objective-c++",
],
"//conditions:default": [],
}),
linkopts = select({
"//mediapipe:apple": [
"-framework CoreVideo",
"-framework MetalKit",
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":image_transformation_calculator_cc_proto",
@ -232,8 +322,8 @@ cc_library(
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
] + select({
"//mediapipe:android": [
] + selects.with_or({
("//mediapipe:android", "//mediapipe:ios"): [
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gl_quad_renderer",
@ -247,8 +337,24 @@ cc_library(
cc_library(
name = "image_cropping_calculator",
srcs = ["image_cropping_calculator.cc"],
visibility = ["//visibility:public"],
copts = select({
"//mediapipe:apple": [
"-x objective-c++",
],
"//conditions:default": [],
}),
linkopts = select({
"//mediapipe:apple": [
"-framework CoreVideo",
"-framework MetalKit",
],
"//conditions:default": [],
}),
visibility = [
"//visibility:public",
],
deps = [
":image_cropping_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
@ -257,7 +363,15 @@ cc_library(
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
] + selects.with_or({
("//mediapipe:android", "//mediapipe:ios"): [
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:shader_util",
],
"//conditions:default": [],
}),
alwayslink = 1,
)
@ -307,6 +421,12 @@ cc_library(
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:shader_util",
],
"//mediapipe:ios": [
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:shader_util",
],
"//conditions:default": [],
}),
alwayslink = 1,
@ -357,6 +477,24 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "image_properties_calculator",
srcs = ["image_properties_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
] + selects.with_or({
("//mediapipe:android", "//mediapipe:ios"): [
"//mediapipe/gpu:gpu_buffer",
],
"//conditions:default": [],
}),
alwayslink = 1,
)
cc_test(
name = "opencv_encoded_image_to_image_frame_calculator_test",
srcs = ["opencv_encoded_image_to_image_frame_calculator_test.cc"],

View File

@ -0,0 +1,553 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "mediapipe/calculators/image/bilateral_filter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h"
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__ || __EMSCRIPTEN__
namespace mediapipe {
namespace {
constexpr char kInputFrameTag[] = "IMAGE";
constexpr char kInputGuideTag[] = "GUIDE";
constexpr char kOutputFrameTag[] = "IMAGE";
constexpr char kInputFrameTagGpu[] = "IMAGE_GPU";
constexpr char kInputGuideTagGpu[] = "GUIDE_GPU";
constexpr char kOutputFrameTagGpu[] = "IMAGE_GPU";
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
} // namespace
// A calculator for applying a bilateral filter to an image,
// with an optional guide image (joint blateral).
//
// Inputs:
// One of the following two IMAGE tags:
// IMAGE: ImageFrame containing input image - Grayscale or RGB only.
// IMAGE_GPU: GpuBuffer containing input image - Grayscale, RGB or RGBA.
//
// GUIDE (optional): ImageFrame guide image used to filter IMAGE. (N/A).
// GUIDE_GPU (optional): GpuBuffer guide image used to filter IMAGE_GPU.
//
// Output:
// One of the following two tags:
// IMAGE: A filtered ImageFrame - Same as input.
// IMAGE_GPU: A filtered GpuBuffer - RGBA
//
// Options:
// sigma_space: Pixel radius: use (sigma_space*2+1)x(sigma_space*2+1) window.
// This should be set based on output image pixel space.
// sigma_color: Color variance: normalized [0-1] color difference allowed.
//
// Notes:
// * When GUIDE is present, the output image is same size as GUIDE image;
// otherwise, the output image is same size as input image.
// * On GPU the kernel window is subsampled by approximately sqrt(sigma_space)
// i.e. the step size is ~sqrt(sigma_space),
// prioritizing performance > quality.
// * TODO: Add CPU path for joint filter.
//
class BilateralFilterCalculator : public CalculatorBase {
public:
BilateralFilterCalculator() = default;
~BilateralFilterCalculator() override = default;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
// From Calculator.
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::Status RenderGpu(CalculatorContext* cc);
::mediapipe::Status RenderCpu(CalculatorContext* cc);
::mediapipe::Status GlSetup(CalculatorContext* cc);
void GlRender(CalculatorContext* cc);
mediapipe::BilateralFilterCalculatorOptions options_;
float sigma_color_ = -1.f;
float sigma_space_ = -1.f;
bool use_gpu_ = false;
bool gpu_initialized_ = false;
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
mediapipe::GlCalculatorHelper gpu_helper_;
GLuint program_ = 0;
GLuint program_joint_ = 0;
#endif // __ANDROID__ || __EMSCRIPTEN__
};
REGISTER_CALCULATOR(BilateralFilterCalculator);
::mediapipe::Status BilateralFilterCalculator::GetContract(
CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1);
if (cc->Inputs().HasTag(kInputFrameTag) &&
cc->Inputs().HasTag(kInputFrameTagGpu)) {
return ::mediapipe::InternalError("Cannot have multiple input images.");
}
if (cc->Inputs().HasTag(kInputFrameTagGpu) !=
cc->Outputs().HasTag(kOutputFrameTagGpu)) {
return ::mediapipe::InternalError("GPU output must have GPU input.");
}
// Input image to filter.
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
if (cc->Inputs().HasTag(kInputFrameTagGpu)) {
cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__ || __EMSCRIPTEN__
if (cc->Inputs().HasTag(kInputFrameTag)) {
cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>();
}
// Input guide image mask (optional)
if (cc->Inputs().HasTag(kInputGuideTagGpu)) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
cc->Inputs().Tag(kInputGuideTagGpu).Set<mediapipe::GpuBuffer>();
#endif // __ANDROID__ || __EMSCRIPTEN__
}
if (cc->Inputs().HasTag(kInputGuideTag)) {
cc->Inputs().Tag(kInputGuideTag).Set<ImageFrame>();
}
// Output image.
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
if (cc->Outputs().HasTag(kOutputFrameTagGpu)) {
cc->Outputs().Tag(kOutputFrameTagGpu).Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__ || __EMSCRIPTEN__
if (cc->Outputs().HasTag(kOutputFrameTag)) {
cc->Outputs().Tag(kOutputFrameTag).Set<ImageFrame>();
}
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // __ANDROID__ || __EMSCRIPTEN__
return ::mediapipe::OkStatus();
}
::mediapipe::Status BilateralFilterCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<mediapipe::BilateralFilterCalculatorOptions>();
if (cc->Inputs().HasTag(kInputFrameTagGpu) &&
cc->Outputs().HasTag(kOutputFrameTagGpu)) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
use_gpu_ = true;
#else
RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet.";
#endif // __ANDROID__ || __EMSCRIPTEN__
}
sigma_color_ = options_.sigma_color();
sigma_space_ = options_.sigma_space();
CHECK_GE(sigma_color_, 0.0);
CHECK_GE(sigma_space_, 0.0);
if (!use_gpu_) sigma_color_ *= 255.0;
if (use_gpu_) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status BilateralFilterCalculator::Process(CalculatorContext* cc) {
if (use_gpu_) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
if (!gpu_initialized_) {
RETURN_IF_ERROR(GlSetup(cc));
gpu_initialized_ = true;
}
RETURN_IF_ERROR(RenderGpu(cc));
return ::mediapipe::OkStatus();
}));
#endif // __ANDROID__ || __EMSCRIPTEN__
} else {
RETURN_IF_ERROR(RenderCpu(cc));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status BilateralFilterCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
gpu_helper_.RunInGlContext([this] {
if (program_) glDeleteProgram(program_);
program_ = 0;
if (program_joint_) glDeleteProgram(program_joint_);
program_joint_ = 0;
});
#endif // __ANDROID__ || __EMSCRIPTEN__
return ::mediapipe::OkStatus();
}
::mediapipe::Status BilateralFilterCalculator::RenderCpu(
CalculatorContext* cc) {
if (cc->Inputs().Tag(kInputFrameTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
auto input_mat = mediapipe::formats::MatView(&input_frame);
// Only 1 or 3 channel images supported by OpenCV.
if ((input_mat.channels() == 1 || input_mat.channels() == 3)) {
return ::mediapipe::InternalError(
"CPU filtering supports only 1 or 3 channel input images.");
}
auto output_frame = absl::make_unique<ImageFrame>(
input_frame.Format(), input_mat.cols, input_mat.rows);
const bool has_guide_image = cc->Inputs().HasTag(kInputGuideTag) &&
!cc->Inputs().Tag(kInputGuideTag).IsEmpty();
if (has_guide_image) {
// cv::jointBilateralFilter() is in contrib module 'ximgproc'.
return ::mediapipe::UnimplementedError(
"CPU joint filtering support is not implemented yet.");
} else {
auto output_mat = mediapipe::formats::MatView(output_frame.get());
// Prefer setting 'd = sigma_space * 2' to match GPU definition of radius.
cv::bilateralFilter(input_mat, output_mat, /*d=*/sigma_space_ * 2.0,
sigma_color_, sigma_space_);
}
cc->Outputs()
.Tag(kOutputFrameTag)
.Add(output_frame.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status BilateralFilterCalculator::RenderGpu(
CalculatorContext* cc) {
if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) {
return ::mediapipe::OkStatus();
}
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
const auto& input_frame =
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
auto input_texture = gpu_helper_.CreateSourceTexture(input_frame);
mediapipe::GlTexture output_texture;
const bool has_guide_image = cc->Inputs().HasTag(kInputGuideTagGpu) &&
!cc->Inputs().Tag(kInputGuideTagGpu).IsEmpty();
// Setup textures and Update image in GPU shader.
if (has_guide_image) {
// joint bilateral filter
glUseProgram(program_joint_);
const auto& guide_image =
cc->Inputs().Tag(kInputGuideTagGpu).Get<mediapipe::GpuBuffer>();
auto guide_texture = gpu_helper_.CreateSourceTexture(guide_image);
glUniform2f(glGetUniformLocation(program_joint_, "texel_size_guide"),
1.0 / guide_image.width(), 1.0 / guide_image.height());
output_texture = gpu_helper_.CreateDestinationTexture(
guide_image.width(), guide_image.height(),
mediapipe::GpuBufferFormat::kBGRA32);
gpu_helper_.BindFramebuffer(output_texture);
glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, input_texture.name());
glActiveTexture(GL_TEXTURE2);
glBindTexture(GL_TEXTURE_2D, guide_texture.name());
GlRender(cc);
glActiveTexture(GL_TEXTURE2);
glBindTexture(GL_TEXTURE_2D, 0);
glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, 0);
guide_texture.Release();
} else {
// regular bilateral filter
glUseProgram(program_);
glUniform2f(glGetUniformLocation(program_, "texel_size"),
1.0 / input_frame.width(), 1.0 / input_frame.height());
output_texture = gpu_helper_.CreateDestinationTexture(
input_frame.width(), input_frame.height(),
mediapipe::GpuBufferFormat::kBGRA32);
gpu_helper_.BindFramebuffer(output_texture);
glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, input_texture.name());
GlRender(cc);
glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, 0);
}
glFlush();
// Send out image as GPU packet.
auto output_frame = output_texture.GetFrame<mediapipe::GpuBuffer>();
cc->Outputs()
.Tag(kOutputFrameTagGpu)
.Add(output_frame.release(), cc->InputTimestamp());
// Cleanup
input_texture.Release();
output_texture.Release();
#endif // __ANDROID__ || __EMSCRIPTEN__
return ::mediapipe::OkStatus();
}
void BilateralFilterCalculator::GlRender(CalculatorContext* cc) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right
-1.0f, 1.0f, // top left
1.0f, 1.0f, // top right
};
static const GLfloat texture_vertices[] = {
0.0f, 0.0f, // bottom left
1.0f, 0.0f, // bottom right
0.0f, 1.0f, // top left
1.0f, 1.0f, // top right
};
// vertex storage
GLuint vbo[2];
glGenBuffers(2, vbo);
GLuint vao;
glGenVertexArrays(1, &vao);
glBindVertexArray(vao);
// vbo 0
glBindBuffer(GL_ARRAY_BUFFER, vbo[0]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_VERTEX);
glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr);
// vbo 1
glBindBuffer(GL_ARRAY_BUFFER, vbo[1]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr);
// draw
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
// cleanup
glDisableVertexAttribArray(ATTRIB_VERTEX);
glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glBindBuffer(GL_ARRAY_BUFFER, 0);
glBindVertexArray(0);
glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo);
#endif // __ANDROID__ || __EMSCRIPTEN__
}
::mediapipe::Status BilateralFilterCalculator::GlSetup(CalculatorContext* cc) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION,
};
const GLchar* attr_name[NUM_ATTRIBUTES] = {
"position",
"texture_coordinate",
};
// We bake our sigma values directly into the shader, so the GLSL compiler can
// optimize appropriately.
std::string sigma_options_string =
"const float sigma_space = " + std::to_string(sigma_space_) +
"; const float sigma_color = " + std::to_string(sigma_color_) + ";\n";
// Shader to do bilateral filtering on input image based on sigma space/color.
// Large kernel sizes are subsampled based on sqrt(sigma_space) window size,
// denoted as 'sparsity' below.
const std::string frag_src = GLES_VERSION_COMPAT
R"(
#if __VERSION__ < 130
#define in varying
#endif // __VERSION__ < 130
#ifdef GL_ES
#define fragColor gl_FragColor
precision highp float;
#else
#define lowp
#define mediump
#define highp
#define texture2D texture
out vec4 fragColor;
#endif // defined(GL_ES)
in vec2 sample_coordinate;
uniform sampler2D input_frame;
)" + sigma_options_string + R"(
uniform vec2 texel_size;
const float kSparsityFactor = 0.66; // Higher is more sparse.
const float sparsity = max(1.0, sqrt(sigma_space) * kSparsityFactor);
const float step = sparsity;
const float radius = sigma_space;
const float offset = (step > 1.0) ? (step * 0.5) : (0.0);
float gaussian(float x, float sigma) {
float coeff = -0.5 / (sigma * sigma * 4.0 + 1.0e-6);
return exp((x * x) * coeff);
}
void main() {
vec2 center_uv = sample_coordinate;
vec3 center_val = texture2D(input_frame, center_uv).rgb;
vec3 new_val = vec3(0.0);
float space_weight = 0.0;
float color_weight = 0.0;
float total_weight = 0.0;
float sigma_texel = max(texel_size.x, texel_size.y) * sigma_space;
// Subsample kernel space.
for (float i = -radius+offset; i <= radius; i+=step) {
for (float j = -radius+offset; j <= radius; j+=step) {
vec2 shift = vec2(j, i) * texel_size;
vec2 uv = vec2(center_uv + shift);
vec3 val = texture2D(input_frame, uv).rgb;
space_weight = gaussian(distance(center_uv, uv), sigma_texel);
color_weight = gaussian(distance(center_val, val), sigma_color);
total_weight += space_weight * color_weight;
new_val += vec3(space_weight * color_weight) * val;
}
}
new_val /= vec3(total_weight);
fragColor = vec4(new_val, 1.0);
}
)";
// Create shader program and set parameters.
mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src.c_str(),
NUM_ATTRIBUTES, (const GLchar**)&attr_name[0],
attr_location, &program_);
RET_CHECK(program_) << "Problem initializing the program.";
glUseProgram(program_);
glUniform1i(glGetUniformLocation(program_, "input_frame"), 1);
// Shader to do joint bilateral filtering on input image based on
// sigma space/color, and a Guide image.
// Large kernel sizes are subsampled based on sqrt(sigma_space) window size,
// denoted as 'sparsity' below.
const std::string joint_frag_src = GLES_VERSION_COMPAT
R"(
#if __VERSION__ < 130
#define in varying
#endif // __VERSION__ < 130
#ifdef GL_ES
#define fragColor gl_FragColor
precision highp float;
#else
#define lowp
#define mediump
#define highp
#define texture2D texture
out vec4 fragColor;
#endif // defined(GL_ES)
in vec2 sample_coordinate;
uniform sampler2D input_frame;
uniform sampler2D guide_frame;
)" + sigma_options_string + R"(
uniform vec2 texel_size_guide; // size of guide and resulting filtered image
const float kSparsityFactor = 0.66; // Higher is more sparse.
const float sparsity = max(1.0, sqrt(sigma_space) * kSparsityFactor);
const float step = sparsity;
const float radius = sigma_space;
const float offset = (step > 1.0) ? (step * 0.5) : (0.0);
float gaussian(float x, float sigma) {
float coeff = -0.5 / (sigma * sigma * 4.0 + 1.0e-6);
return exp((x * x) * coeff);
}
void main() {
vec2 center_uv = sample_coordinate;
vec3 center_val = texture2D(guide_frame, center_uv).rgb;
vec3 new_val = vec3(0.0);
float space_weight = 0.0;
float color_weight = 0.0;
float total_weight = 0.0;
float sigma_texel = max(texel_size_guide.x, texel_size_guide.y) * sigma_space;
// Subsample kernel space.
for (float i = -radius+offset; i <= radius; i+=step) {
for (float j = -radius+offset; j <= radius; j+=step) {
vec2 shift = vec2(j, i) * texel_size_guide;
vec2 uv = vec2(center_uv + shift);
vec3 guide_val = texture2D(guide_frame, uv).rgb;
vec3 out_val = texture2D(input_frame, uv).rgb;
space_weight = gaussian(distance(center_uv, uv), sigma_texel);
color_weight = gaussian(distance(center_val, guide_val), sigma_color);
total_weight += space_weight * color_weight;
new_val += vec3(space_weight * color_weight) * out_val;
}
}
new_val /= vec3(total_weight);
fragColor = vec4(new_val, 1.0);
}
)";
// Create shader program and set parameters.
mediapipe::GlhCreateProgram(
mediapipe::kBasicVertexShader, joint_frag_src.c_str(), NUM_ATTRIBUTES,
(const GLchar**)&attr_name[0], attr_location, &program_joint_);
RET_CHECK(program_joint_) << "Problem initializing the program.";
glUseProgram(program_joint_);
glUniform1i(glGetUniformLocation(program_joint_, "input_frame"), 1);
glUniform1i(glGetUniformLocation(program_joint_, "guide_frame"), 2);
#endif // __ANDROID__ || __EMSCRIPTEN__
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,20 @@
// Options for BilateralFilterCalculator
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message BilateralFilterCalculatorOptions {
extend CalculatorOptions {
optional BilateralFilterCalculatorOptions ext = 255670209;
}
// Max variance in color allowed, based on normalized color values.
optional float sigma_color = 1;
// Window radius.
// Results in a '(sigma_space*2+1) x (sigma_space*2+1)' size kernel.
// This should be set based on output image pixel space.
optional float sigma_space = 2;
}

View File

@ -1,3 +1,17 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"

View File

@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cmath>
#include "mediapipe/calculators/image/image_cropping_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
@ -21,47 +24,98 @@
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__ or iOS
namespace {
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
} // namespace
namespace mediapipe {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
#endif // __ANDROID__ or iOS
// Crops the input texture to the given rectangle region. The rectangle can
// be at arbitrary location on the image with rotation. If there's rotation, the
// output texture will have the size of the input rectangle. The rotation should
// be in radian, see rect.proto for detail.
// Currently it only works for CPU.
//
// Input:
// IMAGE: ImageFrame representing the input image.
// One of the following two tags:
// IMAGE - ImageFrame representing the input image.
// IMAGE_GPU - GpuBuffer representing the input image.
// One of the following two tags (optional if WIDTH/HEIGHT is specified):
// RECT - A Rect proto specifying the width/height and location of the
// cropping rectangle.
// NORM_RECT - A NormalizedRect proto specifying the width/height and location
// of the cropping rectangle in normalized coordinates.
// of the cropping rectangle in normalized coordinates.
// Alternative tags to RECT (optional if RECT/NORM_RECT is specified):
// WIDTH - The desired width of the output cropped image,
// based on image center
// HEIGHT - The desired height of the output cropped image,
// based on image center
//
// Output:
// IMAGE - Cropped frames.
// One of the following two tags:
// IMAGE - Cropped ImageFrame
// IMAGE_GPU - Cropped GpuBuffer.
//
// Note: input_stream values take precedence over options defined in the graph.
//
class ImageCroppingCalculator : public CalculatorBase {
public:
ImageCroppingCalculator() = default;
~ImageCroppingCalculator() override = default;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::Status RenderCpu(CalculatorContext* cc);
::mediapipe::Status RenderGpu(CalculatorContext* cc);
::mediapipe::Status InitGpu(CalculatorContext* cc);
void GlRender();
void GetOutputDimensions(CalculatorContext* cc, int src_width, int src_height,
int* dst_width, int* dst_height);
// TODO: Merge with GlCroppingCalculator to have GPU support.
bool use_gpu_{};
mediapipe::ImageCroppingCalculatorOptions options_;
bool use_gpu_ = false;
// Output texture corners (4) after transoformation in normalized coordinates.
float transformed_points_[8];
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
bool gpu_initialized_ = false;
mediapipe::GlCalculatorHelper gpu_helper_;
GLuint program_ = 0;
#endif // __ANDROID__ or iOS
};
REGISTER_CALCULATOR(ImageCroppingCalculator);
::mediapipe::Status ImageCroppingCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("IMAGE"));
RET_CHECK(cc->Outputs().HasTag("IMAGE"));
RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU"));
RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU"));
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
if (cc->Inputs().HasTag("IMAGE")) {
RET_CHECK(cc->Outputs().HasTag("IMAGE"));
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
}
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Inputs().HasTag("IMAGE_GPU")) {
RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU"));
cc->Inputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
cc->Outputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
}
#endif // __ANDROID__ or iOS
if (cc->Inputs().HasTag("RECT")) {
cc->Inputs().Tag("RECT").Set<Rect>();
@ -69,21 +123,71 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
if (cc->Inputs().HasTag("NORM_RECT")) {
cc->Inputs().Tag("NORM_RECT").Set<NormalizedRect>();
}
if (cc->Inputs().HasTag("WIDTH")) {
cc->Inputs().Tag("WIDTH").Set<int>();
}
if (cc->Inputs().HasTag("HEIGHT")) {
cc->Inputs().Tag("HEIGHT").Set<int>();
}
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImageCroppingCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
if (cc->Inputs().HasTag("IMAGE_GPU")) {
use_gpu_ = true;
}
options_ = cc->Options<mediapipe::ImageCroppingCalculatorOptions>();
if (use_gpu_) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
RETURN_IF_ERROR(gpu_helper_.Open(cc));
#else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
#endif // __ANDROID__ or iOS
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImageCroppingCalculator::Process(CalculatorContext* cc) {
if (use_gpu_) {
RETURN_IF_ERROR(RenderGpu(cc));
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
if (!gpu_initialized_) {
RETURN_IF_ERROR(InitGpu(cc));
gpu_initialized_ = true;
}
RETURN_IF_ERROR(RenderGpu(cc));
return ::mediapipe::OkStatus();
}));
#endif // __ANDROID__ or iOS
} else {
RETURN_IF_ERROR(RenderCpu(cc));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImageCroppingCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
gpu_helper_.RunInGlContext([this] {
if (program_) glDeleteProgram(program_);
program_ = 0;
});
gpu_initialized_ = false;
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}
::mediapipe::Status ImageCroppingCalculator::RenderCpu(CalculatorContext* cc) {
const auto& input_img = cc->Inputs().Tag("IMAGE").Get<ImageFrame>();
cv::Mat input_mat = formats::MatView(&input_img);
@ -97,43 +201,53 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
const auto& rect = cc->Inputs().Tag("RECT").Get<Rect>();
if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 &&
rect.y_center() >= 0) {
rotation = rect.rotation();
rect_center_x = rect.x_center();
rect_center_y = rect.y_center();
target_width = rect.width();
target_height = rect.height();
rotation = rect.rotation();
}
} else if (cc->Inputs().HasTag("NORM_RECT")) {
const auto& rect = cc->Inputs().Tag("NORM_RECT").Get<NormalizedRect>();
if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 &&
rect.y_center() >= 0.0) {
rotation = rect.rotation();
rect_center_x = std::round(rect.x_center() * input_img.Width());
rect_center_y = std::round(rect.y_center() * input_img.Height());
target_width = std::round(rect.width() * input_img.Width());
target_height = std::round(rect.height() * input_img.Height());
rotation = rect.rotation();
}
}
cv::Mat rotated_mat;
if (std::abs(rotation) > 1e-5) {
// TODO: Use open source common math library.
const float pi = 3.1415926f;
rotation = rotation * 180.0 / pi;
// First rotation the image.
cv::Point2f src_center(rect_center_x, rect_center_y);
cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, rotation, 1.0);
cv::warpAffine(input_mat, rotated_mat, rotation_mat, input_mat.size());
} else {
input_mat.copyTo(rotated_mat);
if (cc->Inputs().HasTag("WIDTH") && cc->Inputs().HasTag("HEIGHT")) {
target_width = cc->Inputs().Tag("WIDTH").Get<int>();
target_height = cc->Inputs().Tag("HEIGHT").Get<int>();
} else if (options_.has_width() && options_.has_height()) {
target_width = options_.width();
target_height = options_.height();
}
rotation = options_.rotation();
}
// Then crop the requested area.
const cv::Rect cropping_rect(rect_center_x - target_width / 2,
rect_center_y - target_height / 2, target_width,
target_height);
cv::Mat cropped_image = cv::Mat(rotated_mat, cropping_rect);
const cv::RotatedRect min_rect(cv::Point2f(rect_center_x, rect_center_y),
cv::Size2f(target_width, target_height),
rotation * 180.f / M_PI);
cv::Mat src_points;
cv::boxPoints(min_rect, src_points);
float dst_corners[8] = {0,
min_rect.size.height - 1,
0,
0,
min_rect.size.width - 1,
0,
min_rect.size.width - 1,
min_rect.size.height - 1};
cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners);
cv::Mat projection_matrix =
cv::getPerspectiveTransform(src_points, dst_points);
cv::Mat cropped_image;
cv::warpPerspective(input_mat, cropped_image, projection_matrix,
cv::Size(min_rect.size.width, min_rect.size.height));
std::unique_ptr<ImageFrame> output_frame(new ImageFrame(
input_img.Format(), cropped_image.cols, cropped_image.rows));
@ -144,7 +258,220 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
}
::mediapipe::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) {
return ::mediapipe::UnimplementedError("GPU support is not implemented yet.");
if (cc->Inputs().Tag("IMAGE_GPU").IsEmpty()) {
return ::mediapipe::OkStatus();
}
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value();
const auto& input_buffer = input_packet.Get<mediapipe::GpuBuffer>();
auto src_tex = gpu_helper_.CreateSourceTexture(input_buffer);
int out_width, out_height;
GetOutputDimensions(cc, src_tex.width(), src_tex.height(), &out_width,
&out_height);
auto dst_tex = gpu_helper_.CreateDestinationTexture(out_width, out_height);
// Run cropping shader on GPU.
{
gpu_helper_.BindFramebuffer(dst_tex); // GL_TEXTURE0
glActiveTexture(GL_TEXTURE1);
glBindTexture(src_tex.target(), src_tex.name());
GlRender();
glActiveTexture(GL_TEXTURE2);
glBindTexture(GL_TEXTURE_2D, 0);
glFlush();
}
// Send result image in GPU packet.
auto output = dst_tex.GetFrame<mediapipe::GpuBuffer>();
cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp());
// Cleanup
src_tex.Release();
dst_tex.Release();
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}
void ImageCroppingCalculator::GlRender() {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right
-1.0f, 1.0f, // top left
1.0f, 1.0f, // top right
};
const GLfloat* texture_vertices = &transformed_points_[0];
// program
glUseProgram(program_);
// vertex storage
GLuint vbo[2];
glGenBuffers(2, vbo);
GLuint vao;
glGenVertexArrays(1, &vao);
glBindVertexArray(vao);
// vbo 0
glBindBuffer(GL_ARRAY_BUFFER, vbo[0]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_VERTEX);
glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr);
// vbo 1
glBindBuffer(GL_ARRAY_BUFFER, vbo[1]);
glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices,
GL_STATIC_DRAW);
glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr);
// draw
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
// cleanup
glDisableVertexAttribArray(ATTRIB_VERTEX);
glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glBindBuffer(GL_ARRAY_BUFFER, 0);
glBindVertexArray(0);
glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo);
#endif // __ANDROID__ or iOS
}
::mediapipe::Status ImageCroppingCalculator::InitGpu(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION,
};
const GLchar* attr_name[NUM_ATTRIBUTES] = {
"position",
"texture_coordinate",
};
// Simple pass-through shader.
const GLchar* frag_src = GLES_VERSION_COMPAT
R"(
#if __VERSION__ < 130
#define in varying
#endif // __VERSION__ < 130
#ifdef GL_ES
#define fragColor gl_FragColor
precision highp float;
#else
#define lowp
#define mediump
#define highp
#define texture2D texture
out vec4 fragColor;
#endif // defined(GL_ES)
in vec2 sample_coordinate;
uniform sampler2D input_frame;
void main() {
vec4 pix = texture2D(input_frame, sample_coordinate);
fragColor = pix;
}
)";
// Program
mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, frag_src,
NUM_ATTRIBUTES, &attr_name[0], attr_location,
&program_);
RET_CHECK(program_) << "Problem initializing the program.";
// Parameters
glUseProgram(program_);
glUniform1i(glGetUniformLocation(program_, "input_frame"), 1);
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}
// For GPU only.
void ImageCroppingCalculator::GetOutputDimensions(CalculatorContext* cc,
int src_width, int src_height,
int* dst_width,
int* dst_height) {
// Get the size of the cropping box.
int crop_width = src_width;
int crop_height = src_height;
// Get the center of cropping box. Default is the at the center.
int x_center = src_width / 2;
int y_center = src_height / 2;
// Get the rotation of the cropping box.
float rotation = 0.0f;
if (cc->Inputs().HasTag("RECT")) {
const auto& rect = cc->Inputs().Tag("RECT").Get<Rect>();
// Only use the rect if it is valid.
if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 &&
rect.y_center() >= 0) {
x_center = rect.x_center();
y_center = rect.y_center();
crop_width = rect.width();
crop_height = rect.height();
rotation = rect.rotation();
}
} else if (cc->Inputs().HasTag("NORM_RECT")) {
const auto& rect = cc->Inputs().Tag("NORM_RECT").Get<NormalizedRect>();
// Only use the rect if it is valid.
if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 &&
rect.y_center() >= 0.0) {
x_center = std::round(rect.x_center() * src_width);
y_center = std::round(rect.y_center() * src_height);
crop_width = std::round(rect.width() * src_width);
crop_height = std::round(rect.height() * src_height);
rotation = rect.rotation();
}
} else {
if (cc->Inputs().HasTag("WIDTH") && cc->Inputs().HasTag("HEIGHT")) {
crop_width = cc->Inputs().Tag("WIDTH").Get<int>();
crop_height = cc->Inputs().Tag("HEIGHT").Get<int>();
} else if (options_.has_width() && options_.has_height()) {
crop_width = options_.width();
crop_height = options_.height();
}
rotation = options_.rotation();
}
const float half_width = crop_width / 2.0f;
const float half_height = crop_height / 2.0f;
const float corners[] = {-half_width, -half_height, half_width, -half_height,
-half_width, half_height, half_width, half_height};
for (int i = 0; i < 4; ++i) {
const float rotated_x = std::cos(rotation) * corners[i * 2] -
std::sin(rotation) * corners[i * 2 + 1];
const float rotated_y = std::sin(rotation) * corners[i * 2] +
std::cos(rotation) * corners[i * 2 + 1];
transformed_points_[i * 2] = ((rotated_x + x_center) / src_width);
transformed_points_[i * 2 + 1] = ((rotated_y + y_center) / src_height);
}
// Find the boundaries of the transformed rectangle.
float col_min = transformed_points_[0];
float col_max = transformed_points_[0];
float row_min = transformed_points_[1];
float row_max = transformed_points_[1];
for (int i = 1; i < 4; ++i) {
col_min = std::min(col_min, transformed_points_[i * 2]);
col_max = std::max(col_max, transformed_points_[i * 2]);
row_min = std::min(row_min, transformed_points_[i * 2 + 1]);
row_max = std::max(row_max, transformed_points_[i * 2 + 1]);
}
*dst_width = std::round((col_max - col_min) * src_width);
*dst_height = std::round((row_max - row_min) * src_height);
}
} // namespace mediapipe

View File

@ -0,0 +1,33 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message ImageCroppingCalculatorOptions {
extend CalculatorOptions {
optional ImageCroppingCalculatorOptions ext = 262466399;
}
// Output texture buffer dimensions. The values defined in the options will be
// overriden by the WIDTH and HEIGHT input streams if they exist.
optional int32 width = 1;
optional int32 height = 2;
// Rotation angle is counter-clockwise in radian.
optional float rotation = 3 [default = 0.0];
}

View File

@ -0,0 +1,93 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
#include "mediapipe/gpu/gpu_buffer.h"
#endif // __ANDROID__ or iOS
namespace mediapipe {
// Extracts image properties from the input image and outputs the properties.
// Currently only supports image size.
// Input:
// One of the following:
// IMAGE: An ImageFrame
// IMAGE_GPU: A GpuBuffer
//
// Output:
// SIZE: Size (as a std::pair<int, int>) of the input image.
//
// Example usage:
// node {
// calculator: "ImagePropertiesCalculator"
// input_stream: "IMAGE:image"
// output_stream: "SIZE:size"
// }
class ImagePropertiesCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU"));
if (cc->Inputs().HasTag("IMAGE")) {
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
}
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Inputs().HasTag("IMAGE_GPU")) {
cc->Inputs().Tag("IMAGE_GPU").Set<::mediapipe::GpuBuffer>();
}
#endif // __ANDROID__ or iOS
if (cc->Outputs().HasTag("SIZE")) {
cc->Outputs().Tag("SIZE").Set<std::pair<int, int>>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
int width;
int height;
if (cc->Inputs().HasTag("IMAGE") && !cc->Inputs().Tag("IMAGE").IsEmpty()) {
const auto& image = cc->Inputs().Tag("IMAGE").Get<ImageFrame>();
width = image.Width();
height = image.Height();
}
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Inputs().HasTag("IMAGE_GPU") &&
!cc->Inputs().Tag("IMAGE_GPU").IsEmpty()) {
const auto& image =
cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
width = image.width();
height = image.height();
}
#endif // __ANDROID__ or iOS
cc->Outputs().Tag("SIZE").AddPacket(
MakePacket<std::pair<int, int>>(width, height)
.At(cc->InputTimestamp()));
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(ImagePropertiesCalculator);
} // namespace mediapipe

View File

@ -22,12 +22,12 @@
#include "mediapipe/framework/port/status.h"
#include "mediapipe/gpu/scale_mode.pb.h"
#if defined(__ANDROID__)
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_quad_renderer.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__
#endif // __ANDROID__ || iOS
#if defined(__ANDROID__)
// The size of Java arrays is dynamic, which makes it difficult to
@ -42,9 +42,9 @@ typedef int DimensionsPacketType[2];
namespace mediapipe {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
#endif // __ANDROID__
#endif // __ANDROID__ || iOS
namespace {
int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) {
@ -170,11 +170,12 @@ class ImageTransformationCalculator : public CalculatorBase {
mediapipe::ScaleMode_Mode scale_mode_;
bool use_gpu_ = false;
#if defined(__ANDROID__)
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
GlCalculatorHelper helper_;
std::unique_ptr<QuadRenderer> rgb_renderer_;
std::unique_ptr<QuadRenderer> yuv_renderer_;
std::unique_ptr<QuadRenderer> ext_rgb_renderer_;
#endif // __ANDROID__
#endif // __ANDROID__ || iOS
};
REGISTER_CALCULATOR(ImageTransformationCalculator);
@ -189,13 +190,13 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
}
#if defined(__ANDROID__)
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
if (cc->Inputs().HasTag("IMAGE_GPU")) {
RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU"));
cc->Inputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
cc->Outputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
}
#endif // __ANDROID__
#endif // __ANDROID__ || iOS
if (cc->Inputs().HasTag("ROTATION_DEGREES")) {
cc->Inputs().Tag("ROTATION_DEGREES").Set<int>();
}
@ -211,9 +212,9 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
cc->Outputs().Tag("LETTERBOX_PADDING").Set<std::array<float, 4>>();
}
#if defined(__ANDROID__)
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
#endif // __ANDROID__
#endif // __ANDROID__ || iOS
return ::mediapipe::OkStatus();
}
@ -221,7 +222,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
::mediapipe::Status ImageTransformationCalculator::Open(CalculatorContext* cc) {
// Inform the framework that we always output at the same timestamp
// as we receive a packet at.
cc->SetOffset(mediapipe::TimestampDiff(0));
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<ImageTransformationCalculatorOptions>();
@ -249,12 +250,12 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
scale_mode_ = ParseScaleMode(options_.scale_mode(), DEFAULT_SCALE_MODE);
if (use_gpu_) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
// Let the helper access the GL context information.
RETURN_IF_ERROR(helper_.Open(cc));
#else
RET_CHECK_FAIL() << "GPU processing for non-Android not supported yet.";
#endif // __ANDROID__
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
#endif // __ANDROID__ || iOS
}
return ::mediapipe::OkStatus();
@ -263,10 +264,10 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
::mediapipe::Status ImageTransformationCalculator::Process(
CalculatorContext* cc) {
if (use_gpu_) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
return helper_.RunInGlContext(
[this, cc]() -> ::mediapipe::Status { return RenderGpu(cc); });
#endif // __ANDROID__
#endif // __ANDROID__ || iOS
} else {
return RenderCpu(cc);
}
@ -276,10 +277,11 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
::mediapipe::Status ImageTransformationCalculator::Close(
CalculatorContext* cc) {
if (use_gpu_) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
QuadRenderer* rgb_renderer = rgb_renderer_.release();
QuadRenderer* yuv_renderer = yuv_renderer_.release();
QuadRenderer* ext_rgb_renderer = ext_rgb_renderer_.release();
helper_.RunInGlContext([rgb_renderer, ext_rgb_renderer] {
helper_.RunInGlContext([rgb_renderer, yuv_renderer, ext_rgb_renderer] {
if (rgb_renderer) {
rgb_renderer->GlTeardown();
delete rgb_renderer;
@ -288,10 +290,13 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
ext_rgb_renderer->GlTeardown();
delete ext_rgb_renderer;
}
if (yuv_renderer) {
yuv_renderer->GlTeardown();
delete yuv_renderer;
}
});
#endif // __ANDROID__
#endif // __ANDROID__ || iOS
}
return ::mediapipe::OkStatus();
}
@ -366,7 +371,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
::mediapipe::Status ImageTransformationCalculator::RenderGpu(
CalculatorContext* cc) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX
int input_width = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>().width();
int input_height = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>().height();
@ -387,8 +392,23 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>();
QuadRenderer* renderer = nullptr;
GlTexture src1;
#if defined(__APPLE__) && !TARGET_OS_OSX
if (input.format() == GpuBufferFormat::kBiPlanar420YpCbCr8VideoRange ||
input.format() == GpuBufferFormat::kBiPlanar420YpCbCr8FullRange) {
if (!yuv_renderer_) {
yuv_renderer_ = absl::make_unique<QuadRenderer>();
RETURN_IF_ERROR(
yuv_renderer_->GlSetup(::mediapipe::kYUV2TexToRGBFragmentShader,
{"video_frame_y", "video_frame_uv"}));
}
renderer = yuv_renderer_.get();
src1 = helper_.CreateSourceTexture(input, 0);
} else // NOLINT(readability/braces)
#endif // iOS
{
src1 = helper_.CreateSourceTexture(input);
#if defined(__ANDROID__)
if (src1.target() == GL_TEXTURE_EXTERNAL_OES) {
if (!ext_rgb_renderer_) {
ext_rgb_renderer_ = absl::make_unique<QuadRenderer>();
@ -396,7 +416,9 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
::mediapipe::kBasicTexturedFragmentShaderOES, {"video_frame"}));
}
renderer = ext_rgb_renderer_.get();
} else {
} else // NOLINT(readability/braces)
#endif // __ANDROID__
{
if (!rgb_renderer_) {
rgb_renderer_ = absl::make_unique<QuadRenderer>();
RETURN_IF_ERROR(rgb_renderer_->GlSetup());
@ -438,7 +460,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
auto output = dst.GetFrame<GpuBuffer>();
cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp());
#endif // __ANDROID__
#endif // __ANDROID__ || iOS
return ::mediapipe::OkStatus();
}

View File

@ -45,11 +45,11 @@ class OpenCvPutTextCalculator : public CalculatorBase {
::mediapipe::Status OpenCvPutTextCalculator::Process(CalculatorContext* cc) {
const std::string& text_content = cc->Inputs().Index(0).Get<std::string>();
cv::Mat mat = cv::Mat::zeros(640, 640, CV_8UC3);
cv::Mat mat = cv::Mat::zeros(640, 640, CV_8UC4);
cv::putText(mat, text_content, cv::Point(15, 70), cv::FONT_HERSHEY_PLAIN, 3,
cv::Scalar(255, 255, 0), 4);
cv::Scalar(255, 255, 0, 255), 4);
std::unique_ptr<ImageFrame> output_frame = absl::make_unique<ImageFrame>(
ImageFormat::SRGB, mat.size().width, mat.size().height);
ImageFormat::SRGBA, mat.size().width, mat.size().height);
mat.copyTo(formats::MatView(output_frame.get()));
cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();

View File

@ -21,12 +21,12 @@
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/color.pb.h"
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
namespace {
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
@ -95,10 +95,10 @@ class RecolorCalculator : public CalculatorBase {
mediapipe::RecolorCalculatorOptions::MaskChannel mask_channel_;
bool use_gpu_ = false;
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
mediapipe::GlCalculatorHelper gpu_helper_;
GLuint program_ = 0;
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
};
REGISTER_CALCULATOR(RecolorCalculator);
@ -107,46 +107,48 @@ REGISTER_CALCULATOR(RecolorCalculator);
RET_CHECK(!cc->Inputs().GetTags().empty());
RET_CHECK(!cc->Outputs().GetTags().empty());
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Inputs().HasTag("IMAGE_GPU")) {
cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
if (cc->Inputs().HasTag("IMAGE")) {
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
}
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Inputs().HasTag("MASK_GPU")) {
cc->Inputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
if (cc->Inputs().HasTag("MASK")) {
cc->Inputs().Tag("MASK").Set<ImageFrame>();
}
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Outputs().HasTag("IMAGE_GPU")) {
cc->Outputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
if (cc->Outputs().HasTag("IMAGE")) {
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
}
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}
::mediapipe::Status RecolorCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
if (cc->Inputs().HasTag("IMAGE_GPU")) {
use_gpu_ = true;
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
}
RETURN_IF_ERROR(LoadOptions(cc));
@ -156,7 +158,7 @@ REGISTER_CALCULATOR(RecolorCalculator);
::mediapipe::Status RecolorCalculator::Process(CalculatorContext* cc) {
if (use_gpu_) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
if (!initialized_) {
@ -166,7 +168,7 @@ REGISTER_CALCULATOR(RecolorCalculator);
RETURN_IF_ERROR(RenderGpu(cc));
return ::mediapipe::OkStatus();
}));
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
} else {
RETURN_IF_ERROR(RenderCpu(cc));
}
@ -174,12 +176,12 @@ REGISTER_CALCULATOR(RecolorCalculator);
}
::mediapipe::Status RecolorCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
gpu_helper_.RunInGlContext([this] {
if (program_) glDeleteProgram(program_);
program_ = 0;
});
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}
@ -192,7 +194,7 @@ REGISTER_CALCULATOR(RecolorCalculator);
if (cc->Inputs().Tag("MASK_GPU").IsEmpty()) {
return ::mediapipe::OkStatus();
}
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
// Get inputs and setup output.
const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value();
const Packet& mask_packet = cc->Inputs().Tag("MASK_GPU").Value();
@ -231,13 +233,13 @@ REGISTER_CALCULATOR(RecolorCalculator);
img_tex.Release();
mask_tex.Release();
dst_tex.Release();
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}
void RecolorCalculator::GlRender() {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right
@ -285,7 +287,7 @@ void RecolorCalculator::GlRender() {
glBindVertexArray(0);
glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo);
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
}
::mediapipe::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) {
@ -303,7 +305,7 @@ void RecolorCalculator::GlRender() {
}
::mediapipe::Status RecolorCalculator::InitGpu(CalculatorContext* cc) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION,
@ -372,7 +374,7 @@ void RecolorCalculator::GlRender() {
glUniform1i(glGetUniformLocation(program_, "mask"), 2);
glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0], color_[1],
color_[2]);
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}

View File

@ -25,11 +25,12 @@
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h"
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
namespace mediapipe {
@ -98,18 +99,18 @@ class SetAlphaCalculator : public CalculatorBase {
::mediapipe::Status RenderGpu(CalculatorContext* cc);
::mediapipe::Status RenderCpu(CalculatorContext* cc);
::mediapipe::Status GlRender(CalculatorContext* cc);
::mediapipe::Status GlSetup(CalculatorContext* cc);
void GlRender(CalculatorContext* cc);
mediapipe::SetAlphaCalculatorOptions options_;
float alpha_value_ = -1.f;
bool use_gpu_ = false;
bool gpu_initialized_ = false;
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
mediapipe::GlCalculatorHelper gpu_helper_;
GLuint program_ = 0;
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
};
REGISTER_CALCULATOR(SetAlphaCalculator);
@ -126,52 +127,54 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
}
// Input image to add/edit alpha channel.
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Inputs().HasTag(kInputFrameTagGpu)) {
cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
if (cc->Inputs().HasTag(kInputFrameTag)) {
cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>();
}
// Input alpha image mask (optional)
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Inputs().HasTag(kInputAlphaTagGpu)) {
cc->Inputs().Tag(kInputAlphaTagGpu).Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
if (cc->Inputs().HasTag(kInputAlphaTag)) {
cc->Inputs().Tag(kInputAlphaTag).Set<ImageFrame>();
}
// RGBA output image.
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Outputs().HasTag(kOutputFrameTagGpu)) {
cc->Outputs().Tag(kOutputFrameTagGpu).Set<mediapipe::GpuBuffer>();
}
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
if (cc->Outputs().HasTag(kOutputFrameTag)) {
cc->Outputs().Tag(kOutputFrameTag).Set<ImageFrame>();
}
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}
::mediapipe::Status SetAlphaCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<mediapipe::SetAlphaCalculatorOptions>();
if (cc->Inputs().HasTag(kInputFrameTagGpu) &&
cc->Outputs().HasTag(kOutputFrameTagGpu)) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
use_gpu_ = true;
#else
RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet.";
#endif // __ANDROID__
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
#endif // __ANDROID__ or iOS
}
// Get global value from options (-1 if not set).
@ -184,7 +187,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
RET_CHECK_FAIL() << "Must use either image mask or options alpha value.";
if (use_gpu_) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif
}
@ -194,7 +197,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
::mediapipe::Status SetAlphaCalculator::Process(CalculatorContext* cc) {
if (use_gpu_) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
if (!gpu_initialized_) {
@ -204,7 +207,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
RETURN_IF_ERROR(RenderGpu(cc));
return ::mediapipe::OkStatus();
}));
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
} else {
RETURN_IF_ERROR(RenderCpu(cc));
}
@ -213,12 +216,12 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
}
::mediapipe::Status SetAlphaCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
gpu_helper_.RunInGlContext([this] {
if (program_) glDeleteProgram(program_);
program_ = 0;
});
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}
@ -292,7 +295,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) {
return ::mediapipe::OkStatus();
}
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
// Setup source texture.
const auto& input_frame =
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
@ -345,13 +348,13 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
// Cleanup
input_texture.Release();
output_texture.Release();
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}
::mediapipe::Status SetAlphaCalculator::GlRender(CalculatorContext* cc) {
#if defined(__ANDROID__)
void SetAlphaCalculator::GlRender(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right
@ -400,13 +403,11 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo);
#endif // __ANDROID__
return ::mediapipe::OkStatus();
#endif // __ANDROID__ or iOS
}
::mediapipe::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION,
@ -417,6 +418,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
};
// Shader to overlay a texture onto another when overlay is non-zero.
// TODO split into two shaders to handle alpha_value<0 separately
const GLchar* frag_src = GLES_VERSION_COMPAT
R"(
#if __VERSION__ < 130
@ -458,7 +460,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
glUniform1i(glGetUniformLocation(program_, "alpha_mask"), 2);
glUniform1f(glGetUniformLocation(program_, "alpha_value"), alpha_value_);
#endif // __ANDROID__
#endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus();
}

View File

@ -1,4 +1,17 @@
// Options for SetAlphaCalculator
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;

View File

@ -417,6 +417,9 @@ cc_library(
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos",
],
"//mediapipe:ios": [
"@org_tensorflow//tensorflow/core:ios_tensorflow_lib",
],
}),
alwayslink = 1,
)
@ -435,6 +438,9 @@ cc_library(
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos",
],
"//mediapipe:ios": [
"@org_tensorflow//tensorflow/core:ios_tensorflow_lib",
],
}),
)
@ -459,6 +465,10 @@ cc_library(
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_lite_nortti_lite_protos",
"//mediapipe/android/file/base",
],
"//mediapipe:ios": [
"@org_tensorflow//tensorflow/core:ios_tensorflow_lib",
"//mediapipe/android/file/base",
],
}),
alwayslink = 1,
)
@ -783,6 +793,7 @@ cc_test(
"@org_tensorflow//tensorflow/core/kernels:io",
"@org_tensorflow//tensorflow/core/kernels:state",
"@org_tensorflow//tensorflow/core/kernels:string",
"@org_tensorflow//tensorflow/core/kernels/data:tensor_dataset_op",
],
)
@ -813,6 +824,7 @@ cc_test(
"@org_tensorflow//tensorflow/core/kernels:io",
"@org_tensorflow//tensorflow/core/kernels:state",
"@org_tensorflow//tensorflow/core/kernels:string",
"@org_tensorflow//tensorflow/core/kernels/data:tensor_dataset_op",
],
)
@ -949,6 +961,9 @@ cc_test(
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_with_ops_lite_proto_no_rtti_lib",
],
"//mediapipe:ios": [
"@org_tensorflow//tensorflow/core:ios_tensorflow_test_lib",
],
}),
)

View File

@ -118,7 +118,7 @@ REGISTER_CALCULATOR(MatrixToTensorCalculator);
// Inform the framework that we always output at the same timestamp
// as we receive a packet at.
cc->SetOffset(mediapipe::TimestampDiff(0));
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}

View File

@ -37,6 +37,7 @@ const char kImageTag[] = "IMAGE";
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
const char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
const char kBBoxTag[] = "BBOX";
const char kKeypointsTag[] = "KEYPOINTS";
const char kSegmentationMaskTag[] = "CLASS_SEGMENTATION";
namespace tf = ::tensorflow;
@ -54,16 +55,11 @@ namespace mpms = ::mediapipe::mediasequence;
// stores the encoded optical flow from the same calculator, "BBOX" which stores
// bounding boxes from vector<Detections>, and streams with the
// "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector<float>'s
// associated with the name ${NAME}. Audio streams (i.e. Matrix with a
// TimeSeriesHeader) are given extra packing and unpacking support and are named
// similar to floats with the pattern "AUDIO_${NAME}". "IMAGE_${NAME}" and
// "BBOX_${NAME}" will also store prefixed versions of each stream, which allows
// for multiple image streams to be included. However, the default names are
// suppored by more tools. "ENCODED_MEDIA" stores a video encoding for the clip
// directly. The last packet on this stream is stored, and can be unpacked with
// the metadata generator. Because the media decoder always starts with
// timestamp zero, the "ENCODED_MEDIA_START_TIMESTAMP" should be recorded as
// well. Use the FirstTimestampCalculator to determine this value.
// associated with the name ${NAME}. "KEYPOINTS" stores a map of 2D keypoints
// from unordered_map<std::string, vector<pair<float, float>>>. "IMAGE_${NAME}",
// "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store prefixed versions of
// each stream, which allows for multiple image streams to be included. However,
// the default names are suppored by more tools.
//
// Example config:
// node {
@ -122,6 +118,21 @@ class PackMediaSequenceCalculator : public CalculatorBase {
}
cc->Inputs().Tag(tag).Set<OpenCvImageEncoderCalculatorResults>();
}
if (absl::StartsWith(tag, kKeypointsTag)) {
std::string key = "";
if (tag != kKeypointsTag) {
int tag_length = sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1;
if (tag[tag_length] == '_') {
key = tag.substr(tag_length + 1);
} else {
continue; // Skip keys that don't match "(kKeypointsTag)_?"
}
}
cc->Inputs()
.Tag(tag)
.Set<std::unordered_map<std::string,
std::vector<std::pair<float, float>>>>();
}
if (absl::StartsWith(tag, kBBoxTag)) {
std::string key = "";
if (tag != kBBoxTag) {
@ -169,8 +180,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
features_present_[tag] = false;
}
if (cc->Options()
.GetExtension(PackMediaSequenceCalculatorOptions::ext)
if (cc->Options<PackMediaSequenceCalculatorOptions>()
.replace_data_instead_of_append()) {
for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kImageTag)) {
@ -186,13 +196,19 @@ class PackMediaSequenceCalculator : public CalculatorBase {
mpms::ClearImageEncoded(key, sequence_.get());
mpms::ClearImageTimestamp(key, sequence_.get());
}
}
if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) {
mpms::ClearForwardFlowEncoded(sequence_.get());
mpms::ClearForwardFlowTimestamp(sequence_.get());
}
for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kBBoxTag)) {
std::string key = "";
if (tag != kBBoxTag) {
int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1;
if (tag[tag_length] == '_') {
key = tag.substr(tag_length + 1);
} else {
continue; // Skip keys that don't match "(kBBoxTag)_?"
}
}
mpms::ClearBBox(key, sequence_.get());
mpms::ClearBBoxTimestamp(key, sequence_.get());
}
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
sizeof(*kFloatFeaturePrefixTag) -
@ -200,6 +216,16 @@ class PackMediaSequenceCalculator : public CalculatorBase {
mpms::ClearFeatureFloats(key, sequence_.get());
mpms::ClearFeatureTimestamp(key, sequence_.get());
}
if (absl::StartsWith(tag, kKeypointsTag)) {
std::string key =
tag.substr(sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1);
mpms::ClearBBoxPoint(key, sequence_.get());
mpms::ClearBBoxTimestamp(key, sequence_.get());
}
}
if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) {
mpms::ClearForwardFlowEncoded(sequence_.get());
mpms::ClearForwardFlowTimestamp(sequence_.get());
}
}
@ -228,11 +254,11 @@ class PackMediaSequenceCalculator : public CalculatorBase {
}
::mediapipe::Status Close(CalculatorContext* cc) override {
auto& options =
cc->Options().GetExtension(PackMediaSequenceCalculatorOptions::ext);
auto& options = cc->Options<PackMediaSequenceCalculatorOptions>();
if (options.reconcile_metadata()) {
RET_CHECK_OK(mpms::ReconcileMetadata(options.reconcile_bbox_annotations(),
sequence_.get()));
RET_CHECK_OK(mpms::ReconcileMetadata(
options.reconcile_bbox_annotations(),
options.reconcile_region_annotations(), sequence_.get()));
}
if (options.output_only_if_all_present()) {
@ -260,6 +286,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
::mediapipe::Status Process(CalculatorContext* cc) override {
for (const auto& tag : cc->Inputs().GetTags()) {
if (!cc->Inputs().Tag(tag).IsEmpty()) {
features_present_[tag] = true;
}
if (absl::StartsWith(tag, kImageTag) &&
!cc->Inputs().Tag(tag).IsEmpty()) {
std::string key = "";
@ -281,23 +310,29 @@ class PackMediaSequenceCalculator : public CalculatorBase {
sequence_.get());
mpms::AddImageEncoded(key, image.encoded_image(), sequence_.get());
}
}
if (cc->Inputs().HasTag(kForwardFlowEncodedTag) &&
!cc->Inputs().Tag(kForwardFlowEncodedTag).IsEmpty()) {
const OpenCvImageEncoderCalculatorResults& forward_flow =
cc->Inputs()
.Tag(kForwardFlowEncodedTag)
.Get<OpenCvImageEncoderCalculatorResults>();
if (!forward_flow.has_encoded_image()) {
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "No encoded forward flow";
if (absl::StartsWith(tag, kKeypointsTag) &&
!cc->Inputs().Tag(tag).IsEmpty()) {
std::string key = "";
if (tag != kImageTag) {
int tag_length = sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1;
if (tag[tag_length] == '_') {
key = tag.substr(tag_length + 1);
} else {
continue; // Skip keys that don't match "(kKeypointsTag)_?"
}
}
const auto& keypoints =
cc->Inputs()
.Tag(tag)
.Get<std::unordered_map<
std::string, std::vector<std::pair<float, float>>>>();
for (const auto& pair : keypoints) {
mpms::AddBBoxTimestamp(mpms::merge_prefix(key, pair.first),
cc->InputTimestamp().Value(), sequence_.get());
mpms::AddBBoxPoint(mpms::merge_prefix(key, pair.first), pair.second,
sequence_.get());
}
}
mpms::AddForwardFlowTimestamp(cc->InputTimestamp().Value(),
sequence_.get());
mpms::AddForwardFlowEncoded(forward_flow.encoded_image(),
sequence_.get());
}
for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kFloatFeaturePrefixTag) &&
!cc->Inputs().Tag(tag).IsEmpty()) {
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
@ -309,8 +344,6 @@ class PackMediaSequenceCalculator : public CalculatorBase {
cc->Inputs().Tag(tag).Get<std::vector<float>>(),
sequence_.get());
}
}
for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kBBoxTag) && !cc->Inputs().Tag(tag).IsEmpty()) {
std::string key = "";
if (tag != kBBoxTag) {
@ -349,15 +382,30 @@ class PackMediaSequenceCalculator : public CalculatorBase {
mpms::AddBBoxTimestamp(key, cc->InputTimestamp().Value(),
sequence_.get());
if (!predicted_class_strings.empty()) {
mpms::AddBBoxClassString(key, predicted_class_strings,
mpms::AddBBoxLabelString(key, predicted_class_strings,
sequence_.get());
}
if (!predicted_label_ids.empty()) {
mpms::AddBBoxClassIndex(key, predicted_label_ids, sequence_.get());
mpms::AddBBoxLabelIndex(key, predicted_label_ids, sequence_.get());
}
}
}
}
if (cc->Inputs().HasTag(kForwardFlowEncodedTag) &&
!cc->Inputs().Tag(kForwardFlowEncodedTag).IsEmpty()) {
const OpenCvImageEncoderCalculatorResults& forward_flow =
cc->Inputs()
.Tag(kForwardFlowEncodedTag)
.Get<OpenCvImageEncoderCalculatorResults>();
if (!forward_flow.has_encoded_image()) {
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "No encoded forward flow";
}
mpms::AddForwardFlowTimestamp(cc->InputTimestamp().Value(),
sequence_.get());
mpms::AddForwardFlowEncoded(forward_flow.encoded_image(),
sequence_.get());
}
if (cc->Inputs().HasTag(kSegmentationMaskTag) &&
!cc->Inputs().Tag(kSegmentationMaskTag).IsEmpty()) {
bool already_has_mask = false;
@ -387,11 +435,6 @@ class PackMediaSequenceCalculator : public CalculatorBase {
}
}
}
for (const auto& tag : cc->Inputs().GetTags()) {
if (!cc->Inputs().Tag(tag).IsEmpty()) {
features_present_[tag] = true;
}
}
return ::mediapipe::OkStatus();
}

View File

@ -32,18 +32,29 @@ message PackMediaSequenceCalculatorOptions {
// (e.g. fills in the image height, width, and number of frames.)
optional bool reconcile_metadata = 2 [default = true];
// If true, updates the metadata for sequences with bounding boxes. This will
// align each bounding box annotation with the nearest frame and insert empty
// annotations as needed to satisfy the frame rate.
// If true, updates the metadata for bounding box portions of sequences. This
// will align each bounding box annotation with the nearest frame and insert
// empty annotations as needed to satisfy the frame rate.
// NOTE: IF YOU DOWNSAMPLE IN TIME YOU WILL LOSE ANNOTATIONS.
// If two or more annotations are closest to the same frame, then only
// the closest annotation is saved. This matches the behavior of
// downsampling images in time.
optional bool reconcile_bbox_annotations = 5 [default = true];
optional bool reconcile_bbox_annotations = 5 [default = false];
// If true, updates the metadata for all regions annotations, regardless of
// prefix, in the sequence. This will align each region annotation with the
// nearest frame and insert empty annotations as needed to satisfy the frame
// rate. This is particularly useful for key point annotations that are
// represented as region points. This does not exclude bboxes.
// NOTE: IF YOU DOWNSAMPLE IN TIME YOU WILL LOSE ANNOTATIONS.
// If two or more annotations are closest to the same frame, then only
// the closest annotation is saved. This matches the behavior of
// downsampling images in time.
optional bool reconcile_region_annotations = 6 [default = true];
// If true, the SequenceExample is output only if all input streams are
// present.
optional bool output_only_if_all_present = 3 [default = false];
optional bool output_only_if_all_present = 3 [default = true];
// If true, will remove all data from a sequence example for a corresponding
// input stream. E.g. if images are already present and an IMAGE stream is

View File

@ -348,15 +348,51 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) {
ASSERT_NEAR(1.0, rect.ymax(), 0.001);
}
auto class_strings =
mpms::GetPredictedBBoxClassStringAt(output_sequence, i);
mpms::GetPredictedBBoxLabelStringAt(output_sequence, i);
ASSERT_EQ("absolute bbox", class_strings[0]);
ASSERT_EQ("relative bbox", class_strings[1]);
auto class_indices = mpms::GetPredictedBBoxClassIndexAt(output_sequence, i);
auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i);
ASSERT_EQ(0, class_indices[0]);
ASSERT_EQ(1, class_indices[1]);
}
}
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) {
SetUpCalculator({"KEYPOINTS_TEST:keypoints"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
std::unordered_map<std::string, std::vector<std::pair<float, float>>> points =
{{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}};
runner_->MutableInputs()
->Tag("KEYPOINTS_TEST")
.packets.push_back(PointToForeign(&points).At(Timestamp(0)));
runner_->MutableInputs()
->Tag("KEYPOINTS_TEST")
.packets.push_back(PointToForeign(&points).At(Timestamp(1)));
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
ASSERT_EQ(2, mpms::GetBBoxPointSize("TEST/HEAD", output_sequence));
ASSERT_EQ(2, mpms::GetBBoxPointSize("TEST/TAIL", output_sequence));
ASSERT_NEAR(0.2,
mpms::GetBBoxPointAt("TEST/HEAD", output_sequence, 0)[0].second,
0.001);
ASSERT_NEAR(0.5,
mpms::GetBBoxPointAt("TEST/TAIL", output_sequence, 1)[0].first,
0.001);
}
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
SetUpCalculator({"CLASS_SEGMENTATION:detections"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
@ -395,7 +431,6 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
LOG(INFO) << "output_sequence: \n" << output_sequence.DebugString();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
ASSERT_EQ(height, mpms::GetImageHeight(output_sequence));
@ -602,6 +637,10 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
mpms::AddBBoxTimestamp(21, input_sequence.get());
mpms::AddBBoxTimestamp(22, input_sequence.get());
mpms::AddBBoxTimestamp("PREFIX", 8, input_sequence.get());
mpms::AddBBoxTimestamp("PREFIX", 9, input_sequence.get());
mpms::AddBBoxTimestamp("PREFIX", 22, input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MEDIAPIPE_ASSERT_OK(runner_->Run());
@ -617,6 +656,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 2), 30);
ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 3), 40);
ASSERT_EQ(mpms::GetBBoxTimestampAt(output_sequence, 4), 50);
ASSERT_EQ(mpms::GetBBoxTimestampSize("PREFIX", output_sequence), 5);
ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 0), 10);
ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 1), 20);
ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 2), 30);
ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 3), 40);
ASSERT_EQ(mpms::GetBBoxTimestampAt("PREFIX", output_sequence, 4), 50);
}
} // namespace

View File

@ -183,7 +183,7 @@ REGISTER_CALCULATOR(TensorToMatrixCalculator);
header_ = *input_header;
cc->Outputs().Tag(kMatrix).SetHeader(Adopt(input_header.release()));
}
cc->SetOffset(mediapipe::TimestampDiff(0));
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}

View File

@ -232,8 +232,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
current_timestamp_index_ = 0;
// Determine the data path and output it.
const auto& options =
cc->Options().GetExtension(UnpackMediaSequenceCalculatorOptions::ext);
const auto& options = cc->Options<UnpackMediaSequenceCalculatorOptions>();
const auto& sequence = cc->InputSidePackets()
.Tag(kSequenceExampleTag)
.Get<tensorflow::SequenceExample>();
@ -338,7 +337,6 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
? Timestamp::PostStream()
: Timestamp(map_kv.second[i]);
LOG(INFO) << "key: " << map_kv.first;
if (absl::StrContains(map_kv.first, mpms::GetImageTimestampKey())) {
std::vector<std::string> pieces = absl::StrSplit(map_kv.first, '/');
std::string feature_key = "";

View File

@ -40,7 +40,7 @@ message UnpackMediaSequenceCalculatorOptions {
optional float extra_padding_from_media_decoder = 5 [default = 0.0];
// Stores the packet resampler settings for the graph. The most accurate
// proceedure for sampling a range of frames is to request a padded time range
// procedure for sampling a range of frames is to request a padded time range
// from the MediaDecoderCalculator and then trim it down to the proper time
// range with the PacketResamplerCalculator.
optional PacketResamplerCalculatorOptions base_packet_resampler_options = 6;

View File

@ -460,6 +460,8 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) {
}
TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) {
// TODO: Suport proto3 proto.Any in CalculatorOptions.
// TODO: Avoid proto2 extensions in "RESAMPLER_OPTIONS".
CalculatorOptions options;
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
->set_padding_before_label(1);

View File

@ -61,6 +61,13 @@ proto_library(
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "tflite_tensors_to_landmarks_calculator_proto",
srcs = ["tflite_tensors_to_landmarks_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
mediapipe_cc_proto_library(
name = "ssd_anchors_calculator_cc_proto",
srcs = ["ssd_anchors_calculator.proto"],
@ -109,6 +116,14 @@ mediapipe_cc_proto_library(
deps = [":tflite_tensors_to_detections_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "tflite_tensors_to_landmarks_calculator_cc_proto",
srcs = ["tflite_tensors_to_landmarks_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":tflite_tensors_to_landmarks_calculator_proto"],
)
cc_library(
name = "ssd_anchors_calculator",
srcs = ["ssd_anchors_calculator.cc"],
@ -168,6 +183,21 @@ cc_test(
cc_library(
name = "tflite_inference_calculator",
srcs = ["tflite_inference_calculator.cc"],
copts = select({
"//mediapipe:ios": [
"-std=c++11",
"-x objective-c++",
"-fobjc-arc", # enable reference-counting
],
"//conditions:default": [],
}),
linkopts = select({
"//mediapipe:ios": [
"-framework CoreVideo",
"-framework MetalKit",
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":tflite_inference_calculator_cc_proto",
@ -179,12 +209,17 @@ cc_library(
"//mediapipe/framework/port:ret_check",
] + select({
"//mediapipe:android": [
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
],
"//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/objc:mediapipe_framework_ios",
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
],
"//conditions:default": [],
}),
@ -194,12 +229,28 @@ cc_library(
cc_library(
name = "tflite_converter_calculator",
srcs = ["tflite_converter_calculator.cc"],
copts = select({
"//mediapipe:ios": [
"-std=c++11",
"-x objective-c++",
"-fobjc-arc", # enable reference-counting
],
"//conditions:default": [],
}),
linkopts = select({
"//mediapipe:ios": [
"-framework CoreVideo",
"-framework MetalKit",
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":tflite_converter_calculator_cc_proto",
"//mediapipe/util:resource_util",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/tool:status_util",
"//mediapipe/framework/port:status",
@ -208,13 +259,19 @@ cc_library(
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
] + select({
"//mediapipe:android": [
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
],
"//mediapipe:ios": [
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/objc:mediapipe_framework_ios",
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
],
"//conditions:default": [],
}),
alwayslink = 1,
@ -229,6 +286,9 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
@ -279,6 +339,32 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "tflite_tensors_to_landmarks_calculator",
srcs = ["tflite_tensors_to_landmarks_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":tflite_tensors_to_landmarks_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check",
"@org_tensorflow//tensorflow/lite:framework",
],
alwayslink = 1,
)
cc_library(
name = "tflite_tensors_to_floats_calculator",
srcs = ["tflite_tensors_to_floats_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"@org_tensorflow//tensorflow/lite:framework",
],
alwayslink = 1,
)
cc_test(
name = "tflite_inference_calculator_test",
srcs = ["tflite_inference_calculator_test.cc"],
@ -298,3 +384,24 @@ cc_test(
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
cc_test(
name = "tflite_converter_calculator_test",
srcs = ["tflite_converter_calculator_test.cc"],
deps = [
":tflite_converter_calculator",
":tflite_converter_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/memory",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)

View File

@ -73,6 +73,8 @@ class SsdAnchorsCalculator : public CalculatorBase {
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
const SsdAnchorsCalculatorOptions& options =
cc->Options<SsdAnchorsCalculatorOptions>();

View File

@ -12,9 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/resource_util.h"
#include "tensorflow/lite/error_reporter.h"
@ -22,18 +27,41 @@
#if defined(__ANDROID__)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // ANDROID
#endif // __ANDROID__
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
#import <CoreVideo/CoreVideo.h>
#import <Metal/Metal.h>
#import <MetalKit/MetalKit.h>
#import "mediapipe/gpu/MPPMetalHelper.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // iOS
#if defined(__ANDROID__)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
typedef id<MTLBuffer> GpuTensor;
#endif
namespace {
constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
// Commonly used to compute the number of blocks to launch in a kernel.
int RoundUp(const int size, const int multiple) {
return (size + multiple - 1) / multiple;
int NumGroups(const int size, const int group_size) { // NOLINT
return (size + group_size - 1) / group_size;
}
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
RowMajorMatrixXf;
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>
ColMajorMatrixXf;
} // namespace
namespace mediapipe {
@ -43,30 +71,38 @@ using ::tflite::gpu::gl::GlBuffer;
using ::tflite::gpu::gl::GlProgram;
using ::tflite::gpu::gl::GlShader;
struct GPUData {
int width;
int height;
int channels;
GlBuffer ssbo;
int elements = 1;
GlBuffer buffer;
GlShader shader;
GlProgram program;
};
#endif // ANDROID
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
struct GPUData {
int elements = 1;
id<MTLBuffer> buffer;
id<MTLComputePipelineState> pipeline_state;
};
#endif
// Calculator for normalizing and converting an ImageFrame or GpuBuffer
// into a TfLiteTensor (float 32) or tflite::gpu::GlBuffer, respetively.
// Calculator for normalizing and converting an ImageFrame or Matrix
// into a TfLiteTensor (float 32) or a GpuBuffer to a tflite::gpu::GlBuffer.
//
// This calculator is designed to be used with the TfLiteInferenceCalcualtor,
// as a pre-processing step for calculator inputs.
//
// Input data is normalized to [-1,1] (default) or [0,1], specified by options.
// IMAGE and IMAGE_GPU inputs are normalized to [-1,1] (default) or [0,1],
// specified by options (unless outputting a quantized tensor).
//
// Input:
// One of the following tags:
// IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data).
// IMAGE_GPU - GpuBuffer (assumed to be RGBA or RGB GL texture)
// IMAGE_GPU - GpuBuffer (assumed to be RGBA or RGB GL texture).
// MATRIX - Matrix.
//
// Output:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32
// TENSORS_GPU - vector of GlBuffer
// One of the following tags:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8.
// TENSORS_GPU - vector of GlBuffer.
//
// Example use:
// node {
@ -83,7 +119,7 @@ struct GPUData {
// IMPORTANT Notes:
// No conversion between CPU/GPU is done.
// Inputs/outputs must match type: CPU->CPU or GPU->GPU.
// GPU tensors are currently only supported on Android.
// GPU tensors are currently only supported on mobile platforms.
// This calculator uses FixedSizeInputStreamHandler by default.
//
class TfLiteConverterCalculator : public CalculatorBase {
@ -101,43 +137,62 @@ class TfLiteConverterCalculator : public CalculatorBase {
::mediapipe::Status NormalizeImage(const ImageFrame& image_frame,
bool zero_center, bool flip_vertically,
float* tensor_buffer);
::mediapipe::Status CopyMatrixToTensor(const Matrix& matrix,
float* tensor_buffer);
::mediapipe::Status ProcessCPU(CalculatorContext* cc);
::mediapipe::Status ProcessGPU(CalculatorContext* cc);
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
#if defined(__ANDROID__)
mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_out_;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
MPPMetalHelper* gpu_helper_ = nullptr;
std::unique_ptr<GPUData> gpu_data_out_;
#endif
bool initialized_ = false;
bool use_gpu_ = false;
bool zero_center_ = true; // normalize range to [-1,1] | otherwise [0,1]
bool flip_vertically_ = false;
bool row_major_matrix_ = false;
bool use_quantized_tensors_ = false;
int max_num_channels_ = 3;
};
REGISTER_CALCULATOR(TfLiteConverterCalculator);
::mediapipe::Status TfLiteConverterCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("IMAGE") || cc->Inputs().HasTag("IMAGE_GPU"));
RET_CHECK(cc->Outputs().HasTag("TENSORS") ||
const bool has_image_tag = cc->Inputs().HasTag("IMAGE");
const bool has_image_gpu_tag = cc->Inputs().HasTag("IMAGE_GPU");
const bool has_matrix_tag = cc->Inputs().HasTag("MATRIX");
// Confirm only one of the input streams is present.
RET_CHECK(has_image_tag ^ has_image_gpu_tag ^ has_matrix_tag &&
!(has_image_tag && has_image_gpu_tag && has_matrix_tag));
// Confirm only one of the output streams is present.
RET_CHECK(cc->Outputs().HasTag("TENSORS") ^
cc->Outputs().HasTag("TENSORS_GPU"));
if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
#if defined(__ANDROID__)
if (cc->Inputs().HasTag("MATRIX")) cc->Inputs().Tag("MATRIX").Set<Matrix>();
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Inputs().HasTag("IMAGE_GPU"))
cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
#endif
if (cc->Outputs().HasTag("TENSORS"))
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Outputs().HasTag("TENSORS_GPU"))
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>();
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
#endif
#if defined(__ANDROID__)
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif
// Assign this calculator's default InputStreamHandler.
@ -147,14 +202,16 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
}
::mediapipe::Status TfLiteConverterCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
RETURN_IF_ERROR(LoadOptions(cc));
if (cc->Inputs().HasTag("IMAGE_GPU") ||
cc->Outputs().HasTag("IMAGE_OUT_GPU")) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
use_gpu_ = true;
#else
RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet.";
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
#endif
}
@ -162,8 +219,13 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
// Cannot mix CPU/GPU streams.
RET_CHECK(cc->Inputs().HasTag("IMAGE_GPU") &&
cc->Outputs().HasTag("TENSORS_GPU"));
// Cannot use quantization.
use_quantized_tensors_ = false;
#if defined(__ANDROID__)
RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
#endif
} else {
interpreter_ = absl::make_unique<tflite::Interpreter>();
@ -176,77 +238,59 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
::mediapipe::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) {
if (use_gpu_) {
// GpuBuffer to tflite::gpu::GlBuffer conversion.
#if defined(__ANDROID__)
if (!initialized_) {
RETURN_IF_ERROR(InitGpu(cc));
initialized_ = true;
}
const auto& input =
cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &input]() -> ::mediapipe::Status {
// Convert GL texture into TfLite GlBuffer (SSBO).
auto src = gpu_helper_.CreateSourceTexture(input);
glActiveTexture(GL_TEXTURE0 + 0);
glBindTexture(GL_TEXTURE_2D, src.name());
auto status = gpu_data_out_->ssbo.BindToIndex(1);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
const tflite::gpu::uint3 workgroups = {
RoundUp(gpu_data_out_->width, kWorkgroupSize),
RoundUp(gpu_data_out_->height, kWorkgroupSize), 1};
status = gpu_data_out_->program.Dispatch(workgroups);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
glBindTexture(GL_TEXTURE_2D, 0);
src.Release();
return ::mediapipe::OkStatus();
}));
auto output_tensors = absl::make_unique<std::vector<GlBuffer>>();
output_tensors->resize(1);
for (int i = 0; i < 1; ++i) {
GlBuffer& tensor = output_tensors->at(i);
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
auto status = CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_->width * gpu_data_out_->height *
gpu_data_out_->channels,
&tensor);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
tflite::gpu::gl::CopyBuffer(gpu_data_out_->ssbo, tensor);
}
cc->Outputs()
.Tag("TENSORS_GPU")
.Add(output_tensors.release(), cc->InputTimestamp());
#else
RET_CHECK_FAIL()
<< "GPU input on non-Android devices is not supported yet.";
#endif
// Convert to GPU tensors type.
RETURN_IF_ERROR(ProcessGPU(cc));
} else {
// Convert to CPU tensors or Matrix type.
RETURN_IF_ERROR(ProcessCPU(cc));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__)
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
#endif
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_data_out_.reset();
#endif
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteConverterCalculator::ProcessCPU(
CalculatorContext* cc) {
if (cc->Inputs().HasTag("IMAGE")) {
// CPU ImageFrame to TfLiteTensor conversion.
const auto& image_frame = cc->Inputs().Tag("IMAGE").Get<ImageFrame>();
const int height = image_frame.Height();
const int width = image_frame.Width();
const int channels_preserved =
std::min(image_frame.NumberOfChannels(), max_num_channels_);
if (!(image_frame.Format() == mediapipe::ImageFormat::SRGBA ||
image_frame.Format() == mediapipe::ImageFormat::SRGB ||
image_frame.Format() == mediapipe::ImageFormat::GRAY8 ||
image_frame.Format() == mediapipe::ImageFormat::VEC32F1))
RET_CHECK_FAIL() << "Unsupported CPU input format.";
const int channels = image_frame.NumberOfChannels();
const int channels_preserved = std::min(channels, max_num_channels_);
if (!initialized_) {
interpreter_->SetTensorParametersReadWrite(
0, kTfLiteFloat32, "", {channels_preserved}, TfLiteQuantization());
if (!(image_frame.Format() == mediapipe::ImageFormat::SRGBA ||
image_frame.Format() == mediapipe::ImageFormat::SRGB ||
image_frame.Format() == mediapipe::ImageFormat::GRAY8 ||
image_frame.Format() == mediapipe::ImageFormat::VEC32F1))
RET_CHECK_FAIL() << "Unsupported CPU input format.";
TfLiteQuantization quant;
if (use_quantized_tensors_) {
RET_CHECK(image_frame.Format() != mediapipe::ImageFormat::VEC32F1)
<< "Only 8-bit input images are supported for quantization.";
// Optional: Set 'quant' quantization params here if needed.
interpreter_->SetTensorParametersReadWrite(0, kTfLiteUInt8, "",
{channels_preserved}, quant);
} else {
// Default TfLiteQuantization used for no quantization.
interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "",
{channels_preserved}, quant);
}
initialized_ = true;
}
@ -256,19 +300,66 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
{height, width, channels_preserved});
interpreter_->AllocateTensors();
// Copy image data into tensor.
if (use_quantized_tensors_) {
const int width_padding =
image_frame.WidthStep() / image_frame.ByteDepth() - width * channels;
const uint8* image_buffer =
reinterpret_cast<const uint8*>(image_frame.PixelData());
uint8* tensor_buffer = tensor->data.uint8;
RET_CHECK(tensor_buffer);
for (int row = 0; row < height; ++row) {
for (int col = 0; col < width; ++col) {
for (int channel = 0; channel < channels_preserved; ++channel) {
*tensor_buffer++ = image_buffer[channel];
}
image_buffer += channels;
}
image_buffer += width_padding;
}
} else {
float* tensor_buffer = tensor->data.f;
RET_CHECK(tensor_buffer);
if (image_frame.ByteDepth() == 1) {
RETURN_IF_ERROR(NormalizeImage<uint8>(image_frame, zero_center_,
flip_vertically_, tensor_buffer));
} else if (image_frame.ByteDepth() == 4) {
RETURN_IF_ERROR(NormalizeImage<float>(image_frame, zero_center_,
flip_vertically_, tensor_buffer));
} else {
return ::mediapipe::InternalError(
"Only byte-based (8 bit) and float (32 bit) images supported.");
}
}
auto output_tensors = absl::make_unique<std::vector<TfLiteTensor>>();
output_tensors->emplace_back(*tensor);
cc->Outputs().Tag("TENSORS").Add(output_tensors.release(),
cc->InputTimestamp());
} else if (cc->Inputs().HasTag("MATRIX")) {
// CPU Matrix to TfLiteTensor conversion.
const auto& matrix = cc->Inputs().Tag("MATRIX").Get<Matrix>();
const int height = matrix.rows();
const int width = matrix.cols();
const int channels = 1;
if (!initialized_) {
interpreter_->SetTensorParametersReadWrite(
/*tensor_index=*/0, /*type=*/kTfLiteFloat32, /*name=*/"",
/*dims=*/{channels}, /*quantization=*/TfLiteQuantization());
initialized_ = true;
}
const int tensor_idx = interpreter_->inputs()[0];
TfLiteTensor* tensor = interpreter_->tensor(tensor_idx);
interpreter_->ResizeInputTensor(tensor_idx, {height, width, channels});
interpreter_->AllocateTensors();
float* tensor_buffer = tensor->data.f;
RET_CHECK(tensor_buffer);
if (image_frame.ByteDepth() == 1) {
RETURN_IF_ERROR(NormalizeImage<uint8>(image_frame, zero_center_,
flip_vertically_, tensor_buffer));
} else if (image_frame.ByteDepth() == 4) {
RETURN_IF_ERROR(NormalizeImage<float>(image_frame, zero_center_,
flip_vertically_, tensor_buffer));
} else {
return ::mediapipe::InternalError(
"Only byte-based (8 bit) and float (32 bit) images supported.");
}
RETURN_IF_ERROR(CopyMatrixToTensor(matrix, tensor_buffer));
auto output_tensors = absl::make_unique<std::vector<TfLiteTensor>>();
output_tensors->emplace_back(*tensor);
@ -279,43 +370,132 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
CalculatorContext* cc) {
#if defined(__ANDROID__)
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
#endif // __ANDROID__
// GpuBuffer to tflite::gpu::GlBuffer conversion.
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &input]() -> ::mediapipe::Status {
// Convert GL texture into TfLite GlBuffer (SSBO).
auto src = gpu_helper_.CreateSourceTexture(input);
glActiveTexture(GL_TEXTURE0 + 0);
glBindTexture(GL_TEXTURE_2D, src.name());
auto status = gpu_data_out_->buffer.BindToIndex(1);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
const tflite::gpu::uint3 workgroups = {
NumGroups(input.width(), kWorkgroupSize),
NumGroups(input.height(), kWorkgroupSize), 1};
status = gpu_data_out_->program.Dispatch(workgroups);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
glBindTexture(GL_TEXTURE_2D, 0);
src.Release();
return ::mediapipe::OkStatus();
}));
// Copy into outputs.
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
output_tensors->resize(1);
{
GlBuffer& tensor = output_tensors->at(0);
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
auto status = CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_->elements, &tensor);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
tflite::gpu::gl::CopyBuffer(gpu_data_out_->buffer, tensor);
}
cc->Outputs()
.Tag("TENSORS_GPU")
.Add(output_tensors.release(), cc->InputTimestamp());
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
// GpuBuffer to id<MTLBuffer> conversion.
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
{
id<MTLTexture> src_texture = [gpu_helper_ metalTextureWithGpuBuffer:input];
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
command_buffer.label = @"TfLiteConverterCalculatorConvert";
id<MTLComputeCommandEncoder> compute_encoder =
[command_buffer computeCommandEncoder];
[compute_encoder setComputePipelineState:gpu_data_out_->pipeline_state];
[compute_encoder setTexture:src_texture atIndex:0];
[compute_encoder setBuffer:gpu_data_out_->buffer offset:0 atIndex:1];
MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1);
MTLSize threadgroups =
MTLSizeMake(NumGroups(input.width(), kWorkgroupSize),
NumGroups(input.height(), kWorkgroupSize), 1);
[compute_encoder dispatchThreadgroups:threadgroups
threadsPerThreadgroup:threads_per_group];
[compute_encoder endEncoding];
[command_buffer commit];
[command_buffer waitUntilCompleted];
}
// Copy into outputs.
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
{
id<MTLDevice> device = gpu_helper_.mtlDevice;
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
command_buffer.label = @"TfLiteConverterCalculatorCopy";
id<MTLBuffer> tensor =
[device newBufferWithLength:gpu_data_out_->elements * sizeof(float)
options:MTLResourceStorageModeShared];
id<MTLBlitCommandEncoder> blit_command =
[command_buffer blitCommandEncoder];
[blit_command copyFromBuffer:gpu_data_out_->buffer
sourceOffset:0
toBuffer:tensor
destinationOffset:0
size:gpu_data_out_->elements * sizeof(float)];
[blit_command endEncoding];
[command_buffer commit];
[command_buffer waitUntilCompleted];
output_tensors->push_back(tensor);
}
cc->Outputs()
.Tag("TENSORS_GPU")
.Add(output_tensors.release(), cc->InputTimestamp());
#else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
#endif
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
#if defined(__ANDROID__)
// Get input image sizes.
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
// Configure inputs.
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
mediapipe::ImageFormat::Format format =
mediapipe::ImageFormatForGpuBufferFormat(input.format());
gpu_data_out_ = absl::make_unique<GPUData>();
gpu_data_out_->height = input.height();
gpu_data_out_->width = input.width();
gpu_data_out_->channels = max_num_channels_; // desired output channels
gpu_data_out_->elements = input.height() * input.width() * max_num_channels_;
const bool include_alpha = (max_num_channels_ == 4);
if (!(format == mediapipe::ImageFormat::SRGB ||
format == mediapipe::ImageFormat::SRGBA))
RET_CHECK_FAIL() << "Unsupported GPU input format.";
if (include_alpha && (format != mediapipe::ImageFormat::SRGBA))
RET_CHECK_FAIL() << "Num input channels is less than desired output.";
#endif
// Shader to convert GL Texture to Shader Storage Buffer Object (SSBO),
// with normalization to either: [0,1] or [-1,1].
#if defined(__ANDROID__)
// Device memory.
auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_->width * gpu_data_out_->height * gpu_data_out_->channels,
&gpu_data_out_->ssbo);
gpu_data_out_->elements, &gpu_data_out_->buffer);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
// Shader to convert GL Texture to Shader Storage Buffer Object (SSBO),
// with normalization to either: [0,1] or [-1,1].
const std::string shader_source = absl::Substitute(
R"( #version 310 es
layout(local_size_x = $0, local_size_y = $0) in;
@ -333,8 +513,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
output_data.elements[linear_index + 2] = pixel.z;
$6 // alpha channel
})",
/*$0=*/kWorkgroupSize, /*$1=*/gpu_data_out_->width,
/*$2=*/gpu_data_out_->height,
/*$0=*/kWorkgroupSize, /*$1=*/input.width(), /*$2=*/input.height(),
/*$3=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "",
/*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y",
/*$5=*/
@ -353,7 +532,65 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
#endif // ANDROID
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
RET_CHECK(include_alpha)
<< "iOS GPU inference currently accepts only RGBA input.";
// Device memory.
id<MTLDevice> device = gpu_helper_.mtlDevice;
gpu_data_out_->buffer =
[device newBufferWithLength:gpu_data_out_->elements * sizeof(float)
options:MTLResourceStorageModeShared];
// Shader to convert GL Texture to Metal Buffer,
// with normalization to either: [0,1] or [-1,1].
const std::string shader_source = absl::Substitute(
R"(
#include <simd/simd.h>
#include <metal_stdlib>
using namespace metal;
kernel void convertKernel(
texture2d<half, access::sample> in_tex [[ texture(0) ]],
device float* out_buf [[ buffer(1) ]],
uint2 gid [[ thread_position_in_grid ]]) {
if (gid.x >= in_tex.get_width() || gid.y >= in_tex.get_height()) return;
constexpr sampler texture_sampler(coord::pixel, address::clamp_to_edge);
const float2 coord = float2(gid.x, gid.y);
$0 pixel = $0(in_tex.sample(texture_sampler, coord).$1);
$2 // normalize [-1,1]
const int linear_index = $4 * ($3 * in_tex.get_width() + gid.x);
out_buf[linear_index + 0] = pixel.x;
out_buf[linear_index + 1] = pixel.y;
out_buf[linear_index + 2] = pixel.z;
$5 // alpha channel
}
)",
/*$0=*/include_alpha ? "float4" : "float3",
/*$1=*/include_alpha ? "rgba" : "rgb",
/*$2=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "",
/*$3=*/flip_vertically_ ? "(in_tex.get_height() - 1 - gid.y)" : "gid.y",
/*$4=*/include_alpha ? 4 : 3,
/*$5=*/include_alpha ? "out_buf[linear_index + 3] = pixel.w;" : "");
NSString* library_source =
[NSString stringWithUTF8String:shader_source.c_str()];
NSError* error = nil;
id<MTLLibrary> library =
[device newLibraryWithSource:library_source options:nullptr error:&error];
RET_CHECK(library != nil) << "Couldn't create shader library "
<< [[error localizedDescription] UTF8String];
id<MTLFunction> kernel_func = nil;
kernel_func = [library newFunctionWithName:@"convertKernel"];
RET_CHECK(kernel_func != nil) << "Couldn't create kernel function.";
gpu_data_out_->pipeline_state =
[device newComputePipelineStateWithFunction:kernel_func error:&error];
RET_CHECK(gpu_data_out_->pipeline_state != nil)
<< "Couldn't create pipeline state "
<< [[error localizedDescription] UTF8String];
#endif
return ::mediapipe::OkStatus();
}
@ -370,11 +607,23 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
// Get y-flip mode.
flip_vertically_ = options.flip_vertically();
// Get row_major_matrix mode.
row_major_matrix_ = options.row_major_matrix();
// Get desired way to handle input channels.
max_num_channels_ = options.max_num_channels();
// Currently only alpha channel toggling is suppored.
CHECK_GE(max_num_channels_, 3);
CHECK_LE(max_num_channels_, 4);
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
if (cc->Inputs().HasTag("IMAGE_GPU"))
// Currently on iOS, tflite gpu input tensor must be 4 channels,
// so input image must be 4 channels also (checked in InitGpu).
max_num_channels_ = 4;
#endif
// Get tensor type, float or quantized.
use_quantized_tensors_ = options.use_quantized_tensors();
return ::mediapipe::OkStatus();
}
@ -415,4 +664,19 @@ template <class T>
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteConverterCalculator::CopyMatrixToTensor(
const Matrix& matrix, float* tensor_buffer) {
if (row_major_matrix_) {
auto matrix_map = Eigen::Map<RowMajorMatrixXf>(tensor_buffer, matrix.rows(),
matrix.cols());
matrix_map = matrix;
} else {
auto matrix_map = Eigen::Map<ColMajorMatrixXf>(tensor_buffer, matrix.rows(),
matrix.cols());
matrix_map = matrix;
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -22,9 +22,10 @@ message TfLiteConverterCalculatorOptions {
optional TfLiteConverterCalculatorOptions ext = 245817797;
}
// Choose normalization mode for output:
// Choose normalization mode for output (not applied for Matrix inputs).
// true = [-1,1]
// false = [0,1]
// Ignored if using quantization.
optional bool zero_center = 1 [default = true];
// Whether the input image should be flipped vertically (along the
@ -38,4 +39,12 @@ message TfLiteConverterCalculatorOptions {
// tensor. Currently this only controls whether or not to ignore alpha
// channel, so it must be 3 or 4.
optional int32 max_num_channels = 3 [default = 3];
// The calculator expects Matrix inputs to be in column-major order. Set
// row_major_matrix to true if the inputs are in row-major order.
optional bool row_major_matrix = 4 [default = false];
// Quantization option (CPU only).
// When true, output kTfLiteUInt8 tensor instead of kTfLiteFloat32.
optional bool use_quantized_tensors = 5 [default = false];
}

View File

@ -0,0 +1,199 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <random>
#include <vector>
#include "absl/memory/memory.h"
#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
#include "mediapipe/framework/tool/validate_type.h"
#include "tensorflow/lite/interpreter.h"
namespace mediapipe {
namespace {
constexpr char kTransposeOptionsString[] =
"[mediapipe.TfLiteConverterCalculatorOptions.ext]: {"
"row_major_matrix: True}";
} // namespace
using RandomEngine = std::mt19937_64;
const uint32 kSeed = 1234;
const int kNumSizes = 8;
const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2},
{5, 3}, {7, 13}, {16, 32}, {101, 2}};
class TfLiteConverterCalculatorTest : public ::testing::Test {
protected:
// Adds a packet with a matrix filled with random values in [0,1].
void AddRandomMatrix(int num_rows, int num_columns, uint32 seed,
bool row_major_matrix = false) {
RandomEngine random(kSeed);
std::uniform_real_distribution<> uniform_dist(0, 1.0);
auto matrix = ::absl::make_unique<Matrix>();
matrix->resize(num_rows, num_columns);
if (row_major_matrix) {
for (int y = 0; y < num_rows; ++y) {
for (int x = 0; x < num_columns; ++x) {
float value = uniform_dist(random);
(*matrix)(y, x) = value;
}
}
} else {
for (int x = 0; x < num_columns; ++x) {
for (int y = 0; y < num_rows; ++y) {
float value = uniform_dist(random);
(*matrix)(y, x) = value;
}
}
}
MEDIAPIPE_ASSERT_OK(graph_->AddPacketToInputStream(
"matrix", Adopt(matrix.release()).At(Timestamp(0))));
}
std::unique_ptr<CalculatorGraph> graph_;
};
TEST_F(TfLiteConverterCalculatorTest, RandomMatrixColMajor) {
for (int size_index = 0; size_index < kNumSizes; ++size_index) {
const int num_rows = sizes[size_index][0];
const int num_columns = sizes[size_index][1];
// Run the calculator and verify that one output is generated.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "matrix"
node {
calculator: "TfLiteConverterCalculator"
input_stream: "MATRIX:matrix"
output_stream: "TENSORS:tensor"
options {
[mediapipe.TfLiteConverterCalculatorOptions.ext] {
row_major_matrix: false
}
}
}
)");
std::vector<Packet> output_packets;
tool::AddVectorSink("tensor", &graph_config, &output_packets);
// Run the graph.
graph_ = absl::make_unique<CalculatorGraph>();
MEDIAPIPE_ASSERT_OK(graph_->Initialize(graph_config));
MEDIAPIPE_ASSERT_OK(graph_->StartRun({}));
// Push the tensor into the graph.
AddRandomMatrix(num_rows, num_columns, kSeed, /*row_major_matrix=*/false);
// Wait until the calculator done processing.
MEDIAPIPE_ASSERT_OK(graph_->WaitUntilIdle());
EXPECT_EQ(1, output_packets.size());
// Get and process results.
const std::vector<TfLiteTensor>& tensor_vec =
output_packets[0].Get<std::vector<TfLiteTensor>>();
EXPECT_EQ(1, tensor_vec.size());
const TfLiteTensor* tensor = &tensor_vec[0];
EXPECT_EQ(kTfLiteFloat32, tensor->type);
// Verify that the data is correct.
RandomEngine random(kSeed);
std::uniform_real_distribution<> uniform_dist(0, 1.0);
const float* tensor_buffer = tensor->data.f;
for (int i = 0; i < num_rows * num_columns; ++i) {
const float expected = uniform_dist(random);
EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i;
}
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
MEDIAPIPE_ASSERT_OK(graph_->CloseInputStream("matrix"));
MEDIAPIPE_ASSERT_OK(graph_->WaitUntilDone());
graph_.reset();
}
}
TEST_F(TfLiteConverterCalculatorTest, RandomMatrixRowMajor) {
for (int size_index = 0; size_index < kNumSizes; ++size_index) {
const int num_rows = sizes[size_index][0];
const int num_columns = sizes[size_index][1];
// Run the calculator and verify that one output is generated.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "matrix"
node {
calculator: "TfLiteConverterCalculator"
input_stream: "MATRIX:matrix"
output_stream: "TENSORS:tensor"
options {
[mediapipe.TfLiteConverterCalculatorOptions.ext] {
row_major_matrix: true
}
}
}
)");
std::vector<Packet> output_packets;
tool::AddVectorSink("tensor", &graph_config, &output_packets);
// Run the graph.
graph_ = absl::make_unique<CalculatorGraph>();
MEDIAPIPE_ASSERT_OK(graph_->Initialize(graph_config));
MEDIAPIPE_ASSERT_OK(graph_->StartRun({}));
// Push the tensor into the graph.
AddRandomMatrix(num_rows, num_columns, kSeed, /*row_major_matrix=*/true);
// Wait until the calculator done processing.
MEDIAPIPE_ASSERT_OK(graph_->WaitUntilIdle());
EXPECT_EQ(1, output_packets.size());
// Get and process results.
const std::vector<TfLiteTensor>& tensor_vec =
output_packets[0].Get<std::vector<TfLiteTensor>>();
EXPECT_EQ(1, tensor_vec.size());
const TfLiteTensor* tensor = &tensor_vec[0];
EXPECT_EQ(kTfLiteFloat32, tensor->type);
// Verify that the data is correct.
RandomEngine random(kSeed);
std::uniform_real_distribution<> uniform_dist(0, 1.0);
const float* tensor_buffer = tensor->data.f;
for (int i = 0; i < num_rows * num_columns; ++i) {
const float expected = uniform_dist(random);
EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i;
}
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
MEDIAPIPE_ASSERT_OK(graph_->CloseInputStream("matrix"));
MEDIAPIPE_ASSERT_OK(graph_->WaitUntilDone());
graph_.reset();
}
}
} // namespace mediapipe

View File

@ -47,6 +47,8 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
const TfLiteCustomOpResolverCalculatorOptions& options =
cc->Options<TfLiteCustomOpResolverCalculatorOptions>();

View File

@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include <string>
#include <vector>
@ -32,17 +31,22 @@
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // ANDROID
#endif // __ANDROID__
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
#if defined(__OBJC__)
#import <CoreVideo/CoreVideo.h>
#import <Metal/Metal.h>
#import <MetalKit/MetalKit.h>
#endif // OBJC
#import "mediapipe/framework/ios/NSError+util_status.h"
#import "mediapipe/gpu/MediaPipeMetalHelper.h"
#import "mediapipe/gpu/MPPMetalHelper.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // APPLE && !TARGET_OS_OSX
#endif // iOS
#if defined(__ANDROID__)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
typedef id<MTLBuffer> GpuTensor;
#endif
// TfLiteInferenceCalculator File Layout:
// * Header
@ -56,11 +60,14 @@ using ::tflite::gpu::gl::GlProgram;
using ::tflite::gpu::gl::GlShader;
struct GPUData {
int elements = 1;
GlBuffer ssbo;
GlShader shader;
GlProgram program;
GlBuffer buffer;
};
#endif // ANDROID
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
struct GPUData {
int elements = 1;
id<MTLBuffer> buffer;
};
#endif
// Calculator Header Section
@ -78,12 +85,12 @@ struct GPUData {
// GPU.
//
// Input:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32
// TENSORS_GPU - Vector of GlBuffer (assumed to be RGB image)
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8
// TENSORS_GPU - Vector of GlBuffer or MTLBuffer
//
// Output:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32
// TENSORS_GPU - Vector of GlBuffer
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8
// TENSORS_GPU - Vector of GlBuffer or MTLBuffer
//
// Input side packet:
// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver,
@ -107,7 +114,7 @@ struct GPUData {
// Input tensors are assumed to be of the correct size and already normalized.
// All output TfLiteTensors will be destroyed when the graph closes,
// (i.e. after calling graph.WaitUntilDone()).
// GPU tensors are currently only supported on Android.
// GPU tensors are currently only supported on Android and iOS.
// This calculator uses FixedSizeInputStreamHandler by default.
//
class TfLiteInferenceCalculator : public CalculatorBase {
@ -131,40 +138,41 @@ class TfLiteInferenceCalculator : public CalculatorBase {
mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_in_;
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
#endif
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
MediaPipeMetalHelper* gpu_helper_ = nullptr;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
MPPMetalHelper* gpu_helper_ = nullptr;
std::unique_ptr<GPUData> gpu_data_in_;
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
#endif
std::string model_path_ = "";
bool gpu_inference_ = false;
bool gpu_input_ = false;
bool gpu_output_ = false;
}; // TfLiteInferenceCalculator
bool use_quantized_tensors_ = false;
};
REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// Calculator Core Section
::mediapipe::Status TfLiteInferenceCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("TENSORS") ||
RET_CHECK(cc->Inputs().HasTag("TENSORS") ^
cc->Inputs().HasTag("TENSORS_GPU"));
RET_CHECK(cc->Outputs().HasTag("TENSORS") ||
RET_CHECK(cc->Outputs().HasTag("TENSORS") ^
cc->Outputs().HasTag("TENSORS_GPU"));
if (cc->Inputs().HasTag("TENSORS"))
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Inputs().HasTag("TENSORS_GPU"))
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>();
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
#endif
if (cc->Outputs().HasTag("TENSORS"))
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Outputs().HasTag("TENSORS_GPU"))
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>();
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
#endif
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
@ -176,7 +184,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
#if defined(__ANDROID__)
RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
RETURN_IF_ERROR([MediaPipeMetalHelper updateContract:cc]);
RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif
// Assign this calculator's default InputStreamHandler.
@ -186,26 +194,26 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
}
::mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
RETURN_IF_ERROR(LoadOptions(cc));
if (cc->Inputs().HasTag("TENSORS_GPU")) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
gpu_input_ = true;
gpu_inference_ = true; // Inference must be on GPU also.
#else
RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU"))
<< "GPU input for non-Android not supported yet.";
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
#endif
}
if (cc->Outputs().HasTag("TENSORS_GPU")) {
#if defined(__ANDROID__)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
gpu_output_ = true;
RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU"))
<< "GPU output must also have GPU Input.";
#else
RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU"))
<< "GPU output for non-Android not supported yet.";
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
#endif
}
@ -215,7 +223,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
#if defined(__ANDROID__)
RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_helper_ = [[MediaPipeMetalHelper alloc] initWithCalculatorContext:cc];
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
#endif
@ -226,24 +234,38 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
}
::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) {
// Receive pre-processed tensor inputs.
// 1. Receive pre-processed tensor inputs.
if (gpu_input_) {
// Read GPU input into SSBO.
#if defined(__ANDROID__)
const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>();
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
RET_CHECK_EQ(input_tensors.size(), 1);
RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors]() -> ::mediapipe::Status {
// Explicit copy input.
tflite::gpu::gl::CopyBuffer(input_tensors[0], gpu_data_in_->ssbo);
// Run inference.
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
tflite::gpu::gl::CopyBuffer(input_tensors[0], gpu_data_in_->buffer);
return ::mediapipe::OkStatus();
}));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
RET_CHECK_EQ(input_tensors.size(), 1);
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
command_buffer.label = @"TfLiteInferenceCalculatorInput";
id<MTLBlitCommandEncoder> blit_command =
[command_buffer blitCommandEncoder];
// Explicit copy input.
[blit_command copyFromBuffer:input_tensors[0]
sourceOffset:0
toBuffer:gpu_data_in_->buffer
destinationOffset:0
size:gpu_data_in_->elements * sizeof(float)];
[blit_command endEncoding];
[command_buffer commit];
[command_buffer waitUntilCompleted];
#else
RET_CHECK_FAIL()
<< "GPU input on non-Android devices is not supported yet.";
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
#endif
} else {
// Read CPU input into tensors.
@ -252,35 +274,38 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
RET_CHECK_GT(input_tensors.size(), 0);
for (int i = 0; i < input_tensors.size(); ++i) {
const TfLiteTensor* input_tensor = &input_tensors[i];
const float* input_tensor_buffer = input_tensor->data.f;
RET_CHECK(input_tensor_buffer);
float* local_tensor_buffer = interpreter_->typed_input_tensor<float>(i);
RET_CHECK(local_tensor_buffer);
memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes);
}
// Run inference.
if (gpu_inference_) {
#if defined(__ANDROID__)
RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
return ::mediapipe::OkStatus();
}));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
#endif
} else {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
RET_CHECK(input_tensor->data.raw);
if (use_quantized_tensors_) {
const uint8* input_tensor_buffer = input_tensor->data.uint8;
uint8* local_tensor_buffer = interpreter_->typed_input_tensor<uint8>(i);
memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes);
} else {
const float* input_tensor_buffer = input_tensor->data.f;
float* local_tensor_buffer = interpreter_->typed_input_tensor<float>(i);
memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes);
}
}
}
// 2. Run inference.
if (gpu_inference_) {
#if defined(__ANDROID__)
RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
return ::mediapipe::OkStatus();
}));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
#endif
} else {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
// 3. Output processed tensors.
if (gpu_output_) {
#if defined(__ANDROID__)
// Output result tensors (GPU).
auto output_tensors = absl::make_unique<std::vector<GlBuffer>>();
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
output_tensors->resize(gpu_data_out_.size());
for (int i = 0; i < gpu_data_out_.size(); ++i) {
GlBuffer& tensor = output_tensors->at(i);
@ -290,13 +315,39 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
tflite::gpu::gl::CopyBuffer(gpu_data_out_[i]->ssbo, tensor);
tflite::gpu::gl::CopyBuffer(gpu_data_out_[i]->buffer, tensor);
}
cc->Outputs()
.Tag("TENSORS_GPU")
.Add(output_tensors.release(), cc->InputTimestamp());
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
// Output result tensors (GPU).
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
id<MTLDevice> device = gpu_helper_.mtlDevice;
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
command_buffer.label = @"TfLiteInferenceCalculatorOutput";
for (int i = 0; i < gpu_data_out_.size(); ++i) {
id<MTLBuffer> tensor =
[device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float)
options:MTLResourceStorageModeShared];
id<MTLBlitCommandEncoder> blit_command =
[command_buffer blitCommandEncoder];
// Explicit copy input.
[blit_command copyFromBuffer:gpu_data_out_[i]->buffer
sourceOffset:0
toBuffer:tensor
destinationOffset:0
size:gpu_data_out_[i]->elements * sizeof(float)];
[blit_command endEncoding];
[command_buffer commit];
[command_buffer waitUntilCompleted];
output_tensors->push_back(tensor);
}
cc->Outputs()
.Tag("TENSORS_GPU")
.Add(output_tensors.release(), cc->InputTimestamp());
#else
LOG(ERROR) << "GPU output on non-Android not supported yet.";
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
#endif
} else {
// Output result tensors (CPU).
@ -325,8 +376,13 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
return ::mediapipe::OkStatus();
}));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
DeleteGpuDelegate(delegate_);
TFLGpuDelegateDelete(delegate_);
gpu_data_in_.reset();
for (int i = 0; i < gpu_data_out_.size(); ++i) {
gpu_data_out_[i].reset();
}
#endif
delegate_ = nullptr;
}
return ::mediapipe::OkStatus();
}
@ -341,8 +397,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// Get model name.
if (!options.model_path().empty()) {
ASSIGN_OR_RETURN(model_path_,
mediapipe::PathToResourceAsFile(options.model_path()));
auto model_path = options.model_path();
ASSIGN_OR_RETURN(model_path_, mediapipe::PathToResourceAsFile(model_path));
} else {
LOG(ERROR) << "Must specify path to TFLite model.";
return ::mediapipe::Status(::mediapipe::StatusCode::kNotFound,
@ -360,11 +417,6 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
model_ = tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
RET_CHECK(model_);
#if !defined(__ANDROID__) && !(defined(__APPLE__) && !TARGET_OS_OSX)
LOG(WARNING) << "GPU only supported on mobile platforms. Using CPU fallback.";
gpu_inference_ = false;
#endif
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
const auto& op_resolver =
cc->InputSidePackets()
@ -378,8 +430,13 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
RET_CHECK(interpreter_);
if (!gpu_output_) {
if (gpu_output_) {
use_quantized_tensors_ = false;
} else {
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
use_quantized_tensors_ = (interpreter_->tensor(0)->quantization.type ==
kTfLiteAffineQuantization);
if (use_quantized_tensors_) gpu_inference_ = false;
}
return ::mediapipe::OkStatus();
@ -388,12 +445,20 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate(
CalculatorContext* cc) {
#if defined(__ANDROID__)
// Get input image sizes.
// Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed = 1;
options.compile_options.preferred_gl_object_type =
TFLITE_GL_OBJECT_TYPE_FASTEST;
options.compile_options.dynamic_batch_enabled = 0;
options.compile_options.inline_parameters = 1;
if (!delegate_) delegate_ = TfLiteGpuDelegateCreate(&options);
if (gpu_input_) {
// Get input image sizes.
gpu_data_in_ = absl::make_unique<GPUData>();
const auto& input_indices = interpreter_->inputs();
// TODO accept > 1.
RET_CHECK_EQ(input_indices.size(), 1);
RET_CHECK_EQ(input_indices.size(), 1); // TODO accept > 1.
const TfLiteTensor* tensor = interpreter_->tensor(input_indices[0]);
gpu_data_in_->elements = 1;
for (int d = 0; d < tensor->dims->size; ++d) {
@ -402,9 +467,19 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// Input to model can be either RGB/RGBA only.
RET_CHECK_GE(tensor->dims->data[3], 3);
RET_CHECK_LE(tensor->dims->data[3], 4);
// Create and bind input buffer.
auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
gpu_data_in_->elements, &gpu_data_in_->buffer);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor(
delegate_, gpu_data_in_->buffer.id(),
interpreter_->inputs()[0]), // First tensor only
kTfLiteOk);
}
// Get output image sizes.
if (gpu_output_) {
// Get output image sizes.
const auto& output_indices = interpreter_->outputs();
gpu_data_out_.resize(output_indices.size());
for (int i = 0; i < gpu_data_out_.size(); ++i) {
@ -416,55 +491,90 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
gpu_data_out_[i]->elements *= tensor->dims->data[d];
}
}
}
// Configure and create the delegate.
TfLiteGpuDelegateOptions options;
options.metadata = nullptr;
options.compile_options.precision_loss_allowed = 1;
options.compile_options.preferred_gl_object_type =
TFLITE_GL_OBJECT_TYPE_FASTEST;
options.compile_options.dynamic_batch_enabled = 0;
if (!delegate_) delegate_ = TfLiteGpuDelegateCreate(&options);
// Shader to convert GL texture to SSBO.
if (gpu_input_) {
auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
gpu_data_in_->elements, &gpu_data_in_->ssbo);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor(
delegate_, gpu_data_in_->ssbo.id(),
interpreter_->inputs()[0]), // First tensor only
kTfLiteOk);
}
// Create output SSBO buffers.
if (gpu_output_) {
// Create and bind output buffers.
interpreter_->SetAllowBufferHandleOutput(true);
const auto& output_indices = interpreter_->outputs();
for (int i = 0; i < gpu_data_out_.size(); ++i) {
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
auto status = CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_[i]->elements, &gpu_data_out_[i]->ssbo);
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
RET_CHECK_EQ(
TfLiteGpuDelegateBindBufferToTensor(
delegate_, gpu_data_out_[i]->ssbo.id(), output_indices[i]),
delegate_, gpu_data_out_[i]->buffer.id(), output_indices[i]),
kTfLiteOk);
}
}
// Must call this last.
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk);
return ::mediapipe::OkStatus();
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
#endif // __ANDROID__
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
// Configure and create the delegate.
GpuDelegateOptions options;
options.allow_precision_loss = 1;
options.wait_type = GpuDelegateOptions::WaitType::kPassive;
if (!delegate_) delegate_ = NewGpuDelegate(&options);
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk);
return ::mediapipe::OkStatus();
#endif // ANDROID or iOS
options.allow_precision_loss = false; // Must match converter, F=float/T=half
options.wait_type = GpuDelegateOptions::WaitType::kActive;
if (!delegate_) delegate_ = TFLGpuDelegateCreate(&options);
if (gpu_input_) {
// Get input image sizes.
gpu_data_in_ = absl::make_unique<GPUData>();
const auto& input_indices = interpreter_->inputs();
RET_CHECK_EQ(input_indices.size(), 1);
const TfLiteTensor* tensor = interpreter_->tensor(input_indices[0]);
gpu_data_in_->elements = 1;
// On iOS GPU, input must be 4 channels, regardless of what model expects.
{
gpu_data_in_->elements *= tensor->dims->data[0]; // batch
gpu_data_in_->elements *= tensor->dims->data[1]; // height
gpu_data_in_->elements *= tensor->dims->data[2]; // width
gpu_data_in_->elements *= 4; // channels
}
// Input to model can be RGBA only.
if (tensor->dims->data[3] != 4) {
LOG(WARNING) << "Please ensure input GPU tensor is 4 channels.";
}
// Create and bind input buffer.
id<MTLDevice> device = gpu_helper_.mtlDevice;
gpu_data_in_->buffer =
[device newBufferWithLength:gpu_data_in_->elements * sizeof(float)
options:MTLResourceStorageModeShared];
// Must call this before TFLGpuDelegateBindMetalBufferToTensor.
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk);
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
delegate_,
input_indices[0], // First tensor only
gpu_data_in_->buffer),
true);
}
if (gpu_output_) {
// Get output image sizes.
const auto& output_indices = interpreter_->outputs();
gpu_data_out_.resize(output_indices.size());
for (int i = 0; i < gpu_data_out_.size(); ++i) {
const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]);
gpu_data_out_[i] = absl::make_unique<GPUData>();
gpu_data_out_[i]->elements = 1;
// TODO handle *2 properly on some dialated models
for (int d = 0; d < tensor->dims->size; ++d) {
gpu_data_out_[i]->elements *= tensor->dims->data[d];
}
}
// Create and bind output buffers.
interpreter_->SetAllowBufferHandleOutput(true);
id<MTLDevice> device = gpu_helper_.mtlDevice;
for (int i = 0; i < gpu_data_out_.size(); ++i) {
gpu_data_out_[i]->buffer =
[device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float)
options:MTLResourceStorageModeShared];
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
delegate_, output_indices[i], gpu_data_out_[i]->buffer),
true);
}
}
#endif // iOS
return ::mediapipe::OkStatus();
}

View File

@ -120,11 +120,19 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::Status ProcessCPU(CalculatorContext* cc,
std::vector<Detection>* output_detections);
::mediapipe::Status ProcessGPU(CalculatorContext* cc,
std::vector<Detection>* output_detections);
::mediapipe::Status LoadOptions(CalculatorContext* cc);
::mediapipe::Status GlSetup(CalculatorContext* cc);
::mediapipe::Status DecodeBoxes(const float* raw_boxes,
const std::vector<Anchor>& anchors,
std::vector<float>* boxes);
::mediapipe::Status ConvertToDetections(
const float* detection_boxes, const float* detection_scores,
const int* detection_classes, std::vector<Detection>* output_detections);
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
float box_xmax, float score, int class_id,
bool flip_vertically);
@ -136,6 +144,7 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
::mediapipe::TfLiteTensorsToDetectionsCalculatorOptions options_;
std::vector<Anchor> anchors_;
bool side_packet_anchors_{};
#if defined(__ANDROID__)
mediapipe::GlCalculatorHelper gpu_helper_;
@ -187,6 +196,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Open(
CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
if (cc->Inputs().HasTag("TENSORS_GPU")) {
gpu_input_ = true;
#if defined(__ANDROID__)
@ -195,6 +206,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
}
RETURN_IF_ERROR(LoadOptions(cc));
side_packet_anchors_ = cc->InputSidePackets().HasTag("ANCHORS");
if (gpu_input_) {
RETURN_IF_ERROR(GlSetup(cc));
@ -210,72 +222,32 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
return ::mediapipe::OkStatus();
}
const bool side_packet_anchors =
cc->InputSidePackets().HasTag("ANCHORS") &&
!cc->InputSidePackets().Tag("ANCHORS").IsEmpty();
auto output_detections = absl::make_unique<std::vector<Detection>>();
std::vector<float> boxes(num_boxes_ * num_coords_);
std::vector<float> score_class_id_pairs(num_boxes_ * 2);
if (gpu_input_) {
#if defined(__ANDROID__)
const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>();
// Copy inputs.
tflite::gpu::gl::CopyBuffer(input_tensors[0], *raw_boxes_buffer_.get());
tflite::gpu::gl::CopyBuffer(input_tensors[1], *raw_scores_buffer_.get());
if (!anchors_init_) {
if (side_packet_anchors) {
const auto& anchors =
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox);
ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data());
raw_anchors_buffer_->Write<float>(absl::MakeSpan(raw_anchors));
} else {
CHECK_EQ(input_tensors.size(), 3);
tflite::gpu::gl::CopyBuffer(input_tensors[2],
*raw_anchors_buffer_.get());
}
anchors_init_ = true;
}
// Run shaders.
RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors]() -> ::mediapipe::Status {
// Decode boxes.
decoded_boxes_buffer_->BindToIndex(0);
raw_boxes_buffer_->BindToIndex(1);
raw_anchors_buffer_->BindToIndex(2);
const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1};
decode_program_->Dispatch(decode_workgroups);
// Score boxes.
scored_boxes_buffer_->BindToIndex(0);
raw_scores_buffer_->BindToIndex(1);
const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1};
score_program_->Dispatch(score_workgroups);
return ::mediapipe::OkStatus();
}));
// Copy decoded boxes from GPU to CPU.
auto status = decoded_boxes_buffer_->Read(absl::MakeSpan(boxes));
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
status = scored_boxes_buffer_->Read(absl::MakeSpan(score_class_id_pairs));
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
#else
LOG(ERROR) << "GPU input on non-Android not supported yet.";
#endif // defined(__ANDROID__)
RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
} else {
const auto& input_tensors =
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
} // if gpu_input_
// Output
if (cc->Outputs().HasTag("DETECTIONS")) {
cc->Outputs()
.Tag("DETECTIONS")
.Add(output_detections.release(), cc->InputTimestamp());
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU(
CalculatorContext* cc, std::vector<Detection>* output_detections) {
const auto& input_tensors =
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
if (input_tensors.size() == 2) {
// Postprocessing on CPU for model without postprocessing op. E.g. output
// raw score tensor and box tensor. Anchor decoding will be handled below.
const TfLiteTensor* raw_box_tensor = &input_tensors[0];
const TfLiteTensor* raw_score_tensor = &input_tensors[1];
@ -300,7 +272,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox);
const float* raw_anchors = anchor_tensor->data.f;
ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_);
} else if (side_packet_anchors) {
} else if (side_packet_anchors_) {
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
anchors_ =
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
} else {
@ -308,8 +281,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
}
anchors_init_ = true;
}
std::vector<float> boxes(num_boxes_ * num_coords_);
RETURN_IF_ERROR(DecodeBoxes(raw_boxes, anchors_, &boxes));
std::vector<float> detection_scores(num_boxes_);
std::vector<int> detection_classes(num_boxes_);
// Filter classes by scores.
for (int i = 0; i < num_boxes_; ++i) {
int class_id = -1;
@ -335,44 +312,119 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
}
}
}
score_class_id_pairs[i * 2 + 0] = max_score;
score_class_id_pairs[i * 2 + 1] = class_id;
detection_scores[i] = max_score;
detection_classes[i] = class_id;
}
} // if gpu_input_
// Convert to Detection.
RETURN_IF_ERROR(ConvertToDetections(boxes.data(), detection_scores.data(),
detection_classes.data(),
output_detections));
} else {
// Postprocessing on CPU with postprocessing op (e.g. anchor decoding and
// non-maximum suppression) within the model.
RET_CHECK_EQ(input_tensors.size(), 4);
const TfLiteTensor* detection_boxes_tensor = &input_tensors[0];
const TfLiteTensor* detection_classes_tensor = &input_tensors[1];
const TfLiteTensor* detection_scores_tensor = &input_tensors[2];
const TfLiteTensor* num_boxes_tensor = &input_tensors[3];
RET_CHECK_EQ(num_boxes_tensor->dims->size, 1);
RET_CHECK_EQ(num_boxes_tensor->dims->data[0], 1);
const float* num_boxes = num_boxes_tensor->data.f;
num_boxes_ = num_boxes[0];
RET_CHECK_EQ(detection_boxes_tensor->dims->size, 3);
RET_CHECK_EQ(detection_boxes_tensor->dims->data[0], 1);
const int max_detections = detection_boxes_tensor->dims->data[1];
RET_CHECK_EQ(detection_boxes_tensor->dims->data[2], num_coords_);
RET_CHECK_EQ(detection_classes_tensor->dims->size, 2);
RET_CHECK_EQ(detection_classes_tensor->dims->data[0], 1);
RET_CHECK_EQ(detection_classes_tensor->dims->data[1], max_detections);
RET_CHECK_EQ(detection_scores_tensor->dims->size, 2);
RET_CHECK_EQ(detection_scores_tensor->dims->data[0], 1);
RET_CHECK_EQ(detection_scores_tensor->dims->data[1], max_detections);
const float* detection_boxes = detection_boxes_tensor->data.f;
const float* detection_scores = detection_scores_tensor->data.f;
std::vector<int> detection_classes(num_boxes_);
for (int i = 0; i < num_boxes_; ++i) {
detection_classes[i] =
static_cast<int>(detection_classes_tensor->data.f[i]);
}
RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores,
detection_classes.data(),
output_detections));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
CalculatorContext* cc, std::vector<Detection>* output_detections) {
#if defined(__ANDROID__)
const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>();
// Copy inputs.
tflite::gpu::gl::CopyBuffer(input_tensors[0], *raw_boxes_buffer_.get());
tflite::gpu::gl::CopyBuffer(input_tensors[1], *raw_scores_buffer_.get());
if (!anchors_init_) {
if (side_packet_anchors_) {
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
const auto& anchors =
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox);
ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data());
raw_anchors_buffer_->Write<float>(absl::MakeSpan(raw_anchors));
} else {
CHECK_EQ(input_tensors.size(), 3);
tflite::gpu::gl::CopyBuffer(input_tensors[2], *raw_anchors_buffer_.get());
}
anchors_init_ = true;
}
// Run shaders.
RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors]() -> ::mediapipe::Status {
// Decode boxes.
decoded_boxes_buffer_->BindToIndex(0);
raw_boxes_buffer_->BindToIndex(1);
raw_anchors_buffer_->BindToIndex(2);
const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1};
decode_program_->Dispatch(decode_workgroups);
// Score boxes.
scored_boxes_buffer_->BindToIndex(0);
raw_scores_buffer_->BindToIndex(1);
const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1};
score_program_->Dispatch(score_workgroups);
return ::mediapipe::OkStatus();
}));
// Copy decoded boxes from GPU to CPU.
std::vector<float> boxes(num_boxes_ * num_coords_);
auto status = decoded_boxes_buffer_->Read(absl::MakeSpan(boxes));
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
std::vector<float> score_class_id_pairs(num_boxes_ * 2);
status = scored_boxes_buffer_->Read(absl::MakeSpan(score_class_id_pairs));
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
// TODO: b/138851969. Is it possible to output a float vector
// for score and an int vector for class so that we can avoid copying twice?
std::vector<float> detection_scores(num_boxes_);
std::vector<int> detection_classes(num_boxes_);
for (int i = 0; i < num_boxes_; ++i) {
const float score = score_class_id_pairs[i * 2 + 0];
const int class_id = score_class_id_pairs[i * 2 + 1];
const int box_offset = i * num_coords_;
Detection detection = ConvertToDetection(
boxes[box_offset + 0], boxes[box_offset + 1], boxes[box_offset + 2],
boxes[box_offset + 3], score, class_id, options_.flip_vertically());
// Add keypoints.
if (options_.num_keypoints() > 0) {
auto* location_data = detection.mutable_location_data();
for (int kp_id = 0; kp_id < options_.num_keypoints() *
options_.num_values_per_keypoint();
kp_id += options_.num_values_per_keypoint()) {
auto keypoint = location_data->add_relative_keypoints();
const int keypoint_index =
box_offset + options_.keypoint_coord_offset() + kp_id;
keypoint->set_x(boxes[keypoint_index + 0]);
keypoint->set_y(options_.flip_vertically()
? 1.f - boxes[keypoint_index + 1]
: boxes[keypoint_index + 1]);
}
}
output_detections->emplace_back(detection);
detection_scores[i] = score_class_id_pairs[i * 2];
detection_classes[i] = static_cast<int>(score_class_id_pairs[i * 2 + 1]);
}
// Output
if (cc->Outputs().HasTag("DETECTIONS")) {
cc->Outputs()
.Tag("DETECTIONS")
.Add(output_detections.release(), cc->InputTimestamp());
}
RETURN_IF_ERROR(ConvertToDetections(boxes.data(), detection_scores.data(),
detection_classes.data(),
output_detections));
#else
LOG(ERROR) << "GPU input on non-Android not supported yet.";
#endif // defined(__ANDROID__)
return ::mediapipe::OkStatus();
}
@ -481,6 +533,39 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ConvertToDetections(
const float* detection_boxes, const float* detection_scores,
const int* detection_classes, std::vector<Detection>* output_detections) {
for (int i = 0; i < num_boxes_; ++i) {
if (options_.has_min_score_thresh() &&
detection_scores[i] < options_.min_score_thresh()) {
continue;
}
const int box_offset = i * num_coords_;
Detection detection = ConvertToDetection(
detection_boxes[box_offset + 0], detection_boxes[box_offset + 1],
detection_boxes[box_offset + 2], detection_boxes[box_offset + 3],
detection_scores[i], detection_classes[i], options_.flip_vertically());
// Add keypoints.
if (options_.num_keypoints() > 0) {
auto* location_data = detection.mutable_location_data();
for (int kp_id = 0; kp_id < options_.num_keypoints() *
options_.num_values_per_keypoint();
kp_id += options_.num_values_per_keypoint()) {
auto keypoint = location_data->add_relative_keypoints();
const int keypoint_index =
box_offset + options_.keypoint_coord_offset() + kp_id;
keypoint->set_x(detection_boxes[keypoint_index + 0]);
keypoint->set_y(options_.flip_vertically()
? 1.f - detection_boxes[keypoint_index + 1]
: detection_boxes[keypoint_index + 1]);
}
}
output_detections->emplace_back(detection);
}
return ::mediapipe::OkStatus();
}
Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score,
int class_id, bool flip_vertically) {
@ -628,7 +713,7 @@ void main() {
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
size_t raw_anchors_length = num_boxes_ * num_coords_;
size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox;
raw_anchors_buffer_ = absl::make_unique<GlBuffer>();
status = CreateReadWriteShaderStorageBuffer<float>(raw_anchors_length,
raw_anchors_buffer_.get());

View File

@ -68,4 +68,7 @@ message TfLiteTensorsToDetectionsCalculatorOptions {
// the origin is at the top-left corner, whereas the desired detection
// representation has a bottom-left origin (e.g., in OpenGL).
optional bool flip_vertically = 18 [default = false];
// Score threshold for perserving decoded detections.
optional float min_score_thresh = 19;
}

View File

@ -0,0 +1,103 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/interpreter.h"
namespace mediapipe {
// A calculator for converting TFLite tensors to to a float or a float vector.
//
// Input:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first
// tensor will be used.
// Output:
// FLOAT(optional) - Converted single float number.
// FLOATS(optional) - Converted float vector.
//
// Notes: To output FLOAT stream, the input TFLite tensor must have size 1, e.g.
// only 1 float number in the tensor.
//
// Usage example:
// node {
// calculator: "TfLiteTensorsToFloatsCalculator"
// input_stream: "TENSORS:tensors"
// output_stream: "FLOATS:floats"
// }
class TfLiteTensorsToFloatsCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator);
::mediapipe::Status TfLiteTensorsToFloatsCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("TENSORS"));
RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT"));
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
if (cc->Outputs().HasTag("FLOATS")) {
cc->Outputs().Tag("FLOATS").Set<std::vector<float>>();
}
if (cc->Outputs().HasTag("FLOAT")) {
cc->Outputs().Tag("FLOAT").Set<float>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteTensorsToFloatsCalculator::Open(
CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteTensorsToFloatsCalculator::Process(
CalculatorContext* cc) {
RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty());
const auto& input_tensors =
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
// TODO: Add option to specify which tensor to take from.
const TfLiteTensor* raw_tensor = &input_tensors[0];
const float* raw_floats = raw_tensor->data.f;
int num_values = 1;
for (int i = 0; i < raw_tensor->dims->size; ++i) {
RET_CHECK_GT(raw_tensor->dims->data[i], 0);
num_values *= raw_tensor->dims->data[i];
}
if (cc->Outputs().HasTag("FLOAT")) {
// TODO: Could add an index in the option to specifiy returning one
// value of a float array.
RET_CHECK_EQ(num_values, 1);
cc->Outputs().Tag("FLOAT").AddPacket(
MakePacket<float>(raw_floats[0]).At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag("FLOATS")) {
auto output_floats = absl::make_unique<std::vector<float>>(
raw_floats, raw_floats + num_values);
cc->Outputs().Tag("FLOATS").Add(output_floats.release(),
cc->InputTimestamp());
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,188 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/interpreter.h"
namespace mediapipe {
// A calculator for converting TFLite tensors from regression models into
// landmarks.
//
// Input:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first
// tensor will be used. The size of the values must be
// (num_dimension x num_landmarks).
// Output:
// LANDMARKS(optional) - Result MediaPipe landmarks.
// NORM_LANDMARKS(optional) - Result MediaPipe normalized landmarks.
//
// Notes:
// To output normalized landmarks, user must provide the original input image
// size to the model using calculator option input_image_width and
// input_image_height.
// Usage example:
// node {
// calculator: "TfLiteTensorsToLandmarksCalculator"
// input_stream: "TENSORS:landmark_tensors"
// output_stream: "LANDMARKS:landmarks"
// output_stream: "NORM_LANDMARKS:landmarks"
// options: {
// [mediapipe.TfLiteTensorsToLandmarksCalculatorOptions.ext] {
// num_landmarks: 21
//
// input_image_width: 256
// input_image_height: 256
// }
// }
// }
class TfLiteTensorsToLandmarksCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
::mediapipe::Status LoadOptions(CalculatorContext* cc);
int num_landmarks_ = 0;
::mediapipe::TfLiteTensorsToLandmarksCalculatorOptions options_;
};
REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
::mediapipe::Status TfLiteTensorsToLandmarksCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(!cc->Inputs().GetTags().empty());
RET_CHECK(!cc->Outputs().GetTags().empty());
if (cc->Inputs().HasTag("TENSORS")) {
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
}
if (cc->Outputs().HasTag("LANDMARKS")) {
cc->Outputs().Tag("LANDMARKS").Set<std::vector<Landmark>>();
}
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
cc->Outputs().Tag("NORM_LANDMARKS").Set<std::vector<NormalizedLandmark>>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteTensorsToLandmarksCalculator::Open(
CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
RETURN_IF_ERROR(LoadOptions(cc));
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
RET_CHECK(options_.has_input_image_height() &&
options_.has_input_image_width())
<< "Must provide input with/height for getting normalized landmarks.";
}
if (cc->Outputs().HasTag("LANDMARKS") && options_.flip_vertically()) {
RET_CHECK(options_.has_input_image_height() &&
options_.has_input_image_width())
<< "Must provide input with/height for using flip_vertically option "
"when outputing landmarks in absolute coordinates.";
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteTensorsToLandmarksCalculator::Process(
CalculatorContext* cc) {
if (cc->Inputs().Tag("TENSORS").IsEmpty()) {
return ::mediapipe::OkStatus();
}
const auto& input_tensors =
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
const TfLiteTensor* raw_tensor = &input_tensors[0];
int num_values = 1;
for (int i = 0; i < raw_tensor->dims->size; ++i) {
num_values *= raw_tensor->dims->data[i];
}
const int num_dimensions = num_values / num_landmarks_;
// Landmarks must have less than 3 dimensions. Otherwise please consider
// using matrix.
CHECK_LE(num_dimensions, 3);
CHECK_GT(num_dimensions, 0);
const float* raw_landmarks = raw_tensor->data.f;
auto output_landmarks = absl::make_unique<std::vector<Landmark>>();
for (int ld = 0; ld < num_landmarks_; ++ld) {
const int offset = ld * num_dimensions;
Landmark landmark;
landmark.set_x(raw_landmarks[offset]);
if (num_dimensions > 1) {
if (options_.flip_vertically()) {
landmark.set_y(options_.input_image_height() -
raw_landmarks[offset + 1]);
} else {
landmark.set_y(raw_landmarks[offset + 1]);
}
}
if (num_dimensions > 2) {
landmark.set_z(raw_landmarks[offset + 2]);
}
output_landmarks->push_back(landmark);
}
// Output normalized landmarks if required.
if (cc->Outputs().HasTag("NORM_LANDMARKS")) {
auto output_norm_landmarks =
absl::make_unique<std::vector<NormalizedLandmark>>();
for (const auto& landmark : *output_landmarks) {
NormalizedLandmark norm_landmark;
norm_landmark.set_x(static_cast<float>(landmark.x()) /
options_.input_image_width());
norm_landmark.set_y(static_cast<float>(landmark.y()) /
options_.input_image_height());
norm_landmark.set_z(landmark.z() / options_.normalize_z());
output_norm_landmarks->push_back(norm_landmark);
}
cc->Outputs()
.Tag("NORM_LANDMARKS")
.Add(output_norm_landmarks.release(), cc->InputTimestamp());
}
// Output absolute landmarks.
if (cc->Outputs().HasTag("LANDMARKS")) {
cc->Outputs()
.Tag("LANDMARKS")
.Add(output_landmarks.release(), cc->InputTimestamp());
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteTensorsToLandmarksCalculator::LoadOptions(
CalculatorContext* cc) {
// Get calculator options specified in the graph.
options_ =
cc->Options<::mediapipe::TfLiteTensorsToLandmarksCalculatorOptions>();
num_landmarks_ = options_.num_landmarks();
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,45 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The option proto for the TfLiteTensorsToLandmarksCalculator.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TfLiteTensorsToLandmarksCalculatorOptions {
extend .mediapipe.CalculatorOptions {
optional TfLiteTensorsToLandmarksCalculatorOptions ext = 257405002;
}
// Number of landmarks from the output of the model.
required int32 num_landmarks = 1;
// Size of the input image for the model. These options are used only when
// normalized landmarks is needed.
optional int32 input_image_width = 2;
optional int32 input_image_height = 3;
// Whether the detection coordinates from the input tensors should be flipped
// vertically (along the y-direction). This is useful, for example, when the
// input tensors represent detections defined with a coordinate system where
// the origin is at the top-left corner, whereas the desired detection
// representation has a bottom-left origin (e.g., in OpenGL).
optional bool flip_vertically = 4 [default = false];
// A value that z values should be divided by.
optional float normalize_z = 5 [default = 1.0];
}

View File

@ -20,6 +20,9 @@
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/resource_util.h"
#include "tensorflow/lite/interpreter.h"
@ -39,8 +42,11 @@ namespace {
constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
// Commonly used to compute the number of blocks to launch in a kernel.
int RoundUp(const int size, const int multiple) {
return (size + multiple - 1) / multiple;
int NumGroups(const int size, const int group_size) { // NOLINT
return (size + group_size - 1) / group_size;
}
float Clamp(float val, float min, float max) {
return std::min(std::max(val, min), max);
}
} // namespace
@ -58,7 +64,7 @@ using ::tflite::gpu::gl::GlShader;
// Performs optional upscale to REFERENCE_IMAGE dimensions if provided,
// otherwise the mask is the same size as input tensor.
//
// Note: This calculator is currently GPU only, so only *_GPU tags can be used.
// Produces result as an RGBA image, with the mask in both R & A channels.
//
// Inputs:
// One of the following TENSORS tags:
@ -71,11 +77,11 @@ using ::tflite::gpu::gl::GlShader;
// REFERENCE_IMAGE_GPU (optional): A GpuBuffer input image,
// used only for output dimensions.
// One of the following PREV_MASK tags:
// PREV_MASK (optional): An ImageFrame input mask, Gray, RGB or RGBA.
// PREV_MASK_GPU (optional): A GpuBuffer input mask, RGBA.
// PREV_MASK (optional): An ImageFrame input mask, Gray, RGB or RGBA, [0-255].
// PREV_MASK_GPU (optional): A GpuBuffer input mask, RGBA, [0-1].
// Output:
// One of the following MASK tags:
// MASK: An ImageFrame output mask, Gray, RGB or RGBA.
// MASK: An ImageFrame output mask, RGBA.
// MASK_GPU: A GpuBuffer output mask, RGBA.
//
// Options:
@ -141,10 +147,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
}
if (cc->Inputs().HasTag("PREV_MASK")) {
cc->Inputs().Tag("PREV_MASK").Set<mediapipe::ImageFrame>();
cc->Inputs().Tag("PREV_MASK").Set<ImageFrame>();
}
if (cc->Inputs().HasTag("REFERENCE_IMAGE")) {
cc->Inputs().Tag("REFERENCE_IMAGE").Set<mediapipe::ImageFrame>();
cc->Inputs().Tag("REFERENCE_IMAGE").Set<ImageFrame>();
}
// Inputs GPU.
@ -162,7 +168,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
// Outputs.
if (cc->Outputs().HasTag("MASK")) {
cc->Outputs().Tag("MASK").Set<mediapipe::ImageFrame>();
cc->Outputs().Tag("MASK").Set<ImageFrame>();
}
#if defined(__ANDROID__)
if (cc->Outputs().HasTag("MASK_GPU")) {
@ -179,6 +185,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Open(
CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
if (cc->Inputs().HasTag("TENSORS_GPU")) {
use_gpu_ = true;
#if defined(__ANDROID__)
@ -238,7 +246,107 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::ProcessCpu(
CalculatorContext* cc) {
return ::mediapipe::UnimplementedError("CPU support is not implemented yet.");
if (cc->Inputs().Tag("TENSORS").IsEmpty()) {
return ::mediapipe::OkStatus();
}
// Get input streams.
const auto& input_tensors =
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
const bool has_prev_mask = cc->Inputs().HasTag("PREV_MASK") &&
!cc->Inputs().Tag("PREV_MASK").IsEmpty();
const ImageFrame placeholder;
const auto& input_mask = has_prev_mask
? cc->Inputs().Tag("PREV_MASK").Get<ImageFrame>()
: placeholder;
int output_width = tensor_width_, output_height = tensor_height_;
if (cc->Inputs().HasTag("REFERENCE_IMAGE")) {
const auto& input_image =
cc->Inputs().Tag("REFERENCE_IMAGE").Get<ImageFrame>();
output_width = input_image.Width();
output_height = input_image.Height();
}
RET_CHECK_EQ(input_tensors.size(), 1);
// Create initial working mask.
cv::Mat small_mask_mat(cv::Size(tensor_width_, tensor_height_), CV_8UC4);
// Get input previous mask.
cv::Mat input_mask_mat;
if (has_prev_mask) {
cv::Mat temp_mask_mat = formats::MatView(&input_mask);
if (temp_mask_mat.channels() != 4) {
cv::Mat converted_mat;
cv::cvtColor(temp_mask_mat, converted_mat,
temp_mask_mat.channels() == 1 ? cv::COLOR_GRAY2RGBA
: cv::COLOR_RGB2RGBA);
temp_mask_mat = converted_mat.clone();
}
cv::resize(temp_mask_mat, input_mask_mat, small_mask_mat.size());
}
// Copy input tensor.
const TfLiteTensor* raw_input_tensor = &input_tensors[0];
const float* raw_input_data = raw_input_tensor->data.f;
cv::Mat tensor_mat(cv::Size(tensor_width_, tensor_height_),
CV_MAKETYPE(CV_32F, tensor_channels_));
float* tensor_mat_ptr = tensor_mat.ptr<float>();
memcpy(tensor_mat_ptr, raw_input_data, raw_input_tensor->bytes);
// Process mask tensor.
// Run softmax over tensor output and blend with previous mask.
const int output_layer_index = options_.output_layer_index();
const float combine_with_prev_ratio = options_.combine_with_previous_ratio();
for (int i = 0; i < tensor_height_; ++i) {
for (int j = 0; j < tensor_width_; ++j) {
// Only two channel input tensor is supported.
const cv::Vec2f input_pix = tensor_mat.at<cv::Vec2f>(i, j);
const float shift = std::max(input_pix[0], input_pix[1]);
const float softmax_denom =
std::exp(input_pix[0] - shift) + std::exp(input_pix[1] - shift);
float new_mask_value =
std::exp(input_pix[output_layer_index] - shift) / softmax_denom;
// Combine previous value with current using uncertainty^2 as mixing coeff
if (has_prev_mask) {
const float prev_mask_value =
input_mask_mat.at<cv::Vec4b>(i, j)[0] / 255.0f;
const float eps = 0.001;
float uncertainty_alpha =
1.0 +
(new_mask_value * std::log(new_mask_value + eps) +
(1.0 - new_mask_value) * std::log(1.0 - new_mask_value + eps)) /
std::log(2.0f);
uncertainty_alpha = Clamp(uncertainty_alpha, 0.0f, 1.0f);
// Equivalent to: a = 1 - (1 - a) * (1 - a); (squaring the uncertainty)
uncertainty_alpha *= 2.0 - uncertainty_alpha;
const float mixed_mask_value =
new_mask_value * uncertainty_alpha +
prev_mask_value * (1.0f - uncertainty_alpha);
new_mask_value = mixed_mask_value * combine_with_prev_ratio +
(1.0f - combine_with_prev_ratio) * new_mask_value;
}
const uchar mask_value = static_cast<uchar>(new_mask_value * 255);
// Set both R and A channels for convenience.
const cv::Vec4b out_value = {mask_value, 0, 0, mask_value};
small_mask_mat.at<cv::Vec4b>(i, j) = out_value;
}
}
if (options_.flip_vertically()) cv::flip(small_mask_mat, small_mask_mat, 0);
// Upsample small mask into output.
cv::Mat large_mask_mat;
cv::resize(small_mask_mat, large_mask_mat,
cv::Size(output_width, output_height));
// Send out image as CPU packet.
std::unique_ptr<ImageFrame> output_mask = absl::make_unique<ImageFrame>(
ImageFormat::SRGBA, output_width, output_height);
cv::Mat output_mat = formats::MatView(output_mask.get());
large_mask_mat.copyTo(output_mat);
cc->Outputs().Tag("MASK").Add(output_mask.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
// Steps:
@ -267,10 +375,9 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
output_width = input_image.width();
output_height = input_image.height();
}
RET_CHECK_EQ(input_tensors.size(), 1);
// Create initial output mask texture.
// Create initial working mask texture.
::tflite::gpu::gl::GlTexture small_mask_texture;
::tflite::gpu::gl::CreateReadWriteRgbaImageTexture(
tflite::gpu::DataType::UINT8, // GL_RGBA8
@ -285,6 +392,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
tflite::gpu::gl::CopyBuffer(input_tensors[0], *tensor_buffer_);
// Run shader, process mask tensor.
// Run softmax over tensor output and blend with previous mask.
{
const int output_index = 0;
glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0,
@ -292,8 +400,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
tensor_buffer_->BindToIndex(2);
const tflite::gpu::uint3 workgroups = {
RoundUp(tensor_width_, kWorkgroupSize),
RoundUp(tensor_height_, kWorkgroupSize), 1};
NumGroups(tensor_width_, kWorkgroupSize),
NumGroups(tensor_height_, kWorkgroupSize), 1};
if (!has_prev_mask) {
mask_program_no_prev_->Dispatch(workgroups);
@ -330,8 +438,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
// Cleanup
input_mask_texture.Release();
output_texture.Release();
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}
@ -445,7 +553,7 @@ void main() {
int linear_index = gid.y * out_width + gid.x;
vec2 input_value = input_data.elements[linear_index];
// Only two channel output is supported.
// Only two channel input tensor is supported.
vec2 input_px = input_value.rg;
float shift = max(input_px.r, input_px.g);
float softmax_denom = exp(input_px.r - shift) + exp(input_px.g - shift);
@ -474,8 +582,8 @@ void main() {
(1.0f - combine_with_previous_ratio) * new_mask_value;
#endif // READ_PREVIOUS
// Texture coordinates are inverted on y axis.
ivec2 output_coordinate = ivec2(gid.x, out_height - gid.y - 1);
int y_coord = int($4);
ivec2 output_coordinate = ivec2(gid.x, y_coord);
// Set both R and A channels for convenience.
vec4 out_value = vec4(new_mask_value, 0.0, 0.0, new_mask_value);
imageStore(output_texture, output_coordinate, out_value);
@ -483,10 +591,12 @@ void main() {
const std::string shader_src_no_previous = absl::Substitute(
shader_src_template, kWorkgroupSize, options_.output_layer_index(),
options_.combine_with_previous_ratio(), "");
options_.combine_with_previous_ratio(), "",
options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y");
const std::string shader_src_with_previous = absl::Substitute(
shader_src_template, kWorkgroupSize, options_.output_layer_index(),
options_.combine_with_previous_ratio(), "#define READ_PREVIOUS");
options_.combine_with_previous_ratio(), "#define READ_PREVIOUS",
options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y");
auto status = ::tflite::gpu::OkStatus();

View File

@ -34,4 +34,7 @@ message TfLiteTensorsToSegmentationCalculatorOptions {
// Model specific: Channel to use for processing tensor.
optional int32 output_layer_index = 5 [default = 1];
// Flip result image mask along y-axis.
optional bool flip_vertically = 6;
}

Some files were not shown because too many files have changed in this diff Show More