This commit is contained in:
AK391 2021-06-10 23:00:24 +00:00
parent ba15087099
commit 7c22caf9d6
1940 changed files with 32 additions and 370573 deletions

View File

@ -1,5 +1,13 @@
import mediapipe as mp
import gradio as gr
import cv2
import torch
# Images
torch.hub.download_url_to_file('https://artbreeder.b-cdn.net/imgs/c789e54661bfb432c5522a36553f.jpeg', 'face1.jpg')
torch.hub.download_url_to_file('https://artbreeder.b-cdn.net/imgs/c86622e8cb58d490e35b01cb9996.jpeg', 'face2.jpg')
mp_face_mesh = mp.solutions.face_mesh
# Prepare DrawingSpec for drawing the face landmarks later.
@ -16,16 +24,28 @@ def inference(image):
# Convert the BGR image to RGB and process it with MediaPipe Face Mesh.
results = face_mesh.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Draw face landmarks of each face.
print(f'Face landmarks of {name}:')
if not results.multi_face_landmarks:
continue
annotated_image = image.copy()
for face_landmarks in results.multi_face_landmarks:
mp_drawing.draw_landmarks(
image=annotated_image,
landmark_list=face_landmarks,
connections=mp_face_mesh.FACE_CONNECTIONS,
landmark_drawing_spec=drawing_spec,
connection_drawing_spec=drawing_spec)
return annotated_image
mp_drawing.draw_landmarks(
image=annotated_image,
landmark_list=face_landmarks,
connections=mp_face_mesh.FACE_CONNECTIONS,
landmark_drawing_spec=drawing_spec,
connection_drawing_spec=drawing_spec)
return annotated_image
title = "Face Mesh"
description = "demo for Face Mesh. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2006.10962'>Attention Mesh: High-fidelity Face Mesh Prediction in Real-time</a> | <a href='https://github.com/google/mediapipe'>Github Repo</a></p>"
gr.Interface(
inference,
[gr.inputs.Image(label="Input")],
gr.outputs.Image(type="pil", label="Output"),
title=title,
description=description,
article=article,
examples=[
["face1.jpg"],
["face2.jpg"]
]).launch(debug=True)

View File

@ -1,145 +0,0 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
# Note: yes, these need to use "//external:android/crosstool", not
# @androidndk//:default_crosstool.
config_setting(
name = "android",
values = {"crosstool_top": "//external:android/crosstool"},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_x86",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "x86",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_x86_64",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "x86_64",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_armeabi",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "armeabi",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_arm",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "armeabi-v7a",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_arm64",
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "arm64-v8a",
},
visibility = ["//visibility:public"],
)
# Note: this cannot just match "apple_platform_type": "macos" because that option
# defaults to "macos" even when building on Linux!
alias(
name = "macos",
actual = select({
":macos_i386": ":macos_i386",
":macos_x86_64": ":macos_x86_64",
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
}),
visibility = ["//visibility:public"],
)
# Note: this also matches on crosstool_top so that it does not produce ambiguous
# selectors when used together with "android".
config_setting(
name = "ios",
values = {
"crosstool_top": "@bazel_tools//tools/cpp:toolchain",
"apple_platform_type": "ios",
},
visibility = ["//visibility:public"],
)
alias(
name = "apple",
actual = select({
":macos": ":macos",
":ios": ":ios",
"//conditions:default": ":ios", # Arbitrarily chosen from above.
}),
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_i386",
values = {
"apple_platform_type": "macos",
"cpu": "darwin",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_x86_64",
values = {
"apple_platform_type": "macos",
"cpu": "darwin_x86_64",
},
visibility = ["//visibility:public"],
)
[
config_setting(
name = arch,
values = {"cpu": arch},
visibility = ["//visibility:public"],
)
for arch in [
"ios_i386",
"ios_x86_64",
"ios_armv7",
"ios_arm64",
"ios_arm64e",
]
]
config_setting(
name = "windows",
values = {"cpu": "x64_windows"},
)
exports_files(
["provisioning_profile.mobileprovision"],
visibility = ["//visibility:public"],
)

View File

@ -1,137 +0,0 @@
{
"additionalFilePaths" : [
"/BUILD",
"mediapipe/BUILD",
"mediapipe/examples/ios/common/BUILD",
"mediapipe/examples/ios/facedetectioncpu/BUILD",
"mediapipe/examples/ios/facedetectiongpu/BUILD",
"mediapipe/examples/ios/faceeffect/BUILD",
"mediapipe/examples/ios/facemeshgpu/BUILD",
"mediapipe/examples/ios/handdetectiongpu/BUILD",
"mediapipe/examples/ios/handtrackinggpu/BUILD",
"mediapipe/examples/ios/helloworld/BUILD",
"mediapipe/examples/ios/holistictrackinggpu/BUILD",
"mediapipe/examples/ios/iristrackinggpu/BUILD",
"mediapipe/examples/ios/objectdetectioncpu/BUILD",
"mediapipe/examples/ios/objectdetectiongpu/BUILD",
"mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD",
"mediapipe/examples/ios/posetrackinggpu/BUILD",
"mediapipe/examples/ios/selfiesegmentationgpu/BUILD",
"mediapipe/framework/BUILD",
"mediapipe/gpu/BUILD",
"mediapipe/objc/BUILD",
"mediapipe/objc/testing/app/BUILD"
],
"buildTargets" : [
"//mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp",
"//mediapipe/examples/ios/facedetectiongpu:FaceDetectionGpuApp",
"//mediapipe/examples/ios/faceeffect:FaceEffectApp",
"//mediapipe/examples/ios/facemeshgpu:FaceMeshGpuApp",
"//mediapipe/examples/ios/handdetectiongpu:HandDetectionGpuApp",
"//mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp",
"//mediapipe/examples/ios/helloworld:HelloWorldApp",
"//mediapipe/examples/ios/holistictrackinggpu:HolisticTrackingGpuApp",
"//mediapipe/examples/ios/iristrackinggpu:IrisTrackingGpuApp",
"//mediapipe/examples/ios/objectdetectioncpu:ObjectDetectionCpuApp",
"//mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp",
"//mediapipe/examples/ios/objectdetectiontrackinggpu:ObjectDetectionTrackingGpuApp",
"//mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp",
"//mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp",
"//mediapipe/objc:mediapipe_framework_ios"
],
"optionSet" : {
"BazelBuildOptionsDebug" : {
"p" : "$(inherited)"
},
"BazelBuildOptionsRelease" : {
"p" : "$(inherited)"
},
"BazelBuildStartupOptionsDebug" : {
"p" : "$(inherited)"
},
"BazelBuildStartupOptionsRelease" : {
"p" : "$(inherited)"
},
"BuildActionPostActionScript" : {
"p" : "$(inherited)"
},
"BuildActionPreActionScript" : {
"p" : "$(inherited)"
},
"CommandlineArguments" : {
"p" : "$(inherited)"
},
"EnvironmentVariables" : {
"p" : "$(inherited)"
},
"LaunchActionPostActionScript" : {
"p" : "$(inherited)"
},
"LaunchActionPreActionScript" : {
"p" : "$(inherited)"
},
"ProjectGenerationBazelStartupOptions" : {
"p" : "$(inherited)"
},
"TestActionPostActionScript" : {
"p" : "$(inherited)"
},
"TestActionPreActionScript" : {
"p" : "$(inherited)"
}
},
"projectName" : "Mediapipe",
"sourceFilters" : [
"mediapipe",
"mediapipe/calculators",
"mediapipe/calculators/core",
"mediapipe/calculators/image",
"mediapipe/calculators/internal",
"mediapipe/calculators/tflite",
"mediapipe/calculators/util",
"mediapipe/examples",
"mediapipe/examples/ios",
"mediapipe/examples/ios/common",
"mediapipe/examples/ios/common/Base.lproj",
"mediapipe/examples/ios/facedetectioncpu",
"mediapipe/examples/ios/facedetectiongpu",
"mediapipe/examples/ios/faceeffect",
"mediapipe/examples/ios/faceeffect/Base.lproj",
"mediapipe/examples/ios/handdetectiongpu",
"mediapipe/examples/ios/handtrackinggpu",
"mediapipe/examples/ios/helloworld",
"mediapipe/examples/ios/holistictrackinggpu",
"mediapipe/examples/ios/iristrackinggpu",
"mediapipe/examples/ios/objectdetectioncpu",
"mediapipe/examples/ios/objectdetectiongpu",
"mediapipe/examples/ios/posetrackinggpu",
"mediapipe/examples/ios/selfiesegmentationgpu",
"mediapipe/framework",
"mediapipe/framework/deps",
"mediapipe/framework/formats",
"mediapipe/framework/formats/annotation",
"mediapipe/framework/formats/object_detection",
"mediapipe/framework/port",
"mediapipe/framework/profiler",
"mediapipe/framework/stream_handler",
"mediapipe/framework/tool",
"mediapipe/gpu",
"mediapipe/graphs",
"mediapipe/graphs/edge_detection",
"mediapipe/graphs/face_detection",
"mediapipe/graphs/face_geometry",
"mediapipe/graphs/hand_tracking",
"mediapipe/graphs/object_detection",
"mediapipe/graphs/pose_tracking",
"mediapipe/graphs/selfie_segmentation",
"mediapipe/models",
"mediapipe/modules",
"mediapipe/objc",
"mediapipe/util",
"mediapipe/util/android",
"mediapipe/util/android/file",
"mediapipe/util/android/file/base",
"mediapipe/util/tflite",
"mediapipe/util/tflite/operations"
]
}

View File

@ -1,30 +0,0 @@
{
"configDefaults" : {
"optionSet" : {
"CLANG_CXX_LANGUAGE_STANDARD" : {
"p" : "c++14"
}
}
},
"packages" : [
"",
"mediapipe",
"mediapipe/examples/ios",
"mediapipe/examples/ios/facedetectioncpu",
"mediapipe/examples/ios/facedetectiongpu",
"mediapipe/examples/ios/faceeffect",
"mediapipe/examples/ios/facemeshgpu",
"mediapipe/examples/ios/handdetectiongpu",
"mediapipe/examples/ios/handtrackinggpu",
"mediapipe/examples/ios/holistictrackinggpu",
"mediapipe/examples/ios/iristrackinggpu",
"mediapipe/examples/ios/objectdetectioncpu",
"mediapipe/examples/ios/objectdetectiongpu",
"mediapipe/examples/ios/objectdetectiontrackinggpu",
"mediapipe/examples/ios/posetrackinggpu",
"mediapipe/examples/ios/selfiesegmentationgpu",
"mediapipe/objc"
],
"projectName" : "Mediapipe",
"workspaceRoot" : "../.."
}

View File

@ -1,14 +0,0 @@
"""Copyright 2019 - 2020 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

View File

@ -1,357 +0,0 @@
# Copyright 2019, 2021 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "mfcc_mel_calculators_proto",
srcs = ["mfcc_mel_calculators.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_cc_proto_library(
name = "mfcc_mel_calculators_cc_proto",
srcs = ["mfcc_mel_calculators.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":mfcc_mel_calculators_proto"],
)
proto_library(
name = "rational_factor_resample_calculator_proto",
srcs = ["rational_factor_resample_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_cc_proto_library(
name = "rational_factor_resample_calculator_cc_proto",
srcs = ["rational_factor_resample_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":rational_factor_resample_calculator_proto"],
)
proto_library(
name = "spectrogram_calculator_proto",
srcs = ["spectrogram_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
mediapipe_cc_proto_library(
name = "spectrogram_calculator_cc_proto",
srcs = ["spectrogram_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":spectrogram_calculator_proto"],
)
proto_library(
name = "stabilized_log_calculator_proto",
srcs = ["stabilized_log_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_cc_proto_library(
name = "stabilized_log_calculator_cc_proto",
srcs = ["stabilized_log_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":stabilized_log_calculator_proto"],
)
proto_library(
name = "time_series_framer_calculator_proto",
srcs = ["time_series_framer_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_cc_proto_library(
name = "time_series_framer_calculator_cc_proto",
srcs = ["time_series_framer_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":time_series_framer_calculator_proto"],
)
cc_library(
name = "audio_decoder_calculator",
srcs = ["audio_decoder_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status",
"//mediapipe/util:audio_decoder",
"//mediapipe/util:audio_decoder_cc_proto",
],
alwayslink = 1,
)
cc_library(
name = "basic_time_series_calculators",
srcs = ["basic_time_series_calculators.cc"],
hdrs = ["basic_time_series_calculators.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/util:time_series_util",
"@com_google_absl//absl/strings",
"@eigen_archive//:eigen3",
],
alwayslink = 1,
)
cc_library(
name = "mfcc_mel_calculators",
srcs = ["mfcc_mel_calculators.cc"],
visibility = ["//visibility:public"],
deps = [
":mfcc_mel_calculators_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util",
"@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp/mfcc",
"@eigen_archive//:eigen3",
],
alwayslink = 1,
)
cc_library(
name = "rational_factor_resample_calculator",
srcs = ["rational_factor_resample_calculator.cc"],
hdrs = ["rational_factor_resample_calculator.h"],
visibility = ["//visibility:public"],
deps = [
":rational_factor_resample_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/util:time_series_util",
"@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp:resampler",
"@com_google_audio_tools//audio/dsp:resampler_q",
"@eigen_archive//:eigen3",
],
alwayslink = 1,
)
cc_library(
name = "stabilized_log_calculator",
srcs = ["stabilized_log_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":stabilized_log_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util",
],
alwayslink = 1,
)
cc_library(
name = "spectrogram_calculator",
srcs = ["spectrogram_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":spectrogram_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util",
"@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp:window_functions",
"@com_google_audio_tools//audio/dsp/spectrogram",
"@eigen_archive//:eigen3",
],
alwayslink = 1,
)
cc_library(
name = "time_series_framer_calculator",
srcs = ["time_series_framer_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":time_series_framer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util",
"@com_google_audio_tools//audio/dsp:window_functions",
"@eigen_archive//:eigen3",
],
alwayslink = 1,
)
cc_test(
name = "audio_decoder_calculator_test",
srcs = ["audio_decoder_calculator_test.cc"],
data = ["//mediapipe/calculators/audio/testdata:test_audios"],
deps = [
":audio_decoder_calculator",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/flags:flag",
],
)
cc_test(
name = "basic_time_series_calculators_test",
srcs = ["basic_time_series_calculators_test.cc"],
deps = [
":basic_time_series_calculators",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/util:time_series_test_util",
"@eigen_archive//:eigen3",
],
)
cc_test(
name = "mfcc_mel_calculators_test",
srcs = ["mfcc_mel_calculators_test.cc"],
deps = [
":mfcc_mel_calculators",
":mfcc_mel_calculators_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util",
"@eigen_archive//:eigen3",
],
)
cc_test(
name = "spectrogram_calculator_test",
srcs = ["spectrogram_calculator_test.cc"],
deps = [
":spectrogram_calculator",
":spectrogram_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:benchmark",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util",
"@com_google_audio_tools//audio/dsp:number_util",
"@eigen_archive//:eigen3",
],
)
cc_test(
name = "stabilized_log_calculator_test",
srcs = ["stabilized_log_calculator_test.cc"],
deps = [
":stabilized_log_calculator",
":stabilized_log_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util",
"@eigen_archive//:eigen3",
],
)
cc_test(
name = "time_series_framer_calculator_test",
srcs = ["time_series_framer_calculator_test.cc"],
deps = [
":time_series_framer_calculator",
":time_series_framer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util",
"@com_google_audio_tools//audio/dsp:window_functions",
"@eigen_archive//:eigen3",
],
)
cc_test(
name = "rational_factor_resample_calculator_test",
srcs = ["rational_factor_resample_calculator_test.cc"],
deps = [
":rational_factor_resample_calculator",
":rational_factor_resample_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:validate_type",
"//mediapipe/util:time_series_test_util",
"@com_google_audio_tools//audio/dsp:signal_vector_util",
"@eigen_archive//:eigen3",
],
)

View File

@ -1,109 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/audio_decoder.h"
#include "mediapipe/util/audio_decoder.pb.h"
namespace mediapipe {
// The AudioDecoderCalculator decodes an audio stream of the media file. It
// produces two output streams contain audio packets and the header infomation.
//
// Output Streams:
// AUDIO: Output audio frames (Matrix).
// AUDIO_HEADER:
// Optional audio header information output
// Input Side Packets:
// INPUT_FILE_PATH: The input file path.
//
// Example config:
// node {
// calculator: "AudioDecoderCalculator"
// input_side_packet: "INPUT_FILE_PATH:input_file_path"
// output_stream: "AUDIO:audio"
// output_stream: "AUDIO_HEADER:audio_header"
// node_options {
// [type.googleapis.com/mediapipe.AudioDecoderOptions]: {
// audio_stream { stream_index: 0 }
// start_time: 0
// end_time: 1
// }
// }
//
// TODO: support decoding multiple streams.
class AudioDecoderCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
private:
std::unique_ptr<AudioDecoder> decoder_;
};
absl::Status AudioDecoderCalculator::GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set<std::string>();
if (cc->InputSidePackets().HasTag("OPTIONS")) {
cc->InputSidePackets().Tag("OPTIONS").Set<mediapipe::AudioDecoderOptions>();
}
cc->Outputs().Tag("AUDIO").Set<Matrix>();
if (cc->Outputs().HasTag("AUDIO_HEADER")) {
cc->Outputs().Tag("AUDIO_HEADER").SetNone();
}
return absl::OkStatus();
}
absl::Status AudioDecoderCalculator::Open(CalculatorContext* cc) {
const std::string& input_file_path =
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get<std::string>();
const auto& decoder_options =
tool::RetrieveOptions(cc->Options<mediapipe::AudioDecoderOptions>(),
cc->InputSidePackets(), "OPTIONS");
decoder_ = absl::make_unique<AudioDecoder>();
MP_RETURN_IF_ERROR(decoder_->Initialize(input_file_path, decoder_options));
std::unique_ptr<mediapipe::TimeSeriesHeader> header =
absl::make_unique<mediapipe::TimeSeriesHeader>();
if (decoder_->FillAudioHeader(decoder_options.audio_stream(0), header.get())
.ok()) {
// Only pass on a header if the decoder could actually produce one.
// otherwise, the header will be empty.
cc->Outputs().Tag("AUDIO_HEADER").SetHeader(Adopt(header.release()));
}
cc->Outputs().Tag("AUDIO_HEADER").Close();
return absl::OkStatus();
}
absl::Status AudioDecoderCalculator::Process(CalculatorContext* cc) {
Packet data;
int options_index = -1;
auto status = decoder_->GetData(&options_index, &data);
if (status.ok()) {
cc->Outputs().Tag("AUDIO").AddPacket(data);
}
return status;
}
absl::Status AudioDecoderCalculator::Close(CalculatorContext* cc) {
return decoder_->Close();
}
REGISTER_CALCULATOR(AudioDecoderCalculator);
} // namespace mediapipe

View File

@ -1,150 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/flags/flag.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
TEST(AudioDecoderCalculatorTest, TestWAV) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "AudioDecoderCalculator"
input_side_packet: "INPUT_FILE_PATH:input_file_path"
output_stream: "AUDIO:audio"
output_stream: "AUDIO_HEADER:audio_header"
node_options {
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
audio_stream { stream_index: 0 }
}
})pb");
CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./",
"/mediapipe/calculators/audio/"
"testdata/sine_wave_1k_44100_mono_2_sec_wav.audio"));
MP_ASSERT_OK(runner.Run());
MP_EXPECT_OK(runner.Outputs()
.Tag("AUDIO_HEADER")
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
const mediapipe::TimeSeriesHeader& header =
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.Get<mediapipe::TimeSeriesHeader>();
EXPECT_EQ(44100, header.sample_rate());
EXPECT_EQ(1, header.num_channels());
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
std::ceil(44100.0 * 2 / 2048));
}
TEST(AudioDecoderCalculatorTest, Test48KWAV) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "AudioDecoderCalculator"
input_side_packet: "INPUT_FILE_PATH:input_file_path"
output_stream: "AUDIO:audio"
output_stream: "AUDIO_HEADER:audio_header"
node_options {
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
audio_stream { stream_index: 0 }
}
})pb");
CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./",
"/mediapipe/calculators/audio/"
"testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio"));
MP_ASSERT_OK(runner.Run());
MP_EXPECT_OK(runner.Outputs()
.Tag("AUDIO_HEADER")
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
const mediapipe::TimeSeriesHeader& header =
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.Get<mediapipe::TimeSeriesHeader>();
EXPECT_EQ(48000, header.sample_rate());
EXPECT_EQ(2, header.num_channels());
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
std::ceil(48000.0 * 2 / 1024));
}
TEST(AudioDecoderCalculatorTest, TestMP3) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "AudioDecoderCalculator"
input_side_packet: "INPUT_FILE_PATH:input_file_path"
output_stream: "AUDIO:audio"
output_stream: "AUDIO_HEADER:audio_header"
node_options {
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
audio_stream { stream_index: 0 }
}
})pb");
CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./",
"/mediapipe/calculators/audio/"
"testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio"));
MP_ASSERT_OK(runner.Run());
MP_EXPECT_OK(runner.Outputs()
.Tag("AUDIO_HEADER")
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
const mediapipe::TimeSeriesHeader& header =
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.Get<mediapipe::TimeSeriesHeader>();
EXPECT_EQ(44100, header.sample_rate());
EXPECT_EQ(2, header.num_channels());
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
std::ceil(44100.0 * 2 / 1152));
}
TEST(AudioDecoderCalculatorTest, TestAAC) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "AudioDecoderCalculator"
input_side_packet: "INPUT_FILE_PATH:input_file_path"
output_stream: "AUDIO:audio"
output_stream: "AUDIO_HEADER:audio_header"
node_options {
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
audio_stream { stream_index: 0 }
}
})pb");
CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./",
"/mediapipe/calculators/audio/"
"testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio"));
MP_ASSERT_OK(runner.Run());
MP_EXPECT_OK(runner.Outputs()
.Tag("AUDIO_HEADER")
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
const mediapipe::TimeSeriesHeader& header =
runner.Outputs()
.Tag("AUDIO_HEADER")
.header.Get<mediapipe::TimeSeriesHeader>();
EXPECT_EQ(44100, header.sample_rate());
EXPECT_EQ(2, header.num_channels());
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
std::ceil(44100.0 * 2 / 1024));
}
} // namespace mediapipe

View File

@ -1,405 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Basic Calculators that operate on TimeSeries streams.
#include "mediapipe/calculators/audio/basic_time_series_calculators.h"
#include <cmath>
#include <memory>
#include "Eigen/Core"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
namespace {
static bool SafeMultiply(int x, int y, int* result) {
static_assert(sizeof(int64) >= 2 * sizeof(int),
"Unable to detect overflow after multiplication");
const int64 big = static_cast<int64>(x) * static_cast<int64>(y);
if (big > static_cast<int64>(INT_MIN) && big < static_cast<int64>(INT_MAX)) {
if (result != nullptr) *result = static_cast<int>(big);
return true;
} else {
return false;
}
}
} // namespace
absl::Status BasicTimeSeriesCalculatorBase::GetContract(
CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Input stream with TimeSeriesHeader.
);
cc->Outputs().Index(0).Set<Matrix>(
// Output stream with TimeSeriesHeader.
);
return absl::OkStatus();
}
absl::Status BasicTimeSeriesCalculatorBase::Open(CalculatorContext* cc) {
TimeSeriesHeader input_header;
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), &input_header));
auto output_header = new TimeSeriesHeader(input_header);
MP_RETURN_IF_ERROR(MutateHeader(output_header));
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
cc->SetOffset(0);
return absl::OkStatus();
}
absl::Status BasicTimeSeriesCalculatorBase::Process(CalculatorContext* cc) {
const Matrix& input = cc->Inputs().Index(0).Get<Matrix>();
MP_RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader(
input, cc->Inputs().Index(0).Header().Get<TimeSeriesHeader>()));
std::unique_ptr<Matrix> output(new Matrix(ProcessMatrix(input)));
MP_RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader(
*output, cc->Outputs().Index(0).Header().Get<TimeSeriesHeader>()));
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return absl::OkStatus();
}
absl::Status BasicTimeSeriesCalculatorBase::MutateHeader(
TimeSeriesHeader* output_header) {
return absl::OkStatus();
}
// Calculator to sum an input time series across channels. This is
// useful for e.g. computing 'summary SAI' pitchogram features.
//
// Options proto: None.
class SumTimeSeriesAcrossChannelsCalculator
: public BasicTimeSeriesCalculatorBase {
protected:
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_channels(1);
return absl::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.colwise().sum();
}
};
REGISTER_CALCULATOR(SumTimeSeriesAcrossChannelsCalculator);
// Calculator to average an input time series across channels. This is
// useful for e.g. converting stereo or multi-channel files to mono.
//
// Options proto: None.
class AverageTimeSeriesAcrossChannelsCalculator
: public BasicTimeSeriesCalculatorBase {
protected:
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_channels(1);
return absl::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.colwise().mean();
}
};
REGISTER_CALCULATOR(AverageTimeSeriesAcrossChannelsCalculator);
// Calculator to convert a (temporal) summary SAI stream (a single-channel
// stream output by SumTimeSeriesAcrossChannelsCalculator) into pitchogram
// frames by transposing the input packets, swapping the time and channel axes.
//
// Options proto: None.
class SummarySaiToPitchogramCalculator : public BasicTimeSeriesCalculatorBase {
protected:
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
if (output_header->num_channels() != 1) {
return tool::StatusInvalid(
absl::StrCat("Expected single-channel input, got ",
output_header->num_channels()));
}
output_header->set_num_channels(output_header->num_samples());
output_header->set_num_samples(1);
output_header->set_sample_rate(output_header->packet_rate());
return absl::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.transpose();
}
};
REGISTER_CALCULATOR(SummarySaiToPitchogramCalculator);
// Calculator to reverse the order of channels in TimeSeries packets.
// This is useful for e.g. interfacing with the speech pipeline which uses the
// opposite convention to the hearing filterbanks.
//
// Options proto: None.
class ReverseChannelOrderCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.colwise().reverse();
}
};
REGISTER_CALCULATOR(ReverseChannelOrderCalculator);
// Calculator to flatten all samples in a TimeSeries packet down into
// a single 'sample' vector. This is useful for e.g. stacking several
// frames of features into a single feature vector.
//
// Options proto: None.
class FlattenPacketCalculator : public BasicTimeSeriesCalculatorBase {
protected:
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
const int num_input_channels = output_header->num_channels();
const int num_input_samples = output_header->num_samples();
RET_CHECK(num_input_channels >= 0)
<< "FlattenPacketCalculator: num_input_channels < 0";
RET_CHECK(num_input_samples >= 0)
<< "FlattenPacketCalculator: num_input_samples < 0";
int output_num_channels;
RET_CHECK(SafeMultiply(num_input_channels, num_input_samples,
&output_num_channels))
<< "FlattenPacketCalculator: Multiplication failed.";
output_header->set_num_channels(output_num_channels);
output_header->set_num_samples(1);
output_header->set_sample_rate(output_header->packet_rate());
return absl::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
// Flatten by interleaving channels so that full samples are
// stacked on top of each other instead of interleaving samples
// from the same channel.
Matrix output(input_matrix.size(), 1);
for (int sample = 0; sample < input_matrix.cols(); ++sample) {
output.middleRows(sample * input_matrix.rows(), input_matrix.rows()) =
input_matrix.col(sample);
}
return output;
}
};
REGISTER_CALCULATOR(FlattenPacketCalculator);
// Calculator to subtract the within-packet mean for each channel from each
// corresponding channel.
//
// Options proto: None.
class SubtractMeanCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
Matrix mean = input_matrix.rowwise().mean();
return input_matrix - mean.replicate(1, input_matrix.cols());
}
};
REGISTER_CALCULATOR(SubtractMeanCalculator);
// Calculator to subtract the mean over all values (across all times and
// channels) in a Packet from the values in that Packet.
//
// Options proto: None.
class SubtractMeanAcrossChannelsCalculator
: public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
auto mean = input_matrix.mean();
return (input_matrix.array() - mean).matrix();
}
};
REGISTER_CALCULATOR(SubtractMeanAcrossChannelsCalculator);
// Calculator to divide all values in a Packet by the average value across all
// times and channels in the packet. This is useful for normalizing
// nonnegative quantities like power, but might cause unexpected results if used
// with Packets that can contain negative numbers.
//
// If mean is exactly zero, the output will be a matrix of all ones, because
// that's what happens in other cases where all values are equal.
//
// Options proto: None.
class DivideByMeanAcrossChannelsCalculator
: public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
auto mean = input_matrix.mean();
if (mean != 0) {
return input_matrix / mean;
// When used with nonnegative matrices, the mean will only be zero if the
// entire matrix is exactly zero. If mean is exactly zero, the output will
// be a matrix of all ones, because that's what happens in other cases
// where
// all values are equal.
} else {
return Matrix::Ones(input_matrix.rows(), input_matrix.cols());
}
}
};
REGISTER_CALCULATOR(DivideByMeanAcrossChannelsCalculator);
// Calculator to calculate the mean for each channel.
//
// Options proto: None.
class MeanCalculator : public BasicTimeSeriesCalculatorBase {
protected:
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_samples(1);
output_header->set_sample_rate(output_header->packet_rate());
return absl::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.rowwise().mean();
}
};
REGISTER_CALCULATOR(MeanCalculator);
// Calculator to calculate the uncorrected sample standard deviation in each
// channel, independently for each Packet. I.e. divide by the number of samples
// in the Packet, not (<number of samples> - 1).
//
// Options proto: None.
class StandardDeviationCalculator : public BasicTimeSeriesCalculatorBase {
protected:
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_samples(1);
output_header->set_sample_rate(output_header->packet_rate());
return absl::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
Eigen::VectorXf mean = input_matrix.rowwise().mean();
return (input_matrix.colwise() - mean).rowwise().norm() /
sqrt(input_matrix.cols());
}
};
REGISTER_CALCULATOR(StandardDeviationCalculator);
// Calculator to calculate the covariance matrix. If the input matrix
// has N channels, the output matrix will be an N by N symmetric
// matrix.
//
// Options proto: None.
class CovarianceCalculator : public BasicTimeSeriesCalculatorBase {
protected:
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_samples(output_header->num_channels());
return absl::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
auto mean = input_matrix.rowwise().mean();
auto zero_mean_input =
input_matrix - mean.replicate(1, input_matrix.cols());
return (zero_mean_input * zero_mean_input.transpose()) /
input_matrix.cols();
}
};
REGISTER_CALCULATOR(CovarianceCalculator);
// Calculator to get the per column L2 norm of an input time series.
//
// Options proto: None.
class L2NormCalculator : public BasicTimeSeriesCalculatorBase {
protected:
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
output_header->set_num_channels(1);
return absl::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.colwise().norm();
}
};
REGISTER_CALCULATOR(L2NormCalculator);
// Calculator to convert each column of a matrix to a unit vector.
//
// Options proto: None.
class L2NormalizeColumnCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.colwise().normalized();
}
};
REGISTER_CALCULATOR(L2NormalizeColumnCalculator);
// Calculator to apply L2 normalization to the input matrix.
//
// Returns the matrix as is if the RMS is <= 1E-8.
// Options proto: None.
class L2NormalizeCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
constexpr double kEpsilon = 1e-8;
double rms = std::sqrt(input_matrix.array().square().mean());
if (rms <= kEpsilon) {
return input_matrix;
}
return input_matrix / rms;
}
};
REGISTER_CALCULATOR(L2NormalizeCalculator);
// Calculator to apply Peak normalization to the input matrix.
//
// Returns the matrix as is if the peak is <= 1E-8.
// Options proto: None.
class PeakNormalizeCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
constexpr double kEpsilon = 1e-8;
double max_pcm = input_matrix.cwiseAbs().maxCoeff();
if (max_pcm <= kEpsilon) {
return input_matrix;
}
return input_matrix / max_pcm;
}
};
REGISTER_CALCULATOR(PeakNormalizeCalculator);
// Calculator to compute the elementwise square of an input time series.
//
// Options proto: None.
class ElementwiseSquareCalculator : public BasicTimeSeriesCalculatorBase {
protected:
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.array().square();
}
};
REGISTER_CALCULATOR(ElementwiseSquareCalculator);
// Calculator that outputs first floor(num_samples / 2) of the samples.
//
// Options proto: None.
class FirstHalfSlicerCalculator : public BasicTimeSeriesCalculatorBase {
protected:
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
const int num_input_samples = output_header->num_samples();
RET_CHECK(num_input_samples >= 0)
<< "FirstHalfSlicerCalculator: num_input_samples < 0";
output_header->set_num_samples(num_input_samples / 2);
return absl::OkStatus();
}
Matrix ProcessMatrix(const Matrix& input_matrix) final {
return input_matrix.block(0, 0, input_matrix.rows(),
input_matrix.cols() / 2);
}
};
REGISTER_CALCULATOR(FirstHalfSlicerCalculator);
} // namespace mediapipe

View File

@ -1,48 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Abstract base class for basic MediaPipe calculators that operate on
// TimeSeries streams and don't require any Options protos.
// Subclasses must override ProcessMatrix, and optionally
// MutateHeader.
#ifndef MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_
#define MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
namespace mediapipe {
class BasicTimeSeriesCalculatorBase : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) final;
absl::Status Process(CalculatorContext* cc) final;
protected:
// Open() calls this method to mutate the output stream header. The input
// to this function will contain a copy of the input stream header, so
// subclasses that do not need to mutate the header do not need to override
// it.
virtual absl::Status MutateHeader(TimeSeriesHeader* output_header);
// Process() calls this method on each packet to compute the output matrix.
virtual Matrix ProcessMatrix(const Matrix& input_matrix) = 0;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_

View File

@ -1,515 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "Eigen/Core"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/util/time_series_test_util.h"
namespace mediapipe {
class SumTimeSeriesAcrossChannelsCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override {
calculator_name_ = "SumTimeSeriesAcrossChannelsCalculator";
}
};
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, IsNoOpOnSingleChannelInputs) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
const Matrix input =
Matrix::Random(header.num_channels(), header.num_samples());
Test(header, {input}, header, {input});
}
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, ConstantPacket) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
TimeSeriesHeader output_header(header);
output_header.set_num_channels(1);
Test(header,
{Matrix::Constant(header.num_channels(), header.num_samples(), 1)},
output_header,
{Matrix::Constant(1, header.num_samples(), header.num_channels())});
}
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, MultiplePackets) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
Matrix in(header.num_channels(), header.num_samples());
in << 10, -1, -1, 0, 0, 20, -2, 0, 1, 0, 30, -3, 1, 0, 12;
TimeSeriesHeader output_header(header);
output_header.set_num_channels(1);
Matrix out(1, header.num_samples());
out << 60, -6, 0, 1, 12;
Test(header, {in, 2 * in, in + Matrix::Constant(in.rows(), in.cols(), 3.5f)},
output_header,
{out, 2 * out,
out + Matrix::Constant(out.rows(), out.cols(),
3.5 * header.num_channels())});
}
class AverageTimeSeriesAcrossChannelsCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override {
calculator_name_ = "AverageTimeSeriesAcrossChannelsCalculator";
}
};
TEST_F(AverageTimeSeriesAcrossChannelsCalculatorTest,
IsNoOpOnSingleChannelInputs) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
const Matrix input =
Matrix::Random(header.num_channels(), header.num_samples());
Test(header, {input}, header, {input});
}
TEST_F(AverageTimeSeriesAcrossChannelsCalculatorTest, ConstantPacket) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
TimeSeriesHeader output_header(header);
output_header.set_num_channels(1);
Matrix input =
Matrix::Constant(header.num_channels(), header.num_samples(), 0.0);
input.row(0) = Matrix::Constant(1, header.num_samples(), 1.0);
Test(
header, {input}, output_header,
{Matrix::Constant(1, header.num_samples(), 1.0 / header.num_channels())});
}
class SummarySaiToPitchogramCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override {
calculator_name_ = "SummarySaiToPitchogramCalculator";
}
};
TEST_F(SummarySaiToPitchogramCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 1 num_samples: 3");
Matrix input(1, input_header.num_samples());
input << 3, -9, 4;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 5.0 packet_rate: 5.0 num_channels: 3 num_samples: 1");
Matrix output(input_header.num_samples(), 1);
output << 3, -9, 4;
Test(input_header, {input}, output_header, {output});
}
class ReverseChannelOrderCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "ReverseChannelOrderCalculator"; }
};
TEST_F(ReverseChannelOrderCalculatorTest, IsNoOpOnSingleChannelInputs) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
const Matrix input =
Matrix::Random(header.num_channels(), header.num_samples());
Test(header, {input}, header, {input});
}
TEST_F(ReverseChannelOrderCalculatorTest, SinglePacket) {
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 num_channels: 5 num_samples: 2");
Matrix input(header.num_channels(), header.num_samples());
input.transpose() << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
Matrix output(header.num_channels(), header.num_samples());
output.transpose() << 5, 4, 3, 2, 1, -5, -4, -3, -2, -1;
Test(header, {input}, header, {output});
}
class FlattenPacketCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "FlattenPacketCalculator"; }
};
TEST_F(FlattenPacketCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
input.transpose() << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
Matrix output(10, 1);
output << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 10.0 packet_rate: 10.0 num_channels: 10 num_samples: 1");
Test(input_header, {input}, output_header, {output});
}
class SubtractMeanCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "SubtractMeanCalculator"; }
};
TEST_F(SubtractMeanCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
Matrix output(input_header.num_channels(), input_header.num_samples());
// clang-format off
input.transpose() << 1, 0, 3, 0, 1,
-1, -2, -3, 4, 7;
output.transpose() << 1, 1, 3, -2, -3,
-1, -1, -3, 2, 3;
// clang-format on
const TimeSeriesHeader output_header = input_header;
Test(input_header, {input}, output_header, {output});
}
class SubtractMeanAcrossChannelsCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override {
calculator_name_ = "SubtractMeanAcrossChannelsCalculator";
}
};
TEST_F(SubtractMeanAcrossChannelsCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
TimeSeriesHeader output_header(input_header);
output_header.set_num_samples(2);
Matrix input(input_header.num_channels(), input_header.num_samples());
Matrix output(output_header.num_channels(), output_header.num_samples());
// clang-format off
input.transpose() << 1.0, 2.0, 3.0,
4.0, 5.0, 6.0;
output.transpose() << 1.0 - 3.5, 2.0 - 3.5, 3.0 - 3.5,
4.0 - 3.5, 5.0 - 3.5, 6.0 - 3.5;
// clang-format on
Test(input_header, {input}, output_header, {output});
}
class DivideByMeanAcrossChannelsCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override {
calculator_name_ = "DivideByMeanAcrossChannelsCalculator";
}
};
TEST_F(DivideByMeanAcrossChannelsCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
input.transpose() << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0;
TimeSeriesHeader output_header(input_header);
output_header.set_num_samples(2);
Matrix output(output_header.num_channels(), output_header.num_samples());
output.transpose() << 1.0 / 3.5, 2.0 / 3.5, 3.0 / 3.5, 4.0 / 3.5, 5.0 / 3.5,
6.0 / 3.5;
Test(input_header, {input}, output_header, {output});
}
TEST_F(DivideByMeanAcrossChannelsCalculatorTest, ReturnsOneForZeroMean) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
input.transpose() << -3.0, -2.0, -1.0, 1.0, 2.0, 3.0;
TimeSeriesHeader output_header(input_header);
output_header.set_num_samples(2);
Matrix output(output_header.num_channels(), output_header.num_samples());
output.transpose() << 1.0, 1.0, 1.0, 1.0, 1.0, 1.0;
Test(input_header, {input}, output_header, {output});
}
class MeanCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "MeanCalculator"; }
};
TEST_F(MeanCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
input.transpose() << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0;
TimeSeriesHeader output_header(input_header);
output_header.set_num_samples(1);
output_header.set_sample_rate(10.0);
Matrix output(output_header.num_channels(), output_header.num_samples());
output << (1.0 + 4.0) / 2, (2.0 + 5.0) / 2, (3.0 + 6.0) / 2;
Test(input_header, {input}, output_header, {output});
}
class StandardDeviationCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "StandardDeviationCalculator"; }
};
TEST_F(StandardDeviationCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
input.transpose() << 0.0, 2.0, 3.0, 4.0, 5.0, 8.0;
TimeSeriesHeader output_header(input_header);
output_header.set_sample_rate(10.0);
output_header.set_num_samples(1);
Matrix output(output_header.num_channels(), output_header.num_samples());
output << sqrt((pow(0.0 - 2.0, 2) + pow(4.0 - 2.0, 2)) / 2),
sqrt((pow(2.0 - 3.5, 2) + pow(5.0 - 3.5, 2)) / 2),
sqrt((pow(3.0 - 5.5, 2) + pow(8.0 - 5.5, 2)) / 2);
Test(input_header, {input}, output_header, {output});
}
class CovarianceCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "CovarianceCalculator"; }
};
TEST_F(CovarianceCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
// We'll specify in transposed form so we can write one channel at a time.
input << 1.0, 3.0, 5.0, 9.0, -1.0, -3.0;
TimeSeriesHeader output_header(input_header);
output_header.set_num_samples(output_header.num_channels());
Matrix output(output_header.num_channels(), output_header.num_samples());
output << 1, 2, -1, 2, 4, -2, -1, -2, 1;
Test(input_header, {input}, output_header, {output});
}
class L2NormCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "L2NormCalculator"; }
};
TEST_F(L2NormCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 3, 5, 8, 4, 12, -15;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 1 num_samples: 3");
Matrix output(output_header.num_channels(), output_header.num_samples());
output << 5, 13, 17;
Test(input_header, {input}, output_header, {output});
}
class L2NormalizeColumnCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "L2NormalizeColumnCalculator"; }
};
TEST_F(L2NormalizeColumnCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix output(output_header.num_channels(), output_header.num_samples());
// The values in output are column-wise L2 normalized
// e.g.
// |a| -> |a/sqrt(a^2 + b^2)|
// |b| |b/sqrt(a^2 + b^2)|
output << 0.51449579000473022, 0.40613847970962524, 0.70710676908493042,
0.85749292373657227, 0.91381156444549561, 0.70710676908493042;
Test(input_header, {input}, output_header, {output});
}
class L2NormalizeCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "L2NormalizeCalculator"; }
};
TEST_F(L2NormalizeCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix output(output_header.num_channels(), output_header.num_samples());
// The values in output are L2 normalized
// a -> a/sqrt(a^2 + b^2 + c^2 + ...) * sqrt(matrix.cols()*matrix.rows())
output << 0.45661166, 0.60881555, 1.21763109, 0.76101943, 1.36983498,
1.21763109;
Test(input_header, {input}, output_header, {output});
}
TEST_F(L2NormalizeCalculatorTest, UnitMatrixStaysUnchanged) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 3 num_samples: 5");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1.0, -1.0, 1.0;
Test(input_header, {input}, input_header, {input});
}
class PeakNormalizeCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "PeakNormalizeCalculator"; }
};
TEST_F(PeakNormalizeCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix output(output_header.num_channels(), output_header.num_samples());
output << 0.33333333, 0.44444444, 0.88888889, 0.55555556, 1.0, 0.88888889;
Test(input_header, {input}, output_header, {output});
}
TEST_F(PeakNormalizeCalculatorTest, UnitMatrixStaysUnchanged) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 3 num_samples: 5");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1.0, -1.0, 1.0;
Test(input_header, {input}, input_header, {input});
}
class ElementwiseSquareCalculatorTest
: public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "ElementwiseSquareCalculator"; }
};
TEST_F(ElementwiseSquareCalculatorTest, SinglePacket) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
input << 3, 5, 8, 4, 12, -15;
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
Matrix output(output_header.num_channels(), output_header.num_samples());
output << 9, 25, 64, 16, 144, 225;
Test(input_header, {input}, output_header, {output});
}
class FirstHalfSlicerCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
protected:
void SetUp() override { calculator_name_ = "FirstHalfSlicerCalculator"; }
};
TEST_F(FirstHalfSlicerCalculatorTest, SinglePacketEvenNumSamples) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
// clang-format off
input.transpose() << 0, 1, 2, 3, 4,
5, 6, 7, 8, 9;
// clang-format on
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
Matrix output(output_header.num_channels(), output_header.num_samples());
output.transpose() << 0, 1, 2, 3, 4;
Test(input_header, {input}, output_header, {output});
}
TEST_F(FirstHalfSlicerCalculatorTest, SinglePacketOddNumSamples) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 3");
Matrix input(input_header.num_channels(), input_header.num_samples());
// clang-format off
input.transpose() << 0, 1, 2, 3, 4,
5, 6, 7, 8, 9,
0, 0, 0, 0, 0;
// clang-format on
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
Matrix output(output_header.num_channels(), output_header.num_samples());
output.transpose() << 0, 1, 2, 3, 4;
Test(input_header, {input}, output_header, {output});
}
TEST_F(FirstHalfSlicerCalculatorTest, MultiplePackets) {
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
Matrix input(input_header.num_channels(), input_header.num_samples());
// clang-format off
input.transpose() << 0, 1, 2, 3, 4,
5, 6, 7, 8, 9;
// clang-format on
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
Matrix output(output_header.num_channels(), output_header.num_samples());
output.transpose() << 0, 1, 2, 3, 4;
Test(input_header,
{input, 2 * input,
input + Matrix::Constant(input.rows(), input.cols(), 3.5f)},
output_header,
{output, 2 * output,
output + Matrix::Constant(output.rows(), output.cols(), 3.5f)});
}
} // namespace mediapipe

View File

@ -1,275 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// MediaPipe Calculator wrapper around audio/dsp/mfcc/
// classes MelFilterbank (magnitude spectrograms warped to the Mel
// approximation of the auditory frequency scale) and Mfcc (Mel Frequency
// Cepstral Coefficients, the decorrelated transform of log-Mel-spectrum
// commonly used as acoustic features in speech and other audio tasks.
// Both calculators expect as input the SQUARED_MAGNITUDE-domain outputs
// from the MediaPipe SpectrogramCalculator object.
#include <memory>
#include <vector>
#include "Eigen/Core"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "audio/dsp/mfcc/mel_filterbank.h"
#include "audio/dsp/mfcc/mfcc.h"
#include "mediapipe/calculators/audio/mfcc_mel_calculators.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
namespace {
// Portable version of TimeSeriesHeader's DebugString.
std::string PortableDebugString(const TimeSeriesHeader& header) {
std::string unsubstituted_header_debug_str = R"(
sample_rate: $0
num_channels: $1
num_samples: $2
packet_rate: $3
audio_sample_rate: $4
)";
return absl::Substitute(unsubstituted_header_debug_str, header.sample_rate(),
header.num_channels(), header.num_samples(),
header.packet_rate(), header.audio_sample_rate());
}
} // namespace
// Abstract base class for Calculators that transform feature vectors on a
// frame-by-frame basis.
// Subclasses must override pure virtual methods ConfigureTransform and
// TransformFrame.
// Input and output MediaPipe packets are matrices with one column per frame,
// and one row per feature dimension. Each input packet results in an
// output packet with the same number of columns (but differing numbers of
// rows corresponding to the new feature space).
class FramewiseTransformCalculatorBase : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Sequence of Matrices, each column describing a particular time frame,
// each row a feature dimension, with TimeSeriesHeader.
);
cc->Outputs().Index(0).Set<Matrix>(
// Sequence of Matrices, each column describing a particular time frame,
// each row a feature dimension, with TimeSeriesHeader.
);
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final;
absl::Status Process(CalculatorContext* cc) final;
int num_output_channels(void) { return num_output_channels_; }
void set_num_output_channels(int num_output_channels) {
num_output_channels_ = num_output_channels;
}
private:
// Takes header and options, and sets up state including calling
// set_num_output_channels() on the base object.
virtual absl::Status ConfigureTransform(const TimeSeriesHeader& header,
CalculatorContext* cc) = 0;
// Takes a vector<double> corresponding to an input frame, and
// perform the specific transformation to produce an output frame.
virtual void TransformFrame(const std::vector<double>& input,
std::vector<double>* output) const = 0;
private:
int num_output_channels_;
};
absl::Status FramewiseTransformCalculatorBase::Open(CalculatorContext* cc) {
TimeSeriesHeader input_header;
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), &input_header));
absl::Status status = ConfigureTransform(input_header, cc);
auto output_header = new TimeSeriesHeader(input_header);
output_header->set_num_channels(num_output_channels_);
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
cc->SetOffset(0);
return status;
}
absl::Status FramewiseTransformCalculatorBase::Process(CalculatorContext* cc) {
const Matrix& input = cc->Inputs().Index(0).Get<Matrix>();
const int num_frames = input.cols();
std::unique_ptr<Matrix> output(new Matrix(num_output_channels_, num_frames));
// The main work here is converting each column of the float Matrix
// into a vector of doubles, which is what our target functions from
// dsp_core consume, and doing the reverse with their output.
std::vector<double> input_frame(input.rows());
std::vector<double> output_frame(num_output_channels_);
for (int frame = 0; frame < num_frames; ++frame) {
// Copy input from Eigen::Matrix column to vector<float>.
Eigen::Map<Eigen::MatrixXd> input_frame_map(&input_frame[0],
input_frame.size(), 1);
input_frame_map = input.col(frame).cast<double>();
// Perform the actual transformation.
TransformFrame(input_frame, &output_frame);
// Copy output from vector<float> to Eigen::Vector.
CHECK_EQ(output_frame.size(), num_output_channels_);
Eigen::Map<const Eigen::MatrixXd> output_frame_map(&output_frame[0],
output_frame.size(), 1);
output->col(frame) = output_frame_map.cast<float>();
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return absl::OkStatus();
}
// Calculator wrapper around the dsp/mfcc/mfcc.cc routine.
// Take frames of squared-magnitude spectra from the SpectrogramCalculator
// and convert them into Mel Frequency Cepstral Coefficients.
//
// Example config:
// node {
// calculator: "MfccCalculator"
// input_stream: "spectrogram_frames_stream"
// output_stream: "mfcc_frames_stream"
// options {
// [mediapipe.MfccCalculatorOptions.ext] {
// mel_spectrum_params {
// channel_count: 20
// min_frequency_hertz: 125.0
// max_frequency_hertz: 3800.0
// }
// mfcc_count: 13
// }
// }
// }
class MfccCalculator : public FramewiseTransformCalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
return FramewiseTransformCalculatorBase::GetContract(cc);
}
private:
absl::Status ConfigureTransform(const TimeSeriesHeader& header,
CalculatorContext* cc) override {
MfccCalculatorOptions mfcc_options = cc->Options<MfccCalculatorOptions>();
mfcc_.reset(new audio_dsp::Mfcc());
int input_length = header.num_channels();
// Set up the parameters to the Mfcc object.
set_num_output_channels(mfcc_options.mfcc_count());
mfcc_->set_dct_coefficient_count(num_output_channels());
mfcc_->set_upper_frequency_limit(
mfcc_options.mel_spectrum_params().max_frequency_hertz());
mfcc_->set_lower_frequency_limit(
mfcc_options.mel_spectrum_params().min_frequency_hertz());
mfcc_->set_filterbank_channel_count(
mfcc_options.mel_spectrum_params().channel_count());
// An upstream calculator (such as SpectrogramCalculator) must store
// the sample rate of its input audio waveform in the TimeSeries Header.
// audio_dsp::MelFilterBank needs to know this to
// correctly interpret the spectrogram bins.
if (!header.has_audio_sample_rate()) {
return absl::InvalidArgumentError(
absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ",
PortableDebugString(header)));
}
// Now we can initialize the Mfcc object.
bool initialized =
mfcc_->Initialize(input_length, header.audio_sample_rate());
if (initialized) {
return absl::OkStatus();
} else {
return absl::Status(absl::StatusCode::kInternal,
"Mfcc::Initialize returned uninitialized");
}
}
void TransformFrame(const std::vector<double>& input,
std::vector<double>* output) const override {
mfcc_->Compute(input, output);
}
private:
std::unique_ptr<audio_dsp::Mfcc> mfcc_;
};
REGISTER_CALCULATOR(MfccCalculator);
// Calculator wrapper around the dsp/mfcc/mel_filterbank.cc routine.
// Take frames of squared-magnitude spectra from the SpectrogramCalculator
// and convert them into Mel-warped (linear-magnitude) spectra.
// Note: This code computes a mel-frequency filterbank, using a simple
// algorithm that gives bad results (some mel channels that are always zero)
// if you ask for too many channels.
class MelSpectrumCalculator : public FramewiseTransformCalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
return FramewiseTransformCalculatorBase::GetContract(cc);
}
private:
absl::Status ConfigureTransform(const TimeSeriesHeader& header,
CalculatorContext* cc) override {
MelSpectrumCalculatorOptions mel_spectrum_options =
cc->Options<MelSpectrumCalculatorOptions>();
mel_filterbank_.reset(new audio_dsp::MelFilterbank());
int input_length = header.num_channels();
set_num_output_channels(mel_spectrum_options.channel_count());
// An upstream calculator (such as SpectrogramCalculator) must store
// the sample rate of its input audio waveform in the TimeSeries Header.
// audio_dsp::MelFilterBank needs to know this to
// correctly interpret the spectrogram bins.
if (!header.has_audio_sample_rate()) {
return absl::InvalidArgumentError(
absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ",
PortableDebugString(header)));
}
bool initialized = mel_filterbank_->Initialize(
input_length, header.audio_sample_rate(), num_output_channels(),
mel_spectrum_options.min_frequency_hertz(),
mel_spectrum_options.max_frequency_hertz());
if (initialized) {
return absl::OkStatus();
} else {
return absl::Status(absl::StatusCode::kInternal,
"mfcc::Initialize returned uninitialized");
}
}
void TransformFrame(const std::vector<double>& input,
std::vector<double>* output) const override {
mel_filterbank_->Compute(input, output);
}
private:
std::unique_ptr<audio_dsp::MelFilterbank> mel_filterbank_;
};
REGISTER_CALCULATOR(MelSpectrumCalculator);
} // namespace mediapipe

View File

@ -1,50 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message MelSpectrumCalculatorOptions {
extend CalculatorOptions {
optional MelSpectrumCalculatorOptions ext = 78581812;
}
// The fields are to populate the config parameters in
// audio/dsp/mfcc/mel_filterbank.h
// but the names are chose to mirror
// audio/hearing/filterbanks/cochlea_gammatone_filterbank.proto
// and the default values match those in
// speech/greco3/frontend/filter_bank.proto .
// Total number of frequency bands to use.
optional int32 channel_count = 1 [default = 20];
// Lower edge of lowest triangular Mel band.
optional float min_frequency_hertz = 2 [default = 125.0];
// Upper edge of highest triangular Mel band.
optional float max_frequency_hertz = 3 [default = 3800.0];
}
message MfccCalculatorOptions {
extend CalculatorOptions {
optional MfccCalculatorOptions ext = 78450441;
}
// Specification of the underlying mel filterbank.
optional MelSpectrumCalculatorOptions mel_spectrum_params = 1;
// How many MFCC coefficients to emit.
optional uint32 mfcc_count = 2 [default = 13];
}

View File

@ -1,149 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "Eigen/Core"
#include "mediapipe/calculators/audio/mfcc_mel_calculators.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/util/time_series_test_util.h"
namespace mediapipe {
// Use a sample rate that is unlikely to be a default somewhere.
const float kAudioSampleRate = 8800.0;
template <typename OptionsType, const char* CalculatorName>
class FramewiseTransformCalculatorTest
: public TimeSeriesCalculatorTest<OptionsType> {
protected:
void SetUp() override {
this->calculator_name_ = CalculatorName;
this->num_input_channels_ = 129;
// This is the frame rate coming out of the SpectrogramCalculator.
this->input_sample_rate_ = 100.0;
}
// Returns the number of samples per packet.
int GenerateRandomNonnegInputStream(int num_packets) {
const double kSecondsPerPacket = 0.2;
const int num_samples_per_packet =
kSecondsPerPacket * this->input_sample_rate_;
for (int i = 0; i < num_packets; ++i) {
const int timestamp =
i * kSecondsPerPacket * Timestamp::kTimestampUnitsPerSecond;
// Mfcc, MelSpectrum expect squared-magnitude inputs, so make
// sure the input data has no negative values.
Matrix* sqdata = this->NewRandomMatrix(this->num_input_channels_,
num_samples_per_packet);
*sqdata = sqdata->array().square();
this->AppendInputPacket(sqdata, timestamp);
}
return num_samples_per_packet;
}
void CheckOutputPacketMetadata(int expected_num_channels,
int expected_num_samples_per_packet) {
int expected_timestamp = 0;
for (const auto& packet : this->output().packets) {
EXPECT_EQ(expected_timestamp, packet.Timestamp().Value());
expected_timestamp += expected_num_samples_per_packet /
this->input_sample_rate_ *
Timestamp::kTimestampUnitsPerSecond;
const Matrix& output_matrix = packet.template Get<Matrix>();
EXPECT_EQ(output_matrix.rows(), expected_num_channels);
EXPECT_EQ(output_matrix.cols(), expected_num_samples_per_packet);
}
}
void SetupGraphAndHeader() {
this->InitializeGraph();
this->FillInputHeader();
}
// Argument is the expected number of dimensions (channels, columns) in
// the output data from the Calculator under test, which the test should
// know.
void SetupRandomInputPackets() {
constexpr int kNumPackets = 5;
num_samples_per_packet_ = GenerateRandomNonnegInputStream(kNumPackets);
}
absl::Status Run() { return this->RunGraph(); }
void CheckResults(int expected_num_channels) {
const auto& output_header =
this->output().header.template Get<TimeSeriesHeader>();
EXPECT_EQ(this->input_sample_rate_, output_header.sample_rate());
CheckOutputPacketMetadata(expected_num_channels, num_samples_per_packet_);
// Sanity check that output packets have non-zero energy.
for (const auto& packet : this->output().packets) {
const Matrix& data = packet.template Get<Matrix>();
EXPECT_GT(data.squaredNorm(), 0);
}
}
// Allows SetupRandomInputPackets() to inform CheckResults() about how
// big the packets are supposed to be.
int num_samples_per_packet_;
};
constexpr char kMfccCalculator[] = "MfccCalculator";
typedef FramewiseTransformCalculatorTest<MfccCalculatorOptions, kMfccCalculator>
MfccCalculatorTest;
TEST_F(MfccCalculatorTest, AudioSampleRateFromInputHeader) {
audio_sample_rate_ = kAudioSampleRate;
SetupGraphAndHeader();
SetupRandomInputPackets();
MP_EXPECT_OK(Run());
CheckResults(options_.mfcc_count());
}
TEST_F(MfccCalculatorTest, NoAudioSampleRate) {
// Leave audio_sample_rate_ == kUnset, so it is not present in the
// input TimeSeriesHeader; expect failure.
SetupGraphAndHeader();
SetupRandomInputPackets();
EXPECT_FALSE(Run().ok());
}
constexpr char kMelSpectrumCalculator[] = "MelSpectrumCalculator";
typedef FramewiseTransformCalculatorTest<MelSpectrumCalculatorOptions,
kMelSpectrumCalculator>
MelSpectrumCalculatorTest;
TEST_F(MelSpectrumCalculatorTest, AudioSampleRateFromInputHeader) {
audio_sample_rate_ = kAudioSampleRate;
SetupGraphAndHeader();
SetupRandomInputPackets();
MP_EXPECT_OK(Run());
CheckResults(options_.channel_count());
}
TEST_F(MelSpectrumCalculatorTest, NoAudioSampleRate) {
// Leave audio_sample_rate_ == kUnset, so it is not present in the
// input TimeSeriesHeader; expect failure.
SetupGraphAndHeader();
SetupRandomInputPackets();
EXPECT_FALSE(Run().ok());
}
} // namespace mediapipe

View File

@ -1,190 +0,0 @@
// Copyright 2019, 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Defines RationalFactorResampleCalculator.
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.h"
#include "audio/dsp/resampler_q.h"
using audio_dsp::Resampler;
namespace mediapipe {
absl::Status RationalFactorResampleCalculator::Process(CalculatorContext* cc) {
return ProcessInternal(cc->Inputs().Index(0).Get<Matrix>(), false, cc);
}
absl::Status RationalFactorResampleCalculator::Close(CalculatorContext* cc) {
if (initial_timestamp_ == Timestamp::Unstarted()) {
return absl::OkStatus();
}
Matrix empty_input_frame(num_channels_, 0);
return ProcessInternal(empty_input_frame, true, cc);
}
namespace {
void CopyChannelToVector(const Matrix& matrix, int channel,
std::vector<float>* vec) {
vec->resize(matrix.cols());
Eigen::Map<Eigen::ArrayXf>(vec->data(), vec->size()) = matrix.row(channel);
}
void CopyVectorToChannel(const std::vector<float>& vec, Matrix* matrix,
int channel) {
if (matrix->cols() == 0) {
matrix->resize(matrix->rows(), vec.size());
} else {
CHECK_EQ(vec.size(), matrix->cols());
}
CHECK_LT(channel, matrix->rows());
matrix->row(channel) =
Eigen::Map<const Eigen::ArrayXf>(vec.data(), vec.size());
}
} // namespace
absl::Status RationalFactorResampleCalculator::Open(CalculatorContext* cc) {
RationalFactorResampleCalculatorOptions resample_options =
cc->Options<RationalFactorResampleCalculatorOptions>();
if (!resample_options.has_target_sample_rate()) {
return tool::StatusInvalid(
"resample_options doesn't have target_sample_rate.");
}
target_sample_rate_ = resample_options.target_sample_rate();
TimeSeriesHeader input_header;
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), &input_header));
source_sample_rate_ = input_header.sample_rate();
num_channels_ = input_header.num_channels();
// Don't create resamplers for pass-thru (sample rates are equal).
if (source_sample_rate_ != target_sample_rate_) {
resampler_.resize(num_channels_);
for (auto& r : resampler_) {
r = ResamplerFromOptions(source_sample_rate_, target_sample_rate_,
resample_options);
if (!r) {
LOG(ERROR) << "Failed to initialize resampler.";
return absl::UnknownError("Failed to initialize resampler.");
}
}
}
TimeSeriesHeader* output_header = new TimeSeriesHeader(input_header);
output_header->set_sample_rate(target_sample_rate_);
// The resampler doesn't make guarantees about how many samples will
// be in each packet.
output_header->clear_packet_rate();
output_header->clear_num_samples();
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
cumulative_output_samples_ = 0;
cumulative_input_samples_ = 0;
initial_timestamp_ = Timestamp::Unstarted();
check_inconsistent_timestamps_ =
resample_options.check_inconsistent_timestamps();
return absl::OkStatus();
}
absl::Status RationalFactorResampleCalculator::ProcessInternal(
const Matrix& input_frame, bool should_flush, CalculatorContext* cc) {
if (initial_timestamp_ == Timestamp::Unstarted()) {
initial_timestamp_ = cc->InputTimestamp();
}
if (check_inconsistent_timestamps_) {
time_series_util::LogWarningIfTimestampIsInconsistent(
cc->InputTimestamp(), initial_timestamp_, cumulative_input_samples_,
source_sample_rate_);
}
Timestamp output_timestamp =
initial_timestamp_ + ((cumulative_output_samples_ / target_sample_rate_) *
Timestamp::kTimestampUnitsPerSecond);
cumulative_input_samples_ += input_frame.cols();
std::unique_ptr<Matrix> output_frame(new Matrix(num_channels_, 0));
if (resampler_.empty()) {
// Sample rates were same for input and output; pass-thru.
*output_frame = input_frame;
} else {
if (!Resample(input_frame, output_frame.get(), should_flush)) {
return absl::UnknownError("Resample() failed.");
}
}
cumulative_output_samples_ += output_frame->cols();
if (output_frame->cols() > 0) {
cc->Outputs().Index(0).Add(output_frame.release(), output_timestamp);
}
return absl::OkStatus();
}
bool RationalFactorResampleCalculator::Resample(const Matrix& input_frame,
Matrix* output_frame,
bool should_flush) {
std::vector<float> input_vector;
std::vector<float> output_vector;
for (int i = 0; i < input_frame.rows(); ++i) {
CopyChannelToVector(input_frame, i, &input_vector);
if (should_flush) {
resampler_[i]->Flush(&output_vector);
} else {
resampler_[i]->ProcessSamples(input_vector, &output_vector);
}
CopyVectorToChannel(output_vector, output_frame, i);
}
return true;
}
// static
std::unique_ptr<Resampler<float>>
RationalFactorResampleCalculator::ResamplerFromOptions(
const double source_sample_rate, const double target_sample_rate,
const RationalFactorResampleCalculatorOptions& options) {
std::unique_ptr<Resampler<float>> resampler;
const auto& rational_factor_options =
options.resampler_rational_factor_options();
audio_dsp::QResamplerParams params;
if (rational_factor_options.has_radius() &&
rational_factor_options.has_cutoff() &&
rational_factor_options.has_kaiser_beta()) {
// Convert RationalFactorResampler kernel parameters to QResampler
// settings.
params.filter_radius_factor =
rational_factor_options.radius() *
std::min(1.0, target_sample_rate / source_sample_rate);
params.cutoff_proportion = 2 * rational_factor_options.cutoff() /
std::min(source_sample_rate, target_sample_rate);
params.kaiser_beta = rational_factor_options.kaiser_beta();
}
// Set large enough so that the resampling factor between common sample
// rates (e.g. 8kHz, 16kHz, 22.05kHz, 32kHz, 44.1kHz, 48kHz) is exact, and
// that any factor is represented with error less than 0.025%.
params.max_denominator = 2000;
// NOTE: QResampler supports multichannel resampling, so the code might be
// simplified using a single instance rather than one per channel.
resampler = absl::make_unique<audio_dsp::QResampler<float>>(
source_sample_rate, target_sample_rate, /*num_channels=*/1, params);
if (resampler != nullptr && !resampler->Valid()) {
resampler = std::unique_ptr<Resampler<float>>();
}
return resampler;
}
REGISTER_CALCULATOR(RationalFactorResampleCalculator);
} // namespace mediapipe

View File

@ -1,109 +0,0 @@
// Copyright 2019, 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_
#include <algorithm>
#include <memory>
#include <vector>
#include "Eigen/Core"
#include "absl/strings/str_cat.h"
#include "audio/dsp/resampler.h"
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
// MediaPipe Calculator for resampling a (vector-valued)
// input time series with a uniform sample rate. The output
// stream's sampling rate is specified by target_sample_rate in the
// RationalFactorResampleCalculatorOptions. The output time series may have
// a varying number of samples per frame.
//
// NOTE: This calculator uses QResampler, despite the name, which supersedes
// RationalFactorResampler.
class RationalFactorResampleCalculator : public CalculatorBase {
public:
struct TestAccess;
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Single input stream with TimeSeriesHeader.
);
cc->Outputs().Index(0).Set<Matrix>(
// Resampled stream with TimeSeriesHeader.
);
return absl::OkStatus();
}
// Returns FAIL if the input stream header is invalid or if the
// resampler cannot be initialized.
absl::Status Open(CalculatorContext* cc) override;
// Resamples a packet of TimeSeries data. Returns FAIL if the
// resampler state becomes inconsistent.
absl::Status Process(CalculatorContext* cc) override;
// Flushes any remaining state. Returns FAIL if the resampler state
// becomes inconsistent.
absl::Status Close(CalculatorContext* cc) override;
protected:
typedef audio_dsp::Resampler<float> ResamplerType;
// Returns a Resampler<float> implementation specified by the
// RationalFactorResampleCalculatorOptions proto. Returns null if the options
// specify an invalid resampler.
static std::unique_ptr<ResamplerType> ResamplerFromOptions(
const double source_sample_rate, const double target_sample_rate,
const RationalFactorResampleCalculatorOptions& options);
// Does Timestamp bookkeeping and resampling common to Process() and
// Close(). Returns FAIL if the resampler state becomes
// inconsistent.
absl::Status ProcessInternal(const Matrix& input_frame, bool should_flush,
CalculatorContext* cc);
// Uses the internal resampler_ objects to actually resample each
// row of the input TimeSeries. Returns false if the resampler
// state becomes inconsistent.
bool Resample(const Matrix& input_frame, Matrix* output_frame,
bool should_flush);
double source_sample_rate_;
double target_sample_rate_;
int64 cumulative_input_samples_;
int64 cumulative_output_samples_;
Timestamp initial_timestamp_;
bool check_inconsistent_timestamps_;
int num_channels_;
std::vector<std::unique_ptr<ResamplerType>> resampler_;
};
// Test-only access to RationalFactorResampleCalculator methods.
struct RationalFactorResampleCalculator::TestAccess {
static std::unique_ptr<ResamplerType> ResamplerFromOptions(
const double source_sample_rate, const double target_sample_rate,
const RationalFactorResampleCalculatorOptions& options) {
return RationalFactorResampleCalculator::ResamplerFromOptions(
source_sample_rate, target_sample_rate, options);
}
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_

View File

@ -1,47 +0,0 @@
// Copyright 2019, 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
// NOTE: This calculator uses QResampler, despite the name, which supersedes
// RationalFactorResampler.
message RationalFactorResampleCalculatorOptions {
extend CalculatorOptions {
optional RationalFactorResampleCalculatorOptions ext = 259760074;
}
// target_sample_rate is the sample rate, in Hertz, of the output
// stream. Required. Must be greater than 0.
optional double target_sample_rate = 1;
// Parameters for initializing QResampler. See QResampler for more details.
message ResamplerRationalFactorOptions {
// Kernel radius in units of input samples.
optional double radius = 1;
// Anti-aliasing cutoff frequency in Hertz. A reasonable setting is
// 0.45 * min(input_sample_rate, output_sample_rate).
optional double cutoff = 2;
// The Kaiser beta parameter for the kernel window.
optional double kaiser_beta = 3 [default = 6.0];
}
optional ResamplerRationalFactorOptions resampler_rational_factor_options = 2;
// Set to false to disable checks for jitter in timestamp values. Useful with
// live audio input.
optional bool check_inconsistent_timestamps = 3 [default = true];
}

View File

@ -1,246 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.h"
#include <math.h>
#include <algorithm>
#include <string>
#include <vector>
#include "Eigen/Core"
#include "audio/dsp/signal_vector_util.h"
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.pb.h"
#include "mediapipe/framework//tool/validate_type.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/time_series_test_util.h"
namespace mediapipe {
namespace {
const int kInitialTimestampOffsetMilliseconds = 4;
class RationalFactorResampleCalculatorTest
: public TimeSeriesCalculatorTest<RationalFactorResampleCalculatorOptions> {
protected:
void SetUp() override {
calculator_name_ = "RationalFactorResampleCalculator";
input_sample_rate_ = 4000.0;
num_input_channels_ = 3;
}
// Expects two vectors whose lengths are almost the same and whose
// elements are equal (for indices that are present in both).
//
// This is useful because the resampler doesn't make precise
// guarantees about its output size.
void ExpectVectorMostlyFloatEq(const std::vector<float>& expected,
const std::vector<float>& actual) {
// Lengths should be close, but don't have to be equal.
ASSERT_NEAR(expected.size(), actual.size(), 1);
for (int i = 0; i < std::min(expected.size(), actual.size()); ++i) {
EXPECT_FLOAT_EQ(expected[i], actual[i]) << " where i=" << i << ".";
}
}
// Returns a float value with the sample, channel, and timestamp
// separated by a few orders of magnitude, for easy parsing by
// humans.
double TestValue(int sample, int channel, int timestamp_in_microseconds) {
return timestamp_in_microseconds * 100.0 + sample + channel / 10.0;
}
// Caller takes ownership of the returned value.
Matrix* NewTestFrame(int num_channels, int num_samples, int timestamp) {
auto matrix = new Matrix(num_channels, num_samples);
for (int c = 0; c < num_channels; ++c) {
for (int i = 0; i < num_samples; ++i) {
(*matrix)(c, i) = TestValue(i, c, timestamp);
}
}
return matrix;
}
// Initializes and runs the test graph.
absl::Status Run(double output_sample_rate) {
options_.set_target_sample_rate(output_sample_rate);
InitializeGraph();
FillInputHeader();
concatenated_input_samples_.resize(num_input_channels_, 0);
num_input_samples_ = 0;
for (int i = 0; i < 5; ++i) {
int packet_size = (i + 1) * 10;
int timestamp = kInitialTimestampOffsetMilliseconds +
num_input_samples_ / input_sample_rate_ *
Timestamp::kTimestampUnitsPerSecond;
Matrix* data_frame =
NewTestFrame(num_input_channels_, packet_size, timestamp);
// Keep a reference copy of the input.
//
// conservativeResize() is needed here to preserve the existing
// data. Eigen's resize() resizes without preserving data.
concatenated_input_samples_.conservativeResize(
num_input_channels_, num_input_samples_ + packet_size);
concatenated_input_samples_.rightCols(packet_size) = *data_frame;
num_input_samples_ += packet_size;
AppendInputPacket(data_frame, timestamp);
}
return RunGraph();
}
void CheckOutputLength(double output_sample_rate) {
double factor = output_sample_rate / input_sample_rate_;
int num_output_samples = 0;
for (const Packet& packet : output().packets) {
num_output_samples += packet.Get<Matrix>().cols();
}
// The exact number of expected samples may vary based on the implementation
// of the resampler since the exact value is not an integer.
const double expected_num_output_samples = num_input_samples_ * factor;
EXPECT_LE(ceil(expected_num_output_samples), num_output_samples);
EXPECT_GE(ceil(expected_num_output_samples) + 11, num_output_samples);
}
// Checks that output timestamps are consistent with the
// output_sample_rate and output packet sizes.
void CheckOutputPacketTimestamps(double output_sample_rate) {
int num_output_samples = 0;
for (const Packet& packet : output().packets) {
const int expected_timestamp = kInitialTimestampOffsetMilliseconds +
num_output_samples / output_sample_rate *
Timestamp::kTimestampUnitsPerSecond;
EXPECT_NEAR(expected_timestamp, packet.Timestamp().Value(), 1);
num_output_samples += packet.Get<Matrix>().cols();
}
}
// Checks that output values from the calculator (which resamples
// packet-by-packet) are consistent with resampling the entire
// signal at once.
void CheckOutputValues(double output_sample_rate) {
for (int i = 0; i < num_input_channels_; ++i) {
auto verification_resampler =
RationalFactorResampleCalculator::TestAccess::ResamplerFromOptions(
input_sample_rate_, output_sample_rate, options_);
std::vector<float> input_data;
for (int j = 0; j < num_input_samples_; ++j) {
input_data.push_back(concatenated_input_samples_(i, j));
}
std::vector<float> expected_resampled_data;
std::vector<float> temp;
verification_resampler->ProcessSamples(input_data, &temp);
audio_dsp::VectorAppend(&expected_resampled_data, temp);
verification_resampler->Flush(&temp);
audio_dsp::VectorAppend(&expected_resampled_data, temp);
std::vector<float> actual_resampled_data;
for (const Packet& packet : output().packets) {
Matrix output_frame_row = packet.Get<Matrix>().row(i);
actual_resampled_data.insert(
actual_resampled_data.end(), &output_frame_row(0),
&output_frame_row(0) + output_frame_row.cols());
}
ExpectVectorMostlyFloatEq(expected_resampled_data, actual_resampled_data);
}
}
void CheckOutputHeaders(double output_sample_rate) {
const TimeSeriesHeader& output_header =
output().header.Get<TimeSeriesHeader>();
TimeSeriesHeader expected_header;
expected_header.set_sample_rate(output_sample_rate);
expected_header.set_num_channels(num_input_channels_);
EXPECT_THAT(output_header, mediapipe::EqualsProto(expected_header));
}
void CheckOutput(double output_sample_rate) {
CheckOutputLength(output_sample_rate);
CheckOutputPacketTimestamps(output_sample_rate);
CheckOutputValues(output_sample_rate);
CheckOutputHeaders(output_sample_rate);
}
void CheckOutputUnchanged() {
for (int i = 0; i < num_input_channels_; ++i) {
std::vector<float> expected_resampled_data;
for (int j = 0; j < num_input_samples_; ++j) {
expected_resampled_data.push_back(concatenated_input_samples_(i, j));
}
std::vector<float> actual_resampled_data;
for (const Packet& packet : output().packets) {
Matrix output_frame_row = packet.Get<Matrix>().row(i);
actual_resampled_data.insert(
actual_resampled_data.end(), &output_frame_row(0),
&output_frame_row(0) + output_frame_row.cols());
}
ExpectVectorMostlyFloatEq(expected_resampled_data, actual_resampled_data);
}
}
int num_input_samples_;
Matrix concatenated_input_samples_;
};
TEST_F(RationalFactorResampleCalculatorTest, Upsample) {
const double kUpsampleRate = input_sample_rate_ * 1.9;
MP_ASSERT_OK(Run(kUpsampleRate));
CheckOutput(kUpsampleRate);
}
TEST_F(RationalFactorResampleCalculatorTest, Downsample) {
const double kDownsampleRate = input_sample_rate_ / 1.9;
MP_ASSERT_OK(Run(kDownsampleRate));
CheckOutput(kDownsampleRate);
}
TEST_F(RationalFactorResampleCalculatorTest, UsesRationalFactorResampler) {
const double kUpsampleRate = input_sample_rate_ * 2;
MP_ASSERT_OK(Run(kUpsampleRate));
CheckOutput(kUpsampleRate);
}
TEST_F(RationalFactorResampleCalculatorTest, PassthroughIfSampleRateUnchanged) {
const double kUpsampleRate = input_sample_rate_;
MP_ASSERT_OK(Run(kUpsampleRate));
CheckOutputUnchanged();
}
TEST_F(RationalFactorResampleCalculatorTest, FailsOnBadTargetRate) {
ASSERT_FALSE(Run(-999.9).ok()); // Invalid output sample rate.
}
TEST_F(RationalFactorResampleCalculatorTest, DoesNotDieOnEmptyInput) {
options_.set_target_sample_rate(input_sample_rate_);
InitializeGraph();
FillInputHeader();
MP_ASSERT_OK(RunGraph());
EXPECT_TRUE(output().packets.empty());
}
} // anonymous namespace
} // namespace mediapipe

View File

@ -1,452 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Defines SpectrogramCalculator.
#include <math.h>
#include <complex>
#include <deque>
#include <memory>
#include <string>
#include "Eigen/Core"
#include "absl/strings/string_view.h"
#include "audio/dsp/spectrogram/spectrogram.h"
#include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/source_location.h"
#include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
// MediaPipe Calculator for computing the "spectrogram" (short-time Fourier
// transform squared-magnitude, by default) of a multichannel input
// time series, including optionally overlapping frames. Options are
// specified in SpectrogramCalculatorOptions proto (where names are chosen
// to mirror TimeSeriesFramerCalculator):
//
// Result is a MatrixData record (for single channel input and when the
// allow_multichannel_input flag is false), or a vector of MatrixData records,
// one for each channel (when the allow_multichannel_input flag is set). The
// rows of each spectrogram matrix correspond to the n_fft/2+1 unique complex
// values, or squared/linear/dB magnitudes, depending on the output_type option.
// Each input packet will result in zero or one output packets, each containing
// one Matrix for each channel of the input, where each Matrix has one or more
// columns of spectral values, one for each complete frame of input samples. If
// the input packet contains too few samples to trigger a new output frame, no
// output packet is generated (since zero-length packets are not legal since
// they would result in timestamps that were equal, not strictly increasing).
//
// Output packet Timestamps are set to the beginning of each frame. This is to
// allow calculators downstream from SpectrogramCalculator to have aligned
// Timestamps regardless of a packet's signal length.
//
// Both frame_duration_seconds and frame_overlap_seconds will be
// rounded to the nearest integer number of samples. Conseqently, all output
// frames will be based on the same number of input samples, and each
// analysis frame will advance from its predecessor by the same time step.
class SpectrogramCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Input stream with TimeSeriesHeader.
);
SpectrogramCalculatorOptions spectrogram_options =
cc->Options<SpectrogramCalculatorOptions>();
if (!spectrogram_options.allow_multichannel_input()) {
if (spectrogram_options.output_type() ==
SpectrogramCalculatorOptions::COMPLEX) {
cc->Outputs().Index(0).Set<Eigen::MatrixXcf>(
// Complex spectrogram frames with TimeSeriesHeader.
);
} else {
cc->Outputs().Index(0).Set<Matrix>(
// Spectrogram frames with TimeSeriesHeader.
);
}
} else {
if (spectrogram_options.output_type() ==
SpectrogramCalculatorOptions::COMPLEX) {
cc->Outputs().Index(0).Set<std::vector<Eigen::MatrixXcf>>(
// Complex spectrogram frames with MultiStreamTimeSeriesHeader.
);
} else {
cc->Outputs().Index(0).Set<std::vector<Matrix>>(
// Spectrogram frames with MultiStreamTimeSeriesHeader.
);
}
}
return absl::OkStatus();
}
// Returns FAIL if the input stream header is invalid.
absl::Status Open(CalculatorContext* cc) override;
// Outputs at most one packet consisting of a single Matrix with one or
// more columns containing the spectral values from as many input frames
// as are completed by the input samples. Always returns OK.
absl::Status Process(CalculatorContext* cc) override;
// Performs zero-padding and processing of any remaining samples
// if pad_final_packet is set.
// Returns OK.
absl::Status Close(CalculatorContext* cc) override;
private:
Timestamp CurrentOutputTimestamp(CalculatorContext* cc) {
if (use_local_timestamp_) {
const Timestamp now = cc->InputTimestamp();
if (now == Timestamp::Done()) {
// During Close the timestamp is not available, send an estimate.
return last_local_output_timestamp_ +
round(last_completed_frames_ * frame_step_samples() *
Timestamp::kTimestampUnitsPerSecond / input_sample_rate_);
}
last_local_output_timestamp_ = now;
return now;
}
return CumulativeOutputTimestamp();
}
Timestamp CumulativeOutputTimestamp() {
// Cumulative output timestamp is the *center* of the next frame to be
// emitted, hence delayed by half a window duration compared to relevant
// input timestamp.
return initial_input_timestamp_ +
round(cumulative_completed_frames_ * frame_step_samples() *
Timestamp::kTimestampUnitsPerSecond / input_sample_rate_);
}
int frame_step_samples() const {
return frame_duration_samples_ - frame_overlap_samples_;
}
// Take the next set of input samples, already translated into a
// vector<float> and pass them to the spectrogram object.
// Convert the output of the spectrogram object into a Matrix (or an
// Eigen::MatrixXcf if complex-valued output is requested) and pass to
// MediaPipe output.
absl::Status ProcessVector(const Matrix& input_stream, CalculatorContext* cc);
// Templated function to process either real- or complex-output spectrogram.
template <class OutputMatrixType>
absl::Status ProcessVectorToOutput(
const Matrix& input_stream,
const OutputMatrixType postprocess_output_fn(const OutputMatrixType&),
CalculatorContext* cc);
// Use the MediaPipe timestamp instead of the estimated one. Useful when the
// data is intermittent.
bool use_local_timestamp_;
Timestamp last_local_output_timestamp_;
double input_sample_rate_;
bool pad_final_packet_;
int frame_duration_samples_;
int frame_overlap_samples_;
// How many samples we've been passed, used for checking input time stamps.
int64 cumulative_input_samples_;
// How many frames we've emitted, used for calculating output time stamps.
int64 cumulative_completed_frames_;
// How many frames were emitted last, used for estimating the timestamp on
// Close when use_local_timestamp_ is true;
int64 last_completed_frames_;
Timestamp initial_input_timestamp_;
int num_input_channels_;
// How many frequency bins we emit (=N_FFT/2 + 1).
int num_output_channels_;
// Which output type?
int output_type_;
// Output type: mono or multichannel.
bool allow_multichannel_input_;
// Vector of Spectrogram objects, one for each channel.
std::vector<std::unique_ptr<audio_dsp::Spectrogram>> spectrogram_generators_;
// Fixed scale factor applied to output values (regardless of type).
double output_scale_;
static const float kLnPowerToDb;
};
REGISTER_CALCULATOR(SpectrogramCalculator);
// Factor to convert ln(magnitude_squared) to deciBels = 10.0/ln(10.0).
const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518;
absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
SpectrogramCalculatorOptions spectrogram_options =
cc->Options<SpectrogramCalculatorOptions>();
use_local_timestamp_ = spectrogram_options.use_local_timestamp();
if (spectrogram_options.frame_duration_seconds() <= 0.0) {
// TODO: return an error.
}
if (spectrogram_options.frame_overlap_seconds() >=
spectrogram_options.frame_duration_seconds()) {
// TODO: return an error.
}
if (spectrogram_options.frame_overlap_seconds() < 0.0) {
// TODO: return an error.
}
TimeSeriesHeader input_header;
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), &input_header));
input_sample_rate_ = input_header.sample_rate();
num_input_channels_ = input_header.num_channels();
if (!spectrogram_options.allow_multichannel_input() &&
num_input_channels_ != 1) {
// TODO: return an error.
}
frame_duration_samples_ =
round(spectrogram_options.frame_duration_seconds() * input_sample_rate_);
frame_overlap_samples_ =
round(spectrogram_options.frame_overlap_seconds() * input_sample_rate_);
pad_final_packet_ = spectrogram_options.pad_final_packet();
output_type_ = spectrogram_options.output_type();
allow_multichannel_input_ = spectrogram_options.allow_multichannel_input();
output_scale_ = spectrogram_options.output_scale();
std::vector<double> window;
switch (spectrogram_options.window_type()) {
case SpectrogramCalculatorOptions::COSINE:
audio_dsp::CosineWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::HANN:
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::HAMMING:
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
}
// Propagate settings down to the actual Spectrogram object.
spectrogram_generators_.clear();
for (int i = 0; i < num_input_channels_; i++) {
spectrogram_generators_.push_back(
std::unique_ptr<audio_dsp::Spectrogram>(new audio_dsp::Spectrogram()));
spectrogram_generators_[i]->Initialize(window, frame_step_samples());
}
num_output_channels_ =
spectrogram_generators_[0]->output_frequency_channels();
std::unique_ptr<TimeSeriesHeader> output_header(
new TimeSeriesHeader(input_header));
// Store the actual sample rate of the input audio in the TimeSeriesHeader
// so that subsequent calculators can figure out the frequency scale of
// our output.
output_header->set_audio_sample_rate(input_sample_rate_);
// Setup rest of output header.
output_header->set_num_channels(num_output_channels_);
output_header->set_sample_rate(input_sample_rate_ / frame_step_samples());
// Although we usually generate one output packet for each input
// packet, this might not be true for input packets whose size is smaller
// than the analysis window length. So we clear output_header.packet_rate
// because we can't guarantee a constant packet rate. Similarly, the number
// of output frames per packet depends on the input packet, so we also clear
// output_header.num_samples.
output_header->clear_packet_rate();
output_header->clear_num_samples();
if (!spectrogram_options.allow_multichannel_input()) {
cc->Outputs().Index(0).SetHeader(Adopt(output_header.release()));
} else {
std::unique_ptr<MultiStreamTimeSeriesHeader> multichannel_output_header(
new MultiStreamTimeSeriesHeader());
*multichannel_output_header->mutable_time_series_header() = *output_header;
multichannel_output_header->set_num_streams(num_input_channels_);
cc->Outputs().Index(0).SetHeader(
Adopt(multichannel_output_header.release()));
}
cumulative_completed_frames_ = 0;
last_completed_frames_ = 0;
initial_input_timestamp_ = Timestamp::Unstarted();
if (use_local_timestamp_) {
// Inform the framework that the calculator will output packets at the same
// timestamps as input packets to enable packet queueing optimizations. The
// final packet (emitted from Close()) does not follow this rule but it's
// sufficient that its timestamp is strictly greater than the timestamp of
// the previous packet.
cc->SetOffset(0);
}
return absl::OkStatus();
}
absl::Status SpectrogramCalculator::Process(CalculatorContext* cc) {
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
initial_input_timestamp_ = cc->InputTimestamp();
}
const Matrix& input_stream = cc->Inputs().Index(0).Get<Matrix>();
if (input_stream.rows() != num_input_channels_) {
// TODO: return an error.
}
cumulative_input_samples_ += input_stream.cols();
return ProcessVector(input_stream, cc);
}
template <class OutputMatrixType>
absl::Status SpectrogramCalculator::ProcessVectorToOutput(
const Matrix& input_stream,
const OutputMatrixType postprocess_output_fn(const OutputMatrixType&),
CalculatorContext* cc) {
std::unique_ptr<std::vector<OutputMatrixType>> spectrogram_matrices(
new std::vector<OutputMatrixType>());
std::vector<std::vector<typename OutputMatrixType::Scalar>> output_vectors;
// Compute a spectrogram for each channel.
int num_output_time_frames;
for (int channel = 0; channel < input_stream.rows(); ++channel) {
output_vectors.clear();
// Copy one row (channel) of the input matrix into the std::vector.
std::vector<float> input_vector(input_stream.cols());
Eigen::Map<Matrix>(&input_vector[0], 1, input_vector.size()) =
input_stream.row(channel);
if (!spectrogram_generators_[channel]->ComputeSpectrogram(
input_vector, &output_vectors)) {
return absl::Status(absl::StatusCode::kInternal,
"Spectrogram returned failure");
}
if (channel == 0) {
// Record the number of time frames we expect from each channel.
num_output_time_frames = output_vectors.size();
} else {
RET_CHECK_EQ(output_vectors.size(), num_output_time_frames)
<< "Inconsistent spectrogram time frames for channel " << channel;
}
// Skip remaining processing if there are too few input samples to trigger
// any output frames.
if (!output_vectors.empty()) {
// Translate the returned values into a matrix of output frames.
OutputMatrixType output_frames(num_output_channels_,
output_vectors.size());
for (int frame = 0; frame < output_vectors.size(); ++frame) {
Eigen::Map<const OutputMatrixType> frame_map(
&output_vectors[frame][0], output_vectors[frame].size(), 1);
// The underlying dsp object returns squared magnitudes; here
// we optionally translate to linear magnitude or dB.
output_frames.col(frame) =
output_scale_ * postprocess_output_fn(frame_map);
}
spectrogram_matrices->push_back(output_frames);
}
}
// If the input is very short, there may not be enough accumulated,
// unprocessed samples to cause any new frames to be generated by
// the spectrogram object. If so, we don't want to emit
// a packet at all.
if (!spectrogram_matrices->empty()) {
RET_CHECK_EQ(spectrogram_matrices->size(), input_stream.rows())
<< "Inconsistent number of spectrogram channels.";
if (allow_multichannel_input_) {
cc->Outputs().Index(0).Add(spectrogram_matrices.release(),
CurrentOutputTimestamp(cc));
} else {
cc->Outputs().Index(0).Add(
new OutputMatrixType(spectrogram_matrices->at(0)),
CurrentOutputTimestamp(cc));
}
cumulative_completed_frames_ += output_vectors.size();
last_completed_frames_ = output_vectors.size();
if (!use_local_timestamp_) {
// In non-local timestamp mode the timestamp of the next packet will be
// equal to CumulativeOutputTimestamp(). Inform the framework about this
// fact to enable packet queueing optimizations.
cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
}
}
return absl::OkStatus();
}
absl::Status SpectrogramCalculator::ProcessVector(const Matrix& input_stream,
CalculatorContext* cc) {
switch (output_type_) {
// These blocks deliberately ignore clang-format to preserve the
// "silhouette" of the different cases.
// clang-format off
case SpectrogramCalculatorOptions::COMPLEX: {
return ProcessVectorToOutput(
input_stream,
+[](const Eigen::MatrixXcf& col) -> const Eigen::MatrixXcf {
return col;
}, cc);
}
case SpectrogramCalculatorOptions::SQUARED_MAGNITUDE: {
return ProcessVectorToOutput(
input_stream,
+[](const Matrix& col) -> const Matrix {
return col;
}, cc);
}
case SpectrogramCalculatorOptions::LINEAR_MAGNITUDE: {
return ProcessVectorToOutput(
input_stream,
+[](const Matrix& col) -> const Matrix {
return col.array().sqrt().matrix();
}, cc);
}
case SpectrogramCalculatorOptions::DECIBELS: {
return ProcessVectorToOutput(
input_stream,
+[](const Matrix& col) -> const Matrix {
return kLnPowerToDb * col.array().log().matrix();
}, cc);
}
// clang-format on
default: {
return absl::Status(absl::StatusCode::kInvalidArgument,
"Unrecognized spectrogram output type.");
}
}
}
absl::Status SpectrogramCalculator::Close(CalculatorContext* cc) {
if (cumulative_input_samples_ > 0 && pad_final_packet_) {
// We can flush any remaining samples by sending frame_step_samples - 1
// zeros to the Process method, and letting it do its thing,
// UNLESS we have fewer than one window's worth of samples, in which case
// we pad to exactly one frame_duration_samples.
// Release the memory for the Spectrogram objects.
int required_padding_samples = frame_step_samples() - 1;
if (cumulative_input_samples_ < frame_duration_samples_) {
required_padding_samples =
frame_duration_samples_ - cumulative_input_samples_;
}
return ProcessVector(
Matrix::Zero(num_input_channels_, required_padding_samples), cc);
}
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -1,76 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message SpectrogramCalculatorOptions {
extend CalculatorOptions {
optional SpectrogramCalculatorOptions ext = 76186688;
}
// Options mirror those of TimeSeriesFramerCalculator.
// Analysis window duration in seconds. Required. Must be greater than 0.
// (Note: the spectrogram DFT length will be the smallest power-of-2
// sample count that can hold this duration.)
optional double frame_duration_seconds = 1;
// Duration of overlap between adjacent windows.
// Hence, frame_rate = 1/(frame_duration_seconds - frame_overlap_seconds).
// Required that 0 <= frame_overlap_seconds < frame_duration_seconds.
optional double frame_overlap_seconds = 2 [default = 0.0];
// Whether to pad the final packet with zeros. If true, guarantees that
// all input samples will output. If set to false, any partial packet
// at the end of the stream will be dropped.
optional bool pad_final_packet = 3 [default = true];
// Output value type can be squared-magnitude, linear-magnitude,
// deciBels (dB, = 20*log10(linear_magnitude)), or std::complex.
enum OutputType {
SQUARED_MAGNITUDE = 0;
LINEAR_MAGNITUDE = 1;
DECIBELS = 2;
COMPLEX = 3;
}
optional OutputType output_type = 4 [default = SQUARED_MAGNITUDE];
// If set to true then the output will be a vector of spectrograms, one for
// each channel and the stream will have a MultiStreamTimeSeriesHeader.
optional bool allow_multichannel_input = 5 [default = false];
// Which window to use when computing the FFT.
enum WindowType {
HANN = 0;
HAMMING = 1;
COSINE = 2;
}
optional WindowType window_type = 6 [default = HANN];
// Support a fixed multiplicative scaling of the output. This is applied
// uniformly regardless of output type (i.e., even dBs are multiplied, not
// offset).
optional double output_scale = 7 [default = 1.0];
// If use_local_timestamp is true, the output packet's timestamp is based on
// the last sample of the packet and it's inferred from the latest input
// packet's timestamp. If false, the output packet's timestamp is based on
// the cumulative timestamping, which is inferred from the intial input
// timestamp and the cumulative number of samples.
optional bool use_local_timestamp = 8 [default = false];
}

View File

@ -1,895 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <cmath>
#include <complex>
#include <memory>
#include <numeric>
#include <string>
#include <vector>
#include "Eigen/Core"
#include "audio/dsp/number_util.h"
#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/benchmark.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/time_series_test_util.h"
namespace mediapipe {
namespace {
const int kInitialTimestampOffsetMicroseconds = 4;
class SpectrogramCalculatorTest
: public TimeSeriesCalculatorTest<SpectrogramCalculatorOptions> {
protected:
void SetUp() override {
calculator_name_ = "SpectrogramCalculator";
input_sample_rate_ = 4000.0;
num_input_channels_ = 1;
}
// Initializes and runs the test graph.
absl::Status Run() {
// Now that options are set, we can set up some internal constants.
frame_duration_samples_ =
round(options_.frame_duration_seconds() * input_sample_rate_);
frame_step_samples_ =
frame_duration_samples_ -
round(options_.frame_overlap_seconds() * input_sample_rate_);
// The magnitude of the 0th FFT bin (DC) should be sum(input.*window);
// for an input identically 1.0, this is just sum(window). The average
// value of our Hann window is 0.5, hence this is the expected squared-
// magnitude output value in the DC bin for constant input of 1.0.
expected_dc_squared_magnitude_ =
pow((static_cast<float>(frame_duration_samples_) * 0.5), 2.0);
return RunGraph();
}
// Creates test multichannel input with specified packet sizes and containing
// a constant-frequency sinusoid that maintains phase between adjacent
// packets.
void SetupCosineInputPackets(const std::vector<int>& packet_sizes_samples,
float cosine_frequency_hz) {
int total_num_input_samples = 0;
for (int packet_size_samples : packet_sizes_samples) {
double packet_start_time_seconds =
kInitialTimestampOffsetMicroseconds * 1e-6 +
total_num_input_samples / input_sample_rate_;
double packet_end_time_seconds =
packet_start_time_seconds + packet_size_samples / input_sample_rate_;
double angular_freq = 2 * M_PI * cosine_frequency_hz;
Matrix* packet_data =
new Matrix(num_input_channels_, packet_size_samples);
// Use Eigen's vectorized cos() function to fill the vector with a
// sinusoid of appropriate frequency & phase.
for (int i = 0; i < num_input_channels_; i++) {
packet_data->row(i) =
Eigen::ArrayXf::LinSpaced(packet_size_samples,
packet_start_time_seconds * angular_freq,
packet_end_time_seconds * angular_freq)
.cos()
.transpose();
}
int64 input_timestamp = round(packet_start_time_seconds *
Timestamp::kTimestampUnitsPerSecond);
AppendInputPacket(packet_data, input_timestamp);
total_num_input_samples += packet_size_samples;
}
}
// Setup a sequence of input packets of specified sizes, each filled
// with samples of 1.0.
void SetupConstantInputPackets(const std::vector<int>& packet_sizes_samples) {
// A 0 Hz cosine is identically 1.0 for all samples.
SetupCosineInputPackets(packet_sizes_samples, 0.0);
}
// Setup a sequence of input packets of specified sizes, each containing a
// single sample of 1.0 at a specified offset.
void SetupImpulseInputPackets(
const std::vector<int>& packet_sizes_samples,
const std::vector<int>& impulse_offsets_samples) {
int total_num_input_samples = 0;
for (int i = 0; i < packet_sizes_samples.size(); ++i) {
double packet_start_time_seconds =
kInitialTimestampOffsetMicroseconds * 1e-6 +
total_num_input_samples / input_sample_rate_;
int64 input_timestamp = round(packet_start_time_seconds *
Timestamp::kTimestampUnitsPerSecond);
std::unique_ptr<Matrix> impulse(
new Matrix(Matrix::Zero(1, packet_sizes_samples[i])));
(*impulse)(0, impulse_offsets_samples[i]) = 1.0;
AppendInputPacket(impulse.release(), input_timestamp);
total_num_input_samples += packet_sizes_samples[i];
}
}
// Creates test multichannel input with specified packet sizes and containing
// constant input packets for the even channels and constant-frequency
// sinusoid that maintains phase between adjacent packets for the odd
// channels.
void SetupMultichannelInputPackets(
const std::vector<int>& packet_sizes_samples, float cosine_frequency_hz) {
int total_num_input_samples = 0;
for (int packet_size_samples : packet_sizes_samples) {
double packet_start_time_seconds =
kInitialTimestampOffsetMicroseconds * 1e-6 +
total_num_input_samples / input_sample_rate_;
double packet_end_time_seconds =
packet_start_time_seconds + packet_size_samples / input_sample_rate_;
double angular_freq;
Matrix* packet_data =
new Matrix(num_input_channels_, packet_size_samples);
// Use Eigen's vectorized cos() function to fill the vector with a
// sinusoid of appropriate frequency & phase.
for (int i = 0; i < num_input_channels_; i++) {
if (i % 2 == 0) {
angular_freq = 0;
} else {
angular_freq = 2 * M_PI * cosine_frequency_hz;
}
packet_data->row(i) =
Eigen::ArrayXf::LinSpaced(packet_size_samples,
packet_start_time_seconds * angular_freq,
packet_end_time_seconds * angular_freq)
.cos()
.transpose();
}
int64 input_timestamp = round(packet_start_time_seconds *
Timestamp::kTimestampUnitsPerSecond);
AppendInputPacket(packet_data, input_timestamp);
total_num_input_samples += packet_size_samples;
}
}
// Return vector of the numbers of frames in each output packet.
std::vector<int> OutputFramesPerPacket() {
std::vector<int> frame_counts;
for (const Packet& packet : output().packets) {
const Matrix& matrix = packet.Get<Matrix>();
frame_counts.push_back(matrix.cols());
}
return frame_counts;
}
// Checks output headers and Timestamps.
void CheckOutputHeadersAndTimestamps() {
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
TimeSeriesHeader expected_header = input().header.Get<TimeSeriesHeader>();
expected_header.set_num_channels(fft_size / 2 + 1);
// The output header sample rate should depend on the output frame step.
expected_header.set_sample_rate(input_sample_rate_ / frame_step_samples_);
// SpectrogramCalculator stores the sample rate of the input in
// the TimeSeriesHeader.
expected_header.set_audio_sample_rate(input_sample_rate_);
// We expect the output header to have num_samples and packet_rate unset.
expected_header.clear_num_samples();
expected_header.clear_packet_rate();
if (!options_.allow_multichannel_input()) {
ExpectOutputHeaderEquals(expected_header);
} else {
EXPECT_THAT(output()
.header.template Get<MultiStreamTimeSeriesHeader>()
.time_series_header(),
mediapipe::EqualsProto(expected_header));
EXPECT_THAT(output()
.header.template Get<MultiStreamTimeSeriesHeader>()
.num_streams(),
num_input_channels_);
}
int cumulative_output_frames = 0;
// The timestamps coming out of the spectrogram correspond to the
// middle of the first frame's window, hence frame_duration_samples_/2
// term. We use frame_duration_samples_ because that is how it is
// actually quantized inside spectrogram.
const double packet_timestamp_offset_seconds =
kInitialTimestampOffsetMicroseconds * 1e-6;
const double frame_step_seconds = frame_step_samples_ / input_sample_rate_;
Timestamp initial_timestamp = Timestamp::Unstarted();
for (const Packet& packet : output().packets) {
// This is the timestamp we expect based on how the spectrogram should
// behave (advancing by one step's worth of input samples each frame).
const double expected_timestamp_seconds =
packet_timestamp_offset_seconds +
cumulative_output_frames * frame_step_seconds;
const int64 expected_timestamp_ticks =
expected_timestamp_seconds * Timestamp::kTimestampUnitsPerSecond;
EXPECT_EQ(expected_timestamp_ticks, packet.Timestamp().Value());
// Accept the timestamp of the first packet as the baseline for checking
// the remainder.
if (initial_timestamp == Timestamp::Unstarted()) {
initial_timestamp = packet.Timestamp();
}
// Also check that the timestamp is consistent with the sample_rate
// in the output stream's TimeSeriesHeader.
EXPECT_TRUE(time_series_util::LogWarningIfTimestampIsInconsistent(
packet.Timestamp(), initial_timestamp, cumulative_output_frames,
expected_header.sample_rate()));
if (!options_.allow_multichannel_input()) {
if (options_.output_type() == SpectrogramCalculatorOptions::COMPLEX) {
const Eigen::MatrixXcf& matrix = packet.Get<Eigen::MatrixXcf>();
cumulative_output_frames += matrix.cols();
} else {
const Matrix& matrix = packet.Get<Matrix>();
cumulative_output_frames += matrix.cols();
}
} else {
if (options_.output_type() == SpectrogramCalculatorOptions::COMPLEX) {
const Eigen::MatrixXcf& matrix =
packet.Get<std::vector<Eigen::MatrixXcf>>().at(0);
cumulative_output_frames += matrix.cols();
} else {
const Matrix& matrix = packet.Get<std::vector<Matrix>>().at(0);
cumulative_output_frames += matrix.cols();
}
}
}
}
// Verify that the bin corresponding to the specified frequency
// is the largest one in one particular frame of a single packet.
void CheckPeakFrequencyInPacketFrame(const Packet& packet, int frame,
float frequency) {
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
const int target_bin =
round((frequency / input_sample_rate_) * static_cast<float>(fft_size));
const Matrix& matrix = packet.Get<Matrix>();
// Stop here if the requested frame is not in this packet.
ASSERT_GT(matrix.cols(), frame);
int actual_largest_bin;
matrix.col(frame).maxCoeff(&actual_largest_bin);
EXPECT_EQ(actual_largest_bin, target_bin);
}
// Verify that the bin corresponding to the specified frequency
// is the largest one in one particular frame of a single spectrogram Matrix.
void CheckPeakFrequencyInMatrix(const Matrix& matrix, int frame,
float frequency) {
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
const int target_bin =
round((frequency / input_sample_rate_) * static_cast<float>(fft_size));
// Stop here if the requested frame is not in this packet.
ASSERT_GT(matrix.cols(), frame);
int actual_largest_bin;
matrix.col(frame).maxCoeff(&actual_largest_bin);
EXPECT_EQ(actual_largest_bin, target_bin);
}
int frame_duration_samples_;
int frame_step_samples_;
// Expected DC output for a window of pure 1.0, set when window length
// is set.
float expected_dc_squared_magnitude_;
};
TEST_F(SpectrogramCalculatorTest, IntegerFrameDurationNoOverlap) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {500, 200};
const std::vector<int> expected_output_packet_sizes = {5, 2};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, IntegerFrameDurationSomeOverlap) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {500, 200};
// complete_output_frames = 1 + floor((input_samples - window_length)/step)
// = 1 + floor((500 - 100)/40) = 1 + 10 = 11 for the first packet
// = 1 + floor((700 - 100)/40) = 1 + 15 = 16 for the whole stream
// so expect 16 - 11 = 5 in the second packet.
const std::vector<int> expected_output_packet_sizes = {11, 5};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, NonintegerFrameDurationAndOverlap) {
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
options_.set_frame_overlap_seconds(58.4 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {500, 200};
// now frame_duration_samples will be 99 (rounded), and frame_step_samples
// will be (99-58) = 41, so the first packet of 500 samples will generate
// 1 + floor(500-99)/41 = 10 samples.
const std::vector<int> expected_output_packet_sizes = {10, 5};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, ShortInitialPacketNoOverlap) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {90, 100, 110};
// The first input packet is too small to generate any frames,
// but zero-length packets would result in a timestamp monotonicity
// violation, so they are suppressed. Thus, only the second and third
// input packets generate output packets.
const std::vector<int> expected_output_packet_sizes = {1, 2};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, TrailingSamplesNoPad) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {140, 90};
const std::vector<int> expected_output_packet_sizes = {2, 2};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, NoTrailingSamplesWithPad) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(true);
const std::vector<int> input_packet_sizes = {140, 80};
const std::vector<int> expected_output_packet_sizes = {2, 2};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, TrailingSamplesWithPad) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(true);
const std::vector<int> input_packet_sizes = {140, 90};
// In contrast to NoTrailingSamplesWithPad and TrailingSamplesNoPad,
// this time we get an extra frame in an extra final packet.
const std::vector<int> expected_output_packet_sizes = {2, 2, 1};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, VeryShortInputWillPad) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(true);
const std::vector<int> input_packet_sizes = {30};
const std::vector<int> expected_output_packet_sizes = {1};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, VeryShortInputZeroOutputFramesIfNoPad) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
const std::vector<int> input_packet_sizes = {90};
const std::vector<int> expected_output_packet_sizes = {};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, DCSignalIsPeakBin) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
const std::vector<int> input_packet_sizes = {140}; // Gives 2 output frames.
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
const float dc_frequency_hz = 0.0;
CheckPeakFrequencyInPacketFrame(output().packets[0], 0, dc_frequency_hz);
CheckPeakFrequencyInPacketFrame(output().packets[0], 1, dc_frequency_hz);
}
TEST_F(SpectrogramCalculatorTest, A440ToneIsPeakBin) {
const std::vector<int> input_packet_sizes = {
460}; // 100 + 9*40 for 10 frames.
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
int num_output_frames = output().packets[0].Get<Matrix>().cols();
for (int frame = 0; frame < num_output_frames; ++frame) {
CheckPeakFrequencyInPacketFrame(output().packets[0], frame,
tone_frequency_hz);
}
}
TEST_F(SpectrogramCalculatorTest, SquaredMagnitudeOutputLooksRight) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::SQUARED_MAGNITUDE);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
expected_dc_squared_magnitude_);
}
TEST_F(SpectrogramCalculatorTest, DefaultOutputIsSquaredMagnitude) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
// Let the output_type be its default
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
expected_dc_squared_magnitude_);
}
TEST_F(SpectrogramCalculatorTest, LinearMagnitudeOutputLooksRight) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::LINEAR_MAGNITUDE);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
std::sqrt(expected_dc_squared_magnitude_));
}
TEST_F(SpectrogramCalculatorTest, DbMagnitudeOutputLooksRight) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::DECIBELS);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
10.0 * std::log10(expected_dc_squared_magnitude_));
}
TEST_F(SpectrogramCalculatorTest, OutputScalingLooksRight) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::DECIBELS);
double output_scale = 2.5;
options_.set_output_scale(output_scale);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(
output().packets[0].Get<Matrix>()(0, 0),
output_scale * 10.0 * std::log10(expected_dc_squared_magnitude_));
}
TEST_F(SpectrogramCalculatorTest, ComplexOutputLooksRight) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Setup packets with DC input (non-zero constant value).
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_FLOAT_EQ(std::norm(output().packets[0].Get<Eigen::MatrixXcf>()(0, 0)),
expected_dc_squared_magnitude_);
}
TEST_F(SpectrogramCalculatorTest, ComplexOutputLooksRightForImpulses) {
const int frame_size_samples = 100;
options_.set_frame_duration_seconds(frame_size_samples / input_sample_rate_);
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
const std::vector<int> input_packet_sizes = {frame_size_samples,
frame_size_samples};
const std::vector<int> input_packet_impulse_offsets = {49, 50};
InitializeGraph();
FillInputHeader();
// Make two impulse packets offset one sample from each other
SetupImpulseInputPackets(input_packet_sizes, input_packet_impulse_offsets);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
const int num_buckets =
(audio_dsp::NextPowerOfTwo(frame_size_samples) / 2) + 1;
const float precision = 0.01f;
auto norm_fn = [](const std::complex<float>& cf) { return std::norm(cf); };
// Both impulses should have (approximately) constant power across all
// frequency bins
EXPECT_TRUE(output()
.packets[0]
.Get<Eigen::MatrixXcf>()
.unaryExpr(norm_fn)
.isApproxToConstant(1.0f, precision));
EXPECT_TRUE(output()
.packets[1]
.Get<Eigen::MatrixXcf>()
.unaryExpr(norm_fn)
.isApproxToConstant(1.0f, precision));
// Because the second Packet's impulse is delayed by exactly one sample with
// respect to the first Packet's impulse, the second impulse should have
// greater phase, and in the highest frequency bin, the real part should
// (approximately) flip sign from the first Packet to the second
EXPECT_LT(std::arg(output().packets[0].Get<Eigen::MatrixXcf>()(1, 0)),
std::arg(output().packets[1].Get<Eigen::MatrixXcf>()(1, 0)));
const float highest_bucket_real_ratio =
output().packets[0].Get<Eigen::MatrixXcf>()(num_buckets - 1, 0).real() /
output().packets[1].Get<Eigen::MatrixXcf>()(num_buckets - 1, 0).real();
EXPECT_NEAR(highest_bucket_real_ratio, -1.0f, precision);
}
TEST_F(SpectrogramCalculatorTest, SquaredMagnitudeOutputLooksRightForNonDC) {
const int frame_size_samples = 100;
options_.set_frame_duration_seconds(frame_size_samples / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_output_type(SpectrogramCalculatorOptions::SQUARED_MAGNITUDE);
const std::vector<int> input_packet_sizes = {140};
InitializeGraph();
FillInputHeader();
// Make the tone have an integral number of cycles within the window
const int target_bin = 16;
const int fft_size = audio_dsp::NextPowerOfTwo(frame_size_samples);
const float tone_frequency_hz = target_bin * (input_sample_rate_ / fft_size);
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
// For a non-DC bin, the magnitude will be split between positive and
// negative frequency bins, so it should about be half-magnitude
// = quarter-power.
// It's not quite exact because of the interference from the hann(100)
// spread from the negative-frequency half.
EXPECT_GT(output().packets[0].Get<Matrix>()(target_bin, 0),
0.98 * expected_dc_squared_magnitude_ / 4.0);
EXPECT_LT(output().packets[0].Get<Matrix>()(target_bin, 0),
1.02 * expected_dc_squared_magnitude_ / 4.0);
}
TEST_F(SpectrogramCalculatorTest, ZeroOutputsForZeroInputsWithPaddingEnabled) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(true);
const std::vector<int> input_packet_sizes = {};
const std::vector<int> expected_output_packet_sizes = {};
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
}
TEST_F(SpectrogramCalculatorTest, NumChannelsIsRight) {
const std::vector<int> input_packet_sizes = {460};
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
options_.set_allow_multichannel_input(true);
num_input_channels_ = 3;
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
EXPECT_EQ(output().packets[0].Get<std::vector<Matrix>>().size(),
num_input_channels_);
}
TEST_F(SpectrogramCalculatorTest, NumSamplesAndPacketRateAreCleared) {
num_input_samples_ = 500;
input_packet_rate_ = 1.0;
const std::vector<int> input_packet_sizes = {num_input_samples_};
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(0.0);
options_.set_pad_final_packet(false);
InitializeGraph();
FillInputHeader();
SetupConstantInputPackets(input_packet_sizes);
MP_ASSERT_OK(Run());
const TimeSeriesHeader& output_header =
output().header.Get<TimeSeriesHeader>();
EXPECT_FALSE(output_header.has_num_samples());
EXPECT_FALSE(output_header.has_packet_rate());
}
TEST_F(SpectrogramCalculatorTest, MultichannelSpectrogramSizesAreRight) {
const std::vector<int> input_packet_sizes = {420}; // less than 10 frames
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
options_.set_allow_multichannel_input(true);
num_input_channels_ = 10;
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
int spectrogram_num_rows = spectrograms[0].rows();
int spectrogram_num_cols = spectrograms[0].cols();
for (int i = 1; i < num_input_channels_; i++) {
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
}
}
TEST_F(SpectrogramCalculatorTest, MultichannelSpectrogramValuesAreRight) {
const std::vector<int> input_packet_sizes = {
460}; // 100 + 9*40 for 10 frames.
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_allow_multichannel_input(true);
num_input_channels_ = 10;
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupMultichannelInputPackets(input_packet_sizes, tone_frequency_hz);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
int num_output_frames = spectrograms[0].cols();
for (int i = 0; i < num_input_channels_; i++) {
for (int frame = 0; frame < num_output_frames; ++frame) {
if (i % 2 == 0) {
CheckPeakFrequencyInMatrix(spectrograms[i], frame, 0);
} else {
CheckPeakFrequencyInMatrix(spectrograms[i], frame, tone_frequency_hz);
}
}
}
}
TEST_F(SpectrogramCalculatorTest, MultichannelHandlesShortInitialPacket) {
// First packet is less than one frame, but second packet should trigger a
// complete frame from all channels.
const std::vector<int> input_packet_sizes = {50, 50};
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
options_.set_allow_multichannel_input(true);
num_input_channels_ = 2;
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
int spectrogram_num_rows = spectrograms[0].rows();
int spectrogram_num_cols = spectrograms[0].cols();
for (int i = 1; i < num_input_channels_; i++) {
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
}
}
TEST_F(SpectrogramCalculatorTest,
MultichannelComplexHandlesShortInitialPacket) {
// First packet is less than one frame, but second packet should trigger a
// complete frame from all channels, even for complex output.
const std::vector<int> input_packet_sizes = {50, 50};
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
options_.set_pad_final_packet(false);
options_.set_allow_multichannel_input(true);
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
num_input_channels_ = 2;
InitializeGraph();
FillInputHeader();
const float tone_frequency_hz = 440.0;
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
MP_ASSERT_OK(Run());
CheckOutputHeadersAndTimestamps();
auto spectrograms = output().packets[0].Get<std::vector<Eigen::MatrixXcf>>();
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
int spectrogram_num_rows = spectrograms[0].rows();
int spectrogram_num_cols = spectrograms[0].cols();
for (int i = 1; i < num_input_channels_; i++) {
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
}
}
void BM_ProcessDC(benchmark::State& state) {
CalculatorGraphConfig::Node node_config;
node_config.set_calculator("SpectrogramCalculator");
node_config.add_input_stream("input_audio");
node_config.add_output_stream("output_spectrogram");
SpectrogramCalculatorOptions* options =
node_config.mutable_options()->MutableExtension(
SpectrogramCalculatorOptions::ext);
options->set_frame_duration_seconds(0.010);
options->set_frame_overlap_seconds(0.0);
options->set_pad_final_packet(false);
*node_config.mutable_options()->MutableExtension(
SpectrogramCalculatorOptions::ext) = *options;
int num_input_channels = 1;
int packet_size_samples = 1600000;
TimeSeriesHeader* header = new TimeSeriesHeader();
header->set_sample_rate(16000.0);
header->set_num_channels(num_input_channels);
CalculatorRunner runner(node_config);
runner.MutableInputs()->Index(0).header = Adopt(header);
Matrix* payload = new Matrix(
Matrix::Constant(num_input_channels, packet_size_samples, 1.0));
Timestamp timestamp = Timestamp(0);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(payload).At(timestamp));
for (auto _ : state) {
ASSERT_TRUE(runner.Run().ok());
}
const CalculatorRunner::StreamContents& output = runner.Outputs().Index(0);
const Matrix& output_matrix = output.packets[0].Get<Matrix>();
LOG(INFO) << "Output matrix=" << output_matrix.rows() << "x"
<< output_matrix.cols();
LOG(INFO) << "First values=" << output_matrix(0, 0) << ", "
<< output_matrix(1, 0) << ", " << output_matrix(2, 0) << ", "
<< output_matrix(3, 0);
}
BENCHMARK(BM_ProcessDC);
} // anonymous namespace
} // namespace mediapipe

View File

@ -1,99 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Defines StabilizedLogCalculator.
#include <cmath>
#include <memory>
#include <string>
#include "mediapipe/calculators/audio/stabilized_log_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
// Example config:
// node {
// calculator: "StabilizedLogCalculator"
// input_stream: "input_time_series"
// output_stream: "stabilized_log_time_series"
// options {
// [mediapipe.StabilizedLogCalculatorOptions.ext] {
// stabilizer: .00001
// check_nonnegativity: true
// }
// }
// }
class StabilizedLogCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Input stream with TimeSeriesHeader.
);
cc->Outputs().Index(0).Set<Matrix>(
// Output stabilized log stream with TimeSeriesHeader.
);
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
StabilizedLogCalculatorOptions stabilized_log_calculator_options =
cc->Options<StabilizedLogCalculatorOptions>();
stabilizer_ = stabilized_log_calculator_options.stabilizer();
output_scale_ = stabilized_log_calculator_options.output_scale();
check_nonnegativity_ =
stabilized_log_calculator_options.check_nonnegativity();
CHECK_GE(stabilizer_, 0.0)
<< "stabilizer must be >= 0.0, received a value of " << stabilizer_;
// If the input packets have a header, propagate the header to the output.
if (!cc->Inputs().Index(0).Header().IsEmpty()) {
TimeSeriesHeader input_header;
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), &input_header));
cc->Outputs().Index(0).SetHeader(
Adopt(new TimeSeriesHeader(input_header)));
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
auto input_matrix = cc->Inputs().Index(0).Get<Matrix>();
if (input_matrix.array().isNaN().any()) {
return absl::InvalidArgumentError("NaN input to log operation.");
}
if (check_nonnegativity_) {
if (input_matrix.minCoeff() < 0.0) {
return absl::OutOfRangeError("Negative input to log operation.");
}
}
std::unique_ptr<Matrix> output_frame(new Matrix(
output_scale_ * (input_matrix.array() + stabilizer_).log().matrix()));
cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp());
return absl::OkStatus();
}
private:
float stabilizer_;
bool check_nonnegativity_;
double output_scale_;
};
REGISTER_CALCULATOR(StabilizedLogCalculator);
} // namespace mediapipe

View File

@ -1,37 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message StabilizedLogCalculatorOptions {
extend CalculatorOptions {
optional StabilizedLogCalculatorOptions ext = 101978339;
}
// The calculator computes log(x + stabilizer). stabilizer must be >=
// 0, with 0 indicating a lack of stabilization.
optional float stabilizer = 1 [default = .00001];
// If true, CHECK that all input values in are >= 0. If false, the
// code will take the log of the potentially negative input values
// plus the stabilizer.
optional bool check_nonnegativity = 2 [default = true];
// Support a fixed multiplicative scaling of the output.
optional double output_scale = 3 [default = 1.0];
}

View File

@ -1,141 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cmath>
#include "Eigen/Core"
#include "mediapipe/calculators/audio/stabilized_log_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/util/time_series_test_util.h"
namespace mediapipe {
const float kStabilizer = 0.1;
const int kNumChannels = 3;
const int kNumSamples = 10;
class StabilizedLogCalculatorTest
: public TimeSeriesCalculatorTest<StabilizedLogCalculatorOptions> {
protected:
void SetUp() override {
calculator_name_ = "StabilizedLogCalculator";
options_.set_stabilizer(kStabilizer);
input_sample_rate_ = 8000.0;
num_input_channels_ = kNumChannels;
num_input_samples_ = kNumSamples;
}
void RunGraphNoReturn() { MP_ASSERT_OK(RunGraph()); }
};
TEST_F(StabilizedLogCalculatorTest, BasicOperation) {
const int kNumPackets = 5;
InitializeGraph();
FillInputHeader();
std::vector<Matrix> input_data_matrices;
for (int input_packet = 0; input_packet < kNumPackets; ++input_packet) {
const int64 timestamp = input_packet * Timestamp::kTimestampUnitsPerSecond;
Matrix input_data_matrix =
Matrix::Random(kNumChannels, kNumSamples).array().abs();
input_data_matrices.push_back(input_data_matrix);
AppendInputPacket(new Matrix(input_data_matrix), timestamp);
}
MP_ASSERT_OK(RunGraph());
ExpectOutputHeaderEqualsInputHeader();
for (int output_packet = 0; output_packet < kNumPackets; ++output_packet) {
ExpectApproximatelyEqual(
(input_data_matrices[output_packet].array() + kStabilizer).log(),
runner_->Outputs().Index(0).packets[output_packet].Get<Matrix>());
}
}
TEST_F(StabilizedLogCalculatorTest, OutputScaleWorks) {
const int kNumPackets = 5;
double output_scale = 2.5;
options_.set_output_scale(output_scale);
InitializeGraph();
FillInputHeader();
std::vector<Matrix> input_data_matrices;
for (int input_packet = 0; input_packet < kNumPackets; ++input_packet) {
const int64 timestamp = input_packet * Timestamp::kTimestampUnitsPerSecond;
Matrix input_data_matrix =
Matrix::Random(kNumChannels, kNumSamples).array().abs();
input_data_matrices.push_back(input_data_matrix);
AppendInputPacket(new Matrix(input_data_matrix), timestamp);
}
MP_ASSERT_OK(RunGraph());
ExpectOutputHeaderEqualsInputHeader();
for (int output_packet = 0; output_packet < kNumPackets; ++output_packet) {
ExpectApproximatelyEqual(
output_scale *
((input_data_matrices[output_packet].array() + kStabilizer).log()),
runner_->Outputs().Index(0).packets[output_packet].Get<Matrix>());
}
}
TEST_F(StabilizedLogCalculatorTest, ZerosAreStabilized) {
InitializeGraph();
FillInputHeader();
AppendInputPacket(new Matrix(Matrix::Zero(kNumChannels, kNumSamples)),
0 /* timestamp */);
MP_ASSERT_OK(RunGraph());
ExpectOutputHeaderEqualsInputHeader();
ExpectApproximatelyEqual(
Matrix::Constant(kNumChannels, kNumSamples, kStabilizer).array().log(),
runner_->Outputs().Index(0).packets[0].Get<Matrix>());
}
TEST_F(StabilizedLogCalculatorTest, NanValuesReturnError) {
InitializeGraph();
FillInputHeader();
AppendInputPacket(
new Matrix(Matrix::Constant(kNumChannels, kNumSamples, std::nanf(""))),
0 /* timestamp */);
ASSERT_FALSE(RunGraph().ok());
}
TEST_F(StabilizedLogCalculatorTest, NegativeValuesReturnError) {
InitializeGraph();
FillInputHeader();
AppendInputPacket(
new Matrix(Matrix::Constant(kNumChannels, kNumSamples, -1.0)),
0 /* timestamp */);
ASSERT_FALSE(RunGraph().ok());
}
TEST_F(StabilizedLogCalculatorTest, NegativeValuesDoNotCheckFailIfCheckIsOff) {
options_.set_check_nonnegativity(false);
InitializeGraph();
FillInputHeader();
AppendInputPacket(
new Matrix(Matrix::Constant(kNumChannels, kNumSamples, -1.0)),
0 /* timestamp */);
MP_ASSERT_OK(RunGraph());
// Results are undefined.
}
} // namespace mediapipe

View File

@ -1,27 +0,0 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
licenses(["notice"])
filegroup(
name = "test_audios",
srcs = [
"sine_wave_1k_44100_mono_2_sec_wav.audio",
"sine_wave_1k_44100_stereo_2_sec_aac.audio",
"sine_wave_1k_44100_stereo_2_sec_mp3.audio",
"sine_wave_1k_48000_stereo_2_sec_wav.audio",
],
visibility = ["//visibility:public"],
)

View File

@ -1,325 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Defines TimeSeriesFramerCalculator.
#include <math.h>
#include <deque>
#include <memory>
#include <string>
#include "Eigen/Core"
#include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
// MediaPipe Calculator for framing a (vector-valued) input time series,
// i.e. for breaking an input time series into fixed-size, possibly
// overlapping, frames. The output stream's frame duration is
// specified by frame_duration_seconds in the
// TimeSeriesFramerCalculatorOptions, and the output's overlap is
// specified by frame_overlap_seconds.
//
// This calculator assumes that the input timestamps refer to the
// first sample in each Matrix. The output timestamps follow this
// same convention.
//
// All output frames will have exactly the same number of samples: the number of
// samples that approximates frame_duration_seconds most closely.
//
// Similarly, frame overlap is by default the (fixed) number of samples
// approximating frame_overlap_seconds most closely. But if
// emulate_fractional_frame_overlap is set to true, frame overlap is a variable
// number of samples instead, such that the long-term average step between
// frames is the difference between the (nominal) frame_duration_seconds and
// frame_overlap_seconds.
//
// If pad_final_packet is true, all input samples will be emitted and the final
// packet will be zero padded as necessary. If pad_final_packet is false, some
// samples may be dropped at the end of the stream.
//
// If use_local_timestamp is true, the output packet's timestamp is based on the
// last sample of the packet. The timestamp of this sample is inferred by
// input_packet_timesamp + local_sample_index / sampling_rate_. If false, the
// output packet's timestamp is based on the cumulative timestamping, which is
// done by adopting the timestamp of the first sample of the packet and this
// sample's timestamp is inferred by initial_input_timestamp_ +
// cumulative_completed_samples / sample_rate_.
class TimeSeriesFramerCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Matrix>(
// Input stream with TimeSeriesHeader.
);
cc->Outputs().Index(0).Set<Matrix>(
// Fixed length time series Packets with TimeSeriesHeader.
);
return absl::OkStatus();
}
// Returns FAIL if the input stream header is invalid.
absl::Status Open(CalculatorContext* cc) override;
// Outputs as many framed packets as possible given the accumulated
// input. Always returns OK.
absl::Status Process(CalculatorContext* cc) override;
// Flushes any remaining samples in a zero-padded packet. Always
// returns OK.
absl::Status Close(CalculatorContext* cc) override;
private:
// Adds input data to the internal buffer.
void EnqueueInput(CalculatorContext* cc);
// Constructs and emits framed output packets.
void FrameOutput(CalculatorContext* cc);
Timestamp CurrentOutputTimestamp() {
if (use_local_timestamp_) {
return current_timestamp_;
}
return CumulativeOutputTimestamp();
}
Timestamp CumulativeOutputTimestamp() {
return initial_input_timestamp_ +
round(cumulative_completed_samples_ / sample_rate_ *
Timestamp::kTimestampUnitsPerSecond);
}
// Returns the timestamp of a sample on a base, which is usually the time
// stamp of a packet.
Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base,
int64 number_of_samples) {
return timestamp_base + round(number_of_samples / sample_rate_ *
Timestamp::kTimestampUnitsPerSecond);
}
// The number of input samples to advance after the current output frame is
// emitted.
int next_frame_step_samples() const {
// All numbers are in input samples.
const int64 current_output_frame_start = static_cast<int64>(
round(cumulative_output_frames_ * average_frame_step_samples_));
CHECK_EQ(current_output_frame_start, cumulative_completed_samples_);
const int64 next_output_frame_start = static_cast<int64>(
round((cumulative_output_frames_ + 1) * average_frame_step_samples_));
return next_output_frame_start - current_output_frame_start;
}
double sample_rate_;
bool pad_final_packet_;
int frame_duration_samples_;
// The advance, in input samples, between the start of successive output
// frames. This may be a non-integer average value if
// emulate_fractional_frame_overlap is true.
double average_frame_step_samples_;
int samples_still_to_drop_;
int64 cumulative_output_frames_;
// "Completed" samples are samples that are no longer needed because
// the framer has completely stepped past them (taking into account
// any overlap).
int64 cumulative_completed_samples_;
Timestamp initial_input_timestamp_;
// The current timestamp is updated along with the incoming packets.
Timestamp current_timestamp_;
int num_channels_;
// Each entry in this deque consists of a single sample, i.e. a
// single column vector, and its timestamp.
std::deque<std::pair<Matrix, Timestamp>> sample_buffer_;
bool use_window_;
Matrix window_;
bool use_local_timestamp_;
};
REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) {
const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>();
for (int i = 0; i < input_frame.cols(); ++i) {
sample_buffer_.emplace_back(std::make_pair(
input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i)));
}
}
void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
while (sample_buffer_.size() >=
frame_duration_samples_ + samples_still_to_drop_) {
while (samples_still_to_drop_ > 0) {
sample_buffer_.pop_front();
--samples_still_to_drop_;
}
const int frame_step_samples = next_frame_step_samples();
std::unique_ptr<Matrix> output_frame(
new Matrix(num_channels_, frame_duration_samples_));
for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_);
++i) {
output_frame->col(i) = sample_buffer_.front().first;
current_timestamp_ = sample_buffer_.front().second;
sample_buffer_.pop_front();
}
const int frame_overlap_samples =
frame_duration_samples_ - frame_step_samples;
if (frame_overlap_samples > 0) {
for (int i = 0; i < frame_overlap_samples; ++i) {
output_frame->col(i + frame_step_samples) = sample_buffer_[i].first;
current_timestamp_ = sample_buffer_[i].second;
}
} else {
samples_still_to_drop_ = -frame_overlap_samples;
}
if (use_window_) {
*output_frame = (output_frame->array() * window_.array()).matrix();
}
cc->Outputs().Index(0).Add(output_frame.release(),
CurrentOutputTimestamp());
++cumulative_output_frames_;
cumulative_completed_samples_ += frame_step_samples;
}
if (!use_local_timestamp_) {
// In non-local timestamp mode the timestamp of the next packet will be
// equal to CumulativeOutputTimestamp(). Inform the framework about this
// fact to enable packet queueing optimizations.
cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
}
}
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
initial_input_timestamp_ = cc->InputTimestamp();
current_timestamp_ = initial_input_timestamp_;
}
EnqueueInput(cc);
FrameOutput(cc);
return absl::OkStatus();
}
absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) {
sample_buffer_.pop_front();
--samples_still_to_drop_;
}
if (!sample_buffer_.empty() && pad_final_packet_) {
std::unique_ptr<Matrix> output_frame(new Matrix);
output_frame->setZero(num_channels_, frame_duration_samples_);
for (int i = 0; i < sample_buffer_.size(); ++i) {
output_frame->col(i) = sample_buffer_[i].first;
current_timestamp_ = sample_buffer_[i].second;
}
cc->Outputs().Index(0).Add(output_frame.release(),
CurrentOutputTimestamp());
}
return absl::OkStatus();
}
absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
TimeSeriesFramerCalculatorOptions framer_options =
cc->Options<TimeSeriesFramerCalculatorOptions>();
RET_CHECK_GT(framer_options.frame_duration_seconds(), 0.0)
<< "Invalid or missing frame_duration_seconds. "
<< "framer_duration_seconds: \n"
<< framer_options.frame_duration_seconds();
RET_CHECK_LT(framer_options.frame_overlap_seconds(),
framer_options.frame_duration_seconds())
<< "Invalid frame_overlap_seconds. framer_overlap_seconds: \n"
<< framer_options.frame_overlap_seconds();
TimeSeriesHeader input_header;
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
cc->Inputs().Index(0).Header(), &input_header));
sample_rate_ = input_header.sample_rate();
num_channels_ = input_header.num_channels();
frame_duration_samples_ = time_series_util::SecondsToSamples(
framer_options.frame_duration_seconds(), sample_rate_);
RET_CHECK_GT(frame_duration_samples_, 0)
<< "Frame duration of " << framer_options.frame_duration_seconds()
<< "s too small to cover a single sample at " << sample_rate_ << " Hz ";
if (framer_options.emulate_fractional_frame_overlap()) {
// Frame step may be fractional.
average_frame_step_samples_ = (framer_options.frame_duration_seconds() -
framer_options.frame_overlap_seconds()) *
sample_rate_;
} else {
// Frame step is an integer (stored in a double).
average_frame_step_samples_ =
frame_duration_samples_ -
time_series_util::SecondsToSamples(
framer_options.frame_overlap_seconds(), sample_rate_);
}
RET_CHECK_GE(average_frame_step_samples_, 1)
<< "Frame step too small to cover a single sample at " << sample_rate_
<< " Hz.";
pad_final_packet_ = framer_options.pad_final_packet();
auto output_header = new TimeSeriesHeader(input_header);
output_header->set_num_samples(frame_duration_samples_);
if (round(average_frame_step_samples_) == average_frame_step_samples_) {
// Only set output packet rate if it is fixed.
output_header->set_packet_rate(sample_rate_ / average_frame_step_samples_);
}
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
cumulative_completed_samples_ = 0;
cumulative_output_frames_ = 0;
samples_still_to_drop_ = 0;
initial_input_timestamp_ = Timestamp::Unstarted();
current_timestamp_ = Timestamp::Unstarted();
std::vector<double> window_vector;
use_window_ = false;
switch (framer_options.window_function()) {
case TimeSeriesFramerCalculatorOptions::HAMMING:
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
&window_vector);
use_window_ = true;
break;
case TimeSeriesFramerCalculatorOptions::HANN:
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
&window_vector);
use_window_ = true;
break;
case TimeSeriesFramerCalculatorOptions::NONE:
break;
}
if (use_window_) {
window_ = Matrix::Ones(num_channels_, 1) *
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
frame_duration_samples_)
.cast<float>();
}
use_local_timestamp_ = framer_options.use_local_timestamp();
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -1,72 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TimeSeriesFramerCalculatorOptions {
extend CalculatorOptions {
optional TimeSeriesFramerCalculatorOptions ext = 50631621;
}
// Frame duration in seconds. Required. Must be greater than 0. This is
// rounded to the nearest integer number of samples.
optional double frame_duration_seconds = 1;
// Frame overlap in seconds.
//
// If emulate_fractional_frame_overlap is false (the default), then the frame
// overlap is rounded to the nearest integer number of samples, and the step
// from one frame to the next will be the difference between the number of
// samples in a frame and the number of samples in the overlap.
//
// If emulate_fractional_frame_overlap is true, then frame overlap will be a
// variable number of samples, such that the long-time average time step from
// one frame to the next will be the difference between the (nominal, not
// rounded) frame_duration_seconds and frame_overlap_seconds. This is useful
// where the desired time step is not an integral number of input samples.
//
// A negative frame_overlap_seconds corresponds to skipping some input samples
// between each frame of emitted samples.
//
// Required that frame_overlap_seconds < frame_duration_seconds.
optional double frame_overlap_seconds = 2 [default = 0.0];
// See frame_overlap_seconds for semantics.
optional bool emulate_fractional_frame_overlap = 5 [default = false];
// Whether to pad the final packet with zeros. If true, guarantees that all
// input samples (other than those that fall in gaps implied by negative
// frame_overlap_seconds) will be emitted. If set to false, any partial
// packet at the end of the stream will be dropped.
optional bool pad_final_packet = 3 [default = true];
// Optional windowing function. The default is NONE (no windowing function).
enum WindowFunction {
NONE = 0;
HAMMING = 1;
HANN = 2;
}
optional WindowFunction window_function = 4 [default = NONE];
// If use_local_timestamp is true, the output packet's timestamp is based on
// the last sample of the packet and it's inferred from the latest input
// packet's timestamp. If false, the output packet's timestamp is based on
// the cumulative timestamping, which is inferred from the intial input
// timestamp and the cumulative number of samples.
optional bool use_local_timestamp = 6 [default = false];
}

View File

@ -1,485 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <memory>
#include <string>
#include <vector>
#include "Eigen/Core"
#include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/time_series_test_util.h"
namespace mediapipe {
namespace {
const int kInitialTimestampOffsetMicroseconds = 4;
const int kGapBetweenPacketsInSeconds = 1;
const int kUniversalInputPacketSize = 50;
class TimeSeriesFramerCalculatorTest
: public TimeSeriesCalculatorTest<TimeSeriesFramerCalculatorOptions> {
protected:
void SetUp() override {
calculator_name_ = "TimeSeriesFramerCalculator";
input_sample_rate_ = 4000.0;
num_input_channels_ = 3;
}
// Returns a float value with the channel and timestamp separated by
// an order of magnitude, for easy parsing by humans.
float TestValue(int64 timestamp_in_microseconds, int channel) {
return timestamp_in_microseconds + channel / 10.0;
}
// Caller takes ownership of the returned value.
Matrix* NewTestFrame(int num_channels, int num_samples,
double starting_timestamp_seconds) {
auto matrix = new Matrix(num_channels, num_samples);
for (int c = 0; c < num_channels; ++c) {
for (int i = 0; i < num_samples; ++i) {
int64 timestamp = time_series_util::SecondsToSamples(
starting_timestamp_seconds + i / input_sample_rate_,
Timestamp::kTimestampUnitsPerSecond);
(*matrix)(c, i) = TestValue(timestamp, c);
}
}
return matrix;
}
// Initializes and runs the test graph.
absl::Status Run() {
InitializeGraph();
FillInputHeader();
InitializeInput();
return RunGraph();
}
// Creates test input and saves a reference copy.
void InitializeInput() {
concatenated_input_samples_.resize(0, num_input_channels_);
num_input_samples_ = 0;
for (int i = 0; i < 10; ++i) {
// This range of packet sizes was chosen such that some input
// packets will be smaller than the output packet size and other
// input packets will be larger.
int packet_size = (i + 1) * 20;
double timestamp_seconds = kInitialTimestampOffsetMicroseconds * 1.0e-6 +
num_input_samples_ / input_sample_rate_;
Matrix* data_frame =
NewTestFrame(num_input_channels_, packet_size, timestamp_seconds);
// Keep a reference copy of the input.
//
// conservativeResize() is needed here to preserve the existing
// data. Eigen's resize() resizes without preserving data.
concatenated_input_samples_.conservativeResize(
num_input_channels_, num_input_samples_ + packet_size);
concatenated_input_samples_.rightCols(packet_size) = *data_frame;
num_input_samples_ += packet_size;
AppendInputPacket(data_frame, round(timestamp_seconds *
Timestamp::kTimestampUnitsPerSecond));
}
const int frame_duration_samples = FrameDurationSamples();
std::vector<double> window_vector;
switch (options_.window_function()) {
case TimeSeriesFramerCalculatorOptions::HAMMING:
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples,
&window_vector);
break;
case TimeSeriesFramerCalculatorOptions::HANN:
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples,
&window_vector);
break;
case TimeSeriesFramerCalculatorOptions::NONE:
window_vector.assign(frame_duration_samples, 1.0f);
break;
}
window_ = Matrix::Ones(num_input_channels_, 1) *
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
frame_duration_samples)
.cast<float>();
}
int FrameDurationSamples() {
return time_series_util::SecondsToSamples(options_.frame_duration_seconds(),
input_sample_rate_);
}
// Checks that the values in the framed output packets matches the
// appropriate values from the input.
void CheckOutputPacketValues(const Matrix& actual, int packet_num,
int frame_duration_samples,
double frame_step_samples,
int num_columns_to_check) {
ASSERT_EQ(frame_duration_samples, actual.cols());
Matrix expected = (concatenated_input_samples_
.block(0, round(frame_step_samples * packet_num),
num_input_channels_, num_columns_to_check)
.array() *
window_.leftCols(num_columns_to_check).array())
.matrix();
ExpectApproximatelyEqual(expected, actual.leftCols(num_columns_to_check));
}
// Checks output headers, Timestamps, and values.
void CheckOutput() {
const int frame_duration_samples = FrameDurationSamples();
const double frame_step_samples =
options_.emulate_fractional_frame_overlap()
? (options_.frame_duration_seconds() -
options_.frame_overlap_seconds()) *
input_sample_rate_
: frame_duration_samples -
time_series_util::SecondsToSamples(
options_.frame_overlap_seconds(), input_sample_rate_);
TimeSeriesHeader expected_header = input().header.Get<TimeSeriesHeader>();
expected_header.set_num_samples(frame_duration_samples);
if (!options_.emulate_fractional_frame_overlap() ||
frame_step_samples == round(frame_step_samples)) {
expected_header.set_packet_rate(input_sample_rate_ / frame_step_samples);
}
ExpectOutputHeaderEquals(expected_header);
int num_full_packets = output().packets.size();
if (options_.pad_final_packet()) {
num_full_packets -= 1;
}
for (int packet_num = 0; packet_num < num_full_packets; ++packet_num) {
const Packet& packet = output().packets[packet_num];
CheckOutputPacketValues(packet.Get<Matrix>(), packet_num,
frame_duration_samples, frame_step_samples,
frame_duration_samples);
}
// What is the effective time index of the final sample emitted?
// This includes accounting for the gaps when overlap is negative.
const int num_unique_output_samples =
round((output().packets.size() - 1) * frame_step_samples) +
frame_duration_samples;
LOG(INFO) << "packets.size()=" << output().packets.size()
<< " frame_duration_samples=" << frame_duration_samples
<< " frame_step_samples=" << frame_step_samples
<< " num_input_samples_=" << num_input_samples_
<< " num_unique_output_samples=" << num_unique_output_samples;
const int num_padding_samples =
num_unique_output_samples - num_input_samples_;
if (options_.pad_final_packet()) {
EXPECT_LT(num_padding_samples, frame_duration_samples);
// If the input ended during the dropped samples between the end of
// the last emitted frame and where the next one would begin, there
// can be fewer unique output points than input points, even with
// padding.
const int max_dropped_samples =
static_cast<int>(ceil(frame_step_samples - frame_duration_samples));
EXPECT_GE(num_padding_samples, std::min(0, -max_dropped_samples));
if (num_padding_samples > 0) {
// Check the non-padded part of the final packet.
const Matrix& final_matrix = output().packets.back().Get<Matrix>();
CheckOutputPacketValues(final_matrix, num_full_packets,
frame_duration_samples, frame_step_samples,
frame_duration_samples - num_padding_samples);
// Check the padded part of the final packet.
EXPECT_EQ(
Matrix::Zero(num_input_channels_, num_padding_samples),
final_matrix.block(0, frame_duration_samples - num_padding_samples,
num_input_channels_, num_padding_samples));
}
} else {
EXPECT_GT(num_padding_samples, -frame_duration_samples);
EXPECT_LE(num_padding_samples, 0);
}
}
int num_input_samples_;
Matrix concatenated_input_samples_;
Matrix window_;
};
TEST_F(TimeSeriesFramerCalculatorTest, IntegerSampleDurationNoOverlap) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
MP_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest,
IntegerSampleDurationNoOverlapHammingWindow) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HAMMING);
MP_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest,
IntegerSampleDurationNoOverlapHannWindow) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HANN);
MP_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, IntegerSampleDurationAndOverlap) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(40.0 / input_sample_rate_);
MP_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, NonintegerSampleDurationAndOverlap) {
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
options_.set_frame_overlap_seconds(38.4 / input_sample_rate_);
MP_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapExactFrames) {
// Negative overlap means to drop samples between frames.
// 100 samples per frame plus a skip of 10 samples will be 10 full frames in
// the 1100 input samples.
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(-10.0 / input_sample_rate_);
MP_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 10);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapExactFramesLessSkip) {
// 100 samples per frame plus a skip of 100 samples will be 6 full frames in
// the 1100 input samples.
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(-100.0 / input_sample_rate_);
MP_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 6);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapWithPadding) {
// 150 samples per frame plus a skip of 50 samples will require some padding
// on the sixth and last frame given 1100 sample input.
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(-100.0 / input_sample_rate_);
MP_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 6);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, FixedFrameOverlap) {
// Frame of 30 samples with step of 11.4 samples (rounded to 11 samples)
// results in ceil((1100 - 30) / 11) + 1 = 99 packets.
options_.set_frame_duration_seconds(30 / input_sample_rate_);
options_.set_frame_overlap_seconds((30.0 - 11.4) / input_sample_rate_);
MP_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 99);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, VariableFrameOverlap) {
// Frame of 30 samples with step of 11.4 samples (not rounded)
// results in ceil((1100 - 30) / 11.4) + 1 = 95 packets.
options_.set_frame_duration_seconds(30 / input_sample_rate_);
options_.set_frame_overlap_seconds((30 - 11.4) / input_sample_rate_);
options_.set_emulate_fractional_frame_overlap(true);
MP_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 95);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, VariableFrameSkip) {
// Frame of 30 samples with step of 41.4 samples (not rounded)
// results in ceil((1100 - 30) / 41.4) + 1 = 27 packets.
options_.set_frame_duration_seconds(30 / input_sample_rate_);
options_.set_frame_overlap_seconds((30 - 41.4) / input_sample_rate_);
options_.set_emulate_fractional_frame_overlap(true);
MP_ASSERT_OK(Run());
EXPECT_EQ(output().packets.size(), 27);
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest, NoFinalPacketPadding) {
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
options_.set_pad_final_packet(false);
MP_ASSERT_OK(Run());
CheckOutput();
}
TEST_F(TimeSeriesFramerCalculatorTest,
FrameRateHigherThanSampleRate_FrameDurationTooLow) {
// Try to produce a frame rate 10 times the input sample rate by using a
// a frame duration that is too small and covers only 0.1 samples.
options_.set_frame_duration_seconds(1 / (10 * input_sample_rate_));
options_.set_frame_overlap_seconds(0.0);
EXPECT_FALSE(Run().ok());
}
TEST_F(TimeSeriesFramerCalculatorTest,
FrameRateHigherThanSampleRate_FrameStepTooLow) {
// Try to produce a frame rate 10 times the input sample rate by using
// a frame overlap that is too high and produces frame steps (difference
// between duration and overlap) of 0.1 samples.
options_.set_frame_duration_seconds(10.0 / input_sample_rate_);
options_.set_frame_overlap_seconds(9.9 / input_sample_rate_);
EXPECT_FALSE(Run().ok());
}
// A simple test class to do windowing sanity checks. Tests from this
// class input a single packet of all ones, and check the average
// value of the single output packet. This is useful as a sanity check
// that the correct windows are applied.
class TimeSeriesFramerCalculatorWindowingSanityTest
: public TimeSeriesFramerCalculatorTest {
protected:
void SetUp() override {
TimeSeriesFramerCalculatorTest::SetUp();
num_input_channels_ = 1;
}
void RunAndTestSinglePacketAverage(float expected_average) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
InitializeGraph();
FillInputHeader();
AppendInputPacket(new Matrix(Matrix::Ones(1, FrameDurationSamples())),
kInitialTimestampOffsetMicroseconds);
MP_ASSERT_OK(RunGraph());
ASSERT_EQ(1, output().packets.size());
ASSERT_NEAR(expected_average * FrameDurationSamples(),
output().packets[0].Get<Matrix>().sum(), 1e-5);
}
};
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, NoWindowSanityCheck) {
RunAndTestSinglePacketAverage(1.0f);
}
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest,
HammingWindowSanityCheck) {
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HAMMING);
RunAndTestSinglePacketAverage(0.54f);
}
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, HannWindowSanityCheck) {
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HANN);
RunAndTestSinglePacketAverage(0.5f);
}
// A simple test class that checks the local packet time stamp. This class
// generate a series of packets with and without gaps between packets and tests
// the behavior with cumulative timestamping and local packet timestamping.
class TimeSeriesFramerCalculatorTimestampingTest
: public TimeSeriesFramerCalculatorTest {
protected:
// Creates test input and saves a reference copy.
void InitializeInputForTimeStampingTest() {
concatenated_input_samples_.resize(0, num_input_channels_);
num_input_samples_ = 0;
for (int i = 0; i < 10; ++i) {
// This range of packet sizes was chosen such that some input
// packets will be smaller than the output packet size and other
// input packets will be larger.
int packet_size = kUniversalInputPacketSize;
double timestamp_seconds = kInitialTimestampOffsetMicroseconds * 1.0e-6 +
num_input_samples_ / input_sample_rate_;
if (options_.use_local_timestamp()) {
timestamp_seconds += kGapBetweenPacketsInSeconds * i;
}
Matrix* data_frame =
NewTestFrame(num_input_channels_, packet_size, timestamp_seconds);
AppendInputPacket(data_frame, round(timestamp_seconds *
Timestamp::kTimestampUnitsPerSecond));
num_input_samples_ += packet_size;
}
}
void CheckOutputTimestamps() {
int num_full_packets = output().packets.size();
if (options_.pad_final_packet()) {
num_full_packets -= 1;
}
int64 num_samples = 0;
for (int packet_num = 0; packet_num < num_full_packets; ++packet_num) {
const Packet& packet = output().packets[packet_num];
num_samples += FrameDurationSamples();
double expected_timestamp =
options_.use_local_timestamp()
? GetExpectedLocalTimestampForSample(num_samples - 1)
: GetExpectedCumulativeTimestamp(num_samples - 1);
ASSERT_NEAR(packet.Timestamp().Seconds(), expected_timestamp, 1e-10);
}
}
absl::Status RunTimestampTest() {
InitializeGraph();
InitializeInputForTimeStampingTest();
FillInputHeader();
return RunGraph();
}
private:
// Returns the timestamp in seconds based on local timestamping.
double GetExpectedLocalTimestampForSample(int sample_index) {
return kInitialTimestampOffsetMicroseconds * 1.0e-6 +
sample_index / input_sample_rate_ +
(sample_index / kUniversalInputPacketSize) *
kGapBetweenPacketsInSeconds;
}
// Returns the timestamp inseconds based on cumulative timestamping.
double GetExpectedCumulativeTimestamp(int sample_index) {
return kInitialTimestampOffsetMicroseconds * 1.0e-6 +
sample_index / FrameDurationSamples() * FrameDurationSamples() /
input_sample_rate_;
}
};
TEST_F(TimeSeriesFramerCalculatorTimestampingTest, UseLocalTimeStamp) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_use_local_timestamp(true);
MP_ASSERT_OK(RunTimestampTest());
CheckOutputTimestamps();
}
TEST_F(TimeSeriesFramerCalculatorTimestampingTest, UseCumulativeTimeStamp) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_use_local_timestamp(false);
MP_ASSERT_OK(RunTimestampTest());
CheckOutputTimestamps();
}
} // namespace
} // namespace mediapipe

File diff suppressed because it is too large Load Diff

View File

@ -1,84 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h"
namespace mediapipe {
namespace api2 {
// Attach the header from a stream or side input to another stream.
//
// The header stream (tag HEADER) must not have any packets in it.
//
// Before using this calculator, please think about changing your
// calculator to not need a header or to accept a separate stream with
// a header, that would be more future proof.
//
// Example usage 1:
// node {
// calculator: "AddHeaderCalculator"
// input_stream: "DATA:audio"
// input_stream: "HEADER:audio_header"
// output_stream: "audio_with_header"
// }
//
// Example usage 2:
// node {
// calculator: "AddHeaderCalculator"
// input_stream: "DATA:audio"
// input_side_packet: "HEADER:audio_header"
// output_stream: "audio_with_header"
// }
//
class AddHeaderCalculator : public Node {
public:
static constexpr Input<NoneType>::Optional kHeader{"HEADER"};
static constexpr SideInput<AnyType>::Optional kHeaderSide{"HEADER"};
static constexpr Input<AnyType> kData{"DATA"};
static constexpr Output<SameType<kData>> kOut{""};
MEDIAPIPE_NODE_CONTRACT(kHeader, kHeaderSide, kData, kOut);
static absl::Status UpdateContract(CalculatorContract* cc) {
if (kHeader(cc).IsConnected() == kHeaderSide(cc).IsConnected()) {
return absl::InvalidArgumentError(
"Header must be provided via exactly one of side input and input "
"stream");
}
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
const PacketBase& header =
kHeader(cc).IsConnected() ? kHeader(cc).Header() : kHeaderSide(cc);
if (!header.IsEmpty()) {
kOut(cc).SetHeader(header);
}
cc->SetOffset(0);
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
kOut(cc).Send(kData(cc).packet());
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(AddHeaderCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -1,159 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
class AddHeaderCalculatorTest : public ::testing::Test {};
TEST_F(AddHeaderCalculatorTest, HeaderStream) {
CalculatorGraphConfig::Node node;
node.set_calculator("AddHeaderCalculator");
node.add_input_stream("HEADER:header_stream");
node.add_input_stream("DATA:data_stream");
node.add_output_stream("merged_stream");
CalculatorRunner runner(node);
// Set header and add 5 packets.
runner.MutableInputs()->Tag("HEADER").header =
Adopt(new std::string("my_header"));
for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
}
// Run calculator.
MP_ASSERT_OK(runner.Run());
ASSERT_EQ(1, runner.Outputs().NumEntries());
// Test output.
EXPECT_EQ(std::string("my_header"),
runner.Outputs().Index(0).header.Get<std::string>());
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
ASSERT_EQ(5, output_packets.size());
for (int i = 0; i < 5; ++i) {
const int val = output_packets[i].Get<int>();
EXPECT_EQ(i, val);
EXPECT_EQ(Timestamp(i * 1000), output_packets[i].Timestamp());
}
}
TEST_F(AddHeaderCalculatorTest, HandlesEmptyHeaderStream) {
CalculatorGraphConfig::Node node;
node.set_calculator("AddHeaderCalculator");
node.add_input_stream("HEADER:header_stream");
node.add_input_stream("DATA:data_stream");
node.add_output_stream("merged_stream");
CalculatorRunner runner(node);
// No header and no packets.
// Run calculator.
MP_ASSERT_OK(runner.Run());
EXPECT_TRUE(runner.Outputs().Index(0).header.IsEmpty());
}
TEST_F(AddHeaderCalculatorTest, NoPacketsOnHeaderStream) {
CalculatorGraphConfig::Node node;
node.set_calculator("AddHeaderCalculator");
node.add_input_stream("HEADER:header_stream");
node.add_input_stream("DATA:data_stream");
node.add_output_stream("merged_stream");
CalculatorRunner runner(node);
// Set header and add 5 packets.
runner.MutableInputs()->Tag("HEADER").header =
Adopt(new std::string("my_header"));
runner.MutableInputs()->Tag("HEADER").packets.push_back(
Adopt(new std::string("not allowed")));
for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
}
// Run calculator.
ASSERT_FALSE(runner.Run().ok());
}
TEST_F(AddHeaderCalculatorTest, InputSidePacket) {
CalculatorGraphConfig::Node node;
node.set_calculator("AddHeaderCalculator");
node.add_input_stream("DATA:data_stream");
node.add_output_stream("merged_stream");
node.add_input_side_packet("HEADER:header");
CalculatorRunner runner(node);
// Set header and add 5 packets.
runner.MutableSidePackets()->Tag("HEADER") =
Adopt(new std::string("my_header"));
for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
}
// Run calculator.
MP_ASSERT_OK(runner.Run());
ASSERT_EQ(1, runner.Outputs().NumEntries());
// Test output.
EXPECT_EQ(std::string("my_header"),
runner.Outputs().Index(0).header.Get<std::string>());
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
ASSERT_EQ(5, output_packets.size());
for (int i = 0; i < 5; ++i) {
const int val = output_packets[i].Get<int>();
EXPECT_EQ(i, val);
EXPECT_EQ(Timestamp(i * 1000), output_packets[i].Timestamp());
}
}
TEST_F(AddHeaderCalculatorTest, UsingBothSideInputAndStream) {
CalculatorGraphConfig::Node node;
node.set_calculator("AddHeaderCalculator");
node.add_input_stream("HEADER:header_stream");
node.add_input_stream("DATA:data_stream");
node.add_output_stream("merged_stream");
node.add_input_side_packet("HEADER:header");
CalculatorRunner runner(node);
// Set both headers and add 5 packets.
runner.MutableSidePackets()->Tag("HEADER") =
Adopt(new std::string("my_header"));
runner.MutableSidePackets()->Tag("HEADER") =
Adopt(new std::string("my_header"));
for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
}
// Run should fail because header can only be provided one way.
EXPECT_EQ(runner.Run().code(), absl::InvalidArgumentError("").code());
}
} // namespace mediapipe

View File

@ -1,448 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "absl/memory/memory.h"
#include "mediapipe/calculators/core/begin_loop_calculator.h"
#include "mediapipe/calculators/core/end_loop_calculator.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
namespace {
MATCHER_P2(PacketOfIntsEq, timestamp, value, "") {
Timestamp actual_timestamp = arg.Timestamp();
const auto& actual_value = arg.template Get<std::vector<int>>();
return testing::Value(actual_timestamp, testing::Eq(timestamp)) &&
testing::Value(actual_value, testing::ElementsAreArray(value));
}
typedef BeginLoopCalculator<std::vector<int>> BeginLoopIntegerCalculator;
REGISTER_CALCULATOR(BeginLoopIntegerCalculator);
class IncrementCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
const int& input_int = cc->Inputs().Index(0).Get<int>();
auto output_int = absl::make_unique<int>(input_int + 1);
cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(IncrementCalculator);
typedef EndLoopCalculator<std::vector<int>> EndLoopIntegersCalculator;
REGISTER_CALCULATOR(EndLoopIntegersCalculator);
class BeginEndLoopCalculatorGraphTest : public ::testing::Test {
protected:
void SetUp() override {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
num_threads: 4
input_stream: "ints"
node {
calculator: "BeginLoopIntegerCalculator"
input_stream: "ITERABLE:ints"
output_stream: "ITEM:int"
output_stream: "BATCH_END:timestamp"
}
node {
calculator: "IncrementCalculator"
input_stream: "int"
output_stream: "int_plus_one"
}
node {
calculator: "EndLoopIntegersCalculator"
input_stream: "ITEM:int_plus_one"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:ints_plus_one"
}
)pb");
tool::AddVectorSink("ints_plus_one", &graph_config, &output_packets_);
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_ASSERT_OK(graph_.StartRun({}));
}
void SendPacketOfInts(Timestamp timestamp, std::vector<int> ints) {
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
}
CalculatorGraph graph_;
std::vector<Packet> output_packets_;
};
TEST_F(BeginEndLoopCalculatorGraphTest, InputStreamForIterableIsEmpty) {
MP_ASSERT_OK(graph_.WaitUntilIdle());
// EndLoopCalc will forward the timestamp bound because there are no packets
// to process.
ASSERT_EQ(0, output_packets_.size());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphTest, SingleEmptyVector) {
SendPacketOfInts(Timestamp(0), {});
MP_ASSERT_OK(graph_.WaitUntilIdle());
// EndLoopCalc will forward the timestamp bound because there are no elements
// in collection to output.
EXPECT_TRUE(output_packets_.empty());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphTest, SingleNonEmptyVector) {
Timestamp input_timestamp = Timestamp(0);
SendPacketOfInts(input_timestamp, {0, 1, 2});
MP_ASSERT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(output_packets_,
testing::ElementsAre(
PacketOfIntsEq(input_timestamp, std::vector<int>{1, 2, 3})));
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphTest, MultipleVectors) {
Timestamp input_timestamp0 = Timestamp(0);
SendPacketOfInts(input_timestamp0, {0, 1});
Timestamp input_timestamp1 = Timestamp(1);
SendPacketOfInts(input_timestamp1, {});
Timestamp input_timestamp2 = Timestamp(2);
SendPacketOfInts(input_timestamp2, {2, 3});
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
// no elements in vector to process.
EXPECT_THAT(output_packets_,
testing::ElementsAre(
PacketOfIntsEq(input_timestamp0, std::vector<int>{1, 2}),
PacketOfIntsEq(input_timestamp2, std::vector<int>{3, 4})));
}
// Passes non empty vector through or outputs empty vector in case of timestamp
// bound update.
class PassThroughOrEmptyVectorCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->SetProcessTimestampBounds(true);
cc->Inputs().Index(0).Set<std::vector<int>>();
cc->Outputs().Index(0).Set<std::vector<int>>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (!cc->Inputs().Index(0).IsEmpty()) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
} else {
cc->Outputs().Index(0).AddPacket(
MakePacket<std::vector<int>>(std::vector<int>())
.At(cc->InputTimestamp()));
}
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(PassThroughOrEmptyVectorCalculator);
class BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest
: public ::testing::Test {
protected:
void SetUp() override {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
num_threads: 4
input_stream: "ints"
input_stream: "force_ints_to_be_timestamp_bound_update"
node {
calculator: "GateCalculator"
input_stream: "ints"
input_stream: "DISALLOW:force_ints_to_be_timestamp_bound_update"
output_stream: "ints_passed_through"
}
node {
calculator: "BeginLoopIntegerCalculator"
input_stream: "ITERABLE:ints_passed_through"
output_stream: "ITEM:int"
output_stream: "BATCH_END:timestamp"
}
node {
calculator: "IncrementCalculator"
input_stream: "int"
output_stream: "int_plus_one"
}
node {
calculator: "EndLoopIntegersCalculator"
input_stream: "ITEM:int_plus_one"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:ints_plus_one"
}
node {
calculator: "PassThroughOrEmptyVectorCalculator"
input_stream: "ints_plus_one"
output_stream: "ints_plus_one_passed_through"
}
)pb");
tool::AddVectorSink("ints_plus_one_passed_through", &graph_config,
&output_packets_);
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_ASSERT_OK(graph_.StartRun({}));
}
void SendPacketOfIntsOrBound(Timestamp timestamp, std::vector<int> ints) {
// All "ints" packets which are empty are forced to be just timestamp
// bound updates for begin loop calculator.
bool force_ints_to_be_timestamp_bound_update = ints.empty();
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"force_ints_to_be_timestamp_bound_update",
MakePacket<bool>(force_ints_to_be_timestamp_bound_update)
.At(timestamp)));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
}
CalculatorGraph graph_;
std::vector<Packet> output_packets_;
};
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest,
SingleEmptyVector) {
SendPacketOfIntsOrBound(Timestamp(0), {});
MP_ASSERT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(output_packets_, testing::ElementsAre(PacketOfIntsEq(
Timestamp(0), std::vector<int>{})));
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest,
SingleNonEmptyVector) {
SendPacketOfIntsOrBound(Timestamp(0), {0, 1, 2});
MP_ASSERT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(output_packets_, testing::ElementsAre(PacketOfIntsEq(
Timestamp(0), std::vector<int>{1, 2, 3})));
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest, MultipleVectors) {
SendPacketOfIntsOrBound(Timestamp(0), {});
// Waiting until idle to guarantee all timestamp bound updates are processed
// individually. (Timestamp bounds updates occur in the provide config only
// if input is an empty vector.)
MP_ASSERT_OK(graph_.WaitUntilIdle());
SendPacketOfIntsOrBound(Timestamp(1), {0, 1});
SendPacketOfIntsOrBound(Timestamp(2), {});
// Waiting until idle to guarantee all timestamp bound updates are processed
// individually. (Timestamp bounds updates occur in the provide config only
// if input is an empty vector.)
MP_ASSERT_OK(graph_.WaitUntilIdle());
SendPacketOfIntsOrBound(Timestamp(3), {2, 3});
SendPacketOfIntsOrBound(Timestamp(4), {});
// Waiting until idle to guarantee all timestamp bound updates are processed
// individually. (Timestamp bounds updates occur in the provide config only
// if input is an empty vector.)
MP_ASSERT_OK(graph_.WaitUntilIdle());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
EXPECT_THAT(
output_packets_,
testing::ElementsAre(PacketOfIntsEq(Timestamp(0), std::vector<int>{}),
PacketOfIntsEq(Timestamp(1), std::vector<int>{1, 2}),
PacketOfIntsEq(Timestamp(2), std::vector<int>{}),
PacketOfIntsEq(Timestamp(3), std::vector<int>{3, 4}),
PacketOfIntsEq(Timestamp(4), std::vector<int>{})));
}
class MultiplierCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Inputs().Index(1).Set<int>();
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
const int& input_int = cc->Inputs().Index(0).Get<int>();
const int& multiplier_int = cc->Inputs().Index(1).Get<int>();
auto output_int = absl::make_unique<int>(input_int * multiplier_int);
cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(MultiplierCalculator);
class BeginEndLoopCalculatorGraphWithClonedInputsTest : public ::testing::Test {
protected:
void SetUp() override {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
num_threads: 4
input_stream: "ints"
input_stream: "multiplier"
node {
calculator: "BeginLoopIntegerCalculator"
input_stream: "ITERABLE:ints"
input_stream: "CLONE:multiplier"
output_stream: "ITEM:int_at_loop"
output_stream: "CLONE:multiplier_cloned_at_loop"
output_stream: "BATCH_END:timestamp"
}
node {
calculator: "MultiplierCalculator"
input_stream: "int_at_loop"
input_stream: "multiplier_cloned_at_loop"
output_stream: "multiplied_int_at_loop"
}
node {
calculator: "EndLoopIntegersCalculator"
input_stream: "ITEM:multiplied_int_at_loop"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:multiplied_ints"
}
)pb");
tool::AddVectorSink("multiplied_ints", &graph_config, &output_packets_);
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_ASSERT_OK(graph_.StartRun({}));
}
void SendPackets(Timestamp timestamp, int multiplier, std::vector<int> ints) {
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"multiplier", MakePacket<int>(multiplier).At(timestamp)));
}
void SendMultiplier(Timestamp timestamp, int multiplier) {
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"multiplier", MakePacket<int>(multiplier).At(timestamp)));
}
CalculatorGraph graph_;
std::vector<Packet> output_packets_;
};
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest,
InputStreamForIterableIsEmpty) {
Timestamp input_timestamp = Timestamp(42);
SendMultiplier(input_timestamp, /*multiplier=*/2);
MP_ASSERT_OK(graph_.WaitUntilIdle());
// EndLoopCalc will forward the timestamp bound because there are no packets
// to process.
ASSERT_EQ(0, output_packets_.size());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleEmptyVector) {
SendPackets(Timestamp(0), /*multiplier=*/2, /*ints=*/{});
MP_ASSERT_OK(graph_.WaitUntilIdle());
// EndLoopCalc will forward the timestamp bound because there are no elements
// in collection to output.
EXPECT_TRUE(output_packets_.empty());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleNonEmptyVector) {
Timestamp input_timestamp = Timestamp(42);
SendPackets(input_timestamp, /*multiplier=*/2, /*ints=*/{0, 1, 2});
MP_ASSERT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(output_packets_,
testing::ElementsAre(
PacketOfIntsEq(input_timestamp, std::vector<int>{0, 2, 4})));
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) {
Timestamp input_timestamp0 = Timestamp(42);
SendPackets(input_timestamp0, /*multiplier=*/2, /*ints=*/{0, 1});
Timestamp input_timestamp1 = Timestamp(43);
SendPackets(input_timestamp1, /*multiplier=*/2, /*ints=*/{});
Timestamp input_timestamp2 = Timestamp(44);
SendPackets(input_timestamp2, /*multiplier=*/3, /*ints=*/{2, 3});
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
// no elements in vector to process.
EXPECT_THAT(output_packets_,
testing::ElementsAre(
PacketOfIntsEq(input_timestamp0, std::vector<int>{0, 2}),
PacketOfIntsEq(input_timestamp2, std::vector<int>{6, 9})));
}
} // namespace
} // namespace mediapipe

View File

@ -1,45 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/begin_loop_calculator.h"
#include <vector>
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe {
// A calculator to process std::vector<NormalizedLandmarkList>.
typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedLandmarkList>>
BeginLoopNormalizedLandmarkListVectorCalculator;
REGISTER_CALCULATOR(BeginLoopNormalizedLandmarkListVectorCalculator);
// A calculator to process std::vector<NormalizedRect>.
typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
BeginLoopNormalizedRectCalculator;
REGISTER_CALCULATOR(BeginLoopNormalizedRectCalculator);
// A calculator to process std::vector<Detection>.
typedef BeginLoopCalculator<std::vector<::mediapipe::Detection>>
BeginLoopDetectionCalculator;
REGISTER_CALCULATOR(BeginLoopDetectionCalculator);
// A calculator to process std::vector<Matrix>.
typedef BeginLoopCalculator<std::vector<Matrix>> BeginLoopMatrixCalculator;
REGISTER_CALCULATOR(BeginLoopMatrixCalculator);
} // namespace mediapipe

View File

@ -1,165 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
#include "absl/memory/memory.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Calculator for implementing loops on iterable collections inside a MediaPipe
// graph.
//
// It is designed to be used like:
//
// node {
// calculator: "BeginLoopWithIterableCalculator"
// input_stream: "ITERABLE:input_iterable" # IterableT @ext_ts
// output_stream: "ITEM:input_element" # ItemT @loop_internal_ts
// output_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
// }
//
// node {
// calculator: "ElementToBlaConverterSubgraph"
// input_stream: "ITEM:input_to_loop_body" # ItemT @loop_internal_ts
// output_stream: "BLA:output_of_loop_body" # ItemU @loop_internal_ts
// }
//
// node {
// calculator: "EndLoopWithOutputCalculator"
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
// }
//
// Input streams tagged with "CLONE" are cloned to the corresponding output
// streams at loop timestamps. This ensures that a MediaPipe graph or sub-graph
// can run multiple times, once per element in the "ITERABLE" for each pakcet
// clone of the packets in the "CLONE" input streams.
template <typename IterableT>
class BeginLoopCalculator : public CalculatorBase {
using ItemT = typename IterableT::value_type;
public:
static absl::Status GetContract(CalculatorContract* cc) {
// The below enables processing of timestamp bound updates, and that enables
// correct timestamp propagation by the companion EndLoopCalculator.
//
// For instance, Process() function will be still invoked even if upstream
// calculator has updated timestamp bound for ITERABLE input instead of
// providing actual value.
cc->SetProcessTimestampBounds(true);
// A non-empty packet in the optional "TICK" input stream wakes up the
// calculator.
// DEPRECATED as timestamp bound updates are processed by default in this
// calculator.
if (cc->Inputs().HasTag("TICK")) {
cc->Inputs().Tag("TICK").SetAny();
}
// An iterable collection in the input stream.
RET_CHECK(cc->Inputs().HasTag("ITERABLE"));
cc->Inputs().Tag("ITERABLE").Set<IterableT>();
// An element from the collection.
RET_CHECK(cc->Outputs().HasTag("ITEM"));
cc->Outputs().Tag("ITEM").Set<ItemT>();
RET_CHECK(cc->Outputs().HasTag("BATCH_END"));
cc->Outputs()
.Tag("BATCH_END")
.Set<Timestamp>(
// A flush signal to the corresponding EndLoopCalculator for it to
// emit the aggregated result with the timestamp contained in this
// flush signal packet.
);
// Input streams tagged with "CLONE" are cloned to the corresponding
// "CLONE" output streams at loop timestamps.
RET_CHECK(cc->Inputs().NumEntries("CLONE") ==
cc->Outputs().NumEntries("CLONE"));
if (cc->Inputs().NumEntries("CLONE") > 0) {
for (int i = 0; i < cc->Inputs().NumEntries("CLONE"); ++i) {
cc->Inputs().Get("CLONE", i).SetAny();
cc->Outputs().Get("CLONE", i).SetSameAs(&cc->Inputs().Get("CLONE", i));
}
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
Timestamp last_timestamp = loop_internal_timestamp_;
if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) {
const IterableT& collection =
cc->Inputs().Tag("ITERABLE").template Get<IterableT>();
for (const auto& item : collection) {
cc->Outputs().Tag("ITEM").AddPacket(
MakePacket<ItemT>(item).At(loop_internal_timestamp_));
ForwardClonePackets(cc, loop_internal_timestamp_);
++loop_internal_timestamp_;
}
}
// The collection was empty and nothing was processed.
if (last_timestamp == loop_internal_timestamp_) {
// Increment loop_internal_timestamp_ because it is used up now.
++loop_internal_timestamp_;
for (auto it = cc->Outputs().begin(); it < cc->Outputs().end(); ++it) {
it->SetNextTimestampBound(loop_internal_timestamp_);
}
}
// The for loop processing the input collection already incremented
// loop_internal_timestamp_. To emit BATCH_END packet along the last
// non-BATCH_END packet, decrement by one.
cc->Outputs()
.Tag("BATCH_END")
.AddPacket(MakePacket<Timestamp>(cc->InputTimestamp())
.At(Timestamp(loop_internal_timestamp_ - 1)));
return absl::OkStatus();
}
private:
void ForwardClonePackets(CalculatorContext* cc, Timestamp output_timestamp) {
if (cc->Inputs().NumEntries("CLONE") > 0) {
for (int i = 0; i < cc->Inputs().NumEntries("CLONE"); ++i) {
if (!cc->Inputs().Get("CLONE", i).IsEmpty()) {
auto input_packet = cc->Inputs().Get("CLONE", i).Value();
cc->Outputs()
.Get("CLONE", i)
.AddPacket(std::move(input_packet).At(output_timestamp));
}
}
}
}
// Fake timestamps generated per element in collection.
Timestamp loop_internal_timestamp_ = Timestamp(0);
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_

View File

@ -1,26 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
#include "mediapipe/framework/formats/detection.pb.h"
namespace mediapipe {
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
ClipDetectionVectorSizeCalculator;
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
} // namespace mediapipe

View File

@ -1,33 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe {
typedef ClipVectorSizeCalculator<::mediapipe::NormalizedRect>
ClipNormalizedRectVectorSizeCalculator;
REGISTER_CALCULATOR(ClipNormalizedRectVectorSizeCalculator);
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
ClipDetectionVectorSizeCalculator;
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
} // namespace mediapipe

View File

@ -1,147 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_
#include <type_traits>
#include <vector>
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Clips the size of the input vector of type T to a specified max_vec_size.
// In a graph it will be used as:
// node {
// calculator: "ClipIntVectorSizeCalculator"
// input_stream: "input_vector"
// output_stream: "output_vector"
// options {
// [mediapipe.ClipVectorSizeCalculatorOptions.ext] {
// max_vec_size: 5
// }
// }
// }
// Optionally, you can pass in a side packet that will override `max_vec_size`
// that is specified in the options.
template <typename T>
class ClipVectorSizeCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().NumEntries() == 1);
RET_CHECK(cc->Outputs().NumEntries() == 1);
if (cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>()
.max_vec_size() < 1) {
return absl::InternalError(
"max_vec_size should be greater than or equal to 1.");
}
cc->Inputs().Index(0).Set<std::vector<T>>();
cc->Outputs().Index(0).Set<std::vector<T>>();
// Optional input side packet that determines `max_vec_size`.
if (cc->InputSidePackets().NumEntries() > 0) {
cc->InputSidePackets().Index(0).Set<int>();
}
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
max_vec_size_ = cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>()
.max_vec_size();
// Override `max_vec_size` if passed as side packet.
if (cc->InputSidePackets().NumEntries() > 0 &&
!cc->InputSidePackets().Index(0).IsEmpty()) {
max_vec_size_ = cc->InputSidePackets().Index(0).Get<int>();
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (max_vec_size_ < 1) {
return absl::InternalError(
"max_vec_size should be greater than or equal to 1.");
}
if (cc->Inputs().Index(0).IsEmpty()) {
return absl::OkStatus();
}
return ClipVectorSize<T>(std::is_copy_constructible<T>(), cc);
}
template <typename U>
absl::Status ClipVectorSize(std::true_type, CalculatorContext* cc) {
auto output = absl::make_unique<std::vector<U>>();
const std::vector<U>& input_vector =
cc->Inputs().Index(0).Get<std::vector<U>>();
if (max_vec_size_ >= input_vector.size()) {
output->insert(output->end(), input_vector.begin(), input_vector.end());
} else {
for (int i = 0; i < max_vec_size_; ++i) {
output->push_back(input_vector[i]);
}
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return absl::OkStatus();
}
template <typename U>
absl::Status ClipVectorSize(std::false_type, CalculatorContext* cc) {
return ConsumeAndClipVectorSize<T>(std::is_move_constructible<U>(), cc);
}
template <typename U>
absl::Status ConsumeAndClipVectorSize(std::true_type, CalculatorContext* cc) {
auto output = absl::make_unique<std::vector<U>>();
absl::StatusOr<std::unique_ptr<std::vector<U>>> input_status =
cc->Inputs().Index(0).Value().Consume<std::vector<U>>();
if (input_status.ok()) {
std::unique_ptr<std::vector<U>> input_vector =
std::move(input_status).value();
auto begin_it = input_vector->begin();
auto end_it = input_vector->end();
if (max_vec_size_ < input_vector->size()) {
end_it = input_vector->begin() + max_vec_size_;
}
output->insert(output->end(), std::make_move_iterator(begin_it),
std::make_move_iterator(end_it));
} else {
return input_status.status();
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return absl::OkStatus();
}
template <typename U>
absl::Status ConsumeAndClipVectorSize(std::false_type,
CalculatorContext* cc) {
return absl::InternalError(
"Cannot copy or move input vectors and clip their size.");
}
private:
int max_vec_size_ = 0;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_

View File

@ -1,30 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message ClipVectorSizeCalculatorOptions {
extend CalculatorOptions {
optional ClipVectorSizeCalculatorOptions ext = 274674998;
}
// Maximum size of output vector.
optional int32 max_vec_size = 1 [default = 1];
}

View File

@ -1,206 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
typedef ClipVectorSizeCalculator<int> TestClipIntVectorSizeCalculator;
REGISTER_CALCULATOR(TestClipIntVectorSizeCalculator);
void AddInputVector(const std::vector<int>& input, int64 timestamp,
CalculatorRunner* runner) {
runner->MutableInputs()->Index(0).packets.push_back(
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
}
TEST(TestClipIntVectorSizeCalculatorTest, EmptyVectorInput) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "TestClipIntVectorSizeCalculator"
input_stream: "input_vector"
output_stream: "output_vector"
options {
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 1 }
}
)pb");
CalculatorRunner runner(node_config);
std::vector<int> input = {};
AddInputVector(input, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
EXPECT_TRUE(outputs[0].Get<std::vector<int>>().empty());
}
TEST(TestClipIntVectorSizeCalculatorTest, OneTimestamp) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "TestClipIntVectorSizeCalculator"
input_stream: "input_vector"
output_stream: "output_vector"
options {
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 2 }
}
)pb");
CalculatorRunner runner(node_config);
std::vector<int> input = {0, 1, 2, 3};
AddInputVector(input, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const std::vector<int>& output = outputs[0].Get<std::vector<int>>();
EXPECT_EQ(2, output.size());
std::vector<int> expected_vector = {0, 1};
EXPECT_EQ(expected_vector, output);
}
TEST(TestClipIntVectorSizeCalculatorTest, TwoInputsAtTwoTimestamps) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "TestClipIntVectorSizeCalculator"
input_stream: "input_vector"
output_stream: "output_vector"
options {
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 3 }
}
)pb");
CalculatorRunner runner(node_config);
{
std::vector<int> input = {0, 1, 2, 3};
AddInputVector(input, /*timestamp=*/1, &runner);
}
{
std::vector<int> input = {2, 3, 4, 5};
AddInputVector(input, /*timestamp=*/2, &runner);
}
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(2, outputs.size());
{
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const std::vector<int>& output = outputs[0].Get<std::vector<int>>();
EXPECT_EQ(3, output.size());
std::vector<int> expected_vector = {0, 1, 2};
EXPECT_EQ(expected_vector, output);
}
{
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
const std::vector<int>& output = outputs[1].Get<std::vector<int>>();
EXPECT_EQ(3, output.size());
std::vector<int> expected_vector = {2, 3, 4};
EXPECT_EQ(expected_vector, output);
}
}
typedef ClipVectorSizeCalculator<std::unique_ptr<int>>
TestClipUniqueIntPtrVectorSizeCalculator;
REGISTER_CALCULATOR(TestClipUniqueIntPtrVectorSizeCalculator);
TEST(TestClipUniqueIntPtrVectorSizeCalculatorTest, ConsumeOneTimestamp) {
/* Note: We don't use CalculatorRunner for this test because it keeps copies
* of input packets, so packets sent to the graph don't have sole ownership.
* The test needs to send packets that own the data.
*/
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_vector"
node {
calculator: "TestClipUniqueIntPtrVectorSizeCalculator"
input_stream: "input_vector"
output_stream: "output_vector"
options {
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 3 }
}
}
)pb");
std::vector<Packet> outputs;
tool::AddVectorSink("output_vector", &graph_config, &outputs);
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config));
MP_EXPECT_OK(graph.StartRun({}));
// input1 : {0, 1, 2, 3, 4, 5}
auto input_vector = absl::make_unique<std::vector<std::unique_ptr<int>>>(6);
for (int i = 0; i < 6; ++i) {
input_vector->at(i) = absl::make_unique<int>(i);
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input_vector", Adopt(input_vector.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const std::vector<std::unique_ptr<int>>& result =
outputs[0].Get<std::vector<std::unique_ptr<int>>>();
EXPECT_EQ(3, result.size());
for (int i = 0; i < 3; ++i) {
const std::unique_ptr<int>& v = result[i];
EXPECT_EQ(i, *v);
}
}
TEST(TestClipIntVectorSizeCalculatorTest, SidePacket) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "TestClipIntVectorSizeCalculator"
input_stream: "input_vector"
input_side_packet: "max_vec_size"
output_stream: "output_vector"
options {
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 1 }
}
)pb");
CalculatorRunner runner(node_config);
// This should override the default of 1 set in the options.
runner.MutableSidePackets()->Index(0) = Adopt(new int(2));
std::vector<int> input = {0, 1, 2, 3};
AddInputVector(input, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const std::vector<int>& output = outputs[0].Get<std::vector<int>>();
EXPECT_EQ(2, output.size());
std::vector<int> expected_vector = {0, 1};
EXPECT_EQ(expected_vector, output);
}
} // namespace mediapipe

View File

@ -1,35 +0,0 @@
// Copyright 2019-2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
#include "mediapipe/framework/formats/detection.pb.h"
namespace mediapipe {
// Example config:
//
// node {
// calculator: "ConcatenateDetectionVectorCalculator"
// input_stream: "detection_vector_1"
// input_stream: "detection_vector_2"
// output_stream: "concatenated_detection_vector"
// }
//
typedef ConcatenateVectorCalculator<::mediapipe::Detection>
ConcatenateDetectionVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateDetectionVectorCalculator);
} // namespace mediapipe

View File

@ -1,79 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_ // NOLINT
#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_ // NOLINT
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// Concatenates several NormalizedLandmarkList protos following stream index
// order. This class assumes that every input stream contains a
// NormalizedLandmarkList proto object.
class ConcatenateNormalizedLandmarkListCalculator : public Node {
public:
static constexpr Input<NormalizedLandmarkList>::Multiple kIn{""};
static constexpr Output<NormalizedLandmarkList> kOut{""};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
static absl::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK_GE(kIn(cc).Count(), 1);
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
only_emit_if_all_present_ =
cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>()
.only_emit_if_all_present();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (only_emit_if_all_present_) {
for (const auto& input : kIn(cc)) {
if (input.IsEmpty()) return absl::OkStatus();
}
}
NormalizedLandmarkList output;
for (const auto& input : kIn(cc)) {
if (input.IsEmpty()) continue;
const NormalizedLandmarkList& list = *input;
for (int j = 0; j < list.landmark_size(); ++j) {
*output.add_landmark() = list.landmark(j);
}
}
kOut(cc).Send(std::move(output));
return absl::OkStatus();
}
private:
bool only_emit_if_all_present_;
};
MEDIAPIPE_REGISTER_NODE(ConcatenateNormalizedLandmarkListCalculator);
} // namespace api2
} // namespace mediapipe
// NOLINTNEXTLINE
#endif // MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_

View File

@ -1,184 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
constexpr float kLocationValue = 3;
NormalizedLandmarkList GenerateLandmarks(int landmarks_size,
int value_multiplier) {
NormalizedLandmarkList landmarks;
for (int i = 0; i < landmarks_size; ++i) {
NormalizedLandmark* landmark = landmarks.add_landmark();
landmark->set_x(value_multiplier * kLocationValue);
landmark->set_y(value_multiplier * kLocationValue);
landmark->set_z(value_multiplier * kLocationValue);
}
return landmarks;
}
void ValidateCombinedLandmarks(
const std::vector<NormalizedLandmarkList>& inputs,
const NormalizedLandmarkList& result) {
int element_id = 0;
int expected_size = 0;
for (int i = 0; i < inputs.size(); ++i) {
const NormalizedLandmarkList& landmarks_i = inputs[i];
expected_size += landmarks_i.landmark_size();
for (int j = 0; j < landmarks_i.landmark_size(); ++j) {
const NormalizedLandmark& expected = landmarks_i.landmark(j);
const NormalizedLandmark& got = result.landmark(element_id);
EXPECT_FLOAT_EQ(expected.x(), got.x());
EXPECT_FLOAT_EQ(expected.y(), got.y());
EXPECT_FLOAT_EQ(expected.z(), got.z());
++element_id;
}
}
EXPECT_EQ(expected_size, result.landmark_size());
}
void AddInputLandmarkLists(
const std::vector<NormalizedLandmarkList>& input_landmarks_vec,
int64 timestamp, CalculatorRunner* runner) {
for (int i = 0; i < input_landmarks_vec.size(); ++i) {
runner->MutableInputs()->Index(i).packets.push_back(
MakePacket<NormalizedLandmarkList>(input_landmarks_vec[i])
.At(Timestamp(timestamp)));
}
}
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
NormalizedLandmarkList empty_list;
std::vector<NormalizedLandmarkList> inputs = {empty_list, empty_list,
empty_list};
AddInputLandmarkLists(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(0, outputs[0].Get<NormalizedLandmarkList>().landmark_size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
}
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneTimestamp) {
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
NormalizedLandmarkList input_0 =
GenerateLandmarks(/*landmarks_size=*/3, /*value_multiplier=*/0);
NormalizedLandmarkList input_1 =
GenerateLandmarks(/*landmarks_size=*/1, /*value_multiplier=*/1);
NormalizedLandmarkList input_2 =
GenerateLandmarks(/*landmarks_size=*/2, /*value_multiplier=*/2);
std::vector<NormalizedLandmarkList> inputs = {input_0, input_1, input_2};
AddInputLandmarkLists(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const NormalizedLandmarkList& result =
outputs[0].Get<NormalizedLandmarkList>();
ValidateCombinedLandmarks(inputs, result);
}
TEST(ConcatenateNormalizedLandmarkListCalculatorTest,
TwoInputsAtTwoTimestamps) {
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
NormalizedLandmarkList input_0 =
GenerateLandmarks(/*landmarks_size=*/3, /*value_multiplier=*/0);
NormalizedLandmarkList input_1 =
GenerateLandmarks(/*landmarks_size=*/1, /*value_multiplier=*/1);
NormalizedLandmarkList input_2 =
GenerateLandmarks(/*landmarks_size=*/2, /*value_multiplier=*/2);
std::vector<NormalizedLandmarkList> inputs = {input_0, input_1, input_2};
{ AddInputLandmarkLists(inputs, /*timestamp=*/1, &runner); }
{ AddInputLandmarkLists(inputs, /*timestamp=*/2, &runner); }
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(2, outputs.size());
{
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const NormalizedLandmarkList& result =
outputs[0].Get<NormalizedLandmarkList>();
ValidateCombinedLandmarks(inputs, result);
}
{
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
const NormalizedLandmarkList& result =
outputs[1].Get<NormalizedLandmarkList>();
ValidateCombinedLandmarks(inputs, result);
}
}
TEST(ConcatenateNormalizedLandmarkListCalculatorTest,
OneEmptyStreamStillOutput) {
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
/*options_string=*/"", /*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
NormalizedLandmarkList input_0 =
GenerateLandmarks(/*landmarks_size=*/3, /*value_multiplier=*/0);
std::vector<NormalizedLandmarkList> inputs = {input_0};
AddInputLandmarkLists(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const NormalizedLandmarkList& result =
outputs[0].Get<NormalizedLandmarkList>();
ValidateCombinedLandmarks(inputs, result);
}
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
/*options_string=*/
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
"{only_emit_if_all_present: true}",
/*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
NormalizedLandmarkList input_0 =
GenerateLandmarks(/*landmarks_size=*/3, /*value_multiplier=*/0);
std::vector<NormalizedLandmarkList> inputs = {input_0};
AddInputLandmarkLists(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(0, outputs.size());
}
} // namespace mediapipe

View File

@ -1,94 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/util/render_data.pb.h"
#include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
namespace mediapipe {
// Example config:
// node {
// calculator: "ConcatenateFloatVectorCalculator"
// input_stream: "float_vector_1"
// input_stream: "float_vector_2"
// output_stream: "concatenated_float_vector"
// }
typedef ConcatenateVectorCalculator<float> ConcatenateFloatVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateFloatVectorCalculator);
// Example config:
// node {
// calculator: "ConcatenateInt32VectorCalculator"
// input_stream: "int32_vector_1"
// input_stream: "int32_vector_2"
// output_stream: "concatenated_int32_vector"
// }
typedef ConcatenateVectorCalculator<int32> ConcatenateInt32VectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateInt32VectorCalculator);
typedef ConcatenateVectorCalculator<uint64> ConcatenateUInt64VectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator);
typedef ConcatenateVectorCalculator<bool> ConcatenateBoolVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator);
// Example config:
// node {
// calculator: "ConcatenateTfLiteTensorVectorCalculator"
// input_stream: "tflitetensor_vector_1"
// input_stream: "tflitetensor_vector_2"
// output_stream: "concatenated_tflitetensor_vector"
// }
typedef ConcatenateVectorCalculator<TfLiteTensor>
ConcatenateTfLiteTensorVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateTfLiteTensorVectorCalculator);
typedef ConcatenateVectorCalculator<Tensor> ConcatenateTensorVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateTensorVectorCalculator);
typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmark>
ConcatenateLandmarkVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkVectorCalculator);
typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmarkList>
ConcatenateLandmarListVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarListVectorCalculator);
typedef ConcatenateVectorCalculator<mediapipe::ClassificationList>
ConcatenateClassificationListVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListVectorCalculator);
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
typedef ConcatenateVectorCalculator<::tflite::gpu::gl::GlBuffer>
ConcatenateGlBufferVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateGlBufferVectorCalculator);
#endif
typedef ConcatenateVectorCalculator<mediapipe::RenderData>
ConcatenateRenderDataVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator);
} // namespace mediapipe

View File

@ -1,122 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
#include <string>
#include <type_traits>
#include <vector>
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Note: since this is a calculator template that can be included by other
// source files, we do not place this in namespace api2 directly, but qualify
// the api2 names below, to avoid changing the visible name of the class.
// We cannot simply write "using mediapipe::api2" since it's a header file.
// This distinction will go away once api2 is finalized.
// Concatenates several objects of type T or std::vector<T> following stream
// index order. This class assumes that every input stream contains either T or
// vector<T> type. To use this class for a particular type T, regisiter a
// calculator using ConcatenateVectorCalculator<T>.
template <typename T>
class ConcatenateVectorCalculator : public api2::Node {
public:
static constexpr
typename api2::Input<api2::OneOf<T, std::vector<T>>>::Multiple kIn{""};
static constexpr api2::Output<std::vector<T>> kOut{""};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
static absl::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK_GE(kIn(cc).Count(), 1);
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
only_emit_if_all_present_ =
cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>()
.only_emit_if_all_present();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (only_emit_if_all_present_) {
for (const auto& input : kIn(cc)) {
if (input.IsEmpty()) return ::absl::OkStatus();
}
}
return ConcatenateVectors<T>(std::is_copy_constructible<T>(), cc);
}
template <typename U>
absl::Status ConcatenateVectors(std::true_type, CalculatorContext* cc) {
auto output = std::vector<U>();
for (const auto& input : kIn(cc)) {
if (input.IsEmpty()) continue;
input.Visit([&output](const U& value) { output.push_back(value); },
[&output](const std::vector<U>& value) {
output.insert(output.end(), value.begin(), value.end());
});
}
kOut(cc).Send(std::move(output));
return absl::OkStatus();
}
template <typename U>
absl::Status ConcatenateVectors(std::false_type, CalculatorContext* cc) {
return ConsumeAndConcatenateVectors<T>(std::is_move_constructible<U>(), cc);
}
template <typename U>
absl::Status ConsumeAndConcatenateVectors(std::true_type,
CalculatorContext* cc) {
auto output = std::vector<U>();
for (auto input : kIn(cc)) {
if (input.IsEmpty()) continue;
MP_RETURN_IF_ERROR(input.ConsumeAndVisit(
[&output](std::unique_ptr<U> value) {
output.push_back(std::move(*value));
},
[&output](std::unique_ptr<std::vector<U>> value) {
output.insert(output.end(), std::make_move_iterator(value->begin()),
std::make_move_iterator(value->end()));
}));
}
kOut(cc).Send(std::move(output));
return absl::OkStatus();
}
template <typename U>
absl::Status ConsumeAndConcatenateVectors(std::false_type,
CalculatorContext* cc) {
return absl::InternalError(
"Cannot copy or move inputs to concatenate them");
}
private:
bool only_emit_if_all_present_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_

View File

@ -1,31 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message ConcatenateVectorCalculatorOptions {
extend CalculatorOptions {
optional ConcatenateVectorCalculatorOptions ext = 259397839;
}
// If true, the calculator will only emit a packet at the given timestamp if
// all input streams have a non-empty packet (AND operation on streams).
optional bool only_emit_if_all_present = 1 [default = false];
}

View File

@ -1,548 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator);
void AddInputVector(int index, const std::vector<int>& input, int64 timestamp,
CalculatorRunner* runner) {
runner->MutableInputs()->Index(index).packets.push_back(
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
}
void AddInputVectors(const std::vector<std::vector<int>>& inputs,
int64 timestamp, CalculatorRunner* runner) {
for (int i = 0; i < inputs.size(); ++i) {
AddInputVector(i, inputs[i], timestamp, runner);
}
}
void AddInputItem(int index, int input, int64 timestamp,
CalculatorRunner* runner) {
runner->MutableInputs()->Index(index).packets.push_back(
MakePacket<int>(input).At(Timestamp(timestamp)));
}
void AddInputItems(const std::vector<int>& inputs, int64 timestamp,
CalculatorRunner* runner) {
for (int i = 0; i < inputs.size(); ++i) {
AddInputItem(i, inputs[i], timestamp, runner);
}
}
TEST(TestConcatenateIntVectorCalculatorTest, EmptyVectorInputs) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<int>> inputs = {{}, {}, {}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_TRUE(outputs[0].Get<std::vector<int>>().empty());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
}
TEST(TestConcatenateIntVectorCalculatorTest, OneTimestamp) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<int>> inputs = {{1, 2, 3}, {4}, {5, 6}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
TEST(TestConcatenateIntVectorCalculatorTest, TwoInputsAtTwoTimestamps) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
{
std::vector<std::vector<int>> inputs = {{1, 2, 3}, {4}, {5, 6}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
}
{
std::vector<std::vector<int>> inputs = {{0, 2}, {1}, {3, 5}};
AddInputVectors(inputs, /*timestamp=*/2, &runner);
}
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(2, outputs.size());
{
EXPECT_EQ(6, outputs[0].Get<std::vector<int>>().size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
{
EXPECT_EQ(5, outputs[1].Get<std::vector<int>>().size());
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
std::vector<int> expected_vector = {0, 2, 1, 3, 5};
EXPECT_EQ(expected_vector, outputs[1].Get<std::vector<int>>());
}
}
TEST(TestConcatenateIntVectorCalculatorTest, OneEmptyStreamStillOutput) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<int>> inputs = {{1, 2, 3}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
TEST(TestConcatenateIntVectorCalculatorTest, OneEmptyStreamNoOutput) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
"{only_emit_if_all_present: true}",
/*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<int>> inputs = {{1, 2, 3}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(0, outputs.size());
}
TEST(TestConcatenateIntVectorCalculatorTest, ItemsOneTimestamp) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<int> inputs = {1, 2, 3};
AddInputItems(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
TEST(TestConcatenateIntVectorCalculatorTest, ItemsTwoInputsAtTwoTimestamps) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
{
std::vector<int> inputs = {1, 2, 3};
AddInputItems(inputs, /*timestamp=*/1, &runner);
}
{
std::vector<int> inputs = {4, 5, 6};
AddInputItems(inputs, /*timestamp=*/2, &runner);
}
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(2, outputs.size());
{
EXPECT_EQ(3, outputs[0].Get<std::vector<int>>().size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
{
EXPECT_EQ(3, outputs[1].Get<std::vector<int>>().size());
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
std::vector<int> expected_vector = {4, 5, 6};
EXPECT_EQ(expected_vector, outputs[1].Get<std::vector<int>>());
}
}
TEST(TestConcatenateIntVectorCalculatorTest, ItemsOneEmptyStreamStillOutput) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
// No third input item.
std::vector<int> inputs = {1, 2};
AddInputItems(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
TEST(TestConcatenateIntVectorCalculatorTest, ItemsOneEmptyStreamNoOutput) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
"{only_emit_if_all_present: true}",
/*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
// No third input item.
std::vector<int> inputs = {1, 2};
AddInputItems(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(0, outputs.size());
}
TEST(TestConcatenateIntVectorCalculatorTest, MixedVectorsAndItems) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/4,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<int> vector_0 = {1, 2};
std::vector<int> vector_1 = {3, 4, 5};
int item_0 = 6;
int item_1 = 7;
AddInputVector(/*index*/ 0, vector_0, /*timestamp=*/1, &runner);
AddInputVector(/*index*/ 1, vector_1, /*timestamp=*/1, &runner);
AddInputItem(/*index*/ 2, item_0, /*timestamp=*/1, &runner);
AddInputItem(/*index*/ 3, item_1, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6, 7};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
TEST(TestConcatenateIntVectorCalculatorTest, MixedVectorsAndItemsAnother) {
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
/*options_string=*/"", /*num_inputs=*/4,
/*num_outputs=*/1, /*num_side_packets=*/0);
int item_0 = 1;
std::vector<int> vector_0 = {2, 3};
std::vector<int> vector_1 = {4, 5, 6};
int item_1 = 7;
AddInputItem(/*index*/ 0, item_0, /*timestamp=*/1, &runner);
AddInputVector(/*index*/ 1, vector_0, /*timestamp=*/1, &runner);
AddInputVector(/*index*/ 2, vector_1, /*timestamp=*/1, &runner);
AddInputItem(/*index*/ 3, item_1, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6, 7};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
}
void AddInputVectors(const std::vector<std::vector<float>>& inputs,
int64 timestamp, CalculatorRunner* runner) {
for (int i = 0; i < inputs.size(); ++i) {
runner->MutableInputs()->Index(i).packets.push_back(
MakePacket<std::vector<float>>(inputs[i]).At(Timestamp(timestamp)));
}
}
TEST(ConcatenateFloatVectorCalculatorTest, EmptyVectorInputs) {
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<float>> inputs = {{}, {}, {}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_TRUE(outputs[0].Get<std::vector<float>>().empty());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
}
TEST(ConcatenateFloatVectorCalculatorTest, OneTimestamp) {
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<float>> inputs = {
{1.0f, 2.0f, 3.0f}, {4.0f}, {5.0f, 6.0f}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
}
TEST(ConcatenateFloatVectorCalculatorTest, TwoInputsAtTwoTimestamps) {
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
/*options_string=*/"", /*num_inputs=*/3,
/*num_outputs=*/1, /*num_side_packets=*/0);
{
std::vector<std::vector<float>> inputs = {
{1.0f, 2.0f, 3.0f}, {4.0f}, {5.0f, 6.0f}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
}
{
std::vector<std::vector<float>> inputs = {
{0.0f, 2.0f}, {1.0f}, {3.0f, 5.0f}};
AddInputVectors(inputs, /*timestamp=*/2, &runner);
}
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(2, outputs.size());
{
EXPECT_EQ(6, outputs[0].Get<std::vector<float>>().size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
}
{
EXPECT_EQ(5, outputs[1].Get<std::vector<float>>().size());
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
std::vector<float> expected_vector = {0.0f, 2.0f, 1.0f, 3.0f, 5.0f};
EXPECT_EQ(expected_vector, outputs[1].Get<std::vector<float>>());
}
}
TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamStillOutput) {
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
/*options_string=*/"", /*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<float>> inputs = {{1.0f, 2.0f, 3.0f}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f};
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
}
TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) {
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
/*options_string=*/
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
"{only_emit_if_all_present: true}",
/*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
std::vector<std::vector<float>> inputs = {{1.0f, 2.0f, 3.0f}};
AddInputVectors(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(0, outputs.size());
}
typedef ConcatenateVectorCalculator<std::unique_ptr<int>>
TestConcatenateUniqueIntPtrCalculator;
MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator);
TEST(TestConcatenateUniqueIntVectorCalculatorTest, ConsumeOneTimestamp) {
/* Note: We don't use CalculatorRunner for this test because it keeps copies
* of input packets, so packets sent to the graph don't have sole ownership.
* The test needs to send packets that own the data.
*/
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "in_1"
input_stream: "in_2"
input_stream: "in_3"
node {
calculator: "TestConcatenateUniqueIntPtrCalculator"
input_stream: "in_1"
input_stream: "in_2"
input_stream: "in_3"
output_stream: "out"
}
)pb");
std::vector<Packet> outputs;
tool::AddVectorSink("out", &graph_config, &outputs);
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config));
MP_EXPECT_OK(graph.StartRun({}));
// input1 : {0, 1, 2}
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_1 =
absl::make_unique<std::vector<std::unique_ptr<int>>>(3);
for (int i = 0; i < 3; ++i) {
input_1->at(i) = absl::make_unique<int>(i);
}
// input2: {3}
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_2 =
absl::make_unique<std::vector<std::unique_ptr<int>>>(1);
input_2->at(0) = absl::make_unique<int>(3);
// input3: {4, 5}
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_3 =
absl::make_unique<std::vector<std::unique_ptr<int>>>(2);
input_3->at(0) = absl::make_unique<int>(4);
input_3->at(1) = absl::make_unique<int>(5);
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_1", Adopt(input_1.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_2", Adopt(input_2.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_3", Adopt(input_3.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const std::vector<std::unique_ptr<int>>& result =
outputs[0].Get<std::vector<std::unique_ptr<int>>>();
EXPECT_EQ(6, result.size());
for (int i = 0; i < 6; ++i) {
const std::unique_ptr<int>& v = result[i];
EXPECT_EQ(i, *v);
}
}
TEST(TestConcatenateUniqueIntVectorCalculatorTest, OneEmptyStreamStillOutput) {
/* Note: We don't use CalculatorRunner for this test because it keeps copies
* of input packets, so packets sent to the graph don't have sole ownership.
* The test needs to send packets that own the data.
*/
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "in_1"
input_stream: "in_2"
node {
calculator: "TestConcatenateUniqueIntPtrCalculator"
input_stream: "in_1"
input_stream: "in_2"
output_stream: "out"
}
)pb");
std::vector<Packet> outputs;
tool::AddVectorSink("out", &graph_config, &outputs);
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config));
MP_EXPECT_OK(graph.StartRun({}));
// input1 : {0, 1, 2}
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_1 =
absl::make_unique<std::vector<std::unique_ptr<int>>>(3);
for (int i = 0; i < 3; ++i) {
input_1->at(i) = absl::make_unique<int>(i);
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_1", Adopt(input_1.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const std::vector<std::unique_ptr<int>>& result =
outputs[0].Get<std::vector<std::unique_ptr<int>>>();
EXPECT_EQ(3, result.size());
for (int i = 0; i < 3; ++i) {
const std::unique_ptr<int>& v = result[i];
EXPECT_EQ(i, *v);
}
}
TEST(TestConcatenateUniqueIntVectorCalculatorTest, OneEmptyStreamNoOutput) {
/* Note: We don't use CalculatorRunner for this test because it keeps copies
* of input packets, so packets sent to the graph don't have sole ownership.
* The test needs to send packets that own the data.
*/
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "in_1"
input_stream: "in_2"
node {
calculator: "TestConcatenateUniqueIntPtrCalculator"
input_stream: "in_1"
input_stream: "in_2"
output_stream: "out"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: true
}
}
}
)pb");
std::vector<Packet> outputs;
tool::AddVectorSink("out", &graph_config, &outputs);
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config));
MP_EXPECT_OK(graph.StartRun({}));
// input1 : {0, 1, 2}
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_1 =
absl::make_unique<std::vector<std::unique_ptr<int>>>(3);
for (int i = 0; i < 3; ++i) {
input_1->at(i) = absl::make_unique<int>(i);
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_1", Adopt(input_1.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_EQ(0, outputs.size());
}
} // namespace mediapipe

View File

@ -1,129 +0,0 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace {} // namespace
// Generates an output side packet or multiple output side packets according to
// the specified options.
//
// Example configs:
// node {
// calculator: "ConstantSidePacketCalculator"
// output_side_packet: "PACKET:packet"
// options: {
// [mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
// packet { int_value: 2 }
// }
// }
// }
//
// node {
// calculator: "ConstantSidePacketCalculator"
// output_side_packet: "PACKET:0:int_packet"
// output_side_packet: "PACKET:1:bool_packet"
// options: {
// [mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
// packet { int_value: 2 }
// packet { bool_value: true }
// }
// }
// }
class ConstantSidePacketCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
const auto& options =
cc->Options<::mediapipe::ConstantSidePacketCalculatorOptions>();
RET_CHECK_EQ(cc->OutputSidePackets().NumEntries(kPacketTag),
options.packet_size())
<< "Number of output side packets has to be same as number of packets "
"configured in options.";
int index = 0;
for (CollectionItemId id = cc->OutputSidePackets().BeginId(kPacketTag);
id != cc->OutputSidePackets().EndId(kPacketTag); ++id, ++index) {
const auto& packet_options = options.packet(index);
auto& packet = cc->OutputSidePackets().Get(id);
if (packet_options.has_int_value()) {
packet.Set<int>();
} else if (packet_options.has_float_value()) {
packet.Set<float>();
} else if (packet_options.has_bool_value()) {
packet.Set<bool>();
} else if (packet_options.has_string_value()) {
packet.Set<std::string>();
} else if (packet_options.has_uint64_value()) {
packet.Set<uint64>();
} else if (packet_options.has_classification_list_value()) {
packet.Set<ClassificationList>();
} else {
return absl::InvalidArgumentError(
"None of supported values were specified in options.");
}
}
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
const auto& options =
cc->Options<::mediapipe::ConstantSidePacketCalculatorOptions>();
int index = 0;
for (CollectionItemId id = cc->OutputSidePackets().BeginId(kPacketTag);
id != cc->OutputSidePackets().EndId(kPacketTag); ++id, ++index) {
auto& packet = cc->OutputSidePackets().Get(id);
const auto& packet_options = options.packet(index);
if (packet_options.has_int_value()) {
packet.Set(MakePacket<int>(packet_options.int_value()));
} else if (packet_options.has_float_value()) {
packet.Set(MakePacket<float>(packet_options.float_value()));
} else if (packet_options.has_bool_value()) {
packet.Set(MakePacket<bool>(packet_options.bool_value()));
} else if (packet_options.has_string_value()) {
packet.Set(MakePacket<std::string>(packet_options.string_value()));
} else if (packet_options.has_uint64_value()) {
packet.Set(MakePacket<uint64>(packet_options.uint64_value()));
} else if (packet_options.has_classification_list_value()) {
packet.Set(MakePacket<ClassificationList>(
packet_options.classification_list_value()));
} else {
return absl::InvalidArgumentError(
"None of supported values were specified in options.");
}
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
return absl::OkStatus();
}
private:
static constexpr const char* kPacketTag = "PACKET";
};
REGISTER_CALCULATOR(ConstantSidePacketCalculator);
} // namespace mediapipe

View File

@ -1,41 +0,0 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/formats/classification.proto";
option objc_class_prefix = "MediaPipe";
message ConstantSidePacketCalculatorOptions {
extend CalculatorOptions {
optional ConstantSidePacketCalculatorOptions ext = 291214597;
}
message ConstantSidePacket {
oneof value {
int32 int_value = 1;
float float_value = 2;
bool bool_value = 3;
string string_value = 4;
uint64 uint64_value = 5;
ClassificationList classification_list_value = 6;
}
}
repeated ConstantSidePacket packet = 1;
}

View File

@ -1,189 +0,0 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
template <typename T>
void DoTestSingleSidePacket(absl::string_view packet_spec,
const T& expected_value) {
static constexpr absl::string_view graph_config_template = R"(
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:packet"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet $0
}
}
}
)";
CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(graph_config_template, packet_spec));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
MP_ASSERT_OK(graph.GetOutputSidePacket("packet"));
auto actual_value =
graph.GetOutputSidePacket("packet").value().template Get<T>();
EXPECT_EQ(actual_value, expected_value);
}
TEST(ConstantSidePacketCalculatorTest, EveryPossibleType) {
DoTestSingleSidePacket("{ int_value: 2 }", 2);
DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
DoTestSingleSidePacket("{ bool_value: true }", true);
DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
}
TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {
CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:0:int_packet"
output_side_packet: "PACKET:1:float_packet"
output_side_packet: "PACKET:2:bool_packet"
output_side_packet: "PACKET:3:string_packet"
output_side_packet: "PACKET:4:another_string_packet"
output_side_packet: "PACKET:5:another_int_packet"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { int_value: 256 }
packet { float_value: 0.5f }
packet { bool_value: false }
packet { string_value: "string" }
packet { string_value: "another string" }
packet { int_value: 128 }
}
}
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
MP_ASSERT_OK(graph.GetOutputSidePacket("int_packet"));
EXPECT_EQ(graph.GetOutputSidePacket("int_packet").value().Get<int>(), 256);
MP_ASSERT_OK(graph.GetOutputSidePacket("float_packet"));
EXPECT_EQ(graph.GetOutputSidePacket("float_packet").value().Get<float>(),
0.5f);
MP_ASSERT_OK(graph.GetOutputSidePacket("bool_packet"));
EXPECT_FALSE(graph.GetOutputSidePacket("bool_packet").value().Get<bool>());
MP_ASSERT_OK(graph.GetOutputSidePacket("string_packet"));
EXPECT_EQ(
graph.GetOutputSidePacket("string_packet").value().Get<std::string>(),
"string");
MP_ASSERT_OK(graph.GetOutputSidePacket("another_string_packet"));
EXPECT_EQ(graph.GetOutputSidePacket("another_string_packet")
.value()
.Get<std::string>(),
"another string");
MP_ASSERT_OK(graph.GetOutputSidePacket("another_int_packet"));
EXPECT_EQ(graph.GetOutputSidePacket("another_int_packet").value().Get<int>(),
128);
}
TEST(ConstantSidePacketCalculatorTest, ProcessingPacketsWithCorrectTagOnly) {
CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:0:int_packet"
output_side_packet: "no_tag0"
output_side_packet: "PACKET:1:float_packet"
output_side_packet: "INCORRECT_TAG:0:name1"
output_side_packet: "PACKET:2:bool_packet"
output_side_packet: "PACKET:3:string_packet"
output_side_packet: "no_tag2"
output_side_packet: "INCORRECT_TAG:1:name2"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { int_value: 256 }
packet { float_value: 0.5f }
packet { bool_value: false }
packet { string_value: "string" }
}
}
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
MP_ASSERT_OK(graph.GetOutputSidePacket("int_packet"));
EXPECT_EQ(graph.GetOutputSidePacket("int_packet").value().Get<int>(), 256);
MP_ASSERT_OK(graph.GetOutputSidePacket("float_packet"));
EXPECT_EQ(graph.GetOutputSidePacket("float_packet").value().Get<float>(),
0.5f);
MP_ASSERT_OK(graph.GetOutputSidePacket("bool_packet"));
EXPECT_FALSE(graph.GetOutputSidePacket("bool_packet").value().Get<bool>());
MP_ASSERT_OK(graph.GetOutputSidePacket("string_packet"));
EXPECT_EQ(
graph.GetOutputSidePacket("string_packet").value().Get<std::string>(),
"string");
}
TEST(ConstantSidePacketCalculatorTest, IncorrectConfig_MoreOptionsThanPackets) {
CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:int_packet"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { int_value: 256 }
packet { float_value: 0.5f }
}
}
}
)pb");
CalculatorGraph graph;
EXPECT_FALSE(graph.Initialize(graph_config).ok());
}
TEST(ConstantSidePacketCalculatorTest, IncorrectConfig_MorePacketsThanOptions) {
CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:0:int_packet"
output_side_packet: "PACKET:1:float_packet"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { int_value: 256 }
}
}
}
)pb");
CalculatorGraph graph;
EXPECT_FALSE(graph.Initialize(graph_config).ok());
}
} // namespace mediapipe

View File

@ -1,114 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/string_view.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
// Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of
// sequential numbers from INITIAL_VALUE (default 0) with a common
// difference of INCREMENT (default 1) between successive numbers (with
// timestamps corresponding to the sequence numbers). The packets are
// produced in BATCH_SIZE sized batches with each call to Process(). An
// error will be returned after ERROR_COUNT batches. An error will be
// produced in Open() if ERROR_ON_OPEN is true. Either MAX_COUNT or
// ERROR_COUNT must be provided and non-negative. If BATCH_SIZE is not
// provided, then batches are of size 1.
class CountingSourceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) {
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set<bool>();
}
RET_CHECK(cc->InputSidePackets().HasTag("MAX_COUNT") ||
cc->InputSidePackets().HasTag("ERROR_COUNT"));
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
cc->InputSidePackets().Tag("MAX_COUNT").Set<int>();
}
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
cc->InputSidePackets().Tag("ERROR_COUNT").Set<int>();
}
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
cc->InputSidePackets().Tag("BATCH_SIZE").Set<int>();
}
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
cc->InputSidePackets().Tag("INITIAL_VALUE").Set<int>();
}
if (cc->InputSidePackets().HasTag("INCREMENT")) {
cc->InputSidePackets().Tag("INCREMENT").Set<int>();
}
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") &&
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get<bool>()) {
return absl::NotFoundError("expected error");
}
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get<int>();
RET_CHECK_LE(0, error_count_);
}
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
max_count_ = cc->InputSidePackets().Tag("MAX_COUNT").Get<int>();
RET_CHECK_LE(0, max_count_);
}
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
batch_size_ = cc->InputSidePackets().Tag("BATCH_SIZE").Get<int>();
RET_CHECK_LT(0, batch_size_);
}
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
counter_ = cc->InputSidePackets().Tag("INITIAL_VALUE").Get<int>();
}
if (cc->InputSidePackets().HasTag("INCREMENT")) {
increment_ = cc->InputSidePackets().Tag("INCREMENT").Get<int>();
RET_CHECK_LT(0, increment_);
}
RET_CHECK(error_count_ >= 0 || max_count_ >= 0);
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (error_count_ >= 0 && batch_counter_ >= error_count_) {
return absl::InternalError("expected error");
}
if (max_count_ >= 0 && batch_counter_ >= max_count_) {
return tool::StatusStop();
}
for (int i = 0; i < batch_size_; ++i) {
cc->Outputs().Index(0).Add(new int(counter_), Timestamp(counter_));
counter_ += increment_;
}
++batch_counter_;
return absl::OkStatus();
}
private:
int max_count_ = -1;
int error_count_ = -1;
int batch_size_ = 1;
int batch_counter_ = 0;
int counter_ = 0;
int increment_ = 1;
};
REGISTER_CALCULATOR(CountingSourceCalculator);
} // namespace mediapipe

View File

@ -1,104 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace {
constexpr char kOptionalValueTag[] = "OPTIONAL_VALUE";
constexpr char kDefaultValueTag[] = "DEFAULT_VALUE";
constexpr char kValueTag[] = "VALUE";
} // namespace
// Outputs side packet default value if optional value is not provided.
//
// This calculator utilizes the fact that MediaPipe automatically removes
// optional side packets of the calculator configuration (i.e. OPTIONAL_VALUE).
// And if it happens - returns default value, otherwise - returns optional
// value.
//
// Input:
// OPTIONAL_VALUE (optional) - AnyType (but same type as DEFAULT_VALUE)
// Optional side packet value that is outputted by the calculator as is if
// provided.
//
// DEFAULT_VALUE - AnyType
// Default side pack value that is outputted by the calculator if
// OPTIONAL_VALUE is not provided.
//
// Output:
// VALUE - AnyType (but same type as DEFAULT_VALUE)
// Either OPTIONAL_VALUE (if provided) or DEFAULT_VALUE (otherwise).
//
// Usage example:
// node {
// calculator: "DefaultSidePacketCalculator"
// input_side_packet: "OPTIONAL_VALUE:segmentation_mask_enabled_optional"
// input_side_packet: "DEFAULT_VALUE:segmentation_mask_enabled_default"
// output_side_packet: "VALUE:segmentation_mask_enabled"
// }
class DefaultSidePacketCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(DefaultSidePacketCalculator);
absl::Status DefaultSidePacketCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->InputSidePackets().HasTag(kDefaultValueTag))
<< "Default value must be provided";
cc->InputSidePackets().Tag(kDefaultValueTag).SetAny();
// Optional input side packet can be unspecified. In this case MediaPipe will
// remove it from the calculator config.
if (cc->InputSidePackets().HasTag(kOptionalValueTag)) {
cc->InputSidePackets()
.Tag(kOptionalValueTag)
.SetSameAs(&cc->InputSidePackets().Tag(kDefaultValueTag))
.Optional();
}
RET_CHECK(cc->OutputSidePackets().HasTag(kValueTag));
cc->OutputSidePackets().Tag(kValueTag).SetSameAs(
&cc->InputSidePackets().Tag(kDefaultValueTag));
return absl::OkStatus();
}
absl::Status DefaultSidePacketCalculator::Open(CalculatorContext* cc) {
// If optional value is provided it is returned as the calculator output.
if (cc->InputSidePackets().HasTag(kOptionalValueTag)) {
auto& packet = cc->InputSidePackets().Tag(kOptionalValueTag);
cc->OutputSidePackets().Tag(kValueTag).Set(packet);
return absl::OkStatus();
}
// If no optional value
auto& packet = cc->InputSidePackets().Tag(kDefaultValueTag);
cc->OutputSidePackets().Tag(kValueTag).Set(packet);
return absl::OkStatus();
}
absl::Status DefaultSidePacketCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -1,90 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cfloat>
#include "mediapipe/calculators/core/dequantize_byte_array_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/status.h"
// Dequantizes a byte array to a vector of floats.
//
// Example config:
// node {
// calculator: "DequantizeByteArrayCalculator"
// input_stream: "ENCODED:encoded"
// output_stream: "FLOAT_VECTOR:float_vector"
// options {
// [mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
// max_quantized_value: 2
// min_quantized_value: -2
// }
// }
// }
namespace mediapipe {
class DequantizeByteArrayCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("ENCODED").Set<std::string>();
cc->Outputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
const auto options =
cc->Options<::mediapipe::DequantizeByteArrayCalculatorOptions>();
if (!options.has_max_quantized_value() ||
!options.has_min_quantized_value()) {
return absl::InvalidArgumentError(
"Both max_quantized_value and min_quantized_value must be provided "
"in DequantizeByteArrayCalculatorOptions.");
}
float max_quantized_value = options.max_quantized_value();
float min_quantized_value = options.min_quantized_value();
if (max_quantized_value < min_quantized_value + FLT_EPSILON) {
return absl::InvalidArgumentError(
"max_quantized_value must be greater than min_quantized_value.");
}
float range = max_quantized_value - min_quantized_value;
scalar_ = range / 255.0;
bias_ = (range / 512.0) + min_quantized_value;
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
const std::string& encoded =
cc->Inputs().Tag("ENCODED").Value().Get<std::string>();
std::vector<float> float_vector;
float_vector.reserve(encoded.length());
for (int i = 0; i < encoded.length(); ++i) {
float_vector.push_back(
static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_);
}
cc->Outputs()
.Tag("FLOAT_VECTOR")
.AddPacket(MakePacket<std::vector<float>>(float_vector)
.At(cc->InputTimestamp()));
return absl::OkStatus();
}
private:
float scalar_;
float bias_;
};
REGISTER_CALCULATOR(DequantizeByteArrayCalculator);
} // namespace mediapipe

View File

@ -1,30 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message DequantizeByteArrayCalculatorOptions {
extend CalculatorOptions {
optional DequantizeByteArrayCalculatorOptions ext = 272316343;
}
optional float max_quantized_value = 1;
optional float min_quantized_value = 2;
}

View File

@ -1,137 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: 2
}
}
)pb");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"Both max_quantized_value and min_quantized_value must be provided"));
}
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: -2
min_quantized_value: 2
}
}
)pb");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"max_quantized_value must be greater than min_quantized_value"));
}
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: 1
min_quantized_value: 1
}
}
)pb");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"max_quantized_value must be greater than min_quantized_value"));
}
TEST(DequantizeByteArrayCalculatorTest, TestDequantization) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: 2
min_quantized_value: -2
}
}
)pb");
CalculatorRunner runner(node_config);
unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01};
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(
std::string(reinterpret_cast<char const*>(input), 4))
.At(Timestamp(0)));
auto status = runner.Run();
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs =
runner.Outputs().Tag("FLOAT_VECTOR").packets;
EXPECT_EQ(1, outputs.size());
const std::vector<float>& result = outputs[0].Get<std::vector<float>>();
ASSERT_FALSE(result.empty());
EXPECT_EQ(4, result.size());
EXPECT_NEAR(0, result[0], 0.01);
EXPECT_NEAR(2, result[1], 0.01);
EXPECT_NEAR(-2, result[2], 0.01);
EXPECT_NEAR(-1.976, result[3], 0.01);
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
}
} // namespace mediapipe

View File

@ -1,53 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/end_loop_calculator.h"
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/util/render_data.pb.h"
#include "tensorflow/lite/interpreter.h"
namespace mediapipe {
typedef EndLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
EndLoopNormalizedRectCalculator;
REGISTER_CALCULATOR(EndLoopNormalizedRectCalculator);
typedef EndLoopCalculator<std::vector<::mediapipe::LandmarkList>>
EndLoopLandmarkListVectorCalculator;
REGISTER_CALCULATOR(EndLoopLandmarkListVectorCalculator);
typedef EndLoopCalculator<std::vector<::mediapipe::NormalizedLandmarkList>>
EndLoopNormalizedLandmarkListVectorCalculator;
REGISTER_CALCULATOR(EndLoopNormalizedLandmarkListVectorCalculator);
typedef EndLoopCalculator<std::vector<bool>> EndLoopBooleanCalculator;
REGISTER_CALCULATOR(EndLoopBooleanCalculator);
typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>>
EndLoopRenderDataCalculator;
REGISTER_CALCULATOR(EndLoopRenderDataCalculator);
typedef EndLoopCalculator<std::vector<::mediapipe::ClassificationList>>
EndLoopClassificationListCalculator;
REGISTER_CALCULATOR(EndLoopClassificationListCalculator);
typedef EndLoopCalculator<std::vector<TfLiteTensor>> EndLoopTensorCalculator;
REGISTER_CALCULATOR(EndLoopTensorCalculator);
} // namespace mediapipe

View File

@ -1,106 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Calculator for completing the processing of loops on iterable collections
// inside a MediaPipe graph. The EndLoopCalculator collects all input packets
// from ITEM input_stream into a collection and upon receiving the flush signal
// from the "BATCH_END" tagged input stream, it emits the aggregated results
// at the original timestamp contained in the "BATCH_END" input stream.
//
// It is designed to be used like:
//
// node {
// calculator: "BeginLoopWithIterableCalculator"
// input_stream: "ITERABLE:input_iterable" # IterableT @ext_ts
// output_stream: "ITEM:input_element" # ItemT @loop_internal_ts
// output_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
// }
//
// node {
// calculator: "ElementToBlaConverterSubgraph"
// input_stream: "ITEM:input_to_loop_body" # ItemT @loop_internal_ts
// output_stream: "BLA:output_of_loop_body" # ItemU @loop_internal_ts
// }
//
// node {
// calculator: "EndLoopWithOutputCalculator"
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
// }
template <typename IterableT>
class EndLoopCalculator : public CalculatorBase {
using ItemT = typename IterableT::value_type;
public:
static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("BATCH_END"))
<< "Missing BATCH_END tagged input_stream.";
cc->Inputs().Tag("BATCH_END").Set<Timestamp>();
RET_CHECK(cc->Inputs().HasTag("ITEM"));
cc->Inputs().Tag("ITEM").Set<ItemT>();
RET_CHECK(cc->Outputs().HasTag("ITERABLE"));
cc->Outputs().Tag("ITERABLE").Set<IterableT>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (!cc->Inputs().Tag("ITEM").IsEmpty()) {
if (!input_stream_collection_) {
input_stream_collection_.reset(new IterableT);
}
input_stream_collection_->push_back(
cc->Inputs().Tag("ITEM").template Get<ItemT>());
}
if (!cc->Inputs().Tag("BATCH_END").Value().IsEmpty()) { // flush signal
Timestamp loop_control_ts =
cc->Inputs().Tag("BATCH_END").template Get<Timestamp>();
if (input_stream_collection_) {
cc->Outputs()
.Tag("ITERABLE")
.Add(input_stream_collection_.release(), loop_control_ts);
} else {
// Since there is no collection, inform downstream calculators to not
// expect any packet by updating the timestamp bounds.
cc->Outputs()
.Tag("ITERABLE")
.SetNextTimestampBound(Timestamp(loop_control_ts.Value() + 1));
}
}
return absl::OkStatus();
}
private:
std::unique_ptr<IterableT> input_stream_collection_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_

View File

@ -1,229 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <utility>
#include <vector>
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/header_util.h"
namespace mediapipe {
// FlowLimiterCalculator is used to limit the number of frames in flight
// by dropping input frames when necessary.
//
// The input stream "FINISH" is used to signal the FlowLimiterCalculator
// when a frame is finished processing. Either a non-empty "FINISH" packet
// or a timestamp bound should be received for each processed frame.
//
// The combination of `max_in_flight: 1` and `max_in_queue: 1` generally gives
// best throughput/latency balance. Throughput is nearly optimal as the
// graph is never idle as there is always something in the queue. Latency is
// nearly optimal latency as the queue always stores the latest available frame.
//
// Increasing `max_in_flight` to 2 or more can yield the better throughput
// when the graph exhibits a high degree of pipeline parallelism. Decreasing
// `max_in_flight` to 0 can yield a better average latency, but at the cost of
// lower throughput (lower framerate) due to the time during which the graph
// is idle awaiting the next input frame.
//
// Example config:
// node {
// calculator: "FlowLimiterCalculator"
// input_stream: "raw_frames"
// input_stream: "FINISHED:finished"
// input_stream_info: {
// tag_index: 'FINISHED'
// back_edge: true
// }
// output_stream: "sampled_frames"
// output_stream: "ALLOW:allowed_timestamps"
// }
//
// The "ALLOW" stream indicates the transition between accepting frames and
// dropping frames. "ALLOW = true" indicates the start of accepting frames
// including the current timestamp, and "ALLOW = false" indicates the start of
// dropping frames including the current timestamp.
//
// FlowLimiterCalculator provides limited support for multiple input streams.
// The first input stream is treated as the main input stream and successive
// input streams are treated as auxiliary input streams. The auxiliary input
// streams are limited to timestamps passed on the main input stream.
//
class FlowLimiterCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
auto& side_inputs = cc->InputSidePackets();
side_inputs.Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional();
cc->Inputs().Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional();
RET_CHECK_GE(cc->Inputs().NumEntries(""), 1);
for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) {
cc->Inputs().Get("", i).SetAny();
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
}
cc->Inputs().Get("FINISHED", 0).SetAny();
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set<int>().Optional();
cc->Outputs().Tag("ALLOW").Set<bool>().Optional();
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
cc->SetProcessTimestampBounds(true);
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
options_ = cc->Options<FlowLimiterCalculatorOptions>();
options_ = tool::RetrieveOptions(options_, cc->InputSidePackets());
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
options_.set_max_in_flight(
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get<int>());
}
input_queues_.resize(cc->Inputs().NumEntries(""));
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
return absl::OkStatus();
}
// Returns true if an additional frame can be released for processing.
// The "ALLOW" output stream indicates this condition at each input frame.
bool ProcessingAllowed() {
return frames_in_flight_.size() < options_.max_in_flight();
}
// Outputs a packet indicating whether a frame was sent or dropped.
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
if (cc->Outputs().HasTag("ALLOW")) {
cc->Outputs().Tag("ALLOW").AddPacket(MakePacket<bool>(allow).At(ts));
}
}
// Sets the timestamp bound or closes an output stream.
void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
if (bound > Timestamp::Max()) {
stream->Close();
} else {
stream->SetNextTimestampBound(bound);
}
}
// Returns true if a certain timestamp is being processed.
bool IsInFlight(Timestamp timestamp) {
return std::find(frames_in_flight_.begin(), frames_in_flight_.end(),
timestamp) != frames_in_flight_.end();
}
// Releases input packets up to the latest settled input timestamp.
void ProcessAuxiliaryInputs(CalculatorContext* cc) {
Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
// Release settled frames from each input queue.
while (!input_queues_[i].empty() &&
input_queues_[i].front().Timestamp() < settled_bound) {
Packet packet = input_queues_[i].front();
input_queues_[i].pop_front();
if (IsInFlight(packet.Timestamp())) {
cc->Outputs().Get("", i).AddPacket(packet);
}
}
// Propagate each input timestamp bound.
if (!input_queues_[i].empty()) {
Timestamp bound = input_queues_[i].front().Timestamp();
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
} else {
Timestamp bound =
cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
}
}
}
// Releases input packets allowed by the max_in_flight constraint.
absl::Status Process(CalculatorContext* cc) final {
options_ = tool::RetrieveOptions(options_, cc->Inputs());
// Process the FINISHED input stream.
Packet finished_packet = cc->Inputs().Tag("FINISHED").Value();
if (finished_packet.Timestamp() == cc->InputTimestamp()) {
while (!frames_in_flight_.empty() &&
frames_in_flight_.front() <= finished_packet.Timestamp()) {
frames_in_flight_.pop_front();
}
}
// Process the frame input streams.
for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) {
Packet packet = cc->Inputs().Get("", i).Value();
if (!packet.IsEmpty()) {
input_queues_[i].push_back(packet);
}
}
// Abandon expired frames in flight. Note that old frames are abandoned
// when much newer frame timestamps arrive regardless of elapsed time.
TimestampDiff timeout = options_.in_flight_timeout();
Timestamp latest_ts = cc->Inputs().Get("", 0).Value().Timestamp();
if (timeout > 0 && latest_ts == cc->InputTimestamp() &&
latest_ts < Timestamp::Max()) {
while (!frames_in_flight_.empty() &&
(latest_ts - frames_in_flight_.front()) > timeout) {
frames_in_flight_.pop_front();
}
}
// Release allowed frames from the main input queue.
auto& input_queue = input_queues_[0];
while (ProcessingAllowed() && !input_queue.empty()) {
Packet packet = input_queue.front();
input_queue.pop_front();
cc->Outputs().Get("", 0).AddPacket(packet);
SendAllow(true, packet.Timestamp(), cc);
frames_in_flight_.push_back(packet.Timestamp());
}
// Limit the number of queued frames.
// Note that frames can be dropped after frames are released because
// frame-packets and FINISH-packets never arrive in the same Process call.
while (input_queue.size() > options_.max_in_queue()) {
Packet packet = input_queue.front();
input_queue.pop_front();
SendAllow(false, packet.Timestamp(), cc);
}
// Propagate the input timestamp bound.
if (!input_queue.empty()) {
Timestamp bound = input_queue.front().Timestamp();
SetNextTimestampBound(bound, &cc->Outputs().Get("", 0));
} else {
Timestamp bound =
cc->Inputs().Get("", 0).Value().Timestamp().NextAllowedInStream();
SetNextTimestampBound(bound, &cc->Outputs().Get("", 0));
if (cc->Outputs().HasTag("ALLOW")) {
SetNextTimestampBound(bound, &cc->Outputs().Tag("ALLOW"));
}
}
ProcessAuxiliaryInputs(cc);
return absl::OkStatus();
}
private:
FlowLimiterCalculatorOptions options_;
std::vector<std::deque<Packet>> input_queues_;
std::deque<Timestamp> frames_in_flight_;
};
REGISTER_CALCULATOR(FlowLimiterCalculator);
} // namespace mediapipe

View File

@ -1,40 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message FlowLimiterCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional FlowLimiterCalculatorOptions ext = 326963320;
}
// The maximum number of frames released for processing at one time.
// The default value limits to 1 frame processing at a time.
optional int32 max_in_flight = 1 [default = 1];
// The maximum number of frames queued waiting for processing.
// The default value limits to 1 frame awaiting processing.
optional int32 max_in_queue = 2 [default = 0];
// The maximum time in microseconds to wait for a frame to finish processing.
// The default value stops waiting after 1 sec.
// The value 0 specifies no timeout.
optional int64 in_flight_timeout = 3 [default = 1000000];
}

View File

@ -1,760 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/simulation_clock.h"
#include "mediapipe/framework/tool/simulation_clock_executor.h"
#include "mediapipe/framework/tool/sink.h"
namespace mediapipe {
namespace {
// A simple Semaphore for synchronizing test threads.
class AtomicSemaphore {
public:
AtomicSemaphore(int64_t supply) : supply_(supply) {}
void Acquire(int64_t amount) {
while (supply_.fetch_sub(amount) - amount < 0) {
Release(amount);
}
}
void Release(int64_t amount) { supply_.fetch_add(amount); }
private:
std::atomic<int64_t> supply_;
};
// Returns the timestamp values for a vector of Packets.
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
std::vector<int64> result;
for (const Packet& packet : packets) {
result.push_back(packet.Timestamp().Value());
}
return result;
}
// Returns the packet values for a vector of Packets.
template <typename T>
std::vector<T> PacketValues(const std::vector<Packet>& packets) {
std::vector<T> result;
for (const Packet& packet : packets) {
result.push_back(packet.Get<T>());
}
return result;
}
// A Calculator::Process callback function.
typedef std::function<absl::Status(const InputStreamShardSet&,
OutputStreamShardSet*)>
ProcessFunction;
// A testing callback function that passes through all packets.
absl::Status PassthroughFunction(const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
for (int i = 0; i < inputs.NumEntries(); ++i) {
if (!inputs.Index(i).Value().IsEmpty()) {
outputs->Index(i).AddPacket(inputs.Index(i).Value());
}
}
return absl::OkStatus();
}
// Tests demonstrating an FlowLimiterCalculator operating in a cyclic graph.
class FlowLimiterCalculatorSemaphoreTest : public testing::Test {
public:
FlowLimiterCalculatorSemaphoreTest() : exit_semaphore_(0) {}
void SetUp() override {
graph_config_ = InflightGraphConfig();
tool::AddVectorSink("out_1", &graph_config_, &out_1_packets_);
}
void InitializeGraph(int max_in_flight) {
ProcessFunction semaphore_1_func = [&](const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
exit_semaphore_.Acquire(1);
return PassthroughFunction(inputs, outputs);
};
FlowLimiterCalculatorOptions options;
options.set_max_in_flight(max_in_flight);
options.set_max_in_queue(1);
MP_ASSERT_OK(graph_.Initialize(
graph_config_, {
{"limiter_options", Adopt(new auto(options))},
{"callback_1", Adopt(new auto(semaphore_1_func))},
}));
allow_poller_.reset(
new OutputStreamPoller(graph_.AddOutputStreamPoller("allow").value()));
}
// Adds a packet to a graph input stream.
void AddPacket(const std::string& input_name, int value) {
MP_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int>(value).At(Timestamp(value))));
}
// A calculator graph starting with an FlowLimiterCalculator and
// ending with a InFlightFinishCalculator.
// Back-edge "finished" limits processing to one frame in-flight.
// The LambdaCalculator is used to keep certain frames in flight.
CalculatorGraphConfig InflightGraphConfig() {
return ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in_1'
node {
calculator: 'FlowLimiterCalculator'
input_side_packet: 'OPTIONS:limiter_options'
input_stream: 'in_1'
input_stream: 'FINISHED:out_1'
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
output_stream: 'in_1_sampled'
output_stream: 'ALLOW:allow'
}
node {
calculator: 'LambdaCalculator'
input_side_packet: 'callback_1'
input_stream: 'in_1_sampled'
output_stream: 'out_1'
}
)pb");
}
protected:
CalculatorGraphConfig graph_config_;
CalculatorGraph graph_;
AtomicSemaphore exit_semaphore_;
std::vector<Packet> out_1_packets_;
std::unique_ptr<OutputStreamPoller> allow_poller_;
};
// A test demonstrating an FlowLimiterCalculator operating in a cyclic
// graph. This test shows that:
//
// (1) Frames exceeding the queue size are dropped.
// (2) The "ALLOW" signal is produced.
// (3) Timestamps are passed through unaltered.
//
TEST_F(FlowLimiterCalculatorSemaphoreTest, FramesDropped) {
InitializeGraph(1);
MP_ASSERT_OK(graph_.StartRun({}));
auto send_packet = [this](const std::string& input_name, int64 n) {
MP_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int64>(n).At(Timestamp(n))));
};
Packet allow_packet;
send_packet("in_1", 0);
for (int i = 0; i < 9; i++) {
EXPECT_TRUE(allow_poller_->Next(&allow_packet));
EXPECT_TRUE(allow_packet.Get<bool>());
// This input should wait in the limiter input queue.
send_packet("in_1", i * 10 + 5);
// This input should drop the previous input.
send_packet("in_1", i * 10 + 10);
EXPECT_TRUE(allow_poller_->Next(&allow_packet));
EXPECT_FALSE(allow_packet.Get<bool>());
exit_semaphore_.Release(1);
}
exit_semaphore_.Release(1);
MP_EXPECT_OK(graph_.CloseInputStream("in_1"));
MP_EXPECT_OK(graph_.WaitUntilIdle());
// All output streams are closed and all output packets are delivered,
// with stream "in_1" closed.
EXPECT_EQ(10, out_1_packets_.size());
// Timestamps have not been altered.
EXPECT_EQ(PacketValues<int64>(out_1_packets_),
TimestampValues(out_1_packets_));
// Extra inputs on in_1 have been dropped.
EXPECT_EQ(TimestampValues(out_1_packets_),
(std::vector<int64>{0, 10, 20, 30, 40, 50, 60, 70, 80, 90}));
}
// A calculator that sleeps during Process.
class SleepCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("PACKET").SetAny();
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET"));
cc->InputSidePackets().Tag("SLEEP_TIME").Set<int64>();
cc->InputSidePackets().Tag("WARMUP_TIME").Set<int64>();
cc->InputSidePackets().Tag("CLOCK").Set<mediapipe::Clock*>();
cc->SetTimestampOffset(0);
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<mediapipe::Clock*>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
++packet_count;
absl::Duration sleep_time = absl::Microseconds(
packet_count == 1
? cc->InputSidePackets().Tag("WARMUP_TIME").Get<int64>()
: cc->InputSidePackets().Tag("SLEEP_TIME").Get<int64>());
clock_->Sleep(sleep_time);
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value());
return absl::OkStatus();
}
private:
::mediapipe::Clock* clock_ = nullptr;
int packet_count = 0;
};
REGISTER_CALCULATOR(SleepCalculator);
// A calculator that drops a packet occasionally.
// Drops the 3rd packet, and optionally the corresponding timestamp bound.
class DropCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("PACKET").SetAny();
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET"));
cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Set<bool>();
cc->SetProcessTimestampBounds(true);
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
if (!cc->Inputs().Tag("PACKET").Value().IsEmpty()) {
++packet_count;
}
bool drop = (packet_count == 3);
if (!drop && !cc->Inputs().Tag("PACKET").Value().IsEmpty()) {
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value());
}
if (!drop || !cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Get<bool>()) {
cc->Outputs().Tag("PACKET").SetNextTimestampBound(
cc->InputTimestamp().NextAllowedInStream());
}
return absl::OkStatus();
}
private:
int packet_count = 0;
};
REGISTER_CALCULATOR(DropCalculator);
// Tests demonstrating an FlowLimiterCalculator processing FINISHED timestamps.
class FlowLimiterCalculatorTest : public testing::Test {
protected:
CalculatorGraphConfig InflightGraphConfig() {
return ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in_1'
node {
calculator: 'FlowLimiterCalculator'
input_side_packet: 'OPTIONS:limiter_options'
input_stream: 'in_1'
input_stream: 'FINISHED:out_1'
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
output_stream: 'in_1_sampled'
output_stream: 'ALLOW:allow'
}
node {
calculator: 'SleepCalculator'
input_side_packet: 'WARMUP_TIME:warmup_time'
input_side_packet: 'SLEEP_TIME:sleep_time'
input_side_packet: 'CLOCK:clock'
input_stream: 'PACKET:in_1_sampled'
output_stream: 'PACKET:out_1_sampled'
}
node {
calculator: 'DropCalculator'
input_side_packet: "DROP_TIMESTAMPS:drop_timesamps"
input_stream: 'PACKET:out_1_sampled'
output_stream: 'PACKET:out_1'
}
)pb");
}
// Parse an absl::Time from RFC3339 format.
absl::Time ParseTime(const std::string& date_time_str) {
absl::Time result;
absl::ParseTime(absl::RFC3339_sec, date_time_str, &result, nullptr);
return result;
}
// The point in simulated time when the test starts.
absl::Time StartTime() { return ParseTime("2020-11-03T20:00:00Z"); }
// Initialize the test clock to follow simulated time.
void SetUpSimulationClock() {
auto executor = std::make_shared<SimulationClockExecutor>(8);
simulation_clock_ = executor->GetClock();
clock_ = simulation_clock_.get();
simulation_clock_->ThreadStart();
clock_->SleepUntil(StartTime());
simulation_clock_->ThreadFinish();
MP_ASSERT_OK(graph_.SetExecutor("", executor));
}
// Initialize the test clock to follow wall time.
void SetUpRealClock() { clock_ = mediapipe::Clock::RealClock(); }
// Create a few mediapipe input Packets holding ints.
void SetUpInputData() {
for (int i = 0; i < 100; ++i) {
input_packets_.push_back(MakePacket<int>(i).At(Timestamp(i * 10000)));
}
}
protected:
CalculatorGraph graph_;
mediapipe::Clock* clock_;
std::shared_ptr<SimulationClock> simulation_clock_;
std::vector<Packet> input_packets_;
std::vector<Packet> out_1_packets_;
std::vector<Packet> allow_packets_;
};
// Shows that "FINISHED" can be indicated with either a packet or a timestamp
// bound. DropCalculator periodically drops one packet but always propagates
// the timestamp bound. Input packets are released or dropped promptly after
// each "FINISH" packet or a timestamp bound arrives.
TEST_F(FlowLimiterCalculatorTest, FinishedTimestamps) {
// Configure the test.
SetUpInputData();
SetUpSimulationClock();
CalculatorGraphConfig graph_config = InflightGraphConfig();
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
max_in_flight: 1
max_in_queue: 1
)pb");
std::map<std::string, Packet> side_packets = {
{"limiter_options",
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
{"warmup_time", MakePacket<int64>(22000)},
{"sleep_time", MakePacket<int64>(22000)},
{"drop_timesamps", MakePacket<bool>(false)},
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
};
// Start the graph.
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
out_1_packets_.push_back(p);
return absl::OkStatus();
}));
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
allow_packets_.push_back(p);
return absl::OkStatus();
}));
simulation_clock_->ThreadStart();
MP_ASSERT_OK(graph_.StartRun(side_packets));
// Add 9 input packets.
// 1. packet-0 is released,
// 2. packet-1 is queued,
// 3. packet-2 is queued and packet-1 is dropped,
// 4. packet-2 is released, and so forth.
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0]));
clock_->Sleep(absl::Microseconds(1));
EXPECT_EQ(allow_packets_.size(), 1);
EXPECT_EQ(allow_packets_.back().Get<bool>(), true);
clock_->Sleep(absl::Microseconds(10000));
for (int i = 1; i < 8; i += 2) {
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
clock_->Sleep(absl::Microseconds(10000));
EXPECT_EQ(allow_packets_.size(), i);
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i + 1]));
clock_->Sleep(absl::Microseconds(1));
EXPECT_EQ(allow_packets_.size(), i + 1);
EXPECT_EQ(allow_packets_.back().Get<bool>(), false);
clock_->Sleep(absl::Microseconds(10000));
EXPECT_EQ(allow_packets_.size(), i + 2);
EXPECT_EQ(allow_packets_.back().Get<bool>(), true);
}
// Finish the graph.
MP_EXPECT_OK(graph_.CloseAllPacketSources());
clock_->Sleep(absl::Microseconds(40000));
MP_EXPECT_OK(graph_.WaitUntilDone());
simulation_clock_->ThreadFinish();
// Validate the output.
// input_packets_[4] is dropped by the DropCalculator.
std::vector<Packet> expected_output = {input_packets_[0], input_packets_[2],
input_packets_[6], input_packets_[8]};
EXPECT_EQ(out_1_packets_, expected_output);
}
// Shows that an output packet can be lost completely, and the
// FlowLimiterCalculator will stop waiting for it after in_flight_timeout.
// DropCalculator completely loses one packet including its timestamp bound.
// FlowLimiterCalculator waits 100 ms, and then starts releasing packets again.
TEST_F(FlowLimiterCalculatorTest, FinishedLost) {
// Configure the test.
SetUpInputData();
SetUpSimulationClock();
CalculatorGraphConfig graph_config = InflightGraphConfig();
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
max_in_flight: 1
max_in_queue: 1
in_flight_timeout: 100000 # 100 ms
)pb");
std::map<std::string, Packet> side_packets = {
{"limiter_options",
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
{"warmup_time", MakePacket<int64>(22000)},
{"sleep_time", MakePacket<int64>(22000)},
{"drop_timesamps", MakePacket<bool>(true)},
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
};
// Start the graph.
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
out_1_packets_.push_back(p);
return absl::OkStatus();
}));
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
allow_packets_.push_back(p);
return absl::OkStatus();
}));
simulation_clock_->ThreadStart();
MP_ASSERT_OK(graph_.StartRun(side_packets));
// Add 21 input packets.
// 1. packet-0 is released, packet-1 queued and dropped, and so forth.
// 2. packet-4 is lost by DropCalculator.
// 3. packet-5 through 13 are dropped while waiting for packet-4.
// 4. packet-4 expires and queued packet-14 is released.
// 5. packet-17, 19, and 20 are released on time.
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0]));
clock_->Sleep(absl::Microseconds(10000));
for (int i = 1; i < 21; ++i) {
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
clock_->Sleep(absl::Microseconds(10000));
}
// Finish the graph.
MP_EXPECT_OK(graph_.CloseAllPacketSources());
clock_->Sleep(absl::Microseconds(40000));
MP_EXPECT_OK(graph_.WaitUntilDone());
simulation_clock_->ThreadFinish();
// Validate the output.
// input_packets_[4] is lost by the DropCalculator.
std::vector<Packet> expected_output = {
input_packets_[0], input_packets_[2], input_packets_[14],
input_packets_[17], input_packets_[19], input_packets_[20],
};
EXPECT_EQ(out_1_packets_, expected_output);
}
// Shows what happens when a finish packet is delayed beyond in_flight_timeout.
// After in_flight_timeout, FlowLimiterCalculator continues releasing packets.
// Temporarily, more than max_in_flight frames are in flight.
// Eventually, the number of frames in flight returns to max_in_flight.
TEST_F(FlowLimiterCalculatorTest, FinishedDelayed) {
// Configure the test.
SetUpInputData();
SetUpSimulationClock();
CalculatorGraphConfig graph_config = InflightGraphConfig();
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
max_in_flight: 1
max_in_queue: 1
in_flight_timeout: 100000 # 100 ms
)pb");
std::map<std::string, Packet> side_packets = {
{"limiter_options",
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
{"warmup_time", MakePacket<int64>(500000)},
{"sleep_time", MakePacket<int64>(22000)},
{"drop_timesamps", MakePacket<bool>(false)},
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
};
// Start the graph.
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
out_1_packets_.push_back(p);
return absl::OkStatus();
}));
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
allow_packets_.push_back(p);
return absl::OkStatus();
}));
simulation_clock_->ThreadStart();
MP_ASSERT_OK(graph_.StartRun(side_packets));
// Add 71 input packets.
// 1. During the 500 ms WARMUP_TIME, the in_flight_timeout releases
// packets 0, 10, 20, 30, 40, 50, which are queued at the SleepCalculator.
// 2. During the next 120 ms, these 6 packets are processed.
// 3. After the graph is finally finished with warmup and the backlog packets,
// packets 60 through 70 are released and processed on time.
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0]));
clock_->Sleep(absl::Microseconds(10000));
for (int i = 1; i < 71; ++i) {
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
clock_->Sleep(absl::Microseconds(10000));
}
// Finish the graph.
MP_EXPECT_OK(graph_.CloseAllPacketSources());
clock_->Sleep(absl::Microseconds(40000));
MP_EXPECT_OK(graph_.WaitUntilDone());
simulation_clock_->ThreadFinish();
// Validate the output.
// The graph is warming up or backlogged until packet 60.
std::vector<Packet> expected_output = {
input_packets_[0], input_packets_[10], input_packets_[30],
input_packets_[40], input_packets_[50], input_packets_[60],
input_packets_[63], input_packets_[65], input_packets_[67],
input_packets_[69], input_packets_[70],
};
EXPECT_EQ(out_1_packets_, expected_output);
}
// Shows that packets on auxiliary input streams are relesed for the same
// timestamps as the main input stream, whether the auxiliary packets arrive
// early or late.
TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
// Configure the test.
SetUpInputData();
SetUpSimulationClock();
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in_1'
input_stream: 'in_2'
node {
calculator: 'FlowLimiterCalculator'
input_side_packet: 'OPTIONS:limiter_options'
input_stream: 'in_1'
input_stream: 'in_2'
input_stream: 'FINISHED:out_1'
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
output_stream: 'in_1_sampled'
output_stream: 'in_2_sampled'
output_stream: 'ALLOW:allow'
}
node {
calculator: 'SleepCalculator'
input_side_packet: 'WARMUP_TIME:warmup_time'
input_side_packet: 'SLEEP_TIME:sleep_time'
input_side_packet: 'CLOCK:clock'
input_stream: 'PACKET:in_1_sampled'
output_stream: 'PACKET:out_1_sampled'
}
node {
calculator: 'DropCalculator'
input_side_packet: "DROP_TIMESTAMPS:drop_timesamps"
input_stream: 'PACKET:out_1_sampled'
output_stream: 'PACKET:out_1'
}
)pb");
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
max_in_flight: 1
max_in_queue: 1
in_flight_timeout: 100000 # 100 ms
)pb");
std::map<std::string, Packet> side_packets = {
{"limiter_options",
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
{"warmup_time", MakePacket<int64>(22000)},
{"sleep_time", MakePacket<int64>(22000)},
{"drop_timesamps", MakePacket<bool>(true)},
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
};
// Start the graph.
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
out_1_packets_.push_back(p);
return absl::OkStatus();
}));
std::vector<Packet> out_2_packets;
MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) {
out_2_packets.push_back(p);
return absl::OkStatus();
}));
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
allow_packets_.push_back(p);
return absl::OkStatus();
}));
simulation_clock_->ThreadStart();
MP_ASSERT_OK(graph_.StartRun(side_packets));
// Add packets 0..9 to stream in_1, and packets 0..10 to stream in_2.
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0]));
clock_->Sleep(absl::Microseconds(10000));
for (int i = 1; i < 10; ++i) {
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i - 1]));
clock_->Sleep(absl::Microseconds(10000));
}
// Add packets 10..20 to stream in_1, and packets 11..21 to stream in_2.
for (int i = 10; i < 21; ++i) {
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i + 1]));
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
clock_->Sleep(absl::Microseconds(10000));
}
// Finish the graph run.
MP_EXPECT_OK(graph_.CloseAllPacketSources());
clock_->Sleep(absl::Microseconds(40000));
MP_EXPECT_OK(graph_.WaitUntilDone());
simulation_clock_->ThreadFinish();
// Validate the output.
// Packet input_packets_[4] is lost by the DropCalculator.
std::vector<Packet> expected_output = {
input_packets_[0], input_packets_[2], input_packets_[14],
input_packets_[17], input_packets_[19], input_packets_[20],
};
EXPECT_EQ(out_1_packets_, expected_output);
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
std::vector<Packet> expected_output_2 = {
input_packets_[0], input_packets_[2], input_packets_[4],
input_packets_[14], input_packets_[17], input_packets_[19],
input_packets_[20],
};
EXPECT_EQ(out_2_packets, expected_output_2);
}
// Shows how FlowLimiterCalculator releases packets with max_in_queue 0.
// Shows how auxiliary input streams still work with max_in_queue 0.
// The processing time "sleep_time" is reduced from 22ms to 12ms to create
// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
// Configure the test.
SetUpInputData();
SetUpSimulationClock();
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in_1'
input_stream: 'in_2'
node {
calculator: 'FlowLimiterCalculator'
input_side_packet: 'OPTIONS:limiter_options'
input_stream: 'in_1'
input_stream: 'in_2'
input_stream: 'FINISHED:out_1'
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
output_stream: 'in_1_sampled'
output_stream: 'in_2_sampled'
output_stream: 'ALLOW:allow'
}
node {
calculator: 'SleepCalculator'
input_side_packet: 'WARMUP_TIME:warmup_time'
input_side_packet: 'SLEEP_TIME:sleep_time'
input_side_packet: 'CLOCK:clock'
input_stream: 'PACKET:in_1_sampled'
output_stream: 'PACKET:out_1_sampled'
}
node {
calculator: 'DropCalculator'
input_side_packet: "DROP_TIMESTAMPS:drop_timesamps"
input_stream: 'PACKET:out_1_sampled'
output_stream: 'PACKET:out_1'
}
)pb");
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
max_in_flight: 1
max_in_queue: 0
in_flight_timeout: 100000 # 100 ms
)pb");
std::map<std::string, Packet> side_packets = {
{"limiter_options",
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
{"warmup_time", MakePacket<int64>(12000)},
{"sleep_time", MakePacket<int64>(12000)},
{"drop_timesamps", MakePacket<bool>(true)},
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
};
// Start the graph.
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
out_1_packets_.push_back(p);
return absl::OkStatus();
}));
std::vector<Packet> out_2_packets;
MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) {
out_2_packets.push_back(p);
return absl::OkStatus();
}));
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
allow_packets_.push_back(p);
return absl::OkStatus();
}));
simulation_clock_->ThreadStart();
MP_ASSERT_OK(graph_.StartRun(side_packets));
// Add packets 0..9 to stream in_1, and packets 0..10 to stream in_2.
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0]));
clock_->Sleep(absl::Microseconds(10000));
for (int i = 1; i < 10; ++i) {
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i - 1]));
clock_->Sleep(absl::Microseconds(10000));
}
// Add packets 10..20 to stream in_1, and packets 11..21 to stream in_2.
for (int i = 10; i < 21; ++i) {
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i + 1]));
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
clock_->Sleep(absl::Microseconds(10000));
}
// Finish the graph run.
MP_EXPECT_OK(graph_.CloseAllPacketSources());
clock_->Sleep(absl::Microseconds(40000));
MP_EXPECT_OK(graph_.WaitUntilDone());
simulation_clock_->ThreadFinish();
// Validate the output.
// Packet input_packets_[4] is lost by the DropCalculator.
std::vector<Packet> expected_output = {
input_packets_[0], input_packets_[2], input_packets_[15],
input_packets_[17], input_packets_[19],
};
EXPECT_EQ(out_1_packets_, expected_output);
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
std::vector<Packet> expected_output_2 = {
input_packets_[0], input_packets_[2], input_packets_[4],
input_packets_[15], input_packets_[17], input_packets_[19],
};
EXPECT_EQ(out_2_packets, expected_output_2);
}
} // anonymous namespace
} // namespace mediapipe

View File

@ -1,219 +0,0 @@
// Copyright 2019-2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/gate_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/header_util.h"
namespace mediapipe {
namespace {
enum GateState {
GATE_UNINITIALIZED,
GATE_ALLOW,
GATE_DISALLOW,
};
std::string ToString(GateState state) {
switch (state) {
case GATE_UNINITIALIZED:
return "UNINITIALIZED";
case GATE_ALLOW:
return "ALLOW";
case GATE_DISALLOW:
return "DISALLOW";
}
DLOG(FATAL) << "Unknown GateState";
return "UNKNOWN";
}
} // namespace
// Controls whether or not the input packets are passed further along the graph.
// Takes multiple data input streams and either an ALLOW or a DISALLOW control
// input stream. It outputs an output stream for each input stream that is not
// ALLOW or DISALLOW as well as an optional STATE_CHANGE stream which downstream
// calculators can use to respond to state-change events.
//
// If the current ALLOW packet is set to true, the input packets are passed to
// their corresponding output stream unchanged. If the ALLOW packet is set to
// false, the current input packet is NOT passed to the output stream. If using
// DISALLOW, the behavior is opposite of ALLOW.
//
// By default, an empty packet in the ALLOW or DISALLOW input stream indicates
// disallowing the corresponding packets in other input streams. The behavior
// can be inverted with a calculator option.
//
// ALLOW or DISALLOW can also be specified as an input side packet. The rules
// for evaluation remain the same as above.
//
// ALLOW/DISALLOW inputs must be specified either using input stream or
// via input side packet but not both.
//
// Intended to be used with the default input stream handler, which synchronizes
// all data input streams with the ALLOW/DISALLOW control input stream.
//
// Example config:
// node {
// calculator: "GateCalculator"
// input_side_packet: "ALLOW:allow" or "DISALLOW:disallow"
// input_stream: "input_stream0"
// input_stream: "input_stream1"
// input_stream: "input_streamN"
// input_stream: "ALLOW:allow" or "DISALLOW:disallow"
// output_stream: "STATE_CHANGE:state_change"
// output_stream: "output_stream0"
// output_stream: "output_stream1"
// output_stream: "output_streamN"
// }
class GateCalculator : public CalculatorBase {
public:
GateCalculator() {}
static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) {
bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") ||
cc->InputSidePackets().HasTag("DISALLOW");
bool input_via_stream =
cc->Inputs().HasTag("ALLOW") || cc->Inputs().HasTag("DISALLOW");
// Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW
// input.
RET_CHECK(input_via_side_packet ^ input_via_stream);
if (input_via_side_packet) {
RET_CHECK(cc->InputSidePackets().HasTag("ALLOW") ^
cc->InputSidePackets().HasTag("DISALLOW"));
if (cc->InputSidePackets().HasTag("ALLOW")) {
cc->InputSidePackets().Tag("ALLOW").Set<bool>();
} else {
cc->InputSidePackets().Tag("DISALLOW").Set<bool>();
}
} else {
RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW"));
if (cc->Inputs().HasTag("ALLOW")) {
cc->Inputs().Tag("ALLOW").Set<bool>();
} else {
cc->Inputs().Tag("DISALLOW").Set<bool>();
}
}
return absl::OkStatus();
}
static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc));
const int num_data_streams = cc->Inputs().NumEntries("");
RET_CHECK_GE(num_data_streams, 1);
RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams)
<< "Number of data output streams must match with data input streams.";
for (int i = 0; i < num_data_streams; ++i) {
cc->Inputs().Get("", i).SetAny();
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
}
if (cc->Outputs().HasTag("STATE_CHANGE")) {
cc->Outputs().Tag("STATE_CHANGE").Set<bool>();
}
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
use_side_packet_for_allow_disallow_ = false;
if (cc->InputSidePackets().HasTag("ALLOW")) {
use_side_packet_for_allow_disallow_ = true;
allow_by_side_packet_decision_ =
cc->InputSidePackets().Tag("ALLOW").Get<bool>();
} else if (cc->InputSidePackets().HasTag("DISALLOW")) {
use_side_packet_for_allow_disallow_ = true;
allow_by_side_packet_decision_ =
!cc->InputSidePackets().Tag("DISALLOW").Get<bool>();
}
cc->SetOffset(TimestampDiff(0));
num_data_streams_ = cc->Inputs().NumEntries("");
last_gate_state_ = GATE_UNINITIALIZED;
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>();
empty_packets_as_allow_ = options.empty_packets_as_allow();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
bool allow = empty_packets_as_allow_;
if (use_side_packet_for_allow_disallow_) {
allow = allow_by_side_packet_decision_;
} else {
if (cc->Inputs().HasTag("ALLOW") &&
!cc->Inputs().Tag("ALLOW").IsEmpty()) {
allow = cc->Inputs().Tag("ALLOW").Get<bool>();
}
if (cc->Inputs().HasTag("DISALLOW") &&
!cc->Inputs().Tag("DISALLOW").IsEmpty()) {
allow = !cc->Inputs().Tag("DISALLOW").Get<bool>();
}
}
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
if (cc->Outputs().HasTag("STATE_CHANGE")) {
if (last_gate_state_ != GATE_UNINITIALIZED &&
last_gate_state_ != new_gate_state) {
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
<< cc->InputTimestamp().Value() << " from "
<< ToString(last_gate_state_) << " to "
<< ToString(new_gate_state);
cc->Outputs()
.Tag("STATE_CHANGE")
.AddPacket(MakePacket<bool>(allow).At(cc->InputTimestamp()));
}
}
last_gate_state_ = new_gate_state;
if (!allow) {
// Close the output streams if the gate will be permanently closed.
// Prevents buffering in calculators whose parents do no use SetOffset.
for (int i = 0; i < num_data_streams_; ++i) {
if (!cc->Outputs().Get("", i).IsClosed() &&
use_side_packet_for_allow_disallow_) {
cc->Outputs().Get("", i).Close();
}
}
return absl::OkStatus();
}
// Process data streams.
for (int i = 0; i < num_data_streams_; ++i) {
if (!cc->Inputs().Get("", i).IsEmpty()) {
cc->Outputs().Get("", i).AddPacket(cc->Inputs().Get("", i).Value());
}
}
return absl::OkStatus();
}
private:
GateState last_gate_state_ = GATE_UNINITIALIZED;
int num_data_streams_;
bool empty_packets_as_allow_;
bool use_side_packet_for_allow_disallow_;
bool allow_by_side_packet_decision_;
};
REGISTER_CALCULATOR(GateCalculator);
} // namespace mediapipe

View File

@ -1,32 +0,0 @@
// Copyright 2019-2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message GateCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional GateCalculatorOptions ext = 261754847;
}
// By default an empty packet in the ALLOW or DISALLOW input stream indicates
// disallowing the corresponding packets in the data input streams. Setting
// this option to true inverts that, allowing the data packets to go through.
optional bool empty_packets_as_allow = 1;
}

View File

@ -1,334 +0,0 @@
// Copyright 2019-2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
class GateCalculatorTest : public ::testing::Test {
protected:
// Helper to run a graph and return status.
static absl::Status RunGraph(const std::string& proto) {
auto runner = absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(proto));
return runner->Run();
}
// Use this when ALLOW/DISALLOW input is provided as a side packet.
void RunTimeStep(int64 timestamp, bool stream_payload) {
runner_->MutableInputs()->Get("", 0).packets.push_back(
MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
}
// Use this when ALLOW/DISALLOW input is provided as an input stream.
void RunTimeStep(int64 timestamp, const std::string& control_tag,
bool control) {
runner_->MutableInputs()->Get("", 0).packets.push_back(
MakePacket<bool>(true).At(Timestamp(timestamp)));
runner_->MutableInputs()
->Tag(control_tag)
.packets.push_back(MakePacket<bool>(control).At(Timestamp(timestamp)));
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
}
void SetRunner(const std::string& proto) {
runner_ = absl::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(proto));
}
CalculatorRunner* runner() { return runner_.get(); }
private:
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(GateCalculatorTest, InvalidInputs) {
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
input_stream: "DISALLOW:gating_stream"
output_stream: "test_output"
)")));
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_side_packet: "ALLOW:gating_stream"
input_side_packet: "DISALLOW:gating_stream"
output_stream: "test_output"
)")));
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
input_side_packet: "ALLOW:gating_stream"
output_stream: "test_output"
)")));
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "DISALLOW:gating_stream"
input_side_packet: "DISALLOW:gating_stream"
output_stream: "test_output"
)")));
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
input_side_packet: "DISALLOW:gating_stream"
output_stream: "test_output"
)")));
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "DISALLOW:gating_stream"
input_side_packet: "ALLOW:gating_stream"
output_stream: "test_output"
)")));
}
TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
SetRunner(R"(
calculator: "GateCalculator"
input_side_packet: "ALLOW:gating_stream"
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue1, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>());
EXPECT_EQ(false, output[1].Get<bool>());
}
TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
SetRunner(R"(
calculator: "GateCalculator"
input_side_packet: "DISALLOW:gating_stream"
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue1, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>());
EXPECT_EQ(false, output[1].Get<bool>());
}
TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
SetRunner(R"(
calculator: "GateCalculator"
input_side_packet: "ALLOW:gating_stream"
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(0, output.size());
}
TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
SetRunner(R"(
calculator: "GateCalculator"
input_side_packet: "DISALLOW:gating_stream"
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(0, output.size());
}
TEST_F(GateCalculatorTest, Allow) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
output_stream: "test_output"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "ALLOW", false);
constexpr int64 kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "ALLOW", true);
constexpr int64 kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "ALLOW", false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue2, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>());
EXPECT_EQ(true, output[1].Get<bool>());
}
TEST_F(GateCalculatorTest, Disallow) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "DISALLOW:gating_stream"
output_stream: "test_output"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "DISALLOW", false);
constexpr int64 kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "DISALLOW", true);
constexpr int64 kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "DISALLOW", false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>());
EXPECT_EQ(true, output[1].Get<bool>());
}
TEST_F(GateCalculatorTest, AllowWithStateChange) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
output_stream: "test_output"
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", false);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "ALLOW", true);
constexpr int64 kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "ALLOW", true);
constexpr int64 kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "ALLOW", false);
const std::vector<Packet>& output =
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>()); // Allow.
EXPECT_EQ(false, output[1].Get<bool>()); // Disallow.
}
TEST_F(GateCalculatorTest, DisallowWithStateChange) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "DISALLOW:gating_stream"
output_stream: "test_output"
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", true);
constexpr int64 kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "DISALLOW", false);
constexpr int64 kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "DISALLOW", false);
constexpr int64 kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "DISALLOW", true);
const std::vector<Packet>& output =
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
ASSERT_EQ(2, output.size());
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
EXPECT_EQ(true, output[0].Get<bool>()); // Allow.
EXPECT_EQ(false, output[1].Get<bool>()); // Disallow.
}
// Must not detect disallow value for first timestamp as a state change.
TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "DISALLOW:gating_stream"
output_stream: "test_output"
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", false);
const std::vector<Packet>& output =
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
ASSERT_EQ(0, output.size());
}
// Must not detect allow value for first timestamp as a state change.
TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
SetRunner(R"(
calculator: "GateCalculator"
input_stream: "test_input"
input_stream: "ALLOW:gating_stream"
output_stream: "test_output"
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true);
const std::vector<Packet>& output =
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
ASSERT_EQ(0, output.size());
}
} // namespace
} // namespace mediapipe

View File

@ -1,94 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// This Calculator multiplexes several input streams into a single
// output stream, dropping input packets with timestamps older than the
// last output packet. In case two packets arrive with the same timestamp,
// the packet with the lower stream index will be output and the rest will
// be dropped.
//
// This Calculator optionally produces a finish inidicator as its second
// output stream. One indicator packet is produced for each input packet
// received.
//
// This Calculator can be used with an ImmediateInputStreamHandler or with the
// default ISH.
//
// This Calculator is designed to work with a Demux calculator such as
// the RoundRobinDemuxCalculator. Therefore, packets from different
// input streams are normally not expected to have the same timestamp.
//
// NOTE: this calculator can drop packets non-deterministically, depending on
// how fast the input streams are fed. In most cases, MuxCalculator should be
// preferred. In particular, dropping packets can interfere with rate limiting
// mechanisms.
class ImmediateMuxCalculator : public CalculatorBase {
public:
// This calculator combines any set of input streams into a single
// output stream. All input stream types must match the output stream type.
static absl::Status GetContract(CalculatorContract* cc);
// Passes any input packet to the output stream immediately, unless the
// packet timestamp is lower than a previously passed packet.
absl::Status Process(CalculatorContext* cc) override;
absl::Status Open(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(ImmediateMuxCalculator);
absl::Status ImmediateMuxCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Outputs().NumEntries() >= 1 && cc->Outputs().NumEntries() <= 2)
<< "This calculator produces only one or two output streams.";
cc->Outputs().Index(0).SetAny();
if (cc->Outputs().NumEntries() >= 2) {
cc->Outputs().Index(1).Set<bool>();
}
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).SetSameAs(&cc->Outputs().Index(0));
}
return absl::OkStatus();
}
absl::Status ImmediateMuxCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status ImmediateMuxCalculator::Process(CalculatorContext* cc) {
// Pass along the first packet, unless it has been superseded.
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
const Packet& packet = cc->Inputs().Index(i).Value();
if (!packet.IsEmpty()) {
if (packet.Timestamp() >= cc->Outputs().Index(0).NextTimestampBound()) {
cc->Outputs().Index(0).AddPacket(packet);
} else {
LOG_FIRST_N(WARNING, 5)
<< "Dropping a packet with timestamp " << packet.Timestamp();
}
if (cc->Outputs().NumEntries() >= 2) {
Timestamp output_timestamp = std::max(
cc->InputTimestamp(), cc->Outputs().Index(1).NextTimestampBound());
cc->Outputs().Index(1).Add(new bool(true), output_timestamp);
}
}
}
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -1,373 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include <atomic>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/port/threadpool.h"
#include "mediapipe/framework/tool/sink.h"
// Tests for ImmediateMuxCalculator. These tests show how parallel output
// packets are handled when they arrive in various orders.
using testing::ElementsAre;
namespace mediapipe {
namespace {
// A simple Semaphore for synchronizing test threads.
class AtomicSemaphore {
public:
AtomicSemaphore(int64_t supply) : supply_(supply) {}
void Acquire(int64_t amount) {
while (supply_.fetch_sub(amount) - amount < 0) {
Release(amount);
}
}
void Release(int64_t amount) { supply_.fetch_add(amount); }
private:
std::atomic<int64_t> supply_;
};
// A mediapipe::Executor that signals the start and finish of each task.
// Provides 4 worker threads.
class CountingExecutor : public Executor {
public:
CountingExecutor(std::function<void()> start_callback,
std::function<void()> finish_callback)
: thread_pool_(4),
start_callback_(std::move(start_callback)),
finish_callback_(std::move(finish_callback)) {
thread_pool_.StartWorkers();
}
void Schedule(std::function<void()> task) override {
start_callback_();
thread_pool_.Schedule([this, task] {
task();
finish_callback_();
});
}
private:
ThreadPool thread_pool_;
std::function<void()> start_callback_;
std::function<void()> finish_callback_;
};
// Returns a new mediapipe::Executor with 4 worker threads.
std::shared_ptr<Executor> MakeExecutor(std::function<void()> start_callback,
std::function<void()> finish_callback) {
return std::make_shared<CountingExecutor>(start_callback, finish_callback);
}
// Tests showing ImmediateMuxCalculator dropping packets in various sequences.
class ImmediateMuxCalculatorTest : public ::testing::Test {
protected:
void SetUpMuxGraph() {
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
input_stream: "input_packets_0"
input_stream: "input_packets_1"
node {
calculator: "ImmediateMuxCalculator"
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
input_stream: "input_packets_0"
input_stream: "input_packets_1"
output_stream: "output_packets_0"
}
)",
&graph_config_));
}
void SetUpDemuxGraph() {
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
input_stream: "input_packets_0"
node {
calculator: "RoundRobinDemuxCalculator"
input_stream: "input_packets_0"
output_stream: "OUTPUT:0:input_0"
output_stream: "OUTPUT:1:input_1"
}
node {
calculator: "LambdaCalculator"
input_side_packet: 'callback_0'
input_stream: "input_0"
output_stream: "output_0"
}
node {
calculator: "LambdaCalculator"
input_side_packet: 'callback_1'
input_stream: "input_1"
output_stream: "output_1"
}
node {
calculator: "ImmediateMuxCalculator"
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
input_stream: "output_0"
input_stream: "output_1"
output_stream: "output_packets_0"
}
)",
&graph_config_));
}
void SetUpDemuxInFlightGraph() {
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
input_stream: "input_packets_0"
node {
calculator: 'FlowLimiterCalculator'
input_stream_handler {
input_stream_handler: 'ImmediateInputStreamHandler'
}
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
input_stream: 'input_packets_0'
input_stream: 'FINISHED:finish_indicator'
input_stream_info: {
tag_index: 'FINISHED'
back_edge: true
}
output_stream: 'input_0_sampled'
}
node {
calculator: "RoundRobinDemuxCalculator"
input_stream: "input_0_sampled"
output_stream: "OUTPUT:0:input_0"
output_stream: "OUTPUT:1:input_1"
}
node {
calculator: "LambdaCalculator"
input_side_packet: 'callback_0'
input_stream: "input_0"
output_stream: "output_0"
}
node {
calculator: "LambdaCalculator"
input_side_packet: 'callback_1'
input_stream: "input_1"
output_stream: "output_1"
}
node {
calculator: "ImmediateMuxCalculator"
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
input_stream: "output_0"
input_stream: "output_1"
output_stream: 'output_packets_0'
output_stream: 'finish_indicator'
}
)",
&graph_config_));
}
static Packet PacketAt(int64 ts) {
return Adopt(new int64(999)).At(Timestamp(ts));
}
static Packet None() { return Packet().At(Timestamp::OneOverPostStream()); }
static bool IsNone(const Packet& packet) {
return packet.Timestamp() == Timestamp::OneOverPostStream();
}
// Return the values of the timestamps of a vector of Packets.
static std::vector<int64> TimestampValues(
const std::vector<Packet>& packets) {
std::vector<int64> result;
for (const Packet& p : packets) {
result.push_back(p.Timestamp().Value());
}
return result;
}
// Runs a CalculatorGraph with a series of packet sets.
// Returns a vector of packets from each graph output stream.
void RunGraph(const std::vector<std::vector<Packet>>& input_sets,
std::vector<Packet>* output_packets) {
// Register output packet observers.
tool::AddVectorSink("output_packets_0", &graph_config_, output_packets);
// Start running the graph.
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config_));
MP_ASSERT_OK(graph.StartRun({}));
// Send each packet to the graph in the specified order.
for (int t = 0; t < input_sets.size(); t++) {
const std::vector<Packet>& input_set = input_sets[t];
MP_EXPECT_OK(graph.WaitUntilIdle());
for (int i = 0; i < input_set.size(); i++) {
const Packet& packet = input_set[i];
if (!IsNone(packet)) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
absl::StrCat("input_packets_", i), packet));
}
}
}
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
CalculatorGraphConfig graph_config_;
};
TEST_F(ImmediateMuxCalculatorTest, IncreasingTimestamps) {
// Run the graph with a series of packet sets.
std::vector<std::vector<Packet>> input_sets = {
{PacketAt(10000), None()}, //
{PacketAt(20000), None()}, //
{None(), PacketAt(30000)}, //
{None(), PacketAt(40000)},
};
SetUpMuxGraph();
std::vector<Packet> output_packets;
RunGraph(input_sets, &output_packets);
// Validate the output packets.
EXPECT_THAT(TimestampValues(output_packets),
ElementsAre(10000, 20000, 30000, 40000));
}
TEST_F(ImmediateMuxCalculatorTest, SupersededTimestamp) {
// Run the graph with a series of packet sets.
std::vector<std::vector<Packet>> input_sets = {
{PacketAt(10000), None()}, //
{PacketAt(30000), None()}, //
{None(), PacketAt(20000)}, //
{None(), PacketAt(40000)},
};
SetUpMuxGraph();
std::vector<Packet> output_packets;
RunGraph(input_sets, &output_packets);
// Output packet 20000 is superseded and dropped.
EXPECT_THAT(TimestampValues(output_packets),
ElementsAre(10000, 30000, 40000));
}
TEST_F(ImmediateMuxCalculatorTest, SimultaneousTimestamps) {
// Run the graph with a series of packet sets.
std::vector<std::vector<Packet>> input_sets = {
{PacketAt(10000), None()}, //
{PacketAt(40000), PacketAt(20000)}, //
{None(), PacketAt(30000)},
};
SetUpMuxGraph();
std::vector<Packet> output_packets;
RunGraph(input_sets, &output_packets);
// Output packet 20000 is superseded and dropped.
EXPECT_THAT(TimestampValues(output_packets), ElementsAre(10000, 40000));
}
// A Calculator::Process callback function.
typedef std::function<absl::Status(const InputStreamShardSet&,
OutputStreamShardSet*)>
ProcessFunction;
// A testing callback function that passes through all packets.
absl::Status PassThrough(const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
for (int i = 0; i < inputs.NumEntries(); ++i) {
if (!inputs.Index(i).Value().IsEmpty()) {
outputs->Index(i).AddPacket(inputs.Index(i).Value());
}
}
return absl::OkStatus();
}
TEST_F(ImmediateMuxCalculatorTest, Demux) {
// Semaphores to sequence the parallel Process outputs.
AtomicSemaphore semaphore_0(0);
AtomicSemaphore semaphore_1(0);
ProcessFunction wait_0 = [&semaphore_0](const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
semaphore_0.Acquire(1);
return PassThrough(inputs, outputs);
};
ProcessFunction wait_1 = [&semaphore_1](const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
semaphore_1.Acquire(1);
return PassThrough(inputs, outputs);
};
// A callback to await and capture output packets.
std::vector<Packet> out_packets;
absl::Mutex out_mutex;
auto out_cb = [&](const Packet& p) {
absl::MutexLock lock(&out_mutex);
out_packets.push_back(p);
return absl::OkStatus();
};
auto wait_for = [&](std::function<bool()> cond) {
absl::MutexLock lock(&out_mutex);
out_mutex.Await(absl::Condition(&cond));
};
SetUpDemuxGraph();
// Start the graph and add five input packets.
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config_,
{
{"callback_0", Adopt(new auto(wait_0))},
{"callback_1", Adopt(new auto(wait_1))},
}));
MP_ASSERT_OK(graph.ObserveOutputStream("output_packets_0", out_cb));
MP_ASSERT_OK(graph.StartRun({}));
MP_EXPECT_OK(
graph.AddPacketToInputStream("input_packets_0", PacketAt(10000)));
MP_EXPECT_OK(
graph.AddPacketToInputStream("input_packets_0", PacketAt(20000)));
MP_EXPECT_OK(
graph.AddPacketToInputStream("input_packets_0", PacketAt(30000)));
MP_EXPECT_OK(
graph.AddPacketToInputStream("input_packets_0", PacketAt(40000)));
MP_EXPECT_OK(
graph.AddPacketToInputStream("input_packets_0", PacketAt(50000)));
// Release the outputs in order 20000, 10000, 30000, 50000, 40000.
semaphore_1.Release(1); // 20000
wait_for([&] { return !out_packets.empty(); });
semaphore_0.Release(1); // 10000
semaphore_0.Release(1); // 30000
wait_for([&] { return out_packets.size() >= 2; });
semaphore_0.Release(1); // 50000
wait_for([&] { return out_packets.size() >= 3; });
semaphore_1.Release(1); // 40000
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
// Output packets 10000 and 40000 are superseded and dropped.
EXPECT_THAT(TimestampValues(out_packets), ElementsAre(20000, 30000, 50000));
}
} // namespace
} // namespace mediapipe

View File

@ -1,58 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <utility>
#include <vector>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// Given two input streams (A, B), output a single stream containing a pair<A,
// B>.
//
// Example config:
// node {
// calculator: "MakePairCalculator"
// input_stream: "packet_a"
// input_stream: "packet_b"
// output_stream: "output_pair_a_b"
// }
class MakePairCalculator : public Node {
public:
static constexpr Input<AnyType>::Multiple kIn{""};
// Note that currently api2::Packet is a different type from mediapipe::Packet
static constexpr Output<std::pair<mediapipe::Packet, mediapipe::Packet>>
kPair{""};
MEDIAPIPE_NODE_CONTRACT(kIn, kPair);
static absl::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK_EQ(kIn(cc).Count(), 2);
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
kPair(cc).Send({kIn(cc)[0].packet(), kIn(cc)[1].packet()});
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(MakePairCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -1,51 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "Eigen/Core"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// Perform a (left) matrix multiply. Meaning (output = A * input)
// where A is the matrix which is provided as an input side packet.
//
// Example config:
// node {
// calculator: "MatrixMultiplyCalculator"
// input_stream: "samples"
// output_stream: "multiplied_samples"
// input_side_packet: "multiplication_matrix"
// }
class MatrixMultiplyCalculator : public Node {
public:
static constexpr Input<Matrix> kIn{""};
static constexpr Output<Matrix> kOut{""};
static constexpr SideInput<Matrix> kSide{""};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kSide);
absl::Status Process(CalculatorContext* cc) override;
};
MEDIAPIPE_REGISTER_NODE(MatrixMultiplyCalculator);
absl::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) {
kOut(cc).Send(*kSide(cc) * *kIn(cc));
return absl::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -1,239 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "Eigen/Core"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
namespace {
// A 3x4 Matrix of random integers in [0,1000).
const char kMatrixText[] =
"rows: 3\n"
"cols: 4\n"
"packed_data: 387\n"
"packed_data: 940\n"
"packed_data: 815\n"
"packed_data: 825\n"
"packed_data: 997\n"
"packed_data: 884\n"
"packed_data: 419\n"
"packed_data: 763\n"
"packed_data: 123\n"
"packed_data: 30\n"
"packed_data: 825\n"
"packed_data: 299\n";
// A 4x20 Matrix of random integers in [0,10).
// Each column of this matrix is a sample.
const char kSamplesText[] =
"rows: 4\n"
"cols: 20\n"
"packed_data: 7\n"
"packed_data: 9\n"
"packed_data: 5\n"
"packed_data: 9\n"
"packed_data: 6\n"
"packed_data: 3\n"
"packed_data: 0\n"
"packed_data: 7\n"
"packed_data: 1\n"
"packed_data: 3\n"
"packed_data: 3\n"
"packed_data: 2\n"
"packed_data: 4\n"
"packed_data: 5\n"
"packed_data: 0\n"
"packed_data: 4\n"
"packed_data: 6\n"
"packed_data: 0\n"
"packed_data: 1\n"
"packed_data: 2\n"
"packed_data: 0\n"
"packed_data: 2\n"
"packed_data: 0\n"
"packed_data: 3\n"
"packed_data: 1\n"
"packed_data: 7\n"
"packed_data: 4\n"
"packed_data: 9\n"
"packed_data: 8\n"
"packed_data: 8\n"
"packed_data: 6\n"
"packed_data: 4\n"
"packed_data: 6\n"
"packed_data: 8\n"
"packed_data: 1\n"
"packed_data: 9\n"
"packed_data: 7\n"
"packed_data: 5\n"
"packed_data: 3\n"
"packed_data: 5\n"
"packed_data: 3\n"
"packed_data: 5\n"
"packed_data: 7\n"
"packed_data: 7\n"
"packed_data: 3\n"
"packed_data: 3\n"
"packed_data: 6\n"
"packed_data: 4\n"
"packed_data: 7\n"
"packed_data: 7\n"
"packed_data: 2\n"
"packed_data: 5\n"
"packed_data: 4\n"
"packed_data: 8\n"
"packed_data: 1\n"
"packed_data: 0\n"
"packed_data: 2\n"
"packed_data: 0\n"
"packed_data: 3\n"
"packed_data: 4\n"
"packed_data: 6\n"
"packed_data: 6\n"
"packed_data: 8\n"
"packed_data: 5\n"
"packed_data: 5\n"
"packed_data: 8\n"
"packed_data: 9\n"
"packed_data: 7\n"
"packed_data: 3\n"
"packed_data: 7\n"
"packed_data: 2\n"
"packed_data: 7\n"
"packed_data: 8\n"
"packed_data: 2\n"
"packed_data: 1\n"
"packed_data: 1\n"
"packed_data: 4\n"
"packed_data: 1\n"
"packed_data: 1\n"
"packed_data: 7\n";
// A 3x20 Matrix of expected values for the result of the matrix multiply
// computed using R.
// Each column of this matrix is an expected output.
const char kExpectedText[] =
"rows: 3\n"
"cols: 20\n"
"packed_data: 12499\n"
"packed_data: 26793\n"
"packed_data: 16967\n"
"packed_data: 5007\n"
"packed_data: 14406\n"
"packed_data: 9635\n"
"packed_data: 4179\n"
"packed_data: 7870\n"
"packed_data: 4434\n"
"packed_data: 5793\n"
"packed_data: 12045\n"
"packed_data: 8876\n"
"packed_data: 2801\n"
"packed_data: 8053\n"
"packed_data: 5611\n"
"packed_data: 1740\n"
"packed_data: 4469\n"
"packed_data: 2665\n"
"packed_data: 8108\n"
"packed_data: 18396\n"
"packed_data: 10186\n"
"packed_data: 12330\n"
"packed_data: 23374\n"
"packed_data: 15526\n"
"packed_data: 9611\n"
"packed_data: 21804\n"
"packed_data: 14776\n"
"packed_data: 8241\n"
"packed_data: 17979\n"
"packed_data: 11989\n"
"packed_data: 8429\n"
"packed_data: 18921\n"
"packed_data: 9819\n"
"packed_data: 6270\n"
"packed_data: 13689\n"
"packed_data: 7031\n"
"packed_data: 9472\n"
"packed_data: 19210\n"
"packed_data: 13634\n"
"packed_data: 8567\n"
"packed_data: 12499\n"
"packed_data: 10455\n"
"packed_data: 2151\n"
"packed_data: 7469\n"
"packed_data: 3195\n"
"packed_data: 10774\n"
"packed_data: 21851\n"
"packed_data: 12673\n"
"packed_data: 12516\n"
"packed_data: 25318\n"
"packed_data: 14347\n"
"packed_data: 7984\n"
"packed_data: 17100\n"
"packed_data: 10972\n"
"packed_data: 5195\n"
"packed_data: 11102\n"
"packed_data: 8710\n"
"packed_data: 3002\n"
"packed_data: 11295\n"
"packed_data: 6360\n";
// Send a number of samples through the MatrixMultiplyCalculator.
TEST(MatrixMultiplyCalculatorTest, Multiply) {
CalculatorRunner runner("MatrixMultiplyCalculator", "", 1, 1, 1);
Matrix* matrix = new Matrix();
MatrixFromTextProto(kMatrixText, matrix);
runner.MutableSidePackets()->Index(0) = Adopt(matrix);
Matrix samples;
MatrixFromTextProto(kSamplesText, &samples);
Matrix expected;
MatrixFromTextProto(kExpectedText, &expected);
CHECK_EQ(samples.cols(), expected.cols());
for (int i = 0; i < samples.cols(); ++i) {
// Take a column from samples and produce a packet with just that
// column in it as an input sample for the calculator.
Eigen::MatrixXf* sample = new Eigen::MatrixXf(samples.block(0, i, 4, 1));
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(sample).At(Timestamp(i)));
}
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(runner.MutableInputs()->Index(0).packets.size(),
runner.Outputs().Index(0).packets.size());
int i = 0;
for (const Packet& output : runner.Outputs().Index(0).packets) {
EXPECT_EQ(Timestamp(i), output.Timestamp());
const Eigen::MatrixXf& result = output.Get<Matrix>();
ASSERT_EQ(3, result.rows());
EXPECT_NEAR((expected.block(0, i, 3, 1) - result).cwiseAbs().sum(), 0.0,
1e-5);
++i;
}
EXPECT_EQ(samples.cols(), i);
}
} // namespace
} // namespace mediapipe

View File

@ -1,81 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "Eigen/Core"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// Subtract input matrix from the side input matrix and vice versa. The matrices
// must have the same dimension.
// Based on the tag (MINUEND vs SUBTRAHEND), the matrices in the input stream
// and input side packet can be either subtrahend or minuend. The output matrix
// is generated by performing minuend matrix - subtrahend matrix.
//
// Example config:
// node {
// calculator: "MatrixSubtractCalculator"
// input_stream: "MINUEND:input_matrix"
// input_side_packet: "SUBTRAHEND:side_matrix"
// output_stream: "output_matrix"
// }
//
// or
//
// node {
// calculator: "MatrixSubtractCalculator"
// input_stream: "SUBTRAHEND:input_matrix"
// input_side_packet: "MINUEND:side_matrix"
// output_stream: "output_matrix"
// }
class MatrixSubtractCalculator : public Node {
public:
static constexpr Input<Matrix>::SideFallback kMinuend{"MINUEND"};
static constexpr Input<Matrix>::SideFallback kSubtrahend{"SUBTRAHEND"};
static constexpr Output<Matrix> kOut{""};
MEDIAPIPE_NODE_CONTRACT(kMinuend, kSubtrahend, kOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Process(CalculatorContext* cc) override;
};
MEDIAPIPE_REGISTER_NODE(MatrixSubtractCalculator);
// static
absl::Status MatrixSubtractCalculator::UpdateContract(CalculatorContract* cc) {
// TODO: the next restriction could be relaxed.
RET_CHECK(kMinuend(cc).IsStream() ^ kSubtrahend(cc).IsStream())
<< "MatrixSubtractCalculator only accepts exactly one input stream and "
"one input side packet";
return absl::OkStatus();
}
absl::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) {
const Matrix& minuend = *kMinuend(cc);
const Matrix& subtrahend = *kSubtrahend(cc);
if (minuend.rows() != subtrahend.rows() ||
minuend.cols() != subtrahend.cols()) {
return absl::InvalidArgumentError(
"Minuend and subtrahend must have the same dimensions.");
}
kOut(cc).Send(minuend - subtrahend);
return absl::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -1,156 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "Eigen/Core"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
namespace {
// A 3x4 Matrix of random integers in [0,1000).
const char kMatrixText[] =
"rows: 3\n"
"cols: 4\n"
"packed_data: 387\n"
"packed_data: 940\n"
"packed_data: 815\n"
"packed_data: 825\n"
"packed_data: 997\n"
"packed_data: 884\n"
"packed_data: 419\n"
"packed_data: 763\n"
"packed_data: 123\n"
"packed_data: 30\n"
"packed_data: 825\n"
"packed_data: 299\n";
const char kMatrixText2[] =
"rows: 3\n"
"cols: 4\n"
"packed_data: 388\n"
"packed_data: 941\n"
"packed_data: 816\n"
"packed_data: 826\n"
"packed_data: 998\n"
"packed_data: 885\n"
"packed_data: 420\n"
"packed_data: 764\n"
"packed_data: 124\n"
"packed_data: 31\n"
"packed_data: 826\n"
"packed_data: 300\n";
TEST(MatrixSubtractCalculatorTest, WrongConfig) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "MatrixSubtractCalculator"
input_stream: "input_matrix"
input_side_packet: "SUBTRAHEND:side_matrix"
input_side_packet: "MINUEND:side_matrix2"
output_stream: "output_matrix"
)pb");
CalculatorRunner runner(node_config);
auto status = runner.Run();
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"only accepts exactly one input stream and one input side packet"));
}
TEST(MatrixSubtractCalculatorTest, WrongConfig2) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "MatrixSubtractCalculator"
input_side_packet: "SUBTRAHEND:side_matrix"
input_stream: "SUBTRAHEND:side_matrix2"
output_stream: "output_matrix"
)pb");
CalculatorRunner runner(node_config);
auto status = runner.Run();
EXPECT_THAT(status.message(), testing::HasSubstr("must be connected"));
EXPECT_THAT(status.message(), testing::HasSubstr("not both"));
}
TEST(MatrixSubtractCalculatorTest, SubtractFromInput) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "MatrixSubtractCalculator"
input_stream: "MINUEND:input_matrix"
input_side_packet: "SUBTRAHEND:side_matrix"
output_stream: "output_matrix"
)pb");
CalculatorRunner runner(node_config);
Matrix* side_matrix = new Matrix();
MatrixFromTextProto(kMatrixText, side_matrix);
runner.MutableSidePackets()->Tag("SUBTRAHEND") = Adopt(side_matrix);
Matrix* input_matrix = new Matrix();
MatrixFromTextProto(kMatrixText2, input_matrix);
runner.MutableInputs()->Tag("MINUEND").packets.push_back(
Adopt(input_matrix).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
EXPECT_EQ(Timestamp(0), runner.Outputs().Index(0).packets[0].Timestamp());
const Eigen::MatrixXf& result =
runner.Outputs().Index(0).packets[0].Get<Matrix>();
ASSERT_EQ(3, result.rows());
ASSERT_EQ(4, result.cols());
EXPECT_NEAR(result.sum(), 12, 1e-5);
}
TEST(MatrixSubtractCalculatorTest, SubtractFromSideMatrix) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "MatrixSubtractCalculator"
input_stream: "SUBTRAHEND:input_matrix"
input_side_packet: "MINUEND:side_matrix"
output_stream: "output_matrix"
)pb");
CalculatorRunner runner(node_config);
Matrix* side_matrix = new Matrix();
MatrixFromTextProto(kMatrixText, side_matrix);
runner.MutableSidePackets()->Tag("MINUEND") = Adopt(side_matrix);
Matrix* input_matrix = new Matrix();
MatrixFromTextProto(kMatrixText2, input_matrix);
runner.MutableInputs()
->Tag("SUBTRAHEND")
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
EXPECT_EQ(Timestamp(0), runner.Outputs().Index(0).packets[0].Timestamp());
const Eigen::MatrixXf& result =
runner.Outputs().Index(0).packets[0].Get<Matrix>();
ASSERT_EQ(3, result.rows());
ASSERT_EQ(4, result.cols());
EXPECT_NEAR(result.sum(), -12, 1e-5);
}
} // namespace
} // namespace mediapipe

View File

@ -1,80 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Defines MatrixToVectorCalculator.
#include <math.h>
#include <deque>
#include <memory>
#include <string>
#include "Eigen/Core"
#include "absl/memory/memory.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/tool/status_util.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
namespace api2 {
// A calculator that converts a Matrix M to a vector containing all the
// entries of M in column-major order.
//
// Example config:
// node {
// calculator: "MatrixToVectorCalculator"
// input_stream: "input_matrix"
// output_stream: "column_major_vector"
// }
class MatrixToVectorCalculator : public Node {
public:
static constexpr Input<Matrix> kIn{""};
static constexpr Output<std::vector<float>> kOut{""};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Open(CalculatorContext* cc) override;
// Outputs a packet containing a vector for each input packet.
absl::Status Process(CalculatorContext* cc) override;
};
MEDIAPIPE_REGISTER_NODE(MatrixToVectorCalculator);
absl::Status MatrixToVectorCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(0);
return mediapipe::OkStatus();
}
absl::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) {
const Matrix& input = *kIn(cc);
auto output = absl::make_unique<std::vector<float>>();
// The following lines work to convert the Matrix to a vector because Matrix
// is an Eigen::MatrixXf and Eigen uses column-major layout by default.
output->resize(input.rows() * input.cols());
auto output_as_matrix =
Eigen::Map<Matrix>(output->data(), input.rows(), input.cols());
output_as_matrix = input;
kOut(cc).Send(std::move(output));
return absl::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -1,88 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/validate_type.h"
#include "mediapipe/util/time_series_test_util.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
namespace {
class MatrixToVectorCalculatorTest
: public mediapipe::TimeSeriesCalculatorTest<mediapipe::NoOptions> {
protected:
void SetUp() override { calculator_name_ = "MatrixToVectorCalculator"; }
void AppendInput(const std::vector<float>& column_major_data,
int64 timestamp) {
ASSERT_EQ(num_input_samples_ * num_input_channels_,
column_major_data.size());
Eigen::Map<const Matrix> data_map(&column_major_data[0],
num_input_channels_, num_input_samples_);
AppendInputPacket(new Matrix(data_map), timestamp);
}
void SetInputStreamParameters(int num_channels, int num_samples) {
num_input_channels_ = num_channels;
num_input_samples_ = num_samples;
input_sample_rate_ = 100;
input_packet_rate_ = 20.0;
}
void SetInputHeader(int num_channels, int num_samples) {
SetInputStreamParameters(num_channels, num_samples);
FillInputHeader();
}
void CheckOutputPacket(int packet, std::vector<float> expected_vector) {
const auto& actual_vector =
runner_->Outputs().Index(0).packets[packet].Get<std::vector<float>>();
EXPECT_THAT(actual_vector, testing::ContainerEq(expected_vector));
}
};
TEST_F(MatrixToVectorCalculatorTest, SingleRow) {
InitializeGraph();
SetInputHeader(1, 4); // 1 channel x 4 samples
const std::vector<float>& data_vector = {1.0, 2.0, 3.0, 4.0};
AppendInput(data_vector, 0);
MP_ASSERT_OK(RunGraph());
CheckOutputPacket(0, data_vector);
}
TEST_F(MatrixToVectorCalculatorTest, RegularMatrix) {
InitializeGraph();
SetInputHeader(4, 2); // 4 channels x 2 samples
// Actual data matrix is the transpose of the appearance below.
const std::vector<float>& data_vector = {1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0};
AppendInput(data_vector, 0);
MP_ASSERT_OK(RunGraph());
CheckOutputPacket(0, data_vector);
}
} // namespace
} // namespace mediapipe

View File

@ -1,85 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// This calculator takes a set of input streams and combines them into a single
// output stream. The packets from different streams do not need to contain the
// same type. If there are packets arriving at the same time from two or more
// input streams, the packet corresponding to the input stream with the smallest
// index is passed to the output and the rest are ignored.
//
// Example use-case:
// Suppose we have two (or more) different algorithms for detecting shot
// boundaries and we need to merge their packets into a single stream. The
// algorithms may emit shot boundaries at the same time and their output types
// may not be compatible. Subsequent calculators that process the merged stream
// may be interested only in the timestamps of the shot boundary packets and so
// it may not even need to inspect the values stored inside the packets.
//
// Example config:
// node {
// calculator: "MergeCalculator"
// input_stream: "shot_info1"
// input_stream: "shot_info2"
// input_stream: "shot_info3"
// output_stream: "merged_shot_infos"
// }
//
class MergeCalculator : public Node {
public:
static constexpr Input<AnyType>::Multiple kIn{""};
static constexpr Output<AnyType> kOut{""};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
static absl::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK_GT(kIn(cc).Count(), 0) << "Needs at least one input stream";
if (kIn(cc).Count() == 1) {
LOG(WARNING)
<< "MergeCalculator expects multiple input streams to merge but is "
"receiving only one. Make sure the calculator is configured "
"correctly or consider removing this calculator to reduce "
"unnecessary overhead.";
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
// Output the packet from the first input stream with a packet ready at this
// timestamp.
for (const auto& input : kIn(cc)) {
if (!input.IsEmpty()) {
kOut(cc).Send(input.packet());
return absl::OkStatus();
}
}
LOG(WARNING) << "Empty input packets at timestamp "
<< cc->InputTimestamp().Value();
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(MergeCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -1,141 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
// Checks that the calculator fails if no input streams are provided.
TEST(InvariantMergeInputStreamsCalculator, NoInputStreamsMustFail) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "MergeCalculator"
output_stream: "merged_output"
)pb"));
// Expect calculator to fail.
ASSERT_FALSE(runner.Run().ok());
}
// Checks that the calculator fails with an incorrect number of output streams.
TEST(InvariantMergeInputStreamsCalculator, ExpectExactlyOneOutputStream) {
CalculatorRunner runner1(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "MergeCalculator"
input_stream: "input1"
input_stream: "input2"
)pb"));
// Expect calculator to fail.
EXPECT_FALSE(runner1.Run().ok());
CalculatorRunner runner2(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "MergeCalculator"
input_stream: "input1"
input_stream: "input2"
output_stream: "output1"
output_stream: "output2"
)pb"));
// Expect calculator to fail.
ASSERT_FALSE(runner2.Run().ok());
}
// Ensures two streams with differing types can be merged correctly.
TEST(MediaPipeDetectionToSoapboxDetectionCalculatorTest,
TestMergingTwoStreams) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "MergeCalculator"
input_stream: "input1"
input_stream: "input2"
output_stream: "combined_output"
)pb"));
// input1: integers 10, 20, 30, occurring at times 10, 20, 30.
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(new int(10)).At(Timestamp(10)));
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(new int(20)).At(Timestamp(20)));
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(new int(30)).At(Timestamp(30)));
// input2: floats 5.5, 35.5 at times 5, 35.
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(new float(5.5)).At(Timestamp(5)));
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(new float(35.5)).At(Timestamp(35)));
MP_ASSERT_OK(runner.Run());
// Expected combined_output: 5.5, 10, 20, 30, 35.5 at times 5, 10, 20, 30, 35.
const std::vector<Packet>& actual_output = runner.Outputs().Index(0).packets;
ASSERT_EQ(actual_output.size(), 5);
EXPECT_EQ(actual_output[0].Timestamp(), Timestamp(5));
EXPECT_EQ(actual_output[0].Get<float>(), 5.5);
EXPECT_EQ(actual_output[1].Timestamp(), Timestamp(10));
EXPECT_EQ(actual_output[1].Get<int>(), 10);
EXPECT_EQ(actual_output[2].Timestamp(), Timestamp(20));
EXPECT_EQ(actual_output[2].Get<int>(), 20);
EXPECT_EQ(actual_output[3].Timestamp(), Timestamp(30));
EXPECT_EQ(actual_output[3].Get<int>(), 30);
EXPECT_EQ(actual_output[4].Timestamp(), Timestamp(35));
EXPECT_EQ(actual_output[4].Get<float>(), 35.5);
}
// Ensures three streams with differing types can be merged correctly.
TEST(MediaPipeDetectionToSoapboxDetectionCalculatorTest,
TestMergingThreeStreams) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "MergeCalculator"
input_stream: "input1"
input_stream: "input2"
input_stream: "input3"
output_stream: "combined_output"
)pb"));
// input1: integer 30 occurring at time 30.
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(new int(30)).At(Timestamp(30)));
// input2: float 20.5 occurring at time 20.
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(new float(20.5)).At(Timestamp(20)));
// input3: char 'c' occurring at time 10.
runner.MutableInputs()->Index(2).packets.push_back(
Adopt(new char('c')).At(Timestamp(10)));
MP_ASSERT_OK(runner.Run());
// Expected combined_output: 'c', 20.5, 30 at times 10, 20, 30.
const std::vector<Packet>& actual_output = runner.Outputs().Index(0).packets;
ASSERT_EQ(actual_output.size(), 3);
EXPECT_EQ(actual_output[0].Timestamp(), Timestamp(10));
EXPECT_EQ(actual_output[0].Get<char>(), 'c');
EXPECT_EQ(actual_output[1].Timestamp(), Timestamp(20));
EXPECT_EQ(actual_output[1].Get<float>(), 20.5);
EXPECT_EQ(actual_output[2].Timestamp(), Timestamp(30));
EXPECT_EQ(actual_output[2].Get<int>(), 30);
}
} // namespace
} // namespace mediapipe

View File

@ -1,56 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
namespace api2 {
// A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ...,
// using the integer value (0, 1, ...) in the packet on the "SELECT" input
// stream, and passes the packet on the selected input stream to the "OUTPUT"
// output stream.
//
// Note that this calculator defaults to use MuxInputStreamHandler, which is
// required for this calculator. However, it can be overridden to work with
// other InputStreamHandlers. Check out the unit tests on for an example usage
// with DefaultInputStreamHandler.
// TODO: why would you need to use DefaultISH? Perhaps b/167596925?
class MuxCalculator : public Node {
public:
static constexpr Input<int>::SideFallback kSelect{"SELECT"};
// TODO: this currently sets them all to Any independently, instead
// of the first being Any and the others being SameAs.
static constexpr Input<AnyType>::Multiple kIn{"INPUT"};
static constexpr Output<SameType<kIn>> kOut{"OUTPUT"};
MEDIAPIPE_NODE_CONTRACT(kSelect, kIn, kOut,
StreamHandler("MuxInputStreamHandler"));
absl::Status Process(CalculatorContext* cc) final {
int select = *kSelect(cc);
RET_CHECK(0 <= select && select < kIn(cc).Count());
if (!kIn(cc)[select].IsEmpty()) {
kOut(cc).Send(kIn(cc)[select].packet());
}
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(MuxCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -1,304 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/calculators/core/split_vector_calculator.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
typedef SplitVectorCalculator<int, false> SplitIntVectorCalculator;
REGISTER_CALCULATOR(SplitIntVectorCalculator);
namespace {
// Graph with default input stream handler, and the input selection is driven
// by an input stream. All MuxCalculator inputs are present at each timestamp.
constexpr char kTestGraphConfig1[] = R"pb(
input_stream: "input"
output_stream: "test_output"
node {
calculator: "SplitIntVectorCalculator"
input_stream: "input"
output_stream: "stream0"
output_stream: "stream1"
output_stream: "stream2"
output_stream: "input_select"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 1 end: 2 }
ranges: { begin: 2 end: 3 }
ranges: { begin: 3 end: 4 }
element_only: true
}
}
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:stream0"
input_stream: "INPUT:1:stream1"
input_stream: "INPUT:2:stream2"
input_stream: "SELECT:input_select"
output_stream: "OUTPUT:test_output"
input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" }
}
)pb";
// Graph with default input stream handler, and the input selection is driven
// by an input side packet. All MuxCalculator inputs are present at each
// timestamp.
constexpr char kTestGraphConfig2[] = R"pb(
input_side_packet: "input_selector"
input_stream: "input"
output_stream: "test_output"
node {
calculator: "SplitIntVectorCalculator"
input_stream: "input"
output_stream: "stream0"
output_stream: "stream1"
output_stream: "stream2"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 1 end: 2 }
ranges: { begin: 2 end: 3 }
element_only: true
}
}
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:stream0"
input_stream: "INPUT:1:stream1"
input_stream: "INPUT:2:stream2"
input_side_packet: "SELECT:input_selector"
output_stream: "OUTPUT:test_output"
input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" }
}
)pb";
// Graph with mux input stream handler, and the input selection is driven
// by an input stream. Only one MuxCalculator input is present at each
// timestamp.
constexpr char kTestGraphConfig3[] = R"pb(
input_stream: "input"
output_stream: "test_output"
node {
calculator: "RoundRobinDemuxCalculator"
input_stream: "input"
output_stream: "OUTPUT:0:stream0"
output_stream: "OUTPUT:1:stream1"
output_stream: "OUTPUT:2:stream2"
output_stream: "SELECT:input_select"
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:stream0"
input_stream: "INPUT:1:stream1"
input_stream: "INPUT:2:stream2"
input_stream: "SELECT:input_select"
output_stream: "OUTPUT:test_output"
}
)pb";
constexpr char kOutputName[] = "test_output";
constexpr char kInputName[] = "input";
constexpr char kInputSelector[] = "input_selector";
// Helper to run a graph with the given inputs and generate outputs, asserting
// each step along the way.
// Inputs:
// graph_config_proto - graph config protobuf
// extra_side_packets - input side packets name to value map
// input_stream_name - name of the input
void RunGraph(const std::string& graph_config_proto,
const std::map<std::string, Packet>& extra_side_packets,
const std::string& input_stream_name, int num_input_packets,
std::function<Packet(int)> input_fn,
const std::string& output_stream_name,
std::function<absl::Status(const Packet&)> output_fn) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(graph_config_proto);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.ObserveOutputStream(output_stream_name, output_fn));
MP_ASSERT_OK(graph.StartRun(extra_side_packets));
for (int i = 0; i < num_input_packets; ++i) {
MP_ASSERT_OK(graph.AddPacketToInputStream(input_stream_name, input_fn(i)));
}
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxCalculatorTest, InputStreamSelector_DefaultInputStreamHandler) {
// Input and handling.
std::vector<std::vector<int>> input_packets = {
{1, 1, 2, 1}, {3, 5, 8, 2}, {13, 21, 34, 0},
{55, 89, 144, 2}, {233, 377, 610, 0}, {987, 1597, 2584, 1},
{4181, 6765, 10946, 2},
};
int packet_time_stamp = 22;
// This function will return the i-th input packet.
auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet {
return MakePacket<std::vector<int>>(input_packets[i])
.At(Timestamp(packet_time_stamp++));
};
// Output and handling.
std::vector<int> output;
// This function collects the output from the packet.
auto output_fn = [&output](const Packet& p) -> absl::Status {
output.push_back(p.Get<int>());
return absl::OkStatus();
};
RunGraph(kTestGraphConfig1, {}, kInputName, input_packets.size(), input_fn,
kOutputName, output_fn);
EXPECT_THAT(output, testing::ElementsAre(1, 8, 13, 144, 233, 1597, 10946));
}
TEST(MuxCalculatorTest, InputSidePacketSelector_DefaultInputStreamHandler) {
// Input and handling.
std::vector<std::vector<int>> input_packets = {
{1, 1, 2}, {3, 5, 8}, {13, 21, 34}, {55, 89, 144},
{233, 377, 610}, {987, 1597, 2584}, {4181, 6765, 10946},
};
int packet_time_stamp = 22;
// This function will return the i-th input packet.
auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet {
return MakePacket<std::vector<int>>(input_packets[i])
.At(Timestamp(packet_time_stamp++));
};
// Output and handling.
std::vector<int> output;
// This function collects the output from the packet.
auto output_fn = [&output](const Packet& p) -> absl::Status {
output.push_back(p.Get<int>());
return absl::OkStatus();
};
RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket<int>(0)}},
kInputName, input_packets.size(), input_fn, kOutputName, output_fn);
EXPECT_THAT(output, testing::ElementsAre(1, 3, 13, 55, 233, 987, 4181));
output.clear();
RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket<int>(1)}},
kInputName, input_packets.size(), input_fn, kOutputName, output_fn);
EXPECT_THAT(output, testing::ElementsAre(1, 5, 21, 89, 377, 1597, 6765));
output.clear();
RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket<int>(2)}},
kInputName, input_packets.size(), input_fn, kOutputName, output_fn);
EXPECT_THAT(output, testing::ElementsAre(2, 8, 34, 144, 610, 2584, 10946));
}
TEST(MuxCalculatorTest, InputStreamSelector_MuxInputStreamHandler) {
// Input and handling.
std::vector<int> input_packets = {1, 1, 2, 3, 5, 8, 13,
21, 34, 55, 89, 144, 233, 377,
610, 987, 1597, 2584, 4181, 6765, 10946};
int packet_time_stamp = 22;
// This function will return the i-th input packet.
auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet {
return MakePacket<int>(input_packets[i]).At(Timestamp(packet_time_stamp++));
};
// Output and handling.
std::vector<int> output;
// This function collects the output from the packet.
auto output_fn = [&output](const Packet& p) -> absl::Status {
output.push_back(p.Get<int>());
return absl::OkStatus();
};
RunGraph(kTestGraphConfig3, {}, kInputName, input_packets.size(), input_fn,
kOutputName, output_fn);
EXPECT_EQ(output, input_packets);
}
constexpr char kDualInputGraphConfig[] = R"pb(
input_stream: "input_0"
input_stream: "input_1"
input_stream: "input_select"
output_stream: "test_output"
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:input_0"
input_stream: "INPUT:1:input_1"
input_stream: "SELECT:input_select"
output_stream: "OUTPUT:test_output"
}
)pb";
TEST(MuxCalculatorTest, DiscardSkippedInputs_MuxInputStreamHandler) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
kDualInputGraphConfig);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
std::shared_ptr<int> output;
MP_ASSERT_OK(
graph.ObserveOutputStream("test_output", [&output](const Packet& p) {
output = p.Get<std::shared_ptr<int>>();
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
auto one = std::make_shared<int>(1);
auto two = std::make_shared<int>(2);
auto three = std::make_shared<int>(3);
std::weak_ptr<int> one_weak = one;
std::weak_ptr<int> two_weak = two;
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_0",
MakePacket<std::shared_ptr<int>>(std::move(one)).At(Timestamp(0))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_1",
MakePacket<std::shared_ptr<int>>(std::move(two)).At(Timestamp(0))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_1",
MakePacket<std::shared_ptr<int>>(std::move(three)).At(Timestamp(1))));
EXPECT_EQ(one, nullptr);
EXPECT_EQ(two, nullptr);
EXPECT_EQ(three, nullptr);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_select", MakePacket<int>(0).At(Timestamp(0))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(*output, 1);
EXPECT_NE(one_weak.lock(), nullptr);
EXPECT_EQ(two_weak.lock(), nullptr);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_select", MakePacket<int>(1).At(Timestamp(1))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(*output, 3);
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace
} // namespace mediapipe

View File

@ -1,54 +0,0 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
namespace api2 {
// A Calculator that returns 0 if INPUT is 0, and 1 otherwise.
class NonZeroCalculator : public Node {
public:
static constexpr Input<int>::SideFallback kIn{"INPUT"};
static constexpr Output<int>::Optional kOut{"OUTPUT"};
static constexpr Output<bool>::Optional kBooleanOut{"OUTPUT_BOOL"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kBooleanOut);
absl::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK(kOut(cc).IsConnected() || kBooleanOut(cc).IsConnected())
<< "At least one output stream is expected.";
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
if (!kIn(cc).IsEmpty()) {
bool isNonZero = *kIn(cc) != 0;
if (kOut(cc).IsConnected()) {
kOut(cc).Send(std::make_unique<int>(isNonZero ? 1 : 0));
}
if (kBooleanOut(cc).IsConnected()) {
kBooleanOut(cc).Send(std::make_unique<bool>(isNonZero));
}
}
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(NonZeroCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -1,93 +0,0 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
class NonZeroCalculatorTest : public ::testing::Test {
protected:
NonZeroCalculatorTest()
: runner_(
R"pb(
calculator: "NonZeroCalculator"
input_stream: "INPUT:input"
output_stream: "OUTPUT:output"
output_stream: "OUTPUT_BOOL:output_bool"
)pb") {}
void SetInput(const std::vector<int>& inputs) {
int timestamp = 0;
for (const auto input : inputs) {
runner_.MutableInputs()
->Get("INPUT", 0)
.packets.push_back(MakePacket<int>(input).At(Timestamp(timestamp++)));
}
}
std::vector<int> GetOutput() {
std::vector<int> result;
for (const auto output : runner_.Outputs().Get("OUTPUT", 0).packets) {
result.push_back(output.Get<int>());
}
return result;
}
std::vector<bool> GetOutputBool() {
std::vector<bool> result;
for (const auto output : runner_.Outputs().Get("OUTPUT_BOOL", 0).packets) {
result.push_back(output.Get<bool>());
}
return result;
}
CalculatorRunner runner_;
};
TEST_F(NonZeroCalculatorTest, ProducesZeroOutputForZeroInput) {
SetInput({0});
MP_ASSERT_OK(runner_.Run());
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(0));
EXPECT_THAT(GetOutputBool(), ::testing::ElementsAre(false));
}
TEST_F(NonZeroCalculatorTest, ProducesNonZeroOutputForNonZeroInput) {
SetInput({1, 2, 3, -4, 5});
MP_ASSERT_OK(runner_.Run());
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 1, 1, 1, 1));
EXPECT_THAT(GetOutputBool(),
::testing::ElementsAre(true, true, true, true, true));
}
TEST_F(NonZeroCalculatorTest, SwitchesBetweenNonZeroAndZeroOutput) {
SetInput({1, 0, 3, 0, 5});
MP_ASSERT_OK(runner_.Run());
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 0, 1, 0, 1));
EXPECT_THAT(GetOutputBool(),
::testing::ElementsAre(true, false, true, false, true));
}
} // namespace mediapipe

View File

@ -1,117 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This takes packets from N+1 streams, A_1, A_2, ..., A_N, B.
// For every packet that appears in B, outputs the most recent packet from each
// of the A_i on a separate stream.
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/core/packet_cloner_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
namespace mediapipe {
// For every packet received on the last stream, output the latest packet
// obtained on all other streams. Therefore, if the last stream outputs at a
// higher rate than the others, this effectively clones the packets from the
// other streams to match the last.
//
// Example config:
// node {
// calculator: "PacketClonerCalculator"
// input_stream: "first_base_signal"
// input_stream: "second_base_signal"
// input_stream: "tick_signal"
// output_stream: "cloned_first_base_signal"
// output_stream: "cloned_second_base_signal"
// }
//
// Related:
// packet_cloner_calculator.proto: Options for this calculator.
// merge_input_streams_calculator.cc: One output stream.
// packet_inner_join_calculator.cc: Don't output unless all inputs are new.
class PacketClonerCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
const int tick_signal_index = cc->Inputs().NumEntries() - 1;
for (int i = 0; i < tick_signal_index; ++i) {
cc->Inputs().Index(i).SetAny();
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i));
}
cc->Inputs().Index(tick_signal_index).SetAny();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
// Load options.
const auto calculator_options =
cc->Options<mediapipe::PacketClonerCalculatorOptions>();
output_only_when_all_inputs_received_ =
calculator_options.output_only_when_all_inputs_received();
// Parse input streams.
tick_signal_index_ = cc->Inputs().NumEntries() - 1;
current_.resize(tick_signal_index_);
// Pass along the header for each stream if present.
for (int i = 0; i < tick_signal_index_; ++i) {
if (!cc->Inputs().Index(i).Header().IsEmpty()) {
cc->Outputs().Index(i).SetHeader(cc->Inputs().Index(i).Header());
}
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
// Store input signals.
for (int i = 0; i < tick_signal_index_; ++i) {
if (!cc->Inputs().Index(i).Value().IsEmpty()) {
current_[i] = cc->Inputs().Index(i).Value();
}
}
// Output according to the TICK signal.
if (!cc->Inputs().Index(tick_signal_index_).Value().IsEmpty()) {
if (output_only_when_all_inputs_received_) {
// Return if one of the input is null.
for (int i = 0; i < tick_signal_index_; ++i) {
if (current_[i].IsEmpty()) {
return absl::OkStatus();
}
}
}
// Output each stream.
for (int i = 0; i < tick_signal_index_; ++i) {
if (!current_[i].IsEmpty()) {
cc->Outputs().Index(i).AddPacket(
current_[i].At(cc->InputTimestamp()));
} else {
cc->Outputs().Index(i).SetNextTimestampBound(
cc->InputTimestamp().NextAllowedInStream());
}
}
}
return absl::OkStatus();
}
private:
std::vector<Packet> current_;
int tick_signal_index_;
bool output_only_when_all_inputs_received_;
};
REGISTER_CALCULATOR(PacketClonerCalculator);
} // namespace mediapipe

View File

@ -1,31 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message PacketClonerCalculatorOptions {
extend CalculatorOptions {
optional PacketClonerCalculatorOptions ext = 258872085;
}
// When true, this calculator will drop received TICK packets if any input
// stream hasn't received a packet yet.
optional bool output_only_when_all_inputs_received = 1 [default = false];
}

View File

@ -1,77 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
// Calculator that acts like the SQL query:
// SELECT *
// FROM packets_on_stream1 AS packet1
// INNER JOIN packets_on_stream2 AS packet2
// ON packet1.timestamp = packet2.timestamp
//
// In other words, it only emits and forwards packets if all input streams are
// not empty.
//
// Intended for use with FixedSizeInputStreamHandler.
//
// Related:
// packet_cloner_calculator.cc: Repeats last-seen packets from empty inputs.
class PacketInnerJoinCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
int num_streams_;
};
REGISTER_CALCULATOR(PacketInnerJoinCalculator);
absl::Status PacketInnerJoinCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().NumEntries() == cc->Outputs().NumEntries())
<< "The number of input and output streams must match.";
const int num_streams = cc->Inputs().NumEntries();
for (int i = 0; i < num_streams; ++i) {
cc->Inputs().Index(i).SetAny();
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i));
}
return absl::OkStatus();
}
absl::Status PacketInnerJoinCalculator::Open(CalculatorContext* cc) {
num_streams_ = cc->Inputs().NumEntries();
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status PacketInnerJoinCalculator::Process(CalculatorContext* cc) {
for (int i = 0; i < num_streams_; ++i) {
if (cc->Inputs().Index(i).Value().IsEmpty()) {
return absl::OkStatus();
}
}
for (int i = 0; i < num_streams_; ++i) {
cc->Outputs().Index(i).AddPacket(cc->Inputs().Index(i).Value());
}
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -1,101 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
namespace {
inline Packet PacketFrom(int i) { return Adopt(new int(i)).At(Timestamp(i)); }
TEST(PacketInnerJoinCalculatorTest, AllMatching) {
// Test case.
const std::vector<int> packets_on_stream1 = {0, 1, 2, 3};
const std::vector<int> packets_on_stream2 = {0, 1, 2, 3};
// Run.
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
for (int packet_load : packets_on_stream1) {
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
}
for (int packet_load : packets_on_stream2) {
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
}
MP_ASSERT_OK(runner.Run());
// Check.
const std::vector<int> expected = {0, 1, 2, 3};
ASSERT_EQ(expected.size(), runner.Outputs().Index(0).packets.size());
ASSERT_EQ(expected.size(), runner.Outputs().Index(1).packets.size());
for (int i = 0; i < expected.size(); ++i) {
const Packet packet1 = runner.Outputs().Index(0).packets[i];
EXPECT_EQ(expected[i], packet1.Get<int>());
EXPECT_EQ(expected[i], packet1.Timestamp().Value());
const Packet packet2 = runner.Outputs().Index(1).packets[i];
EXPECT_EQ(expected[i], packet2.Get<int>());
EXPECT_EQ(expected[i], packet2.Timestamp().Value());
}
}
TEST(PacketInnerJoinCalculatorTest, NoneMatching) {
// Test case.
const std::vector<int> packets_on_stream1 = {0, 2};
const std::vector<int> packets_on_stream2 = {1, 3};
// Run.
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
for (int packet_load : packets_on_stream1) {
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
}
for (int packet_load : packets_on_stream2) {
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
}
MP_ASSERT_OK(runner.Run());
// Check.
EXPECT_TRUE(runner.Outputs().Index(0).packets.empty());
EXPECT_TRUE(runner.Outputs().Index(1).packets.empty());
}
TEST(PacketInnerJoinCalculatorTest, SomeMatching) {
// Test case.
const std::vector<int> packets_on_stream1 = {0, 1, 2, 3, 4, 6};
const std::vector<int> packets_on_stream2 = {0, 2, 4, 5, 6};
// Run.
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
for (int packet_load : packets_on_stream1) {
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
}
for (int packet_load : packets_on_stream2) {
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
}
MP_ASSERT_OK(runner.Run());
// Check.
const std::vector<int> expected = {0, 2, 4, 6};
ASSERT_EQ(expected.size(), runner.Outputs().Index(0).packets.size());
ASSERT_EQ(expected.size(), runner.Outputs().Index(1).packets.size());
for (int i = 0; i < expected.size(); ++i) {
const Packet packet1 = runner.Outputs().Index(0).packets[i];
EXPECT_EQ(expected[i], packet1.Get<int>());
EXPECT_EQ(expected[i], packet1.Timestamp().Value());
const Packet packet2 = runner.Outputs().Index(1).packets[i];
EXPECT_EQ(expected[i], packet2.Get<int>());
EXPECT_EQ(expected[i], packet2.Timestamp().Value());
}
}
} // namespace
} // namespace mediapipe

View File

@ -1,84 +0,0 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// For each non empty input packet, emits a single output packet containing a
// boolean value "true", "false" in response to empty packets (a.k.a. timestamp
// bound updates) This can be used to "flag" the presence of an arbitrary packet
// type as input into a downstream calculator.
//
// Inputs:
// PACKET - any type.
//
// Outputs:
// PRESENCE - bool.
// "true" if packet is not empty, "false" if there's timestamp bound update
// instead.
//
// Examples:
// node: {
// calculator: "PacketPresenceCalculator"
// input_stream: "PACKET:packet"
// output_stream: "PRESENCE:presence"
// }
//
// This calculator can be used in conjuction with GateCalculator in order to
// allow/disallow processing. For instance:
// node: {
// calculator: "PacketPresenceCalculator"
// input_stream: "PACKET:value"
// output_stream: "PRESENCE:disallow_if_present"
// }
// node {
// calculator: "GateCalculator"
// input_stream: "image"
// input_stream: "DISALLOW:disallow_if_present"
// output_stream: "image_for_processing"
// options: {
// [mediapipe.GateCalculatorOptions.ext] {
// empty_packets_as_allow: true
// }
// }
// }
class PacketPresenceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("PACKET").SetAny();
cc->Outputs().Tag("PRESENCE").Set<bool>();
// Process() function is invoked in response to input stream timestamp
// bound updates.
cc->SetProcessTimestampBounds(true);
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
cc->Outputs()
.Tag("PRESENCE")
.AddPacket(MakePacket<bool>(!cc->Inputs().Tag("PACKET").IsEmpty())
.At(cc->InputTimestamp()));
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(PacketPresenceCalculator);
} // namespace mediapipe

View File

@ -1,85 +0,0 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/sink.h"
namespace mediapipe {
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Value;
namespace {
MATCHER_P2(BoolPacket, value, timestamp, "") {
return Value(arg.template Get<bool>(), Eq(value)) &&
Value(arg.Timestamp(), Eq(timestamp));
}
TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
std::vector<Packet> output_packets;
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'allow'
input_stream: 'value'
node {
calculator: "GateCalculator"
input_stream: 'value'
input_stream: 'ALLOW:allow'
output_stream: 'gated_value'
}
node {
calculator: 'PacketPresenceCalculator'
input_stream: 'PACKET:gated_value'
output_stream: 'PRESENCE:presence'
}
)pb");
tool::AddVectorSink("presence", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
auto send_packet = [&graph](int value, bool allow, Timestamp timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"value", MakePacket<int>(value).At(timestamp)));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"allow", MakePacket<bool>(allow).At(timestamp)));
};
send_packet(10, false, Timestamp(10));
MP_EXPECT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, ElementsAre(BoolPacket(false, Timestamp(10))));
output_packets.clear();
send_packet(20, true, Timestamp(11));
MP_EXPECT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, ElementsAre(BoolPacket(true, Timestamp(11))));
MP_EXPECT_OK(graph.CloseAllInputStreams());
MP_EXPECT_OK(graph.WaitUntilDone());
}
} // namespace
} // namespace mediapipe

View File

@ -1,696 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/packet_resampler_calculator.h"
#include <memory>
namespace {
// Reflect an integer against the lower and upper bound of an interval.
int64 ReflectBetween(int64 ts, int64 ts_min, int64 ts_max) {
if (ts < ts_min) return 2 * ts_min - ts - 1;
if (ts >= ts_max) return 2 * ts_max - ts - 1;
return ts;
}
// Creates a secure random number generator for use in ProcessWithJitter.
// If no secure random number generator can be constructed, the jitter
// option is disabled in order to mainatain a consistent security and
// consistent random seeding.
std::unique_ptr<RandomBase> CreateSecureRandom(const std::string& seed) {
RandomBase* result = nullptr;
return std::unique_ptr<RandomBase>(result);
}
} // namespace
namespace mediapipe {
REGISTER_CALCULATOR(PacketResamplerCalculator);
namespace {
// Returns a TimestampDiff (assuming microseconds) corresponding to the
// given time in seconds.
TimestampDiff TimestampDiffFromSeconds(double seconds) {
return TimestampDiff(MathUtil::SafeRound<int64, double>(
seconds * Timestamp::kTimestampUnitsPerSecond));
}
} // namespace
absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) {
const auto& resampler_options =
cc->Options<PacketResamplerCalculatorOptions>();
if (cc->InputSidePackets().HasTag("OPTIONS")) {
cc->InputSidePackets().Tag("OPTIONS").Set<CalculatorOptions>();
}
CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0);
if (!input_data_id.IsValid()) {
input_data_id = cc->Inputs().GetId("", 0);
}
cc->Inputs().Get(input_data_id).SetAny();
if (cc->Inputs().HasTag("VIDEO_HEADER")) {
cc->Inputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
}
CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0);
if (!output_data_id.IsValid()) {
output_data_id = cc->Outputs().GetId("", 0);
}
cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id));
if (cc->Outputs().HasTag("VIDEO_HEADER")) {
cc->Outputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
}
if (resampler_options.jitter() != 0.0) {
RET_CHECK_GT(resampler_options.jitter(), 0.0);
RET_CHECK_LE(resampler_options.jitter(), 1.0);
RET_CHECK(cc->InputSidePackets().HasTag("SEED"));
cc->InputSidePackets().Tag("SEED").Set<std::string>();
}
return absl::OkStatus();
}
absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
const auto resampler_options =
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
cc->InputSidePackets(), "OPTIONS");
flush_last_packet_ = resampler_options.flush_last_packet();
jitter_ = resampler_options.jitter();
input_data_id_ = cc->Inputs().GetId("DATA", 0);
if (!input_data_id_.IsValid()) {
input_data_id_ = cc->Inputs().GetId("", 0);
}
output_data_id_ = cc->Outputs().GetId("DATA", 0);
if (!output_data_id_.IsValid()) {
output_data_id_ = cc->Outputs().GetId("", 0);
}
frame_rate_ = resampler_options.frame_rate();
start_time_ = resampler_options.has_start_time()
? Timestamp(resampler_options.start_time())
: Timestamp::Min();
end_time_ = resampler_options.has_end_time()
? Timestamp(resampler_options.end_time())
: Timestamp::Max();
round_limits_ = resampler_options.round_limits();
// The frame_rate has a default value of -1.0, so the user must set it!
RET_CHECK_LT(0, frame_rate_)
<< "The output frame rate must be greater than zero";
RET_CHECK_LE(frame_rate_, Timestamp::kTimestampUnitsPerSecond)
<< "The output frame rate must be smaller than "
<< Timestamp::kTimestampUnitsPerSecond;
frame_time_usec_ = static_cast<int64>(1000000.0 / frame_rate_);
jitter_usec_ = static_cast<int64>(1000000.0 * jitter_ / frame_rate_);
RET_CHECK_LE(jitter_usec_, frame_time_usec_);
video_header_.frame_rate = frame_rate_;
if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE &&
!cc->Inputs().Get(input_data_id_).Header().IsEmpty()) {
if (resampler_options.output_header() ==
PacketResamplerCalculatorOptions::UPDATE_VIDEO_HEADER) {
video_header_ =
cc->Inputs().Get(input_data_id_).Header().Get<VideoHeader>();
video_header_.frame_rate = frame_rate_;
cc->Outputs()
.Get(output_data_id_)
.SetHeader(Adopt(new VideoHeader(video_header_)));
} else {
cc->Outputs()
.Get(output_data_id_)
.SetHeader(cc->Inputs().Get(input_data_id_).Header());
}
}
strategy_ = GetSamplingStrategy(resampler_options);
return strategy_->Open(cc);
}
absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
if (cc->InputTimestamp() == Timestamp::PreStream() &&
cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") &&
!cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) {
video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get<VideoHeader>();
video_header_.frame_rate = frame_rate_;
if (cc->Inputs().Get(input_data_id_).IsEmpty()) {
return absl::OkStatus();
}
}
if (absl::Status status = strategy_->Process(cc); !status.ok()) {
return status; // Avoid MP_RETURN_IF_ERROR macro for external release.
}
last_packet_ = cc->Inputs().Get(input_data_id_).Value();
return absl::OkStatus();
}
absl::Status PacketResamplerCalculator::Close(CalculatorContext* cc) {
if (!cc->GraphStatus().ok()) {
return absl::OkStatus();
}
return strategy_->Close(cc);
}
std::unique_ptr<PacketResamplerStrategy>
PacketResamplerCalculator::GetSamplingStrategy(
const PacketResamplerCalculatorOptions& options) {
if (options.reproducible_sampling()) {
if (!options.jitter_with_reflection()) {
LOG(WARNING)
<< "reproducible_sampling enabled w/ jitter_with_reflection "
"disabled. "
<< "reproducible_sampling always uses jitter with reflection, "
<< "Ignoring jitter_with_reflection setting.";
}
return absl::make_unique<ReproducibleJitterWithReflectionStrategy>(this);
}
if (options.jitter() == 0) {
return absl::make_unique<NoJitterStrategy>(this);
}
if (options.jitter_with_reflection()) {
return absl::make_unique<LegacyJitterWithReflectionStrategy>(this);
}
// With jitter and no reflection.
return absl::make_unique<JitterWithoutReflectionStrategy>(this);
}
Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp(int64 index) const {
CHECK_EQ(jitter_, 0.0);
CHECK_NE(first_timestamp_, Timestamp::Unset());
return first_timestamp_ + TimestampDiffFromSeconds(index / frame_rate_);
}
int64 PacketResamplerCalculator::TimestampToPeriodIndex(
Timestamp timestamp) const {
CHECK_EQ(jitter_, 0.0);
CHECK_NE(first_timestamp_, Timestamp::Unset());
return MathUtil::SafeRound<int64, double>(
(timestamp - first_timestamp_).Seconds() * frame_rate_);
}
void PacketResamplerCalculator::OutputWithinLimits(CalculatorContext* cc,
const Packet& packet) const {
TimestampDiff margin((round_limits_) ? frame_time_usec_ / 2 : 0);
if (packet.Timestamp() >= start_time_ - margin &&
packet.Timestamp() < end_time_ + margin) {
cc->Outputs().Get(output_data_id_).AddPacket(packet);
}
}
absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) {
const auto resampler_options =
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
cc->InputSidePackets(), "OPTIONS");
if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE) {
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
"the actual value.";
}
if (calculator_->flush_last_packet_) {
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
"ignored, because we are adding jitter.";
}
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
random_ = CreateSecureRandom(seed);
if (random_ == nullptr) {
return absl::InvalidArgumentError(
"SecureRandom is not available. With \"jitter\" specified, "
"PacketResamplerCalculator processing cannot proceed.");
}
packet_reservoir_random_ = CreateSecureRandom(seed);
packet_reservoir_ =
std::make_unique<PacketReservoir>(packet_reservoir_random_.get());
return absl::OkStatus();
}
absl::Status LegacyJitterWithReflectionStrategy::Close(CalculatorContext* cc) {
if (!packet_reservoir_->IsEmpty()) {
LOG(INFO) << "Emitting pack from reservoir.";
calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample());
}
return absl::OkStatus();
}
absl::Status LegacyJitterWithReflectionStrategy::Process(
CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
if (packet_reservoir_->IsEnabled() &&
(first_timestamp_ == Timestamp::Unset() ||
(cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) {
auto curr_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
packet_reservoir_->AddSample(curr_packet);
}
if (first_timestamp_ == Timestamp::Unset()) {
first_timestamp_ = cc->InputTimestamp();
InitializeNextOutputTimestampWithJitter();
if (first_timestamp_ == next_output_timestamp_) {
calculator_->OutputWithinLimits(cc, cc->Inputs()
.Get(calculator_->input_data_id_)
.Value()
.At(next_output_timestamp_));
UpdateNextOutputTimestampWithJitter();
}
return absl::OkStatus();
}
if (calculator_->frame_time_usec_ <
(cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) {
LOG_FIRST_N(WARNING, 2)
<< "Adding jitter is not very useful when upsampling.";
}
while (true) {
const int64 last_diff =
(next_output_timestamp_ - calculator_->last_packet_.Timestamp())
.Value();
RET_CHECK_GT(last_diff, 0);
const int64 curr_diff =
(next_output_timestamp_ - cc->InputTimestamp()).Value();
if (curr_diff > 0) {
break;
}
calculator_->OutputWithinLimits(
cc, (std::abs(curr_diff) > last_diff
? calculator_->last_packet_
: cc->Inputs().Get(calculator_->input_data_id_).Value())
.At(next_output_timestamp_));
UpdateNextOutputTimestampWithJitter();
// From now on every time a packet is emitted the timestamp of the next
// packet becomes known; that timestamp is stored in next_output_timestamp_.
// The only exception to this rule is the packet emitted from Close() which
// can only happen when jitter_with_reflection is enabled but in this case
// next_output_timestamp_min_ is a non-decreasing lower bound of any
// subsequent packet.
const Timestamp timestamp_bound = next_output_timestamp_min_;
cc->Outputs()
.Get(calculator_->output_data_id_)
.SetNextTimestampBound(timestamp_bound);
}
return absl::OkStatus();
}
void LegacyJitterWithReflectionStrategy::
InitializeNextOutputTimestampWithJitter() {
next_output_timestamp_min_ = first_timestamp_;
next_output_timestamp_ =
first_timestamp_ +
random_->UnbiasedUniform64(calculator_->frame_time_usec_);
}
void LegacyJitterWithReflectionStrategy::UpdateNextOutputTimestampWithJitter() {
packet_reservoir_->Clear();
next_output_timestamp_min_ += calculator_->frame_time_usec_;
Timestamp next_output_timestamp_max_ =
next_output_timestamp_min_ + calculator_->frame_time_usec_;
next_output_timestamp_ +=
calculator_->frame_time_usec_ +
random_->UnbiasedUniform64(2 * calculator_->jitter_usec_ + 1) -
calculator_->jitter_usec_;
next_output_timestamp_ = Timestamp(ReflectBetween(
next_output_timestamp_.Value(), next_output_timestamp_min_.Value(),
next_output_timestamp_max_.Value()));
CHECK_GE(next_output_timestamp_, next_output_timestamp_min_);
CHECK_LT(next_output_timestamp_, next_output_timestamp_max_);
}
absl::Status ReproducibleJitterWithReflectionStrategy::Open(
CalculatorContext* cc) {
const auto resampler_options =
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
cc->InputSidePackets(), "OPTIONS");
if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE) {
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
"the actual value.";
}
if (calculator_->flush_last_packet_) {
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
"ignored, because we are adding jitter.";
}
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
random_ = CreateSecureRandom(seed);
if (random_ == nullptr) {
return absl::InvalidArgumentError(
"SecureRandom is not available. With \"jitter\" specified, "
"PacketResamplerCalculator processing cannot proceed.");
}
return absl::OkStatus();
}
absl::Status ReproducibleJitterWithReflectionStrategy::Close(
CalculatorContext* cc) {
// If last packet is non-empty and a packet hasn't been emitted for this
// period, emit the last packet.
if (!calculator_->last_packet_.IsEmpty() && !packet_emitted_this_period_) {
calculator_->OutputWithinLimits(
cc, calculator_->last_packet_.At(next_output_timestamp_));
}
return absl::OkStatus();
}
absl::Status ReproducibleJitterWithReflectionStrategy::Process(
CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
Packet current_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
if (calculator_->last_packet_.IsEmpty()) {
// last_packet is empty, this is the first packet of the stream.
InitializeNextOutputTimestamp(current_packet.Timestamp());
// If next_output_timestamp_ happens to fall before current_packet, emit
// current packet. Only a single packet can be emitted at the beginning
// of the stream.
if (next_output_timestamp_ < current_packet.Timestamp()) {
calculator_->OutputWithinLimits(
cc, current_packet.At(next_output_timestamp_));
packet_emitted_this_period_ = true;
}
return absl::OkStatus();
}
// Last packet is set, so we are mid-stream.
if (calculator_->frame_time_usec_ <
(current_packet.Timestamp() - calculator_->last_packet_.Timestamp())
.Value()) {
// Note, if the stream is upsampling, this could lead to the same packet
// being emitted twice. Upsampling and jitter doesn't make much sense
// but does technically work.
LOG_FIRST_N(WARNING, 2)
<< "Adding jitter is not very useful when upsampling.";
}
// Since we may be upsampling, we need to iteratively advance the
// next_output_timestamp_ one period at a time until it reaches the period
// current_packet is in. During this process, last_packet and/or
// current_packet may be repeatly emitted.
UpdateNextOutputTimestamp(current_packet.Timestamp());
while (!packet_emitted_this_period_ &&
next_output_timestamp_ <= current_packet.Timestamp()) {
// last_packet < next_output_timestamp_ <= current_packet,
// so emit the closest packet.
Packet packet_to_emit =
current_packet.Timestamp() - next_output_timestamp_ <
next_output_timestamp_ - calculator_->last_packet_.Timestamp()
? current_packet
: calculator_->last_packet_;
calculator_->OutputWithinLimits(cc,
packet_to_emit.At(next_output_timestamp_));
packet_emitted_this_period_ = true;
// If we are upsampling, packet_emitted_this_period_ can be reset by
// the following UpdateNext and the loop will iterate.
UpdateNextOutputTimestamp(current_packet.Timestamp());
}
// Set the bounds on the output stream. Note, if we emitted a packet
// above, it will already be set at next_output_timestamp_ + 1, in which
// case we have to skip setting it.
if (cc->Outputs().Get(calculator_->output_data_id_).NextTimestampBound() <
next_output_timestamp_) {
cc->Outputs()
.Get(calculator_->output_data_id_)
.SetNextTimestampBound(next_output_timestamp_);
}
return absl::OkStatus();
}
void ReproducibleJitterWithReflectionStrategy::InitializeNextOutputTimestamp(
Timestamp current_timestamp) {
if (next_output_timestamp_min_ != Timestamp::Unset()) {
return;
}
next_output_timestamp_min_ = Timestamp(0);
next_output_timestamp_ =
Timestamp(GetNextRandom(calculator_->frame_time_usec_));
// While the current timestamp is ahead of the max (i.e. min + frame_time),
// fast-forward.
while (current_timestamp >=
next_output_timestamp_min_ + calculator_->frame_time_usec_) {
packet_emitted_this_period_ = true; // Force update...
UpdateNextOutputTimestamp(current_timestamp);
}
}
void ReproducibleJitterWithReflectionStrategy::UpdateNextOutputTimestamp(
Timestamp current_timestamp) {
if (packet_emitted_this_period_ &&
current_timestamp >=
next_output_timestamp_min_ + calculator_->frame_time_usec_) {
next_output_timestamp_min_ += calculator_->frame_time_usec_;
Timestamp next_output_timestamp_max_ =
next_output_timestamp_min_ + calculator_->frame_time_usec_;
next_output_timestamp_ += calculator_->frame_time_usec_ +
GetNextRandom(2 * calculator_->jitter_usec_ + 1) -
calculator_->jitter_usec_;
next_output_timestamp_ = Timestamp(ReflectBetween(
next_output_timestamp_.Value(), next_output_timestamp_min_.Value(),
next_output_timestamp_max_.Value()));
packet_emitted_this_period_ = false;
}
}
absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) {
const auto resampler_options =
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
cc->InputSidePackets(), "OPTIONS");
if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE) {
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
"the actual value.";
}
if (calculator_->flush_last_packet_) {
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
"ignored, because we are adding jitter.";
}
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
random_ = CreateSecureRandom(seed);
if (random_ == nullptr) {
return absl::InvalidArgumentError(
"SecureRandom is not available. With \"jitter\" specified, "
"PacketResamplerCalculator processing cannot proceed.");
}
packet_reservoir_random_ = CreateSecureRandom(seed);
packet_reservoir_ =
absl::make_unique<PacketReservoir>(packet_reservoir_random_.get());
return absl::OkStatus();
}
absl::Status JitterWithoutReflectionStrategy::Close(CalculatorContext* cc) {
if (!packet_reservoir_->IsEmpty()) {
calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample());
}
return absl::OkStatus();
}
absl::Status JitterWithoutReflectionStrategy::Process(CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
// Packet reservior is used to make sure there's an output for every period,
// e.g. partial period at the end of the stream.
if (packet_reservoir_->IsEnabled() &&
(calculator_->first_timestamp_ == Timestamp::Unset() ||
(cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) {
auto curr_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
packet_reservoir_->AddSample(curr_packet);
}
if (calculator_->first_timestamp_ == Timestamp::Unset()) {
calculator_->first_timestamp_ = cc->InputTimestamp();
InitializeNextOutputTimestamp();
if (calculator_->first_timestamp_ == next_output_timestamp_) {
calculator_->OutputWithinLimits(cc, cc->Inputs()
.Get(calculator_->input_data_id_)
.Value()
.At(next_output_timestamp_));
UpdateNextOutputTimestamp();
}
return absl::OkStatus();
}
if (calculator_->frame_time_usec_ <
(cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) {
LOG_FIRST_N(WARNING, 2)
<< "Adding jitter is not very useful when upsampling.";
}
while (true) {
const int64 last_diff =
(next_output_timestamp_ - calculator_->last_packet_.Timestamp())
.Value();
RET_CHECK_GT(last_diff, 0);
const int64 curr_diff =
(next_output_timestamp_ - cc->InputTimestamp()).Value();
if (curr_diff > 0) {
break;
}
calculator_->OutputWithinLimits(
cc, (std::abs(curr_diff) > last_diff
? calculator_->last_packet_
: cc->Inputs().Get(calculator_->input_data_id_).Value())
.At(next_output_timestamp_));
UpdateNextOutputTimestamp();
cc->Outputs()
.Get(calculator_->output_data_id_)
.SetNextTimestampBound(next_output_timestamp_);
}
return absl::OkStatus();
}
void JitterWithoutReflectionStrategy::InitializeNextOutputTimestamp() {
next_output_timestamp_min_ = calculator_->first_timestamp_;
next_output_timestamp_ = calculator_->first_timestamp_ +
calculator_->frame_time_usec_ * random_->RandFloat();
}
void JitterWithoutReflectionStrategy::UpdateNextOutputTimestamp() {
packet_reservoir_->Clear();
packet_reservoir_->Disable();
next_output_timestamp_ += calculator_->frame_time_usec_ *
((1.0 - calculator_->jitter_) +
2.0 * calculator_->jitter_ * random_->RandFloat());
}
absl::Status NoJitterStrategy::Open(CalculatorContext* cc) {
const auto resampler_options =
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
cc->InputSidePackets(), "OPTIONS");
base_timestamp_ = resampler_options.has_base_timestamp()
? Timestamp(resampler_options.base_timestamp())
: Timestamp::Unset();
period_count_ = 0;
return absl::OkStatus();
}
absl::Status NoJitterStrategy::Close(CalculatorContext* cc) {
// Emit the last packet received if we have at least one packet, but
// haven't sent anything for its period.
if (calculator_->first_timestamp_ != Timestamp::Unset() &&
calculator_->flush_last_packet_ &&
calculator_->TimestampToPeriodIndex(
calculator_->last_packet_.Timestamp()) == period_count_) {
calculator_->OutputWithinLimits(
cc, calculator_->last_packet_.At(
calculator_->PeriodIndexToTimestamp(period_count_)));
}
return absl::OkStatus();
}
absl::Status NoJitterStrategy::Process(CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
if (calculator_->first_timestamp_ == Timestamp::Unset()) {
// This is the first packet, initialize the first_timestamp_.
if (base_timestamp_ == Timestamp::Unset()) {
// Initialize first_timestamp_ with exactly the first packet timestamp.
calculator_->first_timestamp_ = cc->InputTimestamp();
} else {
// Initialize first_timestamp_ with the first packet timestamp
// aligned to the base_timestamp_.
int64 first_index = MathUtil::SafeRound<int64, double>(
(cc->InputTimestamp() - base_timestamp_).Seconds() *
calculator_->frame_rate_);
calculator_->first_timestamp_ =
base_timestamp_ +
TimestampDiffFromSeconds(first_index / calculator_->frame_rate_);
}
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) {
cc->Outputs()
.Tag("VIDEO_HEADER")
.Add(new VideoHeader(calculator_->video_header_),
Timestamp::PreStream());
}
}
const Timestamp received_timestamp = cc->InputTimestamp();
const int64 received_timestamp_idx =
calculator_->TimestampToPeriodIndex(received_timestamp);
// Only consider the received packet if it belongs to the current period
// (== period_count_) or to a newer one (> period_count_).
if (received_timestamp_idx >= period_count_) {
// Fill the empty periods until we are in the same index as the received
// packet.
while (received_timestamp_idx > period_count_) {
calculator_->OutputWithinLimits(
cc, calculator_->last_packet_.At(
calculator_->PeriodIndexToTimestamp(period_count_)));
++period_count_;
}
// Now, if the received packet has a timestamp larger than the middle of
// the current period, we can send a packet without waiting. We send the
// one closer to the middle.
Timestamp target_timestamp =
calculator_->PeriodIndexToTimestamp(period_count_);
if (received_timestamp >= target_timestamp) {
bool have_last_packet =
(calculator_->last_packet_.Timestamp() != Timestamp::Unset());
bool send_current =
!have_last_packet ||
(received_timestamp - target_timestamp <=
target_timestamp - calculator_->last_packet_.Timestamp());
if (send_current) {
calculator_->OutputWithinLimits(cc,
cc->Inputs()
.Get(calculator_->input_data_id_)
.Value()
.At(target_timestamp));
} else {
calculator_->OutputWithinLimits(
cc, calculator_->last_packet_.At(target_timestamp));
}
++period_count_;
}
// TODO: Add a mechanism to the framework to allow these packets
// to be output earlier (without waiting for a much later packet to
// arrive)
// Update the bound for the next packet.
cc->Outputs()
.Get(calculator_->output_data_id_)
.SetNextTimestampBound(
calculator_->PeriodIndexToTimestamp(period_count_));
}
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -1,392 +0,0 @@
#ifndef MEDIAPIPE_CALCULATORS_CORE_PACKET_RESAMPLER_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_PACKET_RESAMPLER_CALCULATOR_H_
#include <cstdlib>
#include <memory>
#include <string>
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/deps/mathutil.h"
#include "mediapipe/framework/deps/random_base.h"
#include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/tool/options_util.h"
namespace mediapipe {
class PacketReservoir {
public:
PacketReservoir(RandomBase* rng) : rng_(rng) {}
// Replace candidate with current packet with 1/count_ probability.
void AddSample(Packet sample) {
if (rng_->UnbiasedUniform(++count_) == 0) {
reservoir_ = sample;
}
}
bool IsEnabled() { return rng_ && enabled_; }
void Disable() {
if (enabled_) enabled_ = false;
}
void Clear() { count_ = 0; }
bool IsEmpty() { return count_ == 0; }
Packet GetSample() { return reservoir_; }
private:
RandomBase* rng_;
bool enabled_ = true;
int32 count_ = 0;
Packet reservoir_;
};
// This calculator is used to normalize the frequency of the packets
// out of a stream. Given a desired frame rate, packets are going to be
// removed or added to achieve it.
//
// If jitter_ is specified:
// - The first packet is chosen randomly (uniform distribution) among frames
// that correspond to timestamps [0, 1/frame_rate). Let the chosen packet
// correspond to timestamp t.
// - The next packet is chosen randomly (uniform distribution) among frames
// that correspond to [t+(1-jitter)/frame_rate, t+(1+jitter)/frame_rate].
// - if jitter_with_reflection is true, the timestamp will be reflected
// against the boundaries of [t_0 + (k-1)/frame_rate, t_0 + k/frame_rate)
// so that its marginal distribution is uniform within this interval.
// In the formula, t_0 is the timestamp of the first sampled
// packet, and the k is the packet index.
// See paper (https://arxiv.org/abs/2002.01147) for details.
// - t is updated and the process is repeated.
// - Note that seed is specified as input side packet for reproducibility of
// the resampling. For Cloud ML Video Intelligence API, the hash of the
// input video should serve this purpose. For YouTube, either video ID or
// content hex ID of the input video should do.
// - If reproducible_samping is true, care is taken to allow reproducible
// "mid-stream" sampling. The calculator can be executed on a stream that
// doesn't start at the first period. For instance, if the calculator
// is run on a 10 second stream it will produce the same set of samples
// as two runs of the calculator, the first with 3 seconds of input starting
// at time 0 and the second with 7 seconds of input starting at time +3s.
// - In order to guarantee the exact same samples, 1) the inputs must be
// aligned with the sampling period. For instance, if the sampling rate
// is 2 frames per second, streams should be aligned on 0.5 second
// boundaries, and 2) the stream must include at least one extra packet
// before and after the second aligned sampling period.
//
// If jitter_ is not specified:
// - The first packet defines the first_timestamp of the output stream,
// so it is always emitted.
// - If more packets are emitted, they will have timestamp equal to
// round(first_timestamp + k * period) , where k is a positive
// integer and the period is defined by the frame rate.
// Example: first_timestamp=0, fps=30, then the output stream
// will have timestamps: 0, 33333, 66667, 100000, etc...
// - The packets selected for the output stream are the ones closer
// to the exact middle point (33333.33, 66666.67 in our previous
// example). In case of ties, later packets are chosen.
// - 'Empty' periods happen when there are no packets for a long time
// (greater than a period). In this case, we send a copy of the last
// packet received before the empty period.
// The jitter feature is disabled by default. To enable it, you need to
// implement CreateSecureRandom(const std::string&).
//
// The data stream may be either specified as the only stream (by index)
// or as the stream with tag "DATA".
//
// The input and output streams may be accompanied by a VIDEO_HEADER
// stream. This stream includes a VideoHeader at Timestamp::PreStream().
// The input VideoHeader on the VIDEO_HEADER stream will always be updated
// with the resampler frame rate no matter what the options value for
// output_header is before being output on the output VIDEO_HEADER stream.
// If the input VideoHeader is not available, then only the frame rate
// value will be set in the output.
//
// Related:
// packet_downsampler_calculator.cc: skips packets regardless of timestamps.
class PacketResamplerCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
// Given the current count of periods that have passed, this returns
// the next valid timestamp of the middle point of the next period:
// if count is 0, it returns the first_timestamp_.
// if count is 1, it returns the first_timestamp_ + period (corresponding
// to the first tick using exact fps)
// e.g. for frame_rate=30 and first_timestamp_=0:
// 0: 0
// 1: 33333
// 2: 66667
// 3: 100000
//
// Can only be used if jitter_ equals zero.
Timestamp PeriodIndexToTimestamp(int64 index) const;
// Given a Timestamp, finds the closest sync Timestamp based on
// first_timestamp_ and the desired fps.
//
// Can only be used if jitter_ equals zero.
int64 TimestampToPeriodIndex(Timestamp timestamp) const;
// Outputs a packet if it is in range (start_time_, end_time_).
void OutputWithinLimits(CalculatorContext* cc, const Packet& packet) const;
protected:
// Returns Sampling Strategy to use.
//
// Virtual to allow injection of testing strategies.
virtual std::unique_ptr<class PacketResamplerStrategy> GetSamplingStrategy(
const mediapipe::PacketResamplerCalculatorOptions& options);
private:
std::unique_ptr<class PacketResamplerStrategy> strategy_;
// The timestamp of the first packet received.
Timestamp first_timestamp_;
// Number of frames per second (desired output frequency).
double frame_rate_;
// Inverse of frame_rate_.
int64 frame_time_usec_;
VideoHeader video_header_;
// The "DATA" input stream.
CollectionItemId input_data_id_;
// The "DATA" output stream.
CollectionItemId output_data_id_;
// Indicator whether to flush last packet even if its timestamp is greater
// than the final stream timestamp.
bool flush_last_packet_;
double jitter_ = 0.0;
int64 jitter_usec_;
// The last packet that was received.
Packet last_packet_;
// If specified, only outputs at/after start_time are included.
Timestamp start_time_;
// If specified, only outputs before end_time are included.
Timestamp end_time_;
// If set, the output timestamps nearest to start_time and end_time
// are included in the output, even if the nearest timestamp is not
// between start_time and end_time.
bool round_limits_;
// Allow strategies access to all internal calculator state.
//
// The calculator and strategies are intimiately tied together so this should
// not break encapsulation.
friend class LegacyJitterWithReflectionStrategy;
friend class ReproducibleJitterWithReflectionStrategy;
friend class JitterWithoutReflectionStrategy;
friend class NoJitterStrategy;
};
// Abstract class encapsulating sampling stategy.
//
// These are used solely by PacketResamplerCalculator, but are exposed here
// to facilitate tests.
class PacketResamplerStrategy {
public:
PacketResamplerStrategy(PacketResamplerCalculator* calculator)
: calculator_(calculator) {}
virtual ~PacketResamplerStrategy() = default;
// Delegate for CalculatorBase::Open. See CalculatorBase for relevant
// implementation considerations.
virtual absl::Status Open(CalculatorContext* cc) = 0;
// Delegate for CalculatorBase::Close. See CalculatorBase for relevant
// implementation considerations.
virtual absl::Status Close(CalculatorContext* cc) = 0;
// Delegate for CalculatorBase::Process. See CalculatorBase for relevant
// implementation considerations.
virtual absl::Status Process(CalculatorContext* cc) = 0;
protected:
// Calculator running strategy.
PacketResamplerCalculator* calculator_;
};
// Strategy that applies Jitter with reflection based sampling.
//
// Used by PacketResamplerCalculator when both Jitter and reflection are
// enabled.
//
// This applies the legacy jitter with reflection which doesn't allow
// for reproducibility of sampling when starting mid-stream. This is maintained
// for backward compatibility.
class LegacyJitterWithReflectionStrategy : public PacketResamplerStrategy {
public:
LegacyJitterWithReflectionStrategy(PacketResamplerCalculator* calculator)
: PacketResamplerStrategy(calculator) {}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
void InitializeNextOutputTimestampWithJitter();
void UpdateNextOutputTimestampWithJitter();
// Jitter-related variables.
std::unique_ptr<RandomBase> random_;
// The timestamp of the first packet received.
Timestamp first_timestamp_;
// Next packet to be emitted. Since packets may not align perfectly with
// next_output_timestamp_, the closest packet will be emitted.
Timestamp next_output_timestamp_;
// Lower bound for next timestamp.
//
// next_output_timestamp_ will be kept within the interval
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
// packet reservior used for sampling random packet out of partial
// period when jitter is enabled
std::unique_ptr<PacketReservoir> packet_reservoir_;
// random number generator used in packet_reservior_.
std::unique_ptr<RandomBase> packet_reservoir_random_;
};
// Strategy that applies reproducible jitter with reflection based sampling.
//
// Used by PacketResamplerCalculator when both Jitter and reflection are
// enabled.
class ReproducibleJitterWithReflectionStrategy
: public PacketResamplerStrategy {
public:
ReproducibleJitterWithReflectionStrategy(
PacketResamplerCalculator* calculator)
: PacketResamplerStrategy(calculator) {}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
protected:
// Returns next random in range (0,n].
//
// Exposed as virtual function for testing Jitter with reflection.
// This is the only way random_ is accessed.
virtual uint64 GetNextRandom(uint64 n) {
return random_->UnbiasedUniform64(n);
}
private:
// Initializes Jitter with reflection.
//
// This will fast-forward to the period containing current_timestamp.
// next_output_timestamp_ is guarnateed to be current_timestamp's period
// and packet_emitted_this_period_ will be set to false.
void InitializeNextOutputTimestamp(Timestamp current_timestamp);
// Potentially advances next_output_timestamp_ a single period.
//
// next_output_timestamp_ will only be advanced if packet_emitted_this_period_
// is false. next_output_timestamp_ will never be advanced beyond
// current_timestamp's period.
//
// However, next_output_timestamp_ could fall before current_timestamp's
// period since only a single period can be advanced at a time.
void UpdateNextOutputTimestamp(Timestamp current_timestamp);
// Jitter-related variables.
std::unique_ptr<RandomBase> random_;
// Next packet to be emitted. Since packets may not align perfectly with
// next_output_timestamp_, the closest packet will be emitted.
Timestamp next_output_timestamp_;
// Lower bound for next timestamp.
//
// next_output_timestamp_ will be kept within the interval
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
// Indicates packet was emitted for current period (i.e. the period
// next_output_timestamp_ falls in.
bool packet_emitted_this_period_ = false;
};
// Strategy that applies Jitter without reflection based sampling.
//
// Used by PacketResamplerCalculator when Jitter is enabled and reflection is
// not enabled.
class JitterWithoutReflectionStrategy : public PacketResamplerStrategy {
public:
JitterWithoutReflectionStrategy(PacketResamplerCalculator* calculator)
: PacketResamplerStrategy(calculator) {}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
// Calculates the first sampled timestamp that incorporates a jittering
// offset.
void InitializeNextOutputTimestamp();
// Calculates the next sampled timestamp that incorporates a jittering offset.
void UpdateNextOutputTimestamp();
// Jitter-related variables.
std::unique_ptr<RandomBase> random_;
// Next packet to be emitted. Since packets may not align perfectly with
// next_output_timestamp_, the closest packet will be emitted.
Timestamp next_output_timestamp_;
// Lower bound for next timestamp.
//
// next_output_timestamp_ will be kept within the interval
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
// packet reservior used for sampling random packet out of partial period.
std::unique_ptr<PacketReservoir> packet_reservoir_;
// random number generator used in packet_reservior_.
std::unique_ptr<RandomBase> packet_reservoir_random_;
};
// Strategy that applies sampling without any jitter.
//
// Used by PacketResamplerCalculator when jitter is not enabled.
class NoJitterStrategy : public PacketResamplerStrategy {
public:
NoJitterStrategy(PacketResamplerCalculator* calculator)
: PacketResamplerStrategy(calculator) {}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
// Number of periods that have passed (= #packets sent to the output).
int64 period_count_;
// If specified, output timestamps are aligned with base_timestamp.
// Otherwise, they are aligned with the first input timestamp.
Timestamp base_timestamp_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_PACKET_RESAMPLER_CALCULATOR_H_

View File

@ -1,113 +0,0 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message PacketResamplerCalculatorOptions {
extend CalculatorOptions {
optional PacketResamplerCalculatorOptions ext = 95743844;
}
// The output frame rate measured in frames per second.
//
// The closest packet in time in each period will be chosen. If there
// is no packet in the period then the most recent packet will be chosen
// (not the closest in time).
optional double frame_rate = 1 [default = -1.0];
enum OutputHeader {
// Do not output a header, even if the input contained one.
NONE = 0;
// Pass the header, if the input contained one.
PASS_HEADER = 1;
// Update the frame rate in the header, which must be of type VideoHeader.
UPDATE_VIDEO_HEADER = 2;
}
// Whether and what kind of header to place on the output stream.
// Note, this is about the actual header, not the VIDEO_HEADER stream.
// If this option is set to UPDATE_VIDEO_HEADER then the header will
// also be parsed (updated) and passed along to the VIDEO_HEADER stream.
optional OutputHeader output_header = 2 [default = NONE];
// Flush last packet even if its timestamp is greater than the final stream
// timestamp.
optional bool flush_last_packet = 3 [default = true];
// Adds jitter to resampling if set, so that Google's sampling is not
// externally deterministic.
//
// When set, the randomizer will be initialized with a seed. Then, the first
// sample is chosen randomly (uniform distribution) among frames that
// correspond to timestamps [0, 1/frame_rate). Let the chosen frame
// correspond to timestamp t. The next frame is chosen randomly (uniform
// distribution) among frames that correspond to [t+(1-jitter)/frame_rate,
// t+(1+jitter)/frame_rate]. t is updated and the process is repeated.
//
// Valid values are in the range of [0.0, 1.0] with the default being 0.0 (no
// jitter). A typical value would be a value in the range of 0.1-0.25.
//
// Note that this does NOT guarantee the desired frame rate, but if the
// pseudo-random number generator does its job and the number of frames is
// sufficiently large, the average frame rate will be close to this value.
optional double jitter = 4;
// Enables reflection when applying jitter.
//
// This option is ignored when reproducible_sampling is true, in which case
// reflection will be used.
//
// New use cases should use reproducible_sampling = true, as
// jitter_with_reflection is deprecated and will be removed at some point.
optional bool jitter_with_reflection = 9 [default = false];
// If set, enabled reproducible sampling, allowing frames to be sampled
// without regards to where the stream starts. See
// packet_resampler_calculator.h for details.
//
// This enables reflection (ignoring jitter_with_reflection setting).
optional bool reproducible_sampling = 10 [default = false];
// If specified, output timestamps are aligned with base_timestamp.
// Otherwise, they are aligned with the first input timestamp.
//
// In order to ensure that the outptut timestamps are reproducible,
// with round_limits = false, the bounds for input timestamps must include:
// [start_time - period / 2, end_time + period / 2],
// with round_limits = true, the bounds for input timestamps must include:
// [start_time - period, end_time + period],
// where period = 1 / frame_rate.
//
// For example, in PacketResamplerCalculatorOptions specify
// "start_time: 3000000", and in MediaDecoderOptions specify
// "start_time: 2999950".
optional int64 base_timestamp = 5;
// If specified, only outputs at/after start_time are included.
optional int64 start_time = 6;
// If specified, only outputs before end_time are included.
optional int64 end_time = 7;
// If set, the output timestamps nearest to start_time and end_time
// are included in the output, even if the nearest timestamp is not
// between start_time and end_time.
optional bool round_limits = 8 [default = false];
}

View File

@ -1,752 +0,0 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/packet_resampler_calculator.h"
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
using ::testing::ElementsAre;
namespace {
// A simple version of CalculatorRunner with built-in convenience
// methods for setting inputs from a vector and checking outputs
// against expected outputs (both timestamps and contents).
class SimpleRunner : public CalculatorRunner {
public:
explicit SimpleRunner(const std::string& options_string)
: CalculatorRunner("PacketResamplerCalculator", options_string, 1, 1, 0) {
}
explicit SimpleRunner(const CalculatorGraphConfig::Node& node_config)
: CalculatorRunner(node_config) {}
virtual ~SimpleRunner() {}
void SetInput(const std::vector<int64>& timestamp_list) {
MutableInputs()->Index(0).packets.clear();
for (const int64 ts : timestamp_list) {
MutableInputs()->Index(0).packets.push_back(
Adopt(new std::string(absl::StrCat("Frame #", ts)))
.At(Timestamp(ts)));
}
}
void SetVideoHeader(const double frame_rate) {
video_header_.width = static_count_;
video_header_.height = static_count_ * 10;
video_header_.frame_rate = frame_rate;
video_header_.duration = static_count_ * 100.0;
video_header_.format = static_cast<ImageFormat::Format>(
static_count_ % ImageFormat::Format_ARRAYSIZE);
MutableInputs()->Index(0).header = Adopt(new VideoHeader(video_header_));
++static_count_;
}
void CheckOutputTimestamps(
const std::vector<int64>& expected_frames,
const std::vector<int64>& expected_timestamps) const {
EXPECT_EQ(expected_frames.size(), Outputs().Index(0).packets.size());
EXPECT_EQ(expected_timestamps.size(), Outputs().Index(0).packets.size());
int count = 0;
for (const Packet& packet : Outputs().Index(0).packets) {
EXPECT_EQ(Timestamp(expected_timestamps[count]), packet.Timestamp());
const std::string& packet_contents = packet.Get<std::string>();
EXPECT_EQ(std::string(absl::StrCat("Frame #", expected_frames[count])),
packet_contents);
++count;
}
}
void CheckVideoHeader(const double expected_frame_rate) const {
ASSERT_FALSE(Outputs().Index(0).header.IsEmpty());
const VideoHeader& header = Outputs().Index(0).header.Get<VideoHeader>();
const double frame_rate = header.frame_rate;
EXPECT_EQ(video_header_.width, header.width);
EXPECT_EQ(video_header_.height, header.height);
EXPECT_DOUBLE_EQ(expected_frame_rate, frame_rate);
EXPECT_FLOAT_EQ(video_header_.duration, header.duration);
EXPECT_EQ(video_header_.format, header.format);
}
private:
VideoHeader video_header_;
static int static_count_;
};
// Matcher for Packets with uint64 payload, comparing arg packet's
// timestamp and uint64 payload.
MATCHER_P2(PacketAtTimestamp, payload, timestamp,
absl::StrCat(negation ? "isn't" : "is", " a packet with payload ",
payload, " @ time ", timestamp)) {
if (timestamp != arg.Timestamp().Value()) {
*result_listener << "at incorrect timestamp = " << arg.Timestamp().Value();
return false;
}
int64 actual_payload = arg.template Get<int64>();
if (actual_payload != payload) {
*result_listener << "with incorrect payload = " << actual_payload;
return false;
}
return true;
}
// JitterWithReflectionStrategy child class which injects a specified stream
// of "random" numbers.
//
// Calculators are created through factory methods, making testing and injection
// tricky. This class utilizes a static variable, random_sequence, to pass
// the desired random sequence into the calculator.
class ReproducibleJitterWithReflectionStrategyForTesting
: public ReproducibleJitterWithReflectionStrategy {
public:
ReproducibleJitterWithReflectionStrategyForTesting(
PacketResamplerCalculator* calculator)
: ReproducibleJitterWithReflectionStrategy(calculator) {}
// Statically accessed random sequence to use for jitter with reflection.
//
// An EXPECT will fail if sequence is less than the number requested during
// processing.
static std::vector<uint64> random_sequence;
protected:
virtual uint64 GetNextRandom(uint64 n) {
EXPECT_LT(sequence_index_, random_sequence.size());
return random_sequence[sequence_index_++] % n;
}
private:
int32 sequence_index_ = 0;
};
std::vector<uint64>
ReproducibleJitterWithReflectionStrategyForTesting::random_sequence;
// PacketResamplerCalculator child class which injects a specified stream
// of "random" numbers.
//
// Calculators are created through factory methods, making testing and injection
// tricky. This class utilizes a static variable, random_sequence, to pass
// the desired random sequence into the calculator.
class ReproducibleResamplerCalculatorForTesting
: public PacketResamplerCalculator {
public:
static absl::Status GetContract(CalculatorContract* cc) {
return PacketResamplerCalculator::GetContract(cc);
}
protected:
std::unique_ptr<class PacketResamplerStrategy> GetSamplingStrategy(
const mediapipe::PacketResamplerCalculatorOptions& Options) {
return absl::make_unique<
ReproducibleJitterWithReflectionStrategyForTesting>(this);
}
};
REGISTER_CALCULATOR(ReproducibleResamplerCalculatorForTesting);
int SimpleRunner::static_count_ = 0;
TEST(PacketResamplerCalculatorTest, NoPacketsInStream) {
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({});
MP_ASSERT_OK(runner.Run());
}
}
TEST(PacketResamplerCalculatorTest, SinglePacketInStream) {
// Stream with 1 packet / 1 period.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0}, {0});
}
// Stream with 1 packet / 1 period (0 < packet timestamp < first limit).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({1000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({1000}, {1000});
}
// Stream with 1 packet / 1 period (packet timestamp > first limit).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({16668});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({16668}, {16668});
}
}
TEST(PacketResamplerCalculatorTest, TwoPacketsInStream) {
// Stream with 2 packets / 1 period.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 16666});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0}, {0});
}
// Stream with 2 packets / 2 periods (left extreme for second period).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 16667});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 16667}, {0, 33333});
}
// Stream with 2 packets / 2 periods (right extreme for second period).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 49999});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 49999}, {0, 33333});
}
// Stream with 2 packets / 3 periods (filling 1 in the middle).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 50000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 0, 50000}, {0, 33333, 66667});
}
// Stream with 2 packets / 4 periods (filling 2 in the middle).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({2000, 118666});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({2000, 2000, 2000, 118666},
{2000, 35333, 68667, 102000});
}
}
TEST(PacketResamplerCalculatorTest, InputAtExactFrequencyMiddlepoints) {
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 33333, 66667, 100000, 133333, 166667, 200000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps(
{0, 33333, 66667, 100000, 133333, 166667, 200000},
{0, 33333, 66667, 100000, 133333, 166667, 200000});
}
// When there are several candidates for a period, the one closer to the center
// should be sent to the output.
TEST(PacketResamplerCalculatorTest, MultiplePacketsForPeriods) {
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 16666, 16667, 20000, 33300, 49999, 50000, 66600});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 33300, 66600}, {0, 33333, 66667});
}
// When a period must be filled, we use the latest packet received (not
// necessarily the same as the one stored for the best in the previous period).
TEST(PacketResamplerCalculatorTest, FillPeriodsWithLatestPacket) {
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 5000, 16666, 83334});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 16666, 16666, 83334},
{0, 33333, 66667, 100000});
}
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 16666, 16667, 25000, 33000, 35000, 135000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 33000, 35000, 35000, 135000},
{0, 33333, 66667, 100000, 133333});
}
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({0, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 32000, 49999, 49999, 49999, 150000},
{0, 33333, 66667, 100000, 133333, 166667});
}
}
TEST(PacketResamplerCalculatorTest, SuperHighFrameRate) {
// frame rate == 500000 (a packet will have to be sent every 2 ticks).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:500000}");
runner.SetInput({0, 10, 13});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 0, 0, 0, 0, 10, 10, 13},
{0, 2, 4, 6, 8, 10, 12, 14});
}
// frame rate == 1000000 (a packet will have to be sent in each tick).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:1000000}");
runner.SetInput({0, 10, 13});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps(
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 10, 13},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13});
}
}
TEST(PacketResamplerCalculatorTest, NegativeTimestampTest) {
// Stream with negative timestamps / 1 period.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-200, -20, 16466});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-200}, {-200});
}
// Stream with negative timestamps / 2 periods.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-200, -20, 16467});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-200, 16467}, {-200, 33133});
}
// Stream with negative timestamps and filling an empty period.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-500, 66667});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-500, -500, 66667}, {-500, 32833, 66167});
}
// Stream with negative timestamps and initial packet < -period.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-50000, -33334, 33334});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-50000, -33334, -33334, 33334},
{-50000, -16667, 16667, 50000});
}
}
TEST(PacketResamplerCalculatorTest, ExactFramesPerSecond) {
// Using frame_rate=50, that makes a period of 20000 microsends (exact).
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:50}");
runner.SetInput({0, 9999, 29999});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 29999}, {0, 20000});
}
// Test filling empty periods.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:50}");
runner.SetInput({0, 10000, 50000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 10000, 10000, 50000},
{0, 20000, 40000, 60000});
}
}
TEST(PacketResamplerCalculatorTest, FrameRateTest) {
// Test changing Frame Rate to the same initial value.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:50, output_header:UPDATE_VIDEO_HEADER}");
runner.SetInput({0, 10000, 30000, 50000, 60000});
runner.SetVideoHeader(50.0);
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 10000, 30000, 60000},
{0, 20000, 40000, 60000});
runner.CheckVideoHeader(50.0);
}
// Test changing Frame Rate to new value.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:50, output_header:UPDATE_VIDEO_HEADER}");
runner.SetInput({0, 5000, 10010, 15001, 19990});
runner.SetVideoHeader(200.0);
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 19990}, {0, 20000});
runner.CheckVideoHeader(50.0);
}
// Test that the frame rate is not changing if update_video_header = false.
{
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:50, output_header:PASS_HEADER}");
runner.SetInput({0, 5000, 10010, 15001, 19990});
runner.SetVideoHeader(200.0);
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({0, 19990}, {0, 20000});
runner.CheckVideoHeader(200.0);
}
}
TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "PacketResamplerCalculator"
input_stream: "DATA:in_data"
input_stream: "VIDEO_HEADER:in_video_header"
output_stream: "DATA:out_data"
output_stream: "VIDEO_HEADER:out_video_header"
options {
[mediapipe.PacketResamplerCalculatorOptions.ext] { frame_rate: 50.0 }
}
)pb"));
for (const int64 ts : {0, 5000, 10010, 15001, 19990}) {
runner.MutableInputs()->Tag("DATA").packets.push_back(
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
}
VideoHeader video_header_in;
video_header_in.width = 10;
video_header_in.height = 100;
video_header_in.frame_rate = 1.0;
video_header_in.duration = 1.0;
video_header_in.format = ImageFormat::SRGB;
runner.MutableInputs()
->Tag("VIDEO_HEADER")
.packets.push_back(
Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream()));
MP_ASSERT_OK(runner.Run());
ASSERT_EQ(1, runner.Outputs().Tag("VIDEO_HEADER").packets.size());
EXPECT_EQ(Timestamp::PreStream(),
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Timestamp());
const VideoHeader& video_header_out =
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Get<VideoHeader>();
EXPECT_EQ(video_header_in.width, video_header_out.width);
EXPECT_EQ(video_header_in.height, video_header_out.height);
EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate);
EXPECT_FLOAT_EQ(video_header_in.duration, video_header_out.duration);
EXPECT_EQ(video_header_in.format, video_header_out.format);
}
TEST(PacketResamplerCalculatorTest, FlushLastPacketWithoutRound) {
SimpleRunner runner(R"(
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 1
})");
runner.SetInput({0, 333333, 666667, 1000000, 1333333});
MP_ASSERT_OK(runner.Run());
// 1333333 is not emitted as 2000000, because it does not round to 2000000.
runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000});
}
TEST(PacketResamplerCalculatorTest, FlushLastPacketWithRound) {
SimpleRunner runner(R"(
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 1
})");
runner.SetInput({0, 333333, 666667, 1000000, 1333333, 1666667});
MP_ASSERT_OK(runner.Run());
// 1666667 is emitted as 2000000, because it rounds to 2000000.
runner.CheckOutputTimestamps({0, 1000000, 1666667}, {0, 1000000, 2000000});
}
TEST(PacketResamplerCalculatorTest, DoNotFlushLastPacketWithoutRound) {
SimpleRunner runner(R"(
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 1
flush_last_packet: false
})");
runner.SetInput({0, 333333, 666667, 1000000, 1333333});
MP_ASSERT_OK(runner.Run());
// 1333333 is not emitted no matter what; see FlushLastPacketWithoutRound.
runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000});
}
TEST(PacketResamplerCalculatorTest, DoNotFlushLastPacketWithRound) {
SimpleRunner runner(R"(
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 1
flush_last_packet: false
})");
runner.SetInput({0, 333333, 666667, 1000000, 1333333, 1666667});
MP_ASSERT_OK(runner.Run());
// 1666667 is not emitted due to flush_last_packet: false.
runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000});
}
// When base_timestamp is specified, output timestamps are aligned with it.
TEST(PacketResamplerCalculatorTest, InputAtExactFrequencyMiddlepointsAligned) {
{
// Without base_timestamp, outputs are aligned with the first input
// timestamp, (33333 - 222).
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({33111, 66667, 100000, 133333, 166667, 200000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({33111, 66667, 100000, 133333, 166667, 200000},
{33111, 66444, 99778, 133111, 166444, 199778});
}
{
// With base_timestamp, outputs are aligned with base_timestamp, 0.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0}");
runner.SetInput({33111, 66667, 100000, 133333, 166667, 200000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps(
{33111, 66667, 100000, 133333, 166667, 200000},
{33333, 66666, 100000, 133333, 166666, 200000});
}
}
// When base_timestamp is specified, output timestamps are aligned with it.
TEST(PacketResamplerCalculatorTest, MultiplePacketsForPeriodsAligned) {
{
// Without base_timestamp, outputs are aligned with the first input, -222.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-222, 16666, 16667, 20000, 33300, 49999, 50000, 66600});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-222, 33300, 66600}, {-222, 33111, 66445});
}
{
// With base_timestamp, outputs are aligned with base_timestamp, 900011.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:900011}");
runner.SetInput({-222, 16666, 16667, 20000, 33300, 49999, 50000, 66600});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-222, 33300, 66600}, {11, 33344, 66678});
}
{
// With base_timestamp, outputs still approximate input timestamps,
// while aligned to base_timestamp, 11.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:11}");
runner.SetInput(
{899888, 916666, 916667, 920000, 933300, 949999, 950000, 966600});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({899888, 933300, 966600},
{900011, 933344, 966678});
}
}
// When a period must be filled, we use the latest packet received.
// When base_timestamp is specified, output timestamps are aligned with it.
TEST(PacketResamplerCalculatorTest, FillPeriodsWithLatestPacketAligned) {
{
// Without base_timestamp, outputs are aligned with the first input, -222.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30}");
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000},
{-222, 33111, 66445, 99778, 133111, 166445});
}
{
// With base_timestamp, outputs are aligned with base_timestamp, 0.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0}");
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000},
{0, 33333, 66667, 100000, 133333, 166667});
}
}
// When base_timestamp is specified, output timestamps are aligned with it.
// The first packet is included, because we assume that the input includes the
// whole first sampling interval.
TEST(PacketResamplerCalculatorTest, FirstInputAfterMiddlepointAligned) {
{
// Packet 100020 is omitted from the output sequence because
// packet 99990 is closer to the period midpoint.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0}");
runner.SetInput({66667, 100020, 133333, 166667});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({66667, 100020, 133333, 166667},
{66667, 100000, 133334, 166667});
}
{
// If we seek to packet 100020, packet 100020 is included in
// the output sequence, because we assume that the input includes the
// whole first sampling interval.
//
// We assume that the input includes whole sampling intervals
// in order to produce "reproducible timestamps", which are timestamps
// from the series of timestamps starting at 0.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0}");
runner.SetInput({100020, 133333, 166667});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({100020, 133333, 166667},
{100000, 133333, 166667});
}
}
TEST(PacketResamplerCalculatorTest, OutputTimestampRangeAligned) {
{
// With base_timestamp, outputs are aligned with base_timestamp, 0.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0}");
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000},
{0, 33333, 66667, 100000, 133333, 166667});
}
{
// With start_time, end_time, outputs are filtered.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0 "
"start_time:40000 "
"end_time:160000}");
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({49999, 49999, 49999},
{66667, 100000, 133333});
}
{
// With start_time, end_time, round_limits, outputs are filtered,
// rounding to the nearest limit.
SimpleRunner runner(
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
"{frame_rate:30 "
"base_timestamp:0 "
"start_time:40000 "
"end_time:160000 "
"round_limits:true}");
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run());
runner.CheckOutputTimestamps({32000, 49999, 49999, 49999, 150000},
{33333, 66667, 100000, 133333, 166667});
}
}
TEST(PacketResamplerCalculatorTest, OptionsSidePacket) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "PacketResamplerCalculator"
input_side_packet: "OPTIONS:options"
input_stream: "input"
output_stream: "output"
options {
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 60
base_timestamp: 0
}
})pb");
{
SimpleRunner runner(node_config);
auto options =
new CalculatorOptions(ParseTextProtoOrDie<CalculatorOptions>(
R"pb(
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 30
})pb"));
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(6, runner.Outputs().Index(0).packets.size());
}
{
SimpleRunner runner(node_config);
auto options =
new CalculatorOptions(ParseTextProtoOrDie<CalculatorOptions>(R"pb(
merge_fields: false
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 30
base_timestamp: 0
})pb"));
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(6, runner.Outputs().Index(0).packets.size());
}
}
} // namespace
} // namespace mediapipe

View File

@ -1,317 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Declaration of PacketThinnerCalculator.
#include <cmath> // for ceil
#include <memory>
#include "mediapipe/calculators/core/packet_thinner_calculator.pb.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/options_util.h"
namespace mediapipe {
namespace {
const double kTimebaseUs = 1000000; // Microseconds.
const char* const kOptionsTag = "OPTIONS";
const char* const kPeriodTag = "PERIOD";
} // namespace
// This calculator is used to thin an input stream of Packets.
// An example application would be to sample decoded frames of video
// at a coarser temporal resolution. Unless otherwise stated, all
// timestamps are in units of microseconds.
//
// Thinning can be accomplished in one of two ways:
// 1) asynchronous thinning (known below as async):
// Algorithm does not rely on a master clock and is parameterized only
// by a single option -- the period. Once a packet is emitted, the
// thinner will discard subsequent packets for the duration of the period
// [Analogous to a refractory period during which packet emission is
// suppressed.]
// Packets arriving before start_time are discarded, as are packets
// arriving at or after end_time.
// 2) synchronous thinning (known below as sync):
// There are two variants of this algorithm, both parameterized by a
// start_time and a period. As in (1), packets arriving before start_time
// or at/after end_time are discarded. Otherwise, at most one packet is
// emitted during a period, centered at timestamps generated by the
// expression:
// start_time + i * period [where i is a non-negative integer]
// During each period, the packet closest to the generated timestamp is
// emitted (latest in the case of ties). In the first variant
// (sync_output_timestamps = true), the emitted packet is output at the
// generated timestamp. In the second variant, the packet is output at
// its original timestamp. Both variants emit exactly the same packets,
// but at different timestamps.
//
// Thinning period can be provided in the calculator options or via a
// side packet with the tag "PERIOD".
//
// Calculator options provided optionally with the "OPTIONS" input
// sidepacket tag will be merged with this calculator's node options, i.e.,
// singular fields of the side packet will overwrite the options defined in the
// node, and repeated fields will concatenate.
//
// Example config:
// node {
// calculator: "PacketThinnerCalculator"
// input_side_packet: "OPTIONS:calculator_options"
// input_stream: "signal"
// output_stream: "output"
// options {
// [mediapipe.PacketThinnerCalculatorOptions.ext] {
// thinner_type: SYNC
// period: 10
// sync_output_timestamps: true
// update_frame_rate: false
// }
// }
// }
class PacketThinnerCalculator : public CalculatorBase {
public:
PacketThinnerCalculator() {}
~PacketThinnerCalculator() override {}
static absl::Status GetContract(CalculatorContract* cc) {
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
cc->InputSidePackets().Tag(kOptionsTag).Set<CalculatorOptions>();
}
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
if (cc->InputSidePackets().HasTag(kPeriodTag)) {
cc->InputSidePackets().Tag(kPeriodTag).Set<int64>();
}
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override {
if (cc->InputTimestamp() < start_time_) {
return absl::OkStatus(); // Drop packets before start_time_.
} else if (cc->InputTimestamp() >= end_time_) {
if (!cc->Outputs().Index(0).IsClosed()) {
cc->Outputs()
.Index(0)
.Close(); // No more Packets will be output after end_time_.
}
return absl::OkStatus();
} else {
return thinner_type_ == PacketThinnerCalculatorOptions::ASYNC
? AsyncThinnerProcess(cc)
: SyncThinnerProcess(cc);
}
}
private:
// Implementation of ASYNC and SYNC versions of thinner algorithm.
absl::Status AsyncThinnerProcess(CalculatorContext* cc);
absl::Status SyncThinnerProcess(CalculatorContext* cc);
// Cached option.
PacketThinnerCalculatorOptions::ThinnerType thinner_type_;
// Given a Timestamp, finds the closest sync Timestamp
// based on start_time_ and period_. This can be earlier or
// later than given Timestamp, but is guaranteed to be within
// half a period_.
Timestamp NearestSyncTimestamp(Timestamp now) const;
// Cached option used by both async and sync thinners.
TimestampDiff period_; // Interval during which only one packet is emitted.
Timestamp start_time_; // Cached option - default Timestamp::Min()
Timestamp end_time_; // Cached option - default Timestamp::Max()
// Only used by async thinner:
Timestamp next_valid_timestamp_; // Suppress packets until this timestamp.
// Only used by sync thinner:
Packet saved_packet_; // Best packet not yet emitted.
bool sync_output_timestamps_; // Cached option.
};
REGISTER_CALCULATOR(PacketThinnerCalculator);
namespace {
TimestampDiff abs(TimestampDiff t) { return t < 0 ? -t : t; }
} // namespace
absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) {
PacketThinnerCalculatorOptions options = mediapipe::tool::RetrieveOptions(
cc->Options<PacketThinnerCalculatorOptions>(), cc->InputSidePackets(),
kOptionsTag);
thinner_type_ = options.thinner_type();
// This check enables us to assume only two thinner types exist in Process()
CHECK(thinner_type_ == PacketThinnerCalculatorOptions::ASYNC ||
thinner_type_ == PacketThinnerCalculatorOptions::SYNC)
<< "Unsupported thinner type.";
if (thinner_type_ == PacketThinnerCalculatorOptions::ASYNC) {
// ASYNC thinner outputs packets with the same timestamp as their input so
// its safe to SetOffset(0). SYNC thinner manipulates timestamps of its
// output so we don't do this for that case.
cc->SetOffset(0);
}
if (cc->InputSidePackets().HasTag(kPeriodTag)) {
period_ =
TimestampDiff(cc->InputSidePackets().Tag(kPeriodTag).Get<int64>());
} else {
period_ = TimestampDiff(options.period());
}
CHECK_LT(TimestampDiff(0), period_) << "Specified period must be positive.";
if (options.has_start_time()) {
start_time_ = Timestamp(options.start_time());
} else if (thinner_type_ == PacketThinnerCalculatorOptions::ASYNC) {
start_time_ = Timestamp::Min();
} else {
start_time_ = Timestamp(0);
}
end_time_ =
options.has_end_time() ? Timestamp(options.end_time()) : Timestamp::Max();
CHECK_LT(start_time_, end_time_)
<< "Invalid PacketThinner: start_time must be earlier than end_time";
sync_output_timestamps_ = options.sync_output_timestamps();
next_valid_timestamp_ = start_time_;
// Drop packets until this time.
cc->Outputs().Index(0).SetNextTimestampBound(start_time_);
if (!cc->Inputs().Index(0).Header().IsEmpty()) {
if (options.update_frame_rate()) {
const VideoHeader& video_header =
cc->Inputs().Index(0).Header().Get<VideoHeader>();
double new_frame_rate;
if (thinner_type_ == PacketThinnerCalculatorOptions::ASYNC) {
new_frame_rate =
video_header.frame_rate /
ceil(video_header.frame_rate * options.period() / kTimebaseUs);
} else {
const double sampling_rate = kTimebaseUs / options.period();
new_frame_rate = video_header.frame_rate < sampling_rate
? video_header.frame_rate
: sampling_rate;
}
std::unique_ptr<VideoHeader> header(new VideoHeader);
header->format = video_header.format;
header->width = video_header.width;
header->height = video_header.height;
header->frame_rate = new_frame_rate;
cc->Outputs().Index(0).SetHeader(Adopt(header.release()));
} else {
cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header());
}
}
return absl::OkStatus();
}
absl::Status PacketThinnerCalculator::Close(CalculatorContext* cc) {
// Emit any saved packets before quitting.
if (!saved_packet_.IsEmpty()) {
// Only sync thinner should have saved packets.
CHECK_EQ(PacketThinnerCalculatorOptions::SYNC, thinner_type_);
if (sync_output_timestamps_) {
cc->Outputs().Index(0).AddPacket(
saved_packet_.At(NearestSyncTimestamp(saved_packet_.Timestamp())));
} else {
cc->Outputs().Index(0).AddPacket(saved_packet_);
}
}
return absl::OkStatus();
}
absl::Status PacketThinnerCalculator::AsyncThinnerProcess(
CalculatorContext* cc) {
if (cc->InputTimestamp() >= next_valid_timestamp_) {
cc->Outputs().Index(0).AddPacket(
cc->Inputs().Index(0).Value()); // Emit current packet.
next_valid_timestamp_ = cc->InputTimestamp() + period_;
// Guaranteed not to emit packets seen during refractory period.
cc->Outputs().Index(0).SetNextTimestampBound(next_valid_timestamp_);
}
return absl::OkStatus();
}
absl::Status PacketThinnerCalculator::SyncThinnerProcess(
CalculatorContext* cc) {
if (saved_packet_.IsEmpty()) {
// If no packet has been saved, store the current packet.
saved_packet_ = cc->Inputs().Index(0).Value();
cc->Outputs().Index(0).SetNextTimestampBound(
sync_output_timestamps_ ? NearestSyncTimestamp(cc->InputTimestamp())
: cc->InputTimestamp());
} else {
// Saved packet exists -- update or emit.
const Timestamp saved = saved_packet_.Timestamp();
const Timestamp saved_sync = NearestSyncTimestamp(saved);
const Timestamp now = cc->InputTimestamp();
const Timestamp now_sync = NearestSyncTimestamp(now);
CHECK_LE(saved_sync, now_sync);
if (saved_sync == now_sync) {
// Saved Packet is in same interval as current packet.
// Replace saved packet with current if it is at least as
// central as the saved packet wrt temporal interval.
// [We break ties in favor of fresher packets]
if (abs(now - now_sync) <= abs(saved - saved_sync)) {
saved_packet_ = cc->Inputs().Index(0).Value();
}
} else {
// Saved packet is the best packet from earlier interval: emit!
if (sync_output_timestamps_) {
cc->Outputs().Index(0).AddPacket(saved_packet_.At(saved_sync));
cc->Outputs().Index(0).SetNextTimestampBound(now_sync);
} else {
cc->Outputs().Index(0).AddPacket(saved_packet_);
cc->Outputs().Index(0).SetNextTimestampBound(now);
}
// Current packet is the first one we've seen from new interval -- save!
saved_packet_ = cc->Inputs().Index(0).Value();
}
}
return absl::OkStatus();
}
Timestamp PacketThinnerCalculator::NearestSyncTimestamp(Timestamp now) const {
CHECK_NE(start_time_, Timestamp::Unset())
<< "Method only valid for sync thinner calculator.";
// Computation is done using int64 arithmetic. No easy way to avoid
// since Timestamps don't support div and multiply.
const int64 now64 = now.Value();
const int64 start64 = start_time_.Value();
const int64 period64 = period_.Value();
CHECK_LE(0, period64);
// Round now64 to its closest interval (units of period64).
int64 sync64 =
(now64 - start64 + period64 / 2) / period64 * period64 + start64;
CHECK_LE(abs(now64 - sync64), period64 / 2)
<< "start64: " << start64 << "; now64: " << now64
<< "; sync64: " << sync64;
return Timestamp(sync64);
}
} // namespace mediapipe

View File

@ -1,68 +0,0 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message PacketThinnerCalculatorOptions {
extend CalculatorOptions {
optional PacketThinnerCalculatorOptions ext = 288533508;
}
enum ThinnerType {
ASYNC = 1; // Asynchronous thinner, described below [default].
SYNC = 2; // Synchronous thinner, also described below.
}
optional ThinnerType thinner_type = 1 [default = ASYNC];
// The period (in microsecond) specifies the temporal interval during which
// only a single packet is emitted in the output stream. Has subtly different
// semantics depending on the thinner type, as follows.
//
// Async thinner: this option is a refractory period -- once a packet is
// emitted, we guarantee that no packets will be emitted for period ticks.
//
// Sync thinner: the period specifies a temporal interval during which
// only one packet is emitted. The emitted packet is guaranteed to be
// the one closest to the center of the temporal interval (no guarantee on
// how ties are broken). More specifically,
// intervals are centered at start_time + i * period
// (for non-negative integers i).
// Thus, each interval extends period/2 ticks before and after its center.
// Additionally, in the sync thinner any packets earlier than start_time
// are discarded and the thinner calls Close() once timestamp equals or
// exceeds end_time.
optional int64 period = 2 [default = 1];
// Packets before start_time and at/after end_time are discarded.
// Additionally, for a sync thinner, start time specifies the center of
// time invervals as described above and therefore should be set explicitly.
optional int64 start_time = 3; // If not specified, set to 0 for SYNC type,
// and set to Timestamp::Min() for ASYNC type.
optional int64 end_time = 4; // Set to Timestamp::Max() if not specified.
// Whether the timestamps of packets emitted by sync thinner should
// correspond to the center of their corresponding temporal interval.
// If false, packets emitted using original timestamp (as in async thinner).
optional bool sync_output_timestamps = 5 [default = true];
// If true, update the frame rate in the header, if it's available, to an
// estimated frame rate due to the sampling.
optional bool update_frame_rate = 6 [default = false];
}

View File

@ -1,357 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/core/packet_thinner_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
// A simple version of CalculatorRunner with built-in convenience methods for
// setting inputs from a vector and checking outputs against a vector of
// expected outputs.
class SimpleRunner : public CalculatorRunner {
public:
explicit SimpleRunner(const CalculatorOptions& options)
: CalculatorRunner("PacketThinnerCalculator", options) {
SetNumInputs(1);
SetNumOutputs(1);
SetNumInputSidePackets(0);
}
explicit SimpleRunner(const CalculatorGraphConfig::Node& node)
: CalculatorRunner(node) {}
void SetInput(const std::vector<int>& timestamp_list) {
MutableInputs()->Index(0).packets.clear();
for (const int ts : timestamp_list) {
MutableInputs()->Index(0).packets.push_back(
MakePacket<std::string>(absl::StrCat("Frame #", ts))
.At(Timestamp(ts)));
}
}
void SetFrameRate(const double frame_rate) {
auto video_header = absl::make_unique<VideoHeader>();
video_header->frame_rate = frame_rate;
MutableInputs()->Index(0).header = Adopt(video_header.release());
}
std::vector<int64> GetOutputTimestamps() const {
std::vector<int64> timestamps;
for (const Packet& packet : Outputs().Index(0).packets) {
timestamps.emplace_back(packet.Timestamp().Value());
}
return timestamps;
}
double GetFrameRate() const {
CHECK(!Outputs().Index(0).header.IsEmpty());
return Outputs().Index(0).header.Get<VideoHeader>().frame_rate;
}
};
// Check that thinner respects start_time and end_time options.
// We only test with one thinner because the logic for start & end time
// handling is shared across both types of thinner in Process().
TEST(PacketThinnerCalculatorTest, StartAndEndTimeTest) {
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
extension->set_period(5);
extension->set_start_time(4);
extension->set_end_time(12);
SimpleRunner runner(options);
runner.SetInput({2, 3, 5, 7, 11, 13, 17, 19, 23, 29});
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {5, 11};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
}
TEST(PacketThinnerCalculatorTest, AsyncUniformStreamThinningTest) {
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
extension->set_period(5);
SimpleRunner runner(options);
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 8, 14};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
}
TEST(PacketThinnerCalculatorTest, ASyncUniformStreamThinningTestBySidePacket) {
// Note: sync runner but outputting *original* timestamps.
CalculatorGraphConfig::Node node;
node.set_calculator("PacketThinnerCalculator");
node.add_input_side_packet("PERIOD:period");
node.add_input_stream("input_stream");
node.add_output_stream("output_stream");
auto* extension = node.mutable_options()->MutableExtension(
PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
extension->set_start_time(0);
extension->set_sync_output_timestamps(false);
SimpleRunner runner(node);
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5);
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 8, 14};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
}
TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTest1) {
// Note: sync runner but outputting *original* timestamps.
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
extension->set_start_time(0);
extension->set_period(5);
extension->set_sync_output_timestamps(false);
SimpleRunner runner(options);
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 6, 10, 14};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
}
TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTestBySidePacket1) {
// Note: sync runner but outputting *original* timestamps.
CalculatorGraphConfig::Node node;
node.set_calculator("PacketThinnerCalculator");
node.add_input_side_packet("PERIOD:period");
node.add_input_stream("input_stream");
node.add_output_stream("output_stream");
auto* extension = node.mutable_options()->MutableExtension(
PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
extension->set_start_time(0);
extension->set_sync_output_timestamps(false);
SimpleRunner runner(node);
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5);
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 6, 10, 14};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
}
TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTest2) {
// Same test but now with synced timestamps.
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
extension->set_start_time(0);
extension->set_period(5);
extension->set_sync_output_timestamps(true);
SimpleRunner runner(options);
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {0, 5, 10, 15};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
}
// Test: Given a stream with timestamps corresponding to first ten prime numbers
// and period of 5, confirm whether timestamps of thinner stream matches
// expectations.
TEST(PacketThinnerCalculatorTest, PrimeStreamThinningTest1) {
// ASYNC thinner.
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
extension->set_period(5);
SimpleRunner runner(options);
runner.SetInput({2, 3, 5, 7, 11, 13, 17, 19, 23, 29});
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 7, 13, 19, 29};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
}
TEST(PacketThinnerCalculatorTest, PrimeStreamThinningTest2) {
// SYNC with original timestamps.
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
extension->set_start_time(0);
extension->set_period(5);
extension->set_sync_output_timestamps(false);
SimpleRunner runner(options);
runner.SetInput({2, 3, 5, 7, 11, 13, 17, 19, 23, 29});
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 5, 11, 17, 19, 23, 29};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
}
// Confirm that Calculator correctly handles boundary cases.
TEST(PacketThinnerCalculatorTest, BoundaryTimestampTest1) {
// Odd period, negative start_time
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
extension->set_start_time(-10);
extension->set_period(5);
extension->set_sync_output_timestamps(true);
SimpleRunner runner(options);
// Two timestamps falling on either side of a period boundary.
runner.SetInput({2, 3});
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {0, 5};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
}
TEST(PacketThinnerCalculatorTest, BoundaryTimestampTest2) {
// Even period, negative start_time, negative packet timestamps.
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
extension->set_start_time(-144);
extension->set_period(6);
extension->set_sync_output_timestamps(true);
SimpleRunner runner(options);
// Two timestamps falling on either side of a period boundary.
runner.SetInput({-4, -3, 8, 9});
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {-6, 0, 6, 12};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
}
TEST(PacketThinnerCalculatorTest, FrameRateTest1) {
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
extension->set_period(5);
extension->set_update_frame_rate(true);
SimpleRunner runner(options);
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
runner.SetFrameRate(1000000.0 / 2);
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 8, 14};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
// The true sampling period is 6.
EXPECT_DOUBLE_EQ(1000000.0 / 6, runner.GetFrameRate());
}
TEST(PacketThinnerCalculatorTest, FrameRateTest2) {
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
extension->set_period(5);
extension->set_update_frame_rate(true);
SimpleRunner runner(options);
runner.SetInput({8, 16, 24, 32, 40, 48, 56});
runner.SetFrameRate(1000000.0 / 8);
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {8, 16, 24, 32, 40, 48, 56};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
// The true sampling period is still 8.
EXPECT_DOUBLE_EQ(1000000.0 / 8, runner.GetFrameRate());
}
TEST(PacketThinnerCalculatorTest, FrameRateTest3) {
// Note: sync runner but outputting *original* timestamps.
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
extension->set_start_time(0);
extension->set_period(5);
extension->set_sync_output_timestamps(false);
extension->set_update_frame_rate(true);
SimpleRunner runner(options);
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
runner.SetFrameRate(1000000.0 / 2);
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 6, 10, 14};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
// The true (long-run) sampling period is 5.
EXPECT_DOUBLE_EQ(1000000.0 / 5, runner.GetFrameRate());
}
TEST(PacketThinnerCalculatorTest, FrameRateTest4) {
// Same test but now with synced timestamps.
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
extension->set_start_time(0);
extension->set_period(5);
extension->set_sync_output_timestamps(true);
extension->set_update_frame_rate(true);
SimpleRunner runner(options);
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
runner.SetFrameRate(1000000.0 / 2);
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {0, 5, 10, 15};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
// The true (long-run) sampling period is 5.
EXPECT_DOUBLE_EQ(1000000.0 / 5, runner.GetFrameRate());
}
TEST(PacketThinnerCalculatorTest, FrameRateTest5) {
CalculatorOptions options;
auto* extension =
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
extension->set_start_time(0);
extension->set_period(5);
extension->set_sync_output_timestamps(true);
extension->set_update_frame_rate(true);
SimpleRunner runner(options);
runner.SetInput({8, 16, 24, 32, 40, 48, 56});
runner.SetFrameRate(1000000.0 / 8);
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {10, 15, 25, 30, 40, 50, 55};
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
// The true (long-run) sampling period is 8.
EXPECT_DOUBLE_EQ(1000000.0 / 8, runner.GetFrameRate());
}
} // namespace
} // namespace mediapipe

View File

@ -1,98 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
namespace mediapipe {
// A Calculator that simply passes its input Packets and header through,
// unchanged. The inputs may be specified by tag or index. The outputs
// must match the inputs exactly. Any number of input side packets may
// also be specified. If output side packets are specified, they must
// match the input side packets exactly and the Calculator passes its
// input side packets through, unchanged. Otherwise, the input side
// packets will be ignored (allowing PassThroughCalculator to be used to
// test internal behavior). Any options may be specified and will be
// ignored.
class PassThroughCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
if (!cc->Inputs().TagMap()->SameAs(*cc->Outputs().TagMap())) {
return absl::InvalidArgumentError(
"Input and output streams to PassThroughCalculator must use "
"matching tags and indexes.");
}
for (CollectionItemId id = cc->Inputs().BeginId();
id < cc->Inputs().EndId(); ++id) {
cc->Inputs().Get(id).SetAny();
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Get(id));
}
for (CollectionItemId id = cc->InputSidePackets().BeginId();
id < cc->InputSidePackets().EndId(); ++id) {
cc->InputSidePackets().Get(id).SetAny();
}
if (cc->OutputSidePackets().NumEntries() != 0) {
if (!cc->InputSidePackets().TagMap()->SameAs(
*cc->OutputSidePackets().TagMap())) {
return absl::InvalidArgumentError(
"Input and output side packets to PassThroughCalculator must use "
"matching tags and indexes.");
}
for (CollectionItemId id = cc->InputSidePackets().BeginId();
id < cc->InputSidePackets().EndId(); ++id) {
cc->OutputSidePackets().Get(id).SetSameAs(
&cc->InputSidePackets().Get(id));
}
}
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
for (CollectionItemId id = cc->Inputs().BeginId();
id < cc->Inputs().EndId(); ++id) {
if (!cc->Inputs().Get(id).Header().IsEmpty()) {
cc->Outputs().Get(id).SetHeader(cc->Inputs().Get(id).Header());
}
}
if (cc->OutputSidePackets().NumEntries() != 0) {
for (CollectionItemId id = cc->InputSidePackets().BeginId();
id < cc->InputSidePackets().EndId(); ++id) {
cc->OutputSidePackets().Get(id).Set(cc->InputSidePackets().Get(id));
}
}
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
cc->GetCounter("PassThrough")->Increment();
if (cc->Inputs().NumEntries() == 0) {
return tool::StatusStop();
}
for (CollectionItemId id = cc->Inputs().BeginId();
id < cc->Inputs().EndId(); ++id) {
if (!cc->Inputs().Get(id).IsEmpty()) {
VLOG(3) << "Passing " << cc->Inputs().Get(id).Name() << " to "
<< cc->Outputs().Get(id).Name() << " at "
<< cc->InputTimestamp().DebugString();
cc->Outputs().Get(id).AddPacket(cc->Inputs().Get(id).Value());
}
}
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(PassThroughCalculator);
} // namespace mediapipe

View File

@ -1,177 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <deque>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/timestamp.h"
namespace mediapipe {
namespace api2 {
// PreviousLoopbackCalculator is useful when a graph needs to process an input
// together with some previous output.
//
// For the first packet that arrives on the MAIN input, the timestamp bound is
// advanced on the PREV_LOOP. Downstream calculators will see this as an empty
// packet. This way they are not kept waiting for the previous output, which
// for the first iteration does not exist.
//
// Thereafter,
// - Each non-empty MAIN packet results in:
// a) a PREV_LOOP packet with contents of the LOOP packet received at the
// timestamp of the previous non-empty MAIN packet
// b) or in a PREV_LOOP timestamp bound update if the LOOP packet was empty.
// - Each empty MAIN packet indicating timestamp bound update results in a
// PREV_LOOP timestamp bound update.
//
// Example config:
// node {
// calculator: "PreviousLoopbackCalculator"
// input_stream: "MAIN:input"
// input_stream: "LOOP:output"
// input_stream_info: { tag_index: 'LOOP' back_edge: true }
// output_stream: "PREV_LOOP:prev_output"
// }
// node {
// calculator: "FaceTracker"
// input_stream: "VIDEO:input"
// input_stream: "PREV_TRACK:prev_output"
// output_stream: "TRACK:output"
// }
class PreviousLoopbackCalculator : public Node {
public:
static constexpr Input<AnyType> kMain{"MAIN"};
static constexpr Input<AnyType> kLoop{"LOOP"};
static constexpr Output<SameType<kLoop>> kPrevLoop{"PREV_LOOP"};
// TODO: an optional PREV_TIMESTAMP output could be added to
// carry the original timestamp of the packet on PREV_LOOP.
MEDIAPIPE_NODE_CONTRACT(kMain, kLoop, kPrevLoop,
StreamHandler("ImmediateInputStreamHandler"),
TimestampChange::Arbitrary());
static absl::Status UpdateContract(CalculatorContract* cc) {
// Process() function is invoked in response to MAIN/LOOP stream timestamp
// bound updates.
cc->SetProcessTimestampBounds(true);
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
kPrevLoop(cc).SetHeader(kLoop(cc).Header());
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
// Non-empty packets and empty packets indicating timestamp bound updates
// are guaranteed to have timestamps greater than timestamps of previous
// packets within the same stream. Calculator tracks and operates on such
// packets.
const PacketBase& main_packet = kMain(cc).packet();
if (prev_main_ts_ < main_packet.timestamp()) {
Timestamp loop_timestamp;
if (!main_packet.IsEmpty()) {
loop_timestamp = prev_non_empty_main_ts_;
prev_non_empty_main_ts_ = main_packet.timestamp();
} else {
// Calculator advances PREV_LOOP timestamp bound in response to empty
// MAIN packet, hence not caring about corresponding loop packet.
loop_timestamp = Timestamp::Unset();
}
main_packet_specs_.push_back({main_packet.timestamp(), loop_timestamp});
prev_main_ts_ = main_packet.timestamp();
}
const PacketBase& loop_packet = kLoop(cc).packet();
if (prev_loop_ts_ < loop_packet.timestamp()) {
loop_packets_.push_back(loop_packet);
prev_loop_ts_ = loop_packet.timestamp();
}
while (!main_packet_specs_.empty() && !loop_packets_.empty()) {
// The earliest MAIN packet.
MainPacketSpec main_spec = main_packet_specs_.front();
// The earliest LOOP packet.
const PacketBase& loop_candidate = loop_packets_.front();
// Match LOOP and MAIN packets.
if (main_spec.loop_timestamp < loop_candidate.timestamp()) {
// No LOOP packet can match the MAIN packet under review.
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
main_packet_specs_.pop_front();
} else if (main_spec.loop_timestamp > loop_candidate.timestamp()) {
// No MAIN packet can match the LOOP packet under review.
loop_packets_.pop_front();
} else {
// Exact match found.
if (loop_candidate.IsEmpty()) {
// However, LOOP packet is empty.
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
} else {
kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp));
}
loop_packets_.pop_front();
main_packet_specs_.pop_front();
}
// We can close PREV_LOOP output stream as soon as we processed last
// possible MAIN packet. That can happen in two cases:
// a) Non-empty MAIN packet has been received with Timestamp::Max()
// b) Empty MAIN packet has been received with Timestamp::Max() indicating
// MAIN is done.
if (main_spec.timestamp == Timestamp::Done().PreviousAllowedInStream()) {
kPrevLoop(cc).Close();
}
}
return absl::OkStatus();
}
private:
struct MainPacketSpec {
Timestamp timestamp;
// Expected timestamp of the packet from LOOP stream that corresponds to the
// packet from MAIN stream descirbed by this spec.
Timestamp loop_timestamp;
};
// Contains specs for MAIN packets which only can be:
// - non-empty packets
// - empty packets indicating timestamp bound updates
//
// Sorted according to packet timestamps.
std::deque<MainPacketSpec> main_packet_specs_;
Timestamp prev_main_ts_ = Timestamp::Unstarted();
Timestamp prev_non_empty_main_ts_ = Timestamp::Unstarted();
// Contains LOOP packets which only can be:
// - the very first empty packet
// - non empty packets
// - empty packets indicating timestamp bound updates
//
// Sorted according to packet timestamps.
std::deque<PacketBase> loop_packets_;
// Using "Timestamp::Unset" instead of "Timestamp::Unstarted" in order to
// allow addition of the very first empty packet (which doesn't indicate
// timestamp bound change necessarily).
Timestamp prev_loop_ts_ = Timestamp::Unset();
};
MEDIAPIPE_REGISTER_NODE(PreviousLoopbackCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -1,867 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/sink.h"
namespace mediapipe {
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Pair;
using ::testing::Value;
namespace {
// Returns the timestamp values for a vector of Packets.
// TODO: puth this kind of test util in a common place.
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
std::vector<int64> result;
for (const Packet& packet : packets) {
result.push_back(packet.Timestamp().Value());
}
return result;
}
MATCHER(EmptyPacket, negation ? "isn't empty" : "is empty") {
if (arg.IsEmpty()) {
return true;
}
return false;
}
MATCHER_P(IntPacket, value, "") {
return Value(arg.template Get<int>(), Eq(value));
}
MATCHER_P2(PairPacket, timestamp, pair, "") {
Timestamp actual_timestamp = arg.Timestamp();
const auto& actual_pair = arg.template Get<std::pair<Packet, Packet>>();
return Value(actual_timestamp, Eq(timestamp)) && Value(actual_pair, pair);
}
TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
std::vector<Packet> in_prev;
CalculatorGraphConfig graph_config_ =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:in'
input_stream: 'LOOP:out'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:previous'
}
# This calculator synchronizes its inputs as normal, so it is used
# to check that both "in" and "previous" are ready.
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
input_stream: 'previous'
output_stream: 'out'
output_stream: 'previous2'
}
node {
calculator: 'MakePairCalculator'
input_stream: 'out'
input_stream: 'previous2'
output_stream: 'pair'
}
)pb");
tool::AddVectorSink("pair", &graph_config_, &in_prev);
CalculatorGraph graph_;
MP_ASSERT_OK(graph_.Initialize(graph_config_, {}));
MP_ASSERT_OK(graph_.StartRun({}));
auto send_packet = [&graph_](const std::string& input_name, int n) {
MP_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int>(n).At(Timestamp(n))));
};
send_packet("in", 1);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1));
EXPECT_THAT(in_prev.back(),
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())));
send_packet("in", 2);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2));
EXPECT_THAT(in_prev.back(),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))));
send_packet("in", 5);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2, 5));
EXPECT_THAT(in_prev.back(),
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(2))));
send_packet("in", 15);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2, 5, 15));
EXPECT_THAT(in_prev.back(),
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
// A Calculator that outputs a summary packet in CalculatorBase::Close().
class PacketOnCloseCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
sum_ += cc->Inputs().Index(0).Value().Get<int>();
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return absl::OkStatus();
}
absl::Status Close(CalculatorContext* cc) final {
cc->Outputs().Index(0).AddPacket(
MakePacket<int>(sum_).At(Timestamp::Max()));
return absl::OkStatus();
}
private:
int sum_ = 0;
};
REGISTER_CALCULATOR(PacketOnCloseCalculator);
// Demonstrates that all ouput and input streams in PreviousLoopbackCalculator
// will close as expected when all graph input streams are closed.
TEST(PreviousLoopbackCalculator, ClosesCorrectly) {
std::vector<Packet> outputs;
CalculatorGraphConfig graph_config_ =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:in'
input_stream: 'LOOP:out'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:previous'
}
# This calculator synchronizes its inputs as normal, so it is used
# to check that both "in" and "previous" are ready.
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
input_stream: 'previous'
output_stream: 'out'
output_stream: 'previous2'
}
node {
calculator: 'PacketOnCloseCalculator'
input_stream: 'out'
output_stream: 'close_out'
}
)pb");
tool::AddVectorSink("close_out", &graph_config_, &outputs);
CalculatorGraph graph_;
MP_ASSERT_OK(graph_.Initialize(graph_config_, {}));
MP_ASSERT_OK(graph_.StartRun({}));
auto send_packet = [&graph_](const std::string& input_name, int n) {
MP_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int>(n).At(Timestamp(n))));
};
send_packet("in", 1);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1));
send_packet("in", 2);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2));
send_packet("in", 5);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2, 5));
send_packet("in", 15);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2, 5, 15));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(TimestampValues(outputs),
ElementsAre(1, 2, 5, 15, Timestamp::Max().Value()));
MP_EXPECT_OK(graph_.WaitUntilDone());
}
TEST(PreviousLoopbackCalculator, ProcessesMaxTimestamp) {
std::vector<Packet> out_and_previous_packets;
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:in'
input_stream: 'LOOP:out'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:previous'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
input_stream: 'previous'
output_stream: 'out'
output_stream: 'previous2'
}
node {
calculator: 'MakePairCalculator'
input_stream: 'out'
input_stream: 'previous'
output_stream: 'out_and_previous'
}
)pb");
tool::AddVectorSink("out_and_previous", &graph_config,
&out_and_previous_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", MakePacket<int>(1).At(Timestamp::Max())));
MP_EXPECT_OK(graph.WaitUntilIdle());
EXPECT_THAT(out_and_previous_packets,
ElementsAre(PairPacket(Timestamp::Max(),
Pair(IntPacket(1), EmptyPacket()))));
MP_EXPECT_OK(graph.CloseAllInputStreams());
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.WaitUntilDone());
}
TEST(PreviousLoopbackCalculator, ProcessesMaxTimestampNonEmptyPrevious) {
std::vector<Packet> out_and_previous_packets;
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:in'
input_stream: 'LOOP:out'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:previous'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
input_stream: 'previous'
output_stream: 'out'
output_stream: 'previous2'
}
node {
calculator: 'MakePairCalculator'
input_stream: 'out'
input_stream: 'previous'
output_stream: 'out_and_previous'
}
)pb");
tool::AddVectorSink("out_and_previous", &graph_config,
&out_and_previous_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", MakePacket<int>(1).At(Timestamp::Min())));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", MakePacket<int>(2).At(Timestamp::Max())));
MP_EXPECT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
out_and_previous_packets,
ElementsAre(
PairPacket(Timestamp::Min(), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp::Max(), Pair(IntPacket(2), IntPacket(1)))));
MP_EXPECT_OK(graph.CloseAllInputStreams());
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.WaitUntilDone());
}
// Demonstrates that downstream calculators won't be blocked by
// always-empty-LOOP-stream.
TEST(PreviousLoopbackCalculator, EmptyLoopForever) {
std::vector<Packet> outputs;
CalculatorGraphConfig graph_config_ =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:in'
input_stream: 'LOOP:previous'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:previous'
}
# This calculator synchronizes its inputs as normal, so it is used
# to check that both "in" and "previous" are ready.
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
input_stream: 'previous'
output_stream: 'out'
output_stream: 'previous2'
}
node {
calculator: 'PacketOnCloseCalculator'
input_stream: 'out'
output_stream: 'close_out'
}
)pb");
tool::AddVectorSink("close_out", &graph_config_, &outputs);
CalculatorGraph graph_;
MP_ASSERT_OK(graph_.Initialize(graph_config_, {}));
MP_ASSERT_OK(graph_.StartRun({}));
auto send_packet = [&graph_](const std::string& input_name, int n) {
MP_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int>(n).At(Timestamp(n))));
};
for (int main_ts = 0; main_ts < 50; ++main_ts) {
send_packet("in", main_ts);
MP_EXPECT_OK(graph_.WaitUntilIdle());
std::vector<int64> ts_values = TimestampValues(outputs);
EXPECT_EQ(ts_values.size(), main_ts + 1);
for (int j = 0; j < main_ts + 1; ++j) {
EXPECT_EQ(ts_values[j], j);
}
}
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilIdle());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
class PreviousLoopbackCalculatorProcessingTimestampsTest
: public testing::Test {
protected:
void SetUp() override {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
input_stream: 'force_main_empty'
input_stream: 'force_loop_empty'
# Used to indicate "main" timestamp bound updates.
node {
calculator: 'GateCalculator'
input_stream: 'input'
input_stream: 'DISALLOW:force_main_empty'
output_stream: 'main'
}
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:main'
input_stream: 'LOOP:loop'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:prev_loop'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'input'
input_stream: 'prev_loop'
output_stream: 'passed_through_input'
output_stream: 'passed_through_prev_loop'
}
# Used to indicate "loop" timestamp bound updates.
node {
calculator: 'GateCalculator'
input_stream: 'input'
input_stream: 'DISALLOW:force_loop_empty'
output_stream: 'loop'
}
node {
calculator: 'MakePairCalculator'
input_stream: 'passed_through_input'
input_stream: 'passed_through_prev_loop'
output_stream: 'passed_through_input_and_prev_loop'
}
)pb");
tool::AddVectorSink("passed_through_input_and_prev_loop", &graph_config,
&output_packets_);
MP_ASSERT_OK(graph_.Initialize(graph_config, {}));
MP_ASSERT_OK(graph_.StartRun({}));
}
void SendPackets(int timestamp, int input, bool force_main_empty,
bool force_loop_empty) {
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"input", MakePacket<int>(input).At(Timestamp(timestamp))));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"force_main_empty",
MakePacket<bool>(force_main_empty).At(Timestamp(timestamp))));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"force_loop_empty",
MakePacket<bool>(force_loop_empty).At(Timestamp(timestamp))));
}
CalculatorGraph graph_;
std::vector<Packet> output_packets_;
};
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
MultiplePacketsEmptyMainNonEmptyLoop) {
SendPackets(/*timestamp=*/1, /*input=*/1, /*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2, /*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
SendPackets(/*timestamp=*/3, /*input=*/3, /*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
SendPackets(/*timestamp=*/5, /*input=*/5, /*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
SendPackets(/*timestamp=*/15, /*input=*/15,
/*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
MultiplePacketsNonEmptyMainEmptyLoop) {
SendPackets(/*timestamp=*/1, /*input=*/1,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
SendPackets(/*timestamp=*/3, /*input=*/3,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
SendPackets(/*timestamp=*/5, /*input=*/5,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
SendPackets(/*timestamp=*/15, /*input=*/15,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
MultiplePacketsAlteringMainNonEmptyLoop) {
SendPackets(/*timestamp=*/1, /*input=*/1,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2, /*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
SendPackets(/*timestamp=*/3, /*input=*/3,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1)))));
SendPackets(/*timestamp=*/5, /*input=*/5,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1))),
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3)))));
SendPackets(/*timestamp=*/15, /*input=*/15,
/*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1))),
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3))),
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
MultiplePacketsNonEmptyMainAlteringLoop) {
SendPackets(/*timestamp=*/1, /*input=*/1,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
SendPackets(/*timestamp=*/3, /*input=*/3,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
SendPackets(/*timestamp=*/5, /*input=*/5,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3)))));
SendPackets(/*timestamp=*/15, /*input=*/15,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3))),
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
MultiplePacketsCheckIfLastCorrectAlteringMainAlteringLoop) {
int num_packets = 1000;
for (int i = 0; i < num_packets; ++i) {
bool force_main_empty = i % 3 == 0 ? true : false;
bool force_loop_empty = i % 2 == 0 ? true : false;
SendPackets(/*timestamp=*/i + 1, /*input=*/i + 1, force_main_empty,
force_loop_empty);
}
SendPackets(/*timestamp=*/num_packets + 1,
/*input=*/num_packets + 1, /*force_main_empty=*/false,
/*force_loop_empty=*/false);
SendPackets(/*timestamp=*/num_packets + 2,
/*input=*/num_packets + 2, /*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
ASSERT_FALSE(output_packets_.empty());
EXPECT_THAT(
output_packets_.back(),
PairPacket(Timestamp(num_packets + 2),
Pair(IntPacket(num_packets + 2), IntPacket(num_packets + 1))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
// Similar to GateCalculator, but it doesn't propagate timestamp bound updates.
class DroppingGateCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Inputs().Tag("DISALLOW").Set<bool>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
if (!cc->Inputs().Index(0).IsEmpty() &&
!cc->Inputs().Tag("DISALLOW").Get<bool>()) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
}
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(DroppingGateCalculator);
// Tests PreviousLoopbackCalculator in cases when there are no "LOOP" timestamp
// bound updates and non-empty packets for a while and the aforementioned start
// to arrive at some point. So, "PREV_LOOP" is delayed for a couple of inputs.
class PreviousLoopbackCalculatorDelayBehaviorTest : public testing::Test {
protected:
void SetUp() override {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
# Drops "loop" when set to "true", delaying output of prev_loop, hence
# delaying output of the graph.
input_stream: 'delay_next_output'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:input'
input_stream: 'LOOP:loop'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:prev_loop'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'input'
input_stream: 'prev_loop'
output_stream: 'passed_through_input'
output_stream: 'passed_through_prev_loop'
}
node {
calculator: 'DroppingGateCalculator'
input_stream: 'input'
input_stream: 'DISALLOW:delay_next_output'
output_stream: 'loop'
}
node {
calculator: 'MakePairCalculator'
input_stream: 'passed_through_input'
input_stream: 'passed_through_prev_loop'
output_stream: 'passed_through_input_and_prev_loop'
}
)pb");
tool::AddVectorSink("passed_through_input_and_prev_loop", &graph_config,
&output_packets_);
MP_ASSERT_OK(graph_.Initialize(graph_config, {}));
MP_ASSERT_OK(graph_.StartRun({}));
}
void SendPackets(int timestamp, int input, bool delay_next_output) {
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"input", MakePacket<int>(input).At(Timestamp(timestamp))));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"delay_next_output",
MakePacket<bool>(delay_next_output).At(Timestamp(timestamp))));
}
CalculatorGraph graph_;
std::vector<Packet> output_packets_;
};
TEST_F(PreviousLoopbackCalculatorDelayBehaviorTest, MultipleDelayedOutputs) {
SendPackets(/*timestamp=*/1, /*input=*/1, /*delay_next_output=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2, /*delay_next_output=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/3, /*input=*/3, /*delay_next_output=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/5, /*input=*/5, /*delay_next_output=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
SendPackets(/*timestamp=*/15, /*input=*/15, /*delay_next_output=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5)))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
TEST_F(PreviousLoopbackCalculatorDelayBehaviorTest,
NonDelayedOutputFollowedByMultipleDelayedOutputs) {
SendPackets(/*timestamp=*/1, /*input=*/1, /*delay_next_output=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2, /*delay_next_output=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
SendPackets(/*timestamp=*/3, /*input=*/3, /*delay_next_output=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
SendPackets(/*timestamp=*/5, /*input=*/5, /*delay_next_output=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
SendPackets(/*timestamp=*/15, /*input=*/15, /*delay_next_output=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5)))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
} // anonymous namespace
} // namespace mediapipe

View File

@ -1,102 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cfloat>
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/calculators/core/quantize_float_vector_calculator.pb.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/status.h"
// Quantizes a vector of floats to a std::string so that each float becomes a
// byte in the [0, 255] range. Any value above max_quantized_value or below
// min_quantized_value will be saturated to '/xFF' or '/0'.
//
// Example config:
// node {
// calculator: "QuantizeFloatVectorCalculator"
// input_stream: "FLOAT_VECTOR:float_vector"
// output_stream: "ENCODED:encoded"
// options {
// [mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
// max_quantized_value: 64
// min_quantized_value: -64
// }
// }
// }
namespace mediapipe {
class QuantizeFloatVectorCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
cc->Outputs().Tag("ENCODED").Set<std::string>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
const auto options =
cc->Options<::mediapipe::QuantizeFloatVectorCalculatorOptions>();
if (!options.has_max_quantized_value() ||
!options.has_min_quantized_value()) {
return absl::InvalidArgumentError(
"Both max_quantized_value and min_quantized_value must be provided "
"in QuantizeFloatVectorCalculatorOptions.");
}
max_quantized_value_ = options.max_quantized_value();
min_quantized_value_ = options.min_quantized_value();
if (max_quantized_value_ < min_quantized_value_ + FLT_EPSILON) {
return absl::InvalidArgumentError(
"max_quantized_value must be greater than min_quantized_value.");
}
range_ = max_quantized_value_ - min_quantized_value_;
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
const std::vector<float>& float_vector =
cc->Inputs().Tag("FLOAT_VECTOR").Value().Get<std::vector<float>>();
int feature_size = float_vector.size();
std::string encoded_features;
encoded_features.reserve(feature_size);
for (int i = 0; i < feature_size; i++) {
float old_value = float_vector[i];
if (old_value < min_quantized_value_) {
old_value = min_quantized_value_;
}
if (old_value > max_quantized_value_) {
old_value = max_quantized_value_;
}
unsigned char encoded = static_cast<unsigned char>(
(old_value - min_quantized_value_) * (255.0 / range_));
encoded_features += encoded;
}
cc->Outputs().Tag("ENCODED").AddPacket(
MakePacket<std::string>(encoded_features).At(cc->InputTimestamp()));
return absl::OkStatus();
}
private:
float max_quantized_value_;
float min_quantized_value_;
float range_;
};
REGISTER_CALCULATOR(QuantizeFloatVectorCalculator);
} // namespace mediapipe

View File

@ -1,30 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message QuantizeFloatVectorCalculatorOptions {
extend CalculatorOptions {
optional QuantizeFloatVectorCalculatorOptions ext = 259848061;
}
optional float max_quantized_value = 1;
optional float min_quantized_value = 2;
}

View File

@ -1,204 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
min_quantized_value: 1
}
}
)pb");
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"Both max_quantized_value and min_quantized_value must be provided"));
}
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
max_quantized_value: -1
min_quantized_value: 1
}
}
)pb");
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"max_quantized_value must be greater than min_quantized_value"));
}
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
max_quantized_value: 1
min_quantized_value: 1
}
}
)pb");
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"max_quantized_value must be greater than min_quantized_value"));
}
TEST(QuantizeFloatVectorCalculatorTest, TestEmptyVector) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
max_quantized_value: 1
min_quantized_value: -1
}
}
)pb");
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
EXPECT_EQ(1, outputs.size());
EXPECT_TRUE(outputs[0].Get<std::string>().empty());
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
}
TEST(QuantizeFloatVectorCalculatorTest, TestNonEmptyVector) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
max_quantized_value: 64
min_quantized_value: -64
}
}
)pb");
CalculatorRunner runner(node_config);
std::vector<float> vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f};
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
EXPECT_EQ(1, outputs.size());
const std::string& result = outputs[0].Get<std::string>();
ASSERT_FALSE(result.empty());
EXPECT_EQ(5, result.size());
// 127
EXPECT_EQ('\x7F', result.c_str()[0]);
// 0
EXPECT_EQ('\0', result.c_str()[1]);
// 255
EXPECT_EQ('\xFF', result.c_str()[2]);
// 63
EXPECT_EQ('\x3F', result.c_str()[3]);
// 191
EXPECT_EQ('\xBF', result.c_str()[4]);
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
}
TEST(QuantizeFloatVectorCalculatorTest, TestSaturation) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "QuantizeFloatVectorCalculator"
input_stream: "FLOAT_VECTOR:float_vector"
output_stream: "ENCODED:encoded"
options {
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
max_quantized_value: 64
min_quantized_value: -64
}
}
)pb");
CalculatorRunner runner(node_config);
std::vector<float> vector = {-65.0f, 65.0f};
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
.packets.push_back(
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
EXPECT_EQ(1, outputs.size());
const std::string& result = outputs[0].Get<std::string>();
ASSERT_FALSE(result.empty());
EXPECT_EQ(2, result.size());
// 0
EXPECT_EQ('\0', result.c_str()[0]);
// 255
EXPECT_EQ('\xFF', result.c_str()[1]);
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
}
} // namespace mediapipe

Some files were not shown because too many files have changed in this diff Show More