examples
This commit is contained in:
parent
ba15087099
commit
7c22caf9d6
|
|
@ -1,5 +1,13 @@
|
|||
import mediapipe as mp
|
||||
import gradio as gr
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
|
||||
# Images
|
||||
torch.hub.download_url_to_file('https://artbreeder.b-cdn.net/imgs/c789e54661bfb432c5522a36553f.jpeg', 'face1.jpg')
|
||||
torch.hub.download_url_to_file('https://artbreeder.b-cdn.net/imgs/c86622e8cb58d490e35b01cb9996.jpeg', 'face2.jpg')
|
||||
|
||||
mp_face_mesh = mp.solutions.face_mesh
|
||||
|
||||
# Prepare DrawingSpec for drawing the face landmarks later.
|
||||
|
|
@ -16,16 +24,28 @@ def inference(image):
|
|||
# Convert the BGR image to RGB and process it with MediaPipe Face Mesh.
|
||||
results = face_mesh.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
# Draw face landmarks of each face.
|
||||
print(f'Face landmarks of {name}:')
|
||||
if not results.multi_face_landmarks:
|
||||
continue
|
||||
annotated_image = image.copy()
|
||||
for face_landmarks in results.multi_face_landmarks:
|
||||
mp_drawing.draw_landmarks(
|
||||
image=annotated_image,
|
||||
landmark_list=face_landmarks,
|
||||
connections=mp_face_mesh.FACE_CONNECTIONS,
|
||||
landmark_drawing_spec=drawing_spec,
|
||||
connection_drawing_spec=drawing_spec)
|
||||
return annotated_image
|
||||
mp_drawing.draw_landmarks(
|
||||
image=annotated_image,
|
||||
landmark_list=face_landmarks,
|
||||
connections=mp_face_mesh.FACE_CONNECTIONS,
|
||||
landmark_drawing_spec=drawing_spec,
|
||||
connection_drawing_spec=drawing_spec)
|
||||
return annotated_image
|
||||
|
||||
title = "Face Mesh"
|
||||
description = "demo for Face Mesh. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
|
||||
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2006.10962'>Attention Mesh: High-fidelity Face Mesh Prediction in Real-time</a> | <a href='https://github.com/google/mediapipe'>Github Repo</a></p>"
|
||||
|
||||
gr.Interface(
|
||||
inference,
|
||||
[gr.inputs.Image(label="Input")],
|
||||
gr.outputs.Image(type="pil", label="Output"),
|
||||
title=title,
|
||||
description=description,
|
||||
article=article,
|
||||
examples=[
|
||||
["face1.jpg"],
|
||||
["face2.jpg"]
|
||||
]).launch(debug=True)
|
||||
145
mediapipe/BUILD
145
mediapipe/BUILD
|
|
@ -1,145 +0,0 @@
|
|||
# Copyright 2019 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
# Note: yes, these need to use "//external:android/crosstool", not
|
||||
# @androidndk//:default_crosstool.
|
||||
|
||||
config_setting(
|
||||
name = "android",
|
||||
values = {"crosstool_top": "//external:android/crosstool"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "android_x86",
|
||||
values = {
|
||||
"crosstool_top": "//external:android/crosstool",
|
||||
"cpu": "x86",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "android_x86_64",
|
||||
values = {
|
||||
"crosstool_top": "//external:android/crosstool",
|
||||
"cpu": "x86_64",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "android_armeabi",
|
||||
values = {
|
||||
"crosstool_top": "//external:android/crosstool",
|
||||
"cpu": "armeabi",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "android_arm",
|
||||
values = {
|
||||
"crosstool_top": "//external:android/crosstool",
|
||||
"cpu": "armeabi-v7a",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "android_arm64",
|
||||
values = {
|
||||
"crosstool_top": "//external:android/crosstool",
|
||||
"cpu": "arm64-v8a",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Note: this cannot just match "apple_platform_type": "macos" because that option
|
||||
# defaults to "macos" even when building on Linux!
|
||||
alias(
|
||||
name = "macos",
|
||||
actual = select({
|
||||
":macos_i386": ":macos_i386",
|
||||
":macos_x86_64": ":macos_x86_64",
|
||||
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Note: this also matches on crosstool_top so that it does not produce ambiguous
|
||||
# selectors when used together with "android".
|
||||
config_setting(
|
||||
name = "ios",
|
||||
values = {
|
||||
"crosstool_top": "@bazel_tools//tools/cpp:toolchain",
|
||||
"apple_platform_type": "ios",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "apple",
|
||||
actual = select({
|
||||
":macos": ":macos",
|
||||
":ios": ":ios",
|
||||
"//conditions:default": ":ios", # Arbitrarily chosen from above.
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "macos_i386",
|
||||
values = {
|
||||
"apple_platform_type": "macos",
|
||||
"cpu": "darwin",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "macos_x86_64",
|
||||
values = {
|
||||
"apple_platform_type": "macos",
|
||||
"cpu": "darwin_x86_64",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
[
|
||||
config_setting(
|
||||
name = arch,
|
||||
values = {"cpu": arch},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
for arch in [
|
||||
"ios_i386",
|
||||
"ios_x86_64",
|
||||
"ios_armv7",
|
||||
"ios_arm64",
|
||||
"ios_arm64e",
|
||||
]
|
||||
]
|
||||
|
||||
config_setting(
|
||||
name = "windows",
|
||||
values = {"cpu": "x64_windows"},
|
||||
)
|
||||
|
||||
exports_files(
|
||||
["provisioning_profile.mobileprovision"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
|
@ -1,137 +0,0 @@
|
|||
{
|
||||
"additionalFilePaths" : [
|
||||
"/BUILD",
|
||||
"mediapipe/BUILD",
|
||||
"mediapipe/examples/ios/common/BUILD",
|
||||
"mediapipe/examples/ios/facedetectioncpu/BUILD",
|
||||
"mediapipe/examples/ios/facedetectiongpu/BUILD",
|
||||
"mediapipe/examples/ios/faceeffect/BUILD",
|
||||
"mediapipe/examples/ios/facemeshgpu/BUILD",
|
||||
"mediapipe/examples/ios/handdetectiongpu/BUILD",
|
||||
"mediapipe/examples/ios/handtrackinggpu/BUILD",
|
||||
"mediapipe/examples/ios/helloworld/BUILD",
|
||||
"mediapipe/examples/ios/holistictrackinggpu/BUILD",
|
||||
"mediapipe/examples/ios/iristrackinggpu/BUILD",
|
||||
"mediapipe/examples/ios/objectdetectioncpu/BUILD",
|
||||
"mediapipe/examples/ios/objectdetectiongpu/BUILD",
|
||||
"mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD",
|
||||
"mediapipe/examples/ios/posetrackinggpu/BUILD",
|
||||
"mediapipe/examples/ios/selfiesegmentationgpu/BUILD",
|
||||
"mediapipe/framework/BUILD",
|
||||
"mediapipe/gpu/BUILD",
|
||||
"mediapipe/objc/BUILD",
|
||||
"mediapipe/objc/testing/app/BUILD"
|
||||
],
|
||||
"buildTargets" : [
|
||||
"//mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp",
|
||||
"//mediapipe/examples/ios/facedetectiongpu:FaceDetectionGpuApp",
|
||||
"//mediapipe/examples/ios/faceeffect:FaceEffectApp",
|
||||
"//mediapipe/examples/ios/facemeshgpu:FaceMeshGpuApp",
|
||||
"//mediapipe/examples/ios/handdetectiongpu:HandDetectionGpuApp",
|
||||
"//mediapipe/examples/ios/handtrackinggpu:HandTrackingGpuApp",
|
||||
"//mediapipe/examples/ios/helloworld:HelloWorldApp",
|
||||
"//mediapipe/examples/ios/holistictrackinggpu:HolisticTrackingGpuApp",
|
||||
"//mediapipe/examples/ios/iristrackinggpu:IrisTrackingGpuApp",
|
||||
"//mediapipe/examples/ios/objectdetectioncpu:ObjectDetectionCpuApp",
|
||||
"//mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp",
|
||||
"//mediapipe/examples/ios/objectdetectiontrackinggpu:ObjectDetectionTrackingGpuApp",
|
||||
"//mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp",
|
||||
"//mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp",
|
||||
"//mediapipe/objc:mediapipe_framework_ios"
|
||||
],
|
||||
"optionSet" : {
|
||||
"BazelBuildOptionsDebug" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildOptionsRelease" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildStartupOptionsDebug" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BazelBuildStartupOptionsRelease" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BuildActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"BuildActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"CommandlineArguments" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"EnvironmentVariables" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"LaunchActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"LaunchActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"ProjectGenerationBazelStartupOptions" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"TestActionPostActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
},
|
||||
"TestActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
||||
}
|
||||
},
|
||||
"projectName" : "Mediapipe",
|
||||
"sourceFilters" : [
|
||||
"mediapipe",
|
||||
"mediapipe/calculators",
|
||||
"mediapipe/calculators/core",
|
||||
"mediapipe/calculators/image",
|
||||
"mediapipe/calculators/internal",
|
||||
"mediapipe/calculators/tflite",
|
||||
"mediapipe/calculators/util",
|
||||
"mediapipe/examples",
|
||||
"mediapipe/examples/ios",
|
||||
"mediapipe/examples/ios/common",
|
||||
"mediapipe/examples/ios/common/Base.lproj",
|
||||
"mediapipe/examples/ios/facedetectioncpu",
|
||||
"mediapipe/examples/ios/facedetectiongpu",
|
||||
"mediapipe/examples/ios/faceeffect",
|
||||
"mediapipe/examples/ios/faceeffect/Base.lproj",
|
||||
"mediapipe/examples/ios/handdetectiongpu",
|
||||
"mediapipe/examples/ios/handtrackinggpu",
|
||||
"mediapipe/examples/ios/helloworld",
|
||||
"mediapipe/examples/ios/holistictrackinggpu",
|
||||
"mediapipe/examples/ios/iristrackinggpu",
|
||||
"mediapipe/examples/ios/objectdetectioncpu",
|
||||
"mediapipe/examples/ios/objectdetectiongpu",
|
||||
"mediapipe/examples/ios/posetrackinggpu",
|
||||
"mediapipe/examples/ios/selfiesegmentationgpu",
|
||||
"mediapipe/framework",
|
||||
"mediapipe/framework/deps",
|
||||
"mediapipe/framework/formats",
|
||||
"mediapipe/framework/formats/annotation",
|
||||
"mediapipe/framework/formats/object_detection",
|
||||
"mediapipe/framework/port",
|
||||
"mediapipe/framework/profiler",
|
||||
"mediapipe/framework/stream_handler",
|
||||
"mediapipe/framework/tool",
|
||||
"mediapipe/gpu",
|
||||
"mediapipe/graphs",
|
||||
"mediapipe/graphs/edge_detection",
|
||||
"mediapipe/graphs/face_detection",
|
||||
"mediapipe/graphs/face_geometry",
|
||||
"mediapipe/graphs/hand_tracking",
|
||||
"mediapipe/graphs/object_detection",
|
||||
"mediapipe/graphs/pose_tracking",
|
||||
"mediapipe/graphs/selfie_segmentation",
|
||||
"mediapipe/models",
|
||||
"mediapipe/modules",
|
||||
"mediapipe/objc",
|
||||
"mediapipe/util",
|
||||
"mediapipe/util/android",
|
||||
"mediapipe/util/android/file",
|
||||
"mediapipe/util/android/file/base",
|
||||
"mediapipe/util/tflite",
|
||||
"mediapipe/util/tflite/operations"
|
||||
]
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
{
|
||||
"configDefaults" : {
|
||||
"optionSet" : {
|
||||
"CLANG_CXX_LANGUAGE_STANDARD" : {
|
||||
"p" : "c++14"
|
||||
}
|
||||
}
|
||||
},
|
||||
"packages" : [
|
||||
"",
|
||||
"mediapipe",
|
||||
"mediapipe/examples/ios",
|
||||
"mediapipe/examples/ios/facedetectioncpu",
|
||||
"mediapipe/examples/ios/facedetectiongpu",
|
||||
"mediapipe/examples/ios/faceeffect",
|
||||
"mediapipe/examples/ios/facemeshgpu",
|
||||
"mediapipe/examples/ios/handdetectiongpu",
|
||||
"mediapipe/examples/ios/handtrackinggpu",
|
||||
"mediapipe/examples/ios/holistictrackinggpu",
|
||||
"mediapipe/examples/ios/iristrackinggpu",
|
||||
"mediapipe/examples/ios/objectdetectioncpu",
|
||||
"mediapipe/examples/ios/objectdetectiongpu",
|
||||
"mediapipe/examples/ios/objectdetectiontrackinggpu",
|
||||
"mediapipe/examples/ios/posetrackinggpu",
|
||||
"mediapipe/examples/ios/selfiesegmentationgpu",
|
||||
"mediapipe/objc"
|
||||
],
|
||||
"projectName" : "Mediapipe",
|
||||
"workspaceRoot" : "../.."
|
||||
}
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
"""Copyright 2019 - 2020 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
|
@ -1,357 +0,0 @@
|
|||
# Copyright 2019, 2021 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "mfcc_mel_calculators_proto",
|
||||
srcs = ["mfcc_mel_calculators.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "mfcc_mel_calculators_cc_proto",
|
||||
srcs = ["mfcc_mel_calculators.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":mfcc_mel_calculators_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "rational_factor_resample_calculator_proto",
|
||||
srcs = ["rational_factor_resample_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "rational_factor_resample_calculator_cc_proto",
|
||||
srcs = ["rational_factor_resample_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":rational_factor_resample_calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "spectrogram_calculator_proto",
|
||||
srcs = ["spectrogram_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "spectrogram_calculator_cc_proto",
|
||||
srcs = ["spectrogram_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":spectrogram_calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "stabilized_log_calculator_proto",
|
||||
srcs = ["stabilized_log_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "stabilized_log_calculator_cc_proto",
|
||||
srcs = ["stabilized_log_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":stabilized_log_calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "time_series_framer_calculator_proto",
|
||||
srcs = ["time_series_framer_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "time_series_framer_calculator_cc_proto",
|
||||
srcs = ["time_series_framer_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":time_series_framer_calculator_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "audio_decoder_calculator",
|
||||
srcs = ["audio_decoder_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:audio_decoder",
|
||||
"//mediapipe/util:audio_decoder_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "basic_time_series_calculators",
|
||||
srcs = ["basic_time_series_calculators.cc"],
|
||||
hdrs = ["basic_time_series_calculators.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mfcc_mel_calculators",
|
||||
srcs = ["mfcc_mel_calculators.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":mfcc_mel_calculators_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_audio_tools//audio/dsp/mfcc",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rational_factor_resample_calculator",
|
||||
srcs = ["rational_factor_resample_calculator.cc"],
|
||||
hdrs = ["rational_factor_resample_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":rational_factor_resample_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_audio_tools//audio/dsp:resampler",
|
||||
"@com_google_audio_tools//audio/dsp:resampler_q",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stabilized_log_calculator",
|
||||
srcs = ["stabilized_log_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":stabilized_log_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_util",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "spectrogram_calculator",
|
||||
srcs = ["spectrogram_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":spectrogram_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:source_location",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@com_google_audio_tools//audio/dsp/spectrogram",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "time_series_framer_calculator",
|
||||
srcs = ["time_series_framer_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":time_series_framer_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "audio_decoder_calculator_test",
|
||||
srcs = ["audio_decoder_calculator_test.cc"],
|
||||
data = ["//mediapipe/calculators/audio/testdata:test_audios"],
|
||||
deps = [
|
||||
":audio_decoder_calculator",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "basic_time_series_calculators_test",
|
||||
srcs = ["basic_time_series_calculators_test.cc"],
|
||||
deps = [
|
||||
":basic_time_series_calculators",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "mfcc_mel_calculators_test",
|
||||
srcs = ["mfcc_mel_calculators_test.cc"],
|
||||
deps = [
|
||||
":mfcc_mel_calculators",
|
||||
":mfcc_mel_calculators_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "spectrogram_calculator_test",
|
||||
srcs = ["spectrogram_calculator_test.cc"],
|
||||
deps = [
|
||||
":spectrogram_calculator",
|
||||
":spectrogram_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:benchmark",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@com_google_audio_tools//audio/dsp:number_util",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "stabilized_log_calculator_test",
|
||||
srcs = ["stabilized_log_calculator_test.cc"],
|
||||
deps = [
|
||||
":stabilized_log_calculator",
|
||||
":stabilized_log_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "time_series_framer_calculator_test",
|
||||
srcs = ["time_series_framer_calculator_test.cc"],
|
||||
deps = [
|
||||
":time_series_framer_calculator",
|
||||
":time_series_framer_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "rational_factor_resample_calculator_test",
|
||||
srcs = ["rational_factor_resample_calculator_test.cc"],
|
||||
deps = [
|
||||
":rational_factor_resample_calculator",
|
||||
":rational_factor_resample_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@com_google_audio_tools//audio/dsp:signal_vector_util",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/audio_decoder.h"
|
||||
#include "mediapipe/util/audio_decoder.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// The AudioDecoderCalculator decodes an audio stream of the media file. It
|
||||
// produces two output streams contain audio packets and the header infomation.
|
||||
//
|
||||
// Output Streams:
|
||||
// AUDIO: Output audio frames (Matrix).
|
||||
// AUDIO_HEADER:
|
||||
// Optional audio header information output
|
||||
// Input Side Packets:
|
||||
// INPUT_FILE_PATH: The input file path.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "AudioDecoderCalculator"
|
||||
// input_side_packet: "INPUT_FILE_PATH:input_file_path"
|
||||
// output_stream: "AUDIO:audio"
|
||||
// output_stream: "AUDIO_HEADER:audio_header"
|
||||
// node_options {
|
||||
// [type.googleapis.com/mediapipe.AudioDecoderOptions]: {
|
||||
// audio_stream { stream_index: 0 }
|
||||
// start_time: 0
|
||||
// end_time: 1
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// TODO: support decoding multiple streams.
|
||||
class AudioDecoderCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<AudioDecoder> decoder_;
|
||||
};
|
||||
|
||||
absl::Status AudioDecoderCalculator::GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set<std::string>();
|
||||
if (cc->InputSidePackets().HasTag("OPTIONS")) {
|
||||
cc->InputSidePackets().Tag("OPTIONS").Set<mediapipe::AudioDecoderOptions>();
|
||||
}
|
||||
cc->Outputs().Tag("AUDIO").Set<Matrix>();
|
||||
if (cc->Outputs().HasTag("AUDIO_HEADER")) {
|
||||
cc->Outputs().Tag("AUDIO_HEADER").SetNone();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AudioDecoderCalculator::Open(CalculatorContext* cc) {
|
||||
const std::string& input_file_path =
|
||||
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get<std::string>();
|
||||
const auto& decoder_options =
|
||||
tool::RetrieveOptions(cc->Options<mediapipe::AudioDecoderOptions>(),
|
||||
cc->InputSidePackets(), "OPTIONS");
|
||||
decoder_ = absl::make_unique<AudioDecoder>();
|
||||
MP_RETURN_IF_ERROR(decoder_->Initialize(input_file_path, decoder_options));
|
||||
std::unique_ptr<mediapipe::TimeSeriesHeader> header =
|
||||
absl::make_unique<mediapipe::TimeSeriesHeader>();
|
||||
if (decoder_->FillAudioHeader(decoder_options.audio_stream(0), header.get())
|
||||
.ok()) {
|
||||
// Only pass on a header if the decoder could actually produce one.
|
||||
// otherwise, the header will be empty.
|
||||
cc->Outputs().Tag("AUDIO_HEADER").SetHeader(Adopt(header.release()));
|
||||
}
|
||||
cc->Outputs().Tag("AUDIO_HEADER").Close();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AudioDecoderCalculator::Process(CalculatorContext* cc) {
|
||||
Packet data;
|
||||
int options_index = -1;
|
||||
auto status = decoder_->GetData(&options_index, &data);
|
||||
if (status.ok()) {
|
||||
cc->Outputs().Tag("AUDIO").AddPacket(data);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
absl::Status AudioDecoderCalculator::Close(CalculatorContext* cc) {
|
||||
return decoder_->Close();
|
||||
}
|
||||
|
||||
REGISTER_CALCULATOR(AudioDecoderCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,150 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
TEST(AudioDecoderCalculatorTest, TestWAV) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "AudioDecoderCalculator"
|
||||
input_side_packet: "INPUT_FILE_PATH:input_file_path"
|
||||
output_stream: "AUDIO:audio"
|
||||
output_stream: "AUDIO_HEADER:audio_header"
|
||||
node_options {
|
||||
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
|
||||
audio_stream { stream_index: 0 }
|
||||
}
|
||||
})pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_44100_mono_2_sec_wav.audio"));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
MP_EXPECT_OK(runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
|
||||
const mediapipe::TimeSeriesHeader& header =
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.Get<mediapipe::TimeSeriesHeader>();
|
||||
EXPECT_EQ(44100, header.sample_rate());
|
||||
EXPECT_EQ(1, header.num_channels());
|
||||
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
|
||||
std::ceil(44100.0 * 2 / 2048));
|
||||
}
|
||||
|
||||
TEST(AudioDecoderCalculatorTest, Test48KWAV) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "AudioDecoderCalculator"
|
||||
input_side_packet: "INPUT_FILE_PATH:input_file_path"
|
||||
output_stream: "AUDIO:audio"
|
||||
output_stream: "AUDIO_HEADER:audio_header"
|
||||
node_options {
|
||||
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
|
||||
audio_stream { stream_index: 0 }
|
||||
}
|
||||
})pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio"));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
MP_EXPECT_OK(runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
|
||||
const mediapipe::TimeSeriesHeader& header =
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.Get<mediapipe::TimeSeriesHeader>();
|
||||
EXPECT_EQ(48000, header.sample_rate());
|
||||
EXPECT_EQ(2, header.num_channels());
|
||||
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
|
||||
std::ceil(48000.0 * 2 / 1024));
|
||||
}
|
||||
|
||||
TEST(AudioDecoderCalculatorTest, TestMP3) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "AudioDecoderCalculator"
|
||||
input_side_packet: "INPUT_FILE_PATH:input_file_path"
|
||||
output_stream: "AUDIO:audio"
|
||||
output_stream: "AUDIO_HEADER:audio_header"
|
||||
node_options {
|
||||
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
|
||||
audio_stream { stream_index: 0 }
|
||||
}
|
||||
})pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio"));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
MP_EXPECT_OK(runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
|
||||
const mediapipe::TimeSeriesHeader& header =
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.Get<mediapipe::TimeSeriesHeader>();
|
||||
EXPECT_EQ(44100, header.sample_rate());
|
||||
EXPECT_EQ(2, header.num_channels());
|
||||
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
|
||||
std::ceil(44100.0 * 2 / 1152));
|
||||
}
|
||||
|
||||
TEST(AudioDecoderCalculatorTest, TestAAC) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "AudioDecoderCalculator"
|
||||
input_side_packet: "INPUT_FILE_PATH:input_file_path"
|
||||
output_stream: "AUDIO:audio"
|
||||
output_stream: "AUDIO_HEADER:audio_header"
|
||||
node_options {
|
||||
[type.googleapis.com/mediapipe.AudioDecoderOptions]: {
|
||||
audio_stream { stream_index: 0 }
|
||||
}
|
||||
})pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio"));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
MP_EXPECT_OK(runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.ValidateAsType<mediapipe::TimeSeriesHeader>());
|
||||
const mediapipe::TimeSeriesHeader& header =
|
||||
runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
.header.Get<mediapipe::TimeSeriesHeader>();
|
||||
EXPECT_EQ(44100, header.sample_rate());
|
||||
EXPECT_EQ(2, header.num_channels());
|
||||
EXPECT_TRUE(runner.Outputs().Tag("AUDIO").packets.size() >=
|
||||
std::ceil(44100.0 * 2 / 1024));
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,405 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Basic Calculators that operate on TimeSeries streams.
|
||||
#include "mediapipe/calculators/audio/basic_time_series_calculators.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
static bool SafeMultiply(int x, int y, int* result) {
|
||||
static_assert(sizeof(int64) >= 2 * sizeof(int),
|
||||
"Unable to detect overflow after multiplication");
|
||||
const int64 big = static_cast<int64>(x) * static_cast<int64>(y);
|
||||
if (big > static_cast<int64>(INT_MIN) && big < static_cast<int64>(INT_MAX)) {
|
||||
if (result != nullptr) *result = static_cast<int>(big);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status BasicTimeSeriesCalculatorBase::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Input stream with TimeSeriesHeader.
|
||||
);
|
||||
cc->Outputs().Index(0).Set<Matrix>(
|
||||
// Output stream with TimeSeriesHeader.
|
||||
);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status BasicTimeSeriesCalculatorBase::Open(CalculatorContext* cc) {
|
||||
TimeSeriesHeader input_header;
|
||||
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
|
||||
auto output_header = new TimeSeriesHeader(input_header);
|
||||
MP_RETURN_IF_ERROR(MutateHeader(output_header));
|
||||
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
|
||||
|
||||
cc->SetOffset(0);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status BasicTimeSeriesCalculatorBase::Process(CalculatorContext* cc) {
|
||||
const Matrix& input = cc->Inputs().Index(0).Get<Matrix>();
|
||||
MP_RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader(
|
||||
input, cc->Inputs().Index(0).Header().Get<TimeSeriesHeader>()));
|
||||
|
||||
std::unique_ptr<Matrix> output(new Matrix(ProcessMatrix(input)));
|
||||
MP_RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader(
|
||||
*output, cc->Outputs().Index(0).Header().Get<TimeSeriesHeader>()));
|
||||
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status BasicTimeSeriesCalculatorBase::MutateHeader(
|
||||
TimeSeriesHeader* output_header) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Calculator to sum an input time series across channels. This is
|
||||
// useful for e.g. computing 'summary SAI' pitchogram features.
|
||||
//
|
||||
// Options proto: None.
|
||||
class SumTimeSeriesAcrossChannelsCalculator
|
||||
: public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_channels(1);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.colwise().sum();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(SumTimeSeriesAcrossChannelsCalculator);
|
||||
|
||||
// Calculator to average an input time series across channels. This is
|
||||
// useful for e.g. converting stereo or multi-channel files to mono.
|
||||
//
|
||||
// Options proto: None.
|
||||
class AverageTimeSeriesAcrossChannelsCalculator
|
||||
: public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_channels(1);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.colwise().mean();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(AverageTimeSeriesAcrossChannelsCalculator);
|
||||
|
||||
// Calculator to convert a (temporal) summary SAI stream (a single-channel
|
||||
// stream output by SumTimeSeriesAcrossChannelsCalculator) into pitchogram
|
||||
// frames by transposing the input packets, swapping the time and channel axes.
|
||||
//
|
||||
// Options proto: None.
|
||||
class SummarySaiToPitchogramCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
if (output_header->num_channels() != 1) {
|
||||
return tool::StatusInvalid(
|
||||
absl::StrCat("Expected single-channel input, got ",
|
||||
output_header->num_channels()));
|
||||
}
|
||||
output_header->set_num_channels(output_header->num_samples());
|
||||
output_header->set_num_samples(1);
|
||||
output_header->set_sample_rate(output_header->packet_rate());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.transpose();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(SummarySaiToPitchogramCalculator);
|
||||
|
||||
// Calculator to reverse the order of channels in TimeSeries packets.
|
||||
// This is useful for e.g. interfacing with the speech pipeline which uses the
|
||||
// opposite convention to the hearing filterbanks.
|
||||
//
|
||||
// Options proto: None.
|
||||
class ReverseChannelOrderCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.colwise().reverse();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(ReverseChannelOrderCalculator);
|
||||
|
||||
// Calculator to flatten all samples in a TimeSeries packet down into
|
||||
// a single 'sample' vector. This is useful for e.g. stacking several
|
||||
// frames of features into a single feature vector.
|
||||
//
|
||||
// Options proto: None.
|
||||
class FlattenPacketCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
const int num_input_channels = output_header->num_channels();
|
||||
const int num_input_samples = output_header->num_samples();
|
||||
RET_CHECK(num_input_channels >= 0)
|
||||
<< "FlattenPacketCalculator: num_input_channels < 0";
|
||||
RET_CHECK(num_input_samples >= 0)
|
||||
<< "FlattenPacketCalculator: num_input_samples < 0";
|
||||
int output_num_channels;
|
||||
RET_CHECK(SafeMultiply(num_input_channels, num_input_samples,
|
||||
&output_num_channels))
|
||||
<< "FlattenPacketCalculator: Multiplication failed.";
|
||||
output_header->set_num_channels(output_num_channels);
|
||||
output_header->set_num_samples(1);
|
||||
output_header->set_sample_rate(output_header->packet_rate());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
// Flatten by interleaving channels so that full samples are
|
||||
// stacked on top of each other instead of interleaving samples
|
||||
// from the same channel.
|
||||
Matrix output(input_matrix.size(), 1);
|
||||
for (int sample = 0; sample < input_matrix.cols(); ++sample) {
|
||||
output.middleRows(sample * input_matrix.rows(), input_matrix.rows()) =
|
||||
input_matrix.col(sample);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(FlattenPacketCalculator);
|
||||
|
||||
// Calculator to subtract the within-packet mean for each channel from each
|
||||
// corresponding channel.
|
||||
//
|
||||
// Options proto: None.
|
||||
class SubtractMeanCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
Matrix mean = input_matrix.rowwise().mean();
|
||||
return input_matrix - mean.replicate(1, input_matrix.cols());
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(SubtractMeanCalculator);
|
||||
|
||||
// Calculator to subtract the mean over all values (across all times and
|
||||
// channels) in a Packet from the values in that Packet.
|
||||
//
|
||||
// Options proto: None.
|
||||
class SubtractMeanAcrossChannelsCalculator
|
||||
: public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
auto mean = input_matrix.mean();
|
||||
return (input_matrix.array() - mean).matrix();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(SubtractMeanAcrossChannelsCalculator);
|
||||
|
||||
// Calculator to divide all values in a Packet by the average value across all
|
||||
// times and channels in the packet. This is useful for normalizing
|
||||
// nonnegative quantities like power, but might cause unexpected results if used
|
||||
// with Packets that can contain negative numbers.
|
||||
//
|
||||
// If mean is exactly zero, the output will be a matrix of all ones, because
|
||||
// that's what happens in other cases where all values are equal.
|
||||
//
|
||||
// Options proto: None.
|
||||
class DivideByMeanAcrossChannelsCalculator
|
||||
: public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
auto mean = input_matrix.mean();
|
||||
|
||||
if (mean != 0) {
|
||||
return input_matrix / mean;
|
||||
|
||||
// When used with nonnegative matrices, the mean will only be zero if the
|
||||
// entire matrix is exactly zero. If mean is exactly zero, the output will
|
||||
// be a matrix of all ones, because that's what happens in other cases
|
||||
// where
|
||||
// all values are equal.
|
||||
} else {
|
||||
return Matrix::Ones(input_matrix.rows(), input_matrix.cols());
|
||||
}
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(DivideByMeanAcrossChannelsCalculator);
|
||||
|
||||
// Calculator to calculate the mean for each channel.
|
||||
//
|
||||
// Options proto: None.
|
||||
class MeanCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_samples(1);
|
||||
output_header->set_sample_rate(output_header->packet_rate());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.rowwise().mean();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(MeanCalculator);
|
||||
|
||||
// Calculator to calculate the uncorrected sample standard deviation in each
|
||||
// channel, independently for each Packet. I.e. divide by the number of samples
|
||||
// in the Packet, not (<number of samples> - 1).
|
||||
//
|
||||
// Options proto: None.
|
||||
class StandardDeviationCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_samples(1);
|
||||
output_header->set_sample_rate(output_header->packet_rate());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
Eigen::VectorXf mean = input_matrix.rowwise().mean();
|
||||
return (input_matrix.colwise() - mean).rowwise().norm() /
|
||||
sqrt(input_matrix.cols());
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(StandardDeviationCalculator);
|
||||
|
||||
// Calculator to calculate the covariance matrix. If the input matrix
|
||||
// has N channels, the output matrix will be an N by N symmetric
|
||||
// matrix.
|
||||
//
|
||||
// Options proto: None.
|
||||
class CovarianceCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_samples(output_header->num_channels());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
auto mean = input_matrix.rowwise().mean();
|
||||
auto zero_mean_input =
|
||||
input_matrix - mean.replicate(1, input_matrix.cols());
|
||||
return (zero_mean_input * zero_mean_input.transpose()) /
|
||||
input_matrix.cols();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(CovarianceCalculator);
|
||||
|
||||
// Calculator to get the per column L2 norm of an input time series.
|
||||
//
|
||||
// Options proto: None.
|
||||
class L2NormCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
output_header->set_num_channels(1);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.colwise().norm();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(L2NormCalculator);
|
||||
|
||||
// Calculator to convert each column of a matrix to a unit vector.
|
||||
//
|
||||
// Options proto: None.
|
||||
class L2NormalizeColumnCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.colwise().normalized();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(L2NormalizeColumnCalculator);
|
||||
|
||||
// Calculator to apply L2 normalization to the input matrix.
|
||||
//
|
||||
// Returns the matrix as is if the RMS is <= 1E-8.
|
||||
// Options proto: None.
|
||||
class L2NormalizeCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
constexpr double kEpsilon = 1e-8;
|
||||
double rms = std::sqrt(input_matrix.array().square().mean());
|
||||
if (rms <= kEpsilon) {
|
||||
return input_matrix;
|
||||
}
|
||||
return input_matrix / rms;
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(L2NormalizeCalculator);
|
||||
|
||||
// Calculator to apply Peak normalization to the input matrix.
|
||||
//
|
||||
// Returns the matrix as is if the peak is <= 1E-8.
|
||||
// Options proto: None.
|
||||
class PeakNormalizeCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
constexpr double kEpsilon = 1e-8;
|
||||
double max_pcm = input_matrix.cwiseAbs().maxCoeff();
|
||||
if (max_pcm <= kEpsilon) {
|
||||
return input_matrix;
|
||||
}
|
||||
return input_matrix / max_pcm;
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(PeakNormalizeCalculator);
|
||||
|
||||
// Calculator to compute the elementwise square of an input time series.
|
||||
//
|
||||
// Options proto: None.
|
||||
class ElementwiseSquareCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.array().square();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(ElementwiseSquareCalculator);
|
||||
|
||||
// Calculator that outputs first floor(num_samples / 2) of the samples.
|
||||
//
|
||||
// Options proto: None.
|
||||
class FirstHalfSlicerCalculator : public BasicTimeSeriesCalculatorBase {
|
||||
protected:
|
||||
absl::Status MutateHeader(TimeSeriesHeader* output_header) final {
|
||||
const int num_input_samples = output_header->num_samples();
|
||||
RET_CHECK(num_input_samples >= 0)
|
||||
<< "FirstHalfSlicerCalculator: num_input_samples < 0";
|
||||
output_header->set_num_samples(num_input_samples / 2);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Matrix ProcessMatrix(const Matrix& input_matrix) final {
|
||||
return input_matrix.block(0, 0, input_matrix.rows(),
|
||||
input_matrix.cols() / 2);
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(FirstHalfSlicerCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Abstract base class for basic MediaPipe calculators that operate on
|
||||
// TimeSeries streams and don't require any Options protos.
|
||||
// Subclasses must override ProcessMatrix, and optionally
|
||||
// MutateHeader.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_
|
||||
#define MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class BasicTimeSeriesCalculatorBase : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc) final;
|
||||
absl::Status Process(CalculatorContext* cc) final;
|
||||
|
||||
protected:
|
||||
// Open() calls this method to mutate the output stream header. The input
|
||||
// to this function will contain a copy of the input stream header, so
|
||||
// subclasses that do not need to mutate the header do not need to override
|
||||
// it.
|
||||
virtual absl::Status MutateHeader(TimeSeriesHeader* output_header);
|
||||
|
||||
// Process() calls this method on each packet to compute the output matrix.
|
||||
virtual Matrix ProcessMatrix(const Matrix& input_matrix) = 0;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_AUDIO_BASIC_TIME_SERIES_CALCULATORS_H_
|
||||
|
|
@ -1,515 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class SumTimeSeriesAcrossChannelsCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "SumTimeSeriesAcrossChannelsCalculator";
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, IsNoOpOnSingleChannelInputs) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
|
||||
const Matrix input =
|
||||
Matrix::Random(header.num_channels(), header.num_samples());
|
||||
|
||||
Test(header, {input}, header, {input});
|
||||
}
|
||||
|
||||
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, ConstantPacket) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
|
||||
TimeSeriesHeader output_header(header);
|
||||
output_header.set_num_channels(1);
|
||||
|
||||
Test(header,
|
||||
{Matrix::Constant(header.num_channels(), header.num_samples(), 1)},
|
||||
output_header,
|
||||
{Matrix::Constant(1, header.num_samples(), header.num_channels())});
|
||||
}
|
||||
|
||||
TEST_F(SumTimeSeriesAcrossChannelsCalculatorTest, MultiplePackets) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
|
||||
Matrix in(header.num_channels(), header.num_samples());
|
||||
in << 10, -1, -1, 0, 0, 20, -2, 0, 1, 0, 30, -3, 1, 0, 12;
|
||||
|
||||
TimeSeriesHeader output_header(header);
|
||||
output_header.set_num_channels(1);
|
||||
Matrix out(1, header.num_samples());
|
||||
out << 60, -6, 0, 1, 12;
|
||||
|
||||
Test(header, {in, 2 * in, in + Matrix::Constant(in.rows(), in.cols(), 3.5f)},
|
||||
output_header,
|
||||
{out, 2 * out,
|
||||
out + Matrix::Constant(out.rows(), out.cols(),
|
||||
3.5 * header.num_channels())});
|
||||
}
|
||||
|
||||
class AverageTimeSeriesAcrossChannelsCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "AverageTimeSeriesAcrossChannelsCalculator";
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(AverageTimeSeriesAcrossChannelsCalculatorTest,
|
||||
IsNoOpOnSingleChannelInputs) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
|
||||
const Matrix input =
|
||||
Matrix::Random(header.num_channels(), header.num_samples());
|
||||
|
||||
Test(header, {input}, header, {input});
|
||||
}
|
||||
|
||||
TEST_F(AverageTimeSeriesAcrossChannelsCalculatorTest, ConstantPacket) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 3 num_samples: 5");
|
||||
TimeSeriesHeader output_header(header);
|
||||
output_header.set_num_channels(1);
|
||||
|
||||
Matrix input =
|
||||
Matrix::Constant(header.num_channels(), header.num_samples(), 0.0);
|
||||
input.row(0) = Matrix::Constant(1, header.num_samples(), 1.0);
|
||||
|
||||
Test(
|
||||
header, {input}, output_header,
|
||||
{Matrix::Constant(1, header.num_samples(), 1.0 / header.num_channels())});
|
||||
}
|
||||
|
||||
class SummarySaiToPitchogramCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "SummarySaiToPitchogramCalculator";
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(SummarySaiToPitchogramCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 1 num_samples: 3");
|
||||
Matrix input(1, input_header.num_samples());
|
||||
input << 3, -9, 4;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 5.0 packet_rate: 5.0 num_channels: 3 num_samples: 1");
|
||||
Matrix output(input_header.num_samples(), 1);
|
||||
output << 3, -9, 4;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class ReverseChannelOrderCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "ReverseChannelOrderCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(ReverseChannelOrderCalculatorTest, IsNoOpOnSingleChannelInputs) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 1 num_samples: 5");
|
||||
const Matrix input =
|
||||
Matrix::Random(header.num_channels(), header.num_samples());
|
||||
|
||||
Test(header, {input}, header, {input});
|
||||
}
|
||||
|
||||
TEST_F(ReverseChannelOrderCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 num_channels: 5 num_samples: 2");
|
||||
Matrix input(header.num_channels(), header.num_samples());
|
||||
input.transpose() << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
|
||||
Matrix output(header.num_channels(), header.num_samples());
|
||||
output.transpose() << 5, 4, 3, 2, 1, -5, -4, -3, -2, -1;
|
||||
|
||||
Test(header, {input}, header, {output});
|
||||
}
|
||||
|
||||
class FlattenPacketCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "FlattenPacketCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(FlattenPacketCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input.transpose() << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
|
||||
Matrix output(10, 1);
|
||||
output << 1, 2, 3, 4, 5, -1, -2, -3, -4, -5;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 10.0 packet_rate: 10.0 num_channels: 10 num_samples: 1");
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class SubtractMeanCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "SubtractMeanCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(SubtractMeanCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
Matrix output(input_header.num_channels(), input_header.num_samples());
|
||||
|
||||
// clang-format off
|
||||
input.transpose() << 1, 0, 3, 0, 1,
|
||||
-1, -2, -3, 4, 7;
|
||||
output.transpose() << 1, 1, 3, -2, -3,
|
||||
-1, -1, -3, 2, 3;
|
||||
// clang-format on
|
||||
|
||||
const TimeSeriesHeader output_header = input_header;
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class SubtractMeanAcrossChannelsCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "SubtractMeanAcrossChannelsCalculator";
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(SubtractMeanAcrossChannelsCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_num_samples(2);
|
||||
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
|
||||
// clang-format off
|
||||
input.transpose() << 1.0, 2.0, 3.0,
|
||||
4.0, 5.0, 6.0;
|
||||
output.transpose() << 1.0 - 3.5, 2.0 - 3.5, 3.0 - 3.5,
|
||||
4.0 - 3.5, 5.0 - 3.5, 6.0 - 3.5;
|
||||
// clang-format on
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class DivideByMeanAcrossChannelsCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "DivideByMeanAcrossChannelsCalculator";
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(DivideByMeanAcrossChannelsCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input.transpose() << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0;
|
||||
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_num_samples(2);
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output.transpose() << 1.0 / 3.5, 2.0 / 3.5, 3.0 / 3.5, 4.0 / 3.5, 5.0 / 3.5,
|
||||
6.0 / 3.5;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
TEST_F(DivideByMeanAcrossChannelsCalculatorTest, ReturnsOneForZeroMean) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input.transpose() << -3.0, -2.0, -1.0, 1.0, 2.0, 3.0;
|
||||
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_num_samples(2);
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output.transpose() << 1.0, 1.0, 1.0, 1.0, 1.0, 1.0;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class MeanCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "MeanCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(MeanCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input.transpose() << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0;
|
||||
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_num_samples(1);
|
||||
output_header.set_sample_rate(10.0);
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << (1.0 + 4.0) / 2, (2.0 + 5.0) / 2, (3.0 + 6.0) / 2;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class StandardDeviationCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "StandardDeviationCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(StandardDeviationCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input.transpose() << 0.0, 2.0, 3.0, 4.0, 5.0, 8.0;
|
||||
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_sample_rate(10.0);
|
||||
output_header.set_num_samples(1);
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << sqrt((pow(0.0 - 2.0, 2) + pow(4.0 - 2.0, 2)) / 2),
|
||||
sqrt((pow(2.0 - 3.5, 2) + pow(5.0 - 3.5, 2)) / 2),
|
||||
sqrt((pow(3.0 - 5.5, 2) + pow(8.0 - 5.5, 2)) / 2);
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class CovarianceCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "CovarianceCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(CovarianceCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 3 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
|
||||
// We'll specify in transposed form so we can write one channel at a time.
|
||||
input << 1.0, 3.0, 5.0, 9.0, -1.0, -3.0;
|
||||
|
||||
TimeSeriesHeader output_header(input_header);
|
||||
output_header.set_num_samples(output_header.num_channels());
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << 1, 2, -1, 2, 4, -2, -1, -2, 1;
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class L2NormCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "L2NormCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(L2NormCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 3, 5, 8, 4, 12, -15;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 1 num_samples: 3");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << 5, 13, 17;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class L2NormalizeColumnCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "L2NormalizeColumnCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(L2NormalizeColumnCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
|
||||
// The values in output are column-wise L2 normalized
|
||||
// e.g.
|
||||
// |a| -> |a/sqrt(a^2 + b^2)|
|
||||
// |b| |b/sqrt(a^2 + b^2)|
|
||||
output << 0.51449579000473022, 0.40613847970962524, 0.70710676908493042,
|
||||
0.85749292373657227, 0.91381156444549561, 0.70710676908493042;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class L2NormalizeCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "L2NormalizeCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(L2NormalizeCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
|
||||
// The values in output are L2 normalized
|
||||
// a -> a/sqrt(a^2 + b^2 + c^2 + ...) * sqrt(matrix.cols()*matrix.rows())
|
||||
output << 0.45661166, 0.60881555, 1.21763109, 0.76101943, 1.36983498,
|
||||
1.21763109;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
TEST_F(L2NormalizeCalculatorTest, UnitMatrixStaysUnchanged) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 3 num_samples: 5");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0;
|
||||
|
||||
Test(input_header, {input}, input_header, {input});
|
||||
}
|
||||
|
||||
class PeakNormalizeCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "PeakNormalizeCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(PeakNormalizeCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 0.3, 0.4, 0.8, 0.5, 0.9, 0.8;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << 0.33333333, 0.44444444, 0.88888889, 0.55555556, 1.0, 0.88888889;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
TEST_F(PeakNormalizeCalculatorTest, UnitMatrixStaysUnchanged) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 3 num_samples: 5");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0;
|
||||
|
||||
Test(input_header, {input}, input_header, {input});
|
||||
}
|
||||
|
||||
class ElementwiseSquareCalculatorTest
|
||||
: public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "ElementwiseSquareCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(ElementwiseSquareCalculatorTest, SinglePacket) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
input << 3, 5, 8, 4, 12, -15;
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 8000.0 packet_rate: 5.0 num_channels: 2 num_samples: 3");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output << 9, 25, 64, 16, 144, 225;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
class FirstHalfSlicerCalculatorTest : public BasicTimeSeriesCalculatorTestBase {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "FirstHalfSlicerCalculator"; }
|
||||
};
|
||||
|
||||
TEST_F(FirstHalfSlicerCalculatorTest, SinglePacketEvenNumSamples) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
// clang-format off
|
||||
input.transpose() << 0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9;
|
||||
// clang-format on
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output.transpose() << 0, 1, 2, 3, 4;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
TEST_F(FirstHalfSlicerCalculatorTest, SinglePacketOddNumSamples) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 3");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
// clang-format off
|
||||
input.transpose() << 0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9,
|
||||
0, 0, 0, 0, 0;
|
||||
// clang-format on
|
||||
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output.transpose() << 0, 1, 2, 3, 4;
|
||||
|
||||
Test(input_header, {input}, output_header, {output});
|
||||
}
|
||||
|
||||
TEST_F(FirstHalfSlicerCalculatorTest, MultiplePackets) {
|
||||
const TimeSeriesHeader input_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 2");
|
||||
Matrix input(input_header.num_channels(), input_header.num_samples());
|
||||
// clang-format off
|
||||
input.transpose() << 0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9;
|
||||
// clang-format on
|
||||
const TimeSeriesHeader output_header = ParseTextProtoOrDie<TimeSeriesHeader>(
|
||||
"sample_rate: 20.0 packet_rate: 10.0 num_channels: 5 num_samples: 1");
|
||||
Matrix output(output_header.num_channels(), output_header.num_samples());
|
||||
output.transpose() << 0, 1, 2, 3, 4;
|
||||
|
||||
Test(input_header,
|
||||
{input, 2 * input,
|
||||
input + Matrix::Constant(input.rows(), input.cols(), 3.5f)},
|
||||
output_header,
|
||||
{output, 2 * output,
|
||||
output + Matrix::Constant(output.rows(), output.cols(), 3.5f)});
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,275 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// MediaPipe Calculator wrapper around audio/dsp/mfcc/
|
||||
// classes MelFilterbank (magnitude spectrograms warped to the Mel
|
||||
// approximation of the auditory frequency scale) and Mfcc (Mel Frequency
|
||||
// Cepstral Coefficients, the decorrelated transform of log-Mel-spectrum
|
||||
// commonly used as acoustic features in speech and other audio tasks.
|
||||
// Both calculators expect as input the SQUARED_MAGNITUDE-domain outputs
|
||||
// from the MediaPipe SpectrogramCalculator object.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "audio/dsp/mfcc/mel_filterbank.h"
|
||||
#include "audio/dsp/mfcc/mfcc.h"
|
||||
#include "mediapipe/calculators/audio/mfcc_mel_calculators.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
// Portable version of TimeSeriesHeader's DebugString.
|
||||
std::string PortableDebugString(const TimeSeriesHeader& header) {
|
||||
std::string unsubstituted_header_debug_str = R"(
|
||||
sample_rate: $0
|
||||
num_channels: $1
|
||||
num_samples: $2
|
||||
packet_rate: $3
|
||||
audio_sample_rate: $4
|
||||
)";
|
||||
return absl::Substitute(unsubstituted_header_debug_str, header.sample_rate(),
|
||||
header.num_channels(), header.num_samples(),
|
||||
header.packet_rate(), header.audio_sample_rate());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Abstract base class for Calculators that transform feature vectors on a
|
||||
// frame-by-frame basis.
|
||||
// Subclasses must override pure virtual methods ConfigureTransform and
|
||||
// TransformFrame.
|
||||
// Input and output MediaPipe packets are matrices with one column per frame,
|
||||
// and one row per feature dimension. Each input packet results in an
|
||||
// output packet with the same number of columns (but differing numbers of
|
||||
// rows corresponding to the new feature space).
|
||||
class FramewiseTransformCalculatorBase : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Sequence of Matrices, each column describing a particular time frame,
|
||||
// each row a feature dimension, with TimeSeriesHeader.
|
||||
);
|
||||
cc->Outputs().Index(0).Set<Matrix>(
|
||||
// Sequence of Matrices, each column describing a particular time frame,
|
||||
// each row a feature dimension, with TimeSeriesHeader.
|
||||
);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final;
|
||||
absl::Status Process(CalculatorContext* cc) final;
|
||||
|
||||
int num_output_channels(void) { return num_output_channels_; }
|
||||
|
||||
void set_num_output_channels(int num_output_channels) {
|
||||
num_output_channels_ = num_output_channels;
|
||||
}
|
||||
|
||||
private:
|
||||
// Takes header and options, and sets up state including calling
|
||||
// set_num_output_channels() on the base object.
|
||||
virtual absl::Status ConfigureTransform(const TimeSeriesHeader& header,
|
||||
CalculatorContext* cc) = 0;
|
||||
|
||||
// Takes a vector<double> corresponding to an input frame, and
|
||||
// perform the specific transformation to produce an output frame.
|
||||
virtual void TransformFrame(const std::vector<double>& input,
|
||||
std::vector<double>* output) const = 0;
|
||||
|
||||
private:
|
||||
int num_output_channels_;
|
||||
};
|
||||
|
||||
absl::Status FramewiseTransformCalculatorBase::Open(CalculatorContext* cc) {
|
||||
TimeSeriesHeader input_header;
|
||||
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
|
||||
absl::Status status = ConfigureTransform(input_header, cc);
|
||||
|
||||
auto output_header = new TimeSeriesHeader(input_header);
|
||||
output_header->set_num_channels(num_output_channels_);
|
||||
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
|
||||
|
||||
cc->SetOffset(0);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
absl::Status FramewiseTransformCalculatorBase::Process(CalculatorContext* cc) {
|
||||
const Matrix& input = cc->Inputs().Index(0).Get<Matrix>();
|
||||
const int num_frames = input.cols();
|
||||
std::unique_ptr<Matrix> output(new Matrix(num_output_channels_, num_frames));
|
||||
// The main work here is converting each column of the float Matrix
|
||||
// into a vector of doubles, which is what our target functions from
|
||||
// dsp_core consume, and doing the reverse with their output.
|
||||
std::vector<double> input_frame(input.rows());
|
||||
std::vector<double> output_frame(num_output_channels_);
|
||||
|
||||
for (int frame = 0; frame < num_frames; ++frame) {
|
||||
// Copy input from Eigen::Matrix column to vector<float>.
|
||||
Eigen::Map<Eigen::MatrixXd> input_frame_map(&input_frame[0],
|
||||
input_frame.size(), 1);
|
||||
input_frame_map = input.col(frame).cast<double>();
|
||||
|
||||
// Perform the actual transformation.
|
||||
TransformFrame(input_frame, &output_frame);
|
||||
|
||||
// Copy output from vector<float> to Eigen::Vector.
|
||||
CHECK_EQ(output_frame.size(), num_output_channels_);
|
||||
Eigen::Map<const Eigen::MatrixXd> output_frame_map(&output_frame[0],
|
||||
output_frame.size(), 1);
|
||||
output->col(frame) = output_frame_map.cast<float>();
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Calculator wrapper around the dsp/mfcc/mfcc.cc routine.
|
||||
// Take frames of squared-magnitude spectra from the SpectrogramCalculator
|
||||
// and convert them into Mel Frequency Cepstral Coefficients.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "MfccCalculator"
|
||||
// input_stream: "spectrogram_frames_stream"
|
||||
// output_stream: "mfcc_frames_stream"
|
||||
// options {
|
||||
// [mediapipe.MfccCalculatorOptions.ext] {
|
||||
// mel_spectrum_params {
|
||||
// channel_count: 20
|
||||
// min_frequency_hertz: 125.0
|
||||
// max_frequency_hertz: 3800.0
|
||||
// }
|
||||
// mfcc_count: 13
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class MfccCalculator : public FramewiseTransformCalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
return FramewiseTransformCalculatorBase::GetContract(cc);
|
||||
}
|
||||
|
||||
private:
|
||||
absl::Status ConfigureTransform(const TimeSeriesHeader& header,
|
||||
CalculatorContext* cc) override {
|
||||
MfccCalculatorOptions mfcc_options = cc->Options<MfccCalculatorOptions>();
|
||||
mfcc_.reset(new audio_dsp::Mfcc());
|
||||
int input_length = header.num_channels();
|
||||
// Set up the parameters to the Mfcc object.
|
||||
set_num_output_channels(mfcc_options.mfcc_count());
|
||||
mfcc_->set_dct_coefficient_count(num_output_channels());
|
||||
mfcc_->set_upper_frequency_limit(
|
||||
mfcc_options.mel_spectrum_params().max_frequency_hertz());
|
||||
mfcc_->set_lower_frequency_limit(
|
||||
mfcc_options.mel_spectrum_params().min_frequency_hertz());
|
||||
mfcc_->set_filterbank_channel_count(
|
||||
mfcc_options.mel_spectrum_params().channel_count());
|
||||
// An upstream calculator (such as SpectrogramCalculator) must store
|
||||
// the sample rate of its input audio waveform in the TimeSeries Header.
|
||||
// audio_dsp::MelFilterBank needs to know this to
|
||||
// correctly interpret the spectrogram bins.
|
||||
if (!header.has_audio_sample_rate()) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ",
|
||||
PortableDebugString(header)));
|
||||
}
|
||||
// Now we can initialize the Mfcc object.
|
||||
bool initialized =
|
||||
mfcc_->Initialize(input_length, header.audio_sample_rate());
|
||||
|
||||
if (initialized) {
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::Status(absl::StatusCode::kInternal,
|
||||
"Mfcc::Initialize returned uninitialized");
|
||||
}
|
||||
}
|
||||
|
||||
void TransformFrame(const std::vector<double>& input,
|
||||
std::vector<double>* output) const override {
|
||||
mfcc_->Compute(input, output);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<audio_dsp::Mfcc> mfcc_;
|
||||
};
|
||||
REGISTER_CALCULATOR(MfccCalculator);
|
||||
|
||||
// Calculator wrapper around the dsp/mfcc/mel_filterbank.cc routine.
|
||||
// Take frames of squared-magnitude spectra from the SpectrogramCalculator
|
||||
// and convert them into Mel-warped (linear-magnitude) spectra.
|
||||
// Note: This code computes a mel-frequency filterbank, using a simple
|
||||
// algorithm that gives bad results (some mel channels that are always zero)
|
||||
// if you ask for too many channels.
|
||||
class MelSpectrumCalculator : public FramewiseTransformCalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
return FramewiseTransformCalculatorBase::GetContract(cc);
|
||||
}
|
||||
|
||||
private:
|
||||
absl::Status ConfigureTransform(const TimeSeriesHeader& header,
|
||||
CalculatorContext* cc) override {
|
||||
MelSpectrumCalculatorOptions mel_spectrum_options =
|
||||
cc->Options<MelSpectrumCalculatorOptions>();
|
||||
mel_filterbank_.reset(new audio_dsp::MelFilterbank());
|
||||
int input_length = header.num_channels();
|
||||
set_num_output_channels(mel_spectrum_options.channel_count());
|
||||
// An upstream calculator (such as SpectrogramCalculator) must store
|
||||
// the sample rate of its input audio waveform in the TimeSeries Header.
|
||||
// audio_dsp::MelFilterBank needs to know this to
|
||||
// correctly interpret the spectrogram bins.
|
||||
if (!header.has_audio_sample_rate()) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ",
|
||||
PortableDebugString(header)));
|
||||
}
|
||||
bool initialized = mel_filterbank_->Initialize(
|
||||
input_length, header.audio_sample_rate(), num_output_channels(),
|
||||
mel_spectrum_options.min_frequency_hertz(),
|
||||
mel_spectrum_options.max_frequency_hertz());
|
||||
|
||||
if (initialized) {
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::Status(absl::StatusCode::kInternal,
|
||||
"mfcc::Initialize returned uninitialized");
|
||||
}
|
||||
}
|
||||
|
||||
void TransformFrame(const std::vector<double>& input,
|
||||
std::vector<double>* output) const override {
|
||||
mel_filterbank_->Compute(input, output);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<audio_dsp::MelFilterbank> mel_filterbank_;
|
||||
};
|
||||
REGISTER_CALCULATOR(MelSpectrumCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message MelSpectrumCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional MelSpectrumCalculatorOptions ext = 78581812;
|
||||
}
|
||||
// The fields are to populate the config parameters in
|
||||
// audio/dsp/mfcc/mel_filterbank.h
|
||||
// but the names are chose to mirror
|
||||
// audio/hearing/filterbanks/cochlea_gammatone_filterbank.proto
|
||||
// and the default values match those in
|
||||
// speech/greco3/frontend/filter_bank.proto .
|
||||
|
||||
// Total number of frequency bands to use.
|
||||
optional int32 channel_count = 1 [default = 20];
|
||||
// Lower edge of lowest triangular Mel band.
|
||||
optional float min_frequency_hertz = 2 [default = 125.0];
|
||||
// Upper edge of highest triangular Mel band.
|
||||
optional float max_frequency_hertz = 3 [default = 3800.0];
|
||||
}
|
||||
|
||||
message MfccCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional MfccCalculatorOptions ext = 78450441;
|
||||
}
|
||||
|
||||
// Specification of the underlying mel filterbank.
|
||||
optional MelSpectrumCalculatorOptions mel_spectrum_params = 1;
|
||||
|
||||
// How many MFCC coefficients to emit.
|
||||
optional uint32 mfcc_count = 2 [default = 13];
|
||||
}
|
||||
|
|
@ -1,149 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/calculators/audio/mfcc_mel_calculators.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Use a sample rate that is unlikely to be a default somewhere.
|
||||
const float kAudioSampleRate = 8800.0;
|
||||
|
||||
template <typename OptionsType, const char* CalculatorName>
|
||||
class FramewiseTransformCalculatorTest
|
||||
: public TimeSeriesCalculatorTest<OptionsType> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
this->calculator_name_ = CalculatorName;
|
||||
this->num_input_channels_ = 129;
|
||||
// This is the frame rate coming out of the SpectrogramCalculator.
|
||||
this->input_sample_rate_ = 100.0;
|
||||
}
|
||||
|
||||
// Returns the number of samples per packet.
|
||||
int GenerateRandomNonnegInputStream(int num_packets) {
|
||||
const double kSecondsPerPacket = 0.2;
|
||||
const int num_samples_per_packet =
|
||||
kSecondsPerPacket * this->input_sample_rate_;
|
||||
for (int i = 0; i < num_packets; ++i) {
|
||||
const int timestamp =
|
||||
i * kSecondsPerPacket * Timestamp::kTimestampUnitsPerSecond;
|
||||
// Mfcc, MelSpectrum expect squared-magnitude inputs, so make
|
||||
// sure the input data has no negative values.
|
||||
Matrix* sqdata = this->NewRandomMatrix(this->num_input_channels_,
|
||||
num_samples_per_packet);
|
||||
*sqdata = sqdata->array().square();
|
||||
this->AppendInputPacket(sqdata, timestamp);
|
||||
}
|
||||
return num_samples_per_packet;
|
||||
}
|
||||
|
||||
void CheckOutputPacketMetadata(int expected_num_channels,
|
||||
int expected_num_samples_per_packet) {
|
||||
int expected_timestamp = 0;
|
||||
for (const auto& packet : this->output().packets) {
|
||||
EXPECT_EQ(expected_timestamp, packet.Timestamp().Value());
|
||||
expected_timestamp += expected_num_samples_per_packet /
|
||||
this->input_sample_rate_ *
|
||||
Timestamp::kTimestampUnitsPerSecond;
|
||||
|
||||
const Matrix& output_matrix = packet.template Get<Matrix>();
|
||||
|
||||
EXPECT_EQ(output_matrix.rows(), expected_num_channels);
|
||||
EXPECT_EQ(output_matrix.cols(), expected_num_samples_per_packet);
|
||||
}
|
||||
}
|
||||
|
||||
void SetupGraphAndHeader() {
|
||||
this->InitializeGraph();
|
||||
this->FillInputHeader();
|
||||
}
|
||||
|
||||
// Argument is the expected number of dimensions (channels, columns) in
|
||||
// the output data from the Calculator under test, which the test should
|
||||
// know.
|
||||
void SetupRandomInputPackets() {
|
||||
constexpr int kNumPackets = 5;
|
||||
num_samples_per_packet_ = GenerateRandomNonnegInputStream(kNumPackets);
|
||||
}
|
||||
|
||||
absl::Status Run() { return this->RunGraph(); }
|
||||
|
||||
void CheckResults(int expected_num_channels) {
|
||||
const auto& output_header =
|
||||
this->output().header.template Get<TimeSeriesHeader>();
|
||||
EXPECT_EQ(this->input_sample_rate_, output_header.sample_rate());
|
||||
CheckOutputPacketMetadata(expected_num_channels, num_samples_per_packet_);
|
||||
|
||||
// Sanity check that output packets have non-zero energy.
|
||||
for (const auto& packet : this->output().packets) {
|
||||
const Matrix& data = packet.template Get<Matrix>();
|
||||
EXPECT_GT(data.squaredNorm(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Allows SetupRandomInputPackets() to inform CheckResults() about how
|
||||
// big the packets are supposed to be.
|
||||
int num_samples_per_packet_;
|
||||
};
|
||||
|
||||
constexpr char kMfccCalculator[] = "MfccCalculator";
|
||||
typedef FramewiseTransformCalculatorTest<MfccCalculatorOptions, kMfccCalculator>
|
||||
MfccCalculatorTest;
|
||||
TEST_F(MfccCalculatorTest, AudioSampleRateFromInputHeader) {
|
||||
audio_sample_rate_ = kAudioSampleRate;
|
||||
SetupGraphAndHeader();
|
||||
SetupRandomInputPackets();
|
||||
|
||||
MP_EXPECT_OK(Run());
|
||||
|
||||
CheckResults(options_.mfcc_count());
|
||||
}
|
||||
TEST_F(MfccCalculatorTest, NoAudioSampleRate) {
|
||||
// Leave audio_sample_rate_ == kUnset, so it is not present in the
|
||||
// input TimeSeriesHeader; expect failure.
|
||||
SetupGraphAndHeader();
|
||||
SetupRandomInputPackets();
|
||||
|
||||
EXPECT_FALSE(Run().ok());
|
||||
}
|
||||
|
||||
constexpr char kMelSpectrumCalculator[] = "MelSpectrumCalculator";
|
||||
typedef FramewiseTransformCalculatorTest<MelSpectrumCalculatorOptions,
|
||||
kMelSpectrumCalculator>
|
||||
MelSpectrumCalculatorTest;
|
||||
TEST_F(MelSpectrumCalculatorTest, AudioSampleRateFromInputHeader) {
|
||||
audio_sample_rate_ = kAudioSampleRate;
|
||||
SetupGraphAndHeader();
|
||||
SetupRandomInputPackets();
|
||||
|
||||
MP_EXPECT_OK(Run());
|
||||
|
||||
CheckResults(options_.channel_count());
|
||||
}
|
||||
TEST_F(MelSpectrumCalculatorTest, NoAudioSampleRate) {
|
||||
// Leave audio_sample_rate_ == kUnset, so it is not present in the
|
||||
// input TimeSeriesHeader; expect failure.
|
||||
SetupGraphAndHeader();
|
||||
SetupRandomInputPackets();
|
||||
|
||||
EXPECT_FALSE(Run().ok());
|
||||
}
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
// Copyright 2019, 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Defines RationalFactorResampleCalculator.
|
||||
|
||||
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.h"
|
||||
|
||||
#include "audio/dsp/resampler_q.h"
|
||||
|
||||
using audio_dsp::Resampler;
|
||||
|
||||
namespace mediapipe {
|
||||
absl::Status RationalFactorResampleCalculator::Process(CalculatorContext* cc) {
|
||||
return ProcessInternal(cc->Inputs().Index(0).Get<Matrix>(), false, cc);
|
||||
}
|
||||
|
||||
absl::Status RationalFactorResampleCalculator::Close(CalculatorContext* cc) {
|
||||
if (initial_timestamp_ == Timestamp::Unstarted()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
Matrix empty_input_frame(num_channels_, 0);
|
||||
return ProcessInternal(empty_input_frame, true, cc);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void CopyChannelToVector(const Matrix& matrix, int channel,
|
||||
std::vector<float>* vec) {
|
||||
vec->resize(matrix.cols());
|
||||
Eigen::Map<Eigen::ArrayXf>(vec->data(), vec->size()) = matrix.row(channel);
|
||||
}
|
||||
|
||||
void CopyVectorToChannel(const std::vector<float>& vec, Matrix* matrix,
|
||||
int channel) {
|
||||
if (matrix->cols() == 0) {
|
||||
matrix->resize(matrix->rows(), vec.size());
|
||||
} else {
|
||||
CHECK_EQ(vec.size(), matrix->cols());
|
||||
}
|
||||
CHECK_LT(channel, matrix->rows());
|
||||
matrix->row(channel) =
|
||||
Eigen::Map<const Eigen::ArrayXf>(vec.data(), vec.size());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status RationalFactorResampleCalculator::Open(CalculatorContext* cc) {
|
||||
RationalFactorResampleCalculatorOptions resample_options =
|
||||
cc->Options<RationalFactorResampleCalculatorOptions>();
|
||||
|
||||
if (!resample_options.has_target_sample_rate()) {
|
||||
return tool::StatusInvalid(
|
||||
"resample_options doesn't have target_sample_rate.");
|
||||
}
|
||||
target_sample_rate_ = resample_options.target_sample_rate();
|
||||
|
||||
TimeSeriesHeader input_header;
|
||||
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
|
||||
source_sample_rate_ = input_header.sample_rate();
|
||||
num_channels_ = input_header.num_channels();
|
||||
|
||||
// Don't create resamplers for pass-thru (sample rates are equal).
|
||||
if (source_sample_rate_ != target_sample_rate_) {
|
||||
resampler_.resize(num_channels_);
|
||||
for (auto& r : resampler_) {
|
||||
r = ResamplerFromOptions(source_sample_rate_, target_sample_rate_,
|
||||
resample_options);
|
||||
if (!r) {
|
||||
LOG(ERROR) << "Failed to initialize resampler.";
|
||||
return absl::UnknownError("Failed to initialize resampler.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TimeSeriesHeader* output_header = new TimeSeriesHeader(input_header);
|
||||
output_header->set_sample_rate(target_sample_rate_);
|
||||
// The resampler doesn't make guarantees about how many samples will
|
||||
// be in each packet.
|
||||
output_header->clear_packet_rate();
|
||||
output_header->clear_num_samples();
|
||||
|
||||
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
|
||||
cumulative_output_samples_ = 0;
|
||||
cumulative_input_samples_ = 0;
|
||||
initial_timestamp_ = Timestamp::Unstarted();
|
||||
check_inconsistent_timestamps_ =
|
||||
resample_options.check_inconsistent_timestamps();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RationalFactorResampleCalculator::ProcessInternal(
|
||||
const Matrix& input_frame, bool should_flush, CalculatorContext* cc) {
|
||||
if (initial_timestamp_ == Timestamp::Unstarted()) {
|
||||
initial_timestamp_ = cc->InputTimestamp();
|
||||
}
|
||||
|
||||
if (check_inconsistent_timestamps_) {
|
||||
time_series_util::LogWarningIfTimestampIsInconsistent(
|
||||
cc->InputTimestamp(), initial_timestamp_, cumulative_input_samples_,
|
||||
source_sample_rate_);
|
||||
}
|
||||
Timestamp output_timestamp =
|
||||
initial_timestamp_ + ((cumulative_output_samples_ / target_sample_rate_) *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
|
||||
cumulative_input_samples_ += input_frame.cols();
|
||||
std::unique_ptr<Matrix> output_frame(new Matrix(num_channels_, 0));
|
||||
if (resampler_.empty()) {
|
||||
// Sample rates were same for input and output; pass-thru.
|
||||
*output_frame = input_frame;
|
||||
} else {
|
||||
if (!Resample(input_frame, output_frame.get(), should_flush)) {
|
||||
return absl::UnknownError("Resample() failed.");
|
||||
}
|
||||
}
|
||||
cumulative_output_samples_ += output_frame->cols();
|
||||
|
||||
if (output_frame->cols() > 0) {
|
||||
cc->Outputs().Index(0).Add(output_frame.release(), output_timestamp);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
bool RationalFactorResampleCalculator::Resample(const Matrix& input_frame,
|
||||
Matrix* output_frame,
|
||||
bool should_flush) {
|
||||
std::vector<float> input_vector;
|
||||
std::vector<float> output_vector;
|
||||
for (int i = 0; i < input_frame.rows(); ++i) {
|
||||
CopyChannelToVector(input_frame, i, &input_vector);
|
||||
if (should_flush) {
|
||||
resampler_[i]->Flush(&output_vector);
|
||||
} else {
|
||||
resampler_[i]->ProcessSamples(input_vector, &output_vector);
|
||||
}
|
||||
CopyVectorToChannel(output_vector, output_frame, i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// static
|
||||
std::unique_ptr<Resampler<float>>
|
||||
RationalFactorResampleCalculator::ResamplerFromOptions(
|
||||
const double source_sample_rate, const double target_sample_rate,
|
||||
const RationalFactorResampleCalculatorOptions& options) {
|
||||
std::unique_ptr<Resampler<float>> resampler;
|
||||
const auto& rational_factor_options =
|
||||
options.resampler_rational_factor_options();
|
||||
audio_dsp::QResamplerParams params;
|
||||
if (rational_factor_options.has_radius() &&
|
||||
rational_factor_options.has_cutoff() &&
|
||||
rational_factor_options.has_kaiser_beta()) {
|
||||
// Convert RationalFactorResampler kernel parameters to QResampler
|
||||
// settings.
|
||||
params.filter_radius_factor =
|
||||
rational_factor_options.radius() *
|
||||
std::min(1.0, target_sample_rate / source_sample_rate);
|
||||
params.cutoff_proportion = 2 * rational_factor_options.cutoff() /
|
||||
std::min(source_sample_rate, target_sample_rate);
|
||||
params.kaiser_beta = rational_factor_options.kaiser_beta();
|
||||
}
|
||||
// Set large enough so that the resampling factor between common sample
|
||||
// rates (e.g. 8kHz, 16kHz, 22.05kHz, 32kHz, 44.1kHz, 48kHz) is exact, and
|
||||
// that any factor is represented with error less than 0.025%.
|
||||
params.max_denominator = 2000;
|
||||
|
||||
// NOTE: QResampler supports multichannel resampling, so the code might be
|
||||
// simplified using a single instance rather than one per channel.
|
||||
resampler = absl::make_unique<audio_dsp::QResampler<float>>(
|
||||
source_sample_rate, target_sample_rate, /*num_channels=*/1, params);
|
||||
if (resampler != nullptr && !resampler->Valid()) {
|
||||
resampler = std::unique_ptr<Resampler<float>>();
|
||||
}
|
||||
return resampler;
|
||||
}
|
||||
|
||||
REGISTER_CALCULATOR(RationalFactorResampleCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
// Copyright 2019, 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "audio/dsp/resampler.h"
|
||||
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
// MediaPipe Calculator for resampling a (vector-valued)
|
||||
// input time series with a uniform sample rate. The output
|
||||
// stream's sampling rate is specified by target_sample_rate in the
|
||||
// RationalFactorResampleCalculatorOptions. The output time series may have
|
||||
// a varying number of samples per frame.
|
||||
//
|
||||
// NOTE: This calculator uses QResampler, despite the name, which supersedes
|
||||
// RationalFactorResampler.
|
||||
class RationalFactorResampleCalculator : public CalculatorBase {
|
||||
public:
|
||||
struct TestAccess;
|
||||
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Single input stream with TimeSeriesHeader.
|
||||
);
|
||||
cc->Outputs().Index(0).Set<Matrix>(
|
||||
// Resampled stream with TimeSeriesHeader.
|
||||
);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// Returns FAIL if the input stream header is invalid or if the
|
||||
// resampler cannot be initialized.
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
// Resamples a packet of TimeSeries data. Returns FAIL if the
|
||||
// resampler state becomes inconsistent.
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
// Flushes any remaining state. Returns FAIL if the resampler state
|
||||
// becomes inconsistent.
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
protected:
|
||||
typedef audio_dsp::Resampler<float> ResamplerType;
|
||||
|
||||
// Returns a Resampler<float> implementation specified by the
|
||||
// RationalFactorResampleCalculatorOptions proto. Returns null if the options
|
||||
// specify an invalid resampler.
|
||||
static std::unique_ptr<ResamplerType> ResamplerFromOptions(
|
||||
const double source_sample_rate, const double target_sample_rate,
|
||||
const RationalFactorResampleCalculatorOptions& options);
|
||||
|
||||
// Does Timestamp bookkeeping and resampling common to Process() and
|
||||
// Close(). Returns FAIL if the resampler state becomes
|
||||
// inconsistent.
|
||||
absl::Status ProcessInternal(const Matrix& input_frame, bool should_flush,
|
||||
CalculatorContext* cc);
|
||||
|
||||
// Uses the internal resampler_ objects to actually resample each
|
||||
// row of the input TimeSeries. Returns false if the resampler
|
||||
// state becomes inconsistent.
|
||||
bool Resample(const Matrix& input_frame, Matrix* output_frame,
|
||||
bool should_flush);
|
||||
|
||||
double source_sample_rate_;
|
||||
double target_sample_rate_;
|
||||
int64 cumulative_input_samples_;
|
||||
int64 cumulative_output_samples_;
|
||||
Timestamp initial_timestamp_;
|
||||
bool check_inconsistent_timestamps_;
|
||||
int num_channels_;
|
||||
std::vector<std::unique_ptr<ResamplerType>> resampler_;
|
||||
};
|
||||
|
||||
// Test-only access to RationalFactorResampleCalculator methods.
|
||||
struct RationalFactorResampleCalculator::TestAccess {
|
||||
static std::unique_ptr<ResamplerType> ResamplerFromOptions(
|
||||
const double source_sample_rate, const double target_sample_rate,
|
||||
const RationalFactorResampleCalculatorOptions& options) {
|
||||
return RationalFactorResampleCalculator::ResamplerFromOptions(
|
||||
source_sample_rate, target_sample_rate, options);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_AUDIO_RATIONAL_FACTOR_RESAMPLE_CALCULATOR_H_
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
// Copyright 2019, 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
// NOTE: This calculator uses QResampler, despite the name, which supersedes
|
||||
// RationalFactorResampler.
|
||||
message RationalFactorResampleCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional RationalFactorResampleCalculatorOptions ext = 259760074;
|
||||
}
|
||||
|
||||
// target_sample_rate is the sample rate, in Hertz, of the output
|
||||
// stream. Required. Must be greater than 0.
|
||||
optional double target_sample_rate = 1;
|
||||
|
||||
// Parameters for initializing QResampler. See QResampler for more details.
|
||||
message ResamplerRationalFactorOptions {
|
||||
// Kernel radius in units of input samples.
|
||||
optional double radius = 1;
|
||||
// Anti-aliasing cutoff frequency in Hertz. A reasonable setting is
|
||||
// 0.45 * min(input_sample_rate, output_sample_rate).
|
||||
optional double cutoff = 2;
|
||||
// The Kaiser beta parameter for the kernel window.
|
||||
optional double kaiser_beta = 3 [default = 6.0];
|
||||
}
|
||||
optional ResamplerRationalFactorOptions resampler_rational_factor_options = 2;
|
||||
|
||||
// Set to false to disable checks for jitter in timestamp values. Useful with
|
||||
// live audio input.
|
||||
optional bool check_inconsistent_timestamps = 3 [default = true];
|
||||
}
|
||||
|
|
@ -1,246 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "audio/dsp/signal_vector_util.h"
|
||||
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.pb.h"
|
||||
#include "mediapipe/framework//tool/validate_type.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
const int kInitialTimestampOffsetMilliseconds = 4;
|
||||
|
||||
class RationalFactorResampleCalculatorTest
|
||||
: public TimeSeriesCalculatorTest<RationalFactorResampleCalculatorOptions> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "RationalFactorResampleCalculator";
|
||||
input_sample_rate_ = 4000.0;
|
||||
num_input_channels_ = 3;
|
||||
}
|
||||
|
||||
// Expects two vectors whose lengths are almost the same and whose
|
||||
// elements are equal (for indices that are present in both).
|
||||
//
|
||||
// This is useful because the resampler doesn't make precise
|
||||
// guarantees about its output size.
|
||||
void ExpectVectorMostlyFloatEq(const std::vector<float>& expected,
|
||||
const std::vector<float>& actual) {
|
||||
// Lengths should be close, but don't have to be equal.
|
||||
ASSERT_NEAR(expected.size(), actual.size(), 1);
|
||||
for (int i = 0; i < std::min(expected.size(), actual.size()); ++i) {
|
||||
EXPECT_FLOAT_EQ(expected[i], actual[i]) << " where i=" << i << ".";
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a float value with the sample, channel, and timestamp
|
||||
// separated by a few orders of magnitude, for easy parsing by
|
||||
// humans.
|
||||
double TestValue(int sample, int channel, int timestamp_in_microseconds) {
|
||||
return timestamp_in_microseconds * 100.0 + sample + channel / 10.0;
|
||||
}
|
||||
|
||||
// Caller takes ownership of the returned value.
|
||||
Matrix* NewTestFrame(int num_channels, int num_samples, int timestamp) {
|
||||
auto matrix = new Matrix(num_channels, num_samples);
|
||||
for (int c = 0; c < num_channels; ++c) {
|
||||
for (int i = 0; i < num_samples; ++i) {
|
||||
(*matrix)(c, i) = TestValue(i, c, timestamp);
|
||||
}
|
||||
}
|
||||
return matrix;
|
||||
}
|
||||
|
||||
// Initializes and runs the test graph.
|
||||
absl::Status Run(double output_sample_rate) {
|
||||
options_.set_target_sample_rate(output_sample_rate);
|
||||
InitializeGraph();
|
||||
|
||||
FillInputHeader();
|
||||
concatenated_input_samples_.resize(num_input_channels_, 0);
|
||||
num_input_samples_ = 0;
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
int packet_size = (i + 1) * 10;
|
||||
int timestamp = kInitialTimestampOffsetMilliseconds +
|
||||
num_input_samples_ / input_sample_rate_ *
|
||||
Timestamp::kTimestampUnitsPerSecond;
|
||||
Matrix* data_frame =
|
||||
NewTestFrame(num_input_channels_, packet_size, timestamp);
|
||||
|
||||
// Keep a reference copy of the input.
|
||||
//
|
||||
// conservativeResize() is needed here to preserve the existing
|
||||
// data. Eigen's resize() resizes without preserving data.
|
||||
concatenated_input_samples_.conservativeResize(
|
||||
num_input_channels_, num_input_samples_ + packet_size);
|
||||
concatenated_input_samples_.rightCols(packet_size) = *data_frame;
|
||||
num_input_samples_ += packet_size;
|
||||
|
||||
AppendInputPacket(data_frame, timestamp);
|
||||
}
|
||||
|
||||
return RunGraph();
|
||||
}
|
||||
|
||||
void CheckOutputLength(double output_sample_rate) {
|
||||
double factor = output_sample_rate / input_sample_rate_;
|
||||
|
||||
int num_output_samples = 0;
|
||||
for (const Packet& packet : output().packets) {
|
||||
num_output_samples += packet.Get<Matrix>().cols();
|
||||
}
|
||||
|
||||
// The exact number of expected samples may vary based on the implementation
|
||||
// of the resampler since the exact value is not an integer.
|
||||
const double expected_num_output_samples = num_input_samples_ * factor;
|
||||
EXPECT_LE(ceil(expected_num_output_samples), num_output_samples);
|
||||
EXPECT_GE(ceil(expected_num_output_samples) + 11, num_output_samples);
|
||||
}
|
||||
|
||||
// Checks that output timestamps are consistent with the
|
||||
// output_sample_rate and output packet sizes.
|
||||
void CheckOutputPacketTimestamps(double output_sample_rate) {
|
||||
int num_output_samples = 0;
|
||||
for (const Packet& packet : output().packets) {
|
||||
const int expected_timestamp = kInitialTimestampOffsetMilliseconds +
|
||||
num_output_samples / output_sample_rate *
|
||||
Timestamp::kTimestampUnitsPerSecond;
|
||||
EXPECT_NEAR(expected_timestamp, packet.Timestamp().Value(), 1);
|
||||
num_output_samples += packet.Get<Matrix>().cols();
|
||||
}
|
||||
}
|
||||
|
||||
// Checks that output values from the calculator (which resamples
|
||||
// packet-by-packet) are consistent with resampling the entire
|
||||
// signal at once.
|
||||
void CheckOutputValues(double output_sample_rate) {
|
||||
for (int i = 0; i < num_input_channels_; ++i) {
|
||||
auto verification_resampler =
|
||||
RationalFactorResampleCalculator::TestAccess::ResamplerFromOptions(
|
||||
input_sample_rate_, output_sample_rate, options_);
|
||||
|
||||
std::vector<float> input_data;
|
||||
for (int j = 0; j < num_input_samples_; ++j) {
|
||||
input_data.push_back(concatenated_input_samples_(i, j));
|
||||
}
|
||||
std::vector<float> expected_resampled_data;
|
||||
std::vector<float> temp;
|
||||
verification_resampler->ProcessSamples(input_data, &temp);
|
||||
audio_dsp::VectorAppend(&expected_resampled_data, temp);
|
||||
verification_resampler->Flush(&temp);
|
||||
audio_dsp::VectorAppend(&expected_resampled_data, temp);
|
||||
std::vector<float> actual_resampled_data;
|
||||
for (const Packet& packet : output().packets) {
|
||||
Matrix output_frame_row = packet.Get<Matrix>().row(i);
|
||||
actual_resampled_data.insert(
|
||||
actual_resampled_data.end(), &output_frame_row(0),
|
||||
&output_frame_row(0) + output_frame_row.cols());
|
||||
}
|
||||
|
||||
ExpectVectorMostlyFloatEq(expected_resampled_data, actual_resampled_data);
|
||||
}
|
||||
}
|
||||
|
||||
void CheckOutputHeaders(double output_sample_rate) {
|
||||
const TimeSeriesHeader& output_header =
|
||||
output().header.Get<TimeSeriesHeader>();
|
||||
TimeSeriesHeader expected_header;
|
||||
expected_header.set_sample_rate(output_sample_rate);
|
||||
expected_header.set_num_channels(num_input_channels_);
|
||||
EXPECT_THAT(output_header, mediapipe::EqualsProto(expected_header));
|
||||
}
|
||||
|
||||
void CheckOutput(double output_sample_rate) {
|
||||
CheckOutputLength(output_sample_rate);
|
||||
CheckOutputPacketTimestamps(output_sample_rate);
|
||||
CheckOutputValues(output_sample_rate);
|
||||
CheckOutputHeaders(output_sample_rate);
|
||||
}
|
||||
|
||||
void CheckOutputUnchanged() {
|
||||
for (int i = 0; i < num_input_channels_; ++i) {
|
||||
std::vector<float> expected_resampled_data;
|
||||
for (int j = 0; j < num_input_samples_; ++j) {
|
||||
expected_resampled_data.push_back(concatenated_input_samples_(i, j));
|
||||
}
|
||||
std::vector<float> actual_resampled_data;
|
||||
for (const Packet& packet : output().packets) {
|
||||
Matrix output_frame_row = packet.Get<Matrix>().row(i);
|
||||
actual_resampled_data.insert(
|
||||
actual_resampled_data.end(), &output_frame_row(0),
|
||||
&output_frame_row(0) + output_frame_row.cols());
|
||||
}
|
||||
ExpectVectorMostlyFloatEq(expected_resampled_data, actual_resampled_data);
|
||||
}
|
||||
}
|
||||
|
||||
int num_input_samples_;
|
||||
Matrix concatenated_input_samples_;
|
||||
};
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, Upsample) {
|
||||
const double kUpsampleRate = input_sample_rate_ * 1.9;
|
||||
MP_ASSERT_OK(Run(kUpsampleRate));
|
||||
CheckOutput(kUpsampleRate);
|
||||
}
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, Downsample) {
|
||||
const double kDownsampleRate = input_sample_rate_ / 1.9;
|
||||
MP_ASSERT_OK(Run(kDownsampleRate));
|
||||
CheckOutput(kDownsampleRate);
|
||||
}
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, UsesRationalFactorResampler) {
|
||||
const double kUpsampleRate = input_sample_rate_ * 2;
|
||||
MP_ASSERT_OK(Run(kUpsampleRate));
|
||||
CheckOutput(kUpsampleRate);
|
||||
}
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, PassthroughIfSampleRateUnchanged) {
|
||||
const double kUpsampleRate = input_sample_rate_;
|
||||
MP_ASSERT_OK(Run(kUpsampleRate));
|
||||
CheckOutputUnchanged();
|
||||
}
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, FailsOnBadTargetRate) {
|
||||
ASSERT_FALSE(Run(-999.9).ok()); // Invalid output sample rate.
|
||||
}
|
||||
|
||||
TEST_F(RationalFactorResampleCalculatorTest, DoesNotDieOnEmptyInput) {
|
||||
options_.set_target_sample_rate(input_sample_rate_);
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
MP_ASSERT_OK(RunGraph());
|
||||
EXPECT_TRUE(output().packets.empty());
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,452 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Defines SpectrogramCalculator.
|
||||
#include <math.h>
|
||||
|
||||
#include <complex>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "audio/dsp/spectrogram/spectrogram.h"
|
||||
#include "audio/dsp/window_functions.h"
|
||||
#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/core_proto_inc.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/source_location.h"
|
||||
#include "mediapipe/framework/port/status_builder.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// MediaPipe Calculator for computing the "spectrogram" (short-time Fourier
|
||||
// transform squared-magnitude, by default) of a multichannel input
|
||||
// time series, including optionally overlapping frames. Options are
|
||||
// specified in SpectrogramCalculatorOptions proto (where names are chosen
|
||||
// to mirror TimeSeriesFramerCalculator):
|
||||
//
|
||||
// Result is a MatrixData record (for single channel input and when the
|
||||
// allow_multichannel_input flag is false), or a vector of MatrixData records,
|
||||
// one for each channel (when the allow_multichannel_input flag is set). The
|
||||
// rows of each spectrogram matrix correspond to the n_fft/2+1 unique complex
|
||||
// values, or squared/linear/dB magnitudes, depending on the output_type option.
|
||||
// Each input packet will result in zero or one output packets, each containing
|
||||
// one Matrix for each channel of the input, where each Matrix has one or more
|
||||
// columns of spectral values, one for each complete frame of input samples. If
|
||||
// the input packet contains too few samples to trigger a new output frame, no
|
||||
// output packet is generated (since zero-length packets are not legal since
|
||||
// they would result in timestamps that were equal, not strictly increasing).
|
||||
//
|
||||
// Output packet Timestamps are set to the beginning of each frame. This is to
|
||||
// allow calculators downstream from SpectrogramCalculator to have aligned
|
||||
// Timestamps regardless of a packet's signal length.
|
||||
//
|
||||
// Both frame_duration_seconds and frame_overlap_seconds will be
|
||||
// rounded to the nearest integer number of samples. Conseqently, all output
|
||||
// frames will be based on the same number of input samples, and each
|
||||
// analysis frame will advance from its predecessor by the same time step.
|
||||
class SpectrogramCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Input stream with TimeSeriesHeader.
|
||||
);
|
||||
|
||||
SpectrogramCalculatorOptions spectrogram_options =
|
||||
cc->Options<SpectrogramCalculatorOptions>();
|
||||
if (!spectrogram_options.allow_multichannel_input()) {
|
||||
if (spectrogram_options.output_type() ==
|
||||
SpectrogramCalculatorOptions::COMPLEX) {
|
||||
cc->Outputs().Index(0).Set<Eigen::MatrixXcf>(
|
||||
// Complex spectrogram frames with TimeSeriesHeader.
|
||||
);
|
||||
} else {
|
||||
cc->Outputs().Index(0).Set<Matrix>(
|
||||
// Spectrogram frames with TimeSeriesHeader.
|
||||
);
|
||||
}
|
||||
} else {
|
||||
if (spectrogram_options.output_type() ==
|
||||
SpectrogramCalculatorOptions::COMPLEX) {
|
||||
cc->Outputs().Index(0).Set<std::vector<Eigen::MatrixXcf>>(
|
||||
// Complex spectrogram frames with MultiStreamTimeSeriesHeader.
|
||||
);
|
||||
} else {
|
||||
cc->Outputs().Index(0).Set<std::vector<Matrix>>(
|
||||
// Spectrogram frames with MultiStreamTimeSeriesHeader.
|
||||
);
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Returns FAIL if the input stream header is invalid.
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
|
||||
// Outputs at most one packet consisting of a single Matrix with one or
|
||||
// more columns containing the spectral values from as many input frames
|
||||
// as are completed by the input samples. Always returns OK.
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
// Performs zero-padding and processing of any remaining samples
|
||||
// if pad_final_packet is set.
|
||||
// Returns OK.
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
Timestamp CurrentOutputTimestamp(CalculatorContext* cc) {
|
||||
if (use_local_timestamp_) {
|
||||
const Timestamp now = cc->InputTimestamp();
|
||||
if (now == Timestamp::Done()) {
|
||||
// During Close the timestamp is not available, send an estimate.
|
||||
return last_local_output_timestamp_ +
|
||||
round(last_completed_frames_ * frame_step_samples() *
|
||||
Timestamp::kTimestampUnitsPerSecond / input_sample_rate_);
|
||||
}
|
||||
last_local_output_timestamp_ = now;
|
||||
return now;
|
||||
}
|
||||
return CumulativeOutputTimestamp();
|
||||
}
|
||||
|
||||
Timestamp CumulativeOutputTimestamp() {
|
||||
// Cumulative output timestamp is the *center* of the next frame to be
|
||||
// emitted, hence delayed by half a window duration compared to relevant
|
||||
// input timestamp.
|
||||
return initial_input_timestamp_ +
|
||||
round(cumulative_completed_frames_ * frame_step_samples() *
|
||||
Timestamp::kTimestampUnitsPerSecond / input_sample_rate_);
|
||||
}
|
||||
|
||||
int frame_step_samples() const {
|
||||
return frame_duration_samples_ - frame_overlap_samples_;
|
||||
}
|
||||
|
||||
// Take the next set of input samples, already translated into a
|
||||
// vector<float> and pass them to the spectrogram object.
|
||||
// Convert the output of the spectrogram object into a Matrix (or an
|
||||
// Eigen::MatrixXcf if complex-valued output is requested) and pass to
|
||||
// MediaPipe output.
|
||||
absl::Status ProcessVector(const Matrix& input_stream, CalculatorContext* cc);
|
||||
|
||||
// Templated function to process either real- or complex-output spectrogram.
|
||||
template <class OutputMatrixType>
|
||||
absl::Status ProcessVectorToOutput(
|
||||
const Matrix& input_stream,
|
||||
const OutputMatrixType postprocess_output_fn(const OutputMatrixType&),
|
||||
CalculatorContext* cc);
|
||||
|
||||
// Use the MediaPipe timestamp instead of the estimated one. Useful when the
|
||||
// data is intermittent.
|
||||
bool use_local_timestamp_;
|
||||
Timestamp last_local_output_timestamp_;
|
||||
|
||||
double input_sample_rate_;
|
||||
bool pad_final_packet_;
|
||||
int frame_duration_samples_;
|
||||
int frame_overlap_samples_;
|
||||
// How many samples we've been passed, used for checking input time stamps.
|
||||
int64 cumulative_input_samples_;
|
||||
// How many frames we've emitted, used for calculating output time stamps.
|
||||
int64 cumulative_completed_frames_;
|
||||
// How many frames were emitted last, used for estimating the timestamp on
|
||||
// Close when use_local_timestamp_ is true;
|
||||
int64 last_completed_frames_;
|
||||
Timestamp initial_input_timestamp_;
|
||||
int num_input_channels_;
|
||||
// How many frequency bins we emit (=N_FFT/2 + 1).
|
||||
int num_output_channels_;
|
||||
// Which output type?
|
||||
int output_type_;
|
||||
// Output type: mono or multichannel.
|
||||
bool allow_multichannel_input_;
|
||||
// Vector of Spectrogram objects, one for each channel.
|
||||
std::vector<std::unique_ptr<audio_dsp::Spectrogram>> spectrogram_generators_;
|
||||
// Fixed scale factor applied to output values (regardless of type).
|
||||
double output_scale_;
|
||||
|
||||
static const float kLnPowerToDb;
|
||||
};
|
||||
REGISTER_CALCULATOR(SpectrogramCalculator);
|
||||
|
||||
// Factor to convert ln(magnitude_squared) to deciBels = 10.0/ln(10.0).
|
||||
const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518;
|
||||
|
||||
absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
|
||||
SpectrogramCalculatorOptions spectrogram_options =
|
||||
cc->Options<SpectrogramCalculatorOptions>();
|
||||
|
||||
use_local_timestamp_ = spectrogram_options.use_local_timestamp();
|
||||
|
||||
if (spectrogram_options.frame_duration_seconds() <= 0.0) {
|
||||
// TODO: return an error.
|
||||
}
|
||||
if (spectrogram_options.frame_overlap_seconds() >=
|
||||
spectrogram_options.frame_duration_seconds()) {
|
||||
// TODO: return an error.
|
||||
}
|
||||
if (spectrogram_options.frame_overlap_seconds() < 0.0) {
|
||||
// TODO: return an error.
|
||||
}
|
||||
|
||||
TimeSeriesHeader input_header;
|
||||
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
|
||||
input_sample_rate_ = input_header.sample_rate();
|
||||
num_input_channels_ = input_header.num_channels();
|
||||
|
||||
if (!spectrogram_options.allow_multichannel_input() &&
|
||||
num_input_channels_ != 1) {
|
||||
// TODO: return an error.
|
||||
}
|
||||
|
||||
frame_duration_samples_ =
|
||||
round(spectrogram_options.frame_duration_seconds() * input_sample_rate_);
|
||||
frame_overlap_samples_ =
|
||||
round(spectrogram_options.frame_overlap_seconds() * input_sample_rate_);
|
||||
|
||||
pad_final_packet_ = spectrogram_options.pad_final_packet();
|
||||
output_type_ = spectrogram_options.output_type();
|
||||
allow_multichannel_input_ = spectrogram_options.allow_multichannel_input();
|
||||
|
||||
output_scale_ = spectrogram_options.output_scale();
|
||||
|
||||
std::vector<double> window;
|
||||
switch (spectrogram_options.window_type()) {
|
||||
case SpectrogramCalculatorOptions::COSINE:
|
||||
audio_dsp::CosineWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window);
|
||||
break;
|
||||
case SpectrogramCalculatorOptions::HANN:
|
||||
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window);
|
||||
break;
|
||||
case SpectrogramCalculatorOptions::HAMMING:
|
||||
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window);
|
||||
break;
|
||||
}
|
||||
|
||||
// Propagate settings down to the actual Spectrogram object.
|
||||
spectrogram_generators_.clear();
|
||||
for (int i = 0; i < num_input_channels_; i++) {
|
||||
spectrogram_generators_.push_back(
|
||||
std::unique_ptr<audio_dsp::Spectrogram>(new audio_dsp::Spectrogram()));
|
||||
spectrogram_generators_[i]->Initialize(window, frame_step_samples());
|
||||
}
|
||||
|
||||
num_output_channels_ =
|
||||
spectrogram_generators_[0]->output_frequency_channels();
|
||||
std::unique_ptr<TimeSeriesHeader> output_header(
|
||||
new TimeSeriesHeader(input_header));
|
||||
// Store the actual sample rate of the input audio in the TimeSeriesHeader
|
||||
// so that subsequent calculators can figure out the frequency scale of
|
||||
// our output.
|
||||
output_header->set_audio_sample_rate(input_sample_rate_);
|
||||
// Setup rest of output header.
|
||||
output_header->set_num_channels(num_output_channels_);
|
||||
output_header->set_sample_rate(input_sample_rate_ / frame_step_samples());
|
||||
// Although we usually generate one output packet for each input
|
||||
// packet, this might not be true for input packets whose size is smaller
|
||||
// than the analysis window length. So we clear output_header.packet_rate
|
||||
// because we can't guarantee a constant packet rate. Similarly, the number
|
||||
// of output frames per packet depends on the input packet, so we also clear
|
||||
// output_header.num_samples.
|
||||
output_header->clear_packet_rate();
|
||||
output_header->clear_num_samples();
|
||||
if (!spectrogram_options.allow_multichannel_input()) {
|
||||
cc->Outputs().Index(0).SetHeader(Adopt(output_header.release()));
|
||||
} else {
|
||||
std::unique_ptr<MultiStreamTimeSeriesHeader> multichannel_output_header(
|
||||
new MultiStreamTimeSeriesHeader());
|
||||
*multichannel_output_header->mutable_time_series_header() = *output_header;
|
||||
multichannel_output_header->set_num_streams(num_input_channels_);
|
||||
cc->Outputs().Index(0).SetHeader(
|
||||
Adopt(multichannel_output_header.release()));
|
||||
}
|
||||
cumulative_completed_frames_ = 0;
|
||||
last_completed_frames_ = 0;
|
||||
initial_input_timestamp_ = Timestamp::Unstarted();
|
||||
if (use_local_timestamp_) {
|
||||
// Inform the framework that the calculator will output packets at the same
|
||||
// timestamps as input packets to enable packet queueing optimizations. The
|
||||
// final packet (emitted from Close()) does not follow this rule but it's
|
||||
// sufficient that its timestamp is strictly greater than the timestamp of
|
||||
// the previous packet.
|
||||
cc->SetOffset(0);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SpectrogramCalculator::Process(CalculatorContext* cc) {
|
||||
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
|
||||
initial_input_timestamp_ = cc->InputTimestamp();
|
||||
}
|
||||
|
||||
const Matrix& input_stream = cc->Inputs().Index(0).Get<Matrix>();
|
||||
if (input_stream.rows() != num_input_channels_) {
|
||||
// TODO: return an error.
|
||||
}
|
||||
|
||||
cumulative_input_samples_ += input_stream.cols();
|
||||
|
||||
return ProcessVector(input_stream, cc);
|
||||
}
|
||||
|
||||
template <class OutputMatrixType>
|
||||
absl::Status SpectrogramCalculator::ProcessVectorToOutput(
|
||||
const Matrix& input_stream,
|
||||
const OutputMatrixType postprocess_output_fn(const OutputMatrixType&),
|
||||
CalculatorContext* cc) {
|
||||
std::unique_ptr<std::vector<OutputMatrixType>> spectrogram_matrices(
|
||||
new std::vector<OutputMatrixType>());
|
||||
std::vector<std::vector<typename OutputMatrixType::Scalar>> output_vectors;
|
||||
|
||||
// Compute a spectrogram for each channel.
|
||||
int num_output_time_frames;
|
||||
for (int channel = 0; channel < input_stream.rows(); ++channel) {
|
||||
output_vectors.clear();
|
||||
|
||||
// Copy one row (channel) of the input matrix into the std::vector.
|
||||
std::vector<float> input_vector(input_stream.cols());
|
||||
Eigen::Map<Matrix>(&input_vector[0], 1, input_vector.size()) =
|
||||
input_stream.row(channel);
|
||||
|
||||
if (!spectrogram_generators_[channel]->ComputeSpectrogram(
|
||||
input_vector, &output_vectors)) {
|
||||
return absl::Status(absl::StatusCode::kInternal,
|
||||
"Spectrogram returned failure");
|
||||
}
|
||||
if (channel == 0) {
|
||||
// Record the number of time frames we expect from each channel.
|
||||
num_output_time_frames = output_vectors.size();
|
||||
} else {
|
||||
RET_CHECK_EQ(output_vectors.size(), num_output_time_frames)
|
||||
<< "Inconsistent spectrogram time frames for channel " << channel;
|
||||
}
|
||||
// Skip remaining processing if there are too few input samples to trigger
|
||||
// any output frames.
|
||||
if (!output_vectors.empty()) {
|
||||
// Translate the returned values into a matrix of output frames.
|
||||
OutputMatrixType output_frames(num_output_channels_,
|
||||
output_vectors.size());
|
||||
for (int frame = 0; frame < output_vectors.size(); ++frame) {
|
||||
Eigen::Map<const OutputMatrixType> frame_map(
|
||||
&output_vectors[frame][0], output_vectors[frame].size(), 1);
|
||||
// The underlying dsp object returns squared magnitudes; here
|
||||
// we optionally translate to linear magnitude or dB.
|
||||
output_frames.col(frame) =
|
||||
output_scale_ * postprocess_output_fn(frame_map);
|
||||
}
|
||||
spectrogram_matrices->push_back(output_frames);
|
||||
}
|
||||
}
|
||||
// If the input is very short, there may not be enough accumulated,
|
||||
// unprocessed samples to cause any new frames to be generated by
|
||||
// the spectrogram object. If so, we don't want to emit
|
||||
// a packet at all.
|
||||
if (!spectrogram_matrices->empty()) {
|
||||
RET_CHECK_EQ(spectrogram_matrices->size(), input_stream.rows())
|
||||
<< "Inconsistent number of spectrogram channels.";
|
||||
if (allow_multichannel_input_) {
|
||||
cc->Outputs().Index(0).Add(spectrogram_matrices.release(),
|
||||
CurrentOutputTimestamp(cc));
|
||||
} else {
|
||||
cc->Outputs().Index(0).Add(
|
||||
new OutputMatrixType(spectrogram_matrices->at(0)),
|
||||
CurrentOutputTimestamp(cc));
|
||||
}
|
||||
cumulative_completed_frames_ += output_vectors.size();
|
||||
last_completed_frames_ = output_vectors.size();
|
||||
if (!use_local_timestamp_) {
|
||||
// In non-local timestamp mode the timestamp of the next packet will be
|
||||
// equal to CumulativeOutputTimestamp(). Inform the framework about this
|
||||
// fact to enable packet queueing optimizations.
|
||||
cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SpectrogramCalculator::ProcessVector(const Matrix& input_stream,
|
||||
CalculatorContext* cc) {
|
||||
switch (output_type_) {
|
||||
// These blocks deliberately ignore clang-format to preserve the
|
||||
// "silhouette" of the different cases.
|
||||
// clang-format off
|
||||
case SpectrogramCalculatorOptions::COMPLEX: {
|
||||
return ProcessVectorToOutput(
|
||||
input_stream,
|
||||
+[](const Eigen::MatrixXcf& col) -> const Eigen::MatrixXcf {
|
||||
return col;
|
||||
}, cc);
|
||||
}
|
||||
case SpectrogramCalculatorOptions::SQUARED_MAGNITUDE: {
|
||||
return ProcessVectorToOutput(
|
||||
input_stream,
|
||||
+[](const Matrix& col) -> const Matrix {
|
||||
return col;
|
||||
}, cc);
|
||||
}
|
||||
case SpectrogramCalculatorOptions::LINEAR_MAGNITUDE: {
|
||||
return ProcessVectorToOutput(
|
||||
input_stream,
|
||||
+[](const Matrix& col) -> const Matrix {
|
||||
return col.array().sqrt().matrix();
|
||||
}, cc);
|
||||
}
|
||||
case SpectrogramCalculatorOptions::DECIBELS: {
|
||||
return ProcessVectorToOutput(
|
||||
input_stream,
|
||||
+[](const Matrix& col) -> const Matrix {
|
||||
return kLnPowerToDb * col.array().log().matrix();
|
||||
}, cc);
|
||||
}
|
||||
// clang-format on
|
||||
default: {
|
||||
return absl::Status(absl::StatusCode::kInvalidArgument,
|
||||
"Unrecognized spectrogram output type.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status SpectrogramCalculator::Close(CalculatorContext* cc) {
|
||||
if (cumulative_input_samples_ > 0 && pad_final_packet_) {
|
||||
// We can flush any remaining samples by sending frame_step_samples - 1
|
||||
// zeros to the Process method, and letting it do its thing,
|
||||
// UNLESS we have fewer than one window's worth of samples, in which case
|
||||
// we pad to exactly one frame_duration_samples.
|
||||
// Release the memory for the Spectrogram objects.
|
||||
int required_padding_samples = frame_step_samples() - 1;
|
||||
if (cumulative_input_samples_ < frame_duration_samples_) {
|
||||
required_padding_samples =
|
||||
frame_duration_samples_ - cumulative_input_samples_;
|
||||
}
|
||||
return ProcessVector(
|
||||
Matrix::Zero(num_input_channels_, required_padding_samples), cc);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message SpectrogramCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional SpectrogramCalculatorOptions ext = 76186688;
|
||||
}
|
||||
|
||||
// Options mirror those of TimeSeriesFramerCalculator.
|
||||
|
||||
// Analysis window duration in seconds. Required. Must be greater than 0.
|
||||
// (Note: the spectrogram DFT length will be the smallest power-of-2
|
||||
// sample count that can hold this duration.)
|
||||
optional double frame_duration_seconds = 1;
|
||||
|
||||
// Duration of overlap between adjacent windows.
|
||||
// Hence, frame_rate = 1/(frame_duration_seconds - frame_overlap_seconds).
|
||||
// Required that 0 <= frame_overlap_seconds < frame_duration_seconds.
|
||||
optional double frame_overlap_seconds = 2 [default = 0.0];
|
||||
|
||||
// Whether to pad the final packet with zeros. If true, guarantees that
|
||||
// all input samples will output. If set to false, any partial packet
|
||||
// at the end of the stream will be dropped.
|
||||
optional bool pad_final_packet = 3 [default = true];
|
||||
|
||||
// Output value type can be squared-magnitude, linear-magnitude,
|
||||
// deciBels (dB, = 20*log10(linear_magnitude)), or std::complex.
|
||||
enum OutputType {
|
||||
SQUARED_MAGNITUDE = 0;
|
||||
LINEAR_MAGNITUDE = 1;
|
||||
DECIBELS = 2;
|
||||
COMPLEX = 3;
|
||||
}
|
||||
optional OutputType output_type = 4 [default = SQUARED_MAGNITUDE];
|
||||
|
||||
// If set to true then the output will be a vector of spectrograms, one for
|
||||
// each channel and the stream will have a MultiStreamTimeSeriesHeader.
|
||||
optional bool allow_multichannel_input = 5 [default = false];
|
||||
|
||||
// Which window to use when computing the FFT.
|
||||
enum WindowType {
|
||||
HANN = 0;
|
||||
HAMMING = 1;
|
||||
COSINE = 2;
|
||||
}
|
||||
optional WindowType window_type = 6 [default = HANN];
|
||||
|
||||
// Support a fixed multiplicative scaling of the output. This is applied
|
||||
// uniformly regardless of output type (i.e., even dBs are multiplied, not
|
||||
// offset).
|
||||
optional double output_scale = 7 [default = 1.0];
|
||||
|
||||
// If use_local_timestamp is true, the output packet's timestamp is based on
|
||||
// the last sample of the packet and it's inferred from the latest input
|
||||
// packet's timestamp. If false, the output packet's timestamp is based on
|
||||
// the cumulative timestamping, which is inferred from the intial input
|
||||
// timestamp and the cumulative number of samples.
|
||||
optional bool use_local_timestamp = 8 [default = false];
|
||||
}
|
||||
|
|
@ -1,895 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "audio/dsp/number_util.h"
|
||||
#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/benchmark.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
const int kInitialTimestampOffsetMicroseconds = 4;
|
||||
|
||||
class SpectrogramCalculatorTest
|
||||
: public TimeSeriesCalculatorTest<SpectrogramCalculatorOptions> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "SpectrogramCalculator";
|
||||
input_sample_rate_ = 4000.0;
|
||||
num_input_channels_ = 1;
|
||||
}
|
||||
|
||||
// Initializes and runs the test graph.
|
||||
absl::Status Run() {
|
||||
// Now that options are set, we can set up some internal constants.
|
||||
frame_duration_samples_ =
|
||||
round(options_.frame_duration_seconds() * input_sample_rate_);
|
||||
frame_step_samples_ =
|
||||
frame_duration_samples_ -
|
||||
round(options_.frame_overlap_seconds() * input_sample_rate_);
|
||||
// The magnitude of the 0th FFT bin (DC) should be sum(input.*window);
|
||||
// for an input identically 1.0, this is just sum(window). The average
|
||||
// value of our Hann window is 0.5, hence this is the expected squared-
|
||||
// magnitude output value in the DC bin for constant input of 1.0.
|
||||
expected_dc_squared_magnitude_ =
|
||||
pow((static_cast<float>(frame_duration_samples_) * 0.5), 2.0);
|
||||
|
||||
return RunGraph();
|
||||
}
|
||||
|
||||
// Creates test multichannel input with specified packet sizes and containing
|
||||
// a constant-frequency sinusoid that maintains phase between adjacent
|
||||
// packets.
|
||||
void SetupCosineInputPackets(const std::vector<int>& packet_sizes_samples,
|
||||
float cosine_frequency_hz) {
|
||||
int total_num_input_samples = 0;
|
||||
for (int packet_size_samples : packet_sizes_samples) {
|
||||
double packet_start_time_seconds =
|
||||
kInitialTimestampOffsetMicroseconds * 1e-6 +
|
||||
total_num_input_samples / input_sample_rate_;
|
||||
double packet_end_time_seconds =
|
||||
packet_start_time_seconds + packet_size_samples / input_sample_rate_;
|
||||
double angular_freq = 2 * M_PI * cosine_frequency_hz;
|
||||
Matrix* packet_data =
|
||||
new Matrix(num_input_channels_, packet_size_samples);
|
||||
// Use Eigen's vectorized cos() function to fill the vector with a
|
||||
// sinusoid of appropriate frequency & phase.
|
||||
for (int i = 0; i < num_input_channels_; i++) {
|
||||
packet_data->row(i) =
|
||||
Eigen::ArrayXf::LinSpaced(packet_size_samples,
|
||||
packet_start_time_seconds * angular_freq,
|
||||
packet_end_time_seconds * angular_freq)
|
||||
.cos()
|
||||
.transpose();
|
||||
}
|
||||
int64 input_timestamp = round(packet_start_time_seconds *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
AppendInputPacket(packet_data, input_timestamp);
|
||||
total_num_input_samples += packet_size_samples;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup a sequence of input packets of specified sizes, each filled
|
||||
// with samples of 1.0.
|
||||
void SetupConstantInputPackets(const std::vector<int>& packet_sizes_samples) {
|
||||
// A 0 Hz cosine is identically 1.0 for all samples.
|
||||
SetupCosineInputPackets(packet_sizes_samples, 0.0);
|
||||
}
|
||||
|
||||
// Setup a sequence of input packets of specified sizes, each containing a
|
||||
// single sample of 1.0 at a specified offset.
|
||||
void SetupImpulseInputPackets(
|
||||
const std::vector<int>& packet_sizes_samples,
|
||||
const std::vector<int>& impulse_offsets_samples) {
|
||||
int total_num_input_samples = 0;
|
||||
for (int i = 0; i < packet_sizes_samples.size(); ++i) {
|
||||
double packet_start_time_seconds =
|
||||
kInitialTimestampOffsetMicroseconds * 1e-6 +
|
||||
total_num_input_samples / input_sample_rate_;
|
||||
int64 input_timestamp = round(packet_start_time_seconds *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
std::unique_ptr<Matrix> impulse(
|
||||
new Matrix(Matrix::Zero(1, packet_sizes_samples[i])));
|
||||
(*impulse)(0, impulse_offsets_samples[i]) = 1.0;
|
||||
AppendInputPacket(impulse.release(), input_timestamp);
|
||||
total_num_input_samples += packet_sizes_samples[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Creates test multichannel input with specified packet sizes and containing
|
||||
// constant input packets for the even channels and constant-frequency
|
||||
// sinusoid that maintains phase between adjacent packets for the odd
|
||||
// channels.
|
||||
void SetupMultichannelInputPackets(
|
||||
const std::vector<int>& packet_sizes_samples, float cosine_frequency_hz) {
|
||||
int total_num_input_samples = 0;
|
||||
for (int packet_size_samples : packet_sizes_samples) {
|
||||
double packet_start_time_seconds =
|
||||
kInitialTimestampOffsetMicroseconds * 1e-6 +
|
||||
total_num_input_samples / input_sample_rate_;
|
||||
double packet_end_time_seconds =
|
||||
packet_start_time_seconds + packet_size_samples / input_sample_rate_;
|
||||
double angular_freq;
|
||||
Matrix* packet_data =
|
||||
new Matrix(num_input_channels_, packet_size_samples);
|
||||
// Use Eigen's vectorized cos() function to fill the vector with a
|
||||
// sinusoid of appropriate frequency & phase.
|
||||
for (int i = 0; i < num_input_channels_; i++) {
|
||||
if (i % 2 == 0) {
|
||||
angular_freq = 0;
|
||||
} else {
|
||||
angular_freq = 2 * M_PI * cosine_frequency_hz;
|
||||
}
|
||||
packet_data->row(i) =
|
||||
Eigen::ArrayXf::LinSpaced(packet_size_samples,
|
||||
packet_start_time_seconds * angular_freq,
|
||||
packet_end_time_seconds * angular_freq)
|
||||
.cos()
|
||||
.transpose();
|
||||
}
|
||||
int64 input_timestamp = round(packet_start_time_seconds *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
AppendInputPacket(packet_data, input_timestamp);
|
||||
total_num_input_samples += packet_size_samples;
|
||||
}
|
||||
}
|
||||
|
||||
// Return vector of the numbers of frames in each output packet.
|
||||
std::vector<int> OutputFramesPerPacket() {
|
||||
std::vector<int> frame_counts;
|
||||
for (const Packet& packet : output().packets) {
|
||||
const Matrix& matrix = packet.Get<Matrix>();
|
||||
frame_counts.push_back(matrix.cols());
|
||||
}
|
||||
return frame_counts;
|
||||
}
|
||||
|
||||
// Checks output headers and Timestamps.
|
||||
void CheckOutputHeadersAndTimestamps() {
|
||||
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
|
||||
|
||||
TimeSeriesHeader expected_header = input().header.Get<TimeSeriesHeader>();
|
||||
expected_header.set_num_channels(fft_size / 2 + 1);
|
||||
// The output header sample rate should depend on the output frame step.
|
||||
expected_header.set_sample_rate(input_sample_rate_ / frame_step_samples_);
|
||||
// SpectrogramCalculator stores the sample rate of the input in
|
||||
// the TimeSeriesHeader.
|
||||
expected_header.set_audio_sample_rate(input_sample_rate_);
|
||||
// We expect the output header to have num_samples and packet_rate unset.
|
||||
expected_header.clear_num_samples();
|
||||
expected_header.clear_packet_rate();
|
||||
if (!options_.allow_multichannel_input()) {
|
||||
ExpectOutputHeaderEquals(expected_header);
|
||||
} else {
|
||||
EXPECT_THAT(output()
|
||||
.header.template Get<MultiStreamTimeSeriesHeader>()
|
||||
.time_series_header(),
|
||||
mediapipe::EqualsProto(expected_header));
|
||||
EXPECT_THAT(output()
|
||||
.header.template Get<MultiStreamTimeSeriesHeader>()
|
||||
.num_streams(),
|
||||
num_input_channels_);
|
||||
}
|
||||
|
||||
int cumulative_output_frames = 0;
|
||||
// The timestamps coming out of the spectrogram correspond to the
|
||||
// middle of the first frame's window, hence frame_duration_samples_/2
|
||||
// term. We use frame_duration_samples_ because that is how it is
|
||||
// actually quantized inside spectrogram.
|
||||
const double packet_timestamp_offset_seconds =
|
||||
kInitialTimestampOffsetMicroseconds * 1e-6;
|
||||
const double frame_step_seconds = frame_step_samples_ / input_sample_rate_;
|
||||
|
||||
Timestamp initial_timestamp = Timestamp::Unstarted();
|
||||
|
||||
for (const Packet& packet : output().packets) {
|
||||
// This is the timestamp we expect based on how the spectrogram should
|
||||
// behave (advancing by one step's worth of input samples each frame).
|
||||
const double expected_timestamp_seconds =
|
||||
packet_timestamp_offset_seconds +
|
||||
cumulative_output_frames * frame_step_seconds;
|
||||
const int64 expected_timestamp_ticks =
|
||||
expected_timestamp_seconds * Timestamp::kTimestampUnitsPerSecond;
|
||||
EXPECT_EQ(expected_timestamp_ticks, packet.Timestamp().Value());
|
||||
// Accept the timestamp of the first packet as the baseline for checking
|
||||
// the remainder.
|
||||
if (initial_timestamp == Timestamp::Unstarted()) {
|
||||
initial_timestamp = packet.Timestamp();
|
||||
}
|
||||
// Also check that the timestamp is consistent with the sample_rate
|
||||
// in the output stream's TimeSeriesHeader.
|
||||
EXPECT_TRUE(time_series_util::LogWarningIfTimestampIsInconsistent(
|
||||
packet.Timestamp(), initial_timestamp, cumulative_output_frames,
|
||||
expected_header.sample_rate()));
|
||||
if (!options_.allow_multichannel_input()) {
|
||||
if (options_.output_type() == SpectrogramCalculatorOptions::COMPLEX) {
|
||||
const Eigen::MatrixXcf& matrix = packet.Get<Eigen::MatrixXcf>();
|
||||
cumulative_output_frames += matrix.cols();
|
||||
} else {
|
||||
const Matrix& matrix = packet.Get<Matrix>();
|
||||
cumulative_output_frames += matrix.cols();
|
||||
}
|
||||
} else {
|
||||
if (options_.output_type() == SpectrogramCalculatorOptions::COMPLEX) {
|
||||
const Eigen::MatrixXcf& matrix =
|
||||
packet.Get<std::vector<Eigen::MatrixXcf>>().at(0);
|
||||
cumulative_output_frames += matrix.cols();
|
||||
} else {
|
||||
const Matrix& matrix = packet.Get<std::vector<Matrix>>().at(0);
|
||||
cumulative_output_frames += matrix.cols();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that the bin corresponding to the specified frequency
|
||||
// is the largest one in one particular frame of a single packet.
|
||||
void CheckPeakFrequencyInPacketFrame(const Packet& packet, int frame,
|
||||
float frequency) {
|
||||
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
|
||||
const int target_bin =
|
||||
round((frequency / input_sample_rate_) * static_cast<float>(fft_size));
|
||||
|
||||
const Matrix& matrix = packet.Get<Matrix>();
|
||||
// Stop here if the requested frame is not in this packet.
|
||||
ASSERT_GT(matrix.cols(), frame);
|
||||
|
||||
int actual_largest_bin;
|
||||
matrix.col(frame).maxCoeff(&actual_largest_bin);
|
||||
EXPECT_EQ(actual_largest_bin, target_bin);
|
||||
}
|
||||
|
||||
// Verify that the bin corresponding to the specified frequency
|
||||
// is the largest one in one particular frame of a single spectrogram Matrix.
|
||||
void CheckPeakFrequencyInMatrix(const Matrix& matrix, int frame,
|
||||
float frequency) {
|
||||
const int fft_size = audio_dsp::NextPowerOfTwo(frame_duration_samples_);
|
||||
const int target_bin =
|
||||
round((frequency / input_sample_rate_) * static_cast<float>(fft_size));
|
||||
|
||||
// Stop here if the requested frame is not in this packet.
|
||||
ASSERT_GT(matrix.cols(), frame);
|
||||
|
||||
int actual_largest_bin;
|
||||
matrix.col(frame).maxCoeff(&actual_largest_bin);
|
||||
EXPECT_EQ(actual_largest_bin, target_bin);
|
||||
}
|
||||
|
||||
int frame_duration_samples_;
|
||||
int frame_step_samples_;
|
||||
// Expected DC output for a window of pure 1.0, set when window length
|
||||
// is set.
|
||||
float expected_dc_squared_magnitude_;
|
||||
};
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, IntegerFrameDurationNoOverlap) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {500, 200};
|
||||
const std::vector<int> expected_output_packet_sizes = {5, 2};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, IntegerFrameDurationSomeOverlap) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {500, 200};
|
||||
// complete_output_frames = 1 + floor((input_samples - window_length)/step)
|
||||
// = 1 + floor((500 - 100)/40) = 1 + 10 = 11 for the first packet
|
||||
// = 1 + floor((700 - 100)/40) = 1 + 15 = 16 for the whole stream
|
||||
// so expect 16 - 11 = 5 in the second packet.
|
||||
const std::vector<int> expected_output_packet_sizes = {11, 5};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, NonintegerFrameDurationAndOverlap) {
|
||||
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(58.4 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {500, 200};
|
||||
// now frame_duration_samples will be 99 (rounded), and frame_step_samples
|
||||
// will be (99-58) = 41, so the first packet of 500 samples will generate
|
||||
// 1 + floor(500-99)/41 = 10 samples.
|
||||
const std::vector<int> expected_output_packet_sizes = {10, 5};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, ShortInitialPacketNoOverlap) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {90, 100, 110};
|
||||
// The first input packet is too small to generate any frames,
|
||||
// but zero-length packets would result in a timestamp monotonicity
|
||||
// violation, so they are suppressed. Thus, only the second and third
|
||||
// input packets generate output packets.
|
||||
const std::vector<int> expected_output_packet_sizes = {1, 2};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, TrailingSamplesNoPad) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {140, 90};
|
||||
const std::vector<int> expected_output_packet_sizes = {2, 2};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, NoTrailingSamplesWithPad) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(true);
|
||||
const std::vector<int> input_packet_sizes = {140, 80};
|
||||
const std::vector<int> expected_output_packet_sizes = {2, 2};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, TrailingSamplesWithPad) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(true);
|
||||
const std::vector<int> input_packet_sizes = {140, 90};
|
||||
// In contrast to NoTrailingSamplesWithPad and TrailingSamplesNoPad,
|
||||
// this time we get an extra frame in an extra final packet.
|
||||
const std::vector<int> expected_output_packet_sizes = {2, 2, 1};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, VeryShortInputWillPad) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(true);
|
||||
const std::vector<int> input_packet_sizes = {30};
|
||||
const std::vector<int> expected_output_packet_sizes = {1};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, VeryShortInputZeroOutputFramesIfNoPad) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
const std::vector<int> input_packet_sizes = {90};
|
||||
const std::vector<int> expected_output_packet_sizes = {};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, DCSignalIsPeakBin) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
const std::vector<int> input_packet_sizes = {140}; // Gives 2 output frames.
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
const float dc_frequency_hz = 0.0;
|
||||
CheckPeakFrequencyInPacketFrame(output().packets[0], 0, dc_frequency_hz);
|
||||
CheckPeakFrequencyInPacketFrame(output().packets[0], 1, dc_frequency_hz);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, A440ToneIsPeakBin) {
|
||||
const std::vector<int> input_packet_sizes = {
|
||||
460}; // 100 + 9*40 for 10 frames.
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
int num_output_frames = output().packets[0].Get<Matrix>().cols();
|
||||
for (int frame = 0; frame < num_output_frames; ++frame) {
|
||||
CheckPeakFrequencyInPacketFrame(output().packets[0], frame,
|
||||
tone_frequency_hz);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, SquaredMagnitudeOutputLooksRight) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::SQUARED_MAGNITUDE);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
|
||||
expected_dc_squared_magnitude_);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, DefaultOutputIsSquaredMagnitude) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
// Let the output_type be its default
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
|
||||
expected_dc_squared_magnitude_);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, LinearMagnitudeOutputLooksRight) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::LINEAR_MAGNITUDE);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
|
||||
std::sqrt(expected_dc_squared_magnitude_));
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, DbMagnitudeOutputLooksRight) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::DECIBELS);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(output().packets[0].Get<Matrix>()(0, 0),
|
||||
10.0 * std::log10(expected_dc_squared_magnitude_));
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, OutputScalingLooksRight) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::DECIBELS);
|
||||
double output_scale = 2.5;
|
||||
options_.set_output_scale(output_scale);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(
|
||||
output().packets[0].Get<Matrix>()(0, 0),
|
||||
output_scale * 10.0 * std::log10(expected_dc_squared_magnitude_));
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, ComplexOutputLooksRight) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Setup packets with DC input (non-zero constant value).
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_FLOAT_EQ(std::norm(output().packets[0].Get<Eigen::MatrixXcf>()(0, 0)),
|
||||
expected_dc_squared_magnitude_);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, ComplexOutputLooksRightForImpulses) {
|
||||
const int frame_size_samples = 100;
|
||||
options_.set_frame_duration_seconds(frame_size_samples / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(0.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
|
||||
const std::vector<int> input_packet_sizes = {frame_size_samples,
|
||||
frame_size_samples};
|
||||
const std::vector<int> input_packet_impulse_offsets = {49, 50};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
|
||||
// Make two impulse packets offset one sample from each other
|
||||
SetupImpulseInputPackets(input_packet_sizes, input_packet_impulse_offsets);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
const int num_buckets =
|
||||
(audio_dsp::NextPowerOfTwo(frame_size_samples) / 2) + 1;
|
||||
const float precision = 0.01f;
|
||||
auto norm_fn = [](const std::complex<float>& cf) { return std::norm(cf); };
|
||||
|
||||
// Both impulses should have (approximately) constant power across all
|
||||
// frequency bins
|
||||
EXPECT_TRUE(output()
|
||||
.packets[0]
|
||||
.Get<Eigen::MatrixXcf>()
|
||||
.unaryExpr(norm_fn)
|
||||
.isApproxToConstant(1.0f, precision));
|
||||
EXPECT_TRUE(output()
|
||||
.packets[1]
|
||||
.Get<Eigen::MatrixXcf>()
|
||||
.unaryExpr(norm_fn)
|
||||
.isApproxToConstant(1.0f, precision));
|
||||
|
||||
// Because the second Packet's impulse is delayed by exactly one sample with
|
||||
// respect to the first Packet's impulse, the second impulse should have
|
||||
// greater phase, and in the highest frequency bin, the real part should
|
||||
// (approximately) flip sign from the first Packet to the second
|
||||
EXPECT_LT(std::arg(output().packets[0].Get<Eigen::MatrixXcf>()(1, 0)),
|
||||
std::arg(output().packets[1].Get<Eigen::MatrixXcf>()(1, 0)));
|
||||
const float highest_bucket_real_ratio =
|
||||
output().packets[0].Get<Eigen::MatrixXcf>()(num_buckets - 1, 0).real() /
|
||||
output().packets[1].Get<Eigen::MatrixXcf>()(num_buckets - 1, 0).real();
|
||||
EXPECT_NEAR(highest_bucket_real_ratio, -1.0f, precision);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, SquaredMagnitudeOutputLooksRightForNonDC) {
|
||||
const int frame_size_samples = 100;
|
||||
options_.set_frame_duration_seconds(frame_size_samples / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::SQUARED_MAGNITUDE);
|
||||
const std::vector<int> input_packet_sizes = {140};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
// Make the tone have an integral number of cycles within the window
|
||||
const int target_bin = 16;
|
||||
const int fft_size = audio_dsp::NextPowerOfTwo(frame_size_samples);
|
||||
const float tone_frequency_hz = target_bin * (input_sample_rate_ / fft_size);
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
// For a non-DC bin, the magnitude will be split between positive and
|
||||
// negative frequency bins, so it should about be half-magnitude
|
||||
// = quarter-power.
|
||||
// It's not quite exact because of the interference from the hann(100)
|
||||
// spread from the negative-frequency half.
|
||||
EXPECT_GT(output().packets[0].Get<Matrix>()(target_bin, 0),
|
||||
0.98 * expected_dc_squared_magnitude_ / 4.0);
|
||||
EXPECT_LT(output().packets[0].Get<Matrix>()(target_bin, 0),
|
||||
1.02 * expected_dc_squared_magnitude_ / 4.0);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, ZeroOutputsForZeroInputsWithPaddingEnabled) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(true);
|
||||
const std::vector<int> input_packet_sizes = {};
|
||||
const std::vector<int> expected_output_packet_sizes = {};
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(OutputFramesPerPacket(), expected_output_packet_sizes);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, NumChannelsIsRight) {
|
||||
const std::vector<int> input_packet_sizes = {460};
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
options_.set_allow_multichannel_input(true);
|
||||
num_input_channels_ = 3;
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
EXPECT_EQ(output().packets[0].Get<std::vector<Matrix>>().size(),
|
||||
num_input_channels_);
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, NumSamplesAndPacketRateAreCleared) {
|
||||
num_input_samples_ = 500;
|
||||
input_packet_rate_ = 1.0;
|
||||
const std::vector<int> input_packet_sizes = {num_input_samples_};
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(0.0);
|
||||
options_.set_pad_final_packet(false);
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
SetupConstantInputPackets(input_packet_sizes);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
const TimeSeriesHeader& output_header =
|
||||
output().header.Get<TimeSeriesHeader>();
|
||||
EXPECT_FALSE(output_header.has_num_samples());
|
||||
EXPECT_FALSE(output_header.has_packet_rate());
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, MultichannelSpectrogramSizesAreRight) {
|
||||
const std::vector<int> input_packet_sizes = {420}; // less than 10 frames
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
options_.set_allow_multichannel_input(true);
|
||||
num_input_channels_ = 10;
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
|
||||
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
|
||||
int spectrogram_num_rows = spectrograms[0].rows();
|
||||
int spectrogram_num_cols = spectrograms[0].cols();
|
||||
for (int i = 1; i < num_input_channels_; i++) {
|
||||
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
|
||||
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, MultichannelSpectrogramValuesAreRight) {
|
||||
const std::vector<int> input_packet_sizes = {
|
||||
460}; // 100 + 9*40 for 10 frames.
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_allow_multichannel_input(true);
|
||||
num_input_channels_ = 10;
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupMultichannelInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
|
||||
int num_output_frames = spectrograms[0].cols();
|
||||
for (int i = 0; i < num_input_channels_; i++) {
|
||||
for (int frame = 0; frame < num_output_frames; ++frame) {
|
||||
if (i % 2 == 0) {
|
||||
CheckPeakFrequencyInMatrix(spectrograms[i], frame, 0);
|
||||
} else {
|
||||
CheckPeakFrequencyInMatrix(spectrograms[i], frame, tone_frequency_hz);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest, MultichannelHandlesShortInitialPacket) {
|
||||
// First packet is less than one frame, but second packet should trigger a
|
||||
// complete frame from all channels.
|
||||
const std::vector<int> input_packet_sizes = {50, 50};
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
options_.set_allow_multichannel_input(true);
|
||||
num_input_channels_ = 2;
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
auto spectrograms = output().packets[0].Get<std::vector<Matrix>>();
|
||||
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
|
||||
int spectrogram_num_rows = spectrograms[0].rows();
|
||||
int spectrogram_num_cols = spectrograms[0].cols();
|
||||
for (int i = 1; i < num_input_channels_; i++) {
|
||||
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
|
||||
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SpectrogramCalculatorTest,
|
||||
MultichannelComplexHandlesShortInitialPacket) {
|
||||
// First packet is less than one frame, but second packet should trigger a
|
||||
// complete frame from all channels, even for complex output.
|
||||
const std::vector<int> input_packet_sizes = {50, 50};
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(60.0 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
options_.set_allow_multichannel_input(true);
|
||||
options_.set_output_type(SpectrogramCalculatorOptions::COMPLEX);
|
||||
num_input_channels_ = 2;
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
const float tone_frequency_hz = 440.0;
|
||||
SetupCosineInputPackets(input_packet_sizes, tone_frequency_hz);
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
CheckOutputHeadersAndTimestamps();
|
||||
auto spectrograms = output().packets[0].Get<std::vector<Eigen::MatrixXcf>>();
|
||||
EXPECT_FLOAT_EQ(spectrograms.size(), num_input_channels_);
|
||||
int spectrogram_num_rows = spectrograms[0].rows();
|
||||
int spectrogram_num_cols = spectrograms[0].cols();
|
||||
for (int i = 1; i < num_input_channels_; i++) {
|
||||
EXPECT_EQ(spectrogram_num_rows, spectrograms[i].rows());
|
||||
EXPECT_EQ(spectrogram_num_cols, spectrograms[i].cols());
|
||||
}
|
||||
}
|
||||
|
||||
void BM_ProcessDC(benchmark::State& state) {
|
||||
CalculatorGraphConfig::Node node_config;
|
||||
node_config.set_calculator("SpectrogramCalculator");
|
||||
node_config.add_input_stream("input_audio");
|
||||
node_config.add_output_stream("output_spectrogram");
|
||||
|
||||
SpectrogramCalculatorOptions* options =
|
||||
node_config.mutable_options()->MutableExtension(
|
||||
SpectrogramCalculatorOptions::ext);
|
||||
options->set_frame_duration_seconds(0.010);
|
||||
options->set_frame_overlap_seconds(0.0);
|
||||
options->set_pad_final_packet(false);
|
||||
*node_config.mutable_options()->MutableExtension(
|
||||
SpectrogramCalculatorOptions::ext) = *options;
|
||||
|
||||
int num_input_channels = 1;
|
||||
int packet_size_samples = 1600000;
|
||||
TimeSeriesHeader* header = new TimeSeriesHeader();
|
||||
header->set_sample_rate(16000.0);
|
||||
header->set_num_channels(num_input_channels);
|
||||
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableInputs()->Index(0).header = Adopt(header);
|
||||
|
||||
Matrix* payload = new Matrix(
|
||||
Matrix::Constant(num_input_channels, packet_size_samples, 1.0));
|
||||
Timestamp timestamp = Timestamp(0);
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(payload).At(timestamp));
|
||||
|
||||
for (auto _ : state) {
|
||||
ASSERT_TRUE(runner.Run().ok());
|
||||
}
|
||||
|
||||
const CalculatorRunner::StreamContents& output = runner.Outputs().Index(0);
|
||||
const Matrix& output_matrix = output.packets[0].Get<Matrix>();
|
||||
LOG(INFO) << "Output matrix=" << output_matrix.rows() << "x"
|
||||
<< output_matrix.cols();
|
||||
LOG(INFO) << "First values=" << output_matrix(0, 0) << ", "
|
||||
<< output_matrix(1, 0) << ", " << output_matrix(2, 0) << ", "
|
||||
<< output_matrix(3, 0);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ProcessDC);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,99 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Defines StabilizedLogCalculator.
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/calculators/audio/stabilized_log_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/proto_ns.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "StabilizedLogCalculator"
|
||||
// input_stream: "input_time_series"
|
||||
// output_stream: "stabilized_log_time_series"
|
||||
// options {
|
||||
// [mediapipe.StabilizedLogCalculatorOptions.ext] {
|
||||
// stabilizer: .00001
|
||||
// check_nonnegativity: true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class StabilizedLogCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Input stream with TimeSeriesHeader.
|
||||
);
|
||||
cc->Outputs().Index(0).Set<Matrix>(
|
||||
// Output stabilized log stream with TimeSeriesHeader.
|
||||
);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
StabilizedLogCalculatorOptions stabilized_log_calculator_options =
|
||||
cc->Options<StabilizedLogCalculatorOptions>();
|
||||
|
||||
stabilizer_ = stabilized_log_calculator_options.stabilizer();
|
||||
output_scale_ = stabilized_log_calculator_options.output_scale();
|
||||
check_nonnegativity_ =
|
||||
stabilized_log_calculator_options.check_nonnegativity();
|
||||
CHECK_GE(stabilizer_, 0.0)
|
||||
<< "stabilizer must be >= 0.0, received a value of " << stabilizer_;
|
||||
|
||||
// If the input packets have a header, propagate the header to the output.
|
||||
if (!cc->Inputs().Index(0).Header().IsEmpty()) {
|
||||
TimeSeriesHeader input_header;
|
||||
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
cc->Outputs().Index(0).SetHeader(
|
||||
Adopt(new TimeSeriesHeader(input_header)));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
auto input_matrix = cc->Inputs().Index(0).Get<Matrix>();
|
||||
if (input_matrix.array().isNaN().any()) {
|
||||
return absl::InvalidArgumentError("NaN input to log operation.");
|
||||
}
|
||||
if (check_nonnegativity_) {
|
||||
if (input_matrix.minCoeff() < 0.0) {
|
||||
return absl::OutOfRangeError("Negative input to log operation.");
|
||||
}
|
||||
}
|
||||
std::unique_ptr<Matrix> output_frame(new Matrix(
|
||||
output_scale_ * (input_matrix.array() + stabilizer_).log().matrix()));
|
||||
cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
float stabilizer_;
|
||||
bool check_nonnegativity_;
|
||||
double output_scale_;
|
||||
};
|
||||
REGISTER_CALCULATOR(StabilizedLogCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message StabilizedLogCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional StabilizedLogCalculatorOptions ext = 101978339;
|
||||
}
|
||||
|
||||
// The calculator computes log(x + stabilizer). stabilizer must be >=
|
||||
// 0, with 0 indicating a lack of stabilization.
|
||||
optional float stabilizer = 1 [default = .00001];
|
||||
|
||||
// If true, CHECK that all input values in are >= 0. If false, the
|
||||
// code will take the log of the potentially negative input values
|
||||
// plus the stabilizer.
|
||||
optional bool check_nonnegativity = 2 [default = true];
|
||||
|
||||
// Support a fixed multiplicative scaling of the output.
|
||||
optional double output_scale = 3 [default = 1.0];
|
||||
}
|
||||
|
|
@ -1,141 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include <cmath>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/calculators/audio/stabilized_log_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
const float kStabilizer = 0.1;
|
||||
const int kNumChannels = 3;
|
||||
const int kNumSamples = 10;
|
||||
|
||||
class StabilizedLogCalculatorTest
|
||||
: public TimeSeriesCalculatorTest<StabilizedLogCalculatorOptions> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "StabilizedLogCalculator";
|
||||
options_.set_stabilizer(kStabilizer);
|
||||
|
||||
input_sample_rate_ = 8000.0;
|
||||
num_input_channels_ = kNumChannels;
|
||||
num_input_samples_ = kNumSamples;
|
||||
}
|
||||
|
||||
void RunGraphNoReturn() { MP_ASSERT_OK(RunGraph()); }
|
||||
};
|
||||
|
||||
TEST_F(StabilizedLogCalculatorTest, BasicOperation) {
|
||||
const int kNumPackets = 5;
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
|
||||
std::vector<Matrix> input_data_matrices;
|
||||
for (int input_packet = 0; input_packet < kNumPackets; ++input_packet) {
|
||||
const int64 timestamp = input_packet * Timestamp::kTimestampUnitsPerSecond;
|
||||
Matrix input_data_matrix =
|
||||
Matrix::Random(kNumChannels, kNumSamples).array().abs();
|
||||
input_data_matrices.push_back(input_data_matrix);
|
||||
AppendInputPacket(new Matrix(input_data_matrix), timestamp);
|
||||
}
|
||||
|
||||
MP_ASSERT_OK(RunGraph());
|
||||
ExpectOutputHeaderEqualsInputHeader();
|
||||
for (int output_packet = 0; output_packet < kNumPackets; ++output_packet) {
|
||||
ExpectApproximatelyEqual(
|
||||
(input_data_matrices[output_packet].array() + kStabilizer).log(),
|
||||
runner_->Outputs().Index(0).packets[output_packet].Get<Matrix>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(StabilizedLogCalculatorTest, OutputScaleWorks) {
|
||||
const int kNumPackets = 5;
|
||||
double output_scale = 2.5;
|
||||
options_.set_output_scale(output_scale);
|
||||
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
|
||||
std::vector<Matrix> input_data_matrices;
|
||||
for (int input_packet = 0; input_packet < kNumPackets; ++input_packet) {
|
||||
const int64 timestamp = input_packet * Timestamp::kTimestampUnitsPerSecond;
|
||||
Matrix input_data_matrix =
|
||||
Matrix::Random(kNumChannels, kNumSamples).array().abs();
|
||||
input_data_matrices.push_back(input_data_matrix);
|
||||
AppendInputPacket(new Matrix(input_data_matrix), timestamp);
|
||||
}
|
||||
|
||||
MP_ASSERT_OK(RunGraph());
|
||||
ExpectOutputHeaderEqualsInputHeader();
|
||||
for (int output_packet = 0; output_packet < kNumPackets; ++output_packet) {
|
||||
ExpectApproximatelyEqual(
|
||||
output_scale *
|
||||
((input_data_matrices[output_packet].array() + kStabilizer).log()),
|
||||
runner_->Outputs().Index(0).packets[output_packet].Get<Matrix>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(StabilizedLogCalculatorTest, ZerosAreStabilized) {
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
AppendInputPacket(new Matrix(Matrix::Zero(kNumChannels, kNumSamples)),
|
||||
0 /* timestamp */);
|
||||
MP_ASSERT_OK(RunGraph());
|
||||
ExpectOutputHeaderEqualsInputHeader();
|
||||
ExpectApproximatelyEqual(
|
||||
Matrix::Constant(kNumChannels, kNumSamples, kStabilizer).array().log(),
|
||||
runner_->Outputs().Index(0).packets[0].Get<Matrix>());
|
||||
}
|
||||
|
||||
TEST_F(StabilizedLogCalculatorTest, NanValuesReturnError) {
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
AppendInputPacket(
|
||||
new Matrix(Matrix::Constant(kNumChannels, kNumSamples, std::nanf(""))),
|
||||
0 /* timestamp */);
|
||||
ASSERT_FALSE(RunGraph().ok());
|
||||
}
|
||||
|
||||
TEST_F(StabilizedLogCalculatorTest, NegativeValuesReturnError) {
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
AppendInputPacket(
|
||||
new Matrix(Matrix::Constant(kNumChannels, kNumSamples, -1.0)),
|
||||
0 /* timestamp */);
|
||||
ASSERT_FALSE(RunGraph().ok());
|
||||
}
|
||||
|
||||
TEST_F(StabilizedLogCalculatorTest, NegativeValuesDoNotCheckFailIfCheckIsOff) {
|
||||
options_.set_check_nonnegativity(false);
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
AppendInputPacket(
|
||||
new Matrix(Matrix::Constant(kNumChannels, kNumSamples, -1.0)),
|
||||
0 /* timestamp */);
|
||||
MP_ASSERT_OK(RunGraph());
|
||||
// Results are undefined.
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
27
mediapipe/calculators/audio/testdata/BUILD
vendored
27
mediapipe/calculators/audio/testdata/BUILD
vendored
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright 2019 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
filegroup(
|
||||
name = "test_audios",
|
||||
srcs = [
|
||||
"sine_wave_1k_44100_mono_2_sec_wav.audio",
|
||||
"sine_wave_1k_44100_stereo_2_sec_aac.audio",
|
||||
"sine_wave_1k_44100_stereo_2_sec_mp3.audio",
|
||||
"sine_wave_1k_48000_stereo_2_sec_wav.audio",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,325 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Defines TimeSeriesFramerCalculator.
|
||||
#include <math.h>
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "audio/dsp/window_functions.h"
|
||||
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// MediaPipe Calculator for framing a (vector-valued) input time series,
|
||||
// i.e. for breaking an input time series into fixed-size, possibly
|
||||
// overlapping, frames. The output stream's frame duration is
|
||||
// specified by frame_duration_seconds in the
|
||||
// TimeSeriesFramerCalculatorOptions, and the output's overlap is
|
||||
// specified by frame_overlap_seconds.
|
||||
//
|
||||
// This calculator assumes that the input timestamps refer to the
|
||||
// first sample in each Matrix. The output timestamps follow this
|
||||
// same convention.
|
||||
//
|
||||
// All output frames will have exactly the same number of samples: the number of
|
||||
// samples that approximates frame_duration_seconds most closely.
|
||||
//
|
||||
// Similarly, frame overlap is by default the (fixed) number of samples
|
||||
// approximating frame_overlap_seconds most closely. But if
|
||||
// emulate_fractional_frame_overlap is set to true, frame overlap is a variable
|
||||
// number of samples instead, such that the long-term average step between
|
||||
// frames is the difference between the (nominal) frame_duration_seconds and
|
||||
// frame_overlap_seconds.
|
||||
//
|
||||
// If pad_final_packet is true, all input samples will be emitted and the final
|
||||
// packet will be zero padded as necessary. If pad_final_packet is false, some
|
||||
// samples may be dropped at the end of the stream.
|
||||
//
|
||||
// If use_local_timestamp is true, the output packet's timestamp is based on the
|
||||
// last sample of the packet. The timestamp of this sample is inferred by
|
||||
// input_packet_timesamp + local_sample_index / sampling_rate_. If false, the
|
||||
// output packet's timestamp is based on the cumulative timestamping, which is
|
||||
// done by adopting the timestamp of the first sample of the packet and this
|
||||
// sample's timestamp is inferred by initial_input_timestamp_ +
|
||||
// cumulative_completed_samples / sample_rate_.
|
||||
class TimeSeriesFramerCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<Matrix>(
|
||||
// Input stream with TimeSeriesHeader.
|
||||
);
|
||||
cc->Outputs().Index(0).Set<Matrix>(
|
||||
// Fixed length time series Packets with TimeSeriesHeader.
|
||||
);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Returns FAIL if the input stream header is invalid.
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
|
||||
// Outputs as many framed packets as possible given the accumulated
|
||||
// input. Always returns OK.
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
// Flushes any remaining samples in a zero-padded packet. Always
|
||||
// returns OK.
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Adds input data to the internal buffer.
|
||||
void EnqueueInput(CalculatorContext* cc);
|
||||
// Constructs and emits framed output packets.
|
||||
void FrameOutput(CalculatorContext* cc);
|
||||
|
||||
Timestamp CurrentOutputTimestamp() {
|
||||
if (use_local_timestamp_) {
|
||||
return current_timestamp_;
|
||||
}
|
||||
return CumulativeOutputTimestamp();
|
||||
}
|
||||
|
||||
Timestamp CumulativeOutputTimestamp() {
|
||||
return initial_input_timestamp_ +
|
||||
round(cumulative_completed_samples_ / sample_rate_ *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
}
|
||||
|
||||
// Returns the timestamp of a sample on a base, which is usually the time
|
||||
// stamp of a packet.
|
||||
Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base,
|
||||
int64 number_of_samples) {
|
||||
return timestamp_base + round(number_of_samples / sample_rate_ *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
}
|
||||
|
||||
// The number of input samples to advance after the current output frame is
|
||||
// emitted.
|
||||
int next_frame_step_samples() const {
|
||||
// All numbers are in input samples.
|
||||
const int64 current_output_frame_start = static_cast<int64>(
|
||||
round(cumulative_output_frames_ * average_frame_step_samples_));
|
||||
CHECK_EQ(current_output_frame_start, cumulative_completed_samples_);
|
||||
const int64 next_output_frame_start = static_cast<int64>(
|
||||
round((cumulative_output_frames_ + 1) * average_frame_step_samples_));
|
||||
return next_output_frame_start - current_output_frame_start;
|
||||
}
|
||||
|
||||
double sample_rate_;
|
||||
bool pad_final_packet_;
|
||||
int frame_duration_samples_;
|
||||
// The advance, in input samples, between the start of successive output
|
||||
// frames. This may be a non-integer average value if
|
||||
// emulate_fractional_frame_overlap is true.
|
||||
double average_frame_step_samples_;
|
||||
int samples_still_to_drop_;
|
||||
int64 cumulative_output_frames_;
|
||||
// "Completed" samples are samples that are no longer needed because
|
||||
// the framer has completely stepped past them (taking into account
|
||||
// any overlap).
|
||||
int64 cumulative_completed_samples_;
|
||||
Timestamp initial_input_timestamp_;
|
||||
// The current timestamp is updated along with the incoming packets.
|
||||
Timestamp current_timestamp_;
|
||||
int num_channels_;
|
||||
|
||||
// Each entry in this deque consists of a single sample, i.e. a
|
||||
// single column vector, and its timestamp.
|
||||
std::deque<std::pair<Matrix, Timestamp>> sample_buffer_;
|
||||
|
||||
bool use_window_;
|
||||
Matrix window_;
|
||||
|
||||
bool use_local_timestamp_;
|
||||
};
|
||||
REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
|
||||
|
||||
void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) {
|
||||
const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>();
|
||||
|
||||
for (int i = 0; i < input_frame.cols(); ++i) {
|
||||
sample_buffer_.emplace_back(std::make_pair(
|
||||
input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i)));
|
||||
}
|
||||
}
|
||||
|
||||
void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
|
||||
while (sample_buffer_.size() >=
|
||||
frame_duration_samples_ + samples_still_to_drop_) {
|
||||
while (samples_still_to_drop_ > 0) {
|
||||
sample_buffer_.pop_front();
|
||||
--samples_still_to_drop_;
|
||||
}
|
||||
const int frame_step_samples = next_frame_step_samples();
|
||||
std::unique_ptr<Matrix> output_frame(
|
||||
new Matrix(num_channels_, frame_duration_samples_));
|
||||
for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_);
|
||||
++i) {
|
||||
output_frame->col(i) = sample_buffer_.front().first;
|
||||
current_timestamp_ = sample_buffer_.front().second;
|
||||
sample_buffer_.pop_front();
|
||||
}
|
||||
const int frame_overlap_samples =
|
||||
frame_duration_samples_ - frame_step_samples;
|
||||
if (frame_overlap_samples > 0) {
|
||||
for (int i = 0; i < frame_overlap_samples; ++i) {
|
||||
output_frame->col(i + frame_step_samples) = sample_buffer_[i].first;
|
||||
current_timestamp_ = sample_buffer_[i].second;
|
||||
}
|
||||
} else {
|
||||
samples_still_to_drop_ = -frame_overlap_samples;
|
||||
}
|
||||
|
||||
if (use_window_) {
|
||||
*output_frame = (output_frame->array() * window_.array()).matrix();
|
||||
}
|
||||
|
||||
cc->Outputs().Index(0).Add(output_frame.release(),
|
||||
CurrentOutputTimestamp());
|
||||
++cumulative_output_frames_;
|
||||
cumulative_completed_samples_ += frame_step_samples;
|
||||
}
|
||||
if (!use_local_timestamp_) {
|
||||
// In non-local timestamp mode the timestamp of the next packet will be
|
||||
// equal to CumulativeOutputTimestamp(). Inform the framework about this
|
||||
// fact to enable packet queueing optimizations.
|
||||
cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
|
||||
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
|
||||
initial_input_timestamp_ = cc->InputTimestamp();
|
||||
current_timestamp_ = initial_input_timestamp_;
|
||||
}
|
||||
|
||||
EnqueueInput(cc);
|
||||
FrameOutput(cc);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
|
||||
while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) {
|
||||
sample_buffer_.pop_front();
|
||||
--samples_still_to_drop_;
|
||||
}
|
||||
if (!sample_buffer_.empty() && pad_final_packet_) {
|
||||
std::unique_ptr<Matrix> output_frame(new Matrix);
|
||||
output_frame->setZero(num_channels_, frame_duration_samples_);
|
||||
for (int i = 0; i < sample_buffer_.size(); ++i) {
|
||||
output_frame->col(i) = sample_buffer_[i].first;
|
||||
current_timestamp_ = sample_buffer_[i].second;
|
||||
}
|
||||
|
||||
cc->Outputs().Index(0).Add(output_frame.release(),
|
||||
CurrentOutputTimestamp());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
|
||||
TimeSeriesFramerCalculatorOptions framer_options =
|
||||
cc->Options<TimeSeriesFramerCalculatorOptions>();
|
||||
|
||||
RET_CHECK_GT(framer_options.frame_duration_seconds(), 0.0)
|
||||
<< "Invalid or missing frame_duration_seconds. "
|
||||
<< "framer_duration_seconds: \n"
|
||||
<< framer_options.frame_duration_seconds();
|
||||
RET_CHECK_LT(framer_options.frame_overlap_seconds(),
|
||||
framer_options.frame_duration_seconds())
|
||||
<< "Invalid frame_overlap_seconds. framer_overlap_seconds: \n"
|
||||
<< framer_options.frame_overlap_seconds();
|
||||
|
||||
TimeSeriesHeader input_header;
|
||||
MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
|
||||
sample_rate_ = input_header.sample_rate();
|
||||
num_channels_ = input_header.num_channels();
|
||||
frame_duration_samples_ = time_series_util::SecondsToSamples(
|
||||
framer_options.frame_duration_seconds(), sample_rate_);
|
||||
RET_CHECK_GT(frame_duration_samples_, 0)
|
||||
<< "Frame duration of " << framer_options.frame_duration_seconds()
|
||||
<< "s too small to cover a single sample at " << sample_rate_ << " Hz ";
|
||||
if (framer_options.emulate_fractional_frame_overlap()) {
|
||||
// Frame step may be fractional.
|
||||
average_frame_step_samples_ = (framer_options.frame_duration_seconds() -
|
||||
framer_options.frame_overlap_seconds()) *
|
||||
sample_rate_;
|
||||
} else {
|
||||
// Frame step is an integer (stored in a double).
|
||||
average_frame_step_samples_ =
|
||||
frame_duration_samples_ -
|
||||
time_series_util::SecondsToSamples(
|
||||
framer_options.frame_overlap_seconds(), sample_rate_);
|
||||
}
|
||||
RET_CHECK_GE(average_frame_step_samples_, 1)
|
||||
<< "Frame step too small to cover a single sample at " << sample_rate_
|
||||
<< " Hz.";
|
||||
pad_final_packet_ = framer_options.pad_final_packet();
|
||||
|
||||
auto output_header = new TimeSeriesHeader(input_header);
|
||||
output_header->set_num_samples(frame_duration_samples_);
|
||||
if (round(average_frame_step_samples_) == average_frame_step_samples_) {
|
||||
// Only set output packet rate if it is fixed.
|
||||
output_header->set_packet_rate(sample_rate_ / average_frame_step_samples_);
|
||||
}
|
||||
cc->Outputs().Index(0).SetHeader(Adopt(output_header));
|
||||
cumulative_completed_samples_ = 0;
|
||||
cumulative_output_frames_ = 0;
|
||||
samples_still_to_drop_ = 0;
|
||||
initial_input_timestamp_ = Timestamp::Unstarted();
|
||||
current_timestamp_ = Timestamp::Unstarted();
|
||||
|
||||
std::vector<double> window_vector;
|
||||
use_window_ = false;
|
||||
switch (framer_options.window_function()) {
|
||||
case TimeSeriesFramerCalculatorOptions::HAMMING:
|
||||
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window_vector);
|
||||
use_window_ = true;
|
||||
break;
|
||||
case TimeSeriesFramerCalculatorOptions::HANN:
|
||||
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
|
||||
&window_vector);
|
||||
use_window_ = true;
|
||||
break;
|
||||
case TimeSeriesFramerCalculatorOptions::NONE:
|
||||
break;
|
||||
}
|
||||
|
||||
if (use_window_) {
|
||||
window_ = Matrix::Ones(num_channels_, 1) *
|
||||
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
|
||||
frame_duration_samples_)
|
||||
.cast<float>();
|
||||
}
|
||||
use_local_timestamp_ = framer_options.use_local_timestamp();
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message TimeSeriesFramerCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional TimeSeriesFramerCalculatorOptions ext = 50631621;
|
||||
}
|
||||
|
||||
// Frame duration in seconds. Required. Must be greater than 0. This is
|
||||
// rounded to the nearest integer number of samples.
|
||||
optional double frame_duration_seconds = 1;
|
||||
|
||||
// Frame overlap in seconds.
|
||||
//
|
||||
// If emulate_fractional_frame_overlap is false (the default), then the frame
|
||||
// overlap is rounded to the nearest integer number of samples, and the step
|
||||
// from one frame to the next will be the difference between the number of
|
||||
// samples in a frame and the number of samples in the overlap.
|
||||
//
|
||||
// If emulate_fractional_frame_overlap is true, then frame overlap will be a
|
||||
// variable number of samples, such that the long-time average time step from
|
||||
// one frame to the next will be the difference between the (nominal, not
|
||||
// rounded) frame_duration_seconds and frame_overlap_seconds. This is useful
|
||||
// where the desired time step is not an integral number of input samples.
|
||||
//
|
||||
// A negative frame_overlap_seconds corresponds to skipping some input samples
|
||||
// between each frame of emitted samples.
|
||||
//
|
||||
// Required that frame_overlap_seconds < frame_duration_seconds.
|
||||
optional double frame_overlap_seconds = 2 [default = 0.0];
|
||||
|
||||
// See frame_overlap_seconds for semantics.
|
||||
optional bool emulate_fractional_frame_overlap = 5 [default = false];
|
||||
|
||||
// Whether to pad the final packet with zeros. If true, guarantees that all
|
||||
// input samples (other than those that fall in gaps implied by negative
|
||||
// frame_overlap_seconds) will be emitted. If set to false, any partial
|
||||
// packet at the end of the stream will be dropped.
|
||||
optional bool pad_final_packet = 3 [default = true];
|
||||
|
||||
// Optional windowing function. The default is NONE (no windowing function).
|
||||
enum WindowFunction {
|
||||
NONE = 0;
|
||||
HAMMING = 1;
|
||||
HANN = 2;
|
||||
}
|
||||
optional WindowFunction window_function = 4 [default = NONE];
|
||||
|
||||
// If use_local_timestamp is true, the output packet's timestamp is based on
|
||||
// the last sample of the packet and it's inferred from the latest input
|
||||
// packet's timestamp. If false, the output packet's timestamp is based on
|
||||
// the cumulative timestamping, which is inferred from the intial input
|
||||
// timestamp and the cumulative number of samples.
|
||||
optional bool use_local_timestamp = 6 [default = false];
|
||||
}
|
||||
|
|
@ -1,485 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "audio/dsp/window_functions.h"
|
||||
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
const int kInitialTimestampOffsetMicroseconds = 4;
|
||||
const int kGapBetweenPacketsInSeconds = 1;
|
||||
const int kUniversalInputPacketSize = 50;
|
||||
|
||||
class TimeSeriesFramerCalculatorTest
|
||||
: public TimeSeriesCalculatorTest<TimeSeriesFramerCalculatorOptions> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
calculator_name_ = "TimeSeriesFramerCalculator";
|
||||
input_sample_rate_ = 4000.0;
|
||||
num_input_channels_ = 3;
|
||||
}
|
||||
|
||||
// Returns a float value with the channel and timestamp separated by
|
||||
// an order of magnitude, for easy parsing by humans.
|
||||
float TestValue(int64 timestamp_in_microseconds, int channel) {
|
||||
return timestamp_in_microseconds + channel / 10.0;
|
||||
}
|
||||
|
||||
// Caller takes ownership of the returned value.
|
||||
Matrix* NewTestFrame(int num_channels, int num_samples,
|
||||
double starting_timestamp_seconds) {
|
||||
auto matrix = new Matrix(num_channels, num_samples);
|
||||
for (int c = 0; c < num_channels; ++c) {
|
||||
for (int i = 0; i < num_samples; ++i) {
|
||||
int64 timestamp = time_series_util::SecondsToSamples(
|
||||
starting_timestamp_seconds + i / input_sample_rate_,
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
(*matrix)(c, i) = TestValue(timestamp, c);
|
||||
}
|
||||
}
|
||||
return matrix;
|
||||
}
|
||||
|
||||
// Initializes and runs the test graph.
|
||||
absl::Status Run() {
|
||||
InitializeGraph();
|
||||
|
||||
FillInputHeader();
|
||||
InitializeInput();
|
||||
|
||||
return RunGraph();
|
||||
}
|
||||
|
||||
// Creates test input and saves a reference copy.
|
||||
void InitializeInput() {
|
||||
concatenated_input_samples_.resize(0, num_input_channels_);
|
||||
num_input_samples_ = 0;
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
// This range of packet sizes was chosen such that some input
|
||||
// packets will be smaller than the output packet size and other
|
||||
// input packets will be larger.
|
||||
int packet_size = (i + 1) * 20;
|
||||
double timestamp_seconds = kInitialTimestampOffsetMicroseconds * 1.0e-6 +
|
||||
num_input_samples_ / input_sample_rate_;
|
||||
|
||||
Matrix* data_frame =
|
||||
NewTestFrame(num_input_channels_, packet_size, timestamp_seconds);
|
||||
|
||||
// Keep a reference copy of the input.
|
||||
//
|
||||
// conservativeResize() is needed here to preserve the existing
|
||||
// data. Eigen's resize() resizes without preserving data.
|
||||
concatenated_input_samples_.conservativeResize(
|
||||
num_input_channels_, num_input_samples_ + packet_size);
|
||||
concatenated_input_samples_.rightCols(packet_size) = *data_frame;
|
||||
num_input_samples_ += packet_size;
|
||||
|
||||
AppendInputPacket(data_frame, round(timestamp_seconds *
|
||||
Timestamp::kTimestampUnitsPerSecond));
|
||||
}
|
||||
|
||||
const int frame_duration_samples = FrameDurationSamples();
|
||||
std::vector<double> window_vector;
|
||||
switch (options_.window_function()) {
|
||||
case TimeSeriesFramerCalculatorOptions::HAMMING:
|
||||
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples,
|
||||
&window_vector);
|
||||
break;
|
||||
case TimeSeriesFramerCalculatorOptions::HANN:
|
||||
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples,
|
||||
&window_vector);
|
||||
break;
|
||||
case TimeSeriesFramerCalculatorOptions::NONE:
|
||||
window_vector.assign(frame_duration_samples, 1.0f);
|
||||
break;
|
||||
}
|
||||
|
||||
window_ = Matrix::Ones(num_input_channels_, 1) *
|
||||
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
|
||||
frame_duration_samples)
|
||||
.cast<float>();
|
||||
}
|
||||
|
||||
int FrameDurationSamples() {
|
||||
return time_series_util::SecondsToSamples(options_.frame_duration_seconds(),
|
||||
input_sample_rate_);
|
||||
}
|
||||
|
||||
// Checks that the values in the framed output packets matches the
|
||||
// appropriate values from the input.
|
||||
void CheckOutputPacketValues(const Matrix& actual, int packet_num,
|
||||
int frame_duration_samples,
|
||||
double frame_step_samples,
|
||||
int num_columns_to_check) {
|
||||
ASSERT_EQ(frame_duration_samples, actual.cols());
|
||||
Matrix expected = (concatenated_input_samples_
|
||||
.block(0, round(frame_step_samples * packet_num),
|
||||
num_input_channels_, num_columns_to_check)
|
||||
.array() *
|
||||
window_.leftCols(num_columns_to_check).array())
|
||||
.matrix();
|
||||
ExpectApproximatelyEqual(expected, actual.leftCols(num_columns_to_check));
|
||||
}
|
||||
|
||||
// Checks output headers, Timestamps, and values.
|
||||
void CheckOutput() {
|
||||
const int frame_duration_samples = FrameDurationSamples();
|
||||
const double frame_step_samples =
|
||||
options_.emulate_fractional_frame_overlap()
|
||||
? (options_.frame_duration_seconds() -
|
||||
options_.frame_overlap_seconds()) *
|
||||
input_sample_rate_
|
||||
: frame_duration_samples -
|
||||
time_series_util::SecondsToSamples(
|
||||
options_.frame_overlap_seconds(), input_sample_rate_);
|
||||
|
||||
TimeSeriesHeader expected_header = input().header.Get<TimeSeriesHeader>();
|
||||
expected_header.set_num_samples(frame_duration_samples);
|
||||
if (!options_.emulate_fractional_frame_overlap() ||
|
||||
frame_step_samples == round(frame_step_samples)) {
|
||||
expected_header.set_packet_rate(input_sample_rate_ / frame_step_samples);
|
||||
}
|
||||
ExpectOutputHeaderEquals(expected_header);
|
||||
|
||||
int num_full_packets = output().packets.size();
|
||||
if (options_.pad_final_packet()) {
|
||||
num_full_packets -= 1;
|
||||
}
|
||||
|
||||
for (int packet_num = 0; packet_num < num_full_packets; ++packet_num) {
|
||||
const Packet& packet = output().packets[packet_num];
|
||||
CheckOutputPacketValues(packet.Get<Matrix>(), packet_num,
|
||||
frame_duration_samples, frame_step_samples,
|
||||
frame_duration_samples);
|
||||
}
|
||||
|
||||
// What is the effective time index of the final sample emitted?
|
||||
// This includes accounting for the gaps when overlap is negative.
|
||||
const int num_unique_output_samples =
|
||||
round((output().packets.size() - 1) * frame_step_samples) +
|
||||
frame_duration_samples;
|
||||
LOG(INFO) << "packets.size()=" << output().packets.size()
|
||||
<< " frame_duration_samples=" << frame_duration_samples
|
||||
<< " frame_step_samples=" << frame_step_samples
|
||||
<< " num_input_samples_=" << num_input_samples_
|
||||
<< " num_unique_output_samples=" << num_unique_output_samples;
|
||||
const int num_padding_samples =
|
||||
num_unique_output_samples - num_input_samples_;
|
||||
if (options_.pad_final_packet()) {
|
||||
EXPECT_LT(num_padding_samples, frame_duration_samples);
|
||||
// If the input ended during the dropped samples between the end of
|
||||
// the last emitted frame and where the next one would begin, there
|
||||
// can be fewer unique output points than input points, even with
|
||||
// padding.
|
||||
const int max_dropped_samples =
|
||||
static_cast<int>(ceil(frame_step_samples - frame_duration_samples));
|
||||
EXPECT_GE(num_padding_samples, std::min(0, -max_dropped_samples));
|
||||
|
||||
if (num_padding_samples > 0) {
|
||||
// Check the non-padded part of the final packet.
|
||||
const Matrix& final_matrix = output().packets.back().Get<Matrix>();
|
||||
CheckOutputPacketValues(final_matrix, num_full_packets,
|
||||
frame_duration_samples, frame_step_samples,
|
||||
frame_duration_samples - num_padding_samples);
|
||||
// Check the padded part of the final packet.
|
||||
EXPECT_EQ(
|
||||
Matrix::Zero(num_input_channels_, num_padding_samples),
|
||||
final_matrix.block(0, frame_duration_samples - num_padding_samples,
|
||||
num_input_channels_, num_padding_samples));
|
||||
}
|
||||
} else {
|
||||
EXPECT_GT(num_padding_samples, -frame_duration_samples);
|
||||
EXPECT_LE(num_padding_samples, 0);
|
||||
}
|
||||
}
|
||||
|
||||
int num_input_samples_;
|
||||
Matrix concatenated_input_samples_;
|
||||
Matrix window_;
|
||||
};
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, IntegerSampleDurationNoOverlap) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
MP_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest,
|
||||
IntegerSampleDurationNoOverlapHammingWindow) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HAMMING);
|
||||
MP_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest,
|
||||
IntegerSampleDurationNoOverlapHannWindow) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HANN);
|
||||
MP_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, IntegerSampleDurationAndOverlap) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(40.0 / input_sample_rate_);
|
||||
MP_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, NonintegerSampleDurationAndOverlap) {
|
||||
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(38.4 / input_sample_rate_);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapExactFrames) {
|
||||
// Negative overlap means to drop samples between frames.
|
||||
// 100 samples per frame plus a skip of 10 samples will be 10 full frames in
|
||||
// the 1100 input samples.
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(-10.0 / input_sample_rate_);
|
||||
MP_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 10);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapExactFramesLessSkip) {
|
||||
// 100 samples per frame plus a skip of 100 samples will be 6 full frames in
|
||||
// the 1100 input samples.
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(-100.0 / input_sample_rate_);
|
||||
MP_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 6);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, NegativeOverlapWithPadding) {
|
||||
// 150 samples per frame plus a skip of 50 samples will require some padding
|
||||
// on the sixth and last frame given 1100 sample input.
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(-100.0 / input_sample_rate_);
|
||||
MP_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 6);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, FixedFrameOverlap) {
|
||||
// Frame of 30 samples with step of 11.4 samples (rounded to 11 samples)
|
||||
// results in ceil((1100 - 30) / 11) + 1 = 99 packets.
|
||||
options_.set_frame_duration_seconds(30 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds((30.0 - 11.4) / input_sample_rate_);
|
||||
MP_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 99);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, VariableFrameOverlap) {
|
||||
// Frame of 30 samples with step of 11.4 samples (not rounded)
|
||||
// results in ceil((1100 - 30) / 11.4) + 1 = 95 packets.
|
||||
options_.set_frame_duration_seconds(30 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds((30 - 11.4) / input_sample_rate_);
|
||||
options_.set_emulate_fractional_frame_overlap(true);
|
||||
MP_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 95);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, VariableFrameSkip) {
|
||||
// Frame of 30 samples with step of 41.4 samples (not rounded)
|
||||
// results in ceil((1100 - 30) / 41.4) + 1 = 27 packets.
|
||||
options_.set_frame_duration_seconds(30 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds((30 - 41.4) / input_sample_rate_);
|
||||
options_.set_emulate_fractional_frame_overlap(true);
|
||||
MP_ASSERT_OK(Run());
|
||||
EXPECT_EQ(output().packets.size(), 27);
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest, NoFinalPacketPadding) {
|
||||
options_.set_frame_duration_seconds(98.5 / input_sample_rate_);
|
||||
options_.set_pad_final_packet(false);
|
||||
|
||||
MP_ASSERT_OK(Run());
|
||||
CheckOutput();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest,
|
||||
FrameRateHigherThanSampleRate_FrameDurationTooLow) {
|
||||
// Try to produce a frame rate 10 times the input sample rate by using a
|
||||
// a frame duration that is too small and covers only 0.1 samples.
|
||||
options_.set_frame_duration_seconds(1 / (10 * input_sample_rate_));
|
||||
options_.set_frame_overlap_seconds(0.0);
|
||||
EXPECT_FALSE(Run().ok());
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTest,
|
||||
FrameRateHigherThanSampleRate_FrameStepTooLow) {
|
||||
// Try to produce a frame rate 10 times the input sample rate by using
|
||||
// a frame overlap that is too high and produces frame steps (difference
|
||||
// between duration and overlap) of 0.1 samples.
|
||||
options_.set_frame_duration_seconds(10.0 / input_sample_rate_);
|
||||
options_.set_frame_overlap_seconds(9.9 / input_sample_rate_);
|
||||
EXPECT_FALSE(Run().ok());
|
||||
}
|
||||
|
||||
// A simple test class to do windowing sanity checks. Tests from this
|
||||
// class input a single packet of all ones, and check the average
|
||||
// value of the single output packet. This is useful as a sanity check
|
||||
// that the correct windows are applied.
|
||||
class TimeSeriesFramerCalculatorWindowingSanityTest
|
||||
: public TimeSeriesFramerCalculatorTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TimeSeriesFramerCalculatorTest::SetUp();
|
||||
num_input_channels_ = 1;
|
||||
}
|
||||
|
||||
void RunAndTestSinglePacketAverage(float expected_average) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
InitializeGraph();
|
||||
FillInputHeader();
|
||||
AppendInputPacket(new Matrix(Matrix::Ones(1, FrameDurationSamples())),
|
||||
kInitialTimestampOffsetMicroseconds);
|
||||
MP_ASSERT_OK(RunGraph());
|
||||
ASSERT_EQ(1, output().packets.size());
|
||||
ASSERT_NEAR(expected_average * FrameDurationSamples(),
|
||||
output().packets[0].Get<Matrix>().sum(), 1e-5);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, NoWindowSanityCheck) {
|
||||
RunAndTestSinglePacketAverage(1.0f);
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest,
|
||||
HammingWindowSanityCheck) {
|
||||
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HAMMING);
|
||||
RunAndTestSinglePacketAverage(0.54f);
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, HannWindowSanityCheck) {
|
||||
options_.set_window_function(TimeSeriesFramerCalculatorOptions::HANN);
|
||||
RunAndTestSinglePacketAverage(0.5f);
|
||||
}
|
||||
|
||||
// A simple test class that checks the local packet time stamp. This class
|
||||
// generate a series of packets with and without gaps between packets and tests
|
||||
// the behavior with cumulative timestamping and local packet timestamping.
|
||||
class TimeSeriesFramerCalculatorTimestampingTest
|
||||
: public TimeSeriesFramerCalculatorTest {
|
||||
protected:
|
||||
// Creates test input and saves a reference copy.
|
||||
void InitializeInputForTimeStampingTest() {
|
||||
concatenated_input_samples_.resize(0, num_input_channels_);
|
||||
num_input_samples_ = 0;
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
// This range of packet sizes was chosen such that some input
|
||||
// packets will be smaller than the output packet size and other
|
||||
// input packets will be larger.
|
||||
int packet_size = kUniversalInputPacketSize;
|
||||
double timestamp_seconds = kInitialTimestampOffsetMicroseconds * 1.0e-6 +
|
||||
num_input_samples_ / input_sample_rate_;
|
||||
if (options_.use_local_timestamp()) {
|
||||
timestamp_seconds += kGapBetweenPacketsInSeconds * i;
|
||||
}
|
||||
|
||||
Matrix* data_frame =
|
||||
NewTestFrame(num_input_channels_, packet_size, timestamp_seconds);
|
||||
|
||||
AppendInputPacket(data_frame, round(timestamp_seconds *
|
||||
Timestamp::kTimestampUnitsPerSecond));
|
||||
num_input_samples_ += packet_size;
|
||||
}
|
||||
}
|
||||
|
||||
void CheckOutputTimestamps() {
|
||||
int num_full_packets = output().packets.size();
|
||||
if (options_.pad_final_packet()) {
|
||||
num_full_packets -= 1;
|
||||
}
|
||||
|
||||
int64 num_samples = 0;
|
||||
for (int packet_num = 0; packet_num < num_full_packets; ++packet_num) {
|
||||
const Packet& packet = output().packets[packet_num];
|
||||
num_samples += FrameDurationSamples();
|
||||
double expected_timestamp =
|
||||
options_.use_local_timestamp()
|
||||
? GetExpectedLocalTimestampForSample(num_samples - 1)
|
||||
: GetExpectedCumulativeTimestamp(num_samples - 1);
|
||||
ASSERT_NEAR(packet.Timestamp().Seconds(), expected_timestamp, 1e-10);
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status RunTimestampTest() {
|
||||
InitializeGraph();
|
||||
InitializeInputForTimeStampingTest();
|
||||
FillInputHeader();
|
||||
return RunGraph();
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns the timestamp in seconds based on local timestamping.
|
||||
double GetExpectedLocalTimestampForSample(int sample_index) {
|
||||
return kInitialTimestampOffsetMicroseconds * 1.0e-6 +
|
||||
sample_index / input_sample_rate_ +
|
||||
(sample_index / kUniversalInputPacketSize) *
|
||||
kGapBetweenPacketsInSeconds;
|
||||
}
|
||||
|
||||
// Returns the timestamp inseconds based on cumulative timestamping.
|
||||
double GetExpectedCumulativeTimestamp(int sample_index) {
|
||||
return kInitialTimestampOffsetMicroseconds * 1.0e-6 +
|
||||
sample_index / FrameDurationSamples() * FrameDurationSamples() /
|
||||
input_sample_rate_;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTimestampingTest, UseLocalTimeStamp) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_use_local_timestamp(true);
|
||||
|
||||
MP_ASSERT_OK(RunTimestampTest());
|
||||
CheckOutputTimestamps();
|
||||
}
|
||||
|
||||
TEST_F(TimeSeriesFramerCalculatorTimestampingTest, UseCumulativeTimeStamp) {
|
||||
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
|
||||
options_.set_use_local_timestamp(false);
|
||||
|
||||
MP_ASSERT_OK(RunTimestampTest());
|
||||
CheckOutputTimestamps();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,84 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Attach the header from a stream or side input to another stream.
|
||||
//
|
||||
// The header stream (tag HEADER) must not have any packets in it.
|
||||
//
|
||||
// Before using this calculator, please think about changing your
|
||||
// calculator to not need a header or to accept a separate stream with
|
||||
// a header, that would be more future proof.
|
||||
//
|
||||
// Example usage 1:
|
||||
// node {
|
||||
// calculator: "AddHeaderCalculator"
|
||||
// input_stream: "DATA:audio"
|
||||
// input_stream: "HEADER:audio_header"
|
||||
// output_stream: "audio_with_header"
|
||||
// }
|
||||
//
|
||||
// Example usage 2:
|
||||
// node {
|
||||
// calculator: "AddHeaderCalculator"
|
||||
// input_stream: "DATA:audio"
|
||||
// input_side_packet: "HEADER:audio_header"
|
||||
// output_stream: "audio_with_header"
|
||||
// }
|
||||
//
|
||||
class AddHeaderCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<NoneType>::Optional kHeader{"HEADER"};
|
||||
static constexpr SideInput<AnyType>::Optional kHeaderSide{"HEADER"};
|
||||
static constexpr Input<AnyType> kData{"DATA"};
|
||||
static constexpr Output<SameType<kData>> kOut{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kHeader, kHeaderSide, kData, kOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
if (kHeader(cc).IsConnected() == kHeaderSide(cc).IsConnected()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Header must be provided via exactly one of side input and input "
|
||||
"stream");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
const PacketBase& header =
|
||||
kHeader(cc).IsConnected() ? kHeader(cc).Header() : kHeaderSide(cc);
|
||||
if (!header.IsEmpty()) {
|
||||
kOut(cc).SetHeader(header);
|
||||
}
|
||||
cc->SetOffset(0);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
kOut(cc).Send(kData(cc).packet());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(AddHeaderCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,159 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class AddHeaderCalculatorTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(AddHeaderCalculatorTest, HeaderStream) {
|
||||
CalculatorGraphConfig::Node node;
|
||||
node.set_calculator("AddHeaderCalculator");
|
||||
node.add_input_stream("HEADER:header_stream");
|
||||
node.add_input_stream("DATA:data_stream");
|
||||
node.add_output_stream("merged_stream");
|
||||
|
||||
CalculatorRunner runner(node);
|
||||
|
||||
// Set header and add 5 packets.
|
||||
runner.MutableInputs()->Tag("HEADER").header =
|
||||
Adopt(new std::string("my_header"));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
||||
}
|
||||
|
||||
// Run calculator.
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
ASSERT_EQ(1, runner.Outputs().NumEntries());
|
||||
|
||||
// Test output.
|
||||
EXPECT_EQ(std::string("my_header"),
|
||||
runner.Outputs().Index(0).header.Get<std::string>());
|
||||
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
|
||||
ASSERT_EQ(5, output_packets.size());
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
const int val = output_packets[i].Get<int>();
|
||||
EXPECT_EQ(i, val);
|
||||
EXPECT_EQ(Timestamp(i * 1000), output_packets[i].Timestamp());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(AddHeaderCalculatorTest, HandlesEmptyHeaderStream) {
|
||||
CalculatorGraphConfig::Node node;
|
||||
node.set_calculator("AddHeaderCalculator");
|
||||
node.add_input_stream("HEADER:header_stream");
|
||||
node.add_input_stream("DATA:data_stream");
|
||||
node.add_output_stream("merged_stream");
|
||||
|
||||
CalculatorRunner runner(node);
|
||||
|
||||
// No header and no packets.
|
||||
// Run calculator.
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
EXPECT_TRUE(runner.Outputs().Index(0).header.IsEmpty());
|
||||
}
|
||||
|
||||
TEST_F(AddHeaderCalculatorTest, NoPacketsOnHeaderStream) {
|
||||
CalculatorGraphConfig::Node node;
|
||||
node.set_calculator("AddHeaderCalculator");
|
||||
node.add_input_stream("HEADER:header_stream");
|
||||
node.add_input_stream("DATA:data_stream");
|
||||
node.add_output_stream("merged_stream");
|
||||
|
||||
CalculatorRunner runner(node);
|
||||
|
||||
// Set header and add 5 packets.
|
||||
runner.MutableInputs()->Tag("HEADER").header =
|
||||
Adopt(new std::string("my_header"));
|
||||
runner.MutableInputs()->Tag("HEADER").packets.push_back(
|
||||
Adopt(new std::string("not allowed")));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
||||
}
|
||||
|
||||
// Run calculator.
|
||||
ASSERT_FALSE(runner.Run().ok());
|
||||
}
|
||||
|
||||
TEST_F(AddHeaderCalculatorTest, InputSidePacket) {
|
||||
CalculatorGraphConfig::Node node;
|
||||
node.set_calculator("AddHeaderCalculator");
|
||||
node.add_input_stream("DATA:data_stream");
|
||||
node.add_output_stream("merged_stream");
|
||||
node.add_input_side_packet("HEADER:header");
|
||||
|
||||
CalculatorRunner runner(node);
|
||||
|
||||
// Set header and add 5 packets.
|
||||
runner.MutableSidePackets()->Tag("HEADER") =
|
||||
Adopt(new std::string("my_header"));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
||||
}
|
||||
|
||||
// Run calculator.
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
ASSERT_EQ(1, runner.Outputs().NumEntries());
|
||||
|
||||
// Test output.
|
||||
EXPECT_EQ(std::string("my_header"),
|
||||
runner.Outputs().Index(0).header.Get<std::string>());
|
||||
const std::vector<Packet>& output_packets = runner.Outputs().Index(0).packets;
|
||||
ASSERT_EQ(5, output_packets.size());
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
const int val = output_packets[i].Get<int>();
|
||||
EXPECT_EQ(i, val);
|
||||
EXPECT_EQ(Timestamp(i * 1000), output_packets[i].Timestamp());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(AddHeaderCalculatorTest, UsingBothSideInputAndStream) {
|
||||
CalculatorGraphConfig::Node node;
|
||||
node.set_calculator("AddHeaderCalculator");
|
||||
node.add_input_stream("HEADER:header_stream");
|
||||
node.add_input_stream("DATA:data_stream");
|
||||
node.add_output_stream("merged_stream");
|
||||
node.add_input_side_packet("HEADER:header");
|
||||
|
||||
CalculatorRunner runner(node);
|
||||
|
||||
// Set both headers and add 5 packets.
|
||||
runner.MutableSidePackets()->Tag("HEADER") =
|
||||
Adopt(new std::string("my_header"));
|
||||
runner.MutableSidePackets()->Tag("HEADER") =
|
||||
Adopt(new std::string("my_header"));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
||||
}
|
||||
|
||||
// Run should fail because header can only be provided one way.
|
||||
EXPECT_EQ(runner.Run().code(), absl::InvalidArgumentError("").code());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,448 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/calculators/core/begin_loop_calculator.h"
|
||||
#include "mediapipe/calculators/core/end_loop_calculator.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
MATCHER_P2(PacketOfIntsEq, timestamp, value, "") {
|
||||
Timestamp actual_timestamp = arg.Timestamp();
|
||||
const auto& actual_value = arg.template Get<std::vector<int>>();
|
||||
return testing::Value(actual_timestamp, testing::Eq(timestamp)) &&
|
||||
testing::Value(actual_value, testing::ElementsAreArray(value));
|
||||
}
|
||||
|
||||
typedef BeginLoopCalculator<std::vector<int>> BeginLoopIntegerCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopIntegerCalculator);
|
||||
|
||||
class IncrementCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Outputs().Index(0).Set<int>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
const int& input_int = cc->Inputs().Index(0).Get<int>();
|
||||
auto output_int = absl::make_unique<int>(input_int + 1);
|
||||
cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(IncrementCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<int>> EndLoopIntegersCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopIntegersCalculator);
|
||||
|
||||
class BeginEndLoopCalculatorGraphTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
num_threads: 4
|
||||
input_stream: "ints"
|
||||
node {
|
||||
calculator: "BeginLoopIntegerCalculator"
|
||||
input_stream: "ITERABLE:ints"
|
||||
output_stream: "ITEM:int"
|
||||
output_stream: "BATCH_END:timestamp"
|
||||
}
|
||||
node {
|
||||
calculator: "IncrementCalculator"
|
||||
input_stream: "int"
|
||||
output_stream: "int_plus_one"
|
||||
}
|
||||
node {
|
||||
calculator: "EndLoopIntegersCalculator"
|
||||
input_stream: "ITEM:int_plus_one"
|
||||
input_stream: "BATCH_END:timestamp"
|
||||
output_stream: "ITERABLE:ints_plus_one"
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("ints_plus_one", &graph_config, &output_packets_);
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
}
|
||||
|
||||
void SendPacketOfInts(Timestamp timestamp, std::vector<int> ints) {
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
|
||||
}
|
||||
|
||||
CalculatorGraph graph_;
|
||||
std::vector<Packet> output_packets_;
|
||||
};
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphTest, InputStreamForIterableIsEmpty) {
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
// EndLoopCalc will forward the timestamp bound because there are no packets
|
||||
// to process.
|
||||
ASSERT_EQ(0, output_packets_.size());
|
||||
|
||||
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphTest, SingleEmptyVector) {
|
||||
SendPacketOfInts(Timestamp(0), {});
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
// EndLoopCalc will forward the timestamp bound because there are no elements
|
||||
// in collection to output.
|
||||
EXPECT_TRUE(output_packets_.empty());
|
||||
|
||||
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphTest, SingleNonEmptyVector) {
|
||||
Timestamp input_timestamp = Timestamp(0);
|
||||
SendPacketOfInts(input_timestamp, {0, 1, 2});
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(output_packets_,
|
||||
testing::ElementsAre(
|
||||
PacketOfIntsEq(input_timestamp, std::vector<int>{1, 2, 3})));
|
||||
|
||||
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphTest, MultipleVectors) {
|
||||
Timestamp input_timestamp0 = Timestamp(0);
|
||||
SendPacketOfInts(input_timestamp0, {0, 1});
|
||||
|
||||
Timestamp input_timestamp1 = Timestamp(1);
|
||||
SendPacketOfInts(input_timestamp1, {});
|
||||
|
||||
Timestamp input_timestamp2 = Timestamp(2);
|
||||
SendPacketOfInts(input_timestamp2, {2, 3});
|
||||
|
||||
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
|
||||
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
|
||||
// no elements in vector to process.
|
||||
EXPECT_THAT(output_packets_,
|
||||
testing::ElementsAre(
|
||||
PacketOfIntsEq(input_timestamp0, std::vector<int>{1, 2}),
|
||||
PacketOfIntsEq(input_timestamp2, std::vector<int>{3, 4})));
|
||||
}
|
||||
|
||||
// Passes non empty vector through or outputs empty vector in case of timestamp
|
||||
// bound update.
|
||||
class PassThroughOrEmptyVectorCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
cc->Inputs().Index(0).Set<std::vector<int>>();
|
||||
cc->Outputs().Index(0).Set<std::vector<int>>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (!cc->Inputs().Index(0).IsEmpty()) {
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||
} else {
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
MakePacket<std::vector<int>>(std::vector<int>())
|
||||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(PassThroughOrEmptyVectorCalculator);
|
||||
|
||||
class BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest
|
||||
: public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
num_threads: 4
|
||||
input_stream: "ints"
|
||||
input_stream: "force_ints_to_be_timestamp_bound_update"
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "ints"
|
||||
input_stream: "DISALLOW:force_ints_to_be_timestamp_bound_update"
|
||||
output_stream: "ints_passed_through"
|
||||
}
|
||||
node {
|
||||
calculator: "BeginLoopIntegerCalculator"
|
||||
input_stream: "ITERABLE:ints_passed_through"
|
||||
output_stream: "ITEM:int"
|
||||
output_stream: "BATCH_END:timestamp"
|
||||
}
|
||||
node {
|
||||
calculator: "IncrementCalculator"
|
||||
input_stream: "int"
|
||||
output_stream: "int_plus_one"
|
||||
}
|
||||
node {
|
||||
calculator: "EndLoopIntegersCalculator"
|
||||
input_stream: "ITEM:int_plus_one"
|
||||
input_stream: "BATCH_END:timestamp"
|
||||
output_stream: "ITERABLE:ints_plus_one"
|
||||
}
|
||||
node {
|
||||
calculator: "PassThroughOrEmptyVectorCalculator"
|
||||
input_stream: "ints_plus_one"
|
||||
output_stream: "ints_plus_one_passed_through"
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("ints_plus_one_passed_through", &graph_config,
|
||||
&output_packets_);
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
}
|
||||
|
||||
void SendPacketOfIntsOrBound(Timestamp timestamp, std::vector<int> ints) {
|
||||
// All "ints" packets which are empty are forced to be just timestamp
|
||||
// bound updates for begin loop calculator.
|
||||
bool force_ints_to_be_timestamp_bound_update = ints.empty();
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"force_ints_to_be_timestamp_bound_update",
|
||||
MakePacket<bool>(force_ints_to_be_timestamp_bound_update)
|
||||
.At(timestamp)));
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
|
||||
}
|
||||
|
||||
CalculatorGraph graph_;
|
||||
std::vector<Packet> output_packets_;
|
||||
};
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest,
|
||||
SingleEmptyVector) {
|
||||
SendPacketOfIntsOrBound(Timestamp(0), {});
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(output_packets_, testing::ElementsAre(PacketOfIntsEq(
|
||||
Timestamp(0), std::vector<int>{})));
|
||||
|
||||
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest,
|
||||
SingleNonEmptyVector) {
|
||||
SendPacketOfIntsOrBound(Timestamp(0), {0, 1, 2});
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(output_packets_, testing::ElementsAre(PacketOfIntsEq(
|
||||
Timestamp(0), std::vector<int>{1, 2, 3})));
|
||||
|
||||
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest, MultipleVectors) {
|
||||
SendPacketOfIntsOrBound(Timestamp(0), {});
|
||||
// Waiting until idle to guarantee all timestamp bound updates are processed
|
||||
// individually. (Timestamp bounds updates occur in the provide config only
|
||||
// if input is an empty vector.)
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
SendPacketOfIntsOrBound(Timestamp(1), {0, 1});
|
||||
SendPacketOfIntsOrBound(Timestamp(2), {});
|
||||
// Waiting until idle to guarantee all timestamp bound updates are processed
|
||||
// individually. (Timestamp bounds updates occur in the provide config only
|
||||
// if input is an empty vector.)
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
SendPacketOfIntsOrBound(Timestamp(3), {2, 3});
|
||||
SendPacketOfIntsOrBound(Timestamp(4), {});
|
||||
// Waiting until idle to guarantee all timestamp bound updates are processed
|
||||
// individually. (Timestamp bounds updates occur in the provide config only
|
||||
// if input is an empty vector.)
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
testing::ElementsAre(PacketOfIntsEq(Timestamp(0), std::vector<int>{}),
|
||||
PacketOfIntsEq(Timestamp(1), std::vector<int>{1, 2}),
|
||||
PacketOfIntsEq(Timestamp(2), std::vector<int>{}),
|
||||
PacketOfIntsEq(Timestamp(3), std::vector<int>{3, 4}),
|
||||
PacketOfIntsEq(Timestamp(4), std::vector<int>{})));
|
||||
}
|
||||
|
||||
class MultiplierCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Inputs().Index(1).Set<int>();
|
||||
cc->Outputs().Index(0).Set<int>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
const int& input_int = cc->Inputs().Index(0).Get<int>();
|
||||
const int& multiplier_int = cc->Inputs().Index(1).Get<int>();
|
||||
auto output_int = absl::make_unique<int>(input_int * multiplier_int);
|
||||
cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(MultiplierCalculator);
|
||||
|
||||
class BeginEndLoopCalculatorGraphWithClonedInputsTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
num_threads: 4
|
||||
input_stream: "ints"
|
||||
input_stream: "multiplier"
|
||||
node {
|
||||
calculator: "BeginLoopIntegerCalculator"
|
||||
input_stream: "ITERABLE:ints"
|
||||
input_stream: "CLONE:multiplier"
|
||||
output_stream: "ITEM:int_at_loop"
|
||||
output_stream: "CLONE:multiplier_cloned_at_loop"
|
||||
output_stream: "BATCH_END:timestamp"
|
||||
}
|
||||
node {
|
||||
calculator: "MultiplierCalculator"
|
||||
input_stream: "int_at_loop"
|
||||
input_stream: "multiplier_cloned_at_loop"
|
||||
output_stream: "multiplied_int_at_loop"
|
||||
}
|
||||
node {
|
||||
calculator: "EndLoopIntegersCalculator"
|
||||
input_stream: "ITEM:multiplied_int_at_loop"
|
||||
input_stream: "BATCH_END:timestamp"
|
||||
output_stream: "ITERABLE:multiplied_ints"
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("multiplied_ints", &graph_config, &output_packets_);
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
}
|
||||
|
||||
void SendPackets(Timestamp timestamp, int multiplier, std::vector<int> ints) {
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"multiplier", MakePacket<int>(multiplier).At(timestamp)));
|
||||
}
|
||||
|
||||
void SendMultiplier(Timestamp timestamp, int multiplier) {
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"multiplier", MakePacket<int>(multiplier).At(timestamp)));
|
||||
}
|
||||
|
||||
CalculatorGraph graph_;
|
||||
std::vector<Packet> output_packets_;
|
||||
};
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest,
|
||||
InputStreamForIterableIsEmpty) {
|
||||
Timestamp input_timestamp = Timestamp(42);
|
||||
SendMultiplier(input_timestamp, /*multiplier=*/2);
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
// EndLoopCalc will forward the timestamp bound because there are no packets
|
||||
// to process.
|
||||
ASSERT_EQ(0, output_packets_.size());
|
||||
|
||||
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleEmptyVector) {
|
||||
SendPackets(Timestamp(0), /*multiplier=*/2, /*ints=*/{});
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
// EndLoopCalc will forward the timestamp bound because there are no elements
|
||||
// in collection to output.
|
||||
EXPECT_TRUE(output_packets_.empty());
|
||||
|
||||
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleNonEmptyVector) {
|
||||
Timestamp input_timestamp = Timestamp(42);
|
||||
SendPackets(input_timestamp, /*multiplier=*/2, /*ints=*/{0, 1, 2});
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(output_packets_,
|
||||
testing::ElementsAre(
|
||||
PacketOfIntsEq(input_timestamp, std::vector<int>{0, 2, 4})));
|
||||
|
||||
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) {
|
||||
Timestamp input_timestamp0 = Timestamp(42);
|
||||
SendPackets(input_timestamp0, /*multiplier=*/2, /*ints=*/{0, 1});
|
||||
|
||||
Timestamp input_timestamp1 = Timestamp(43);
|
||||
SendPackets(input_timestamp1, /*multiplier=*/2, /*ints=*/{});
|
||||
|
||||
Timestamp input_timestamp2 = Timestamp(44);
|
||||
SendPackets(input_timestamp2, /*multiplier=*/3, /*ints=*/{2, 3});
|
||||
|
||||
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
|
||||
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
|
||||
// no elements in vector to process.
|
||||
EXPECT_THAT(output_packets_,
|
||||
testing::ElementsAre(
|
||||
PacketOfIntsEq(input_timestamp0, std::vector<int>{0, 2}),
|
||||
PacketOfIntsEq(input_timestamp2, std::vector<int>{6, 9})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/begin_loop_calculator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A calculator to process std::vector<NormalizedLandmarkList>.
|
||||
typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedLandmarkList>>
|
||||
BeginLoopNormalizedLandmarkListVectorCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopNormalizedLandmarkListVectorCalculator);
|
||||
|
||||
// A calculator to process std::vector<NormalizedRect>.
|
||||
typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
|
||||
BeginLoopNormalizedRectCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopNormalizedRectCalculator);
|
||||
|
||||
// A calculator to process std::vector<Detection>.
|
||||
typedef BeginLoopCalculator<std::vector<::mediapipe::Detection>>
|
||||
BeginLoopDetectionCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopDetectionCalculator);
|
||||
|
||||
// A calculator to process std::vector<Matrix>.
|
||||
typedef BeginLoopCalculator<std::vector<Matrix>> BeginLoopMatrixCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopMatrixCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,165 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Calculator for implementing loops on iterable collections inside a MediaPipe
|
||||
// graph.
|
||||
//
|
||||
// It is designed to be used like:
|
||||
//
|
||||
// node {
|
||||
// calculator: "BeginLoopWithIterableCalculator"
|
||||
// input_stream: "ITERABLE:input_iterable" # IterableT @ext_ts
|
||||
// output_stream: "ITEM:input_element" # ItemT @loop_internal_ts
|
||||
// output_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
|
||||
// }
|
||||
//
|
||||
// node {
|
||||
// calculator: "ElementToBlaConverterSubgraph"
|
||||
// input_stream: "ITEM:input_to_loop_body" # ItemT @loop_internal_ts
|
||||
// output_stream: "BLA:output_of_loop_body" # ItemU @loop_internal_ts
|
||||
// }
|
||||
//
|
||||
// node {
|
||||
// calculator: "EndLoopWithOutputCalculator"
|
||||
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
|
||||
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
|
||||
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
|
||||
// }
|
||||
//
|
||||
// Input streams tagged with "CLONE" are cloned to the corresponding output
|
||||
// streams at loop timestamps. This ensures that a MediaPipe graph or sub-graph
|
||||
// can run multiple times, once per element in the "ITERABLE" for each pakcet
|
||||
// clone of the packets in the "CLONE" input streams.
|
||||
template <typename IterableT>
|
||||
class BeginLoopCalculator : public CalculatorBase {
|
||||
using ItemT = typename IterableT::value_type;
|
||||
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
// The below enables processing of timestamp bound updates, and that enables
|
||||
// correct timestamp propagation by the companion EndLoopCalculator.
|
||||
//
|
||||
// For instance, Process() function will be still invoked even if upstream
|
||||
// calculator has updated timestamp bound for ITERABLE input instead of
|
||||
// providing actual value.
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
|
||||
// A non-empty packet in the optional "TICK" input stream wakes up the
|
||||
// calculator.
|
||||
// DEPRECATED as timestamp bound updates are processed by default in this
|
||||
// calculator.
|
||||
if (cc->Inputs().HasTag("TICK")) {
|
||||
cc->Inputs().Tag("TICK").SetAny();
|
||||
}
|
||||
|
||||
// An iterable collection in the input stream.
|
||||
RET_CHECK(cc->Inputs().HasTag("ITERABLE"));
|
||||
cc->Inputs().Tag("ITERABLE").Set<IterableT>();
|
||||
|
||||
// An element from the collection.
|
||||
RET_CHECK(cc->Outputs().HasTag("ITEM"));
|
||||
cc->Outputs().Tag("ITEM").Set<ItemT>();
|
||||
|
||||
RET_CHECK(cc->Outputs().HasTag("BATCH_END"));
|
||||
cc->Outputs()
|
||||
.Tag("BATCH_END")
|
||||
.Set<Timestamp>(
|
||||
// A flush signal to the corresponding EndLoopCalculator for it to
|
||||
// emit the aggregated result with the timestamp contained in this
|
||||
// flush signal packet.
|
||||
);
|
||||
|
||||
// Input streams tagged with "CLONE" are cloned to the corresponding
|
||||
// "CLONE" output streams at loop timestamps.
|
||||
RET_CHECK(cc->Inputs().NumEntries("CLONE") ==
|
||||
cc->Outputs().NumEntries("CLONE"));
|
||||
if (cc->Inputs().NumEntries("CLONE") > 0) {
|
||||
for (int i = 0; i < cc->Inputs().NumEntries("CLONE"); ++i) {
|
||||
cc->Inputs().Get("CLONE", i).SetAny();
|
||||
cc->Outputs().Get("CLONE", i).SetSameAs(&cc->Inputs().Get("CLONE", i));
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
Timestamp last_timestamp = loop_internal_timestamp_;
|
||||
if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) {
|
||||
const IterableT& collection =
|
||||
cc->Inputs().Tag("ITERABLE").template Get<IterableT>();
|
||||
for (const auto& item : collection) {
|
||||
cc->Outputs().Tag("ITEM").AddPacket(
|
||||
MakePacket<ItemT>(item).At(loop_internal_timestamp_));
|
||||
ForwardClonePackets(cc, loop_internal_timestamp_);
|
||||
++loop_internal_timestamp_;
|
||||
}
|
||||
}
|
||||
|
||||
// The collection was empty and nothing was processed.
|
||||
if (last_timestamp == loop_internal_timestamp_) {
|
||||
// Increment loop_internal_timestamp_ because it is used up now.
|
||||
++loop_internal_timestamp_;
|
||||
for (auto it = cc->Outputs().begin(); it < cc->Outputs().end(); ++it) {
|
||||
it->SetNextTimestampBound(loop_internal_timestamp_);
|
||||
}
|
||||
}
|
||||
|
||||
// The for loop processing the input collection already incremented
|
||||
// loop_internal_timestamp_. To emit BATCH_END packet along the last
|
||||
// non-BATCH_END packet, decrement by one.
|
||||
cc->Outputs()
|
||||
.Tag("BATCH_END")
|
||||
.AddPacket(MakePacket<Timestamp>(cc->InputTimestamp())
|
||||
.At(Timestamp(loop_internal_timestamp_ - 1)));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
void ForwardClonePackets(CalculatorContext* cc, Timestamp output_timestamp) {
|
||||
if (cc->Inputs().NumEntries("CLONE") > 0) {
|
||||
for (int i = 0; i < cc->Inputs().NumEntries("CLONE"); ++i) {
|
||||
if (!cc->Inputs().Get("CLONE", i).IsEmpty()) {
|
||||
auto input_packet = cc->Inputs().Get("CLONE", i).Value();
|
||||
cc->Outputs()
|
||||
.Get("CLONE", i)
|
||||
.AddPacket(std::move(input_packet).At(output_timestamp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fake timestamps generated per element in collection.
|
||||
Timestamp loop_internal_timestamp_ = Timestamp(0);
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
|
||||
ClipDetectionVectorSizeCalculator;
|
||||
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef ClipVectorSizeCalculator<::mediapipe::NormalizedRect>
|
||||
ClipNormalizedRectVectorSizeCalculator;
|
||||
REGISTER_CALCULATOR(ClipNormalizedRectVectorSizeCalculator);
|
||||
|
||||
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
|
||||
ClipDetectionVectorSizeCalculator;
|
||||
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Clips the size of the input vector of type T to a specified max_vec_size.
|
||||
// In a graph it will be used as:
|
||||
// node {
|
||||
// calculator: "ClipIntVectorSizeCalculator"
|
||||
// input_stream: "input_vector"
|
||||
// output_stream: "output_vector"
|
||||
// options {
|
||||
// [mediapipe.ClipVectorSizeCalculatorOptions.ext] {
|
||||
// max_vec_size: 5
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// Optionally, you can pass in a side packet that will override `max_vec_size`
|
||||
// that is specified in the options.
|
||||
template <typename T>
|
||||
class ClipVectorSizeCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().NumEntries() == 1);
|
||||
RET_CHECK(cc->Outputs().NumEntries() == 1);
|
||||
|
||||
if (cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>()
|
||||
.max_vec_size() < 1) {
|
||||
return absl::InternalError(
|
||||
"max_vec_size should be greater than or equal to 1.");
|
||||
}
|
||||
|
||||
cc->Inputs().Index(0).Set<std::vector<T>>();
|
||||
cc->Outputs().Index(0).Set<std::vector<T>>();
|
||||
// Optional input side packet that determines `max_vec_size`.
|
||||
if (cc->InputSidePackets().NumEntries() > 0) {
|
||||
cc->InputSidePackets().Index(0).Set<int>();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
max_vec_size_ = cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>()
|
||||
.max_vec_size();
|
||||
// Override `max_vec_size` if passed as side packet.
|
||||
if (cc->InputSidePackets().NumEntries() > 0 &&
|
||||
!cc->InputSidePackets().Index(0).IsEmpty()) {
|
||||
max_vec_size_ = cc->InputSidePackets().Index(0).Get<int>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (max_vec_size_ < 1) {
|
||||
return absl::InternalError(
|
||||
"max_vec_size should be greater than or equal to 1.");
|
||||
}
|
||||
if (cc->Inputs().Index(0).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
return ClipVectorSize<T>(std::is_copy_constructible<T>(), cc);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
absl::Status ClipVectorSize(std::true_type, CalculatorContext* cc) {
|
||||
auto output = absl::make_unique<std::vector<U>>();
|
||||
const std::vector<U>& input_vector =
|
||||
cc->Inputs().Index(0).Get<std::vector<U>>();
|
||||
if (max_vec_size_ >= input_vector.size()) {
|
||||
output->insert(output->end(), input_vector.begin(), input_vector.end());
|
||||
} else {
|
||||
for (int i = 0; i < max_vec_size_; ++i) {
|
||||
output->push_back(input_vector[i]);
|
||||
}
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
absl::Status ClipVectorSize(std::false_type, CalculatorContext* cc) {
|
||||
return ConsumeAndClipVectorSize<T>(std::is_move_constructible<U>(), cc);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
absl::Status ConsumeAndClipVectorSize(std::true_type, CalculatorContext* cc) {
|
||||
auto output = absl::make_unique<std::vector<U>>();
|
||||
absl::StatusOr<std::unique_ptr<std::vector<U>>> input_status =
|
||||
cc->Inputs().Index(0).Value().Consume<std::vector<U>>();
|
||||
|
||||
if (input_status.ok()) {
|
||||
std::unique_ptr<std::vector<U>> input_vector =
|
||||
std::move(input_status).value();
|
||||
auto begin_it = input_vector->begin();
|
||||
auto end_it = input_vector->end();
|
||||
if (max_vec_size_ < input_vector->size()) {
|
||||
end_it = input_vector->begin() + max_vec_size_;
|
||||
}
|
||||
output->insert(output->end(), std::make_move_iterator(begin_it),
|
||||
std::make_move_iterator(end_it));
|
||||
} else {
|
||||
return input_status.status();
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
absl::Status ConsumeAndClipVectorSize(std::false_type,
|
||||
CalculatorContext* cc) {
|
||||
return absl::InternalError(
|
||||
"Cannot copy or move input vectors and clip their size.");
|
||||
}
|
||||
|
||||
private:
|
||||
int max_vec_size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message ClipVectorSizeCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional ClipVectorSizeCalculatorOptions ext = 274674998;
|
||||
}
|
||||
|
||||
// Maximum size of output vector.
|
||||
optional int32 max_vec_size = 1 [default = 1];
|
||||
}
|
||||
|
|
@ -1,206 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef ClipVectorSizeCalculator<int> TestClipIntVectorSizeCalculator;
|
||||
REGISTER_CALCULATOR(TestClipIntVectorSizeCalculator);
|
||||
|
||||
void AddInputVector(const std::vector<int>& input, int64 timestamp,
|
||||
CalculatorRunner* runner) {
|
||||
runner->MutableInputs()->Index(0).packets.push_back(
|
||||
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
|
||||
}
|
||||
|
||||
TEST(TestClipIntVectorSizeCalculatorTest, EmptyVectorInput) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "TestClipIntVectorSizeCalculator"
|
||||
input_stream: "input_vector"
|
||||
output_stream: "output_vector"
|
||||
options {
|
||||
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 1 }
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
|
||||
std::vector<int> input = {};
|
||||
AddInputVector(input, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
EXPECT_TRUE(outputs[0].Get<std::vector<int>>().empty());
|
||||
}
|
||||
|
||||
TEST(TestClipIntVectorSizeCalculatorTest, OneTimestamp) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "TestClipIntVectorSizeCalculator"
|
||||
input_stream: "input_vector"
|
||||
output_stream: "output_vector"
|
||||
options {
|
||||
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 2 }
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
|
||||
std::vector<int> input = {0, 1, 2, 3};
|
||||
AddInputVector(input, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const std::vector<int>& output = outputs[0].Get<std::vector<int>>();
|
||||
EXPECT_EQ(2, output.size());
|
||||
std::vector<int> expected_vector = {0, 1};
|
||||
EXPECT_EQ(expected_vector, output);
|
||||
}
|
||||
|
||||
TEST(TestClipIntVectorSizeCalculatorTest, TwoInputsAtTwoTimestamps) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "TestClipIntVectorSizeCalculator"
|
||||
input_stream: "input_vector"
|
||||
output_stream: "output_vector"
|
||||
options {
|
||||
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 3 }
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
|
||||
{
|
||||
std::vector<int> input = {0, 1, 2, 3};
|
||||
AddInputVector(input, /*timestamp=*/1, &runner);
|
||||
}
|
||||
{
|
||||
std::vector<int> input = {2, 3, 4, 5};
|
||||
AddInputVector(input, /*timestamp=*/2, &runner);
|
||||
}
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(2, outputs.size());
|
||||
{
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const std::vector<int>& output = outputs[0].Get<std::vector<int>>();
|
||||
EXPECT_EQ(3, output.size());
|
||||
std::vector<int> expected_vector = {0, 1, 2};
|
||||
EXPECT_EQ(expected_vector, output);
|
||||
}
|
||||
{
|
||||
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||
const std::vector<int>& output = outputs[1].Get<std::vector<int>>();
|
||||
EXPECT_EQ(3, output.size());
|
||||
std::vector<int> expected_vector = {2, 3, 4};
|
||||
EXPECT_EQ(expected_vector, output);
|
||||
}
|
||||
}
|
||||
|
||||
typedef ClipVectorSizeCalculator<std::unique_ptr<int>>
|
||||
TestClipUniqueIntPtrVectorSizeCalculator;
|
||||
REGISTER_CALCULATOR(TestClipUniqueIntPtrVectorSizeCalculator);
|
||||
|
||||
TEST(TestClipUniqueIntPtrVectorSizeCalculatorTest, ConsumeOneTimestamp) {
|
||||
/* Note: We don't use CalculatorRunner for this test because it keeps copies
|
||||
* of input packets, so packets sent to the graph don't have sole ownership.
|
||||
* The test needs to send packets that own the data.
|
||||
*/
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "input_vector"
|
||||
node {
|
||||
calculator: "TestClipUniqueIntPtrVectorSizeCalculator"
|
||||
input_stream: "input_vector"
|
||||
output_stream: "output_vector"
|
||||
options {
|
||||
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 3 }
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
|
||||
std::vector<Packet> outputs;
|
||||
tool::AddVectorSink("output_vector", &graph_config, &outputs);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(graph_config));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
|
||||
// input1 : {0, 1, 2, 3, 4, 5}
|
||||
auto input_vector = absl::make_unique<std::vector<std::unique_ptr<int>>>(6);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
input_vector->at(i) = absl::make_unique<int>(i);
|
||||
}
|
||||
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"input_vector", Adopt(input_vector.release()).At(Timestamp(1))));
|
||||
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const std::vector<std::unique_ptr<int>>& result =
|
||||
outputs[0].Get<std::vector<std::unique_ptr<int>>>();
|
||||
EXPECT_EQ(3, result.size());
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
const std::unique_ptr<int>& v = result[i];
|
||||
EXPECT_EQ(i, *v);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestClipIntVectorSizeCalculatorTest, SidePacket) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "TestClipIntVectorSizeCalculator"
|
||||
input_stream: "input_vector"
|
||||
input_side_packet: "max_vec_size"
|
||||
output_stream: "output_vector"
|
||||
options {
|
||||
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 1 }
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
// This should override the default of 1 set in the options.
|
||||
runner.MutableSidePackets()->Index(0) = Adopt(new int(2));
|
||||
std::vector<int> input = {0, 1, 2, 3};
|
||||
AddInputVector(input, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const std::vector<int>& output = outputs[0].Get<std::vector<int>>();
|
||||
EXPECT_EQ(2, output.size());
|
||||
std::vector<int> expected_vector = {0, 1};
|
||||
EXPECT_EQ(expected_vector, output);
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
// Copyright 2019-2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Example config:
|
||||
//
|
||||
// node {
|
||||
// calculator: "ConcatenateDetectionVectorCalculator"
|
||||
// input_stream: "detection_vector_1"
|
||||
// input_stream: "detection_vector_2"
|
||||
// output_stream: "concatenated_detection_vector"
|
||||
// }
|
||||
//
|
||||
typedef ConcatenateVectorCalculator<::mediapipe::Detection>
|
||||
ConcatenateDetectionVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateDetectionVectorCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_ // NOLINT
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_ // NOLINT
|
||||
|
||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Concatenates several NormalizedLandmarkList protos following stream index
|
||||
// order. This class assumes that every input stream contains a
|
||||
// NormalizedLandmarkList proto object.
|
||||
class ConcatenateNormalizedLandmarkListCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<NormalizedLandmarkList>::Multiple kIn{""};
|
||||
static constexpr Output<NormalizedLandmarkList> kOut{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
RET_CHECK_GE(kIn(cc).Count(), 1);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
only_emit_if_all_present_ =
|
||||
cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>()
|
||||
.only_emit_if_all_present();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (only_emit_if_all_present_) {
|
||||
for (const auto& input : kIn(cc)) {
|
||||
if (input.IsEmpty()) return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
|
||||
NormalizedLandmarkList output;
|
||||
for (const auto& input : kIn(cc)) {
|
||||
if (input.IsEmpty()) continue;
|
||||
const NormalizedLandmarkList& list = *input;
|
||||
for (int j = 0; j < list.landmark_size(); ++j) {
|
||||
*output.add_landmark() = list.landmark(j);
|
||||
}
|
||||
}
|
||||
kOut(cc).Send(std::move(output));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
bool only_emit_if_all_present_;
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateNormalizedLandmarkListCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr float kLocationValue = 3;
|
||||
|
||||
NormalizedLandmarkList GenerateLandmarks(int landmarks_size,
|
||||
int value_multiplier) {
|
||||
NormalizedLandmarkList landmarks;
|
||||
for (int i = 0; i < landmarks_size; ++i) {
|
||||
NormalizedLandmark* landmark = landmarks.add_landmark();
|
||||
landmark->set_x(value_multiplier * kLocationValue);
|
||||
landmark->set_y(value_multiplier * kLocationValue);
|
||||
landmark->set_z(value_multiplier * kLocationValue);
|
||||
}
|
||||
return landmarks;
|
||||
}
|
||||
|
||||
void ValidateCombinedLandmarks(
|
||||
const std::vector<NormalizedLandmarkList>& inputs,
|
||||
const NormalizedLandmarkList& result) {
|
||||
int element_id = 0;
|
||||
int expected_size = 0;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const NormalizedLandmarkList& landmarks_i = inputs[i];
|
||||
expected_size += landmarks_i.landmark_size();
|
||||
for (int j = 0; j < landmarks_i.landmark_size(); ++j) {
|
||||
const NormalizedLandmark& expected = landmarks_i.landmark(j);
|
||||
const NormalizedLandmark& got = result.landmark(element_id);
|
||||
EXPECT_FLOAT_EQ(expected.x(), got.x());
|
||||
EXPECT_FLOAT_EQ(expected.y(), got.y());
|
||||
EXPECT_FLOAT_EQ(expected.z(), got.z());
|
||||
++element_id;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(expected_size, result.landmark_size());
|
||||
}
|
||||
|
||||
void AddInputLandmarkLists(
|
||||
const std::vector<NormalizedLandmarkList>& input_landmarks_vec,
|
||||
int64 timestamp, CalculatorRunner* runner) {
|
||||
for (int i = 0; i < input_landmarks_vec.size(); ++i) {
|
||||
runner->MutableInputs()->Index(i).packets.push_back(
|
||||
MakePacket<NormalizedLandmarkList>(input_landmarks_vec[i])
|
||||
.At(Timestamp(timestamp)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
|
||||
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
NormalizedLandmarkList empty_list;
|
||||
std::vector<NormalizedLandmarkList> inputs = {empty_list, empty_list,
|
||||
empty_list};
|
||||
AddInputLandmarkLists(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(0, outputs[0].Get<NormalizedLandmarkList>().landmark_size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneTimestamp) {
|
||||
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
NormalizedLandmarkList input_0 =
|
||||
GenerateLandmarks(/*landmarks_size=*/3, /*value_multiplier=*/0);
|
||||
NormalizedLandmarkList input_1 =
|
||||
GenerateLandmarks(/*landmarks_size=*/1, /*value_multiplier=*/1);
|
||||
NormalizedLandmarkList input_2 =
|
||||
GenerateLandmarks(/*landmarks_size=*/2, /*value_multiplier=*/2);
|
||||
std::vector<NormalizedLandmarkList> inputs = {input_0, input_1, input_2};
|
||||
AddInputLandmarkLists(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const NormalizedLandmarkList& result =
|
||||
outputs[0].Get<NormalizedLandmarkList>();
|
||||
ValidateCombinedLandmarks(inputs, result);
|
||||
}
|
||||
|
||||
TEST(ConcatenateNormalizedLandmarkListCalculatorTest,
|
||||
TwoInputsAtTwoTimestamps) {
|
||||
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
NormalizedLandmarkList input_0 =
|
||||
GenerateLandmarks(/*landmarks_size=*/3, /*value_multiplier=*/0);
|
||||
NormalizedLandmarkList input_1 =
|
||||
GenerateLandmarks(/*landmarks_size=*/1, /*value_multiplier=*/1);
|
||||
NormalizedLandmarkList input_2 =
|
||||
GenerateLandmarks(/*landmarks_size=*/2, /*value_multiplier=*/2);
|
||||
std::vector<NormalizedLandmarkList> inputs = {input_0, input_1, input_2};
|
||||
{ AddInputLandmarkLists(inputs, /*timestamp=*/1, &runner); }
|
||||
{ AddInputLandmarkLists(inputs, /*timestamp=*/2, &runner); }
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(2, outputs.size());
|
||||
{
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const NormalizedLandmarkList& result =
|
||||
outputs[0].Get<NormalizedLandmarkList>();
|
||||
ValidateCombinedLandmarks(inputs, result);
|
||||
}
|
||||
{
|
||||
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||
const NormalizedLandmarkList& result =
|
||||
outputs[1].Get<NormalizedLandmarkList>();
|
||||
ValidateCombinedLandmarks(inputs, result);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ConcatenateNormalizedLandmarkListCalculatorTest,
|
||||
OneEmptyStreamStillOutput) {
|
||||
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
NormalizedLandmarkList input_0 =
|
||||
GenerateLandmarks(/*landmarks_size=*/3, /*value_multiplier=*/0);
|
||||
std::vector<NormalizedLandmarkList> inputs = {input_0};
|
||||
AddInputLandmarkLists(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const NormalizedLandmarkList& result =
|
||||
outputs[0].Get<NormalizedLandmarkList>();
|
||||
ValidateCombinedLandmarks(inputs, result);
|
||||
}
|
||||
|
||||
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
|
||||
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
||||
/*options_string=*/
|
||||
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
|
||||
"{only_emit_if_all_present: true}",
|
||||
/*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
NormalizedLandmarkList input_0 =
|
||||
GenerateLandmarks(/*landmarks_size=*/3, /*value_multiplier=*/0);
|
||||
std::vector<NormalizedLandmarkList> inputs = {input_0};
|
||||
AddInputLandmarkLists(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "ConcatenateFloatVectorCalculator"
|
||||
// input_stream: "float_vector_1"
|
||||
// input_stream: "float_vector_2"
|
||||
// output_stream: "concatenated_float_vector"
|
||||
// }
|
||||
typedef ConcatenateVectorCalculator<float> ConcatenateFloatVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateFloatVectorCalculator);
|
||||
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "ConcatenateInt32VectorCalculator"
|
||||
// input_stream: "int32_vector_1"
|
||||
// input_stream: "int32_vector_2"
|
||||
// output_stream: "concatenated_int32_vector"
|
||||
// }
|
||||
typedef ConcatenateVectorCalculator<int32> ConcatenateInt32VectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateInt32VectorCalculator);
|
||||
|
||||
typedef ConcatenateVectorCalculator<uint64> ConcatenateUInt64VectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator);
|
||||
|
||||
typedef ConcatenateVectorCalculator<bool> ConcatenateBoolVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator);
|
||||
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "ConcatenateTfLiteTensorVectorCalculator"
|
||||
// input_stream: "tflitetensor_vector_1"
|
||||
// input_stream: "tflitetensor_vector_2"
|
||||
// output_stream: "concatenated_tflitetensor_vector"
|
||||
// }
|
||||
typedef ConcatenateVectorCalculator<TfLiteTensor>
|
||||
ConcatenateTfLiteTensorVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateTfLiteTensorVectorCalculator);
|
||||
|
||||
typedef ConcatenateVectorCalculator<Tensor> ConcatenateTensorVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateTensorVectorCalculator);
|
||||
|
||||
typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmark>
|
||||
ConcatenateLandmarkVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkVectorCalculator);
|
||||
|
||||
typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmarkList>
|
||||
ConcatenateLandmarListVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarListVectorCalculator);
|
||||
|
||||
typedef ConcatenateVectorCalculator<mediapipe::ClassificationList>
|
||||
ConcatenateClassificationListVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListVectorCalculator);
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
typedef ConcatenateVectorCalculator<::tflite::gpu::gl::GlBuffer>
|
||||
ConcatenateGlBufferVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateGlBufferVectorCalculator);
|
||||
#endif
|
||||
|
||||
typedef ConcatenateVectorCalculator<mediapipe::RenderData>
|
||||
ConcatenateRenderDataVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
// Note: since this is a calculator template that can be included by other
|
||||
// source files, we do not place this in namespace api2 directly, but qualify
|
||||
// the api2 names below, to avoid changing the visible name of the class.
|
||||
// We cannot simply write "using mediapipe::api2" since it's a header file.
|
||||
// This distinction will go away once api2 is finalized.
|
||||
|
||||
// Concatenates several objects of type T or std::vector<T> following stream
|
||||
// index order. This class assumes that every input stream contains either T or
|
||||
// vector<T> type. To use this class for a particular type T, regisiter a
|
||||
// calculator using ConcatenateVectorCalculator<T>.
|
||||
template <typename T>
|
||||
class ConcatenateVectorCalculator : public api2::Node {
|
||||
public:
|
||||
static constexpr
|
||||
typename api2::Input<api2::OneOf<T, std::vector<T>>>::Multiple kIn{""};
|
||||
static constexpr api2::Output<std::vector<T>> kOut{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
RET_CHECK_GE(kIn(cc).Count(), 1);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
only_emit_if_all_present_ =
|
||||
cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>()
|
||||
.only_emit_if_all_present();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (only_emit_if_all_present_) {
|
||||
for (const auto& input : kIn(cc)) {
|
||||
if (input.IsEmpty()) return ::absl::OkStatus();
|
||||
}
|
||||
}
|
||||
return ConcatenateVectors<T>(std::is_copy_constructible<T>(), cc);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
absl::Status ConcatenateVectors(std::true_type, CalculatorContext* cc) {
|
||||
auto output = std::vector<U>();
|
||||
for (const auto& input : kIn(cc)) {
|
||||
if (input.IsEmpty()) continue;
|
||||
input.Visit([&output](const U& value) { output.push_back(value); },
|
||||
[&output](const std::vector<U>& value) {
|
||||
output.insert(output.end(), value.begin(), value.end());
|
||||
});
|
||||
}
|
||||
kOut(cc).Send(std::move(output));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
absl::Status ConcatenateVectors(std::false_type, CalculatorContext* cc) {
|
||||
return ConsumeAndConcatenateVectors<T>(std::is_move_constructible<U>(), cc);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
absl::Status ConsumeAndConcatenateVectors(std::true_type,
|
||||
CalculatorContext* cc) {
|
||||
auto output = std::vector<U>();
|
||||
for (auto input : kIn(cc)) {
|
||||
if (input.IsEmpty()) continue;
|
||||
MP_RETURN_IF_ERROR(input.ConsumeAndVisit(
|
||||
[&output](std::unique_ptr<U> value) {
|
||||
output.push_back(std::move(*value));
|
||||
},
|
||||
[&output](std::unique_ptr<std::vector<U>> value) {
|
||||
output.insert(output.end(), std::make_move_iterator(value->begin()),
|
||||
std::make_move_iterator(value->end()));
|
||||
}));
|
||||
}
|
||||
kOut(cc).Send(std::move(output));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
absl::Status ConsumeAndConcatenateVectors(std::false_type,
|
||||
CalculatorContext* cc) {
|
||||
return absl::InternalError(
|
||||
"Cannot copy or move inputs to concatenate them");
|
||||
}
|
||||
|
||||
private:
|
||||
bool only_emit_if_all_present_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message ConcatenateVectorCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional ConcatenateVectorCalculatorOptions ext = 259397839;
|
||||
}
|
||||
|
||||
// If true, the calculator will only emit a packet at the given timestamp if
|
||||
// all input streams have a non-empty packet (AND operation on streams).
|
||||
optional bool only_emit_if_all_present = 1 [default = false];
|
||||
}
|
||||
|
|
@ -1,548 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef ConcatenateVectorCalculator<int> TestConcatenateIntVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator);
|
||||
|
||||
void AddInputVector(int index, const std::vector<int>& input, int64 timestamp,
|
||||
CalculatorRunner* runner) {
|
||||
runner->MutableInputs()->Index(index).packets.push_back(
|
||||
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
|
||||
}
|
||||
|
||||
void AddInputVectors(const std::vector<std::vector<int>>& inputs,
|
||||
int64 timestamp, CalculatorRunner* runner) {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
AddInputVector(i, inputs[i], timestamp, runner);
|
||||
}
|
||||
}
|
||||
|
||||
void AddInputItem(int index, int input, int64 timestamp,
|
||||
CalculatorRunner* runner) {
|
||||
runner->MutableInputs()->Index(index).packets.push_back(
|
||||
MakePacket<int>(input).At(Timestamp(timestamp)));
|
||||
}
|
||||
|
||||
void AddInputItems(const std::vector<int>& inputs, int64 timestamp,
|
||||
CalculatorRunner* runner) {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
AddInputItem(i, inputs[i], timestamp, runner);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, EmptyVectorInputs) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<int>> inputs = {{}, {}, {}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_TRUE(outputs[0].Get<std::vector<int>>().empty());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, OneTimestamp) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<int>> inputs = {{1, 2, 3}, {4}, {5, 6}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, TwoInputsAtTwoTimestamps) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
{
|
||||
std::vector<std::vector<int>> inputs = {{1, 2, 3}, {4}, {5, 6}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
}
|
||||
{
|
||||
std::vector<std::vector<int>> inputs = {{0, 2}, {1}, {3, 5}};
|
||||
AddInputVectors(inputs, /*timestamp=*/2, &runner);
|
||||
}
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(2, outputs.size());
|
||||
{
|
||||
EXPECT_EQ(6, outputs[0].Get<std::vector<int>>().size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
|
||||
}
|
||||
{
|
||||
EXPECT_EQ(5, outputs[1].Get<std::vector<int>>().size());
|
||||
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||
std::vector<int> expected_vector = {0, 2, 1, 3, 5};
|
||||
EXPECT_EQ(expected_vector, outputs[1].Get<std::vector<int>>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, OneEmptyStreamStillOutput) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<int>> inputs = {{1, 2, 3}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<int> expected_vector = {1, 2, 3};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, OneEmptyStreamNoOutput) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/
|
||||
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
|
||||
"{only_emit_if_all_present: true}",
|
||||
/*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<int>> inputs = {{1, 2, 3}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, ItemsOneTimestamp) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<int> inputs = {1, 2, 3};
|
||||
AddInputItems(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<int> expected_vector = {1, 2, 3};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, ItemsTwoInputsAtTwoTimestamps) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
{
|
||||
std::vector<int> inputs = {1, 2, 3};
|
||||
AddInputItems(inputs, /*timestamp=*/1, &runner);
|
||||
}
|
||||
{
|
||||
std::vector<int> inputs = {4, 5, 6};
|
||||
AddInputItems(inputs, /*timestamp=*/2, &runner);
|
||||
}
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(2, outputs.size());
|
||||
{
|
||||
EXPECT_EQ(3, outputs[0].Get<std::vector<int>>().size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<int> expected_vector = {1, 2, 3};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
|
||||
}
|
||||
{
|
||||
EXPECT_EQ(3, outputs[1].Get<std::vector<int>>().size());
|
||||
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||
std::vector<int> expected_vector = {4, 5, 6};
|
||||
EXPECT_EQ(expected_vector, outputs[1].Get<std::vector<int>>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, ItemsOneEmptyStreamStillOutput) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
// No third input item.
|
||||
std::vector<int> inputs = {1, 2};
|
||||
AddInputItems(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<int> expected_vector = {1, 2};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, ItemsOneEmptyStreamNoOutput) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/
|
||||
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
|
||||
"{only_emit_if_all_present: true}",
|
||||
/*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
// No third input item.
|
||||
std::vector<int> inputs = {1, 2};
|
||||
AddInputItems(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, MixedVectorsAndItems) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/4,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<int> vector_0 = {1, 2};
|
||||
std::vector<int> vector_1 = {3, 4, 5};
|
||||
int item_0 = 6;
|
||||
int item_1 = 7;
|
||||
|
||||
AddInputVector(/*index*/ 0, vector_0, /*timestamp=*/1, &runner);
|
||||
AddInputVector(/*index*/ 1, vector_1, /*timestamp=*/1, &runner);
|
||||
AddInputItem(/*index*/ 2, item_0, /*timestamp=*/1, &runner);
|
||||
AddInputItem(/*index*/ 3, item_1, /*timestamp=*/1, &runner);
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6, 7};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
|
||||
}
|
||||
|
||||
TEST(TestConcatenateIntVectorCalculatorTest, MixedVectorsAndItemsAnother) {
|
||||
CalculatorRunner runner("TestConcatenateIntVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/4,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
int item_0 = 1;
|
||||
std::vector<int> vector_0 = {2, 3};
|
||||
std::vector<int> vector_1 = {4, 5, 6};
|
||||
int item_1 = 7;
|
||||
|
||||
AddInputItem(/*index*/ 0, item_0, /*timestamp=*/1, &runner);
|
||||
AddInputVector(/*index*/ 1, vector_0, /*timestamp=*/1, &runner);
|
||||
AddInputVector(/*index*/ 2, vector_1, /*timestamp=*/1, &runner);
|
||||
AddInputItem(/*index*/ 3, item_1, /*timestamp=*/1, &runner);
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<int> expected_vector = {1, 2, 3, 4, 5, 6, 7};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<int>>());
|
||||
}
|
||||
|
||||
void AddInputVectors(const std::vector<std::vector<float>>& inputs,
|
||||
int64 timestamp, CalculatorRunner* runner) {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
runner->MutableInputs()->Index(i).packets.push_back(
|
||||
MakePacket<std::vector<float>>(inputs[i]).At(Timestamp(timestamp)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ConcatenateFloatVectorCalculatorTest, EmptyVectorInputs) {
|
||||
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<float>> inputs = {{}, {}, {}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_TRUE(outputs[0].Get<std::vector<float>>().empty());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
TEST(ConcatenateFloatVectorCalculatorTest, OneTimestamp) {
|
||||
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<float>> inputs = {
|
||||
{1.0f, 2.0f, 3.0f}, {4.0f}, {5.0f, 6.0f}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
|
||||
}
|
||||
|
||||
TEST(ConcatenateFloatVectorCalculatorTest, TwoInputsAtTwoTimestamps) {
|
||||
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
{
|
||||
std::vector<std::vector<float>> inputs = {
|
||||
{1.0f, 2.0f, 3.0f}, {4.0f}, {5.0f, 6.0f}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
}
|
||||
{
|
||||
std::vector<std::vector<float>> inputs = {
|
||||
{0.0f, 2.0f}, {1.0f}, {3.0f, 5.0f}};
|
||||
AddInputVectors(inputs, /*timestamp=*/2, &runner);
|
||||
}
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(2, outputs.size());
|
||||
{
|
||||
EXPECT_EQ(6, outputs[0].Get<std::vector<float>>().size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
|
||||
}
|
||||
{
|
||||
EXPECT_EQ(5, outputs[1].Get<std::vector<float>>().size());
|
||||
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||
std::vector<float> expected_vector = {0.0f, 2.0f, 1.0f, 3.0f, 5.0f};
|
||||
EXPECT_EQ(expected_vector, outputs[1].Get<std::vector<float>>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamStillOutput) {
|
||||
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<float>> inputs = {{1.0f, 2.0f, 3.0f}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
std::vector<float> expected_vector = {1.0f, 2.0f, 3.0f};
|
||||
EXPECT_EQ(expected_vector, outputs[0].Get<std::vector<float>>());
|
||||
}
|
||||
|
||||
TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) {
|
||||
CalculatorRunner runner("ConcatenateFloatVectorCalculator",
|
||||
/*options_string=*/
|
||||
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
|
||||
"{only_emit_if_all_present: true}",
|
||||
/*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
std::vector<std::vector<float>> inputs = {{1.0f, 2.0f, 3.0f}};
|
||||
AddInputVectors(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
}
|
||||
|
||||
typedef ConcatenateVectorCalculator<std::unique_ptr<int>>
|
||||
TestConcatenateUniqueIntPtrCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator);
|
||||
|
||||
TEST(TestConcatenateUniqueIntVectorCalculatorTest, ConsumeOneTimestamp) {
|
||||
/* Note: We don't use CalculatorRunner for this test because it keeps copies
|
||||
* of input packets, so packets sent to the graph don't have sole ownership.
|
||||
* The test needs to send packets that own the data.
|
||||
*/
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "in_1"
|
||||
input_stream: "in_2"
|
||||
input_stream: "in_3"
|
||||
node {
|
||||
calculator: "TestConcatenateUniqueIntPtrCalculator"
|
||||
input_stream: "in_1"
|
||||
input_stream: "in_2"
|
||||
input_stream: "in_3"
|
||||
output_stream: "out"
|
||||
}
|
||||
)pb");
|
||||
|
||||
std::vector<Packet> outputs;
|
||||
tool::AddVectorSink("out", &graph_config, &outputs);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(graph_config));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
|
||||
// input1 : {0, 1, 2}
|
||||
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_1 =
|
||||
absl::make_unique<std::vector<std::unique_ptr<int>>>(3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
input_1->at(i) = absl::make_unique<int>(i);
|
||||
}
|
||||
// input2: {3}
|
||||
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_2 =
|
||||
absl::make_unique<std::vector<std::unique_ptr<int>>>(1);
|
||||
input_2->at(0) = absl::make_unique<int>(3);
|
||||
// input3: {4, 5}
|
||||
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_3 =
|
||||
absl::make_unique<std::vector<std::unique_ptr<int>>>(2);
|
||||
input_3->at(0) = absl::make_unique<int>(4);
|
||||
input_3->at(1) = absl::make_unique<int>(5);
|
||||
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in_1", Adopt(input_1.release()).At(Timestamp(1))));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in_2", Adopt(input_2.release()).At(Timestamp(1))));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in_3", Adopt(input_3.release()).At(Timestamp(1))));
|
||||
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const std::vector<std::unique_ptr<int>>& result =
|
||||
outputs[0].Get<std::vector<std::unique_ptr<int>>>();
|
||||
EXPECT_EQ(6, result.size());
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
const std::unique_ptr<int>& v = result[i];
|
||||
EXPECT_EQ(i, *v);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestConcatenateUniqueIntVectorCalculatorTest, OneEmptyStreamStillOutput) {
|
||||
/* Note: We don't use CalculatorRunner for this test because it keeps copies
|
||||
* of input packets, so packets sent to the graph don't have sole ownership.
|
||||
* The test needs to send packets that own the data.
|
||||
*/
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "in_1"
|
||||
input_stream: "in_2"
|
||||
node {
|
||||
calculator: "TestConcatenateUniqueIntPtrCalculator"
|
||||
input_stream: "in_1"
|
||||
input_stream: "in_2"
|
||||
output_stream: "out"
|
||||
}
|
||||
)pb");
|
||||
|
||||
std::vector<Packet> outputs;
|
||||
tool::AddVectorSink("out", &graph_config, &outputs);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(graph_config));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
|
||||
// input1 : {0, 1, 2}
|
||||
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_1 =
|
||||
absl::make_unique<std::vector<std::unique_ptr<int>>>(3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
input_1->at(i) = absl::make_unique<int>(i);
|
||||
}
|
||||
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in_1", Adopt(input_1.release()).At(Timestamp(1))));
|
||||
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const std::vector<std::unique_ptr<int>>& result =
|
||||
outputs[0].Get<std::vector<std::unique_ptr<int>>>();
|
||||
EXPECT_EQ(3, result.size());
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
const std::unique_ptr<int>& v = result[i];
|
||||
EXPECT_EQ(i, *v);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestConcatenateUniqueIntVectorCalculatorTest, OneEmptyStreamNoOutput) {
|
||||
/* Note: We don't use CalculatorRunner for this test because it keeps copies
|
||||
* of input packets, so packets sent to the graph don't have sole ownership.
|
||||
* The test needs to send packets that own the data.
|
||||
*/
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "in_1"
|
||||
input_stream: "in_2"
|
||||
node {
|
||||
calculator: "TestConcatenateUniqueIntPtrCalculator"
|
||||
input_stream: "in_1"
|
||||
input_stream: "in_2"
|
||||
output_stream: "out"
|
||||
options {
|
||||
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||
only_emit_if_all_present: true
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
|
||||
std::vector<Packet> outputs;
|
||||
tool::AddVectorSink("out", &graph_config, &outputs);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(graph_config));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
|
||||
// input1 : {0, 1, 2}
|
||||
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_1 =
|
||||
absl::make_unique<std::vector<std::unique_ptr<int>>>(3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
input_1->at(i) = absl::make_unique<int>(i);
|
||||
}
|
||||
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in_1", Adopt(input_1.release()).At(Timestamp(1))));
|
||||
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {} // namespace
|
||||
|
||||
// Generates an output side packet or multiple output side packets according to
|
||||
// the specified options.
|
||||
//
|
||||
// Example configs:
|
||||
// node {
|
||||
// calculator: "ConstantSidePacketCalculator"
|
||||
// output_side_packet: "PACKET:packet"
|
||||
// options: {
|
||||
// [mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
// packet { int_value: 2 }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// node {
|
||||
// calculator: "ConstantSidePacketCalculator"
|
||||
// output_side_packet: "PACKET:0:int_packet"
|
||||
// output_side_packet: "PACKET:1:bool_packet"
|
||||
// options: {
|
||||
// [mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
// packet { int_value: 2 }
|
||||
// packet { bool_value: true }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class ConstantSidePacketCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
const auto& options =
|
||||
cc->Options<::mediapipe::ConstantSidePacketCalculatorOptions>();
|
||||
RET_CHECK_EQ(cc->OutputSidePackets().NumEntries(kPacketTag),
|
||||
options.packet_size())
|
||||
<< "Number of output side packets has to be same as number of packets "
|
||||
"configured in options.";
|
||||
|
||||
int index = 0;
|
||||
for (CollectionItemId id = cc->OutputSidePackets().BeginId(kPacketTag);
|
||||
id != cc->OutputSidePackets().EndId(kPacketTag); ++id, ++index) {
|
||||
const auto& packet_options = options.packet(index);
|
||||
auto& packet = cc->OutputSidePackets().Get(id);
|
||||
if (packet_options.has_int_value()) {
|
||||
packet.Set<int>();
|
||||
} else if (packet_options.has_float_value()) {
|
||||
packet.Set<float>();
|
||||
} else if (packet_options.has_bool_value()) {
|
||||
packet.Set<bool>();
|
||||
} else if (packet_options.has_string_value()) {
|
||||
packet.Set<std::string>();
|
||||
} else if (packet_options.has_uint64_value()) {
|
||||
packet.Set<uint64>();
|
||||
} else if (packet_options.has_classification_list_value()) {
|
||||
packet.Set<ClassificationList>();
|
||||
} else {
|
||||
return absl::InvalidArgumentError(
|
||||
"None of supported values were specified in options.");
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
const auto& options =
|
||||
cc->Options<::mediapipe::ConstantSidePacketCalculatorOptions>();
|
||||
int index = 0;
|
||||
for (CollectionItemId id = cc->OutputSidePackets().BeginId(kPacketTag);
|
||||
id != cc->OutputSidePackets().EndId(kPacketTag); ++id, ++index) {
|
||||
auto& packet = cc->OutputSidePackets().Get(id);
|
||||
const auto& packet_options = options.packet(index);
|
||||
if (packet_options.has_int_value()) {
|
||||
packet.Set(MakePacket<int>(packet_options.int_value()));
|
||||
} else if (packet_options.has_float_value()) {
|
||||
packet.Set(MakePacket<float>(packet_options.float_value()));
|
||||
} else if (packet_options.has_bool_value()) {
|
||||
packet.Set(MakePacket<bool>(packet_options.bool_value()));
|
||||
} else if (packet_options.has_string_value()) {
|
||||
packet.Set(MakePacket<std::string>(packet_options.string_value()));
|
||||
} else if (packet_options.has_uint64_value()) {
|
||||
packet.Set(MakePacket<uint64>(packet_options.uint64_value()));
|
||||
} else if (packet_options.has_classification_list_value()) {
|
||||
packet.Set(MakePacket<ClassificationList>(
|
||||
packet_options.classification_list_value()));
|
||||
} else {
|
||||
return absl::InvalidArgumentError(
|
||||
"None of supported values were specified in options.");
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr const char* kPacketTag = "PACKET";
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(ConstantSidePacketCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/framework/formats/classification.proto";
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message ConstantSidePacketCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional ConstantSidePacketCalculatorOptions ext = 291214597;
|
||||
}
|
||||
|
||||
message ConstantSidePacket {
|
||||
oneof value {
|
||||
int32 int_value = 1;
|
||||
float float_value = 2;
|
||||
bool bool_value = 3;
|
||||
string string_value = 4;
|
||||
uint64 uint64_value = 5;
|
||||
ClassificationList classification_list_value = 6;
|
||||
}
|
||||
}
|
||||
|
||||
repeated ConstantSidePacket packet = 1;
|
||||
}
|
||||
|
|
@ -1,189 +0,0 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
template <typename T>
|
||||
void DoTestSingleSidePacket(absl::string_view packet_spec,
|
||||
const T& expected_value) {
|
||||
static constexpr absl::string_view graph_config_template = R"(
|
||||
node {
|
||||
calculator: "ConstantSidePacketCalculator"
|
||||
output_side_packet: "PACKET:packet"
|
||||
options: {
|
||||
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
packet $0
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
CalculatorGraphConfig graph_config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
absl::Substitute(graph_config_template, packet_spec));
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
MP_ASSERT_OK(graph.GetOutputSidePacket("packet"));
|
||||
auto actual_value =
|
||||
graph.GetOutputSidePacket("packet").value().template Get<T>();
|
||||
EXPECT_EQ(actual_value, expected_value);
|
||||
}
|
||||
|
||||
TEST(ConstantSidePacketCalculatorTest, EveryPossibleType) {
|
||||
DoTestSingleSidePacket("{ int_value: 2 }", 2);
|
||||
DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
|
||||
DoTestSingleSidePacket("{ bool_value: true }", true);
|
||||
DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
|
||||
}
|
||||
|
||||
TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "ConstantSidePacketCalculator"
|
||||
output_side_packet: "PACKET:0:int_packet"
|
||||
output_side_packet: "PACKET:1:float_packet"
|
||||
output_side_packet: "PACKET:2:bool_packet"
|
||||
output_side_packet: "PACKET:3:string_packet"
|
||||
output_side_packet: "PACKET:4:another_string_packet"
|
||||
output_side_packet: "PACKET:5:another_int_packet"
|
||||
options: {
|
||||
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
packet { int_value: 256 }
|
||||
packet { float_value: 0.5f }
|
||||
packet { bool_value: false }
|
||||
packet { string_value: "string" }
|
||||
packet { string_value: "another string" }
|
||||
packet { int_value: 128 }
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
MP_ASSERT_OK(graph.GetOutputSidePacket("int_packet"));
|
||||
EXPECT_EQ(graph.GetOutputSidePacket("int_packet").value().Get<int>(), 256);
|
||||
MP_ASSERT_OK(graph.GetOutputSidePacket("float_packet"));
|
||||
EXPECT_EQ(graph.GetOutputSidePacket("float_packet").value().Get<float>(),
|
||||
0.5f);
|
||||
MP_ASSERT_OK(graph.GetOutputSidePacket("bool_packet"));
|
||||
EXPECT_FALSE(graph.GetOutputSidePacket("bool_packet").value().Get<bool>());
|
||||
MP_ASSERT_OK(graph.GetOutputSidePacket("string_packet"));
|
||||
EXPECT_EQ(
|
||||
graph.GetOutputSidePacket("string_packet").value().Get<std::string>(),
|
||||
"string");
|
||||
MP_ASSERT_OK(graph.GetOutputSidePacket("another_string_packet"));
|
||||
EXPECT_EQ(graph.GetOutputSidePacket("another_string_packet")
|
||||
.value()
|
||||
.Get<std::string>(),
|
||||
"another string");
|
||||
MP_ASSERT_OK(graph.GetOutputSidePacket("another_int_packet"));
|
||||
EXPECT_EQ(graph.GetOutputSidePacket("another_int_packet").value().Get<int>(),
|
||||
128);
|
||||
}
|
||||
|
||||
TEST(ConstantSidePacketCalculatorTest, ProcessingPacketsWithCorrectTagOnly) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "ConstantSidePacketCalculator"
|
||||
output_side_packet: "PACKET:0:int_packet"
|
||||
output_side_packet: "no_tag0"
|
||||
output_side_packet: "PACKET:1:float_packet"
|
||||
output_side_packet: "INCORRECT_TAG:0:name1"
|
||||
output_side_packet: "PACKET:2:bool_packet"
|
||||
output_side_packet: "PACKET:3:string_packet"
|
||||
output_side_packet: "no_tag2"
|
||||
output_side_packet: "INCORRECT_TAG:1:name2"
|
||||
options: {
|
||||
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
packet { int_value: 256 }
|
||||
packet { float_value: 0.5f }
|
||||
packet { bool_value: false }
|
||||
packet { string_value: "string" }
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
MP_ASSERT_OK(graph.GetOutputSidePacket("int_packet"));
|
||||
EXPECT_EQ(graph.GetOutputSidePacket("int_packet").value().Get<int>(), 256);
|
||||
MP_ASSERT_OK(graph.GetOutputSidePacket("float_packet"));
|
||||
EXPECT_EQ(graph.GetOutputSidePacket("float_packet").value().Get<float>(),
|
||||
0.5f);
|
||||
MP_ASSERT_OK(graph.GetOutputSidePacket("bool_packet"));
|
||||
EXPECT_FALSE(graph.GetOutputSidePacket("bool_packet").value().Get<bool>());
|
||||
MP_ASSERT_OK(graph.GetOutputSidePacket("string_packet"));
|
||||
EXPECT_EQ(
|
||||
graph.GetOutputSidePacket("string_packet").value().Get<std::string>(),
|
||||
"string");
|
||||
}
|
||||
|
||||
TEST(ConstantSidePacketCalculatorTest, IncorrectConfig_MoreOptionsThanPackets) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "ConstantSidePacketCalculator"
|
||||
output_side_packet: "PACKET:int_packet"
|
||||
options: {
|
||||
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
packet { int_value: 256 }
|
||||
packet { float_value: 0.5f }
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorGraph graph;
|
||||
EXPECT_FALSE(graph.Initialize(graph_config).ok());
|
||||
}
|
||||
|
||||
TEST(ConstantSidePacketCalculatorTest, IncorrectConfig_MorePacketsThanOptions) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "ConstantSidePacketCalculator"
|
||||
output_side_packet: "PACKET:0:int_packet"
|
||||
output_side_packet: "PACKET:1:float_packet"
|
||||
options: {
|
||||
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
packet { int_value: 256 }
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorGraph graph;
|
||||
EXPECT_FALSE(graph.Initialize(graph_config).ok());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,114 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of
|
||||
// sequential numbers from INITIAL_VALUE (default 0) with a common
|
||||
// difference of INCREMENT (default 1) between successive numbers (with
|
||||
// timestamps corresponding to the sequence numbers). The packets are
|
||||
// produced in BATCH_SIZE sized batches with each call to Process(). An
|
||||
// error will be returned after ERROR_COUNT batches. An error will be
|
||||
// produced in Open() if ERROR_ON_OPEN is true. Either MAX_COUNT or
|
||||
// ERROR_COUNT must be provided and non-negative. If BATCH_SIZE is not
|
||||
// provided, then batches are of size 1.
|
||||
class CountingSourceCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Outputs().Index(0).Set<int>();
|
||||
|
||||
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) {
|
||||
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set<bool>();
|
||||
}
|
||||
|
||||
RET_CHECK(cc->InputSidePackets().HasTag("MAX_COUNT") ||
|
||||
cc->InputSidePackets().HasTag("ERROR_COUNT"));
|
||||
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
|
||||
cc->InputSidePackets().Tag("MAX_COUNT").Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
|
||||
cc->InputSidePackets().Tag("ERROR_COUNT").Set<int>();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
|
||||
cc->InputSidePackets().Tag("BATCH_SIZE").Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
|
||||
cc->InputSidePackets().Tag("INITIAL_VALUE").Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("INCREMENT")) {
|
||||
cc->InputSidePackets().Tag("INCREMENT").Set<int>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") &&
|
||||
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get<bool>()) {
|
||||
return absl::NotFoundError("expected error");
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
|
||||
error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get<int>();
|
||||
RET_CHECK_LE(0, error_count_);
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
|
||||
max_count_ = cc->InputSidePackets().Tag("MAX_COUNT").Get<int>();
|
||||
RET_CHECK_LE(0, max_count_);
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
|
||||
batch_size_ = cc->InputSidePackets().Tag("BATCH_SIZE").Get<int>();
|
||||
RET_CHECK_LT(0, batch_size_);
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
|
||||
counter_ = cc->InputSidePackets().Tag("INITIAL_VALUE").Get<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("INCREMENT")) {
|
||||
increment_ = cc->InputSidePackets().Tag("INCREMENT").Get<int>();
|
||||
RET_CHECK_LT(0, increment_);
|
||||
}
|
||||
RET_CHECK(error_count_ >= 0 || max_count_ >= 0);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (error_count_ >= 0 && batch_counter_ >= error_count_) {
|
||||
return absl::InternalError("expected error");
|
||||
}
|
||||
if (max_count_ >= 0 && batch_counter_ >= max_count_) {
|
||||
return tool::StatusStop();
|
||||
}
|
||||
for (int i = 0; i < batch_size_; ++i) {
|
||||
cc->Outputs().Index(0).Add(new int(counter_), Timestamp(counter_));
|
||||
counter_ += increment_;
|
||||
}
|
||||
++batch_counter_;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int max_count_ = -1;
|
||||
int error_count_ = -1;
|
||||
int batch_size_ = 1;
|
||||
int batch_counter_ = 0;
|
||||
int counter_ = 0;
|
||||
int increment_ = 1;
|
||||
};
|
||||
REGISTER_CALCULATOR(CountingSourceCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kOptionalValueTag[] = "OPTIONAL_VALUE";
|
||||
constexpr char kDefaultValueTag[] = "DEFAULT_VALUE";
|
||||
constexpr char kValueTag[] = "VALUE";
|
||||
|
||||
} // namespace
|
||||
|
||||
// Outputs side packet default value if optional value is not provided.
|
||||
//
|
||||
// This calculator utilizes the fact that MediaPipe automatically removes
|
||||
// optional side packets of the calculator configuration (i.e. OPTIONAL_VALUE).
|
||||
// And if it happens - returns default value, otherwise - returns optional
|
||||
// value.
|
||||
//
|
||||
// Input:
|
||||
// OPTIONAL_VALUE (optional) - AnyType (but same type as DEFAULT_VALUE)
|
||||
// Optional side packet value that is outputted by the calculator as is if
|
||||
// provided.
|
||||
//
|
||||
// DEFAULT_VALUE - AnyType
|
||||
// Default side pack value that is outputted by the calculator if
|
||||
// OPTIONAL_VALUE is not provided.
|
||||
//
|
||||
// Output:
|
||||
// VALUE - AnyType (but same type as DEFAULT_VALUE)
|
||||
// Either OPTIONAL_VALUE (if provided) or DEFAULT_VALUE (otherwise).
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "DefaultSidePacketCalculator"
|
||||
// input_side_packet: "OPTIONAL_VALUE:segmentation_mask_enabled_optional"
|
||||
// input_side_packet: "DEFAULT_VALUE:segmentation_mask_enabled_default"
|
||||
// output_side_packet: "VALUE:segmentation_mask_enabled"
|
||||
// }
|
||||
class DefaultSidePacketCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
REGISTER_CALCULATOR(DefaultSidePacketCalculator);
|
||||
|
||||
absl::Status DefaultSidePacketCalculator::GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->InputSidePackets().HasTag(kDefaultValueTag))
|
||||
<< "Default value must be provided";
|
||||
cc->InputSidePackets().Tag(kDefaultValueTag).SetAny();
|
||||
|
||||
// Optional input side packet can be unspecified. In this case MediaPipe will
|
||||
// remove it from the calculator config.
|
||||
if (cc->InputSidePackets().HasTag(kOptionalValueTag)) {
|
||||
cc->InputSidePackets()
|
||||
.Tag(kOptionalValueTag)
|
||||
.SetSameAs(&cc->InputSidePackets().Tag(kDefaultValueTag))
|
||||
.Optional();
|
||||
}
|
||||
|
||||
RET_CHECK(cc->OutputSidePackets().HasTag(kValueTag));
|
||||
cc->OutputSidePackets().Tag(kValueTag).SetSameAs(
|
||||
&cc->InputSidePackets().Tag(kDefaultValueTag));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status DefaultSidePacketCalculator::Open(CalculatorContext* cc) {
|
||||
// If optional value is provided it is returned as the calculator output.
|
||||
if (cc->InputSidePackets().HasTag(kOptionalValueTag)) {
|
||||
auto& packet = cc->InputSidePackets().Tag(kOptionalValueTag);
|
||||
cc->OutputSidePackets().Tag(kValueTag).Set(packet);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// If no optional value
|
||||
auto& packet = cc->InputSidePackets().Tag(kDefaultValueTag);
|
||||
cc->OutputSidePackets().Tag(kValueTag).Set(packet);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status DefaultSidePacketCalculator::Process(CalculatorContext* cc) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cfloat>
|
||||
|
||||
#include "mediapipe/calculators/core/dequantize_byte_array_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
// Dequantizes a byte array to a vector of floats.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "DequantizeByteArrayCalculator"
|
||||
// input_stream: "ENCODED:encoded"
|
||||
// output_stream: "FLOAT_VECTOR:float_vector"
|
||||
// options {
|
||||
// [mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
// max_quantized_value: 2
|
||||
// min_quantized_value: -2
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
namespace mediapipe {
|
||||
|
||||
class DequantizeByteArrayCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("ENCODED").Set<std::string>();
|
||||
cc->Outputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
const auto options =
|
||||
cc->Options<::mediapipe::DequantizeByteArrayCalculatorOptions>();
|
||||
if (!options.has_max_quantized_value() ||
|
||||
!options.has_min_quantized_value()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Both max_quantized_value and min_quantized_value must be provided "
|
||||
"in DequantizeByteArrayCalculatorOptions.");
|
||||
}
|
||||
float max_quantized_value = options.max_quantized_value();
|
||||
float min_quantized_value = options.min_quantized_value();
|
||||
if (max_quantized_value < min_quantized_value + FLT_EPSILON) {
|
||||
return absl::InvalidArgumentError(
|
||||
"max_quantized_value must be greater than min_quantized_value.");
|
||||
}
|
||||
float range = max_quantized_value - min_quantized_value;
|
||||
scalar_ = range / 255.0;
|
||||
bias_ = (range / 512.0) + min_quantized_value;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
const std::string& encoded =
|
||||
cc->Inputs().Tag("ENCODED").Value().Get<std::string>();
|
||||
std::vector<float> float_vector;
|
||||
float_vector.reserve(encoded.length());
|
||||
for (int i = 0; i < encoded.length(); ++i) {
|
||||
float_vector.push_back(
|
||||
static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("FLOAT_VECTOR")
|
||||
.AddPacket(MakePacket<std::vector<float>>(float_vector)
|
||||
.At(cc->InputTimestamp()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
float scalar_;
|
||||
float bias_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(DequantizeByteArrayCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message DequantizeByteArrayCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional DequantizeByteArrayCalculatorOptions ext = 272316343;
|
||||
}
|
||||
|
||||
optional float max_quantized_value = 1;
|
||||
optional float min_quantized_value = 2;
|
||||
}
|
||||
|
|
@ -1,137 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: 2
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"Both max_quantized_value and min_quantized_value must be provided"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: -2
|
||||
min_quantized_value: 2
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"max_quantized_value must be greater than min_quantized_value"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: 1
|
||||
min_quantized_value: 1
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"max_quantized_value must be greater than min_quantized_value"));
|
||||
}
|
||||
|
||||
TEST(DequantizeByteArrayCalculatorTest, TestDequantization) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: 2
|
||||
min_quantized_value: -2
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01};
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(
|
||||
std::string(reinterpret_cast<char const*>(input), 4))
|
||||
.At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs =
|
||||
runner.Outputs().Tag("FLOAT_VECTOR").packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
const std::vector<float>& result = outputs[0].Get<std::vector<float>>();
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_EQ(4, result.size());
|
||||
EXPECT_NEAR(0, result[0], 0.01);
|
||||
EXPECT_NEAR(2, result[1], 0.01);
|
||||
EXPECT_NEAR(-2, result[2], 0.01);
|
||||
EXPECT_NEAR(-1.976, result[3], 0.01);
|
||||
|
||||
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/end_loop_calculator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef EndLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
|
||||
EndLoopNormalizedRectCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopNormalizedRectCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<::mediapipe::LandmarkList>>
|
||||
EndLoopLandmarkListVectorCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopLandmarkListVectorCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<::mediapipe::NormalizedLandmarkList>>
|
||||
EndLoopNormalizedLandmarkListVectorCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopNormalizedLandmarkListVectorCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<bool>> EndLoopBooleanCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopBooleanCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>>
|
||||
EndLoopRenderDataCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopRenderDataCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<::mediapipe::ClassificationList>>
|
||||
EndLoopClassificationListCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopClassificationListCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<TfLiteTensor>> EndLoopTensorCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopTensorCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
||||
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Calculator for completing the processing of loops on iterable collections
|
||||
// inside a MediaPipe graph. The EndLoopCalculator collects all input packets
|
||||
// from ITEM input_stream into a collection and upon receiving the flush signal
|
||||
// from the "BATCH_END" tagged input stream, it emits the aggregated results
|
||||
// at the original timestamp contained in the "BATCH_END" input stream.
|
||||
//
|
||||
// It is designed to be used like:
|
||||
//
|
||||
// node {
|
||||
// calculator: "BeginLoopWithIterableCalculator"
|
||||
// input_stream: "ITERABLE:input_iterable" # IterableT @ext_ts
|
||||
// output_stream: "ITEM:input_element" # ItemT @loop_internal_ts
|
||||
// output_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
|
||||
// }
|
||||
//
|
||||
// node {
|
||||
// calculator: "ElementToBlaConverterSubgraph"
|
||||
// input_stream: "ITEM:input_to_loop_body" # ItemT @loop_internal_ts
|
||||
// output_stream: "BLA:output_of_loop_body" # ItemU @loop_internal_ts
|
||||
// }
|
||||
//
|
||||
// node {
|
||||
// calculator: "EndLoopWithOutputCalculator"
|
||||
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
|
||||
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
|
||||
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
|
||||
// }
|
||||
template <typename IterableT>
|
||||
class EndLoopCalculator : public CalculatorBase {
|
||||
using ItemT = typename IterableT::value_type;
|
||||
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("BATCH_END"))
|
||||
<< "Missing BATCH_END tagged input_stream.";
|
||||
cc->Inputs().Tag("BATCH_END").Set<Timestamp>();
|
||||
|
||||
RET_CHECK(cc->Inputs().HasTag("ITEM"));
|
||||
cc->Inputs().Tag("ITEM").Set<ItemT>();
|
||||
|
||||
RET_CHECK(cc->Outputs().HasTag("ITERABLE"));
|
||||
cc->Outputs().Tag("ITERABLE").Set<IterableT>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (!cc->Inputs().Tag("ITEM").IsEmpty()) {
|
||||
if (!input_stream_collection_) {
|
||||
input_stream_collection_.reset(new IterableT);
|
||||
}
|
||||
input_stream_collection_->push_back(
|
||||
cc->Inputs().Tag("ITEM").template Get<ItemT>());
|
||||
}
|
||||
|
||||
if (!cc->Inputs().Tag("BATCH_END").Value().IsEmpty()) { // flush signal
|
||||
Timestamp loop_control_ts =
|
||||
cc->Inputs().Tag("BATCH_END").template Get<Timestamp>();
|
||||
if (input_stream_collection_) {
|
||||
cc->Outputs()
|
||||
.Tag("ITERABLE")
|
||||
.Add(input_stream_collection_.release(), loop_control_ts);
|
||||
} else {
|
||||
// Since there is no collection, inform downstream calculators to not
|
||||
// expect any packet by updating the timestamp bounds.
|
||||
cc->Outputs()
|
||||
.Tag("ITERABLE")
|
||||
.SetNextTimestampBound(Timestamp(loop_control_ts.Value() + 1));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<IterableT> input_stream_collection_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
||||
|
|
@ -1,229 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/header_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// FlowLimiterCalculator is used to limit the number of frames in flight
|
||||
// by dropping input frames when necessary.
|
||||
//
|
||||
// The input stream "FINISH" is used to signal the FlowLimiterCalculator
|
||||
// when a frame is finished processing. Either a non-empty "FINISH" packet
|
||||
// or a timestamp bound should be received for each processed frame.
|
||||
//
|
||||
// The combination of `max_in_flight: 1` and `max_in_queue: 1` generally gives
|
||||
// best throughput/latency balance. Throughput is nearly optimal as the
|
||||
// graph is never idle as there is always something in the queue. Latency is
|
||||
// nearly optimal latency as the queue always stores the latest available frame.
|
||||
//
|
||||
// Increasing `max_in_flight` to 2 or more can yield the better throughput
|
||||
// when the graph exhibits a high degree of pipeline parallelism. Decreasing
|
||||
// `max_in_flight` to 0 can yield a better average latency, but at the cost of
|
||||
// lower throughput (lower framerate) due to the time during which the graph
|
||||
// is idle awaiting the next input frame.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "FlowLimiterCalculator"
|
||||
// input_stream: "raw_frames"
|
||||
// input_stream: "FINISHED:finished"
|
||||
// input_stream_info: {
|
||||
// tag_index: 'FINISHED'
|
||||
// back_edge: true
|
||||
// }
|
||||
// output_stream: "sampled_frames"
|
||||
// output_stream: "ALLOW:allowed_timestamps"
|
||||
// }
|
||||
//
|
||||
// The "ALLOW" stream indicates the transition between accepting frames and
|
||||
// dropping frames. "ALLOW = true" indicates the start of accepting frames
|
||||
// including the current timestamp, and "ALLOW = false" indicates the start of
|
||||
// dropping frames including the current timestamp.
|
||||
//
|
||||
// FlowLimiterCalculator provides limited support for multiple input streams.
|
||||
// The first input stream is treated as the main input stream and successive
|
||||
// input streams are treated as auxiliary input streams. The auxiliary input
|
||||
// streams are limited to timestamps passed on the main input stream.
|
||||
//
|
||||
class FlowLimiterCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
auto& side_inputs = cc->InputSidePackets();
|
||||
side_inputs.Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional();
|
||||
cc->Inputs().Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional();
|
||||
RET_CHECK_GE(cc->Inputs().NumEntries(""), 1);
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) {
|
||||
cc->Inputs().Get("", i).SetAny();
|
||||
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
|
||||
}
|
||||
cc->Inputs().Get("FINISHED", 0).SetAny();
|
||||
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set<int>().Optional();
|
||||
cc->Outputs().Tag("ALLOW").Set<bool>().Optional();
|
||||
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
options_ = cc->Options<FlowLimiterCalculatorOptions>();
|
||||
options_ = tool::RetrieveOptions(options_, cc->InputSidePackets());
|
||||
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
|
||||
options_.set_max_in_flight(
|
||||
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get<int>());
|
||||
}
|
||||
input_queues_.resize(cc->Inputs().NumEntries(""));
|
||||
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Returns true if an additional frame can be released for processing.
|
||||
// The "ALLOW" output stream indicates this condition at each input frame.
|
||||
bool ProcessingAllowed() {
|
||||
return frames_in_flight_.size() < options_.max_in_flight();
|
||||
}
|
||||
|
||||
// Outputs a packet indicating whether a frame was sent or dropped.
|
||||
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
|
||||
if (cc->Outputs().HasTag("ALLOW")) {
|
||||
cc->Outputs().Tag("ALLOW").AddPacket(MakePacket<bool>(allow).At(ts));
|
||||
}
|
||||
}
|
||||
|
||||
// Sets the timestamp bound or closes an output stream.
|
||||
void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
|
||||
if (bound > Timestamp::Max()) {
|
||||
stream->Close();
|
||||
} else {
|
||||
stream->SetNextTimestampBound(bound);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if a certain timestamp is being processed.
|
||||
bool IsInFlight(Timestamp timestamp) {
|
||||
return std::find(frames_in_flight_.begin(), frames_in_flight_.end(),
|
||||
timestamp) != frames_in_flight_.end();
|
||||
}
|
||||
|
||||
// Releases input packets up to the latest settled input timestamp.
|
||||
void ProcessAuxiliaryInputs(CalculatorContext* cc) {
|
||||
Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
|
||||
for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
|
||||
// Release settled frames from each input queue.
|
||||
while (!input_queues_[i].empty() &&
|
||||
input_queues_[i].front().Timestamp() < settled_bound) {
|
||||
Packet packet = input_queues_[i].front();
|
||||
input_queues_[i].pop_front();
|
||||
if (IsInFlight(packet.Timestamp())) {
|
||||
cc->Outputs().Get("", i).AddPacket(packet);
|
||||
}
|
||||
}
|
||||
|
||||
// Propagate each input timestamp bound.
|
||||
if (!input_queues_[i].empty()) {
|
||||
Timestamp bound = input_queues_[i].front().Timestamp();
|
||||
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
|
||||
} else {
|
||||
Timestamp bound =
|
||||
cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
|
||||
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Releases input packets allowed by the max_in_flight constraint.
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
options_ = tool::RetrieveOptions(options_, cc->Inputs());
|
||||
|
||||
// Process the FINISHED input stream.
|
||||
Packet finished_packet = cc->Inputs().Tag("FINISHED").Value();
|
||||
if (finished_packet.Timestamp() == cc->InputTimestamp()) {
|
||||
while (!frames_in_flight_.empty() &&
|
||||
frames_in_flight_.front() <= finished_packet.Timestamp()) {
|
||||
frames_in_flight_.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
// Process the frame input streams.
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) {
|
||||
Packet packet = cc->Inputs().Get("", i).Value();
|
||||
if (!packet.IsEmpty()) {
|
||||
input_queues_[i].push_back(packet);
|
||||
}
|
||||
}
|
||||
|
||||
// Abandon expired frames in flight. Note that old frames are abandoned
|
||||
// when much newer frame timestamps arrive regardless of elapsed time.
|
||||
TimestampDiff timeout = options_.in_flight_timeout();
|
||||
Timestamp latest_ts = cc->Inputs().Get("", 0).Value().Timestamp();
|
||||
if (timeout > 0 && latest_ts == cc->InputTimestamp() &&
|
||||
latest_ts < Timestamp::Max()) {
|
||||
while (!frames_in_flight_.empty() &&
|
||||
(latest_ts - frames_in_flight_.front()) > timeout) {
|
||||
frames_in_flight_.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
// Release allowed frames from the main input queue.
|
||||
auto& input_queue = input_queues_[0];
|
||||
while (ProcessingAllowed() && !input_queue.empty()) {
|
||||
Packet packet = input_queue.front();
|
||||
input_queue.pop_front();
|
||||
cc->Outputs().Get("", 0).AddPacket(packet);
|
||||
SendAllow(true, packet.Timestamp(), cc);
|
||||
frames_in_flight_.push_back(packet.Timestamp());
|
||||
}
|
||||
|
||||
// Limit the number of queued frames.
|
||||
// Note that frames can be dropped after frames are released because
|
||||
// frame-packets and FINISH-packets never arrive in the same Process call.
|
||||
while (input_queue.size() > options_.max_in_queue()) {
|
||||
Packet packet = input_queue.front();
|
||||
input_queue.pop_front();
|
||||
SendAllow(false, packet.Timestamp(), cc);
|
||||
}
|
||||
|
||||
// Propagate the input timestamp bound.
|
||||
if (!input_queue.empty()) {
|
||||
Timestamp bound = input_queue.front().Timestamp();
|
||||
SetNextTimestampBound(bound, &cc->Outputs().Get("", 0));
|
||||
} else {
|
||||
Timestamp bound =
|
||||
cc->Inputs().Get("", 0).Value().Timestamp().NextAllowedInStream();
|
||||
SetNextTimestampBound(bound, &cc->Outputs().Get("", 0));
|
||||
if (cc->Outputs().HasTag("ALLOW")) {
|
||||
SetNextTimestampBound(bound, &cc->Outputs().Tag("ALLOW"));
|
||||
}
|
||||
}
|
||||
|
||||
ProcessAuxiliaryInputs(cc);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
FlowLimiterCalculatorOptions options_;
|
||||
std::vector<std::deque<Packet>> input_queues_;
|
||||
std::deque<Timestamp> frames_in_flight_;
|
||||
};
|
||||
REGISTER_CALCULATOR(FlowLimiterCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message FlowLimiterCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional FlowLimiterCalculatorOptions ext = 326963320;
|
||||
}
|
||||
|
||||
// The maximum number of frames released for processing at one time.
|
||||
// The default value limits to 1 frame processing at a time.
|
||||
optional int32 max_in_flight = 1 [default = 1];
|
||||
|
||||
// The maximum number of frames queued waiting for processing.
|
||||
// The default value limits to 1 frame awaiting processing.
|
||||
optional int32 max_in_queue = 2 [default = 0];
|
||||
|
||||
// The maximum time in microseconds to wait for a frame to finish processing.
|
||||
// The default value stops waiting after 1 sec.
|
||||
// The value 0 specifies no timeout.
|
||||
optional int64 in_flight_timeout = 3 [default = 1000000];
|
||||
}
|
||||
|
|
@ -1,760 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/time/clock.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/framework/tool/simulation_clock.h"
|
||||
#include "mediapipe/framework/tool/simulation_clock_executor.h"
|
||||
#include "mediapipe/framework/tool/sink.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
// A simple Semaphore for synchronizing test threads.
|
||||
class AtomicSemaphore {
|
||||
public:
|
||||
AtomicSemaphore(int64_t supply) : supply_(supply) {}
|
||||
void Acquire(int64_t amount) {
|
||||
while (supply_.fetch_sub(amount) - amount < 0) {
|
||||
Release(amount);
|
||||
}
|
||||
}
|
||||
void Release(int64_t amount) { supply_.fetch_add(amount); }
|
||||
|
||||
private:
|
||||
std::atomic<int64_t> supply_;
|
||||
};
|
||||
|
||||
// Returns the timestamp values for a vector of Packets.
|
||||
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
|
||||
std::vector<int64> result;
|
||||
for (const Packet& packet : packets) {
|
||||
result.push_back(packet.Timestamp().Value());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the packet values for a vector of Packets.
|
||||
template <typename T>
|
||||
std::vector<T> PacketValues(const std::vector<Packet>& packets) {
|
||||
std::vector<T> result;
|
||||
for (const Packet& packet : packets) {
|
||||
result.push_back(packet.Get<T>());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// A Calculator::Process callback function.
|
||||
typedef std::function<absl::Status(const InputStreamShardSet&,
|
||||
OutputStreamShardSet*)>
|
||||
ProcessFunction;
|
||||
|
||||
// A testing callback function that passes through all packets.
|
||||
absl::Status PassthroughFunction(const InputStreamShardSet& inputs,
|
||||
OutputStreamShardSet* outputs) {
|
||||
for (int i = 0; i < inputs.NumEntries(); ++i) {
|
||||
if (!inputs.Index(i).Value().IsEmpty()) {
|
||||
outputs->Index(i).AddPacket(inputs.Index(i).Value());
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Tests demonstrating an FlowLimiterCalculator operating in a cyclic graph.
|
||||
class FlowLimiterCalculatorSemaphoreTest : public testing::Test {
|
||||
public:
|
||||
FlowLimiterCalculatorSemaphoreTest() : exit_semaphore_(0) {}
|
||||
|
||||
void SetUp() override {
|
||||
graph_config_ = InflightGraphConfig();
|
||||
tool::AddVectorSink("out_1", &graph_config_, &out_1_packets_);
|
||||
}
|
||||
|
||||
void InitializeGraph(int max_in_flight) {
|
||||
ProcessFunction semaphore_1_func = [&](const InputStreamShardSet& inputs,
|
||||
OutputStreamShardSet* outputs) {
|
||||
exit_semaphore_.Acquire(1);
|
||||
return PassthroughFunction(inputs, outputs);
|
||||
};
|
||||
FlowLimiterCalculatorOptions options;
|
||||
options.set_max_in_flight(max_in_flight);
|
||||
options.set_max_in_queue(1);
|
||||
MP_ASSERT_OK(graph_.Initialize(
|
||||
graph_config_, {
|
||||
{"limiter_options", Adopt(new auto(options))},
|
||||
{"callback_1", Adopt(new auto(semaphore_1_func))},
|
||||
}));
|
||||
|
||||
allow_poller_.reset(
|
||||
new OutputStreamPoller(graph_.AddOutputStreamPoller("allow").value()));
|
||||
}
|
||||
|
||||
// Adds a packet to a graph input stream.
|
||||
void AddPacket(const std::string& input_name, int value) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream(
|
||||
input_name, MakePacket<int>(value).At(Timestamp(value))));
|
||||
}
|
||||
|
||||
// A calculator graph starting with an FlowLimiterCalculator and
|
||||
// ending with a InFlightFinishCalculator.
|
||||
// Back-edge "finished" limits processing to one frame in-flight.
|
||||
// The LambdaCalculator is used to keep certain frames in flight.
|
||||
CalculatorGraphConfig InflightGraphConfig() {
|
||||
return ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'in_1'
|
||||
node {
|
||||
calculator: 'FlowLimiterCalculator'
|
||||
input_side_packet: 'OPTIONS:limiter_options'
|
||||
input_stream: 'in_1'
|
||||
input_stream: 'FINISHED:out_1'
|
||||
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
|
||||
output_stream: 'in_1_sampled'
|
||||
output_stream: 'ALLOW:allow'
|
||||
}
|
||||
node {
|
||||
calculator: 'LambdaCalculator'
|
||||
input_side_packet: 'callback_1'
|
||||
input_stream: 'in_1_sampled'
|
||||
output_stream: 'out_1'
|
||||
}
|
||||
)pb");
|
||||
}
|
||||
|
||||
protected:
|
||||
CalculatorGraphConfig graph_config_;
|
||||
CalculatorGraph graph_;
|
||||
AtomicSemaphore exit_semaphore_;
|
||||
std::vector<Packet> out_1_packets_;
|
||||
std::unique_ptr<OutputStreamPoller> allow_poller_;
|
||||
};
|
||||
|
||||
// A test demonstrating an FlowLimiterCalculator operating in a cyclic
|
||||
// graph. This test shows that:
|
||||
//
|
||||
// (1) Frames exceeding the queue size are dropped.
|
||||
// (2) The "ALLOW" signal is produced.
|
||||
// (3) Timestamps are passed through unaltered.
|
||||
//
|
||||
TEST_F(FlowLimiterCalculatorSemaphoreTest, FramesDropped) {
|
||||
InitializeGraph(1);
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
|
||||
auto send_packet = [this](const std::string& input_name, int64 n) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream(
|
||||
input_name, MakePacket<int64>(n).At(Timestamp(n))));
|
||||
};
|
||||
|
||||
Packet allow_packet;
|
||||
send_packet("in_1", 0);
|
||||
for (int i = 0; i < 9; i++) {
|
||||
EXPECT_TRUE(allow_poller_->Next(&allow_packet));
|
||||
EXPECT_TRUE(allow_packet.Get<bool>());
|
||||
// This input should wait in the limiter input queue.
|
||||
send_packet("in_1", i * 10 + 5);
|
||||
// This input should drop the previous input.
|
||||
send_packet("in_1", i * 10 + 10);
|
||||
EXPECT_TRUE(allow_poller_->Next(&allow_packet));
|
||||
EXPECT_FALSE(allow_packet.Get<bool>());
|
||||
exit_semaphore_.Release(1);
|
||||
}
|
||||
exit_semaphore_.Release(1);
|
||||
MP_EXPECT_OK(graph_.CloseInputStream("in_1"));
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
|
||||
// All output streams are closed and all output packets are delivered,
|
||||
// with stream "in_1" closed.
|
||||
EXPECT_EQ(10, out_1_packets_.size());
|
||||
|
||||
// Timestamps have not been altered.
|
||||
EXPECT_EQ(PacketValues<int64>(out_1_packets_),
|
||||
TimestampValues(out_1_packets_));
|
||||
|
||||
// Extra inputs on in_1 have been dropped.
|
||||
EXPECT_EQ(TimestampValues(out_1_packets_),
|
||||
(std::vector<int64>{0, 10, 20, 30, 40, 50, 60, 70, 80, 90}));
|
||||
}
|
||||
|
||||
// A calculator that sleeps during Process.
|
||||
class SleepCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("PACKET").SetAny();
|
||||
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET"));
|
||||
cc->InputSidePackets().Tag("SLEEP_TIME").Set<int64>();
|
||||
cc->InputSidePackets().Tag("WARMUP_TIME").Set<int64>();
|
||||
cc->InputSidePackets().Tag("CLOCK").Set<mediapipe::Clock*>();
|
||||
cc->SetTimestampOffset(0);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<mediapipe::Clock*>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
++packet_count;
|
||||
absl::Duration sleep_time = absl::Microseconds(
|
||||
packet_count == 1
|
||||
? cc->InputSidePackets().Tag("WARMUP_TIME").Get<int64>()
|
||||
: cc->InputSidePackets().Tag("SLEEP_TIME").Get<int64>());
|
||||
clock_->Sleep(sleep_time);
|
||||
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
::mediapipe::Clock* clock_ = nullptr;
|
||||
int packet_count = 0;
|
||||
};
|
||||
REGISTER_CALCULATOR(SleepCalculator);
|
||||
|
||||
// A calculator that drops a packet occasionally.
|
||||
// Drops the 3rd packet, and optionally the corresponding timestamp bound.
|
||||
class DropCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("PACKET").SetAny();
|
||||
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET"));
|
||||
cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Set<bool>();
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (!cc->Inputs().Tag("PACKET").Value().IsEmpty()) {
|
||||
++packet_count;
|
||||
}
|
||||
bool drop = (packet_count == 3);
|
||||
if (!drop && !cc->Inputs().Tag("PACKET").Value().IsEmpty()) {
|
||||
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value());
|
||||
}
|
||||
if (!drop || !cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Get<bool>()) {
|
||||
cc->Outputs().Tag("PACKET").SetNextTimestampBound(
|
||||
cc->InputTimestamp().NextAllowedInStream());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int packet_count = 0;
|
||||
};
|
||||
REGISTER_CALCULATOR(DropCalculator);
|
||||
|
||||
// Tests demonstrating an FlowLimiterCalculator processing FINISHED timestamps.
|
||||
class FlowLimiterCalculatorTest : public testing::Test {
|
||||
protected:
|
||||
CalculatorGraphConfig InflightGraphConfig() {
|
||||
return ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'in_1'
|
||||
node {
|
||||
calculator: 'FlowLimiterCalculator'
|
||||
input_side_packet: 'OPTIONS:limiter_options'
|
||||
input_stream: 'in_1'
|
||||
input_stream: 'FINISHED:out_1'
|
||||
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
|
||||
output_stream: 'in_1_sampled'
|
||||
output_stream: 'ALLOW:allow'
|
||||
}
|
||||
node {
|
||||
calculator: 'SleepCalculator'
|
||||
input_side_packet: 'WARMUP_TIME:warmup_time'
|
||||
input_side_packet: 'SLEEP_TIME:sleep_time'
|
||||
input_side_packet: 'CLOCK:clock'
|
||||
input_stream: 'PACKET:in_1_sampled'
|
||||
output_stream: 'PACKET:out_1_sampled'
|
||||
}
|
||||
node {
|
||||
calculator: 'DropCalculator'
|
||||
input_side_packet: "DROP_TIMESTAMPS:drop_timesamps"
|
||||
input_stream: 'PACKET:out_1_sampled'
|
||||
output_stream: 'PACKET:out_1'
|
||||
}
|
||||
)pb");
|
||||
}
|
||||
|
||||
// Parse an absl::Time from RFC3339 format.
|
||||
absl::Time ParseTime(const std::string& date_time_str) {
|
||||
absl::Time result;
|
||||
absl::ParseTime(absl::RFC3339_sec, date_time_str, &result, nullptr);
|
||||
return result;
|
||||
}
|
||||
|
||||
// The point in simulated time when the test starts.
|
||||
absl::Time StartTime() { return ParseTime("2020-11-03T20:00:00Z"); }
|
||||
|
||||
// Initialize the test clock to follow simulated time.
|
||||
void SetUpSimulationClock() {
|
||||
auto executor = std::make_shared<SimulationClockExecutor>(8);
|
||||
simulation_clock_ = executor->GetClock();
|
||||
clock_ = simulation_clock_.get();
|
||||
simulation_clock_->ThreadStart();
|
||||
clock_->SleepUntil(StartTime());
|
||||
simulation_clock_->ThreadFinish();
|
||||
MP_ASSERT_OK(graph_.SetExecutor("", executor));
|
||||
}
|
||||
|
||||
// Initialize the test clock to follow wall time.
|
||||
void SetUpRealClock() { clock_ = mediapipe::Clock::RealClock(); }
|
||||
|
||||
// Create a few mediapipe input Packets holding ints.
|
||||
void SetUpInputData() {
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
input_packets_.push_back(MakePacket<int>(i).At(Timestamp(i * 10000)));
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
CalculatorGraph graph_;
|
||||
mediapipe::Clock* clock_;
|
||||
std::shared_ptr<SimulationClock> simulation_clock_;
|
||||
std::vector<Packet> input_packets_;
|
||||
std::vector<Packet> out_1_packets_;
|
||||
std::vector<Packet> allow_packets_;
|
||||
};
|
||||
|
||||
// Shows that "FINISHED" can be indicated with either a packet or a timestamp
|
||||
// bound. DropCalculator periodically drops one packet but always propagates
|
||||
// the timestamp bound. Input packets are released or dropped promptly after
|
||||
// each "FINISH" packet or a timestamp bound arrives.
|
||||
TEST_F(FlowLimiterCalculatorTest, FinishedTimestamps) {
|
||||
// Configure the test.
|
||||
SetUpInputData();
|
||||
SetUpSimulationClock();
|
||||
CalculatorGraphConfig graph_config = InflightGraphConfig();
|
||||
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
|
||||
max_in_flight: 1
|
||||
max_in_queue: 1
|
||||
)pb");
|
||||
std::map<std::string, Packet> side_packets = {
|
||||
{"limiter_options",
|
||||
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
|
||||
{"warmup_time", MakePacket<int64>(22000)},
|
||||
{"sleep_time", MakePacket<int64>(22000)},
|
||||
{"drop_timesamps", MakePacket<bool>(false)},
|
||||
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
|
||||
};
|
||||
|
||||
// Start the graph.
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
|
||||
out_1_packets_.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
|
||||
allow_packets_.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
simulation_clock_->ThreadStart();
|
||||
MP_ASSERT_OK(graph_.StartRun(side_packets));
|
||||
|
||||
// Add 9 input packets.
|
||||
// 1. packet-0 is released,
|
||||
// 2. packet-1 is queued,
|
||||
// 3. packet-2 is queued and packet-1 is dropped,
|
||||
// 4. packet-2 is released, and so forth.
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0]));
|
||||
clock_->Sleep(absl::Microseconds(1));
|
||||
EXPECT_EQ(allow_packets_.size(), 1);
|
||||
EXPECT_EQ(allow_packets_.back().Get<bool>(), true);
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
for (int i = 1; i < 8; i += 2) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
EXPECT_EQ(allow_packets_.size(), i);
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i + 1]));
|
||||
clock_->Sleep(absl::Microseconds(1));
|
||||
EXPECT_EQ(allow_packets_.size(), i + 1);
|
||||
EXPECT_EQ(allow_packets_.back().Get<bool>(), false);
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
EXPECT_EQ(allow_packets_.size(), i + 2);
|
||||
EXPECT_EQ(allow_packets_.back().Get<bool>(), true);
|
||||
}
|
||||
|
||||
// Finish the graph.
|
||||
MP_EXPECT_OK(graph_.CloseAllPacketSources());
|
||||
clock_->Sleep(absl::Microseconds(40000));
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
simulation_clock_->ThreadFinish();
|
||||
|
||||
// Validate the output.
|
||||
// input_packets_[4] is dropped by the DropCalculator.
|
||||
std::vector<Packet> expected_output = {input_packets_[0], input_packets_[2],
|
||||
input_packets_[6], input_packets_[8]};
|
||||
EXPECT_EQ(out_1_packets_, expected_output);
|
||||
}
|
||||
|
||||
// Shows that an output packet can be lost completely, and the
|
||||
// FlowLimiterCalculator will stop waiting for it after in_flight_timeout.
|
||||
// DropCalculator completely loses one packet including its timestamp bound.
|
||||
// FlowLimiterCalculator waits 100 ms, and then starts releasing packets again.
|
||||
TEST_F(FlowLimiterCalculatorTest, FinishedLost) {
|
||||
// Configure the test.
|
||||
SetUpInputData();
|
||||
SetUpSimulationClock();
|
||||
CalculatorGraphConfig graph_config = InflightGraphConfig();
|
||||
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
|
||||
max_in_flight: 1
|
||||
max_in_queue: 1
|
||||
in_flight_timeout: 100000 # 100 ms
|
||||
)pb");
|
||||
std::map<std::string, Packet> side_packets = {
|
||||
{"limiter_options",
|
||||
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
|
||||
{"warmup_time", MakePacket<int64>(22000)},
|
||||
{"sleep_time", MakePacket<int64>(22000)},
|
||||
{"drop_timesamps", MakePacket<bool>(true)},
|
||||
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
|
||||
};
|
||||
|
||||
// Start the graph.
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
|
||||
out_1_packets_.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
|
||||
allow_packets_.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
simulation_clock_->ThreadStart();
|
||||
MP_ASSERT_OK(graph_.StartRun(side_packets));
|
||||
|
||||
// Add 21 input packets.
|
||||
// 1. packet-0 is released, packet-1 queued and dropped, and so forth.
|
||||
// 2. packet-4 is lost by DropCalculator.
|
||||
// 3. packet-5 through 13 are dropped while waiting for packet-4.
|
||||
// 4. packet-4 expires and queued packet-14 is released.
|
||||
// 5. packet-17, 19, and 20 are released on time.
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0]));
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
for (int i = 1; i < 21; ++i) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
}
|
||||
|
||||
// Finish the graph.
|
||||
MP_EXPECT_OK(graph_.CloseAllPacketSources());
|
||||
clock_->Sleep(absl::Microseconds(40000));
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
simulation_clock_->ThreadFinish();
|
||||
|
||||
// Validate the output.
|
||||
// input_packets_[4] is lost by the DropCalculator.
|
||||
std::vector<Packet> expected_output = {
|
||||
input_packets_[0], input_packets_[2], input_packets_[14],
|
||||
input_packets_[17], input_packets_[19], input_packets_[20],
|
||||
};
|
||||
EXPECT_EQ(out_1_packets_, expected_output);
|
||||
}
|
||||
|
||||
// Shows what happens when a finish packet is delayed beyond in_flight_timeout.
|
||||
// After in_flight_timeout, FlowLimiterCalculator continues releasing packets.
|
||||
// Temporarily, more than max_in_flight frames are in flight.
|
||||
// Eventually, the number of frames in flight returns to max_in_flight.
|
||||
TEST_F(FlowLimiterCalculatorTest, FinishedDelayed) {
|
||||
// Configure the test.
|
||||
SetUpInputData();
|
||||
SetUpSimulationClock();
|
||||
CalculatorGraphConfig graph_config = InflightGraphConfig();
|
||||
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
|
||||
max_in_flight: 1
|
||||
max_in_queue: 1
|
||||
in_flight_timeout: 100000 # 100 ms
|
||||
)pb");
|
||||
std::map<std::string, Packet> side_packets = {
|
||||
{"limiter_options",
|
||||
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
|
||||
{"warmup_time", MakePacket<int64>(500000)},
|
||||
{"sleep_time", MakePacket<int64>(22000)},
|
||||
{"drop_timesamps", MakePacket<bool>(false)},
|
||||
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
|
||||
};
|
||||
|
||||
// Start the graph.
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
|
||||
out_1_packets_.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
|
||||
allow_packets_.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
simulation_clock_->ThreadStart();
|
||||
MP_ASSERT_OK(graph_.StartRun(side_packets));
|
||||
|
||||
// Add 71 input packets.
|
||||
// 1. During the 500 ms WARMUP_TIME, the in_flight_timeout releases
|
||||
// packets 0, 10, 20, 30, 40, 50, which are queued at the SleepCalculator.
|
||||
// 2. During the next 120 ms, these 6 packets are processed.
|
||||
// 3. After the graph is finally finished with warmup and the backlog packets,
|
||||
// packets 60 through 70 are released and processed on time.
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0]));
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
for (int i = 1; i < 71; ++i) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
}
|
||||
|
||||
// Finish the graph.
|
||||
MP_EXPECT_OK(graph_.CloseAllPacketSources());
|
||||
clock_->Sleep(absl::Microseconds(40000));
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
simulation_clock_->ThreadFinish();
|
||||
|
||||
// Validate the output.
|
||||
// The graph is warming up or backlogged until packet 60.
|
||||
std::vector<Packet> expected_output = {
|
||||
input_packets_[0], input_packets_[10], input_packets_[30],
|
||||
input_packets_[40], input_packets_[50], input_packets_[60],
|
||||
input_packets_[63], input_packets_[65], input_packets_[67],
|
||||
input_packets_[69], input_packets_[70],
|
||||
};
|
||||
EXPECT_EQ(out_1_packets_, expected_output);
|
||||
}
|
||||
|
||||
// Shows that packets on auxiliary input streams are relesed for the same
|
||||
// timestamps as the main input stream, whether the auxiliary packets arrive
|
||||
// early or late.
|
||||
TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
|
||||
// Configure the test.
|
||||
SetUpInputData();
|
||||
SetUpSimulationClock();
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'in_1'
|
||||
input_stream: 'in_2'
|
||||
node {
|
||||
calculator: 'FlowLimiterCalculator'
|
||||
input_side_packet: 'OPTIONS:limiter_options'
|
||||
input_stream: 'in_1'
|
||||
input_stream: 'in_2'
|
||||
input_stream: 'FINISHED:out_1'
|
||||
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
|
||||
output_stream: 'in_1_sampled'
|
||||
output_stream: 'in_2_sampled'
|
||||
output_stream: 'ALLOW:allow'
|
||||
}
|
||||
node {
|
||||
calculator: 'SleepCalculator'
|
||||
input_side_packet: 'WARMUP_TIME:warmup_time'
|
||||
input_side_packet: 'SLEEP_TIME:sleep_time'
|
||||
input_side_packet: 'CLOCK:clock'
|
||||
input_stream: 'PACKET:in_1_sampled'
|
||||
output_stream: 'PACKET:out_1_sampled'
|
||||
}
|
||||
node {
|
||||
calculator: 'DropCalculator'
|
||||
input_side_packet: "DROP_TIMESTAMPS:drop_timesamps"
|
||||
input_stream: 'PACKET:out_1_sampled'
|
||||
output_stream: 'PACKET:out_1'
|
||||
}
|
||||
)pb");
|
||||
|
||||
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
|
||||
max_in_flight: 1
|
||||
max_in_queue: 1
|
||||
in_flight_timeout: 100000 # 100 ms
|
||||
)pb");
|
||||
std::map<std::string, Packet> side_packets = {
|
||||
{"limiter_options",
|
||||
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
|
||||
{"warmup_time", MakePacket<int64>(22000)},
|
||||
{"sleep_time", MakePacket<int64>(22000)},
|
||||
{"drop_timesamps", MakePacket<bool>(true)},
|
||||
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
|
||||
};
|
||||
|
||||
// Start the graph.
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
|
||||
out_1_packets_.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
std::vector<Packet> out_2_packets;
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) {
|
||||
out_2_packets.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
|
||||
allow_packets_.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
simulation_clock_->ThreadStart();
|
||||
MP_ASSERT_OK(graph_.StartRun(side_packets));
|
||||
|
||||
// Add packets 0..9 to stream in_1, and packets 0..10 to stream in_2.
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0]));
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
for (int i = 1; i < 10; ++i) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i - 1]));
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
}
|
||||
|
||||
// Add packets 10..20 to stream in_1, and packets 11..21 to stream in_2.
|
||||
for (int i = 10; i < 21; ++i) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i + 1]));
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
}
|
||||
|
||||
// Finish the graph run.
|
||||
MP_EXPECT_OK(graph_.CloseAllPacketSources());
|
||||
clock_->Sleep(absl::Microseconds(40000));
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
simulation_clock_->ThreadFinish();
|
||||
|
||||
// Validate the output.
|
||||
// Packet input_packets_[4] is lost by the DropCalculator.
|
||||
std::vector<Packet> expected_output = {
|
||||
input_packets_[0], input_packets_[2], input_packets_[14],
|
||||
input_packets_[17], input_packets_[19], input_packets_[20],
|
||||
};
|
||||
EXPECT_EQ(out_1_packets_, expected_output);
|
||||
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
|
||||
std::vector<Packet> expected_output_2 = {
|
||||
input_packets_[0], input_packets_[2], input_packets_[4],
|
||||
input_packets_[14], input_packets_[17], input_packets_[19],
|
||||
input_packets_[20],
|
||||
};
|
||||
EXPECT_EQ(out_2_packets, expected_output_2);
|
||||
}
|
||||
|
||||
// Shows how FlowLimiterCalculator releases packets with max_in_queue 0.
|
||||
// Shows how auxiliary input streams still work with max_in_queue 0.
|
||||
// The processing time "sleep_time" is reduced from 22ms to 12ms to create
|
||||
// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
|
||||
TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
||||
// Configure the test.
|
||||
SetUpInputData();
|
||||
SetUpSimulationClock();
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'in_1'
|
||||
input_stream: 'in_2'
|
||||
node {
|
||||
calculator: 'FlowLimiterCalculator'
|
||||
input_side_packet: 'OPTIONS:limiter_options'
|
||||
input_stream: 'in_1'
|
||||
input_stream: 'in_2'
|
||||
input_stream: 'FINISHED:out_1'
|
||||
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
|
||||
output_stream: 'in_1_sampled'
|
||||
output_stream: 'in_2_sampled'
|
||||
output_stream: 'ALLOW:allow'
|
||||
}
|
||||
node {
|
||||
calculator: 'SleepCalculator'
|
||||
input_side_packet: 'WARMUP_TIME:warmup_time'
|
||||
input_side_packet: 'SLEEP_TIME:sleep_time'
|
||||
input_side_packet: 'CLOCK:clock'
|
||||
input_stream: 'PACKET:in_1_sampled'
|
||||
output_stream: 'PACKET:out_1_sampled'
|
||||
}
|
||||
node {
|
||||
calculator: 'DropCalculator'
|
||||
input_side_packet: "DROP_TIMESTAMPS:drop_timesamps"
|
||||
input_stream: 'PACKET:out_1_sampled'
|
||||
output_stream: 'PACKET:out_1'
|
||||
}
|
||||
)pb");
|
||||
|
||||
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
|
||||
max_in_flight: 1
|
||||
max_in_queue: 0
|
||||
in_flight_timeout: 100000 # 100 ms
|
||||
)pb");
|
||||
std::map<std::string, Packet> side_packets = {
|
||||
{"limiter_options",
|
||||
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
|
||||
{"warmup_time", MakePacket<int64>(12000)},
|
||||
{"sleep_time", MakePacket<int64>(12000)},
|
||||
{"drop_timesamps", MakePacket<bool>(true)},
|
||||
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
|
||||
};
|
||||
|
||||
// Start the graph.
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
|
||||
out_1_packets_.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
std::vector<Packet> out_2_packets;
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) {
|
||||
out_2_packets.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
|
||||
allow_packets_.push_back(p);
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
simulation_clock_->ThreadStart();
|
||||
MP_ASSERT_OK(graph_.StartRun(side_packets));
|
||||
|
||||
// Add packets 0..9 to stream in_1, and packets 0..10 to stream in_2.
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[0]));
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
for (int i = 1; i < 10; ++i) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i - 1]));
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
}
|
||||
|
||||
// Add packets 10..20 to stream in_1, and packets 11..21 to stream in_2.
|
||||
for (int i = 10; i < 21; ++i) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i + 1]));
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
|
||||
clock_->Sleep(absl::Microseconds(10000));
|
||||
}
|
||||
|
||||
// Finish the graph run.
|
||||
MP_EXPECT_OK(graph_.CloseAllPacketSources());
|
||||
clock_->Sleep(absl::Microseconds(40000));
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
simulation_clock_->ThreadFinish();
|
||||
|
||||
// Validate the output.
|
||||
// Packet input_packets_[4] is lost by the DropCalculator.
|
||||
std::vector<Packet> expected_output = {
|
||||
input_packets_[0], input_packets_[2], input_packets_[15],
|
||||
input_packets_[17], input_packets_[19],
|
||||
};
|
||||
EXPECT_EQ(out_1_packets_, expected_output);
|
||||
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
|
||||
std::vector<Packet> expected_output_2 = {
|
||||
input_packets_[0], input_packets_[2], input_packets_[4],
|
||||
input_packets_[15], input_packets_[17], input_packets_[19],
|
||||
};
|
||||
EXPECT_EQ(out_2_packets, expected_output_2);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,219 +0,0 @@
|
|||
// Copyright 2019-2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/gate_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/header_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
enum GateState {
|
||||
GATE_UNINITIALIZED,
|
||||
GATE_ALLOW,
|
||||
GATE_DISALLOW,
|
||||
};
|
||||
|
||||
std::string ToString(GateState state) {
|
||||
switch (state) {
|
||||
case GATE_UNINITIALIZED:
|
||||
return "UNINITIALIZED";
|
||||
case GATE_ALLOW:
|
||||
return "ALLOW";
|
||||
case GATE_DISALLOW:
|
||||
return "DISALLOW";
|
||||
}
|
||||
DLOG(FATAL) << "Unknown GateState";
|
||||
return "UNKNOWN";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Controls whether or not the input packets are passed further along the graph.
|
||||
// Takes multiple data input streams and either an ALLOW or a DISALLOW control
|
||||
// input stream. It outputs an output stream for each input stream that is not
|
||||
// ALLOW or DISALLOW as well as an optional STATE_CHANGE stream which downstream
|
||||
// calculators can use to respond to state-change events.
|
||||
//
|
||||
// If the current ALLOW packet is set to true, the input packets are passed to
|
||||
// their corresponding output stream unchanged. If the ALLOW packet is set to
|
||||
// false, the current input packet is NOT passed to the output stream. If using
|
||||
// DISALLOW, the behavior is opposite of ALLOW.
|
||||
//
|
||||
// By default, an empty packet in the ALLOW or DISALLOW input stream indicates
|
||||
// disallowing the corresponding packets in other input streams. The behavior
|
||||
// can be inverted with a calculator option.
|
||||
//
|
||||
// ALLOW or DISALLOW can also be specified as an input side packet. The rules
|
||||
// for evaluation remain the same as above.
|
||||
//
|
||||
// ALLOW/DISALLOW inputs must be specified either using input stream or
|
||||
// via input side packet but not both.
|
||||
//
|
||||
// Intended to be used with the default input stream handler, which synchronizes
|
||||
// all data input streams with the ALLOW/DISALLOW control input stream.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "GateCalculator"
|
||||
// input_side_packet: "ALLOW:allow" or "DISALLOW:disallow"
|
||||
// input_stream: "input_stream0"
|
||||
// input_stream: "input_stream1"
|
||||
// input_stream: "input_streamN"
|
||||
// input_stream: "ALLOW:allow" or "DISALLOW:disallow"
|
||||
// output_stream: "STATE_CHANGE:state_change"
|
||||
// output_stream: "output_stream0"
|
||||
// output_stream: "output_stream1"
|
||||
// output_stream: "output_streamN"
|
||||
// }
|
||||
class GateCalculator : public CalculatorBase {
|
||||
public:
|
||||
GateCalculator() {}
|
||||
|
||||
static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) {
|
||||
bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") ||
|
||||
cc->InputSidePackets().HasTag("DISALLOW");
|
||||
bool input_via_stream =
|
||||
cc->Inputs().HasTag("ALLOW") || cc->Inputs().HasTag("DISALLOW");
|
||||
// Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW
|
||||
// input.
|
||||
RET_CHECK(input_via_side_packet ^ input_via_stream);
|
||||
|
||||
if (input_via_side_packet) {
|
||||
RET_CHECK(cc->InputSidePackets().HasTag("ALLOW") ^
|
||||
cc->InputSidePackets().HasTag("DISALLOW"));
|
||||
|
||||
if (cc->InputSidePackets().HasTag("ALLOW")) {
|
||||
cc->InputSidePackets().Tag("ALLOW").Set<bool>();
|
||||
} else {
|
||||
cc->InputSidePackets().Tag("DISALLOW").Set<bool>();
|
||||
}
|
||||
} else {
|
||||
RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW"));
|
||||
|
||||
if (cc->Inputs().HasTag("ALLOW")) {
|
||||
cc->Inputs().Tag("ALLOW").Set<bool>();
|
||||
} else {
|
||||
cc->Inputs().Tag("DISALLOW").Set<bool>();
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc));
|
||||
|
||||
const int num_data_streams = cc->Inputs().NumEntries("");
|
||||
RET_CHECK_GE(num_data_streams, 1);
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams)
|
||||
<< "Number of data output streams must match with data input streams.";
|
||||
|
||||
for (int i = 0; i < num_data_streams; ++i) {
|
||||
cc->Inputs().Get("", i).SetAny();
|
||||
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("STATE_CHANGE")) {
|
||||
cc->Outputs().Tag("STATE_CHANGE").Set<bool>();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
use_side_packet_for_allow_disallow_ = false;
|
||||
if (cc->InputSidePackets().HasTag("ALLOW")) {
|
||||
use_side_packet_for_allow_disallow_ = true;
|
||||
allow_by_side_packet_decision_ =
|
||||
cc->InputSidePackets().Tag("ALLOW").Get<bool>();
|
||||
} else if (cc->InputSidePackets().HasTag("DISALLOW")) {
|
||||
use_side_packet_for_allow_disallow_ = true;
|
||||
allow_by_side_packet_decision_ =
|
||||
!cc->InputSidePackets().Tag("DISALLOW").Get<bool>();
|
||||
}
|
||||
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
num_data_streams_ = cc->Inputs().NumEntries("");
|
||||
last_gate_state_ = GATE_UNINITIALIZED;
|
||||
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs()));
|
||||
|
||||
const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>();
|
||||
empty_packets_as_allow_ = options.empty_packets_as_allow();
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
bool allow = empty_packets_as_allow_;
|
||||
if (use_side_packet_for_allow_disallow_) {
|
||||
allow = allow_by_side_packet_decision_;
|
||||
} else {
|
||||
if (cc->Inputs().HasTag("ALLOW") &&
|
||||
!cc->Inputs().Tag("ALLOW").IsEmpty()) {
|
||||
allow = cc->Inputs().Tag("ALLOW").Get<bool>();
|
||||
}
|
||||
if (cc->Inputs().HasTag("DISALLOW") &&
|
||||
!cc->Inputs().Tag("DISALLOW").IsEmpty()) {
|
||||
allow = !cc->Inputs().Tag("DISALLOW").Get<bool>();
|
||||
}
|
||||
}
|
||||
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
|
||||
|
||||
if (cc->Outputs().HasTag("STATE_CHANGE")) {
|
||||
if (last_gate_state_ != GATE_UNINITIALIZED &&
|
||||
last_gate_state_ != new_gate_state) {
|
||||
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
|
||||
<< cc->InputTimestamp().Value() << " from "
|
||||
<< ToString(last_gate_state_) << " to "
|
||||
<< ToString(new_gate_state);
|
||||
cc->Outputs()
|
||||
.Tag("STATE_CHANGE")
|
||||
.AddPacket(MakePacket<bool>(allow).At(cc->InputTimestamp()));
|
||||
}
|
||||
}
|
||||
last_gate_state_ = new_gate_state;
|
||||
|
||||
if (!allow) {
|
||||
// Close the output streams if the gate will be permanently closed.
|
||||
// Prevents buffering in calculators whose parents do no use SetOffset.
|
||||
for (int i = 0; i < num_data_streams_; ++i) {
|
||||
if (!cc->Outputs().Get("", i).IsClosed() &&
|
||||
use_side_packet_for_allow_disallow_) {
|
||||
cc->Outputs().Get("", i).Close();
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Process data streams.
|
||||
for (int i = 0; i < num_data_streams_; ++i) {
|
||||
if (!cc->Inputs().Get("", i).IsEmpty()) {
|
||||
cc->Outputs().Get("", i).AddPacket(cc->Inputs().Get("", i).Value());
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
GateState last_gate_state_ = GATE_UNINITIALIZED;
|
||||
int num_data_streams_;
|
||||
bool empty_packets_as_allow_;
|
||||
bool use_side_packet_for_allow_disallow_;
|
||||
bool allow_by_side_packet_decision_;
|
||||
};
|
||||
REGISTER_CALCULATOR(GateCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
// Copyright 2019-2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message GateCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional GateCalculatorOptions ext = 261754847;
|
||||
}
|
||||
|
||||
// By default an empty packet in the ALLOW or DISALLOW input stream indicates
|
||||
// disallowing the corresponding packets in the data input streams. Setting
|
||||
// this option to true inverts that, allowing the data packets to go through.
|
||||
optional bool empty_packets_as_allow = 1;
|
||||
}
|
||||
|
|
@ -1,334 +0,0 @@
|
|||
// Copyright 2019-2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
class GateCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
// Helper to run a graph and return status.
|
||||
static absl::Status RunGraph(const std::string& proto) {
|
||||
auto runner = absl::make_unique<CalculatorRunner>(
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(proto));
|
||||
return runner->Run();
|
||||
}
|
||||
|
||||
// Use this when ALLOW/DISALLOW input is provided as a side packet.
|
||||
void RunTimeStep(int64 timestamp, bool stream_payload) {
|
||||
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
||||
MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
|
||||
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
|
||||
}
|
||||
|
||||
// Use this when ALLOW/DISALLOW input is provided as an input stream.
|
||||
void RunTimeStep(int64 timestamp, const std::string& control_tag,
|
||||
bool control) {
|
||||
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
||||
MakePacket<bool>(true).At(Timestamp(timestamp)));
|
||||
runner_->MutableInputs()
|
||||
->Tag(control_tag)
|
||||
.packets.push_back(MakePacket<bool>(control).At(Timestamp(timestamp)));
|
||||
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
|
||||
}
|
||||
|
||||
void SetRunner(const std::string& proto) {
|
||||
runner_ = absl::make_unique<CalculatorRunner>(
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(proto));
|
||||
}
|
||||
|
||||
CalculatorRunner* runner() { return runner_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<CalculatorRunner> runner_;
|
||||
};
|
||||
|
||||
TEST_F(GateCalculatorTest, InvalidInputs) {
|
||||
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "ALLOW:gating_stream"
|
||||
input_stream: "DISALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
)")));
|
||||
|
||||
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_side_packet: "ALLOW:gating_stream"
|
||||
input_side_packet: "DISALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
)")));
|
||||
|
||||
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "ALLOW:gating_stream"
|
||||
input_side_packet: "ALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
)")));
|
||||
|
||||
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "DISALLOW:gating_stream"
|
||||
input_side_packet: "DISALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
)")));
|
||||
|
||||
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "ALLOW:gating_stream"
|
||||
input_side_packet: "DISALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
)")));
|
||||
|
||||
EXPECT_TRUE(absl::IsInternal(GateCalculatorTest::RunGraph(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "DISALLOW:gating_stream"
|
||||
input_side_packet: "ALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
)")));
|
||||
}
|
||||
|
||||
TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_side_packet: "ALLOW:gating_stream"
|
||||
input_stream: "test_input"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(true));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
ASSERT_EQ(2, output.size());
|
||||
EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value());
|
||||
EXPECT_EQ(kTimestampValue1, output[1].Timestamp().Value());
|
||||
EXPECT_EQ(true, output[0].Get<bool>());
|
||||
EXPECT_EQ(false, output[1].Get<bool>());
|
||||
}
|
||||
|
||||
TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_side_packet: "DISALLOW:gating_stream"
|
||||
input_stream: "test_input"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(false));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
ASSERT_EQ(2, output.size());
|
||||
EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value());
|
||||
EXPECT_EQ(kTimestampValue1, output[1].Timestamp().Value());
|
||||
EXPECT_EQ(true, output[0].Get<bool>());
|
||||
EXPECT_EQ(false, output[1].Get<bool>());
|
||||
}
|
||||
|
||||
TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_side_packet: "ALLOW:gating_stream"
|
||||
input_stream: "test_input"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(false));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
ASSERT_EQ(0, output.size());
|
||||
}
|
||||
|
||||
TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_side_packet: "DISALLOW:gating_stream"
|
||||
input_stream: "test_input"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(true));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
ASSERT_EQ(0, output.size());
|
||||
}
|
||||
|
||||
TEST_F(GateCalculatorTest, Allow) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "ALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "ALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
ASSERT_EQ(2, output.size());
|
||||
EXPECT_EQ(kTimestampValue0, output[0].Timestamp().Value());
|
||||
EXPECT_EQ(kTimestampValue2, output[1].Timestamp().Value());
|
||||
EXPECT_EQ(true, output[0].Get<bool>());
|
||||
EXPECT_EQ(true, output[1].Get<bool>());
|
||||
}
|
||||
|
||||
TEST_F(GateCalculatorTest, Disallow) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "DISALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "DISALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
ASSERT_EQ(2, output.size());
|
||||
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
|
||||
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
|
||||
EXPECT_EQ(true, output[0].Get<bool>());
|
||||
EXPECT_EQ(true, output[1].Get<bool>());
|
||||
}
|
||||
|
||||
TEST_F(GateCalculatorTest, AllowWithStateChange) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "ALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", false);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
|
||||
ASSERT_EQ(2, output.size());
|
||||
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
|
||||
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
|
||||
EXPECT_EQ(true, output[0].Get<bool>()); // Allow.
|
||||
EXPECT_EQ(false, output[1].Get<bool>()); // Disallow.
|
||||
}
|
||||
|
||||
TEST_F(GateCalculatorTest, DisallowWithStateChange) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "DISALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "DISALLOW", true);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
|
||||
ASSERT_EQ(2, output.size());
|
||||
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
|
||||
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
|
||||
EXPECT_EQ(true, output[0].Get<bool>()); // Allow.
|
||||
EXPECT_EQ(false, output[1].Get<bool>()); // Disallow.
|
||||
}
|
||||
|
||||
// Must not detect disallow value for first timestamp as a state change.
|
||||
TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "DISALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
|
||||
ASSERT_EQ(0, output.size());
|
||||
}
|
||||
|
||||
// Must not detect allow value for first timestamp as a state change.
|
||||
TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
|
||||
SetRunner(R"(
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "test_input"
|
||||
input_stream: "ALLOW:gating_stream"
|
||||
output_stream: "test_output"
|
||||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
runner()->Outputs().Get("STATE_CHANGE", 0).packets;
|
||||
ASSERT_EQ(0, output.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// This Calculator multiplexes several input streams into a single
|
||||
// output stream, dropping input packets with timestamps older than the
|
||||
// last output packet. In case two packets arrive with the same timestamp,
|
||||
// the packet with the lower stream index will be output and the rest will
|
||||
// be dropped.
|
||||
//
|
||||
// This Calculator optionally produces a finish inidicator as its second
|
||||
// output stream. One indicator packet is produced for each input packet
|
||||
// received.
|
||||
//
|
||||
// This Calculator can be used with an ImmediateInputStreamHandler or with the
|
||||
// default ISH.
|
||||
//
|
||||
// This Calculator is designed to work with a Demux calculator such as
|
||||
// the RoundRobinDemuxCalculator. Therefore, packets from different
|
||||
// input streams are normally not expected to have the same timestamp.
|
||||
//
|
||||
// NOTE: this calculator can drop packets non-deterministically, depending on
|
||||
// how fast the input streams are fed. In most cases, MuxCalculator should be
|
||||
// preferred. In particular, dropping packets can interfere with rate limiting
|
||||
// mechanisms.
|
||||
class ImmediateMuxCalculator : public CalculatorBase {
|
||||
public:
|
||||
// This calculator combines any set of input streams into a single
|
||||
// output stream. All input stream types must match the output stream type.
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
// Passes any input packet to the output stream immediately, unless the
|
||||
// packet timestamp is lower than a previously passed packet.
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
};
|
||||
REGISTER_CALCULATOR(ImmediateMuxCalculator);
|
||||
|
||||
absl::Status ImmediateMuxCalculator::GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Outputs().NumEntries() >= 1 && cc->Outputs().NumEntries() <= 2)
|
||||
<< "This calculator produces only one or two output streams.";
|
||||
cc->Outputs().Index(0).SetAny();
|
||||
if (cc->Outputs().NumEntries() >= 2) {
|
||||
cc->Outputs().Index(1).Set<bool>();
|
||||
}
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
cc->Inputs().Index(i).SetSameAs(&cc->Outputs().Index(0));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ImmediateMuxCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ImmediateMuxCalculator::Process(CalculatorContext* cc) {
|
||||
// Pass along the first packet, unless it has been superseded.
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||
const Packet& packet = cc->Inputs().Index(i).Value();
|
||||
if (!packet.IsEmpty()) {
|
||||
if (packet.Timestamp() >= cc->Outputs().Index(0).NextTimestampBound()) {
|
||||
cc->Outputs().Index(0).AddPacket(packet);
|
||||
} else {
|
||||
LOG_FIRST_N(WARNING, 5)
|
||||
<< "Dropping a packet with timestamp " << packet.Timestamp();
|
||||
}
|
||||
if (cc->Outputs().NumEntries() >= 2) {
|
||||
Timestamp output_timestamp = std::max(
|
||||
cc->InputTimestamp(), cc->Outputs().Index(1).NextTimestampBound());
|
||||
cc->Outputs().Index(1).Add(new bool(true), output_timestamp);
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,373 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/port/threadpool.h"
|
||||
#include "mediapipe/framework/tool/sink.h"
|
||||
|
||||
// Tests for ImmediateMuxCalculator. These tests show how parallel output
|
||||
// packets are handled when they arrive in various orders.
|
||||
using testing::ElementsAre;
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
// A simple Semaphore for synchronizing test threads.
|
||||
class AtomicSemaphore {
|
||||
public:
|
||||
AtomicSemaphore(int64_t supply) : supply_(supply) {}
|
||||
void Acquire(int64_t amount) {
|
||||
while (supply_.fetch_sub(amount) - amount < 0) {
|
||||
Release(amount);
|
||||
}
|
||||
}
|
||||
void Release(int64_t amount) { supply_.fetch_add(amount); }
|
||||
|
||||
private:
|
||||
std::atomic<int64_t> supply_;
|
||||
};
|
||||
|
||||
// A mediapipe::Executor that signals the start and finish of each task.
|
||||
// Provides 4 worker threads.
|
||||
class CountingExecutor : public Executor {
|
||||
public:
|
||||
CountingExecutor(std::function<void()> start_callback,
|
||||
std::function<void()> finish_callback)
|
||||
: thread_pool_(4),
|
||||
start_callback_(std::move(start_callback)),
|
||||
finish_callback_(std::move(finish_callback)) {
|
||||
thread_pool_.StartWorkers();
|
||||
}
|
||||
void Schedule(std::function<void()> task) override {
|
||||
start_callback_();
|
||||
thread_pool_.Schedule([this, task] {
|
||||
task();
|
||||
finish_callback_();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
ThreadPool thread_pool_;
|
||||
std::function<void()> start_callback_;
|
||||
std::function<void()> finish_callback_;
|
||||
};
|
||||
|
||||
// Returns a new mediapipe::Executor with 4 worker threads.
|
||||
std::shared_ptr<Executor> MakeExecutor(std::function<void()> start_callback,
|
||||
std::function<void()> finish_callback) {
|
||||
return std::make_shared<CountingExecutor>(start_callback, finish_callback);
|
||||
}
|
||||
|
||||
// Tests showing ImmediateMuxCalculator dropping packets in various sequences.
|
||||
class ImmediateMuxCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUpMuxGraph() {
|
||||
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
|
||||
input_stream: "input_packets_0"
|
||||
input_stream: "input_packets_1"
|
||||
node {
|
||||
calculator: "ImmediateMuxCalculator"
|
||||
input_stream_handler {
|
||||
input_stream_handler: "ImmediateInputStreamHandler"
|
||||
}
|
||||
input_stream: "input_packets_0"
|
||||
input_stream: "input_packets_1"
|
||||
output_stream: "output_packets_0"
|
||||
}
|
||||
)",
|
||||
&graph_config_));
|
||||
}
|
||||
|
||||
void SetUpDemuxGraph() {
|
||||
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
|
||||
input_stream: "input_packets_0"
|
||||
node {
|
||||
calculator: "RoundRobinDemuxCalculator"
|
||||
input_stream: "input_packets_0"
|
||||
output_stream: "OUTPUT:0:input_0"
|
||||
output_stream: "OUTPUT:1:input_1"
|
||||
}
|
||||
node {
|
||||
calculator: "LambdaCalculator"
|
||||
input_side_packet: 'callback_0'
|
||||
input_stream: "input_0"
|
||||
output_stream: "output_0"
|
||||
}
|
||||
node {
|
||||
calculator: "LambdaCalculator"
|
||||
input_side_packet: 'callback_1'
|
||||
input_stream: "input_1"
|
||||
output_stream: "output_1"
|
||||
}
|
||||
node {
|
||||
calculator: "ImmediateMuxCalculator"
|
||||
input_stream_handler {
|
||||
input_stream_handler: "ImmediateInputStreamHandler"
|
||||
}
|
||||
input_stream: "output_0"
|
||||
input_stream: "output_1"
|
||||
output_stream: "output_packets_0"
|
||||
}
|
||||
)",
|
||||
&graph_config_));
|
||||
}
|
||||
|
||||
void SetUpDemuxInFlightGraph() {
|
||||
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
|
||||
input_stream: "input_packets_0"
|
||||
node {
|
||||
calculator: 'FlowLimiterCalculator'
|
||||
input_stream_handler {
|
||||
input_stream_handler: 'ImmediateInputStreamHandler'
|
||||
}
|
||||
input_side_packet: 'MAX_IN_FLIGHT:max_in_flight'
|
||||
input_stream: 'input_packets_0'
|
||||
input_stream: 'FINISHED:finish_indicator'
|
||||
input_stream_info: {
|
||||
tag_index: 'FINISHED'
|
||||
back_edge: true
|
||||
}
|
||||
output_stream: 'input_0_sampled'
|
||||
}
|
||||
node {
|
||||
calculator: "RoundRobinDemuxCalculator"
|
||||
input_stream: "input_0_sampled"
|
||||
output_stream: "OUTPUT:0:input_0"
|
||||
output_stream: "OUTPUT:1:input_1"
|
||||
}
|
||||
node {
|
||||
calculator: "LambdaCalculator"
|
||||
input_side_packet: 'callback_0'
|
||||
input_stream: "input_0"
|
||||
output_stream: "output_0"
|
||||
}
|
||||
node {
|
||||
calculator: "LambdaCalculator"
|
||||
input_side_packet: 'callback_1'
|
||||
input_stream: "input_1"
|
||||
output_stream: "output_1"
|
||||
}
|
||||
node {
|
||||
calculator: "ImmediateMuxCalculator"
|
||||
input_stream_handler {
|
||||
input_stream_handler: "ImmediateInputStreamHandler"
|
||||
}
|
||||
input_stream: "output_0"
|
||||
input_stream: "output_1"
|
||||
output_stream: 'output_packets_0'
|
||||
output_stream: 'finish_indicator'
|
||||
}
|
||||
)",
|
||||
&graph_config_));
|
||||
}
|
||||
|
||||
static Packet PacketAt(int64 ts) {
|
||||
return Adopt(new int64(999)).At(Timestamp(ts));
|
||||
}
|
||||
static Packet None() { return Packet().At(Timestamp::OneOverPostStream()); }
|
||||
static bool IsNone(const Packet& packet) {
|
||||
return packet.Timestamp() == Timestamp::OneOverPostStream();
|
||||
}
|
||||
// Return the values of the timestamps of a vector of Packets.
|
||||
static std::vector<int64> TimestampValues(
|
||||
const std::vector<Packet>& packets) {
|
||||
std::vector<int64> result;
|
||||
for (const Packet& p : packets) {
|
||||
result.push_back(p.Timestamp().Value());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Runs a CalculatorGraph with a series of packet sets.
|
||||
// Returns a vector of packets from each graph output stream.
|
||||
void RunGraph(const std::vector<std::vector<Packet>>& input_sets,
|
||||
std::vector<Packet>* output_packets) {
|
||||
// Register output packet observers.
|
||||
tool::AddVectorSink("output_packets_0", &graph_config_, output_packets);
|
||||
|
||||
// Start running the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config_));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
// Send each packet to the graph in the specified order.
|
||||
for (int t = 0; t < input_sets.size(); t++) {
|
||||
const std::vector<Packet>& input_set = input_sets[t];
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
for (int i = 0; i < input_set.size(); i++) {
|
||||
const Packet& packet = input_set[i];
|
||||
if (!IsNone(packet)) {
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
absl::StrCat("input_packets_", i), packet));
|
||||
}
|
||||
}
|
||||
}
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
CalculatorGraphConfig graph_config_;
|
||||
};
|
||||
|
||||
TEST_F(ImmediateMuxCalculatorTest, IncreasingTimestamps) {
|
||||
// Run the graph with a series of packet sets.
|
||||
std::vector<std::vector<Packet>> input_sets = {
|
||||
{PacketAt(10000), None()}, //
|
||||
{PacketAt(20000), None()}, //
|
||||
{None(), PacketAt(30000)}, //
|
||||
{None(), PacketAt(40000)},
|
||||
};
|
||||
SetUpMuxGraph();
|
||||
std::vector<Packet> output_packets;
|
||||
RunGraph(input_sets, &output_packets);
|
||||
|
||||
// Validate the output packets.
|
||||
EXPECT_THAT(TimestampValues(output_packets),
|
||||
ElementsAre(10000, 20000, 30000, 40000));
|
||||
}
|
||||
|
||||
TEST_F(ImmediateMuxCalculatorTest, SupersededTimestamp) {
|
||||
// Run the graph with a series of packet sets.
|
||||
std::vector<std::vector<Packet>> input_sets = {
|
||||
{PacketAt(10000), None()}, //
|
||||
{PacketAt(30000), None()}, //
|
||||
{None(), PacketAt(20000)}, //
|
||||
{None(), PacketAt(40000)},
|
||||
};
|
||||
SetUpMuxGraph();
|
||||
std::vector<Packet> output_packets;
|
||||
RunGraph(input_sets, &output_packets);
|
||||
|
||||
// Output packet 20000 is superseded and dropped.
|
||||
EXPECT_THAT(TimestampValues(output_packets),
|
||||
ElementsAre(10000, 30000, 40000));
|
||||
}
|
||||
|
||||
TEST_F(ImmediateMuxCalculatorTest, SimultaneousTimestamps) {
|
||||
// Run the graph with a series of packet sets.
|
||||
std::vector<std::vector<Packet>> input_sets = {
|
||||
{PacketAt(10000), None()}, //
|
||||
{PacketAt(40000), PacketAt(20000)}, //
|
||||
{None(), PacketAt(30000)},
|
||||
};
|
||||
SetUpMuxGraph();
|
||||
std::vector<Packet> output_packets;
|
||||
RunGraph(input_sets, &output_packets);
|
||||
|
||||
// Output packet 20000 is superseded and dropped.
|
||||
EXPECT_THAT(TimestampValues(output_packets), ElementsAre(10000, 40000));
|
||||
}
|
||||
|
||||
// A Calculator::Process callback function.
|
||||
typedef std::function<absl::Status(const InputStreamShardSet&,
|
||||
OutputStreamShardSet*)>
|
||||
ProcessFunction;
|
||||
|
||||
// A testing callback function that passes through all packets.
|
||||
absl::Status PassThrough(const InputStreamShardSet& inputs,
|
||||
OutputStreamShardSet* outputs) {
|
||||
for (int i = 0; i < inputs.NumEntries(); ++i) {
|
||||
if (!inputs.Index(i).Value().IsEmpty()) {
|
||||
outputs->Index(i).AddPacket(inputs.Index(i).Value());
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
TEST_F(ImmediateMuxCalculatorTest, Demux) {
|
||||
// Semaphores to sequence the parallel Process outputs.
|
||||
AtomicSemaphore semaphore_0(0);
|
||||
AtomicSemaphore semaphore_1(0);
|
||||
ProcessFunction wait_0 = [&semaphore_0](const InputStreamShardSet& inputs,
|
||||
OutputStreamShardSet* outputs) {
|
||||
semaphore_0.Acquire(1);
|
||||
return PassThrough(inputs, outputs);
|
||||
};
|
||||
ProcessFunction wait_1 = [&semaphore_1](const InputStreamShardSet& inputs,
|
||||
OutputStreamShardSet* outputs) {
|
||||
semaphore_1.Acquire(1);
|
||||
return PassThrough(inputs, outputs);
|
||||
};
|
||||
|
||||
// A callback to await and capture output packets.
|
||||
std::vector<Packet> out_packets;
|
||||
absl::Mutex out_mutex;
|
||||
auto out_cb = [&](const Packet& p) {
|
||||
absl::MutexLock lock(&out_mutex);
|
||||
out_packets.push_back(p);
|
||||
return absl::OkStatus();
|
||||
};
|
||||
auto wait_for = [&](std::function<bool()> cond) {
|
||||
absl::MutexLock lock(&out_mutex);
|
||||
out_mutex.Await(absl::Condition(&cond));
|
||||
};
|
||||
SetUpDemuxGraph();
|
||||
|
||||
// Start the graph and add five input packets.
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config_,
|
||||
{
|
||||
{"callback_0", Adopt(new auto(wait_0))},
|
||||
{"callback_1", Adopt(new auto(wait_1))},
|
||||
}));
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream("output_packets_0", out_cb));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_EXPECT_OK(
|
||||
graph.AddPacketToInputStream("input_packets_0", PacketAt(10000)));
|
||||
MP_EXPECT_OK(
|
||||
graph.AddPacketToInputStream("input_packets_0", PacketAt(20000)));
|
||||
MP_EXPECT_OK(
|
||||
graph.AddPacketToInputStream("input_packets_0", PacketAt(30000)));
|
||||
MP_EXPECT_OK(
|
||||
graph.AddPacketToInputStream("input_packets_0", PacketAt(40000)));
|
||||
MP_EXPECT_OK(
|
||||
graph.AddPacketToInputStream("input_packets_0", PacketAt(50000)));
|
||||
|
||||
// Release the outputs in order 20000, 10000, 30000, 50000, 40000.
|
||||
semaphore_1.Release(1); // 20000
|
||||
wait_for([&] { return !out_packets.empty(); });
|
||||
semaphore_0.Release(1); // 10000
|
||||
semaphore_0.Release(1); // 30000
|
||||
wait_for([&] { return out_packets.size() >= 2; });
|
||||
semaphore_0.Release(1); // 50000
|
||||
wait_for([&] { return out_packets.size() >= 3; });
|
||||
semaphore_1.Release(1); // 40000
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
|
||||
// Output packets 10000 and 40000 are superseded and dropped.
|
||||
EXPECT_THAT(TimestampValues(out_packets), ElementsAre(20000, 30000, 50000));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Given two input streams (A, B), output a single stream containing a pair<A,
|
||||
// B>.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "MakePairCalculator"
|
||||
// input_stream: "packet_a"
|
||||
// input_stream: "packet_b"
|
||||
// output_stream: "output_pair_a_b"
|
||||
// }
|
||||
class MakePairCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<AnyType>::Multiple kIn{""};
|
||||
// Note that currently api2::Packet is a different type from mediapipe::Packet
|
||||
static constexpr Output<std::pair<mediapipe::Packet, mediapipe::Packet>>
|
||||
kPair{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kPair);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
RET_CHECK_EQ(kIn(cc).Count(), 2);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
kPair(cc).Send({kIn(cc)[0].packet(), kIn(cc)[1].packet()});
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(MakePairCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
// Perform a (left) matrix multiply. Meaning (output = A * input)
|
||||
// where A is the matrix which is provided as an input side packet.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "MatrixMultiplyCalculator"
|
||||
// input_stream: "samples"
|
||||
// output_stream: "multiplied_samples"
|
||||
// input_side_packet: "multiplication_matrix"
|
||||
// }
|
||||
class MatrixMultiplyCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<Matrix> kIn{""};
|
||||
static constexpr Output<Matrix> kOut{""};
|
||||
static constexpr SideInput<Matrix> kSide{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kSide);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(MatrixMultiplyCalculator);
|
||||
|
||||
absl::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) {
|
||||
kOut(cc).Send(*kSide(cc) * *kIn(cc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,239 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
// A 3x4 Matrix of random integers in [0,1000).
|
||||
const char kMatrixText[] =
|
||||
"rows: 3\n"
|
||||
"cols: 4\n"
|
||||
"packed_data: 387\n"
|
||||
"packed_data: 940\n"
|
||||
"packed_data: 815\n"
|
||||
"packed_data: 825\n"
|
||||
"packed_data: 997\n"
|
||||
"packed_data: 884\n"
|
||||
"packed_data: 419\n"
|
||||
"packed_data: 763\n"
|
||||
"packed_data: 123\n"
|
||||
"packed_data: 30\n"
|
||||
"packed_data: 825\n"
|
||||
"packed_data: 299\n";
|
||||
|
||||
// A 4x20 Matrix of random integers in [0,10).
|
||||
// Each column of this matrix is a sample.
|
||||
const char kSamplesText[] =
|
||||
"rows: 4\n"
|
||||
"cols: 20\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 9\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 9\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 9\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 9\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 0\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 6\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 5\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 9\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 3\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 7\n"
|
||||
"packed_data: 8\n"
|
||||
"packed_data: 2\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 4\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 1\n"
|
||||
"packed_data: 7\n";
|
||||
|
||||
// A 3x20 Matrix of expected values for the result of the matrix multiply
|
||||
// computed using R.
|
||||
// Each column of this matrix is an expected output.
|
||||
const char kExpectedText[] =
|
||||
"rows: 3\n"
|
||||
"cols: 20\n"
|
||||
"packed_data: 12499\n"
|
||||
"packed_data: 26793\n"
|
||||
"packed_data: 16967\n"
|
||||
"packed_data: 5007\n"
|
||||
"packed_data: 14406\n"
|
||||
"packed_data: 9635\n"
|
||||
"packed_data: 4179\n"
|
||||
"packed_data: 7870\n"
|
||||
"packed_data: 4434\n"
|
||||
"packed_data: 5793\n"
|
||||
"packed_data: 12045\n"
|
||||
"packed_data: 8876\n"
|
||||
"packed_data: 2801\n"
|
||||
"packed_data: 8053\n"
|
||||
"packed_data: 5611\n"
|
||||
"packed_data: 1740\n"
|
||||
"packed_data: 4469\n"
|
||||
"packed_data: 2665\n"
|
||||
"packed_data: 8108\n"
|
||||
"packed_data: 18396\n"
|
||||
"packed_data: 10186\n"
|
||||
"packed_data: 12330\n"
|
||||
"packed_data: 23374\n"
|
||||
"packed_data: 15526\n"
|
||||
"packed_data: 9611\n"
|
||||
"packed_data: 21804\n"
|
||||
"packed_data: 14776\n"
|
||||
"packed_data: 8241\n"
|
||||
"packed_data: 17979\n"
|
||||
"packed_data: 11989\n"
|
||||
"packed_data: 8429\n"
|
||||
"packed_data: 18921\n"
|
||||
"packed_data: 9819\n"
|
||||
"packed_data: 6270\n"
|
||||
"packed_data: 13689\n"
|
||||
"packed_data: 7031\n"
|
||||
"packed_data: 9472\n"
|
||||
"packed_data: 19210\n"
|
||||
"packed_data: 13634\n"
|
||||
"packed_data: 8567\n"
|
||||
"packed_data: 12499\n"
|
||||
"packed_data: 10455\n"
|
||||
"packed_data: 2151\n"
|
||||
"packed_data: 7469\n"
|
||||
"packed_data: 3195\n"
|
||||
"packed_data: 10774\n"
|
||||
"packed_data: 21851\n"
|
||||
"packed_data: 12673\n"
|
||||
"packed_data: 12516\n"
|
||||
"packed_data: 25318\n"
|
||||
"packed_data: 14347\n"
|
||||
"packed_data: 7984\n"
|
||||
"packed_data: 17100\n"
|
||||
"packed_data: 10972\n"
|
||||
"packed_data: 5195\n"
|
||||
"packed_data: 11102\n"
|
||||
"packed_data: 8710\n"
|
||||
"packed_data: 3002\n"
|
||||
"packed_data: 11295\n"
|
||||
"packed_data: 6360\n";
|
||||
|
||||
// Send a number of samples through the MatrixMultiplyCalculator.
|
||||
TEST(MatrixMultiplyCalculatorTest, Multiply) {
|
||||
CalculatorRunner runner("MatrixMultiplyCalculator", "", 1, 1, 1);
|
||||
Matrix* matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText, matrix);
|
||||
runner.MutableSidePackets()->Index(0) = Adopt(matrix);
|
||||
|
||||
Matrix samples;
|
||||
MatrixFromTextProto(kSamplesText, &samples);
|
||||
Matrix expected;
|
||||
MatrixFromTextProto(kExpectedText, &expected);
|
||||
CHECK_EQ(samples.cols(), expected.cols());
|
||||
|
||||
for (int i = 0; i < samples.cols(); ++i) {
|
||||
// Take a column from samples and produce a packet with just that
|
||||
// column in it as an input sample for the calculator.
|
||||
Eigen::MatrixXf* sample = new Eigen::MatrixXf(samples.block(0, i, 4, 1));
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(sample).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
EXPECT_EQ(runner.MutableInputs()->Index(0).packets.size(),
|
||||
runner.Outputs().Index(0).packets.size());
|
||||
|
||||
int i = 0;
|
||||
for (const Packet& output : runner.Outputs().Index(0).packets) {
|
||||
EXPECT_EQ(Timestamp(i), output.Timestamp());
|
||||
const Eigen::MatrixXf& result = output.Get<Matrix>();
|
||||
ASSERT_EQ(3, result.rows());
|
||||
EXPECT_NEAR((expected.block(0, i, 3, 1) - result).cwiseAbs().sum(), 0.0,
|
||||
1e-5);
|
||||
++i;
|
||||
}
|
||||
EXPECT_EQ(samples.cols(), i);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Subtract input matrix from the side input matrix and vice versa. The matrices
|
||||
// must have the same dimension.
|
||||
// Based on the tag (MINUEND vs SUBTRAHEND), the matrices in the input stream
|
||||
// and input side packet can be either subtrahend or minuend. The output matrix
|
||||
// is generated by performing minuend matrix - subtrahend matrix.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "MatrixSubtractCalculator"
|
||||
// input_stream: "MINUEND:input_matrix"
|
||||
// input_side_packet: "SUBTRAHEND:side_matrix"
|
||||
// output_stream: "output_matrix"
|
||||
// }
|
||||
//
|
||||
// or
|
||||
//
|
||||
// node {
|
||||
// calculator: "MatrixSubtractCalculator"
|
||||
// input_stream: "SUBTRAHEND:input_matrix"
|
||||
// input_side_packet: "MINUEND:side_matrix"
|
||||
// output_stream: "output_matrix"
|
||||
// }
|
||||
class MatrixSubtractCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<Matrix>::SideFallback kMinuend{"MINUEND"};
|
||||
static constexpr Input<Matrix>::SideFallback kSubtrahend{"SUBTRAHEND"};
|
||||
static constexpr Output<Matrix> kOut{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kMinuend, kSubtrahend, kOut);
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(MatrixSubtractCalculator);
|
||||
|
||||
// static
|
||||
absl::Status MatrixSubtractCalculator::UpdateContract(CalculatorContract* cc) {
|
||||
// TODO: the next restriction could be relaxed.
|
||||
RET_CHECK(kMinuend(cc).IsStream() ^ kSubtrahend(cc).IsStream())
|
||||
<< "MatrixSubtractCalculator only accepts exactly one input stream and "
|
||||
"one input side packet";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) {
|
||||
const Matrix& minuend = *kMinuend(cc);
|
||||
const Matrix& subtrahend = *kSubtrahend(cc);
|
||||
if (minuend.rows() != subtrahend.rows() ||
|
||||
minuend.cols() != subtrahend.cols()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Minuend and subtrahend must have the same dimensions.");
|
||||
}
|
||||
kOut(cc).Send(minuend - subtrahend);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,156 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
// A 3x4 Matrix of random integers in [0,1000).
|
||||
const char kMatrixText[] =
|
||||
"rows: 3\n"
|
||||
"cols: 4\n"
|
||||
"packed_data: 387\n"
|
||||
"packed_data: 940\n"
|
||||
"packed_data: 815\n"
|
||||
"packed_data: 825\n"
|
||||
"packed_data: 997\n"
|
||||
"packed_data: 884\n"
|
||||
"packed_data: 419\n"
|
||||
"packed_data: 763\n"
|
||||
"packed_data: 123\n"
|
||||
"packed_data: 30\n"
|
||||
"packed_data: 825\n"
|
||||
"packed_data: 299\n";
|
||||
|
||||
const char kMatrixText2[] =
|
||||
"rows: 3\n"
|
||||
"cols: 4\n"
|
||||
"packed_data: 388\n"
|
||||
"packed_data: 941\n"
|
||||
"packed_data: 816\n"
|
||||
"packed_data: 826\n"
|
||||
"packed_data: 998\n"
|
||||
"packed_data: 885\n"
|
||||
"packed_data: 420\n"
|
||||
"packed_data: 764\n"
|
||||
"packed_data: 124\n"
|
||||
"packed_data: 31\n"
|
||||
"packed_data: 826\n"
|
||||
"packed_data: 300\n";
|
||||
|
||||
TEST(MatrixSubtractCalculatorTest, WrongConfig) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "MatrixSubtractCalculator"
|
||||
input_stream: "input_matrix"
|
||||
input_side_packet: "SUBTRAHEND:side_matrix"
|
||||
input_side_packet: "MINUEND:side_matrix2"
|
||||
output_stream: "output_matrix"
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
auto status = runner.Run();
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"only accepts exactly one input stream and one input side packet"));
|
||||
}
|
||||
|
||||
TEST(MatrixSubtractCalculatorTest, WrongConfig2) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "MatrixSubtractCalculator"
|
||||
input_side_packet: "SUBTRAHEND:side_matrix"
|
||||
input_stream: "SUBTRAHEND:side_matrix2"
|
||||
output_stream: "output_matrix"
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
auto status = runner.Run();
|
||||
EXPECT_THAT(status.message(), testing::HasSubstr("must be connected"));
|
||||
EXPECT_THAT(status.message(), testing::HasSubstr("not both"));
|
||||
}
|
||||
|
||||
TEST(MatrixSubtractCalculatorTest, SubtractFromInput) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "MatrixSubtractCalculator"
|
||||
input_stream: "MINUEND:input_matrix"
|
||||
input_side_packet: "SUBTRAHEND:side_matrix"
|
||||
output_stream: "output_matrix"
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
Matrix* side_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText, side_matrix);
|
||||
runner.MutableSidePackets()->Tag("SUBTRAHEND") = Adopt(side_matrix);
|
||||
|
||||
Matrix* input_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText2, input_matrix);
|
||||
runner.MutableInputs()->Tag("MINUEND").packets.push_back(
|
||||
Adopt(input_matrix).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
|
||||
|
||||
EXPECT_EQ(Timestamp(0), runner.Outputs().Index(0).packets[0].Timestamp());
|
||||
const Eigen::MatrixXf& result =
|
||||
runner.Outputs().Index(0).packets[0].Get<Matrix>();
|
||||
ASSERT_EQ(3, result.rows());
|
||||
ASSERT_EQ(4, result.cols());
|
||||
EXPECT_NEAR(result.sum(), 12, 1e-5);
|
||||
}
|
||||
|
||||
TEST(MatrixSubtractCalculatorTest, SubtractFromSideMatrix) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "MatrixSubtractCalculator"
|
||||
input_stream: "SUBTRAHEND:input_matrix"
|
||||
input_side_packet: "MINUEND:side_matrix"
|
||||
output_stream: "output_matrix"
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
Matrix* side_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText, side_matrix);
|
||||
runner.MutableSidePackets()->Tag("MINUEND") = Adopt(side_matrix);
|
||||
|
||||
Matrix* input_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText2, input_matrix);
|
||||
runner.MutableInputs()
|
||||
->Tag("SUBTRAHEND")
|
||||
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
|
||||
|
||||
EXPECT_EQ(Timestamp(0), runner.Outputs().Index(0).packets[0].Timestamp());
|
||||
const Eigen::MatrixXf& result =
|
||||
runner.Outputs().Index(0).packets[0].Get<Matrix>();
|
||||
ASSERT_EQ(3, result.rows());
|
||||
ASSERT_EQ(4, result.cols());
|
||||
EXPECT_NEAR(result.sum(), -12, 1e-5);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Defines MatrixToVectorCalculator.
|
||||
#include <math.h>
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/tool/status_util.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// A calculator that converts a Matrix M to a vector containing all the
|
||||
// entries of M in column-major order.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "MatrixToVectorCalculator"
|
||||
// input_stream: "input_matrix"
|
||||
// output_stream: "column_major_vector"
|
||||
// }
|
||||
class MatrixToVectorCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<Matrix> kIn{""};
|
||||
static constexpr Output<std::vector<float>> kOut{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
|
||||
// Outputs a packet containing a vector for each input packet.
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(MatrixToVectorCalculator);
|
||||
|
||||
absl::Status MatrixToVectorCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(0);
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) {
|
||||
const Matrix& input = *kIn(cc);
|
||||
auto output = absl::make_unique<std::vector<float>>();
|
||||
|
||||
// The following lines work to convert the Matrix to a vector because Matrix
|
||||
// is an Eigen::MatrixXf and Eigen uses column-major layout by default.
|
||||
output->resize(input.rows() * input.cols());
|
||||
auto output_as_matrix =
|
||||
Eigen::Map<Matrix>(output->data(), input.rows(), input.cols());
|
||||
output_as_matrix = input;
|
||||
|
||||
kOut(cc).Send(std::move(output));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,88 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
#include "mediapipe/util/time_series_test_util.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
class MatrixToVectorCalculatorTest
|
||||
: public mediapipe::TimeSeriesCalculatorTest<mediapipe::NoOptions> {
|
||||
protected:
|
||||
void SetUp() override { calculator_name_ = "MatrixToVectorCalculator"; }
|
||||
|
||||
void AppendInput(const std::vector<float>& column_major_data,
|
||||
int64 timestamp) {
|
||||
ASSERT_EQ(num_input_samples_ * num_input_channels_,
|
||||
column_major_data.size());
|
||||
Eigen::Map<const Matrix> data_map(&column_major_data[0],
|
||||
num_input_channels_, num_input_samples_);
|
||||
AppendInputPacket(new Matrix(data_map), timestamp);
|
||||
}
|
||||
|
||||
void SetInputStreamParameters(int num_channels, int num_samples) {
|
||||
num_input_channels_ = num_channels;
|
||||
num_input_samples_ = num_samples;
|
||||
input_sample_rate_ = 100;
|
||||
input_packet_rate_ = 20.0;
|
||||
}
|
||||
|
||||
void SetInputHeader(int num_channels, int num_samples) {
|
||||
SetInputStreamParameters(num_channels, num_samples);
|
||||
FillInputHeader();
|
||||
}
|
||||
|
||||
void CheckOutputPacket(int packet, std::vector<float> expected_vector) {
|
||||
const auto& actual_vector =
|
||||
runner_->Outputs().Index(0).packets[packet].Get<std::vector<float>>();
|
||||
EXPECT_THAT(actual_vector, testing::ContainerEq(expected_vector));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MatrixToVectorCalculatorTest, SingleRow) {
|
||||
InitializeGraph();
|
||||
SetInputHeader(1, 4); // 1 channel x 4 samples
|
||||
const std::vector<float>& data_vector = {1.0, 2.0, 3.0, 4.0};
|
||||
AppendInput(data_vector, 0);
|
||||
MP_ASSERT_OK(RunGraph());
|
||||
CheckOutputPacket(0, data_vector);
|
||||
}
|
||||
|
||||
TEST_F(MatrixToVectorCalculatorTest, RegularMatrix) {
|
||||
InitializeGraph();
|
||||
SetInputHeader(4, 2); // 4 channels x 2 samples
|
||||
// Actual data matrix is the transpose of the appearance below.
|
||||
const std::vector<float>& data_vector = {1.0, 2.0, 3.0, 4.0,
|
||||
5.0, 6.0, 7.0, 8.0};
|
||||
AppendInput(data_vector, 0);
|
||||
|
||||
MP_ASSERT_OK(RunGraph());
|
||||
CheckOutputPacket(0, data_vector);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// This calculator takes a set of input streams and combines them into a single
|
||||
// output stream. The packets from different streams do not need to contain the
|
||||
// same type. If there are packets arriving at the same time from two or more
|
||||
// input streams, the packet corresponding to the input stream with the smallest
|
||||
// index is passed to the output and the rest are ignored.
|
||||
//
|
||||
// Example use-case:
|
||||
// Suppose we have two (or more) different algorithms for detecting shot
|
||||
// boundaries and we need to merge their packets into a single stream. The
|
||||
// algorithms may emit shot boundaries at the same time and their output types
|
||||
// may not be compatible. Subsequent calculators that process the merged stream
|
||||
// may be interested only in the timestamps of the shot boundary packets and so
|
||||
// it may not even need to inspect the values stored inside the packets.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "MergeCalculator"
|
||||
// input_stream: "shot_info1"
|
||||
// input_stream: "shot_info2"
|
||||
// input_stream: "shot_info3"
|
||||
// output_stream: "merged_shot_infos"
|
||||
// }
|
||||
//
|
||||
class MergeCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<AnyType>::Multiple kIn{""};
|
||||
static constexpr Output<AnyType> kOut{""};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
RET_CHECK_GT(kIn(cc).Count(), 0) << "Needs at least one input stream";
|
||||
if (kIn(cc).Count() == 1) {
|
||||
LOG(WARNING)
|
||||
<< "MergeCalculator expects multiple input streams to merge but is "
|
||||
"receiving only one. Make sure the calculator is configured "
|
||||
"correctly or consider removing this calculator to reduce "
|
||||
"unnecessary overhead.";
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
// Output the packet from the first input stream with a packet ready at this
|
||||
// timestamp.
|
||||
for (const auto& input : kIn(cc)) {
|
||||
if (!input.IsEmpty()) {
|
||||
kOut(cc).Send(input.packet());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
|
||||
LOG(WARNING) << "Empty input packets at timestamp "
|
||||
<< cc->InputTimestamp().Value();
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(MergeCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,141 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
// Checks that the calculator fails if no input streams are provided.
|
||||
TEST(InvariantMergeInputStreamsCalculator, NoInputStreamsMustFail) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "MergeCalculator"
|
||||
output_stream: "merged_output"
|
||||
)pb"));
|
||||
// Expect calculator to fail.
|
||||
ASSERT_FALSE(runner.Run().ok());
|
||||
}
|
||||
|
||||
// Checks that the calculator fails with an incorrect number of output streams.
|
||||
TEST(InvariantMergeInputStreamsCalculator, ExpectExactlyOneOutputStream) {
|
||||
CalculatorRunner runner1(
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "MergeCalculator"
|
||||
input_stream: "input1"
|
||||
input_stream: "input2"
|
||||
)pb"));
|
||||
// Expect calculator to fail.
|
||||
EXPECT_FALSE(runner1.Run().ok());
|
||||
|
||||
CalculatorRunner runner2(
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "MergeCalculator"
|
||||
input_stream: "input1"
|
||||
input_stream: "input2"
|
||||
output_stream: "output1"
|
||||
output_stream: "output2"
|
||||
)pb"));
|
||||
// Expect calculator to fail.
|
||||
ASSERT_FALSE(runner2.Run().ok());
|
||||
}
|
||||
|
||||
// Ensures two streams with differing types can be merged correctly.
|
||||
TEST(MediaPipeDetectionToSoapboxDetectionCalculatorTest,
|
||||
TestMergingTwoStreams) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "MergeCalculator"
|
||||
input_stream: "input1"
|
||||
input_stream: "input2"
|
||||
output_stream: "combined_output"
|
||||
)pb"));
|
||||
|
||||
// input1: integers 10, 20, 30, occurring at times 10, 20, 30.
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new int(10)).At(Timestamp(10)));
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new int(20)).At(Timestamp(20)));
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new int(30)).At(Timestamp(30)));
|
||||
// input2: floats 5.5, 35.5 at times 5, 35.
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(new float(5.5)).At(Timestamp(5)));
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(new float(35.5)).At(Timestamp(35)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
// Expected combined_output: 5.5, 10, 20, 30, 35.5 at times 5, 10, 20, 30, 35.
|
||||
const std::vector<Packet>& actual_output = runner.Outputs().Index(0).packets;
|
||||
ASSERT_EQ(actual_output.size(), 5);
|
||||
EXPECT_EQ(actual_output[0].Timestamp(), Timestamp(5));
|
||||
EXPECT_EQ(actual_output[0].Get<float>(), 5.5);
|
||||
|
||||
EXPECT_EQ(actual_output[1].Timestamp(), Timestamp(10));
|
||||
EXPECT_EQ(actual_output[1].Get<int>(), 10);
|
||||
|
||||
EXPECT_EQ(actual_output[2].Timestamp(), Timestamp(20));
|
||||
EXPECT_EQ(actual_output[2].Get<int>(), 20);
|
||||
|
||||
EXPECT_EQ(actual_output[3].Timestamp(), Timestamp(30));
|
||||
EXPECT_EQ(actual_output[3].Get<int>(), 30);
|
||||
|
||||
EXPECT_EQ(actual_output[4].Timestamp(), Timestamp(35));
|
||||
EXPECT_EQ(actual_output[4].Get<float>(), 35.5);
|
||||
}
|
||||
|
||||
// Ensures three streams with differing types can be merged correctly.
|
||||
TEST(MediaPipeDetectionToSoapboxDetectionCalculatorTest,
|
||||
TestMergingThreeStreams) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "MergeCalculator"
|
||||
input_stream: "input1"
|
||||
input_stream: "input2"
|
||||
input_stream: "input3"
|
||||
output_stream: "combined_output"
|
||||
)pb"));
|
||||
|
||||
// input1: integer 30 occurring at time 30.
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new int(30)).At(Timestamp(30)));
|
||||
// input2: float 20.5 occurring at time 20.
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(new float(20.5)).At(Timestamp(20)));
|
||||
// input3: char 'c' occurring at time 10.
|
||||
runner.MutableInputs()->Index(2).packets.push_back(
|
||||
Adopt(new char('c')).At(Timestamp(10)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
// Expected combined_output: 'c', 20.5, 30 at times 10, 20, 30.
|
||||
const std::vector<Packet>& actual_output = runner.Outputs().Index(0).packets;
|
||||
ASSERT_EQ(actual_output.size(), 3);
|
||||
EXPECT_EQ(actual_output[0].Timestamp(), Timestamp(10));
|
||||
EXPECT_EQ(actual_output[0].Get<char>(), 'c');
|
||||
|
||||
EXPECT_EQ(actual_output[1].Timestamp(), Timestamp(20));
|
||||
EXPECT_EQ(actual_output[1].Get<float>(), 20.5);
|
||||
|
||||
EXPECT_EQ(actual_output[2].Timestamp(), Timestamp(30));
|
||||
EXPECT_EQ(actual_output[2].Get<int>(), 30);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ...,
|
||||
// using the integer value (0, 1, ...) in the packet on the "SELECT" input
|
||||
// stream, and passes the packet on the selected input stream to the "OUTPUT"
|
||||
// output stream.
|
||||
//
|
||||
// Note that this calculator defaults to use MuxInputStreamHandler, which is
|
||||
// required for this calculator. However, it can be overridden to work with
|
||||
// other InputStreamHandlers. Check out the unit tests on for an example usage
|
||||
// with DefaultInputStreamHandler.
|
||||
// TODO: why would you need to use DefaultISH? Perhaps b/167596925?
|
||||
class MuxCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<int>::SideFallback kSelect{"SELECT"};
|
||||
// TODO: this currently sets them all to Any independently, instead
|
||||
// of the first being Any and the others being SameAs.
|
||||
static constexpr Input<AnyType>::Multiple kIn{"INPUT"};
|
||||
static constexpr Output<SameType<kIn>> kOut{"OUTPUT"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kSelect, kIn, kOut,
|
||||
StreamHandler("MuxInputStreamHandler"));
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
int select = *kSelect(cc);
|
||||
RET_CHECK(0 <= select && select < kIn(cc).Count());
|
||||
if (!kIn(cc)[select].IsEmpty()) {
|
||||
kOut(cc).Send(kIn(cc)[select].packet());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(MuxCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,304 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef SplitVectorCalculator<int, false> SplitIntVectorCalculator;
|
||||
REGISTER_CALCULATOR(SplitIntVectorCalculator);
|
||||
|
||||
namespace {
|
||||
|
||||
// Graph with default input stream handler, and the input selection is driven
|
||||
// by an input stream. All MuxCalculator inputs are present at each timestamp.
|
||||
constexpr char kTestGraphConfig1[] = R"pb(
|
||||
input_stream: "input"
|
||||
output_stream: "test_output"
|
||||
node {
|
||||
calculator: "SplitIntVectorCalculator"
|
||||
input_stream: "input"
|
||||
output_stream: "stream0"
|
||||
output_stream: "stream1"
|
||||
output_stream: "stream2"
|
||||
output_stream: "input_select"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 1 }
|
||||
ranges: { begin: 1 end: 2 }
|
||||
ranges: { begin: 2 end: 3 }
|
||||
ranges: { begin: 3 end: 4 }
|
||||
element_only: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
calculator: "MuxCalculator"
|
||||
input_stream: "INPUT:0:stream0"
|
||||
input_stream: "INPUT:1:stream1"
|
||||
input_stream: "INPUT:2:stream2"
|
||||
input_stream: "SELECT:input_select"
|
||||
output_stream: "OUTPUT:test_output"
|
||||
input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" }
|
||||
}
|
||||
)pb";
|
||||
|
||||
// Graph with default input stream handler, and the input selection is driven
|
||||
// by an input side packet. All MuxCalculator inputs are present at each
|
||||
// timestamp.
|
||||
constexpr char kTestGraphConfig2[] = R"pb(
|
||||
input_side_packet: "input_selector"
|
||||
input_stream: "input"
|
||||
output_stream: "test_output"
|
||||
node {
|
||||
calculator: "SplitIntVectorCalculator"
|
||||
input_stream: "input"
|
||||
output_stream: "stream0"
|
||||
output_stream: "stream1"
|
||||
output_stream: "stream2"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 1 }
|
||||
ranges: { begin: 1 end: 2 }
|
||||
ranges: { begin: 2 end: 3 }
|
||||
element_only: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
calculator: "MuxCalculator"
|
||||
input_stream: "INPUT:0:stream0"
|
||||
input_stream: "INPUT:1:stream1"
|
||||
input_stream: "INPUT:2:stream2"
|
||||
input_side_packet: "SELECT:input_selector"
|
||||
output_stream: "OUTPUT:test_output"
|
||||
input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" }
|
||||
}
|
||||
)pb";
|
||||
|
||||
// Graph with mux input stream handler, and the input selection is driven
|
||||
// by an input stream. Only one MuxCalculator input is present at each
|
||||
// timestamp.
|
||||
constexpr char kTestGraphConfig3[] = R"pb(
|
||||
input_stream: "input"
|
||||
output_stream: "test_output"
|
||||
node {
|
||||
calculator: "RoundRobinDemuxCalculator"
|
||||
input_stream: "input"
|
||||
output_stream: "OUTPUT:0:stream0"
|
||||
output_stream: "OUTPUT:1:stream1"
|
||||
output_stream: "OUTPUT:2:stream2"
|
||||
output_stream: "SELECT:input_select"
|
||||
}
|
||||
node {
|
||||
calculator: "MuxCalculator"
|
||||
input_stream: "INPUT:0:stream0"
|
||||
input_stream: "INPUT:1:stream1"
|
||||
input_stream: "INPUT:2:stream2"
|
||||
input_stream: "SELECT:input_select"
|
||||
output_stream: "OUTPUT:test_output"
|
||||
}
|
||||
)pb";
|
||||
|
||||
constexpr char kOutputName[] = "test_output";
|
||||
constexpr char kInputName[] = "input";
|
||||
constexpr char kInputSelector[] = "input_selector";
|
||||
|
||||
// Helper to run a graph with the given inputs and generate outputs, asserting
|
||||
// each step along the way.
|
||||
// Inputs:
|
||||
// graph_config_proto - graph config protobuf
|
||||
// extra_side_packets - input side packets name to value map
|
||||
// input_stream_name - name of the input
|
||||
void RunGraph(const std::string& graph_config_proto,
|
||||
const std::map<std::string, Packet>& extra_side_packets,
|
||||
const std::string& input_stream_name, int num_input_packets,
|
||||
std::function<Packet(int)> input_fn,
|
||||
const std::string& output_stream_name,
|
||||
std::function<absl::Status(const Packet&)> output_fn) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(graph_config_proto);
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(output_stream_name, output_fn));
|
||||
MP_ASSERT_OK(graph.StartRun(extra_side_packets));
|
||||
for (int i = 0; i < num_input_packets; ++i) {
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(input_stream_name, input_fn(i)));
|
||||
}
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(MuxCalculatorTest, InputStreamSelector_DefaultInputStreamHandler) {
|
||||
// Input and handling.
|
||||
std::vector<std::vector<int>> input_packets = {
|
||||
{1, 1, 2, 1}, {3, 5, 8, 2}, {13, 21, 34, 0},
|
||||
{55, 89, 144, 2}, {233, 377, 610, 0}, {987, 1597, 2584, 1},
|
||||
{4181, 6765, 10946, 2},
|
||||
};
|
||||
int packet_time_stamp = 22;
|
||||
// This function will return the i-th input packet.
|
||||
auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet {
|
||||
return MakePacket<std::vector<int>>(input_packets[i])
|
||||
.At(Timestamp(packet_time_stamp++));
|
||||
};
|
||||
|
||||
// Output and handling.
|
||||
std::vector<int> output;
|
||||
// This function collects the output from the packet.
|
||||
auto output_fn = [&output](const Packet& p) -> absl::Status {
|
||||
output.push_back(p.Get<int>());
|
||||
return absl::OkStatus();
|
||||
};
|
||||
|
||||
RunGraph(kTestGraphConfig1, {}, kInputName, input_packets.size(), input_fn,
|
||||
kOutputName, output_fn);
|
||||
EXPECT_THAT(output, testing::ElementsAre(1, 8, 13, 144, 233, 1597, 10946));
|
||||
}
|
||||
|
||||
TEST(MuxCalculatorTest, InputSidePacketSelector_DefaultInputStreamHandler) {
|
||||
// Input and handling.
|
||||
std::vector<std::vector<int>> input_packets = {
|
||||
{1, 1, 2}, {3, 5, 8}, {13, 21, 34}, {55, 89, 144},
|
||||
{233, 377, 610}, {987, 1597, 2584}, {4181, 6765, 10946},
|
||||
};
|
||||
int packet_time_stamp = 22;
|
||||
// This function will return the i-th input packet.
|
||||
auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet {
|
||||
return MakePacket<std::vector<int>>(input_packets[i])
|
||||
.At(Timestamp(packet_time_stamp++));
|
||||
};
|
||||
|
||||
// Output and handling.
|
||||
std::vector<int> output;
|
||||
// This function collects the output from the packet.
|
||||
auto output_fn = [&output](const Packet& p) -> absl::Status {
|
||||
output.push_back(p.Get<int>());
|
||||
return absl::OkStatus();
|
||||
};
|
||||
|
||||
RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket<int>(0)}},
|
||||
kInputName, input_packets.size(), input_fn, kOutputName, output_fn);
|
||||
EXPECT_THAT(output, testing::ElementsAre(1, 3, 13, 55, 233, 987, 4181));
|
||||
|
||||
output.clear();
|
||||
RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket<int>(1)}},
|
||||
kInputName, input_packets.size(), input_fn, kOutputName, output_fn);
|
||||
EXPECT_THAT(output, testing::ElementsAre(1, 5, 21, 89, 377, 1597, 6765));
|
||||
|
||||
output.clear();
|
||||
RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket<int>(2)}},
|
||||
kInputName, input_packets.size(), input_fn, kOutputName, output_fn);
|
||||
EXPECT_THAT(output, testing::ElementsAre(2, 8, 34, 144, 610, 2584, 10946));
|
||||
}
|
||||
|
||||
TEST(MuxCalculatorTest, InputStreamSelector_MuxInputStreamHandler) {
|
||||
// Input and handling.
|
||||
std::vector<int> input_packets = {1, 1, 2, 3, 5, 8, 13,
|
||||
21, 34, 55, 89, 144, 233, 377,
|
||||
610, 987, 1597, 2584, 4181, 6765, 10946};
|
||||
int packet_time_stamp = 22;
|
||||
// This function will return the i-th input packet.
|
||||
auto input_fn = [&packet_time_stamp, &input_packets](int i) -> Packet {
|
||||
return MakePacket<int>(input_packets[i]).At(Timestamp(packet_time_stamp++));
|
||||
};
|
||||
|
||||
// Output and handling.
|
||||
std::vector<int> output;
|
||||
// This function collects the output from the packet.
|
||||
auto output_fn = [&output](const Packet& p) -> absl::Status {
|
||||
output.push_back(p.Get<int>());
|
||||
return absl::OkStatus();
|
||||
};
|
||||
|
||||
RunGraph(kTestGraphConfig3, {}, kInputName, input_packets.size(), input_fn,
|
||||
kOutputName, output_fn);
|
||||
EXPECT_EQ(output, input_packets);
|
||||
}
|
||||
|
||||
constexpr char kDualInputGraphConfig[] = R"pb(
|
||||
input_stream: "input_0"
|
||||
input_stream: "input_1"
|
||||
input_stream: "input_select"
|
||||
output_stream: "test_output"
|
||||
node {
|
||||
calculator: "MuxCalculator"
|
||||
input_stream: "INPUT:0:input_0"
|
||||
input_stream: "INPUT:1:input_1"
|
||||
input_stream: "SELECT:input_select"
|
||||
output_stream: "OUTPUT:test_output"
|
||||
}
|
||||
)pb";
|
||||
|
||||
TEST(MuxCalculatorTest, DiscardSkippedInputs_MuxInputStreamHandler) {
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
kDualInputGraphConfig);
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
|
||||
std::shared_ptr<int> output;
|
||||
MP_ASSERT_OK(
|
||||
graph.ObserveOutputStream("test_output", [&output](const Packet& p) {
|
||||
output = p.Get<std::shared_ptr<int>>();
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
auto one = std::make_shared<int>(1);
|
||||
auto two = std::make_shared<int>(2);
|
||||
auto three = std::make_shared<int>(3);
|
||||
std::weak_ptr<int> one_weak = one;
|
||||
std::weak_ptr<int> two_weak = two;
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_0",
|
||||
MakePacket<std::shared_ptr<int>>(std::move(one)).At(Timestamp(0))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_1",
|
||||
MakePacket<std::shared_ptr<int>>(std::move(two)).At(Timestamp(0))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_1",
|
||||
MakePacket<std::shared_ptr<int>>(std::move(three)).At(Timestamp(1))));
|
||||
EXPECT_EQ(one, nullptr);
|
||||
EXPECT_EQ(two, nullptr);
|
||||
EXPECT_EQ(three, nullptr);
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_select", MakePacket<int>(0).At(Timestamp(0))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_EQ(*output, 1);
|
||||
EXPECT_NE(one_weak.lock(), nullptr);
|
||||
EXPECT_EQ(two_weak.lock(), nullptr);
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_select", MakePacket<int>(1).At(Timestamp(1))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_EQ(*output, 3);
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
// Copyright 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// A Calculator that returns 0 if INPUT is 0, and 1 otherwise.
|
||||
class NonZeroCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<int>::SideFallback kIn{"INPUT"};
|
||||
static constexpr Output<int>::Optional kOut{"OUTPUT"};
|
||||
static constexpr Output<bool>::Optional kBooleanOut{"OUTPUT_BOOL"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kBooleanOut);
|
||||
|
||||
absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
RET_CHECK(kOut(cc).IsConnected() || kBooleanOut(cc).IsConnected())
|
||||
<< "At least one output stream is expected.";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (!kIn(cc).IsEmpty()) {
|
||||
bool isNonZero = *kIn(cc) != 0;
|
||||
if (kOut(cc).IsConnected()) {
|
||||
kOut(cc).Send(std::make_unique<int>(isNonZero ? 1 : 0));
|
||||
}
|
||||
if (kBooleanOut(cc).IsConnected()) {
|
||||
kBooleanOut(cc).Send(std::make_unique<bool>(isNonZero));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(NonZeroCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
// Copyright 2021 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class NonZeroCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
NonZeroCalculatorTest()
|
||||
: runner_(
|
||||
R"pb(
|
||||
calculator: "NonZeroCalculator"
|
||||
input_stream: "INPUT:input"
|
||||
output_stream: "OUTPUT:output"
|
||||
output_stream: "OUTPUT_BOOL:output_bool"
|
||||
)pb") {}
|
||||
|
||||
void SetInput(const std::vector<int>& inputs) {
|
||||
int timestamp = 0;
|
||||
for (const auto input : inputs) {
|
||||
runner_.MutableInputs()
|
||||
->Get("INPUT", 0)
|
||||
.packets.push_back(MakePacket<int>(input).At(Timestamp(timestamp++)));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> GetOutput() {
|
||||
std::vector<int> result;
|
||||
for (const auto output : runner_.Outputs().Get("OUTPUT", 0).packets) {
|
||||
result.push_back(output.Get<int>());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<bool> GetOutputBool() {
|
||||
std::vector<bool> result;
|
||||
for (const auto output : runner_.Outputs().Get("OUTPUT_BOOL", 0).packets) {
|
||||
result.push_back(output.Get<bool>());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
CalculatorRunner runner_;
|
||||
};
|
||||
|
||||
TEST_F(NonZeroCalculatorTest, ProducesZeroOutputForZeroInput) {
|
||||
SetInput({0});
|
||||
|
||||
MP_ASSERT_OK(runner_.Run());
|
||||
|
||||
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(0));
|
||||
EXPECT_THAT(GetOutputBool(), ::testing::ElementsAre(false));
|
||||
}
|
||||
|
||||
TEST_F(NonZeroCalculatorTest, ProducesNonZeroOutputForNonZeroInput) {
|
||||
SetInput({1, 2, 3, -4, 5});
|
||||
|
||||
MP_ASSERT_OK(runner_.Run());
|
||||
|
||||
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 1, 1, 1, 1));
|
||||
EXPECT_THAT(GetOutputBool(),
|
||||
::testing::ElementsAre(true, true, true, true, true));
|
||||
}
|
||||
|
||||
TEST_F(NonZeroCalculatorTest, SwitchesBetweenNonZeroAndZeroOutput) {
|
||||
SetInput({1, 0, 3, 0, 5});
|
||||
MP_ASSERT_OK(runner_.Run());
|
||||
EXPECT_THAT(GetOutput(), ::testing::ElementsAre(1, 0, 1, 0, 1));
|
||||
EXPECT_THAT(GetOutputBool(),
|
||||
::testing::ElementsAre(true, false, true, false, true));
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,117 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// This takes packets from N+1 streams, A_1, A_2, ..., A_N, B.
|
||||
// For every packet that appears in B, outputs the most recent packet from each
|
||||
// of the A_i on a separate stream.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/calculators/core/packet_cloner_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// For every packet received on the last stream, output the latest packet
|
||||
// obtained on all other streams. Therefore, if the last stream outputs at a
|
||||
// higher rate than the others, this effectively clones the packets from the
|
||||
// other streams to match the last.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "PacketClonerCalculator"
|
||||
// input_stream: "first_base_signal"
|
||||
// input_stream: "second_base_signal"
|
||||
// input_stream: "tick_signal"
|
||||
// output_stream: "cloned_first_base_signal"
|
||||
// output_stream: "cloned_second_base_signal"
|
||||
// }
|
||||
//
|
||||
// Related:
|
||||
// packet_cloner_calculator.proto: Options for this calculator.
|
||||
// merge_input_streams_calculator.cc: One output stream.
|
||||
// packet_inner_join_calculator.cc: Don't output unless all inputs are new.
|
||||
class PacketClonerCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
const int tick_signal_index = cc->Inputs().NumEntries() - 1;
|
||||
for (int i = 0; i < tick_signal_index; ++i) {
|
||||
cc->Inputs().Index(i).SetAny();
|
||||
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i));
|
||||
}
|
||||
cc->Inputs().Index(tick_signal_index).SetAny();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
// Load options.
|
||||
const auto calculator_options =
|
||||
cc->Options<mediapipe::PacketClonerCalculatorOptions>();
|
||||
output_only_when_all_inputs_received_ =
|
||||
calculator_options.output_only_when_all_inputs_received();
|
||||
|
||||
// Parse input streams.
|
||||
tick_signal_index_ = cc->Inputs().NumEntries() - 1;
|
||||
current_.resize(tick_signal_index_);
|
||||
// Pass along the header for each stream if present.
|
||||
for (int i = 0; i < tick_signal_index_; ++i) {
|
||||
if (!cc->Inputs().Index(i).Header().IsEmpty()) {
|
||||
cc->Outputs().Index(i).SetHeader(cc->Inputs().Index(i).Header());
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
// Store input signals.
|
||||
for (int i = 0; i < tick_signal_index_; ++i) {
|
||||
if (!cc->Inputs().Index(i).Value().IsEmpty()) {
|
||||
current_[i] = cc->Inputs().Index(i).Value();
|
||||
}
|
||||
}
|
||||
|
||||
// Output according to the TICK signal.
|
||||
if (!cc->Inputs().Index(tick_signal_index_).Value().IsEmpty()) {
|
||||
if (output_only_when_all_inputs_received_) {
|
||||
// Return if one of the input is null.
|
||||
for (int i = 0; i < tick_signal_index_; ++i) {
|
||||
if (current_[i].IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
}
|
||||
// Output each stream.
|
||||
for (int i = 0; i < tick_signal_index_; ++i) {
|
||||
if (!current_[i].IsEmpty()) {
|
||||
cc->Outputs().Index(i).AddPacket(
|
||||
current_[i].At(cc->InputTimestamp()));
|
||||
} else {
|
||||
cc->Outputs().Index(i).SetNextTimestampBound(
|
||||
cc->InputTimestamp().NextAllowedInStream());
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Packet> current_;
|
||||
int tick_signal_index_;
|
||||
bool output_only_when_all_inputs_received_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(PacketClonerCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message PacketClonerCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional PacketClonerCalculatorOptions ext = 258872085;
|
||||
}
|
||||
|
||||
// When true, this calculator will drop received TICK packets if any input
|
||||
// stream hasn't received a packet yet.
|
||||
optional bool output_only_when_all_inputs_received = 1 [default = false];
|
||||
}
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Calculator that acts like the SQL query:
|
||||
// SELECT *
|
||||
// FROM packets_on_stream1 AS packet1
|
||||
// INNER JOIN packets_on_stream2 AS packet2
|
||||
// ON packet1.timestamp = packet2.timestamp
|
||||
//
|
||||
// In other words, it only emits and forwards packets if all input streams are
|
||||
// not empty.
|
||||
//
|
||||
// Intended for use with FixedSizeInputStreamHandler.
|
||||
//
|
||||
// Related:
|
||||
// packet_cloner_calculator.cc: Repeats last-seen packets from empty inputs.
|
||||
class PacketInnerJoinCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
int num_streams_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(PacketInnerJoinCalculator);
|
||||
|
||||
absl::Status PacketInnerJoinCalculator::GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().NumEntries() == cc->Outputs().NumEntries())
|
||||
<< "The number of input and output streams must match.";
|
||||
const int num_streams = cc->Inputs().NumEntries();
|
||||
for (int i = 0; i < num_streams; ++i) {
|
||||
cc->Inputs().Index(i).SetAny();
|
||||
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status PacketInnerJoinCalculator::Open(CalculatorContext* cc) {
|
||||
num_streams_ = cc->Inputs().NumEntries();
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status PacketInnerJoinCalculator::Process(CalculatorContext* cc) {
|
||||
for (int i = 0; i < num_streams_; ++i) {
|
||||
if (cc->Inputs().Index(i).Value().IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < num_streams_; ++i) {
|
||||
cc->Outputs().Index(i).AddPacket(cc->Inputs().Index(i).Value());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,101 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
inline Packet PacketFrom(int i) { return Adopt(new int(i)).At(Timestamp(i)); }
|
||||
|
||||
TEST(PacketInnerJoinCalculatorTest, AllMatching) {
|
||||
// Test case.
|
||||
const std::vector<int> packets_on_stream1 = {0, 1, 2, 3};
|
||||
const std::vector<int> packets_on_stream2 = {0, 1, 2, 3};
|
||||
// Run.
|
||||
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
|
||||
for (int packet_load : packets_on_stream1) {
|
||||
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
for (int packet_load : packets_on_stream2) {
|
||||
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
// Check.
|
||||
const std::vector<int> expected = {0, 1, 2, 3};
|
||||
ASSERT_EQ(expected.size(), runner.Outputs().Index(0).packets.size());
|
||||
ASSERT_EQ(expected.size(), runner.Outputs().Index(1).packets.size());
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
const Packet packet1 = runner.Outputs().Index(0).packets[i];
|
||||
EXPECT_EQ(expected[i], packet1.Get<int>());
|
||||
EXPECT_EQ(expected[i], packet1.Timestamp().Value());
|
||||
const Packet packet2 = runner.Outputs().Index(1).packets[i];
|
||||
EXPECT_EQ(expected[i], packet2.Get<int>());
|
||||
EXPECT_EQ(expected[i], packet2.Timestamp().Value());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketInnerJoinCalculatorTest, NoneMatching) {
|
||||
// Test case.
|
||||
const std::vector<int> packets_on_stream1 = {0, 2};
|
||||
const std::vector<int> packets_on_stream2 = {1, 3};
|
||||
// Run.
|
||||
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
|
||||
for (int packet_load : packets_on_stream1) {
|
||||
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
for (int packet_load : packets_on_stream2) {
|
||||
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
// Check.
|
||||
EXPECT_TRUE(runner.Outputs().Index(0).packets.empty());
|
||||
EXPECT_TRUE(runner.Outputs().Index(1).packets.empty());
|
||||
}
|
||||
|
||||
TEST(PacketInnerJoinCalculatorTest, SomeMatching) {
|
||||
// Test case.
|
||||
const std::vector<int> packets_on_stream1 = {0, 1, 2, 3, 4, 6};
|
||||
const std::vector<int> packets_on_stream2 = {0, 2, 4, 5, 6};
|
||||
// Run.
|
||||
CalculatorRunner runner("PacketInnerJoinCalculator", "", 2, 2, 0);
|
||||
for (int packet_load : packets_on_stream1) {
|
||||
runner.MutableInputs()->Index(0).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
for (int packet_load : packets_on_stream2) {
|
||||
runner.MutableInputs()->Index(1).packets.push_back(PacketFrom(packet_load));
|
||||
}
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
// Check.
|
||||
const std::vector<int> expected = {0, 2, 4, 6};
|
||||
ASSERT_EQ(expected.size(), runner.Outputs().Index(0).packets.size());
|
||||
ASSERT_EQ(expected.size(), runner.Outputs().Index(1).packets.size());
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
const Packet packet1 = runner.Outputs().Index(0).packets[i];
|
||||
EXPECT_EQ(expected[i], packet1.Get<int>());
|
||||
EXPECT_EQ(expected[i], packet1.Timestamp().Value());
|
||||
const Packet packet2 = runner.Outputs().Index(1).packets[i];
|
||||
EXPECT_EQ(expected[i], packet2.Get<int>());
|
||||
EXPECT_EQ(expected[i], packet2.Timestamp().Value());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// For each non empty input packet, emits a single output packet containing a
|
||||
// boolean value "true", "false" in response to empty packets (a.k.a. timestamp
|
||||
// bound updates) This can be used to "flag" the presence of an arbitrary packet
|
||||
// type as input into a downstream calculator.
|
||||
//
|
||||
// Inputs:
|
||||
// PACKET - any type.
|
||||
//
|
||||
// Outputs:
|
||||
// PRESENCE - bool.
|
||||
// "true" if packet is not empty, "false" if there's timestamp bound update
|
||||
// instead.
|
||||
//
|
||||
// Examples:
|
||||
// node: {
|
||||
// calculator: "PacketPresenceCalculator"
|
||||
// input_stream: "PACKET:packet"
|
||||
// output_stream: "PRESENCE:presence"
|
||||
// }
|
||||
//
|
||||
// This calculator can be used in conjuction with GateCalculator in order to
|
||||
// allow/disallow processing. For instance:
|
||||
// node: {
|
||||
// calculator: "PacketPresenceCalculator"
|
||||
// input_stream: "PACKET:value"
|
||||
// output_stream: "PRESENCE:disallow_if_present"
|
||||
// }
|
||||
// node {
|
||||
// calculator: "GateCalculator"
|
||||
// input_stream: "image"
|
||||
// input_stream: "DISALLOW:disallow_if_present"
|
||||
// output_stream: "image_for_processing"
|
||||
// options: {
|
||||
// [mediapipe.GateCalculatorOptions.ext] {
|
||||
// empty_packets_as_allow: true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class PacketPresenceCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("PACKET").SetAny();
|
||||
cc->Outputs().Tag("PRESENCE").Set<bool>();
|
||||
// Process() function is invoked in response to input stream timestamp
|
||||
// bound updates.
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
cc->Outputs()
|
||||
.Tag("PRESENCE")
|
||||
.AddPacket(MakePacket<bool>(!cc->Inputs().Tag("PACKET").IsEmpty())
|
||||
.At(cc->InputTimestamp()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(PacketPresenceCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/framework/tool/sink.h"
|
||||
|
||||
namespace mediapipe {
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::Eq;
|
||||
using ::testing::Value;
|
||||
namespace {
|
||||
|
||||
MATCHER_P2(BoolPacket, value, timestamp, "") {
|
||||
return Value(arg.template Get<bool>(), Eq(value)) &&
|
||||
Value(arg.Timestamp(), Eq(timestamp));
|
||||
}
|
||||
|
||||
TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
|
||||
std::vector<Packet> output_packets;
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'allow'
|
||||
input_stream: 'value'
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: 'value'
|
||||
input_stream: 'ALLOW:allow'
|
||||
output_stream: 'gated_value'
|
||||
}
|
||||
node {
|
||||
calculator: 'PacketPresenceCalculator'
|
||||
input_stream: 'PACKET:gated_value'
|
||||
output_stream: 'PRESENCE:presence'
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("presence", &graph_config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
auto send_packet = [&graph](int value, bool allow, Timestamp timestamp) {
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"value", MakePacket<int>(value).At(timestamp)));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"allow", MakePacket<bool>(allow).At(timestamp)));
|
||||
};
|
||||
|
||||
send_packet(10, false, Timestamp(10));
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, ElementsAre(BoolPacket(false, Timestamp(10))));
|
||||
|
||||
output_packets.clear();
|
||||
send_packet(20, true, Timestamp(11));
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, ElementsAre(BoolPacket(true, Timestamp(11))));
|
||||
|
||||
MP_EXPECT_OK(graph.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,696 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/packet_resampler_calculator.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace {
|
||||
// Reflect an integer against the lower and upper bound of an interval.
|
||||
int64 ReflectBetween(int64 ts, int64 ts_min, int64 ts_max) {
|
||||
if (ts < ts_min) return 2 * ts_min - ts - 1;
|
||||
if (ts >= ts_max) return 2 * ts_max - ts - 1;
|
||||
return ts;
|
||||
}
|
||||
|
||||
// Creates a secure random number generator for use in ProcessWithJitter.
|
||||
// If no secure random number generator can be constructed, the jitter
|
||||
// option is disabled in order to mainatain a consistent security and
|
||||
// consistent random seeding.
|
||||
std::unique_ptr<RandomBase> CreateSecureRandom(const std::string& seed) {
|
||||
RandomBase* result = nullptr;
|
||||
return std::unique_ptr<RandomBase>(result);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
REGISTER_CALCULATOR(PacketResamplerCalculator);
|
||||
namespace {
|
||||
// Returns a TimestampDiff (assuming microseconds) corresponding to the
|
||||
// given time in seconds.
|
||||
TimestampDiff TimestampDiffFromSeconds(double seconds) {
|
||||
return TimestampDiff(MathUtil::SafeRound<int64, double>(
|
||||
seconds * Timestamp::kTimestampUnitsPerSecond));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) {
|
||||
const auto& resampler_options =
|
||||
cc->Options<PacketResamplerCalculatorOptions>();
|
||||
if (cc->InputSidePackets().HasTag("OPTIONS")) {
|
||||
cc->InputSidePackets().Tag("OPTIONS").Set<CalculatorOptions>();
|
||||
}
|
||||
CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0);
|
||||
if (!input_data_id.IsValid()) {
|
||||
input_data_id = cc->Inputs().GetId("", 0);
|
||||
}
|
||||
cc->Inputs().Get(input_data_id).SetAny();
|
||||
if (cc->Inputs().HasTag("VIDEO_HEADER")) {
|
||||
cc->Inputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
|
||||
}
|
||||
|
||||
CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0);
|
||||
if (!output_data_id.IsValid()) {
|
||||
output_data_id = cc->Outputs().GetId("", 0);
|
||||
}
|
||||
cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id));
|
||||
if (cc->Outputs().HasTag("VIDEO_HEADER")) {
|
||||
cc->Outputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
|
||||
}
|
||||
|
||||
if (resampler_options.jitter() != 0.0) {
|
||||
RET_CHECK_GT(resampler_options.jitter(), 0.0);
|
||||
RET_CHECK_LE(resampler_options.jitter(), 1.0);
|
||||
RET_CHECK(cc->InputSidePackets().HasTag("SEED"));
|
||||
cc->InputSidePackets().Tag("SEED").Set<std::string>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
|
||||
const auto resampler_options =
|
||||
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
|
||||
cc->InputSidePackets(), "OPTIONS");
|
||||
|
||||
flush_last_packet_ = resampler_options.flush_last_packet();
|
||||
jitter_ = resampler_options.jitter();
|
||||
|
||||
input_data_id_ = cc->Inputs().GetId("DATA", 0);
|
||||
if (!input_data_id_.IsValid()) {
|
||||
input_data_id_ = cc->Inputs().GetId("", 0);
|
||||
}
|
||||
output_data_id_ = cc->Outputs().GetId("DATA", 0);
|
||||
if (!output_data_id_.IsValid()) {
|
||||
output_data_id_ = cc->Outputs().GetId("", 0);
|
||||
}
|
||||
|
||||
frame_rate_ = resampler_options.frame_rate();
|
||||
start_time_ = resampler_options.has_start_time()
|
||||
? Timestamp(resampler_options.start_time())
|
||||
: Timestamp::Min();
|
||||
end_time_ = resampler_options.has_end_time()
|
||||
? Timestamp(resampler_options.end_time())
|
||||
: Timestamp::Max();
|
||||
round_limits_ = resampler_options.round_limits();
|
||||
// The frame_rate has a default value of -1.0, so the user must set it!
|
||||
RET_CHECK_LT(0, frame_rate_)
|
||||
<< "The output frame rate must be greater than zero";
|
||||
RET_CHECK_LE(frame_rate_, Timestamp::kTimestampUnitsPerSecond)
|
||||
<< "The output frame rate must be smaller than "
|
||||
<< Timestamp::kTimestampUnitsPerSecond;
|
||||
|
||||
frame_time_usec_ = static_cast<int64>(1000000.0 / frame_rate_);
|
||||
jitter_usec_ = static_cast<int64>(1000000.0 * jitter_ / frame_rate_);
|
||||
RET_CHECK_LE(jitter_usec_, frame_time_usec_);
|
||||
|
||||
video_header_.frame_rate = frame_rate_;
|
||||
|
||||
if (resampler_options.output_header() !=
|
||||
PacketResamplerCalculatorOptions::NONE &&
|
||||
!cc->Inputs().Get(input_data_id_).Header().IsEmpty()) {
|
||||
if (resampler_options.output_header() ==
|
||||
PacketResamplerCalculatorOptions::UPDATE_VIDEO_HEADER) {
|
||||
video_header_ =
|
||||
cc->Inputs().Get(input_data_id_).Header().Get<VideoHeader>();
|
||||
video_header_.frame_rate = frame_rate_;
|
||||
cc->Outputs()
|
||||
.Get(output_data_id_)
|
||||
.SetHeader(Adopt(new VideoHeader(video_header_)));
|
||||
} else {
|
||||
cc->Outputs()
|
||||
.Get(output_data_id_)
|
||||
.SetHeader(cc->Inputs().Get(input_data_id_).Header());
|
||||
}
|
||||
}
|
||||
|
||||
strategy_ = GetSamplingStrategy(resampler_options);
|
||||
|
||||
return strategy_->Open(cc);
|
||||
}
|
||||
|
||||
absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
|
||||
if (cc->InputTimestamp() == Timestamp::PreStream() &&
|
||||
cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") &&
|
||||
!cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) {
|
||||
video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get<VideoHeader>();
|
||||
video_header_.frame_rate = frame_rate_;
|
||||
if (cc->Inputs().Get(input_data_id_).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
|
||||
if (absl::Status status = strategy_->Process(cc); !status.ok()) {
|
||||
return status; // Avoid MP_RETURN_IF_ERROR macro for external release.
|
||||
}
|
||||
|
||||
last_packet_ = cc->Inputs().Get(input_data_id_).Value();
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status PacketResamplerCalculator::Close(CalculatorContext* cc) {
|
||||
if (!cc->GraphStatus().ok()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
return strategy_->Close(cc);
|
||||
}
|
||||
|
||||
std::unique_ptr<PacketResamplerStrategy>
|
||||
PacketResamplerCalculator::GetSamplingStrategy(
|
||||
const PacketResamplerCalculatorOptions& options) {
|
||||
if (options.reproducible_sampling()) {
|
||||
if (!options.jitter_with_reflection()) {
|
||||
LOG(WARNING)
|
||||
<< "reproducible_sampling enabled w/ jitter_with_reflection "
|
||||
"disabled. "
|
||||
<< "reproducible_sampling always uses jitter with reflection, "
|
||||
<< "Ignoring jitter_with_reflection setting.";
|
||||
}
|
||||
return absl::make_unique<ReproducibleJitterWithReflectionStrategy>(this);
|
||||
}
|
||||
|
||||
if (options.jitter() == 0) {
|
||||
return absl::make_unique<NoJitterStrategy>(this);
|
||||
}
|
||||
|
||||
if (options.jitter_with_reflection()) {
|
||||
return absl::make_unique<LegacyJitterWithReflectionStrategy>(this);
|
||||
}
|
||||
|
||||
// With jitter and no reflection.
|
||||
return absl::make_unique<JitterWithoutReflectionStrategy>(this);
|
||||
}
|
||||
|
||||
Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp(int64 index) const {
|
||||
CHECK_EQ(jitter_, 0.0);
|
||||
CHECK_NE(first_timestamp_, Timestamp::Unset());
|
||||
return first_timestamp_ + TimestampDiffFromSeconds(index / frame_rate_);
|
||||
}
|
||||
|
||||
int64 PacketResamplerCalculator::TimestampToPeriodIndex(
|
||||
Timestamp timestamp) const {
|
||||
CHECK_EQ(jitter_, 0.0);
|
||||
CHECK_NE(first_timestamp_, Timestamp::Unset());
|
||||
return MathUtil::SafeRound<int64, double>(
|
||||
(timestamp - first_timestamp_).Seconds() * frame_rate_);
|
||||
}
|
||||
|
||||
void PacketResamplerCalculator::OutputWithinLimits(CalculatorContext* cc,
|
||||
const Packet& packet) const {
|
||||
TimestampDiff margin((round_limits_) ? frame_time_usec_ / 2 : 0);
|
||||
if (packet.Timestamp() >= start_time_ - margin &&
|
||||
packet.Timestamp() < end_time_ + margin) {
|
||||
cc->Outputs().Get(output_data_id_).AddPacket(packet);
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) {
|
||||
const auto resampler_options =
|
||||
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
|
||||
cc->InputSidePackets(), "OPTIONS");
|
||||
|
||||
if (resampler_options.output_header() !=
|
||||
PacketResamplerCalculatorOptions::NONE) {
|
||||
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
|
||||
"the actual value.";
|
||||
}
|
||||
|
||||
if (calculator_->flush_last_packet_) {
|
||||
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
|
||||
"ignored, because we are adding jitter.";
|
||||
}
|
||||
|
||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
||||
random_ = CreateSecureRandom(seed);
|
||||
if (random_ == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
"SecureRandom is not available. With \"jitter\" specified, "
|
||||
"PacketResamplerCalculator processing cannot proceed.");
|
||||
}
|
||||
|
||||
packet_reservoir_random_ = CreateSecureRandom(seed);
|
||||
packet_reservoir_ =
|
||||
std::make_unique<PacketReservoir>(packet_reservoir_random_.get());
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status LegacyJitterWithReflectionStrategy::Close(CalculatorContext* cc) {
|
||||
if (!packet_reservoir_->IsEmpty()) {
|
||||
LOG(INFO) << "Emitting pack from reservoir.";
|
||||
calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status LegacyJitterWithReflectionStrategy::Process(
|
||||
CalculatorContext* cc) {
|
||||
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
|
||||
|
||||
if (packet_reservoir_->IsEnabled() &&
|
||||
(first_timestamp_ == Timestamp::Unset() ||
|
||||
(cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) {
|
||||
auto curr_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
|
||||
packet_reservoir_->AddSample(curr_packet);
|
||||
}
|
||||
|
||||
if (first_timestamp_ == Timestamp::Unset()) {
|
||||
first_timestamp_ = cc->InputTimestamp();
|
||||
InitializeNextOutputTimestampWithJitter();
|
||||
if (first_timestamp_ == next_output_timestamp_) {
|
||||
calculator_->OutputWithinLimits(cc, cc->Inputs()
|
||||
.Get(calculator_->input_data_id_)
|
||||
.Value()
|
||||
.At(next_output_timestamp_));
|
||||
UpdateNextOutputTimestampWithJitter();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
if (calculator_->frame_time_usec_ <
|
||||
(cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) {
|
||||
LOG_FIRST_N(WARNING, 2)
|
||||
<< "Adding jitter is not very useful when upsampling.";
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const int64 last_diff =
|
||||
(next_output_timestamp_ - calculator_->last_packet_.Timestamp())
|
||||
.Value();
|
||||
RET_CHECK_GT(last_diff, 0);
|
||||
const int64 curr_diff =
|
||||
(next_output_timestamp_ - cc->InputTimestamp()).Value();
|
||||
if (curr_diff > 0) {
|
||||
break;
|
||||
}
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, (std::abs(curr_diff) > last_diff
|
||||
? calculator_->last_packet_
|
||||
: cc->Inputs().Get(calculator_->input_data_id_).Value())
|
||||
.At(next_output_timestamp_));
|
||||
UpdateNextOutputTimestampWithJitter();
|
||||
// From now on every time a packet is emitted the timestamp of the next
|
||||
// packet becomes known; that timestamp is stored in next_output_timestamp_.
|
||||
// The only exception to this rule is the packet emitted from Close() which
|
||||
// can only happen when jitter_with_reflection is enabled but in this case
|
||||
// next_output_timestamp_min_ is a non-decreasing lower bound of any
|
||||
// subsequent packet.
|
||||
const Timestamp timestamp_bound = next_output_timestamp_min_;
|
||||
cc->Outputs()
|
||||
.Get(calculator_->output_data_id_)
|
||||
.SetNextTimestampBound(timestamp_bound);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void LegacyJitterWithReflectionStrategy::
|
||||
InitializeNextOutputTimestampWithJitter() {
|
||||
next_output_timestamp_min_ = first_timestamp_;
|
||||
next_output_timestamp_ =
|
||||
first_timestamp_ +
|
||||
random_->UnbiasedUniform64(calculator_->frame_time_usec_);
|
||||
}
|
||||
|
||||
void LegacyJitterWithReflectionStrategy::UpdateNextOutputTimestampWithJitter() {
|
||||
packet_reservoir_->Clear();
|
||||
next_output_timestamp_min_ += calculator_->frame_time_usec_;
|
||||
Timestamp next_output_timestamp_max_ =
|
||||
next_output_timestamp_min_ + calculator_->frame_time_usec_;
|
||||
|
||||
next_output_timestamp_ +=
|
||||
calculator_->frame_time_usec_ +
|
||||
random_->UnbiasedUniform64(2 * calculator_->jitter_usec_ + 1) -
|
||||
calculator_->jitter_usec_;
|
||||
next_output_timestamp_ = Timestamp(ReflectBetween(
|
||||
next_output_timestamp_.Value(), next_output_timestamp_min_.Value(),
|
||||
next_output_timestamp_max_.Value()));
|
||||
CHECK_GE(next_output_timestamp_, next_output_timestamp_min_);
|
||||
CHECK_LT(next_output_timestamp_, next_output_timestamp_max_);
|
||||
}
|
||||
|
||||
absl::Status ReproducibleJitterWithReflectionStrategy::Open(
|
||||
CalculatorContext* cc) {
|
||||
const auto resampler_options =
|
||||
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
|
||||
cc->InputSidePackets(), "OPTIONS");
|
||||
|
||||
if (resampler_options.output_header() !=
|
||||
PacketResamplerCalculatorOptions::NONE) {
|
||||
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
|
||||
"the actual value.";
|
||||
}
|
||||
|
||||
if (calculator_->flush_last_packet_) {
|
||||
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
|
||||
"ignored, because we are adding jitter.";
|
||||
}
|
||||
|
||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
||||
random_ = CreateSecureRandom(seed);
|
||||
if (random_ == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
"SecureRandom is not available. With \"jitter\" specified, "
|
||||
"PacketResamplerCalculator processing cannot proceed.");
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status ReproducibleJitterWithReflectionStrategy::Close(
|
||||
CalculatorContext* cc) {
|
||||
// If last packet is non-empty and a packet hasn't been emitted for this
|
||||
// period, emit the last packet.
|
||||
if (!calculator_->last_packet_.IsEmpty() && !packet_emitted_this_period_) {
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, calculator_->last_packet_.At(next_output_timestamp_));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status ReproducibleJitterWithReflectionStrategy::Process(
|
||||
CalculatorContext* cc) {
|
||||
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
|
||||
|
||||
Packet current_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
|
||||
|
||||
if (calculator_->last_packet_.IsEmpty()) {
|
||||
// last_packet is empty, this is the first packet of the stream.
|
||||
|
||||
InitializeNextOutputTimestamp(current_packet.Timestamp());
|
||||
|
||||
// If next_output_timestamp_ happens to fall before current_packet, emit
|
||||
// current packet. Only a single packet can be emitted at the beginning
|
||||
// of the stream.
|
||||
if (next_output_timestamp_ < current_packet.Timestamp()) {
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, current_packet.At(next_output_timestamp_));
|
||||
packet_emitted_this_period_ = true;
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Last packet is set, so we are mid-stream.
|
||||
if (calculator_->frame_time_usec_ <
|
||||
(current_packet.Timestamp() - calculator_->last_packet_.Timestamp())
|
||||
.Value()) {
|
||||
// Note, if the stream is upsampling, this could lead to the same packet
|
||||
// being emitted twice. Upsampling and jitter doesn't make much sense
|
||||
// but does technically work.
|
||||
LOG_FIRST_N(WARNING, 2)
|
||||
<< "Adding jitter is not very useful when upsampling.";
|
||||
}
|
||||
|
||||
// Since we may be upsampling, we need to iteratively advance the
|
||||
// next_output_timestamp_ one period at a time until it reaches the period
|
||||
// current_packet is in. During this process, last_packet and/or
|
||||
// current_packet may be repeatly emitted.
|
||||
|
||||
UpdateNextOutputTimestamp(current_packet.Timestamp());
|
||||
|
||||
while (!packet_emitted_this_period_ &&
|
||||
next_output_timestamp_ <= current_packet.Timestamp()) {
|
||||
// last_packet < next_output_timestamp_ <= current_packet,
|
||||
// so emit the closest packet.
|
||||
Packet packet_to_emit =
|
||||
current_packet.Timestamp() - next_output_timestamp_ <
|
||||
next_output_timestamp_ - calculator_->last_packet_.Timestamp()
|
||||
? current_packet
|
||||
: calculator_->last_packet_;
|
||||
calculator_->OutputWithinLimits(cc,
|
||||
packet_to_emit.At(next_output_timestamp_));
|
||||
|
||||
packet_emitted_this_period_ = true;
|
||||
|
||||
// If we are upsampling, packet_emitted_this_period_ can be reset by
|
||||
// the following UpdateNext and the loop will iterate.
|
||||
UpdateNextOutputTimestamp(current_packet.Timestamp());
|
||||
}
|
||||
|
||||
// Set the bounds on the output stream. Note, if we emitted a packet
|
||||
// above, it will already be set at next_output_timestamp_ + 1, in which
|
||||
// case we have to skip setting it.
|
||||
if (cc->Outputs().Get(calculator_->output_data_id_).NextTimestampBound() <
|
||||
next_output_timestamp_) {
|
||||
cc->Outputs()
|
||||
.Get(calculator_->output_data_id_)
|
||||
.SetNextTimestampBound(next_output_timestamp_);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ReproducibleJitterWithReflectionStrategy::InitializeNextOutputTimestamp(
|
||||
Timestamp current_timestamp) {
|
||||
if (next_output_timestamp_min_ != Timestamp::Unset()) {
|
||||
return;
|
||||
}
|
||||
|
||||
next_output_timestamp_min_ = Timestamp(0);
|
||||
next_output_timestamp_ =
|
||||
Timestamp(GetNextRandom(calculator_->frame_time_usec_));
|
||||
|
||||
// While the current timestamp is ahead of the max (i.e. min + frame_time),
|
||||
// fast-forward.
|
||||
while (current_timestamp >=
|
||||
next_output_timestamp_min_ + calculator_->frame_time_usec_) {
|
||||
packet_emitted_this_period_ = true; // Force update...
|
||||
UpdateNextOutputTimestamp(current_timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
void ReproducibleJitterWithReflectionStrategy::UpdateNextOutputTimestamp(
|
||||
Timestamp current_timestamp) {
|
||||
if (packet_emitted_this_period_ &&
|
||||
current_timestamp >=
|
||||
next_output_timestamp_min_ + calculator_->frame_time_usec_) {
|
||||
next_output_timestamp_min_ += calculator_->frame_time_usec_;
|
||||
Timestamp next_output_timestamp_max_ =
|
||||
next_output_timestamp_min_ + calculator_->frame_time_usec_;
|
||||
|
||||
next_output_timestamp_ += calculator_->frame_time_usec_ +
|
||||
GetNextRandom(2 * calculator_->jitter_usec_ + 1) -
|
||||
calculator_->jitter_usec_;
|
||||
next_output_timestamp_ = Timestamp(ReflectBetween(
|
||||
next_output_timestamp_.Value(), next_output_timestamp_min_.Value(),
|
||||
next_output_timestamp_max_.Value()));
|
||||
|
||||
packet_emitted_this_period_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) {
|
||||
const auto resampler_options =
|
||||
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
|
||||
cc->InputSidePackets(), "OPTIONS");
|
||||
|
||||
if (resampler_options.output_header() !=
|
||||
PacketResamplerCalculatorOptions::NONE) {
|
||||
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not "
|
||||
"the actual value.";
|
||||
}
|
||||
|
||||
if (calculator_->flush_last_packet_) {
|
||||
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is "
|
||||
"ignored, because we are adding jitter.";
|
||||
}
|
||||
|
||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
||||
random_ = CreateSecureRandom(seed);
|
||||
if (random_ == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
"SecureRandom is not available. With \"jitter\" specified, "
|
||||
"PacketResamplerCalculator processing cannot proceed.");
|
||||
}
|
||||
|
||||
packet_reservoir_random_ = CreateSecureRandom(seed);
|
||||
packet_reservoir_ =
|
||||
absl::make_unique<PacketReservoir>(packet_reservoir_random_.get());
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status JitterWithoutReflectionStrategy::Close(CalculatorContext* cc) {
|
||||
if (!packet_reservoir_->IsEmpty()) {
|
||||
calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status JitterWithoutReflectionStrategy::Process(CalculatorContext* cc) {
|
||||
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
|
||||
|
||||
// Packet reservior is used to make sure there's an output for every period,
|
||||
// e.g. partial period at the end of the stream.
|
||||
if (packet_reservoir_->IsEnabled() &&
|
||||
(calculator_->first_timestamp_ == Timestamp::Unset() ||
|
||||
(cc->InputTimestamp() - next_output_timestamp_min_).Value() >= 0)) {
|
||||
auto curr_packet = cc->Inputs().Get(calculator_->input_data_id_).Value();
|
||||
packet_reservoir_->AddSample(curr_packet);
|
||||
}
|
||||
|
||||
if (calculator_->first_timestamp_ == Timestamp::Unset()) {
|
||||
calculator_->first_timestamp_ = cc->InputTimestamp();
|
||||
InitializeNextOutputTimestamp();
|
||||
if (calculator_->first_timestamp_ == next_output_timestamp_) {
|
||||
calculator_->OutputWithinLimits(cc, cc->Inputs()
|
||||
.Get(calculator_->input_data_id_)
|
||||
.Value()
|
||||
.At(next_output_timestamp_));
|
||||
UpdateNextOutputTimestamp();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
if (calculator_->frame_time_usec_ <
|
||||
(cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) {
|
||||
LOG_FIRST_N(WARNING, 2)
|
||||
<< "Adding jitter is not very useful when upsampling.";
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const int64 last_diff =
|
||||
(next_output_timestamp_ - calculator_->last_packet_.Timestamp())
|
||||
.Value();
|
||||
RET_CHECK_GT(last_diff, 0);
|
||||
const int64 curr_diff =
|
||||
(next_output_timestamp_ - cc->InputTimestamp()).Value();
|
||||
if (curr_diff > 0) {
|
||||
break;
|
||||
}
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, (std::abs(curr_diff) > last_diff
|
||||
? calculator_->last_packet_
|
||||
: cc->Inputs().Get(calculator_->input_data_id_).Value())
|
||||
.At(next_output_timestamp_));
|
||||
UpdateNextOutputTimestamp();
|
||||
cc->Outputs()
|
||||
.Get(calculator_->output_data_id_)
|
||||
.SetNextTimestampBound(next_output_timestamp_);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void JitterWithoutReflectionStrategy::InitializeNextOutputTimestamp() {
|
||||
next_output_timestamp_min_ = calculator_->first_timestamp_;
|
||||
next_output_timestamp_ = calculator_->first_timestamp_ +
|
||||
calculator_->frame_time_usec_ * random_->RandFloat();
|
||||
}
|
||||
|
||||
void JitterWithoutReflectionStrategy::UpdateNextOutputTimestamp() {
|
||||
packet_reservoir_->Clear();
|
||||
packet_reservoir_->Disable();
|
||||
next_output_timestamp_ += calculator_->frame_time_usec_ *
|
||||
((1.0 - calculator_->jitter_) +
|
||||
2.0 * calculator_->jitter_ * random_->RandFloat());
|
||||
}
|
||||
|
||||
absl::Status NoJitterStrategy::Open(CalculatorContext* cc) {
|
||||
const auto resampler_options =
|
||||
tool::RetrieveOptions(cc->Options<PacketResamplerCalculatorOptions>(),
|
||||
cc->InputSidePackets(), "OPTIONS");
|
||||
base_timestamp_ = resampler_options.has_base_timestamp()
|
||||
? Timestamp(resampler_options.base_timestamp())
|
||||
: Timestamp::Unset();
|
||||
|
||||
period_count_ = 0;
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status NoJitterStrategy::Close(CalculatorContext* cc) {
|
||||
// Emit the last packet received if we have at least one packet, but
|
||||
// haven't sent anything for its period.
|
||||
if (calculator_->first_timestamp_ != Timestamp::Unset() &&
|
||||
calculator_->flush_last_packet_ &&
|
||||
calculator_->TimestampToPeriodIndex(
|
||||
calculator_->last_packet_.Timestamp()) == period_count_) {
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, calculator_->last_packet_.At(
|
||||
calculator_->PeriodIndexToTimestamp(period_count_)));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status NoJitterStrategy::Process(CalculatorContext* cc) {
|
||||
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
|
||||
|
||||
if (calculator_->first_timestamp_ == Timestamp::Unset()) {
|
||||
// This is the first packet, initialize the first_timestamp_.
|
||||
if (base_timestamp_ == Timestamp::Unset()) {
|
||||
// Initialize first_timestamp_ with exactly the first packet timestamp.
|
||||
calculator_->first_timestamp_ = cc->InputTimestamp();
|
||||
} else {
|
||||
// Initialize first_timestamp_ with the first packet timestamp
|
||||
// aligned to the base_timestamp_.
|
||||
int64 first_index = MathUtil::SafeRound<int64, double>(
|
||||
(cc->InputTimestamp() - base_timestamp_).Seconds() *
|
||||
calculator_->frame_rate_);
|
||||
calculator_->first_timestamp_ =
|
||||
base_timestamp_ +
|
||||
TimestampDiffFromSeconds(first_index / calculator_->frame_rate_);
|
||||
}
|
||||
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) {
|
||||
cc->Outputs()
|
||||
.Tag("VIDEO_HEADER")
|
||||
.Add(new VideoHeader(calculator_->video_header_),
|
||||
Timestamp::PreStream());
|
||||
}
|
||||
}
|
||||
const Timestamp received_timestamp = cc->InputTimestamp();
|
||||
const int64 received_timestamp_idx =
|
||||
calculator_->TimestampToPeriodIndex(received_timestamp);
|
||||
// Only consider the received packet if it belongs to the current period
|
||||
// (== period_count_) or to a newer one (> period_count_).
|
||||
if (received_timestamp_idx >= period_count_) {
|
||||
// Fill the empty periods until we are in the same index as the received
|
||||
// packet.
|
||||
while (received_timestamp_idx > period_count_) {
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, calculator_->last_packet_.At(
|
||||
calculator_->PeriodIndexToTimestamp(period_count_)));
|
||||
++period_count_;
|
||||
}
|
||||
// Now, if the received packet has a timestamp larger than the middle of
|
||||
// the current period, we can send a packet without waiting. We send the
|
||||
// one closer to the middle.
|
||||
Timestamp target_timestamp =
|
||||
calculator_->PeriodIndexToTimestamp(period_count_);
|
||||
if (received_timestamp >= target_timestamp) {
|
||||
bool have_last_packet =
|
||||
(calculator_->last_packet_.Timestamp() != Timestamp::Unset());
|
||||
bool send_current =
|
||||
!have_last_packet ||
|
||||
(received_timestamp - target_timestamp <=
|
||||
target_timestamp - calculator_->last_packet_.Timestamp());
|
||||
if (send_current) {
|
||||
calculator_->OutputWithinLimits(cc,
|
||||
cc->Inputs()
|
||||
.Get(calculator_->input_data_id_)
|
||||
.Value()
|
||||
.At(target_timestamp));
|
||||
} else {
|
||||
calculator_->OutputWithinLimits(
|
||||
cc, calculator_->last_packet_.At(target_timestamp));
|
||||
}
|
||||
++period_count_;
|
||||
}
|
||||
// TODO: Add a mechanism to the framework to allow these packets
|
||||
// to be output earlier (without waiting for a much later packet to
|
||||
// arrive)
|
||||
|
||||
// Update the bound for the next packet.
|
||||
cc->Outputs()
|
||||
.Get(calculator_->output_data_id_)
|
||||
.SetNextTimestampBound(
|
||||
calculator_->PeriodIndexToTimestamp(period_count_));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,392 +0,0 @@
|
|||
#ifndef MEDIAPIPE_CALCULATORS_CORE_PACKET_RESAMPLER_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_PACKET_RESAMPLER_CALCULATOR_H_
|
||||
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/deps/mathutil.h"
|
||||
#include "mediapipe/framework/deps/random_base.h"
|
||||
#include "mediapipe/framework/formats/video_stream_header.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/tool/options_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class PacketReservoir {
|
||||
public:
|
||||
PacketReservoir(RandomBase* rng) : rng_(rng) {}
|
||||
// Replace candidate with current packet with 1/count_ probability.
|
||||
void AddSample(Packet sample) {
|
||||
if (rng_->UnbiasedUniform(++count_) == 0) {
|
||||
reservoir_ = sample;
|
||||
}
|
||||
}
|
||||
bool IsEnabled() { return rng_ && enabled_; }
|
||||
void Disable() {
|
||||
if (enabled_) enabled_ = false;
|
||||
}
|
||||
void Clear() { count_ = 0; }
|
||||
bool IsEmpty() { return count_ == 0; }
|
||||
Packet GetSample() { return reservoir_; }
|
||||
|
||||
private:
|
||||
RandomBase* rng_;
|
||||
bool enabled_ = true;
|
||||
int32 count_ = 0;
|
||||
Packet reservoir_;
|
||||
};
|
||||
|
||||
// This calculator is used to normalize the frequency of the packets
|
||||
// out of a stream. Given a desired frame rate, packets are going to be
|
||||
// removed or added to achieve it.
|
||||
//
|
||||
// If jitter_ is specified:
|
||||
// - The first packet is chosen randomly (uniform distribution) among frames
|
||||
// that correspond to timestamps [0, 1/frame_rate). Let the chosen packet
|
||||
// correspond to timestamp t.
|
||||
// - The next packet is chosen randomly (uniform distribution) among frames
|
||||
// that correspond to [t+(1-jitter)/frame_rate, t+(1+jitter)/frame_rate].
|
||||
// - if jitter_with_reflection is true, the timestamp will be reflected
|
||||
// against the boundaries of [t_0 + (k-1)/frame_rate, t_0 + k/frame_rate)
|
||||
// so that its marginal distribution is uniform within this interval.
|
||||
// In the formula, t_0 is the timestamp of the first sampled
|
||||
// packet, and the k is the packet index.
|
||||
// See paper (https://arxiv.org/abs/2002.01147) for details.
|
||||
// - t is updated and the process is repeated.
|
||||
// - Note that seed is specified as input side packet for reproducibility of
|
||||
// the resampling. For Cloud ML Video Intelligence API, the hash of the
|
||||
// input video should serve this purpose. For YouTube, either video ID or
|
||||
// content hex ID of the input video should do.
|
||||
// - If reproducible_samping is true, care is taken to allow reproducible
|
||||
// "mid-stream" sampling. The calculator can be executed on a stream that
|
||||
// doesn't start at the first period. For instance, if the calculator
|
||||
// is run on a 10 second stream it will produce the same set of samples
|
||||
// as two runs of the calculator, the first with 3 seconds of input starting
|
||||
// at time 0 and the second with 7 seconds of input starting at time +3s.
|
||||
// - In order to guarantee the exact same samples, 1) the inputs must be
|
||||
// aligned with the sampling period. For instance, if the sampling rate
|
||||
// is 2 frames per second, streams should be aligned on 0.5 second
|
||||
// boundaries, and 2) the stream must include at least one extra packet
|
||||
// before and after the second aligned sampling period.
|
||||
//
|
||||
// If jitter_ is not specified:
|
||||
// - The first packet defines the first_timestamp of the output stream,
|
||||
// so it is always emitted.
|
||||
// - If more packets are emitted, they will have timestamp equal to
|
||||
// round(first_timestamp + k * period) , where k is a positive
|
||||
// integer and the period is defined by the frame rate.
|
||||
// Example: first_timestamp=0, fps=30, then the output stream
|
||||
// will have timestamps: 0, 33333, 66667, 100000, etc...
|
||||
// - The packets selected for the output stream are the ones closer
|
||||
// to the exact middle point (33333.33, 66666.67 in our previous
|
||||
// example). In case of ties, later packets are chosen.
|
||||
// - 'Empty' periods happen when there are no packets for a long time
|
||||
// (greater than a period). In this case, we send a copy of the last
|
||||
// packet received before the empty period.
|
||||
// The jitter feature is disabled by default. To enable it, you need to
|
||||
// implement CreateSecureRandom(const std::string&).
|
||||
//
|
||||
// The data stream may be either specified as the only stream (by index)
|
||||
// or as the stream with tag "DATA".
|
||||
//
|
||||
// The input and output streams may be accompanied by a VIDEO_HEADER
|
||||
// stream. This stream includes a VideoHeader at Timestamp::PreStream().
|
||||
// The input VideoHeader on the VIDEO_HEADER stream will always be updated
|
||||
// with the resampler frame rate no matter what the options value for
|
||||
// output_header is before being output on the output VIDEO_HEADER stream.
|
||||
// If the input VideoHeader is not available, then only the frame rate
|
||||
// value will be set in the output.
|
||||
//
|
||||
// Related:
|
||||
// packet_downsampler_calculator.cc: skips packets regardless of timestamps.
|
||||
class PacketResamplerCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
// Given the current count of periods that have passed, this returns
|
||||
// the next valid timestamp of the middle point of the next period:
|
||||
// if count is 0, it returns the first_timestamp_.
|
||||
// if count is 1, it returns the first_timestamp_ + period (corresponding
|
||||
// to the first tick using exact fps)
|
||||
// e.g. for frame_rate=30 and first_timestamp_=0:
|
||||
// 0: 0
|
||||
// 1: 33333
|
||||
// 2: 66667
|
||||
// 3: 100000
|
||||
//
|
||||
// Can only be used if jitter_ equals zero.
|
||||
Timestamp PeriodIndexToTimestamp(int64 index) const;
|
||||
|
||||
// Given a Timestamp, finds the closest sync Timestamp based on
|
||||
// first_timestamp_ and the desired fps.
|
||||
//
|
||||
// Can only be used if jitter_ equals zero.
|
||||
int64 TimestampToPeriodIndex(Timestamp timestamp) const;
|
||||
|
||||
// Outputs a packet if it is in range (start_time_, end_time_).
|
||||
void OutputWithinLimits(CalculatorContext* cc, const Packet& packet) const;
|
||||
|
||||
protected:
|
||||
// Returns Sampling Strategy to use.
|
||||
//
|
||||
// Virtual to allow injection of testing strategies.
|
||||
virtual std::unique_ptr<class PacketResamplerStrategy> GetSamplingStrategy(
|
||||
const mediapipe::PacketResamplerCalculatorOptions& options);
|
||||
|
||||
private:
|
||||
std::unique_ptr<class PacketResamplerStrategy> strategy_;
|
||||
|
||||
// The timestamp of the first packet received.
|
||||
Timestamp first_timestamp_;
|
||||
|
||||
// Number of frames per second (desired output frequency).
|
||||
double frame_rate_;
|
||||
|
||||
// Inverse of frame_rate_.
|
||||
int64 frame_time_usec_;
|
||||
|
||||
VideoHeader video_header_;
|
||||
// The "DATA" input stream.
|
||||
CollectionItemId input_data_id_;
|
||||
// The "DATA" output stream.
|
||||
CollectionItemId output_data_id_;
|
||||
|
||||
// Indicator whether to flush last packet even if its timestamp is greater
|
||||
// than the final stream timestamp.
|
||||
bool flush_last_packet_;
|
||||
|
||||
double jitter_ = 0.0;
|
||||
|
||||
int64 jitter_usec_;
|
||||
|
||||
// The last packet that was received.
|
||||
Packet last_packet_;
|
||||
|
||||
// If specified, only outputs at/after start_time are included.
|
||||
Timestamp start_time_;
|
||||
|
||||
// If specified, only outputs before end_time are included.
|
||||
Timestamp end_time_;
|
||||
|
||||
// If set, the output timestamps nearest to start_time and end_time
|
||||
// are included in the output, even if the nearest timestamp is not
|
||||
// between start_time and end_time.
|
||||
bool round_limits_;
|
||||
|
||||
// Allow strategies access to all internal calculator state.
|
||||
//
|
||||
// The calculator and strategies are intimiately tied together so this should
|
||||
// not break encapsulation.
|
||||
friend class LegacyJitterWithReflectionStrategy;
|
||||
friend class ReproducibleJitterWithReflectionStrategy;
|
||||
friend class JitterWithoutReflectionStrategy;
|
||||
friend class NoJitterStrategy;
|
||||
};
|
||||
|
||||
// Abstract class encapsulating sampling stategy.
|
||||
//
|
||||
// These are used solely by PacketResamplerCalculator, but are exposed here
|
||||
// to facilitate tests.
|
||||
class PacketResamplerStrategy {
|
||||
public:
|
||||
PacketResamplerStrategy(PacketResamplerCalculator* calculator)
|
||||
: calculator_(calculator) {}
|
||||
virtual ~PacketResamplerStrategy() = default;
|
||||
|
||||
// Delegate for CalculatorBase::Open. See CalculatorBase for relevant
|
||||
// implementation considerations.
|
||||
virtual absl::Status Open(CalculatorContext* cc) = 0;
|
||||
// Delegate for CalculatorBase::Close. See CalculatorBase for relevant
|
||||
// implementation considerations.
|
||||
virtual absl::Status Close(CalculatorContext* cc) = 0;
|
||||
// Delegate for CalculatorBase::Process. See CalculatorBase for relevant
|
||||
// implementation considerations.
|
||||
virtual absl::Status Process(CalculatorContext* cc) = 0;
|
||||
|
||||
protected:
|
||||
// Calculator running strategy.
|
||||
PacketResamplerCalculator* calculator_;
|
||||
};
|
||||
|
||||
// Strategy that applies Jitter with reflection based sampling.
|
||||
//
|
||||
// Used by PacketResamplerCalculator when both Jitter and reflection are
|
||||
// enabled.
|
||||
//
|
||||
// This applies the legacy jitter with reflection which doesn't allow
|
||||
// for reproducibility of sampling when starting mid-stream. This is maintained
|
||||
// for backward compatibility.
|
||||
class LegacyJitterWithReflectionStrategy : public PacketResamplerStrategy {
|
||||
public:
|
||||
LegacyJitterWithReflectionStrategy(PacketResamplerCalculator* calculator)
|
||||
: PacketResamplerStrategy(calculator) {}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
void InitializeNextOutputTimestampWithJitter();
|
||||
void UpdateNextOutputTimestampWithJitter();
|
||||
|
||||
// Jitter-related variables.
|
||||
std::unique_ptr<RandomBase> random_;
|
||||
|
||||
// The timestamp of the first packet received.
|
||||
Timestamp first_timestamp_;
|
||||
|
||||
// Next packet to be emitted. Since packets may not align perfectly with
|
||||
// next_output_timestamp_, the closest packet will be emitted.
|
||||
Timestamp next_output_timestamp_;
|
||||
|
||||
// Lower bound for next timestamp.
|
||||
//
|
||||
// next_output_timestamp_ will be kept within the interval
|
||||
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
|
||||
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
|
||||
|
||||
// packet reservior used for sampling random packet out of partial
|
||||
// period when jitter is enabled
|
||||
std::unique_ptr<PacketReservoir> packet_reservoir_;
|
||||
|
||||
// random number generator used in packet_reservior_.
|
||||
std::unique_ptr<RandomBase> packet_reservoir_random_;
|
||||
};
|
||||
|
||||
// Strategy that applies reproducible jitter with reflection based sampling.
|
||||
//
|
||||
// Used by PacketResamplerCalculator when both Jitter and reflection are
|
||||
// enabled.
|
||||
class ReproducibleJitterWithReflectionStrategy
|
||||
: public PacketResamplerStrategy {
|
||||
public:
|
||||
ReproducibleJitterWithReflectionStrategy(
|
||||
PacketResamplerCalculator* calculator)
|
||||
: PacketResamplerStrategy(calculator) {}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
protected:
|
||||
// Returns next random in range (0,n].
|
||||
//
|
||||
// Exposed as virtual function for testing Jitter with reflection.
|
||||
// This is the only way random_ is accessed.
|
||||
virtual uint64 GetNextRandom(uint64 n) {
|
||||
return random_->UnbiasedUniform64(n);
|
||||
}
|
||||
|
||||
private:
|
||||
// Initializes Jitter with reflection.
|
||||
//
|
||||
// This will fast-forward to the period containing current_timestamp.
|
||||
// next_output_timestamp_ is guarnateed to be current_timestamp's period
|
||||
// and packet_emitted_this_period_ will be set to false.
|
||||
void InitializeNextOutputTimestamp(Timestamp current_timestamp);
|
||||
|
||||
// Potentially advances next_output_timestamp_ a single period.
|
||||
//
|
||||
// next_output_timestamp_ will only be advanced if packet_emitted_this_period_
|
||||
// is false. next_output_timestamp_ will never be advanced beyond
|
||||
// current_timestamp's period.
|
||||
//
|
||||
// However, next_output_timestamp_ could fall before current_timestamp's
|
||||
// period since only a single period can be advanced at a time.
|
||||
void UpdateNextOutputTimestamp(Timestamp current_timestamp);
|
||||
|
||||
// Jitter-related variables.
|
||||
std::unique_ptr<RandomBase> random_;
|
||||
|
||||
// Next packet to be emitted. Since packets may not align perfectly with
|
||||
// next_output_timestamp_, the closest packet will be emitted.
|
||||
Timestamp next_output_timestamp_;
|
||||
|
||||
// Lower bound for next timestamp.
|
||||
//
|
||||
// next_output_timestamp_ will be kept within the interval
|
||||
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
|
||||
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
|
||||
|
||||
// Indicates packet was emitted for current period (i.e. the period
|
||||
// next_output_timestamp_ falls in.
|
||||
bool packet_emitted_this_period_ = false;
|
||||
};
|
||||
|
||||
// Strategy that applies Jitter without reflection based sampling.
|
||||
//
|
||||
// Used by PacketResamplerCalculator when Jitter is enabled and reflection is
|
||||
// not enabled.
|
||||
class JitterWithoutReflectionStrategy : public PacketResamplerStrategy {
|
||||
public:
|
||||
JitterWithoutReflectionStrategy(PacketResamplerCalculator* calculator)
|
||||
: PacketResamplerStrategy(calculator) {}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Calculates the first sampled timestamp that incorporates a jittering
|
||||
// offset.
|
||||
void InitializeNextOutputTimestamp();
|
||||
|
||||
// Calculates the next sampled timestamp that incorporates a jittering offset.
|
||||
void UpdateNextOutputTimestamp();
|
||||
|
||||
// Jitter-related variables.
|
||||
std::unique_ptr<RandomBase> random_;
|
||||
|
||||
// Next packet to be emitted. Since packets may not align perfectly with
|
||||
// next_output_timestamp_, the closest packet will be emitted.
|
||||
Timestamp next_output_timestamp_;
|
||||
|
||||
// Lower bound for next timestamp.
|
||||
//
|
||||
// next_output_timestamp_ will be kept within the interval
|
||||
// [next_output_timestamp_min_, next_output_timestamp_min_ + frame_time_usec_)
|
||||
Timestamp next_output_timestamp_min_ = Timestamp::Unset();
|
||||
|
||||
// packet reservior used for sampling random packet out of partial period.
|
||||
std::unique_ptr<PacketReservoir> packet_reservoir_;
|
||||
|
||||
// random number generator used in packet_reservior_.
|
||||
std::unique_ptr<RandomBase> packet_reservoir_random_;
|
||||
};
|
||||
|
||||
// Strategy that applies sampling without any jitter.
|
||||
//
|
||||
// Used by PacketResamplerCalculator when jitter is not enabled.
|
||||
class NoJitterStrategy : public PacketResamplerStrategy {
|
||||
public:
|
||||
NoJitterStrategy(PacketResamplerCalculator* calculator)
|
||||
: PacketResamplerStrategy(calculator) {}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Number of periods that have passed (= #packets sent to the output).
|
||||
int64 period_count_;
|
||||
|
||||
// If specified, output timestamps are aligned with base_timestamp.
|
||||
// Otherwise, they are aligned with the first input timestamp.
|
||||
Timestamp base_timestamp_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_PACKET_RESAMPLER_CALCULATOR_H_
|
||||
|
|
@ -1,113 +0,0 @@
|
|||
// Copyright 2018 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message PacketResamplerCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional PacketResamplerCalculatorOptions ext = 95743844;
|
||||
}
|
||||
|
||||
// The output frame rate measured in frames per second.
|
||||
//
|
||||
// The closest packet in time in each period will be chosen. If there
|
||||
// is no packet in the period then the most recent packet will be chosen
|
||||
// (not the closest in time).
|
||||
optional double frame_rate = 1 [default = -1.0];
|
||||
|
||||
enum OutputHeader {
|
||||
// Do not output a header, even if the input contained one.
|
||||
NONE = 0;
|
||||
// Pass the header, if the input contained one.
|
||||
PASS_HEADER = 1;
|
||||
// Update the frame rate in the header, which must be of type VideoHeader.
|
||||
UPDATE_VIDEO_HEADER = 2;
|
||||
}
|
||||
|
||||
// Whether and what kind of header to place on the output stream.
|
||||
// Note, this is about the actual header, not the VIDEO_HEADER stream.
|
||||
// If this option is set to UPDATE_VIDEO_HEADER then the header will
|
||||
// also be parsed (updated) and passed along to the VIDEO_HEADER stream.
|
||||
optional OutputHeader output_header = 2 [default = NONE];
|
||||
|
||||
// Flush last packet even if its timestamp is greater than the final stream
|
||||
// timestamp.
|
||||
optional bool flush_last_packet = 3 [default = true];
|
||||
|
||||
// Adds jitter to resampling if set, so that Google's sampling is not
|
||||
// externally deterministic.
|
||||
//
|
||||
// When set, the randomizer will be initialized with a seed. Then, the first
|
||||
// sample is chosen randomly (uniform distribution) among frames that
|
||||
// correspond to timestamps [0, 1/frame_rate). Let the chosen frame
|
||||
// correspond to timestamp t. The next frame is chosen randomly (uniform
|
||||
// distribution) among frames that correspond to [t+(1-jitter)/frame_rate,
|
||||
// t+(1+jitter)/frame_rate]. t is updated and the process is repeated.
|
||||
//
|
||||
// Valid values are in the range of [0.0, 1.0] with the default being 0.0 (no
|
||||
// jitter). A typical value would be a value in the range of 0.1-0.25.
|
||||
//
|
||||
// Note that this does NOT guarantee the desired frame rate, but if the
|
||||
// pseudo-random number generator does its job and the number of frames is
|
||||
// sufficiently large, the average frame rate will be close to this value.
|
||||
optional double jitter = 4;
|
||||
|
||||
// Enables reflection when applying jitter.
|
||||
//
|
||||
// This option is ignored when reproducible_sampling is true, in which case
|
||||
// reflection will be used.
|
||||
//
|
||||
// New use cases should use reproducible_sampling = true, as
|
||||
// jitter_with_reflection is deprecated and will be removed at some point.
|
||||
optional bool jitter_with_reflection = 9 [default = false];
|
||||
|
||||
// If set, enabled reproducible sampling, allowing frames to be sampled
|
||||
// without regards to where the stream starts. See
|
||||
// packet_resampler_calculator.h for details.
|
||||
//
|
||||
// This enables reflection (ignoring jitter_with_reflection setting).
|
||||
optional bool reproducible_sampling = 10 [default = false];
|
||||
|
||||
// If specified, output timestamps are aligned with base_timestamp.
|
||||
// Otherwise, they are aligned with the first input timestamp.
|
||||
//
|
||||
// In order to ensure that the outptut timestamps are reproducible,
|
||||
// with round_limits = false, the bounds for input timestamps must include:
|
||||
// [start_time - period / 2, end_time + period / 2],
|
||||
// with round_limits = true, the bounds for input timestamps must include:
|
||||
// [start_time - period, end_time + period],
|
||||
// where period = 1 / frame_rate.
|
||||
//
|
||||
// For example, in PacketResamplerCalculatorOptions specify
|
||||
// "start_time: 3000000", and in MediaDecoderOptions specify
|
||||
// "start_time: 2999950".
|
||||
optional int64 base_timestamp = 5;
|
||||
|
||||
// If specified, only outputs at/after start_time are included.
|
||||
optional int64 start_time = 6;
|
||||
|
||||
// If specified, only outputs before end_time are included.
|
||||
optional int64 end_time = 7;
|
||||
|
||||
// If set, the output timestamps nearest to start_time and end_time
|
||||
// are included in the output, even if the nearest timestamp is not
|
||||
// between start_time and end_time.
|
||||
optional bool round_limits = 8 [default = false];
|
||||
}
|
||||
|
|
@ -1,752 +0,0 @@
|
|||
// Copyright 2018 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/packet_resampler_calculator.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/video_stream_header.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
namespace {
|
||||
// A simple version of CalculatorRunner with built-in convenience
|
||||
// methods for setting inputs from a vector and checking outputs
|
||||
// against expected outputs (both timestamps and contents).
|
||||
class SimpleRunner : public CalculatorRunner {
|
||||
public:
|
||||
explicit SimpleRunner(const std::string& options_string)
|
||||
: CalculatorRunner("PacketResamplerCalculator", options_string, 1, 1, 0) {
|
||||
}
|
||||
explicit SimpleRunner(const CalculatorGraphConfig::Node& node_config)
|
||||
: CalculatorRunner(node_config) {}
|
||||
|
||||
virtual ~SimpleRunner() {}
|
||||
|
||||
void SetInput(const std::vector<int64>& timestamp_list) {
|
||||
MutableInputs()->Index(0).packets.clear();
|
||||
for (const int64 ts : timestamp_list) {
|
||||
MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new std::string(absl::StrCat("Frame #", ts)))
|
||||
.At(Timestamp(ts)));
|
||||
}
|
||||
}
|
||||
|
||||
void SetVideoHeader(const double frame_rate) {
|
||||
video_header_.width = static_count_;
|
||||
video_header_.height = static_count_ * 10;
|
||||
video_header_.frame_rate = frame_rate;
|
||||
video_header_.duration = static_count_ * 100.0;
|
||||
video_header_.format = static_cast<ImageFormat::Format>(
|
||||
static_count_ % ImageFormat::Format_ARRAYSIZE);
|
||||
MutableInputs()->Index(0).header = Adopt(new VideoHeader(video_header_));
|
||||
++static_count_;
|
||||
}
|
||||
|
||||
void CheckOutputTimestamps(
|
||||
const std::vector<int64>& expected_frames,
|
||||
const std::vector<int64>& expected_timestamps) const {
|
||||
EXPECT_EQ(expected_frames.size(), Outputs().Index(0).packets.size());
|
||||
EXPECT_EQ(expected_timestamps.size(), Outputs().Index(0).packets.size());
|
||||
int count = 0;
|
||||
for (const Packet& packet : Outputs().Index(0).packets) {
|
||||
EXPECT_EQ(Timestamp(expected_timestamps[count]), packet.Timestamp());
|
||||
const std::string& packet_contents = packet.Get<std::string>();
|
||||
EXPECT_EQ(std::string(absl::StrCat("Frame #", expected_frames[count])),
|
||||
packet_contents);
|
||||
++count;
|
||||
}
|
||||
}
|
||||
|
||||
void CheckVideoHeader(const double expected_frame_rate) const {
|
||||
ASSERT_FALSE(Outputs().Index(0).header.IsEmpty());
|
||||
const VideoHeader& header = Outputs().Index(0).header.Get<VideoHeader>();
|
||||
const double frame_rate = header.frame_rate;
|
||||
|
||||
EXPECT_EQ(video_header_.width, header.width);
|
||||
EXPECT_EQ(video_header_.height, header.height);
|
||||
EXPECT_DOUBLE_EQ(expected_frame_rate, frame_rate);
|
||||
EXPECT_FLOAT_EQ(video_header_.duration, header.duration);
|
||||
EXPECT_EQ(video_header_.format, header.format);
|
||||
}
|
||||
|
||||
private:
|
||||
VideoHeader video_header_;
|
||||
static int static_count_;
|
||||
};
|
||||
|
||||
// Matcher for Packets with uint64 payload, comparing arg packet's
|
||||
// timestamp and uint64 payload.
|
||||
MATCHER_P2(PacketAtTimestamp, payload, timestamp,
|
||||
absl::StrCat(negation ? "isn't" : "is", " a packet with payload ",
|
||||
payload, " @ time ", timestamp)) {
|
||||
if (timestamp != arg.Timestamp().Value()) {
|
||||
*result_listener << "at incorrect timestamp = " << arg.Timestamp().Value();
|
||||
return false;
|
||||
}
|
||||
int64 actual_payload = arg.template Get<int64>();
|
||||
if (actual_payload != payload) {
|
||||
*result_listener << "with incorrect payload = " << actual_payload;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// JitterWithReflectionStrategy child class which injects a specified stream
|
||||
// of "random" numbers.
|
||||
//
|
||||
// Calculators are created through factory methods, making testing and injection
|
||||
// tricky. This class utilizes a static variable, random_sequence, to pass
|
||||
// the desired random sequence into the calculator.
|
||||
class ReproducibleJitterWithReflectionStrategyForTesting
|
||||
: public ReproducibleJitterWithReflectionStrategy {
|
||||
public:
|
||||
ReproducibleJitterWithReflectionStrategyForTesting(
|
||||
PacketResamplerCalculator* calculator)
|
||||
: ReproducibleJitterWithReflectionStrategy(calculator) {}
|
||||
|
||||
// Statically accessed random sequence to use for jitter with reflection.
|
||||
//
|
||||
// An EXPECT will fail if sequence is less than the number requested during
|
||||
// processing.
|
||||
static std::vector<uint64> random_sequence;
|
||||
|
||||
protected:
|
||||
virtual uint64 GetNextRandom(uint64 n) {
|
||||
EXPECT_LT(sequence_index_, random_sequence.size());
|
||||
return random_sequence[sequence_index_++] % n;
|
||||
}
|
||||
|
||||
private:
|
||||
int32 sequence_index_ = 0;
|
||||
};
|
||||
std::vector<uint64>
|
||||
ReproducibleJitterWithReflectionStrategyForTesting::random_sequence;
|
||||
|
||||
// PacketResamplerCalculator child class which injects a specified stream
|
||||
// of "random" numbers.
|
||||
//
|
||||
// Calculators are created through factory methods, making testing and injection
|
||||
// tricky. This class utilizes a static variable, random_sequence, to pass
|
||||
// the desired random sequence into the calculator.
|
||||
class ReproducibleResamplerCalculatorForTesting
|
||||
: public PacketResamplerCalculator {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
return PacketResamplerCalculator::GetContract(cc);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<class PacketResamplerStrategy> GetSamplingStrategy(
|
||||
const mediapipe::PacketResamplerCalculatorOptions& Options) {
|
||||
return absl::make_unique<
|
||||
ReproducibleJitterWithReflectionStrategyForTesting>(this);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(ReproducibleResamplerCalculatorForTesting);
|
||||
|
||||
int SimpleRunner::static_count_ = 0;
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, NoPacketsInStream) {
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, SinglePacketInStream) {
|
||||
// Stream with 1 packet / 1 period.
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({0});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0}, {0});
|
||||
}
|
||||
|
||||
// Stream with 1 packet / 1 period (0 < packet timestamp < first limit).
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({1000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({1000}, {1000});
|
||||
}
|
||||
|
||||
// Stream with 1 packet / 1 period (packet timestamp > first limit).
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({16668});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({16668}, {16668});
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, TwoPacketsInStream) {
|
||||
// Stream with 2 packets / 1 period.
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({0, 16666});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0}, {0});
|
||||
}
|
||||
|
||||
// Stream with 2 packets / 2 periods (left extreme for second period).
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({0, 16667});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 16667}, {0, 33333});
|
||||
}
|
||||
|
||||
// Stream with 2 packets / 2 periods (right extreme for second period).
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({0, 49999});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 49999}, {0, 33333});
|
||||
}
|
||||
|
||||
// Stream with 2 packets / 3 periods (filling 1 in the middle).
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({0, 50000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 0, 50000}, {0, 33333, 66667});
|
||||
}
|
||||
|
||||
// Stream with 2 packets / 4 periods (filling 2 in the middle).
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({2000, 118666});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({2000, 2000, 2000, 118666},
|
||||
{2000, 35333, 68667, 102000});
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, InputAtExactFrequencyMiddlepoints) {
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({0, 33333, 66667, 100000, 133333, 166667, 200000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps(
|
||||
{0, 33333, 66667, 100000, 133333, 166667, 200000},
|
||||
{0, 33333, 66667, 100000, 133333, 166667, 200000});
|
||||
}
|
||||
|
||||
// When there are several candidates for a period, the one closer to the center
|
||||
// should be sent to the output.
|
||||
TEST(PacketResamplerCalculatorTest, MultiplePacketsForPeriods) {
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({0, 16666, 16667, 20000, 33300, 49999, 50000, 66600});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 33300, 66600}, {0, 33333, 66667});
|
||||
}
|
||||
|
||||
// When a period must be filled, we use the latest packet received (not
|
||||
// necessarily the same as the one stored for the best in the previous period).
|
||||
TEST(PacketResamplerCalculatorTest, FillPeriodsWithLatestPacket) {
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({0, 5000, 16666, 83334});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 16666, 16666, 83334},
|
||||
{0, 33333, 66667, 100000});
|
||||
}
|
||||
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({0, 16666, 16667, 25000, 33000, 35000, 135000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 33000, 35000, 35000, 135000},
|
||||
{0, 33333, 66667, 100000, 133333});
|
||||
}
|
||||
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({0, 15000, 32000, 49999, 150000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 32000, 49999, 49999, 49999, 150000},
|
||||
{0, 33333, 66667, 100000, 133333, 166667});
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, SuperHighFrameRate) {
|
||||
// frame rate == 500000 (a packet will have to be sent every 2 ticks).
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:500000}");
|
||||
runner.SetInput({0, 10, 13});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 0, 0, 0, 0, 10, 10, 13},
|
||||
{0, 2, 4, 6, 8, 10, 12, 14});
|
||||
}
|
||||
|
||||
// frame rate == 1000000 (a packet will have to be sent in each tick).
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:1000000}");
|
||||
runner.SetInput({0, 10, 13});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps(
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 10, 13},
|
||||
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13});
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, NegativeTimestampTest) {
|
||||
// Stream with negative timestamps / 1 period.
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({-200, -20, 16466});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({-200}, {-200});
|
||||
}
|
||||
|
||||
// Stream with negative timestamps / 2 periods.
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({-200, -20, 16467});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({-200, 16467}, {-200, 33133});
|
||||
}
|
||||
|
||||
// Stream with negative timestamps and filling an empty period.
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({-500, 66667});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({-500, -500, 66667}, {-500, 32833, 66167});
|
||||
}
|
||||
|
||||
// Stream with negative timestamps and initial packet < -period.
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({-50000, -33334, 33334});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({-50000, -33334, -33334, 33334},
|
||||
{-50000, -16667, 16667, 50000});
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, ExactFramesPerSecond) {
|
||||
// Using frame_rate=50, that makes a period of 20000 microsends (exact).
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:50}");
|
||||
runner.SetInput({0, 9999, 29999});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 29999}, {0, 20000});
|
||||
}
|
||||
|
||||
// Test filling empty periods.
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:50}");
|
||||
runner.SetInput({0, 10000, 50000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 10000, 10000, 50000},
|
||||
{0, 20000, 40000, 60000});
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, FrameRateTest) {
|
||||
// Test changing Frame Rate to the same initial value.
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:50, output_header:UPDATE_VIDEO_HEADER}");
|
||||
runner.SetInput({0, 10000, 30000, 50000, 60000});
|
||||
runner.SetVideoHeader(50.0);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 10000, 30000, 60000},
|
||||
{0, 20000, 40000, 60000});
|
||||
runner.CheckVideoHeader(50.0);
|
||||
}
|
||||
|
||||
// Test changing Frame Rate to new value.
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:50, output_header:UPDATE_VIDEO_HEADER}");
|
||||
runner.SetInput({0, 5000, 10010, 15001, 19990});
|
||||
runner.SetVideoHeader(200.0);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 19990}, {0, 20000});
|
||||
runner.CheckVideoHeader(50.0);
|
||||
}
|
||||
|
||||
// Test that the frame rate is not changing if update_video_header = false.
|
||||
{
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:50, output_header:PASS_HEADER}");
|
||||
runner.SetInput({0, 5000, 10010, 15001, 19990});
|
||||
runner.SetVideoHeader(200.0);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({0, 19990}, {0, 20000});
|
||||
runner.CheckVideoHeader(200.0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "PacketResamplerCalculator"
|
||||
input_stream: "DATA:in_data"
|
||||
input_stream: "VIDEO_HEADER:in_video_header"
|
||||
output_stream: "DATA:out_data"
|
||||
output_stream: "VIDEO_HEADER:out_video_header"
|
||||
options {
|
||||
[mediapipe.PacketResamplerCalculatorOptions.ext] { frame_rate: 50.0 }
|
||||
}
|
||||
)pb"));
|
||||
|
||||
for (const int64 ts : {0, 5000, 10010, 15001, 19990}) {
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(
|
||||
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
|
||||
}
|
||||
VideoHeader video_header_in;
|
||||
video_header_in.width = 10;
|
||||
video_header_in.height = 100;
|
||||
video_header_in.frame_rate = 1.0;
|
||||
video_header_in.duration = 1.0;
|
||||
video_header_in.format = ImageFormat::SRGB;
|
||||
runner.MutableInputs()
|
||||
->Tag("VIDEO_HEADER")
|
||||
.packets.push_back(
|
||||
Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream()));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
ASSERT_EQ(1, runner.Outputs().Tag("VIDEO_HEADER").packets.size());
|
||||
EXPECT_EQ(Timestamp::PreStream(),
|
||||
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Timestamp());
|
||||
const VideoHeader& video_header_out =
|
||||
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Get<VideoHeader>();
|
||||
EXPECT_EQ(video_header_in.width, video_header_out.width);
|
||||
EXPECT_EQ(video_header_in.height, video_header_out.height);
|
||||
EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate);
|
||||
EXPECT_FLOAT_EQ(video_header_in.duration, video_header_out.duration);
|
||||
EXPECT_EQ(video_header_in.format, video_header_out.format);
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, FlushLastPacketWithoutRound) {
|
||||
SimpleRunner runner(R"(
|
||||
[mediapipe.PacketResamplerCalculatorOptions.ext] {
|
||||
frame_rate: 1
|
||||
})");
|
||||
runner.SetInput({0, 333333, 666667, 1000000, 1333333});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
// 1333333 is not emitted as 2000000, because it does not round to 2000000.
|
||||
runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000});
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, FlushLastPacketWithRound) {
|
||||
SimpleRunner runner(R"(
|
||||
[mediapipe.PacketResamplerCalculatorOptions.ext] {
|
||||
frame_rate: 1
|
||||
})");
|
||||
runner.SetInput({0, 333333, 666667, 1000000, 1333333, 1666667});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
// 1666667 is emitted as 2000000, because it rounds to 2000000.
|
||||
runner.CheckOutputTimestamps({0, 1000000, 1666667}, {0, 1000000, 2000000});
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, DoNotFlushLastPacketWithoutRound) {
|
||||
SimpleRunner runner(R"(
|
||||
[mediapipe.PacketResamplerCalculatorOptions.ext] {
|
||||
frame_rate: 1
|
||||
flush_last_packet: false
|
||||
})");
|
||||
runner.SetInput({0, 333333, 666667, 1000000, 1333333});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
// 1333333 is not emitted no matter what; see FlushLastPacketWithoutRound.
|
||||
runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000});
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, DoNotFlushLastPacketWithRound) {
|
||||
SimpleRunner runner(R"(
|
||||
[mediapipe.PacketResamplerCalculatorOptions.ext] {
|
||||
frame_rate: 1
|
||||
flush_last_packet: false
|
||||
})");
|
||||
runner.SetInput({0, 333333, 666667, 1000000, 1333333, 1666667});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
// 1666667 is not emitted due to flush_last_packet: false.
|
||||
runner.CheckOutputTimestamps({0, 1000000}, {0, 1000000});
|
||||
}
|
||||
|
||||
// When base_timestamp is specified, output timestamps are aligned with it.
|
||||
TEST(PacketResamplerCalculatorTest, InputAtExactFrequencyMiddlepointsAligned) {
|
||||
{
|
||||
// Without base_timestamp, outputs are aligned with the first input
|
||||
// timestamp, (33333 - 222).
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({33111, 66667, 100000, 133333, 166667, 200000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({33111, 66667, 100000, 133333, 166667, 200000},
|
||||
{33111, 66444, 99778, 133111, 166444, 199778});
|
||||
}
|
||||
{
|
||||
// With base_timestamp, outputs are aligned with base_timestamp, 0.
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30 "
|
||||
"base_timestamp:0}");
|
||||
runner.SetInput({33111, 66667, 100000, 133333, 166667, 200000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps(
|
||||
{33111, 66667, 100000, 133333, 166667, 200000},
|
||||
{33333, 66666, 100000, 133333, 166666, 200000});
|
||||
}
|
||||
}
|
||||
|
||||
// When base_timestamp is specified, output timestamps are aligned with it.
|
||||
TEST(PacketResamplerCalculatorTest, MultiplePacketsForPeriodsAligned) {
|
||||
{
|
||||
// Without base_timestamp, outputs are aligned with the first input, -222.
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({-222, 16666, 16667, 20000, 33300, 49999, 50000, 66600});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({-222, 33300, 66600}, {-222, 33111, 66445});
|
||||
}
|
||||
{
|
||||
// With base_timestamp, outputs are aligned with base_timestamp, 900011.
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30 "
|
||||
"base_timestamp:900011}");
|
||||
runner.SetInput({-222, 16666, 16667, 20000, 33300, 49999, 50000, 66600});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({-222, 33300, 66600}, {11, 33344, 66678});
|
||||
}
|
||||
{
|
||||
// With base_timestamp, outputs still approximate input timestamps,
|
||||
// while aligned to base_timestamp, 11.
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30 "
|
||||
"base_timestamp:11}");
|
||||
runner.SetInput(
|
||||
{899888, 916666, 916667, 920000, 933300, 949999, 950000, 966600});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({899888, 933300, 966600},
|
||||
{900011, 933344, 966678});
|
||||
}
|
||||
}
|
||||
|
||||
// When a period must be filled, we use the latest packet received.
|
||||
// When base_timestamp is specified, output timestamps are aligned with it.
|
||||
TEST(PacketResamplerCalculatorTest, FillPeriodsWithLatestPacketAligned) {
|
||||
{
|
||||
// Without base_timestamp, outputs are aligned with the first input, -222.
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30}");
|
||||
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000},
|
||||
{-222, 33111, 66445, 99778, 133111, 166445});
|
||||
}
|
||||
{
|
||||
// With base_timestamp, outputs are aligned with base_timestamp, 0.
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30 "
|
||||
"base_timestamp:0}");
|
||||
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000},
|
||||
{0, 33333, 66667, 100000, 133333, 166667});
|
||||
}
|
||||
}
|
||||
|
||||
// When base_timestamp is specified, output timestamps are aligned with it.
|
||||
// The first packet is included, because we assume that the input includes the
|
||||
// whole first sampling interval.
|
||||
TEST(PacketResamplerCalculatorTest, FirstInputAfterMiddlepointAligned) {
|
||||
{
|
||||
// Packet 100020 is omitted from the output sequence because
|
||||
// packet 99990 is closer to the period midpoint.
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30 "
|
||||
"base_timestamp:0}");
|
||||
runner.SetInput({66667, 100020, 133333, 166667});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({66667, 100020, 133333, 166667},
|
||||
{66667, 100000, 133334, 166667});
|
||||
}
|
||||
{
|
||||
// If we seek to packet 100020, packet 100020 is included in
|
||||
// the output sequence, because we assume that the input includes the
|
||||
// whole first sampling interval.
|
||||
//
|
||||
// We assume that the input includes whole sampling intervals
|
||||
// in order to produce "reproducible timestamps", which are timestamps
|
||||
// from the series of timestamps starting at 0.
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30 "
|
||||
"base_timestamp:0}");
|
||||
runner.SetInput({100020, 133333, 166667});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({100020, 133333, 166667},
|
||||
{100000, 133333, 166667});
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, OutputTimestampRangeAligned) {
|
||||
{
|
||||
// With base_timestamp, outputs are aligned with base_timestamp, 0.
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30 "
|
||||
"base_timestamp:0}");
|
||||
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({-222, 32000, 49999, 49999, 49999, 150000},
|
||||
{0, 33333, 66667, 100000, 133333, 166667});
|
||||
}
|
||||
{
|
||||
// With start_time, end_time, outputs are filtered.
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30 "
|
||||
"base_timestamp:0 "
|
||||
"start_time:40000 "
|
||||
"end_time:160000}");
|
||||
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({49999, 49999, 49999},
|
||||
{66667, 100000, 133333});
|
||||
}
|
||||
{
|
||||
// With start_time, end_time, round_limits, outputs are filtered,
|
||||
// rounding to the nearest limit.
|
||||
SimpleRunner runner(
|
||||
"[mediapipe.PacketResamplerCalculatorOptions.ext]: "
|
||||
"{frame_rate:30 "
|
||||
"base_timestamp:0 "
|
||||
"start_time:40000 "
|
||||
"end_time:160000 "
|
||||
"round_limits:true}");
|
||||
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
runner.CheckOutputTimestamps({32000, 49999, 49999, 49999, 150000},
|
||||
{33333, 66667, 100000, 133333, 166667});
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PacketResamplerCalculatorTest, OptionsSidePacket) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "PacketResamplerCalculator"
|
||||
input_side_packet: "OPTIONS:options"
|
||||
input_stream: "input"
|
||||
output_stream: "output"
|
||||
options {
|
||||
[mediapipe.PacketResamplerCalculatorOptions.ext] {
|
||||
frame_rate: 60
|
||||
base_timestamp: 0
|
||||
}
|
||||
})pb");
|
||||
|
||||
{
|
||||
SimpleRunner runner(node_config);
|
||||
auto options =
|
||||
new CalculatorOptions(ParseTextProtoOrDie<CalculatorOptions>(
|
||||
R"pb(
|
||||
[mediapipe.PacketResamplerCalculatorOptions.ext] {
|
||||
frame_rate: 30
|
||||
})pb"));
|
||||
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
|
||||
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
EXPECT_EQ(6, runner.Outputs().Index(0).packets.size());
|
||||
}
|
||||
{
|
||||
SimpleRunner runner(node_config);
|
||||
|
||||
auto options =
|
||||
new CalculatorOptions(ParseTextProtoOrDie<CalculatorOptions>(R"pb(
|
||||
merge_fields: false
|
||||
[mediapipe.PacketResamplerCalculatorOptions.ext] {
|
||||
frame_rate: 30
|
||||
base_timestamp: 0
|
||||
})pb"));
|
||||
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
|
||||
|
||||
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
EXPECT_EQ(6, runner.Outputs().Index(0).packets.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,317 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Declaration of PacketThinnerCalculator.
|
||||
|
||||
#include <cmath> // for ceil
|
||||
#include <memory>
|
||||
|
||||
#include "mediapipe/calculators/core/packet_thinner_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/video_stream_header.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/tool/options_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
const double kTimebaseUs = 1000000; // Microseconds.
|
||||
const char* const kOptionsTag = "OPTIONS";
|
||||
const char* const kPeriodTag = "PERIOD";
|
||||
} // namespace
|
||||
|
||||
// This calculator is used to thin an input stream of Packets.
|
||||
// An example application would be to sample decoded frames of video
|
||||
// at a coarser temporal resolution. Unless otherwise stated, all
|
||||
// timestamps are in units of microseconds.
|
||||
//
|
||||
// Thinning can be accomplished in one of two ways:
|
||||
// 1) asynchronous thinning (known below as async):
|
||||
// Algorithm does not rely on a master clock and is parameterized only
|
||||
// by a single option -- the period. Once a packet is emitted, the
|
||||
// thinner will discard subsequent packets for the duration of the period
|
||||
// [Analogous to a refractory period during which packet emission is
|
||||
// suppressed.]
|
||||
// Packets arriving before start_time are discarded, as are packets
|
||||
// arriving at or after end_time.
|
||||
// 2) synchronous thinning (known below as sync):
|
||||
// There are two variants of this algorithm, both parameterized by a
|
||||
// start_time and a period. As in (1), packets arriving before start_time
|
||||
// or at/after end_time are discarded. Otherwise, at most one packet is
|
||||
// emitted during a period, centered at timestamps generated by the
|
||||
// expression:
|
||||
// start_time + i * period [where i is a non-negative integer]
|
||||
// During each period, the packet closest to the generated timestamp is
|
||||
// emitted (latest in the case of ties). In the first variant
|
||||
// (sync_output_timestamps = true), the emitted packet is output at the
|
||||
// generated timestamp. In the second variant, the packet is output at
|
||||
// its original timestamp. Both variants emit exactly the same packets,
|
||||
// but at different timestamps.
|
||||
//
|
||||
// Thinning period can be provided in the calculator options or via a
|
||||
// side packet with the tag "PERIOD".
|
||||
//
|
||||
// Calculator options provided optionally with the "OPTIONS" input
|
||||
// sidepacket tag will be merged with this calculator's node options, i.e.,
|
||||
// singular fields of the side packet will overwrite the options defined in the
|
||||
// node, and repeated fields will concatenate.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "PacketThinnerCalculator"
|
||||
// input_side_packet: "OPTIONS:calculator_options"
|
||||
// input_stream: "signal"
|
||||
// output_stream: "output"
|
||||
// options {
|
||||
// [mediapipe.PacketThinnerCalculatorOptions.ext] {
|
||||
// thinner_type: SYNC
|
||||
// period: 10
|
||||
// sync_output_timestamps: true
|
||||
// update_frame_rate: false
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class PacketThinnerCalculator : public CalculatorBase {
|
||||
public:
|
||||
PacketThinnerCalculator() {}
|
||||
~PacketThinnerCalculator() override {}
|
||||
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
|
||||
cc->InputSidePackets().Tag(kOptionsTag).Set<CalculatorOptions>();
|
||||
}
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
if (cc->InputSidePackets().HasTag(kPeriodTag)) {
|
||||
cc->InputSidePackets().Tag(kPeriodTag).Set<int64>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (cc->InputTimestamp() < start_time_) {
|
||||
return absl::OkStatus(); // Drop packets before start_time_.
|
||||
} else if (cc->InputTimestamp() >= end_time_) {
|
||||
if (!cc->Outputs().Index(0).IsClosed()) {
|
||||
cc->Outputs()
|
||||
.Index(0)
|
||||
.Close(); // No more Packets will be output after end_time_.
|
||||
}
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return thinner_type_ == PacketThinnerCalculatorOptions::ASYNC
|
||||
? AsyncThinnerProcess(cc)
|
||||
: SyncThinnerProcess(cc);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Implementation of ASYNC and SYNC versions of thinner algorithm.
|
||||
absl::Status AsyncThinnerProcess(CalculatorContext* cc);
|
||||
absl::Status SyncThinnerProcess(CalculatorContext* cc);
|
||||
|
||||
// Cached option.
|
||||
PacketThinnerCalculatorOptions::ThinnerType thinner_type_;
|
||||
|
||||
// Given a Timestamp, finds the closest sync Timestamp
|
||||
// based on start_time_ and period_. This can be earlier or
|
||||
// later than given Timestamp, but is guaranteed to be within
|
||||
// half a period_.
|
||||
Timestamp NearestSyncTimestamp(Timestamp now) const;
|
||||
|
||||
// Cached option used by both async and sync thinners.
|
||||
TimestampDiff period_; // Interval during which only one packet is emitted.
|
||||
Timestamp start_time_; // Cached option - default Timestamp::Min()
|
||||
Timestamp end_time_; // Cached option - default Timestamp::Max()
|
||||
|
||||
// Only used by async thinner:
|
||||
Timestamp next_valid_timestamp_; // Suppress packets until this timestamp.
|
||||
|
||||
// Only used by sync thinner:
|
||||
Packet saved_packet_; // Best packet not yet emitted.
|
||||
bool sync_output_timestamps_; // Cached option.
|
||||
};
|
||||
REGISTER_CALCULATOR(PacketThinnerCalculator);
|
||||
|
||||
namespace {
|
||||
TimestampDiff abs(TimestampDiff t) { return t < 0 ? -t : t; }
|
||||
} // namespace
|
||||
|
||||
absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) {
|
||||
PacketThinnerCalculatorOptions options = mediapipe::tool::RetrieveOptions(
|
||||
cc->Options<PacketThinnerCalculatorOptions>(), cc->InputSidePackets(),
|
||||
kOptionsTag);
|
||||
|
||||
thinner_type_ = options.thinner_type();
|
||||
// This check enables us to assume only two thinner types exist in Process()
|
||||
CHECK(thinner_type_ == PacketThinnerCalculatorOptions::ASYNC ||
|
||||
thinner_type_ == PacketThinnerCalculatorOptions::SYNC)
|
||||
<< "Unsupported thinner type.";
|
||||
|
||||
if (thinner_type_ == PacketThinnerCalculatorOptions::ASYNC) {
|
||||
// ASYNC thinner outputs packets with the same timestamp as their input so
|
||||
// its safe to SetOffset(0). SYNC thinner manipulates timestamps of its
|
||||
// output so we don't do this for that case.
|
||||
cc->SetOffset(0);
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag(kPeriodTag)) {
|
||||
period_ =
|
||||
TimestampDiff(cc->InputSidePackets().Tag(kPeriodTag).Get<int64>());
|
||||
} else {
|
||||
period_ = TimestampDiff(options.period());
|
||||
}
|
||||
CHECK_LT(TimestampDiff(0), period_) << "Specified period must be positive.";
|
||||
|
||||
if (options.has_start_time()) {
|
||||
start_time_ = Timestamp(options.start_time());
|
||||
} else if (thinner_type_ == PacketThinnerCalculatorOptions::ASYNC) {
|
||||
start_time_ = Timestamp::Min();
|
||||
} else {
|
||||
start_time_ = Timestamp(0);
|
||||
}
|
||||
|
||||
end_time_ =
|
||||
options.has_end_time() ? Timestamp(options.end_time()) : Timestamp::Max();
|
||||
CHECK_LT(start_time_, end_time_)
|
||||
<< "Invalid PacketThinner: start_time must be earlier than end_time";
|
||||
|
||||
sync_output_timestamps_ = options.sync_output_timestamps();
|
||||
|
||||
next_valid_timestamp_ = start_time_;
|
||||
// Drop packets until this time.
|
||||
cc->Outputs().Index(0).SetNextTimestampBound(start_time_);
|
||||
|
||||
if (!cc->Inputs().Index(0).Header().IsEmpty()) {
|
||||
if (options.update_frame_rate()) {
|
||||
const VideoHeader& video_header =
|
||||
cc->Inputs().Index(0).Header().Get<VideoHeader>();
|
||||
double new_frame_rate;
|
||||
if (thinner_type_ == PacketThinnerCalculatorOptions::ASYNC) {
|
||||
new_frame_rate =
|
||||
video_header.frame_rate /
|
||||
ceil(video_header.frame_rate * options.period() / kTimebaseUs);
|
||||
} else {
|
||||
const double sampling_rate = kTimebaseUs / options.period();
|
||||
new_frame_rate = video_header.frame_rate < sampling_rate
|
||||
? video_header.frame_rate
|
||||
: sampling_rate;
|
||||
}
|
||||
std::unique_ptr<VideoHeader> header(new VideoHeader);
|
||||
header->format = video_header.format;
|
||||
header->width = video_header.width;
|
||||
header->height = video_header.height;
|
||||
header->frame_rate = new_frame_rate;
|
||||
cc->Outputs().Index(0).SetHeader(Adopt(header.release()));
|
||||
} else {
|
||||
cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header());
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status PacketThinnerCalculator::Close(CalculatorContext* cc) {
|
||||
// Emit any saved packets before quitting.
|
||||
if (!saved_packet_.IsEmpty()) {
|
||||
// Only sync thinner should have saved packets.
|
||||
CHECK_EQ(PacketThinnerCalculatorOptions::SYNC, thinner_type_);
|
||||
if (sync_output_timestamps_) {
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
saved_packet_.At(NearestSyncTimestamp(saved_packet_.Timestamp())));
|
||||
} else {
|
||||
cc->Outputs().Index(0).AddPacket(saved_packet_);
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status PacketThinnerCalculator::AsyncThinnerProcess(
|
||||
CalculatorContext* cc) {
|
||||
if (cc->InputTimestamp() >= next_valid_timestamp_) {
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
cc->Inputs().Index(0).Value()); // Emit current packet.
|
||||
next_valid_timestamp_ = cc->InputTimestamp() + period_;
|
||||
// Guaranteed not to emit packets seen during refractory period.
|
||||
cc->Outputs().Index(0).SetNextTimestampBound(next_valid_timestamp_);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status PacketThinnerCalculator::SyncThinnerProcess(
|
||||
CalculatorContext* cc) {
|
||||
if (saved_packet_.IsEmpty()) {
|
||||
// If no packet has been saved, store the current packet.
|
||||
saved_packet_ = cc->Inputs().Index(0).Value();
|
||||
cc->Outputs().Index(0).SetNextTimestampBound(
|
||||
sync_output_timestamps_ ? NearestSyncTimestamp(cc->InputTimestamp())
|
||||
: cc->InputTimestamp());
|
||||
} else {
|
||||
// Saved packet exists -- update or emit.
|
||||
const Timestamp saved = saved_packet_.Timestamp();
|
||||
const Timestamp saved_sync = NearestSyncTimestamp(saved);
|
||||
const Timestamp now = cc->InputTimestamp();
|
||||
const Timestamp now_sync = NearestSyncTimestamp(now);
|
||||
CHECK_LE(saved_sync, now_sync);
|
||||
if (saved_sync == now_sync) {
|
||||
// Saved Packet is in same interval as current packet.
|
||||
// Replace saved packet with current if it is at least as
|
||||
// central as the saved packet wrt temporal interval.
|
||||
// [We break ties in favor of fresher packets]
|
||||
if (abs(now - now_sync) <= abs(saved - saved_sync)) {
|
||||
saved_packet_ = cc->Inputs().Index(0).Value();
|
||||
}
|
||||
} else {
|
||||
// Saved packet is the best packet from earlier interval: emit!
|
||||
if (sync_output_timestamps_) {
|
||||
cc->Outputs().Index(0).AddPacket(saved_packet_.At(saved_sync));
|
||||
cc->Outputs().Index(0).SetNextTimestampBound(now_sync);
|
||||
} else {
|
||||
cc->Outputs().Index(0).AddPacket(saved_packet_);
|
||||
cc->Outputs().Index(0).SetNextTimestampBound(now);
|
||||
}
|
||||
// Current packet is the first one we've seen from new interval -- save!
|
||||
saved_packet_ = cc->Inputs().Index(0).Value();
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Timestamp PacketThinnerCalculator::NearestSyncTimestamp(Timestamp now) const {
|
||||
CHECK_NE(start_time_, Timestamp::Unset())
|
||||
<< "Method only valid for sync thinner calculator.";
|
||||
|
||||
// Computation is done using int64 arithmetic. No easy way to avoid
|
||||
// since Timestamps don't support div and multiply.
|
||||
const int64 now64 = now.Value();
|
||||
const int64 start64 = start_time_.Value();
|
||||
const int64 period64 = period_.Value();
|
||||
CHECK_LE(0, period64);
|
||||
|
||||
// Round now64 to its closest interval (units of period64).
|
||||
int64 sync64 =
|
||||
(now64 - start64 + period64 / 2) / period64 * period64 + start64;
|
||||
CHECK_LE(abs(now64 - sync64), period64 / 2)
|
||||
<< "start64: " << start64 << "; now64: " << now64
|
||||
<< "; sync64: " << sync64;
|
||||
|
||||
return Timestamp(sync64);
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
// Copyright 2018 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message PacketThinnerCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional PacketThinnerCalculatorOptions ext = 288533508;
|
||||
}
|
||||
|
||||
enum ThinnerType {
|
||||
ASYNC = 1; // Asynchronous thinner, described below [default].
|
||||
SYNC = 2; // Synchronous thinner, also described below.
|
||||
}
|
||||
optional ThinnerType thinner_type = 1 [default = ASYNC];
|
||||
|
||||
// The period (in microsecond) specifies the temporal interval during which
|
||||
// only a single packet is emitted in the output stream. Has subtly different
|
||||
// semantics depending on the thinner type, as follows.
|
||||
//
|
||||
// Async thinner: this option is a refractory period -- once a packet is
|
||||
// emitted, we guarantee that no packets will be emitted for period ticks.
|
||||
//
|
||||
// Sync thinner: the period specifies a temporal interval during which
|
||||
// only one packet is emitted. The emitted packet is guaranteed to be
|
||||
// the one closest to the center of the temporal interval (no guarantee on
|
||||
// how ties are broken). More specifically,
|
||||
// intervals are centered at start_time + i * period
|
||||
// (for non-negative integers i).
|
||||
// Thus, each interval extends period/2 ticks before and after its center.
|
||||
// Additionally, in the sync thinner any packets earlier than start_time
|
||||
// are discarded and the thinner calls Close() once timestamp equals or
|
||||
// exceeds end_time.
|
||||
optional int64 period = 2 [default = 1];
|
||||
|
||||
// Packets before start_time and at/after end_time are discarded.
|
||||
// Additionally, for a sync thinner, start time specifies the center of
|
||||
// time invervals as described above and therefore should be set explicitly.
|
||||
optional int64 start_time = 3; // If not specified, set to 0 for SYNC type,
|
||||
// and set to Timestamp::Min() for ASYNC type.
|
||||
optional int64 end_time = 4; // Set to Timestamp::Max() if not specified.
|
||||
|
||||
// Whether the timestamps of packets emitted by sync thinner should
|
||||
// correspond to the center of their corresponding temporal interval.
|
||||
// If false, packets emitted using original timestamp (as in async thinner).
|
||||
optional bool sync_output_timestamps = 5 [default = true];
|
||||
|
||||
// If true, update the frame rate in the header, if it's available, to an
|
||||
// estimated frame rate due to the sampling.
|
||||
optional bool update_frame_rate = 6 [default = false];
|
||||
}
|
||||
|
|
@ -1,357 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/calculators/core/packet_thinner_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/video_stream_header.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
// A simple version of CalculatorRunner with built-in convenience methods for
|
||||
// setting inputs from a vector and checking outputs against a vector of
|
||||
// expected outputs.
|
||||
class SimpleRunner : public CalculatorRunner {
|
||||
public:
|
||||
explicit SimpleRunner(const CalculatorOptions& options)
|
||||
: CalculatorRunner("PacketThinnerCalculator", options) {
|
||||
SetNumInputs(1);
|
||||
SetNumOutputs(1);
|
||||
SetNumInputSidePackets(0);
|
||||
}
|
||||
|
||||
explicit SimpleRunner(const CalculatorGraphConfig::Node& node)
|
||||
: CalculatorRunner(node) {}
|
||||
|
||||
void SetInput(const std::vector<int>& timestamp_list) {
|
||||
MutableInputs()->Index(0).packets.clear();
|
||||
for (const int ts : timestamp_list) {
|
||||
MutableInputs()->Index(0).packets.push_back(
|
||||
MakePacket<std::string>(absl::StrCat("Frame #", ts))
|
||||
.At(Timestamp(ts)));
|
||||
}
|
||||
}
|
||||
|
||||
void SetFrameRate(const double frame_rate) {
|
||||
auto video_header = absl::make_unique<VideoHeader>();
|
||||
video_header->frame_rate = frame_rate;
|
||||
MutableInputs()->Index(0).header = Adopt(video_header.release());
|
||||
}
|
||||
|
||||
std::vector<int64> GetOutputTimestamps() const {
|
||||
std::vector<int64> timestamps;
|
||||
for (const Packet& packet : Outputs().Index(0).packets) {
|
||||
timestamps.emplace_back(packet.Timestamp().Value());
|
||||
}
|
||||
return timestamps;
|
||||
}
|
||||
|
||||
double GetFrameRate() const {
|
||||
CHECK(!Outputs().Index(0).header.IsEmpty());
|
||||
return Outputs().Index(0).header.Get<VideoHeader>().frame_rate;
|
||||
}
|
||||
};
|
||||
|
||||
// Check that thinner respects start_time and end_time options.
|
||||
// We only test with one thinner because the logic for start & end time
|
||||
// handling is shared across both types of thinner in Process().
|
||||
TEST(PacketThinnerCalculatorTest, StartAndEndTimeTest) {
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
|
||||
extension->set_period(5);
|
||||
extension->set_start_time(4);
|
||||
extension->set_end_time(12);
|
||||
SimpleRunner runner(options);
|
||||
runner.SetInput({2, 3, 5, 7, 11, 13, 17, 19, 23, 29});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {5, 11};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, AsyncUniformStreamThinningTest) {
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
|
||||
extension->set_period(5);
|
||||
SimpleRunner runner(options);
|
||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {2, 8, 14};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, ASyncUniformStreamThinningTestBySidePacket) {
|
||||
// Note: sync runner but outputting *original* timestamps.
|
||||
CalculatorGraphConfig::Node node;
|
||||
node.set_calculator("PacketThinnerCalculator");
|
||||
node.add_input_side_packet("PERIOD:period");
|
||||
node.add_input_stream("input_stream");
|
||||
node.add_output_stream("output_stream");
|
||||
auto* extension = node.mutable_options()->MutableExtension(
|
||||
PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
|
||||
extension->set_start_time(0);
|
||||
extension->set_sync_output_timestamps(false);
|
||||
|
||||
SimpleRunner runner(node);
|
||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {2, 8, 14};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTest1) {
|
||||
// Note: sync runner but outputting *original* timestamps.
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
|
||||
extension->set_start_time(0);
|
||||
extension->set_period(5);
|
||||
extension->set_sync_output_timestamps(false);
|
||||
SimpleRunner runner(options);
|
||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {2, 6, 10, 14};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTestBySidePacket1) {
|
||||
// Note: sync runner but outputting *original* timestamps.
|
||||
CalculatorGraphConfig::Node node;
|
||||
node.set_calculator("PacketThinnerCalculator");
|
||||
node.add_input_side_packet("PERIOD:period");
|
||||
node.add_input_stream("input_stream");
|
||||
node.add_output_stream("output_stream");
|
||||
auto* extension = node.mutable_options()->MutableExtension(
|
||||
PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
|
||||
extension->set_start_time(0);
|
||||
extension->set_sync_output_timestamps(false);
|
||||
|
||||
SimpleRunner runner(node);
|
||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {2, 6, 10, 14};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTest2) {
|
||||
// Same test but now with synced timestamps.
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
|
||||
extension->set_start_time(0);
|
||||
extension->set_period(5);
|
||||
extension->set_sync_output_timestamps(true);
|
||||
SimpleRunner runner(options);
|
||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {0, 5, 10, 15};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
}
|
||||
|
||||
// Test: Given a stream with timestamps corresponding to first ten prime numbers
|
||||
// and period of 5, confirm whether timestamps of thinner stream matches
|
||||
// expectations.
|
||||
TEST(PacketThinnerCalculatorTest, PrimeStreamThinningTest1) {
|
||||
// ASYNC thinner.
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
|
||||
extension->set_period(5);
|
||||
SimpleRunner runner(options);
|
||||
runner.SetInput({2, 3, 5, 7, 11, 13, 17, 19, 23, 29});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {2, 7, 13, 19, 29};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, PrimeStreamThinningTest2) {
|
||||
// SYNC with original timestamps.
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
|
||||
extension->set_start_time(0);
|
||||
extension->set_period(5);
|
||||
extension->set_sync_output_timestamps(false);
|
||||
SimpleRunner runner(options);
|
||||
runner.SetInput({2, 3, 5, 7, 11, 13, 17, 19, 23, 29});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {2, 5, 11, 17, 19, 23, 29};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
}
|
||||
|
||||
// Confirm that Calculator correctly handles boundary cases.
|
||||
TEST(PacketThinnerCalculatorTest, BoundaryTimestampTest1) {
|
||||
// Odd period, negative start_time
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
|
||||
extension->set_start_time(-10);
|
||||
extension->set_period(5);
|
||||
extension->set_sync_output_timestamps(true);
|
||||
SimpleRunner runner(options);
|
||||
// Two timestamps falling on either side of a period boundary.
|
||||
runner.SetInput({2, 3});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {0, 5};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, BoundaryTimestampTest2) {
|
||||
// Even period, negative start_time, negative packet timestamps.
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
|
||||
extension->set_start_time(-144);
|
||||
extension->set_period(6);
|
||||
extension->set_sync_output_timestamps(true);
|
||||
SimpleRunner runner(options);
|
||||
// Two timestamps falling on either side of a period boundary.
|
||||
runner.SetInput({-4, -3, 8, 9});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {-6, 0, 6, 12};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, FrameRateTest1) {
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
|
||||
extension->set_period(5);
|
||||
extension->set_update_frame_rate(true);
|
||||
SimpleRunner runner(options);
|
||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||
runner.SetFrameRate(1000000.0 / 2);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {2, 8, 14};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
// The true sampling period is 6.
|
||||
EXPECT_DOUBLE_EQ(1000000.0 / 6, runner.GetFrameRate());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, FrameRateTest2) {
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::ASYNC);
|
||||
extension->set_period(5);
|
||||
extension->set_update_frame_rate(true);
|
||||
SimpleRunner runner(options);
|
||||
runner.SetInput({8, 16, 24, 32, 40, 48, 56});
|
||||
runner.SetFrameRate(1000000.0 / 8);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<int64> expected_timestamps = {8, 16, 24, 32, 40, 48, 56};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
// The true sampling period is still 8.
|
||||
EXPECT_DOUBLE_EQ(1000000.0 / 8, runner.GetFrameRate());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, FrameRateTest3) {
|
||||
// Note: sync runner but outputting *original* timestamps.
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
|
||||
extension->set_start_time(0);
|
||||
extension->set_period(5);
|
||||
extension->set_sync_output_timestamps(false);
|
||||
extension->set_update_frame_rate(true);
|
||||
SimpleRunner runner(options);
|
||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||
runner.SetFrameRate(1000000.0 / 2);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {2, 6, 10, 14};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
// The true (long-run) sampling period is 5.
|
||||
EXPECT_DOUBLE_EQ(1000000.0 / 5, runner.GetFrameRate());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, FrameRateTest4) {
|
||||
// Same test but now with synced timestamps.
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
|
||||
extension->set_start_time(0);
|
||||
extension->set_period(5);
|
||||
extension->set_sync_output_timestamps(true);
|
||||
extension->set_update_frame_rate(true);
|
||||
SimpleRunner runner(options);
|
||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||
runner.SetFrameRate(1000000.0 / 2);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {0, 5, 10, 15};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
// The true (long-run) sampling period is 5.
|
||||
EXPECT_DOUBLE_EQ(1000000.0 / 5, runner.GetFrameRate());
|
||||
}
|
||||
|
||||
TEST(PacketThinnerCalculatorTest, FrameRateTest5) {
|
||||
CalculatorOptions options;
|
||||
auto* extension =
|
||||
options.MutableExtension(PacketThinnerCalculatorOptions::ext);
|
||||
extension->set_thinner_type(PacketThinnerCalculatorOptions::SYNC);
|
||||
extension->set_start_time(0);
|
||||
extension->set_period(5);
|
||||
extension->set_sync_output_timestamps(true);
|
||||
extension->set_update_frame_rate(true);
|
||||
SimpleRunner runner(options);
|
||||
runner.SetInput({8, 16, 24, 32, 40, 48, 56});
|
||||
runner.SetFrameRate(1000000.0 / 8);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {10, 15, 25, 30, 40, 50, 55};
|
||||
EXPECT_EQ(expected_timestamps, runner.GetOutputTimestamps());
|
||||
// The true (long-run) sampling period is 8.
|
||||
EXPECT_DOUBLE_EQ(1000000.0 / 8, runner.GetFrameRate());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,98 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A Calculator that simply passes its input Packets and header through,
|
||||
// unchanged. The inputs may be specified by tag or index. The outputs
|
||||
// must match the inputs exactly. Any number of input side packets may
|
||||
// also be specified. If output side packets are specified, they must
|
||||
// match the input side packets exactly and the Calculator passes its
|
||||
// input side packets through, unchanged. Otherwise, the input side
|
||||
// packets will be ignored (allowing PassThroughCalculator to be used to
|
||||
// test internal behavior). Any options may be specified and will be
|
||||
// ignored.
|
||||
class PassThroughCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
if (!cc->Inputs().TagMap()->SameAs(*cc->Outputs().TagMap())) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Input and output streams to PassThroughCalculator must use "
|
||||
"matching tags and indexes.");
|
||||
}
|
||||
for (CollectionItemId id = cc->Inputs().BeginId();
|
||||
id < cc->Inputs().EndId(); ++id) {
|
||||
cc->Inputs().Get(id).SetAny();
|
||||
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Get(id));
|
||||
}
|
||||
for (CollectionItemId id = cc->InputSidePackets().BeginId();
|
||||
id < cc->InputSidePackets().EndId(); ++id) {
|
||||
cc->InputSidePackets().Get(id).SetAny();
|
||||
}
|
||||
if (cc->OutputSidePackets().NumEntries() != 0) {
|
||||
if (!cc->InputSidePackets().TagMap()->SameAs(
|
||||
*cc->OutputSidePackets().TagMap())) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Input and output side packets to PassThroughCalculator must use "
|
||||
"matching tags and indexes.");
|
||||
}
|
||||
for (CollectionItemId id = cc->InputSidePackets().BeginId();
|
||||
id < cc->InputSidePackets().EndId(); ++id) {
|
||||
cc->OutputSidePackets().Get(id).SetSameAs(
|
||||
&cc->InputSidePackets().Get(id));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
for (CollectionItemId id = cc->Inputs().BeginId();
|
||||
id < cc->Inputs().EndId(); ++id) {
|
||||
if (!cc->Inputs().Get(id).Header().IsEmpty()) {
|
||||
cc->Outputs().Get(id).SetHeader(cc->Inputs().Get(id).Header());
|
||||
}
|
||||
}
|
||||
if (cc->OutputSidePackets().NumEntries() != 0) {
|
||||
for (CollectionItemId id = cc->InputSidePackets().BeginId();
|
||||
id < cc->InputSidePackets().EndId(); ++id) {
|
||||
cc->OutputSidePackets().Get(id).Set(cc->InputSidePackets().Get(id));
|
||||
}
|
||||
}
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
cc->GetCounter("PassThrough")->Increment();
|
||||
if (cc->Inputs().NumEntries() == 0) {
|
||||
return tool::StatusStop();
|
||||
}
|
||||
for (CollectionItemId id = cc->Inputs().BeginId();
|
||||
id < cc->Inputs().EndId(); ++id) {
|
||||
if (!cc->Inputs().Get(id).IsEmpty()) {
|
||||
VLOG(3) << "Passing " << cc->Inputs().Get(id).Name() << " to "
|
||||
<< cc->Outputs().Get(id).Name() << " at "
|
||||
<< cc->InputTimestamp().DebugString();
|
||||
cc->Outputs().Get(id).AddPacket(cc->Inputs().Get(id).Value());
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(PassThroughCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,177 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// PreviousLoopbackCalculator is useful when a graph needs to process an input
|
||||
// together with some previous output.
|
||||
//
|
||||
// For the first packet that arrives on the MAIN input, the timestamp bound is
|
||||
// advanced on the PREV_LOOP. Downstream calculators will see this as an empty
|
||||
// packet. This way they are not kept waiting for the previous output, which
|
||||
// for the first iteration does not exist.
|
||||
//
|
||||
// Thereafter,
|
||||
// - Each non-empty MAIN packet results in:
|
||||
// a) a PREV_LOOP packet with contents of the LOOP packet received at the
|
||||
// timestamp of the previous non-empty MAIN packet
|
||||
// b) or in a PREV_LOOP timestamp bound update if the LOOP packet was empty.
|
||||
// - Each empty MAIN packet indicating timestamp bound update results in a
|
||||
// PREV_LOOP timestamp bound update.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "PreviousLoopbackCalculator"
|
||||
// input_stream: "MAIN:input"
|
||||
// input_stream: "LOOP:output"
|
||||
// input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||
// output_stream: "PREV_LOOP:prev_output"
|
||||
// }
|
||||
// node {
|
||||
// calculator: "FaceTracker"
|
||||
// input_stream: "VIDEO:input"
|
||||
// input_stream: "PREV_TRACK:prev_output"
|
||||
// output_stream: "TRACK:output"
|
||||
// }
|
||||
class PreviousLoopbackCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<AnyType> kMain{"MAIN"};
|
||||
static constexpr Input<AnyType> kLoop{"LOOP"};
|
||||
static constexpr Output<SameType<kLoop>> kPrevLoop{"PREV_LOOP"};
|
||||
// TODO: an optional PREV_TIMESTAMP output could be added to
|
||||
// carry the original timestamp of the packet on PREV_LOOP.
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kMain, kLoop, kPrevLoop,
|
||||
StreamHandler("ImmediateInputStreamHandler"),
|
||||
TimestampChange::Arbitrary());
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
// Process() function is invoked in response to MAIN/LOOP stream timestamp
|
||||
// bound updates.
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
kPrevLoop(cc).SetHeader(kLoop(cc).Header());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
// Non-empty packets and empty packets indicating timestamp bound updates
|
||||
// are guaranteed to have timestamps greater than timestamps of previous
|
||||
// packets within the same stream. Calculator tracks and operates on such
|
||||
// packets.
|
||||
|
||||
const PacketBase& main_packet = kMain(cc).packet();
|
||||
if (prev_main_ts_ < main_packet.timestamp()) {
|
||||
Timestamp loop_timestamp;
|
||||
if (!main_packet.IsEmpty()) {
|
||||
loop_timestamp = prev_non_empty_main_ts_;
|
||||
prev_non_empty_main_ts_ = main_packet.timestamp();
|
||||
} else {
|
||||
// Calculator advances PREV_LOOP timestamp bound in response to empty
|
||||
// MAIN packet, hence not caring about corresponding loop packet.
|
||||
loop_timestamp = Timestamp::Unset();
|
||||
}
|
||||
main_packet_specs_.push_back({main_packet.timestamp(), loop_timestamp});
|
||||
prev_main_ts_ = main_packet.timestamp();
|
||||
}
|
||||
|
||||
const PacketBase& loop_packet = kLoop(cc).packet();
|
||||
if (prev_loop_ts_ < loop_packet.timestamp()) {
|
||||
loop_packets_.push_back(loop_packet);
|
||||
prev_loop_ts_ = loop_packet.timestamp();
|
||||
}
|
||||
|
||||
while (!main_packet_specs_.empty() && !loop_packets_.empty()) {
|
||||
// The earliest MAIN packet.
|
||||
MainPacketSpec main_spec = main_packet_specs_.front();
|
||||
// The earliest LOOP packet.
|
||||
const PacketBase& loop_candidate = loop_packets_.front();
|
||||
// Match LOOP and MAIN packets.
|
||||
if (main_spec.loop_timestamp < loop_candidate.timestamp()) {
|
||||
// No LOOP packet can match the MAIN packet under review.
|
||||
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
|
||||
main_packet_specs_.pop_front();
|
||||
} else if (main_spec.loop_timestamp > loop_candidate.timestamp()) {
|
||||
// No MAIN packet can match the LOOP packet under review.
|
||||
loop_packets_.pop_front();
|
||||
} else {
|
||||
// Exact match found.
|
||||
if (loop_candidate.IsEmpty()) {
|
||||
// However, LOOP packet is empty.
|
||||
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
|
||||
} else {
|
||||
kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp));
|
||||
}
|
||||
loop_packets_.pop_front();
|
||||
main_packet_specs_.pop_front();
|
||||
}
|
||||
|
||||
// We can close PREV_LOOP output stream as soon as we processed last
|
||||
// possible MAIN packet. That can happen in two cases:
|
||||
// a) Non-empty MAIN packet has been received with Timestamp::Max()
|
||||
// b) Empty MAIN packet has been received with Timestamp::Max() indicating
|
||||
// MAIN is done.
|
||||
if (main_spec.timestamp == Timestamp::Done().PreviousAllowedInStream()) {
|
||||
kPrevLoop(cc).Close();
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
struct MainPacketSpec {
|
||||
Timestamp timestamp;
|
||||
// Expected timestamp of the packet from LOOP stream that corresponds to the
|
||||
// packet from MAIN stream descirbed by this spec.
|
||||
Timestamp loop_timestamp;
|
||||
};
|
||||
|
||||
// Contains specs for MAIN packets which only can be:
|
||||
// - non-empty packets
|
||||
// - empty packets indicating timestamp bound updates
|
||||
//
|
||||
// Sorted according to packet timestamps.
|
||||
std::deque<MainPacketSpec> main_packet_specs_;
|
||||
Timestamp prev_main_ts_ = Timestamp::Unstarted();
|
||||
Timestamp prev_non_empty_main_ts_ = Timestamp::Unstarted();
|
||||
|
||||
// Contains LOOP packets which only can be:
|
||||
// - the very first empty packet
|
||||
// - non empty packets
|
||||
// - empty packets indicating timestamp bound updates
|
||||
//
|
||||
// Sorted according to packet timestamps.
|
||||
std::deque<PacketBase> loop_packets_;
|
||||
// Using "Timestamp::Unset" instead of "Timestamp::Unstarted" in order to
|
||||
// allow addition of the very first empty packet (which doesn't indicate
|
||||
// timestamp bound change necessarily).
|
||||
Timestamp prev_loop_ts_ = Timestamp::Unset();
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(PreviousLoopbackCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,867 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/time/clock.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/framework/tool/sink.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::Eq;
|
||||
using ::testing::Pair;
|
||||
using ::testing::Value;
|
||||
namespace {
|
||||
|
||||
// Returns the timestamp values for a vector of Packets.
|
||||
// TODO: puth this kind of test util in a common place.
|
||||
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
|
||||
std::vector<int64> result;
|
||||
for (const Packet& packet : packets) {
|
||||
result.push_back(packet.Timestamp().Value());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
MATCHER(EmptyPacket, negation ? "isn't empty" : "is empty") {
|
||||
if (arg.IsEmpty()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
MATCHER_P(IntPacket, value, "") {
|
||||
return Value(arg.template Get<int>(), Eq(value));
|
||||
}
|
||||
|
||||
MATCHER_P2(PairPacket, timestamp, pair, "") {
|
||||
Timestamp actual_timestamp = arg.Timestamp();
|
||||
const auto& actual_pair = arg.template Get<std::pair<Packet, Packet>>();
|
||||
return Value(actual_timestamp, Eq(timestamp)) && Value(actual_pair, pair);
|
||||
}
|
||||
|
||||
TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
|
||||
std::vector<Packet> in_prev;
|
||||
CalculatorGraphConfig graph_config_ =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'in'
|
||||
node {
|
||||
calculator: 'PreviousLoopbackCalculator'
|
||||
input_stream: 'MAIN:in'
|
||||
input_stream: 'LOOP:out'
|
||||
input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||
output_stream: 'PREV_LOOP:previous'
|
||||
}
|
||||
# This calculator synchronizes its inputs as normal, so it is used
|
||||
# to check that both "in" and "previous" are ready.
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'in'
|
||||
input_stream: 'previous'
|
||||
output_stream: 'out'
|
||||
output_stream: 'previous2'
|
||||
}
|
||||
node {
|
||||
calculator: 'MakePairCalculator'
|
||||
input_stream: 'out'
|
||||
input_stream: 'previous2'
|
||||
output_stream: 'pair'
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("pair", &graph_config_, &in_prev);
|
||||
|
||||
CalculatorGraph graph_;
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config_, {}));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
|
||||
auto send_packet = [&graph_](const std::string& input_name, int n) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream(
|
||||
input_name, MakePacket<int>(n).At(Timestamp(n))));
|
||||
};
|
||||
|
||||
send_packet("in", 1);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1));
|
||||
EXPECT_THAT(in_prev.back(),
|
||||
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())));
|
||||
|
||||
send_packet("in", 2);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2));
|
||||
EXPECT_THAT(in_prev.back(),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))));
|
||||
|
||||
send_packet("in", 5);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2, 5));
|
||||
EXPECT_THAT(in_prev.back(),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(2))));
|
||||
|
||||
send_packet("in", 15);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2, 5, 15));
|
||||
EXPECT_THAT(in_prev.back(),
|
||||
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5))));
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
// A Calculator that outputs a summary packet in CalculatorBase::Close().
|
||||
class PacketOnCloseCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Outputs().Index(0).Set<int>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
sum_ += cc->Inputs().Index(0).Value().Get<int>();
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Close(CalculatorContext* cc) final {
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
MakePacket<int>(sum_).At(Timestamp::Max()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int sum_ = 0;
|
||||
};
|
||||
REGISTER_CALCULATOR(PacketOnCloseCalculator);
|
||||
|
||||
// Demonstrates that all ouput and input streams in PreviousLoopbackCalculator
|
||||
// will close as expected when all graph input streams are closed.
|
||||
TEST(PreviousLoopbackCalculator, ClosesCorrectly) {
|
||||
std::vector<Packet> outputs;
|
||||
CalculatorGraphConfig graph_config_ =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'in'
|
||||
node {
|
||||
calculator: 'PreviousLoopbackCalculator'
|
||||
input_stream: 'MAIN:in'
|
||||
input_stream: 'LOOP:out'
|
||||
input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||
output_stream: 'PREV_LOOP:previous'
|
||||
}
|
||||
# This calculator synchronizes its inputs as normal, so it is used
|
||||
# to check that both "in" and "previous" are ready.
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'in'
|
||||
input_stream: 'previous'
|
||||
output_stream: 'out'
|
||||
output_stream: 'previous2'
|
||||
}
|
||||
node {
|
||||
calculator: 'PacketOnCloseCalculator'
|
||||
input_stream: 'out'
|
||||
output_stream: 'close_out'
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("close_out", &graph_config_, &outputs);
|
||||
|
||||
CalculatorGraph graph_;
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config_, {}));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
|
||||
auto send_packet = [&graph_](const std::string& input_name, int n) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream(
|
||||
input_name, MakePacket<int>(n).At(Timestamp(n))));
|
||||
};
|
||||
|
||||
send_packet("in", 1);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1));
|
||||
|
||||
send_packet("in", 2);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2));
|
||||
|
||||
send_packet("in", 5);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2, 5));
|
||||
|
||||
send_packet("in", 15);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2, 5, 15));
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(TimestampValues(outputs),
|
||||
ElementsAre(1, 2, 5, 15, Timestamp::Max().Value()));
|
||||
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(PreviousLoopbackCalculator, ProcessesMaxTimestamp) {
|
||||
std::vector<Packet> out_and_previous_packets;
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'in'
|
||||
node {
|
||||
calculator: 'PreviousLoopbackCalculator'
|
||||
input_stream: 'MAIN:in'
|
||||
input_stream: 'LOOP:out'
|
||||
input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||
output_stream: 'PREV_LOOP:previous'
|
||||
}
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'in'
|
||||
input_stream: 'previous'
|
||||
output_stream: 'out'
|
||||
output_stream: 'previous2'
|
||||
}
|
||||
node {
|
||||
calculator: 'MakePairCalculator'
|
||||
input_stream: 'out'
|
||||
input_stream: 'previous'
|
||||
output_stream: 'out_and_previous'
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("out_and_previous", &graph_config,
|
||||
&out_and_previous_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in", MakePacket<int>(1).At(Timestamp::Max())));
|
||||
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(out_and_previous_packets,
|
||||
ElementsAre(PairPacket(Timestamp::Max(),
|
||||
Pair(IntPacket(1), EmptyPacket()))));
|
||||
|
||||
MP_EXPECT_OK(graph.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST(PreviousLoopbackCalculator, ProcessesMaxTimestampNonEmptyPrevious) {
|
||||
std::vector<Packet> out_and_previous_packets;
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'in'
|
||||
node {
|
||||
calculator: 'PreviousLoopbackCalculator'
|
||||
input_stream: 'MAIN:in'
|
||||
input_stream: 'LOOP:out'
|
||||
input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||
output_stream: 'PREV_LOOP:previous'
|
||||
}
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'in'
|
||||
input_stream: 'previous'
|
||||
output_stream: 'out'
|
||||
output_stream: 'previous2'
|
||||
}
|
||||
node {
|
||||
calculator: 'MakePairCalculator'
|
||||
input_stream: 'out'
|
||||
input_stream: 'previous'
|
||||
output_stream: 'out_and_previous'
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("out_and_previous", &graph_config,
|
||||
&out_and_previous_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in", MakePacket<int>(1).At(Timestamp::Min())));
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"in", MakePacket<int>(2).At(Timestamp::Max())));
|
||||
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(
|
||||
out_and_previous_packets,
|
||||
ElementsAre(
|
||||
PairPacket(Timestamp::Min(), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp::Max(), Pair(IntPacket(2), IntPacket(1)))));
|
||||
|
||||
MP_EXPECT_OK(graph.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// Demonstrates that downstream calculators won't be blocked by
|
||||
// always-empty-LOOP-stream.
|
||||
TEST(PreviousLoopbackCalculator, EmptyLoopForever) {
|
||||
std::vector<Packet> outputs;
|
||||
CalculatorGraphConfig graph_config_ =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'in'
|
||||
node {
|
||||
calculator: 'PreviousLoopbackCalculator'
|
||||
input_stream: 'MAIN:in'
|
||||
input_stream: 'LOOP:previous'
|
||||
input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||
output_stream: 'PREV_LOOP:previous'
|
||||
}
|
||||
# This calculator synchronizes its inputs as normal, so it is used
|
||||
# to check that both "in" and "previous" are ready.
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'in'
|
||||
input_stream: 'previous'
|
||||
output_stream: 'out'
|
||||
output_stream: 'previous2'
|
||||
}
|
||||
node {
|
||||
calculator: 'PacketOnCloseCalculator'
|
||||
input_stream: 'out'
|
||||
output_stream: 'close_out'
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("close_out", &graph_config_, &outputs);
|
||||
|
||||
CalculatorGraph graph_;
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config_, {}));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
|
||||
auto send_packet = [&graph_](const std::string& input_name, int n) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream(
|
||||
input_name, MakePacket<int>(n).At(Timestamp(n))));
|
||||
};
|
||||
|
||||
for (int main_ts = 0; main_ts < 50; ++main_ts) {
|
||||
send_packet("in", main_ts);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
std::vector<int64> ts_values = TimestampValues(outputs);
|
||||
EXPECT_EQ(ts_values.size(), main_ts + 1);
|
||||
for (int j = 0; j < main_ts + 1; ++j) {
|
||||
EXPECT_EQ(ts_values[j], j);
|
||||
}
|
||||
}
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
class PreviousLoopbackCalculatorProcessingTimestampsTest
|
||||
: public testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'input'
|
||||
input_stream: 'force_main_empty'
|
||||
input_stream: 'force_loop_empty'
|
||||
# Used to indicate "main" timestamp bound updates.
|
||||
node {
|
||||
calculator: 'GateCalculator'
|
||||
input_stream: 'input'
|
||||
input_stream: 'DISALLOW:force_main_empty'
|
||||
output_stream: 'main'
|
||||
}
|
||||
node {
|
||||
calculator: 'PreviousLoopbackCalculator'
|
||||
input_stream: 'MAIN:main'
|
||||
input_stream: 'LOOP:loop'
|
||||
input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||
output_stream: 'PREV_LOOP:prev_loop'
|
||||
}
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'input'
|
||||
input_stream: 'prev_loop'
|
||||
output_stream: 'passed_through_input'
|
||||
output_stream: 'passed_through_prev_loop'
|
||||
}
|
||||
# Used to indicate "loop" timestamp bound updates.
|
||||
node {
|
||||
calculator: 'GateCalculator'
|
||||
input_stream: 'input'
|
||||
input_stream: 'DISALLOW:force_loop_empty'
|
||||
output_stream: 'loop'
|
||||
}
|
||||
node {
|
||||
calculator: 'MakePairCalculator'
|
||||
input_stream: 'passed_through_input'
|
||||
input_stream: 'passed_through_prev_loop'
|
||||
output_stream: 'passed_through_input_and_prev_loop'
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("passed_through_input_and_prev_loop", &graph_config,
|
||||
&output_packets_);
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
}
|
||||
|
||||
void SendPackets(int timestamp, int input, bool force_main_empty,
|
||||
bool force_loop_empty) {
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"input", MakePacket<int>(input).At(Timestamp(timestamp))));
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"force_main_empty",
|
||||
MakePacket<bool>(force_main_empty).At(Timestamp(timestamp))));
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"force_loop_empty",
|
||||
MakePacket<bool>(force_loop_empty).At(Timestamp(timestamp))));
|
||||
}
|
||||
|
||||
CalculatorGraph graph_;
|
||||
std::vector<Packet> output_packets_;
|
||||
};
|
||||
|
||||
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
|
||||
MultiplePacketsEmptyMainNonEmptyLoop) {
|
||||
SendPackets(/*timestamp=*/1, /*input=*/1, /*force_main_empty=*/true,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/2, /*input=*/2, /*force_main_empty=*/true,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/3, /*input=*/3, /*force_main_empty=*/true,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/5, /*input=*/5, /*force_main_empty=*/true,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/15, /*input=*/15,
|
||||
/*force_main_empty=*/true,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(
|
||||
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
|
||||
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
|
||||
MultiplePacketsNonEmptyMainEmptyLoop) {
|
||||
SendPackets(/*timestamp=*/1, /*input=*/1,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/2, /*input=*/2,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/3, /*input=*/3,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/5, /*input=*/5,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/15, /*input=*/15,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(
|
||||
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
|
||||
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
|
||||
MultiplePacketsAlteringMainNonEmptyLoop) {
|
||||
SendPackets(/*timestamp=*/1, /*input=*/1,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/2, /*input=*/2, /*force_main_empty=*/true,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/3, /*input=*/3,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1)))));
|
||||
|
||||
SendPackets(/*timestamp=*/5, /*input=*/5,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1))),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3)))));
|
||||
|
||||
SendPackets(/*timestamp=*/15, /*input=*/15,
|
||||
/*force_main_empty=*/true,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(
|
||||
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1))),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3))),
|
||||
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
|
||||
MultiplePacketsNonEmptyMainAlteringLoop) {
|
||||
SendPackets(/*timestamp=*/1, /*input=*/1,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/2, /*input=*/2,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
|
||||
|
||||
SendPackets(/*timestamp=*/3, /*input=*/3,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/5, /*input=*/5,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3)))));
|
||||
|
||||
SendPackets(/*timestamp=*/15, /*input=*/15,
|
||||
/*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(
|
||||
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3))),
|
||||
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
|
||||
MultiplePacketsCheckIfLastCorrectAlteringMainAlteringLoop) {
|
||||
int num_packets = 1000;
|
||||
for (int i = 0; i < num_packets; ++i) {
|
||||
bool force_main_empty = i % 3 == 0 ? true : false;
|
||||
bool force_loop_empty = i % 2 == 0 ? true : false;
|
||||
SendPackets(/*timestamp=*/i + 1, /*input=*/i + 1, force_main_empty,
|
||||
force_loop_empty);
|
||||
}
|
||||
SendPackets(/*timestamp=*/num_packets + 1,
|
||||
/*input=*/num_packets + 1, /*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/false);
|
||||
SendPackets(/*timestamp=*/num_packets + 2,
|
||||
/*input=*/num_packets + 2, /*force_main_empty=*/false,
|
||||
/*force_loop_empty=*/false);
|
||||
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
ASSERT_FALSE(output_packets_.empty());
|
||||
EXPECT_THAT(
|
||||
output_packets_.back(),
|
||||
PairPacket(Timestamp(num_packets + 2),
|
||||
Pair(IntPacket(num_packets + 2), IntPacket(num_packets + 1))));
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
// Similar to GateCalculator, but it doesn't propagate timestamp bound updates.
|
||||
class DroppingGateCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
cc->Inputs().Tag("DISALLOW").Set<bool>();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (!cc->Inputs().Index(0).IsEmpty() &&
|
||||
!cc->Inputs().Tag("DISALLOW").Get<bool>()) {
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(DroppingGateCalculator);
|
||||
|
||||
// Tests PreviousLoopbackCalculator in cases when there are no "LOOP" timestamp
|
||||
// bound updates and non-empty packets for a while and the aforementioned start
|
||||
// to arrive at some point. So, "PREV_LOOP" is delayed for a couple of inputs.
|
||||
class PreviousLoopbackCalculatorDelayBehaviorTest : public testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'input'
|
||||
# Drops "loop" when set to "true", delaying output of prev_loop, hence
|
||||
# delaying output of the graph.
|
||||
input_stream: 'delay_next_output'
|
||||
node {
|
||||
calculator: 'PreviousLoopbackCalculator'
|
||||
input_stream: 'MAIN:input'
|
||||
input_stream: 'LOOP:loop'
|
||||
input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||
output_stream: 'PREV_LOOP:prev_loop'
|
||||
}
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'input'
|
||||
input_stream: 'prev_loop'
|
||||
output_stream: 'passed_through_input'
|
||||
output_stream: 'passed_through_prev_loop'
|
||||
}
|
||||
node {
|
||||
calculator: 'DroppingGateCalculator'
|
||||
input_stream: 'input'
|
||||
input_stream: 'DISALLOW:delay_next_output'
|
||||
output_stream: 'loop'
|
||||
}
|
||||
node {
|
||||
calculator: 'MakePairCalculator'
|
||||
input_stream: 'passed_through_input'
|
||||
input_stream: 'passed_through_prev_loop'
|
||||
output_stream: 'passed_through_input_and_prev_loop'
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("passed_through_input_and_prev_loop", &graph_config,
|
||||
&output_packets_);
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
}
|
||||
|
||||
void SendPackets(int timestamp, int input, bool delay_next_output) {
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"input", MakePacket<int>(input).At(Timestamp(timestamp))));
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"delay_next_output",
|
||||
MakePacket<bool>(delay_next_output).At(Timestamp(timestamp))));
|
||||
}
|
||||
|
||||
CalculatorGraph graph_;
|
||||
std::vector<Packet> output_packets_;
|
||||
};
|
||||
|
||||
TEST_F(PreviousLoopbackCalculatorDelayBehaviorTest, MultipleDelayedOutputs) {
|
||||
SendPackets(/*timestamp=*/1, /*input=*/1, /*delay_next_output=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/2, /*input=*/2, /*delay_next_output=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/3, /*input=*/3, /*delay_next_output=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/5, /*input=*/5, /*delay_next_output=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/15, /*input=*/15, /*delay_next_output=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(
|
||||
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
|
||||
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5)))));
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(PreviousLoopbackCalculatorDelayBehaviorTest,
|
||||
NonDelayedOutputFollowedByMultipleDelayedOutputs) {
|
||||
SendPackets(/*timestamp=*/1, /*input=*/1, /*delay_next_output=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/2, /*input=*/2, /*delay_next_output=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
|
||||
|
||||
SendPackets(/*timestamp=*/3, /*input=*/3, /*delay_next_output=*/true);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
|
||||
|
||||
SendPackets(/*timestamp=*/5, /*input=*/5, /*delay_next_output=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
|
||||
|
||||
SendPackets(/*timestamp=*/15, /*input=*/15, /*delay_next_output=*/false);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_THAT(
|
||||
output_packets_,
|
||||
ElementsAre(
|
||||
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
|
||||
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
|
||||
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5)))));
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cfloat>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/quantize_float_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
// Quantizes a vector of floats to a std::string so that each float becomes a
|
||||
// byte in the [0, 255] range. Any value above max_quantized_value or below
|
||||
// min_quantized_value will be saturated to '/xFF' or '/0'.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "QuantizeFloatVectorCalculator"
|
||||
// input_stream: "FLOAT_VECTOR:float_vector"
|
||||
// output_stream: "ENCODED:encoded"
|
||||
// options {
|
||||
// [mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
// max_quantized_value: 64
|
||||
// min_quantized_value: -64
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
namespace mediapipe {
|
||||
|
||||
class QuantizeFloatVectorCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
|
||||
cc->Outputs().Tag("ENCODED").Set<std::string>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
const auto options =
|
||||
cc->Options<::mediapipe::QuantizeFloatVectorCalculatorOptions>();
|
||||
if (!options.has_max_quantized_value() ||
|
||||
!options.has_min_quantized_value()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Both max_quantized_value and min_quantized_value must be provided "
|
||||
"in QuantizeFloatVectorCalculatorOptions.");
|
||||
}
|
||||
max_quantized_value_ = options.max_quantized_value();
|
||||
min_quantized_value_ = options.min_quantized_value();
|
||||
if (max_quantized_value_ < min_quantized_value_ + FLT_EPSILON) {
|
||||
return absl::InvalidArgumentError(
|
||||
"max_quantized_value must be greater than min_quantized_value.");
|
||||
}
|
||||
range_ = max_quantized_value_ - min_quantized_value_;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
const std::vector<float>& float_vector =
|
||||
cc->Inputs().Tag("FLOAT_VECTOR").Value().Get<std::vector<float>>();
|
||||
int feature_size = float_vector.size();
|
||||
std::string encoded_features;
|
||||
encoded_features.reserve(feature_size);
|
||||
for (int i = 0; i < feature_size; i++) {
|
||||
float old_value = float_vector[i];
|
||||
if (old_value < min_quantized_value_) {
|
||||
old_value = min_quantized_value_;
|
||||
}
|
||||
if (old_value > max_quantized_value_) {
|
||||
old_value = max_quantized_value_;
|
||||
}
|
||||
unsigned char encoded = static_cast<unsigned char>(
|
||||
(old_value - min_quantized_value_) * (255.0 / range_));
|
||||
encoded_features += encoded;
|
||||
}
|
||||
cc->Outputs().Tag("ENCODED").AddPacket(
|
||||
MakePacket<std::string>(encoded_features).At(cc->InputTimestamp()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
float max_quantized_value_;
|
||||
float min_quantized_value_;
|
||||
float range_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(QuantizeFloatVectorCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message QuantizeFloatVectorCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional QuantizeFloatVectorCalculatorOptions ext = 259848061;
|
||||
}
|
||||
|
||||
optional float max_quantized_value = 1;
|
||||
optional float min_quantized_value = 2;
|
||||
}
|
||||
|
|
@ -1,204 +0,0 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
min_quantized_value: 1
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"Both max_quantized_value and min_quantized_value must be provided"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
max_quantized_value: -1
|
||||
min_quantized_value: 1
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"max_quantized_value must be greater than min_quantized_value"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
max_quantized_value: 1
|
||||
min_quantized_value: 1
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"max_quantized_value must be greater than min_quantized_value"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, TestEmptyVector) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
max_quantized_value: 1
|
||||
min_quantized_value: -1
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_TRUE(outputs[0].Get<std::string>().empty());
|
||||
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, TestNonEmptyVector) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
max_quantized_value: 64
|
||||
min_quantized_value: -64
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f};
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
const std::string& result = outputs[0].Get<std::string>();
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_EQ(5, result.size());
|
||||
// 127
|
||||
EXPECT_EQ('\x7F', result.c_str()[0]);
|
||||
// 0
|
||||
EXPECT_EQ('\0', result.c_str()[1]);
|
||||
// 255
|
||||
EXPECT_EQ('\xFF', result.c_str()[2]);
|
||||
// 63
|
||||
EXPECT_EQ('\x3F', result.c_str()[3]);
|
||||
// 191
|
||||
EXPECT_EQ('\xBF', result.c_str()[4]);
|
||||
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, TestSaturation) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "QuantizeFloatVectorCalculator"
|
||||
input_stream: "FLOAT_VECTOR:float_vector"
|
||||
output_stream: "ENCODED:encoded"
|
||||
options {
|
||||
[mediapipe.QuantizeFloatVectorCalculatorOptions.ext]: {
|
||||
max_quantized_value: 64
|
||||
min_quantized_value: -64
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> vector = {-65.0f, 65.0f};
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
const std::string& result = outputs[0].Get<std::string>();
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_EQ(2, result.size());
|
||||
// 0
|
||||
EXPECT_EQ('\0', result.c_str()[0]);
|
||||
// 255
|
||||
EXPECT_EQ('\xFF', result.c_str()[1]);
|
||||
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user