Merge branch 'google:master' into master

This commit is contained in:
kuaashish 2023-09-01 13:59:13 +05:30 committed by GitHub
commit e060824cd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
766 changed files with 20819 additions and 6050 deletions

View File

@ -73,12 +73,9 @@ http_archive(
http_archive( http_archive(
name = "zlib", name = "zlib",
build_file = "@//third_party:zlib.BUILD", build_file = "@//third_party:zlib.BUILD",
sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", sha256 = "b3a24de97a8fdbc835b9833169501030b8977031bcb54b3b3ac13740f846ab30",
strip_prefix = "zlib-1.2.11", strip_prefix = "zlib-1.2.13",
urls = [ url = "http://zlib.net/fossils/zlib-1.2.13.tar.gz",
"http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz",
"http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15
],
patches = [ patches = [
"@//third_party:zlib.diff", "@//third_party:zlib.diff",
], ],
@ -157,22 +154,22 @@ http_archive(
# 2020-08-21 # 2020-08-21
http_archive( http_archive(
name = "com_github_glog_glog", name = "com_github_glog_glog",
strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6", strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372",
sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab", sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb",
urls = [ urls = [
"https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip", "https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip",
], ],
) )
http_archive( http_archive(
name = "com_github_glog_glog_no_gflags", name = "com_github_glog_glog_no_gflags",
strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6", strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372",
sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab", sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb",
build_file = "@//third_party:glog_no_gflags.BUILD", build_file = "@//third_party:glog_no_gflags.BUILD",
urls = [ urls = [
"https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip", "https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip",
], ],
patches = [ patches = [
"@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff", "@//third_party:com_github_glog_glog.diff",
], ],
patch_args = [ patch_args = [
"-p1", "-p1",
@ -485,9 +482,10 @@ http_archive(
) )
# TensorFlow repo should always go after the other external dependencies. # TensorFlow repo should always go after the other external dependencies.
# TF on 2023-05-26. # TF on 2023-07-26.
_TENSORFLOW_GIT_COMMIT = "67d5c561981edc45daf3f9d73ddd1a77963733ca" _TENSORFLOW_GIT_COMMIT = "e92261fd4cec0b726692081c4d2966b75abf31dd"
_TENSORFLOW_SHA256 = "0c8326285e9cb695313e194b97d388eea70bf8bf5b13e8f0962ca8eed5179ece" # curl -L https://github.com/tensorflow/tensorflow/archive/<TENSORFLOW_GIT_COMMIT>.tar.gz | shasum -a 256
_TENSORFLOW_SHA256 = "478a229bd4ec70a5b568ac23b5ea013d9fca46a47d6c43e30365a0412b9febf4"
http_archive( http_archive(
name = "org_tensorflow", name = "org_tensorflow",
urls = [ urls = [
@ -495,6 +493,7 @@ http_archive(
], ],
patches = [ patches = [
"@//third_party:org_tensorflow_compatibility_fixes.diff", "@//third_party:org_tensorflow_compatibility_fixes.diff",
"@//third_party:org_tensorflow_system_python.diff",
# Diff is generated with a script, don't update it manually. # Diff is generated with a script, don't update it manually.
"@//third_party:org_tensorflow_custom_ops.diff", "@//third_party:org_tensorflow_custom_ops.diff",
], ],

View File

@ -50,7 +50,7 @@ as the primary developer documentation site for MediaPipe as of April 3, 2023.*
3. The [`hello world`] example uses a simple MediaPipe graph in the 3. The [`hello world`] example uses a simple MediaPipe graph in the
`PrintHelloWorld()` function, defined in a [`CalculatorGraphConfig`] proto. `PrintHelloWorld()` function, defined in a [`CalculatorGraphConfig`] proto.
```C++ ```c++
absl::Status PrintHelloWorld() { absl::Status PrintHelloWorld() {
// Configures a simple graph, which concatenates 2 PassThroughCalculators. // Configures a simple graph, which concatenates 2 PassThroughCalculators.
CalculatorGraphConfig config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"( CalculatorGraphConfig config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
@ -126,7 +126,7 @@ as the primary developer documentation site for MediaPipe as of April 3, 2023.*
```c++ ```c++
mediapipe::Packet packet; mediapipe::Packet packet;
while (poller.Next(&packet)) { while (poller.Next(&packet)) {
LOG(INFO) << packet.Get<string>(); ABSL_LOG(INFO) << packet.Get<string>();
} }
``` ```

View File

@ -138,7 +138,7 @@ Create a `BUILD` file in the `$APPLICATION_PATH` and add the following build
rules: rules:
``` ```
MIN_IOS_VERSION = "11.0" MIN_IOS_VERSION = "12.0"
load( load(
"@build_bazel_rules_apple//apple:ios.bzl", "@build_bazel_rules_apple//apple:ios.bzl",

View File

@ -14,81 +14,155 @@
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
# Note: yes, these need to use "//external:android/crosstool", not load("@mediapipe//mediapipe:platforms.bzl", "config_setting_and_platform")
# @androidndk//:default_crosstool.
# Generic Android
config_setting( config_setting(
name = "android", name = "android",
values = {"crosstool_top": "//external:android/crosstool"}, constraint_values = [
"@platforms//os:android",
],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( # Android x86 32-bit.
config_setting_and_platform(
name = "android_x86", name = "android_x86",
values = { constraint_values = [
"crosstool_top": "//external:android/crosstool", "@platforms//os:android",
"cpu": "x86", "@platforms//cpu:x86_32",
}, ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( # Android x86 64-bit.
config_setting_and_platform(
name = "android_x86_64", name = "android_x86_64",
values = { constraint_values = [
"crosstool_top": "//external:android/crosstool", "@platforms//os:android",
"cpu": "x86_64", "@platforms//cpu:x86_64",
}, ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( # Android ARMv7.
name = "android_armeabi", config_setting_and_platform(
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "armeabi",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "android_arm", name = "android_arm",
values = { constraint_values = [
"crosstool_top": "//external:android/crosstool", "@platforms//os:android",
"cpu": "armeabi-v7a", "@platforms//cpu:armv7",
}, ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( # Android ARM64.
config_setting_and_platform(
name = "android_arm64", name = "android_arm64",
values = { constraint_values = [
"crosstool_top": "//external:android/crosstool", "@platforms//os:android",
"cpu": "arm64-v8a", "@platforms//cpu:arm64",
}, ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# Note: this cannot just match "apple_platform_type": "macos" because that option # Generic MacOS.
# defaults to "macos" even when building on Linux! config_setting(
alias(
name = "macos", name = "macos",
actual = select({ constraint_values = [
":macos_i386": ":macos_i386", "@platforms//os:macos",
":macos_x86_64": ":macos_x86_64", ],
":macos_arm64": ":macos_arm64",
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
}),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# Note: this also matches on crosstool_top so that it does not produce ambiguous # MacOS x86 64-bit.
# selectors when used together with "android". config_setting_and_platform(
name = "macos_x86_64",
constraint_values = [
"@platforms//os:macos",
"@platforms//cpu:x86_64",
],
visibility = ["//visibility:public"],
)
# MacOS ARM64.
config_setting_and_platform(
name = "macos_arm64",
constraint_values = [
"@platforms//os:macos",
"@platforms//cpu:arm64",
],
visibility = ["//visibility:public"],
)
# Generic iOS.
config_setting( config_setting(
name = "ios", name = "ios",
values = { constraint_values = [
"crosstool_top": "@bazel_tools//tools/cpp:toolchain", "@platforms//os:ios",
"apple_platform_type": "ios", ],
}, visibility = ["//visibility:public"],
)
# iOS device ARM32.
config_setting_and_platform(
name = "ios_armv7",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm",
],
visibility = ["//visibility:public"],
)
# iOS device ARM64.
config_setting_and_platform(
name = "ios_arm64",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm64",
],
visibility = ["//visibility:public"],
)
# iOS device ARM64E.
config_setting_and_platform(
name = "ios_arm64e",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm64e",
],
visibility = ["//visibility:public"],
)
# iOS simulator x86 32-bit.
config_setting_and_platform(
name = "ios_i386",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:x86_32",
"@build_bazel_apple_support//constraints:simulator",
],
visibility = ["//visibility:public"],
)
# iOS simulator x86 64-bit.
config_setting_and_platform(
name = "ios_x86_64",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:x86_64",
"@build_bazel_apple_support//constraints:simulator",
],
visibility = ["//visibility:public"],
)
# iOS simulator ARM64.
config_setting_and_platform(
name = "ios_sim_arm64",
constraint_values = [
"@platforms//os:ios",
"@platforms//cpu:arm64",
"@build_bazel_apple_support//constraints:simulator",
],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
@ -102,52 +176,24 @@ alias(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( # Windows 64-bit.
name = "macos_i386", config_setting_and_platform(
values = {
"apple_platform_type": "macos",
"cpu": "darwin",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_x86_64",
values = {
"apple_platform_type": "macos",
"cpu": "darwin_x86_64",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_arm64",
values = {
"apple_platform_type": "macos",
"cpu": "darwin_arm64",
},
visibility = ["//visibility:public"],
)
[
config_setting(
name = arch,
values = {"cpu": arch},
visibility = ["//visibility:public"],
)
for arch in [
"ios_i386",
"ios_x86_64",
"ios_armv7",
"ios_arm64",
"ios_arm64e",
"ios_sim_arm64",
]
]
config_setting(
name = "windows", name = "windows",
values = {"cpu": "x64_windows"}, constraint_values = [
"@platforms//os:windows",
"@platforms//cpu:x86_64",
],
visibility = ["//visibility:public"],
)
# Linux 64-bit.
config_setting_and_platform(
name = "linux",
constraint_values = [
"@platforms//os:linux",
"@platforms//cpu:x86_64",
],
visibility = ["//visibility:public"],
) )
exports_files( exports_files(

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Placeholder: load py_proto_library
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) licenses(["notice"])
@ -145,6 +146,7 @@ cc_library(
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util", "//mediapipe/util:time_series_util",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp/mfcc", "@com_google_audio_tools//audio/dsp/mfcc",
"@eigen_archive//:eigen3", "@eigen_archive//:eigen3",
@ -163,8 +165,9 @@ cc_library(
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/util:time_series_util", "//mediapipe/util:time_series_util",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp:resampler", "@com_google_audio_tools//audio/dsp:resampler",
"@com_google_audio_tools//audio/dsp:resampler_q", "@com_google_audio_tools//audio/dsp:resampler_q",
@ -185,6 +188,7 @@ cc_library(
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util", "//mediapipe/util:time_series_util",
"@com_google_absl//absl/log:absl_check",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -219,13 +223,12 @@ cc_library(
deps = [ deps = [
":time_series_framer_calculator_cc_proto", ":time_series_framer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util", "//mediapipe/util:time_series_util",
"@com_google_absl//absl/log:absl_check",
"@com_google_audio_tools//audio/dsp:window_functions", "@com_google_audio_tools//audio/dsp:window_functions",
"@eigen_archive//:eigen3", "@eigen_archive//:eigen3",
], ],
@ -296,6 +299,7 @@ cc_test(
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util", "//mediapipe/util:time_series_test_util",
"@com_google_absl//absl/log:absl_log",
"@com_google_audio_tools//audio/dsp:number_util", "@com_google_audio_tools//audio/dsp:number_util",
"@eigen_archive//:eigen3", "@eigen_archive//:eigen3",
], ],
@ -319,6 +323,21 @@ cc_test(
], ],
) )
cc_binary(
name = "time_series_framer_calculator_benchmark",
srcs = ["time_series_framer_calculator_benchmark.cc"],
deps = [
":time_series_framer_calculator",
":time_series_framer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"@com_google_absl//absl/log:absl_check",
"@com_google_benchmark//:benchmark",
],
)
cc_test( cc_test(
name = "time_series_framer_calculator_test", name = "time_series_framer_calculator_test",
srcs = ["time_series_framer_calculator_test.cc"], srcs = ["time_series_framer_calculator_test.cc"],
@ -333,6 +352,7 @@ cc_test(
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util", "//mediapipe/util:time_series_test_util",
"@com_google_absl//absl/log:absl_log",
"@com_google_audio_tools//audio/dsp:window_functions", "@com_google_audio_tools//audio/dsp:window_functions",
"@eigen_archive//:eigen3", "@eigen_archive//:eigen3",
], ],

View File

@ -23,6 +23,7 @@
#include <vector> #include <vector>
#include "Eigen/Core" #include "Eigen/Core"
#include "absl/log/absl_check.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
@ -138,7 +139,7 @@ absl::Status FramewiseTransformCalculatorBase::Process(CalculatorContext* cc) {
TransformFrame(input_frame, &output_frame); TransformFrame(input_frame, &output_frame);
// Copy output from vector<float> to Eigen::Vector. // Copy output from vector<float> to Eigen::Vector.
CHECK_EQ(output_frame.size(), num_output_channels_); ABSL_CHECK_EQ(output_frame.size(), num_output_channels_);
Eigen::Map<const Eigen::MatrixXd> output_frame_map(&output_frame[0], Eigen::Map<const Eigen::MatrixXd> output_frame_map(&output_frame[0],
output_frame.size(), 1); output_frame.size(), 1);
output->col(frame) = output_frame_map.cast<float>(); output->col(frame) = output_frame_map.cast<float>();

View File

@ -16,6 +16,8 @@
#include "mediapipe/calculators/audio/rational_factor_resample_calculator.h" #include "mediapipe/calculators/audio/rational_factor_resample_calculator.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "audio/dsp/resampler_q.h" #include "audio/dsp/resampler_q.h"
using audio_dsp::Resampler; using audio_dsp::Resampler;
@ -45,9 +47,9 @@ void CopyVectorToChannel(const std::vector<float>& vec, Matrix* matrix,
if (matrix->cols() == 0) { if (matrix->cols() == 0) {
matrix->resize(matrix->rows(), vec.size()); matrix->resize(matrix->rows(), vec.size());
} else { } else {
CHECK_EQ(vec.size(), matrix->cols()); ABSL_CHECK_EQ(vec.size(), matrix->cols());
} }
CHECK_LT(channel, matrix->rows()); ABSL_CHECK_LT(channel, matrix->rows());
matrix->row(channel) = matrix->row(channel) =
Eigen::Map<const Eigen::ArrayXf>(vec.data(), vec.size()); Eigen::Map<const Eigen::ArrayXf>(vec.data(), vec.size());
} }
@ -77,7 +79,7 @@ absl::Status RationalFactorResampleCalculator::Open(CalculatorContext* cc) {
r = ResamplerFromOptions(source_sample_rate_, target_sample_rate_, r = ResamplerFromOptions(source_sample_rate_, target_sample_rate_,
resample_options); resample_options);
if (!r) { if (!r) {
LOG(ERROR) << "Failed to initialize resampler."; ABSL_LOG(ERROR) << "Failed to initialize resampler.";
return absl::UnknownError("Failed to initialize resampler."); return absl::UnknownError("Failed to initialize resampler.");
} }
} }

View File

@ -27,7 +27,6 @@
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/util/time_series_util.h" #include "mediapipe/util/time_series_util.h"
namespace mediapipe { namespace mediapipe {

View File

@ -210,6 +210,23 @@ REGISTER_CALCULATOR(SpectrogramCalculator);
// Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0). // Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0).
const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518; const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518;
namespace {
std::unique_ptr<audio_dsp::WindowFunction> MakeWindowFun(
const SpectrogramCalculatorOptions::WindowType window_type) {
switch (window_type) {
// The cosine window and square root of Hann are equivalent.
case SpectrogramCalculatorOptions::COSINE:
case SpectrogramCalculatorOptions::SQRT_HANN:
return std::make_unique<audio_dsp::CosineWindow>();
case SpectrogramCalculatorOptions::HANN:
return std::make_unique<audio_dsp::HannWindow>();
case SpectrogramCalculatorOptions::HAMMING:
return std::make_unique<audio_dsp::HammingWindow>();
}
return nullptr;
}
} // namespace
absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
SpectrogramCalculatorOptions spectrogram_options = SpectrogramCalculatorOptions spectrogram_options =
cc->Options<SpectrogramCalculatorOptions>(); cc->Options<SpectrogramCalculatorOptions>();
@ -266,28 +283,14 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
output_scale_ = spectrogram_options.output_scale(); output_scale_ = spectrogram_options.output_scale();
auto window_fun = MakeWindowFun(spectrogram_options.window_type());
if (window_fun == nullptr) {
return absl::Status(absl::StatusCode::kInvalidArgument,
absl::StrCat("Invalid window type ",
spectrogram_options.window_type()));
}
std::vector<double> window; std::vector<double> window;
switch (spectrogram_options.window_type()) { window_fun->GetPeriodicSamples(frame_duration_samples_, &window);
case SpectrogramCalculatorOptions::COSINE:
audio_dsp::CosineWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::HANN:
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::HAMMING:
audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
break;
case SpectrogramCalculatorOptions::SQRT_HANN: {
audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_,
&window);
absl::c_transform(window, window.begin(),
[](double x) { return std::sqrt(x); });
break;
}
}
// Propagate settings down to the actual Spectrogram object. // Propagate settings down to the actual Spectrogram object.
spectrogram_generators_.clear(); spectrogram_generators_.clear();

View File

@ -68,7 +68,7 @@ message SpectrogramCalculatorOptions {
HANN = 0; HANN = 0;
HAMMING = 1; HAMMING = 1;
COSINE = 2; COSINE = 2;
SQRT_HANN = 4; SQRT_HANN = 4; // Alias of COSINE.
} }
optional WindowType window_type = 6 [default = HANN]; optional WindowType window_type = 6 [default = HANN];

View File

@ -22,6 +22,7 @@
#include <vector> #include <vector>
#include "Eigen/Core" #include "Eigen/Core"
#include "absl/log/absl_log.h"
#include "audio/dsp/number_util.h" #include "audio/dsp/number_util.h"
#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h" #include "mediapipe/calculators/audio/spectrogram_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -882,9 +883,9 @@ void BM_ProcessDC(benchmark::State& state) {
const CalculatorRunner::StreamContents& output = runner.Outputs().Index(0); const CalculatorRunner::StreamContents& output = runner.Outputs().Index(0);
const Matrix& output_matrix = output.packets[0].Get<Matrix>(); const Matrix& output_matrix = output.packets[0].Get<Matrix>();
LOG(INFO) << "Output matrix=" << output_matrix.rows() << "x" ABSL_LOG(INFO) << "Output matrix=" << output_matrix.rows() << "x"
<< output_matrix.cols(); << output_matrix.cols();
LOG(INFO) << "First values=" << output_matrix(0, 0) << ", " ABSL_LOG(INFO) << "First values=" << output_matrix(0, 0) << ", "
<< output_matrix(1, 0) << ", " << output_matrix(2, 0) << ", " << output_matrix(1, 0) << ", " << output_matrix(2, 0) << ", "
<< output_matrix(3, 0); << output_matrix(3, 0);
} }

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "absl/log/absl_check.h"
#include "mediapipe/calculators/audio/stabilized_log_calculator.pb.h" #include "mediapipe/calculators/audio/stabilized_log_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
@ -59,7 +60,7 @@ class StabilizedLogCalculator : public CalculatorBase {
output_scale_ = stabilized_log_calculator_options.output_scale(); output_scale_ = stabilized_log_calculator_options.output_scale();
check_nonnegativity_ = check_nonnegativity_ =
stabilized_log_calculator_options.check_nonnegativity(); stabilized_log_calculator_options.check_nonnegativity();
CHECK_GE(stabilizer_, 0.0) ABSL_CHECK_GE(stabilizer_, 0.0)
<< "stabilizer must be >= 0.0, received a value of " << stabilizer_; << "stabilizer must be >= 0.0, received a value of " << stabilizer_;
// If the input packets have a header, propagate the header to the output. // If the input packets have a header, propagate the header to the output.

View File

@ -15,19 +15,17 @@
// Defines TimeSeriesFramerCalculator. // Defines TimeSeriesFramerCalculator.
#include <math.h> #include <math.h>
#include <deque> #include <vector>
#include <memory>
#include <string>
#include "Eigen/Core" #include "Eigen/Core"
#include "absl/log/absl_check.h"
#include "audio/dsp/window_functions.h" #include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h" #include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/util/time_series_util.h" #include "mediapipe/util/time_series_util.h"
namespace mediapipe { namespace mediapipe {
@ -88,11 +86,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
absl::Status Close(CalculatorContext* cc) override; absl::Status Close(CalculatorContext* cc) override;
private: private:
// Adds input data to the internal buffer.
void EnqueueInput(CalculatorContext* cc);
// Constructs and emits framed output packets.
void FrameOutput(CalculatorContext* cc);
Timestamp CurrentOutputTimestamp() { Timestamp CurrentOutputTimestamp() {
if (use_local_timestamp_) { if (use_local_timestamp_) {
return current_timestamp_; return current_timestamp_;
@ -106,21 +99,13 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
Timestamp::kTimestampUnitsPerSecond); Timestamp::kTimestampUnitsPerSecond);
} }
// Returns the timestamp of a sample on a base, which is usually the time
// stamp of a packet.
Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base,
int64_t number_of_samples) {
return timestamp_base + round(number_of_samples / sample_rate_ *
Timestamp::kTimestampUnitsPerSecond);
}
// The number of input samples to advance after the current output frame is // The number of input samples to advance after the current output frame is
// emitted. // emitted.
int next_frame_step_samples() const { int next_frame_step_samples() const {
// All numbers are in input samples. // All numbers are in input samples.
const int64_t current_output_frame_start = static_cast<int64_t>( const int64_t current_output_frame_start = static_cast<int64_t>(
round(cumulative_output_frames_ * average_frame_step_samples_)); round(cumulative_output_frames_ * average_frame_step_samples_));
CHECK_EQ(current_output_frame_start, cumulative_completed_samples_); ABSL_CHECK_EQ(current_output_frame_start, cumulative_completed_samples_);
const int64_t next_output_frame_start = static_cast<int64_t>( const int64_t next_output_frame_start = static_cast<int64_t>(
round((cumulative_output_frames_ + 1) * average_frame_step_samples_)); round((cumulative_output_frames_ + 1) * average_frame_step_samples_));
return next_output_frame_start - current_output_frame_start; return next_output_frame_start - current_output_frame_start;
@ -142,61 +127,174 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
Timestamp initial_input_timestamp_; Timestamp initial_input_timestamp_;
// The current timestamp is updated along with the incoming packets. // The current timestamp is updated along with the incoming packets.
Timestamp current_timestamp_; Timestamp current_timestamp_;
int num_channels_;
// Each entry in this deque consists of a single sample, i.e. a // Samples are buffered in a vector of sample blocks.
// single column vector, and its timestamp. class SampleBlockBuffer {
std::deque<std::pair<Matrix, Timestamp>> sample_buffer_; public:
// Initializes the buffer.
void Init(double sample_rate, int num_channels) {
ts_units_per_sample_ = Timestamp::kTimestampUnitsPerSecond / sample_rate;
num_channels_ = num_channels;
num_samples_ = 0;
first_block_offset_ = 0;
}
// Number of channels, equal to the number of rows in each Matrix.
int num_channels() const { return num_channels_; }
// Total number of available samples over all blocks.
int num_samples() const { return num_samples_; }
// Pushes a new block of samples on the back of the buffer with `timestamp`
// being the input timestamp of the packet containing the Matrix.
void Push(const Matrix& samples, Timestamp timestamp);
// Copies `count` samples from the front of the buffer. If there are fewer
// samples than this, the result is zero padded to have `count` samples.
// The timestamp of the last copied sample is written to *last_timestamp.
// This output is used below to update `current_timestamp_`, which is only
// used when `use_local_timestamp` is true.
Matrix CopySamples(int count, Timestamp* last_timestamp) const;
// Drops `count` samples from the front of the buffer. If `count` exceeds
// `num_samples()`, the buffer is emptied. Returns how many samples were
// dropped.
int DropSamples(int count);
private:
struct Block {
// Matrix of num_channels rows by num_samples columns, a block of possibly
// multiple samples.
Matrix samples;
// Timestamp of the first sample in the Block. This comes from the input
// packet's timestamp that contains this Matrix.
Timestamp timestamp;
Block() : timestamp(Timestamp::Unstarted()) {}
Block(const Matrix& samples, Timestamp timestamp)
: samples(samples), timestamp(timestamp) {}
int num_samples() const { return samples.cols(); }
};
std::vector<Block> blocks_;
// Number of timestamp units per sample. Used to compute timestamps as
// nth sample timestamp = base_timestamp + round(ts_units_per_sample_ * n).
double ts_units_per_sample_;
// Number of rows in each Matrix.
int num_channels_;
// The total number of samples over all blocks, equal to
// (sum_i blocks_[i].num_samples()) - first_block_offset_.
int num_samples_;
// The number of samples in the first block that have been discarded. This
// way we can cheaply represent "partially discarding" a block.
int first_block_offset_;
} sample_buffer_;
bool use_window_; bool use_window_;
Matrix window_; Eigen::RowVectorXf window_;
bool use_local_timestamp_; bool use_local_timestamp_;
}; };
REGISTER_CALCULATOR(TimeSeriesFramerCalculator); REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) { void TimeSeriesFramerCalculator::SampleBlockBuffer::Push(const Matrix& samples,
const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>(); Timestamp timestamp) {
num_samples_ += samples.cols();
for (int i = 0; i < input_frame.cols(); ++i) { blocks_.emplace_back(samples, timestamp);
sample_buffer_.emplace_back(std::make_pair(
input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i)));
}
} }
void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { Matrix TimeSeriesFramerCalculator::SampleBlockBuffer::CopySamples(
while (sample_buffer_.size() >= int count, Timestamp* last_timestamp) const {
Matrix copied(num_channels_, count);
if (!blocks_.empty()) {
int num_copied = 0;
// First block has an offset for samples that have been discarded.
int offset = first_block_offset_;
int n;
Timestamp last_block_ts;
int last_sample_index;
for (auto it = blocks_.begin(); it != blocks_.end() && count > 0; ++it) {
n = std::min(it->num_samples() - offset, count);
// Copy `n` samples from the next block.
copied.middleCols(num_copied, n) = it->samples.middleCols(offset, n);
count -= n;
num_copied += n;
last_block_ts = it->timestamp;
last_sample_index = offset + n - 1;
offset = 0; // No samples have been discarded in subsequent blocks.
}
// Compute the timestamp of the last copied sample.
*last_timestamp =
last_block_ts + std::round(ts_units_per_sample_ * last_sample_index);
}
if (count > 0) {
copied.rightCols(count).setZero(); // Zero pad if needed.
}
return copied;
}
int TimeSeriesFramerCalculator::SampleBlockBuffer::DropSamples(int count) {
if (blocks_.empty()) {
return 0;
}
auto block_it = blocks_.begin();
if (first_block_offset_ + count < block_it->num_samples()) {
// `count` is less than the remaining samples in the first block.
first_block_offset_ += count;
num_samples_ -= count;
return count;
}
int num_samples_dropped = block_it->num_samples() - first_block_offset_;
count -= num_samples_dropped;
first_block_offset_ = 0;
for (++block_it; block_it != blocks_.end(); ++block_it) {
if (block_it->num_samples() > count) {
break;
}
num_samples_dropped += block_it->num_samples();
count -= block_it->num_samples();
}
blocks_.erase(blocks_.begin(), block_it); // Drop whole blocks.
if (!blocks_.empty()) {
first_block_offset_ = count; // Drop part of the next block.
num_samples_dropped += count;
}
num_samples_ -= num_samples_dropped;
return num_samples_dropped;
}
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
initial_input_timestamp_ = cc->InputTimestamp();
current_timestamp_ = initial_input_timestamp_;
}
// Add input data to the internal buffer.
sample_buffer_.Push(cc->Inputs().Index(0).Get<Matrix>(),
cc->InputTimestamp());
// Construct and emit framed output packets.
while (sample_buffer_.num_samples() >=
frame_duration_samples_ + samples_still_to_drop_) { frame_duration_samples_ + samples_still_to_drop_) {
while (samples_still_to_drop_ > 0) { sample_buffer_.DropSamples(samples_still_to_drop_);
sample_buffer_.pop_front(); Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
--samples_still_to_drop_; &current_timestamp_);
}
const int frame_step_samples = next_frame_step_samples(); const int frame_step_samples = next_frame_step_samples();
std::unique_ptr<Matrix> output_frame( samples_still_to_drop_ = frame_step_samples;
new Matrix(num_channels_, frame_duration_samples_));
for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_);
++i) {
output_frame->col(i) = sample_buffer_.front().first;
current_timestamp_ = sample_buffer_.front().second;
sample_buffer_.pop_front();
}
const int frame_overlap_samples =
frame_duration_samples_ - frame_step_samples;
if (frame_overlap_samples > 0) {
for (int i = 0; i < frame_overlap_samples; ++i) {
output_frame->col(i + frame_step_samples) = sample_buffer_[i].first;
current_timestamp_ = sample_buffer_[i].second;
}
} else {
samples_still_to_drop_ = -frame_overlap_samples;
}
if (use_window_) { if (use_window_) {
*output_frame = (output_frame->array() * window_.array()).matrix(); // Apply the window to each row of output_frame.
output_frame.array().rowwise() *= window_.array();
} }
cc->Outputs().Index(0).Add(output_frame.release(), cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
CurrentOutputTimestamp()); .At(CurrentOutputTimestamp()));
++cumulative_output_frames_; ++cumulative_output_frames_;
cumulative_completed_samples_ += frame_step_samples; cumulative_completed_samples_ += frame_step_samples;
} }
@ -206,35 +304,18 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
// fact to enable packet queueing optimizations. // fact to enable packet queueing optimizations.
cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp()); cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
} }
}
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
initial_input_timestamp_ = cc->InputTimestamp();
current_timestamp_ = initial_input_timestamp_;
}
EnqueueInput(cc);
FrameOutput(cc);
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) { absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) { sample_buffer_.DropSamples(samples_still_to_drop_);
sample_buffer_.pop_front();
--samples_still_to_drop_;
}
if (!sample_buffer_.empty() && pad_final_packet_) {
std::unique_ptr<Matrix> output_frame(new Matrix);
output_frame->setZero(num_channels_, frame_duration_samples_);
for (int i = 0; i < sample_buffer_.size(); ++i) {
output_frame->col(i) = sample_buffer_[i].first;
current_timestamp_ = sample_buffer_[i].second;
}
cc->Outputs().Index(0).Add(output_frame.release(), if (sample_buffer_.num_samples() > 0 && pad_final_packet_) {
CurrentOutputTimestamp()); Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
&current_timestamp_);
cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
.At(CurrentOutputTimestamp()));
} }
return absl::OkStatus(); return absl::OkStatus();
@ -258,7 +339,7 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
cc->Inputs().Index(0).Header(), &input_header)); cc->Inputs().Index(0).Header(), &input_header));
sample_rate_ = input_header.sample_rate(); sample_rate_ = input_header.sample_rate();
num_channels_ = input_header.num_channels(); sample_buffer_.Init(sample_rate_, input_header.num_channels());
frame_duration_samples_ = time_series_util::SecondsToSamples( frame_duration_samples_ = time_series_util::SecondsToSamples(
framer_options.frame_duration_seconds(), sample_rate_); framer_options.frame_duration_seconds(), sample_rate_);
RET_CHECK_GT(frame_duration_samples_, 0) RET_CHECK_GT(frame_duration_samples_, 0)
@ -312,8 +393,7 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
} }
if (use_window_) { if (use_window_) {
window_ = Matrix::Ones(num_channels_, 1) * window_ = Eigen::Map<Eigen::RowVectorXd>(window_vector.data(),
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
frame_duration_samples_) frame_duration_samples_)
.cast<float>(); .cast<float>();
} }

View File

@ -0,0 +1,93 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Benchmark for TimeSeriesFramerCalculator.
#include <memory>
#include <random>
#include <vector>
#include "absl/log/absl_check.h"
#include "benchmark/benchmark.h"
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/packet.h"
using ::mediapipe::Matrix;
void BM_TimeSeriesFramerCalculator(benchmark::State& state) {
constexpr float kSampleRate = 32000.0;
constexpr int kNumChannels = 2;
constexpr int kFrameDurationSeconds = 5.0;
std::mt19937 rng(0 /*seed*/);
// Input around a half second's worth of samples at a time.
std::uniform_int_distribution<int> input_size_dist(15000, 17000);
// Generate a pool of random blocks of samples up front.
std::vector<Matrix> sample_pool;
sample_pool.reserve(20);
for (int i = 0; i < 20; ++i) {
sample_pool.push_back(Matrix::Random(kNumChannels, input_size_dist(rng)));
}
std::uniform_int_distribution<int> pool_index_dist(0, sample_pool.size() - 1);
mediapipe::CalculatorGraphConfig config;
config.add_input_stream("input");
config.add_output_stream("output");
auto* node = config.add_node();
node->set_calculator("TimeSeriesFramerCalculator");
node->add_input_stream("input");
node->add_output_stream("output");
mediapipe::TimeSeriesFramerCalculatorOptions* options =
node->mutable_options()->MutableExtension(
mediapipe::TimeSeriesFramerCalculatorOptions::ext);
options->set_frame_duration_seconds(kFrameDurationSeconds);
for (auto _ : state) {
state.PauseTiming(); // Pause benchmark timing.
// Prepare input packets of random blocks of samples.
std::vector<mediapipe::Packet> input_packets;
input_packets.reserve(32);
float t = 0;
for (int i = 0; i < 32; ++i) {
auto samples =
std::make_unique<Matrix>(sample_pool[pool_index_dist(rng)]);
const int num_samples = samples->cols();
input_packets.push_back(mediapipe::Adopt(samples.release())
.At(mediapipe::Timestamp::FromSeconds(t)));
t += num_samples / kSampleRate;
}
// Initialize graph.
mediapipe::CalculatorGraph graph;
ABSL_CHECK_OK(graph.Initialize(config));
// Prepare input header.
auto header = std::make_unique<mediapipe::TimeSeriesHeader>();
header->set_sample_rate(kSampleRate);
header->set_num_channels(kNumChannels);
state.ResumeTiming(); // Resume benchmark timing.
ABSL_CHECK_OK(graph.StartRun({}, {{"input", Adopt(header.release())}}));
for (auto& packet : input_packets) {
ABSL_CHECK_OK(graph.AddPacketToInputStream("input", packet));
}
ABSL_CHECK(!graph.HasError());
ABSL_CHECK_OK(graph.CloseAllInputStreams());
ABSL_CHECK_OK(graph.WaitUntilIdle());
}
}
BENCHMARK(BM_TimeSeriesFramerCalculator);
BENCHMARK_MAIN();

View File

@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "Eigen/Core" #include "Eigen/Core"
#include "absl/log/absl_log.h"
#include "audio/dsp/window_functions.h" #include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h" #include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -186,11 +187,12 @@ class TimeSeriesFramerCalculatorTest
const int num_unique_output_samples = const int num_unique_output_samples =
round((output().packets.size() - 1) * frame_step_samples) + round((output().packets.size() - 1) * frame_step_samples) +
frame_duration_samples; frame_duration_samples;
LOG(INFO) << "packets.size()=" << output().packets.size() ABSL_LOG(INFO) << "packets.size()=" << output().packets.size()
<< " frame_duration_samples=" << frame_duration_samples << " frame_duration_samples=" << frame_duration_samples
<< " frame_step_samples=" << frame_step_samples << " frame_step_samples=" << frame_step_samples
<< " num_input_samples_=" << num_input_samples_ << " num_input_samples_=" << num_input_samples_
<< " num_unique_output_samples=" << num_unique_output_samples; << " num_unique_output_samples="
<< num_unique_output_samples;
const int num_padding_samples = const int num_padding_samples =
num_unique_output_samples - num_input_samples_; num_unique_output_samples - num_input_samples_;
if (options_.pad_final_packet()) { if (options_.pad_final_packet()) {

View File

@ -117,6 +117,7 @@ mediapipe_proto_library(
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/framework/formats:classification_proto", "//mediapipe/framework/formats:classification_proto",
"//mediapipe/framework/formats:landmark_proto", "//mediapipe/framework/formats:landmark_proto",
"//mediapipe/framework/formats:matrix_data_proto",
"//mediapipe/framework/formats:time_series_header_proto", "//mediapipe/framework/formats:time_series_header_proto",
], ],
) )
@ -289,6 +290,7 @@ cc_library(
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
@ -379,17 +381,6 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "clip_detection_vector_size_calculator",
srcs = ["clip_detection_vector_size_calculator.cc"],
deps = [
":clip_vector_size_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
],
alwayslink = 1,
)
cc_test( cc_test(
name = "clip_vector_size_calculator_test", name = "clip_vector_size_calculator_test",
srcs = ["clip_vector_size_calculator_test.cc"], srcs = ["clip_vector_size_calculator_test.cc"],
@ -591,6 +582,7 @@ cc_library(
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_util", "//mediapipe/framework/tool:options_util",
"@com_google_absl//absl/log:absl_check",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -606,6 +598,7 @@ cc_test(
"//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -638,6 +631,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_log",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -785,10 +779,11 @@ cc_library(
"//mediapipe/framework/deps:random", "//mediapipe/framework/deps:random",
"//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_util", "//mediapipe/framework/tool:options_util",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
alwayslink = 1, alwayslink = 1,
@ -844,6 +839,7 @@ cc_test(
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/log:absl_check",
"@eigen_archive//:eigen3", "@eigen_archive//:eigen3",
], ],
) )
@ -1031,6 +1027,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_log",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -1069,6 +1066,7 @@ cc_test(
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/log:absl_log",
], ],
) )
@ -1115,6 +1113,7 @@ cc_library(
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_log",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -1167,6 +1166,7 @@ cc_library(
"//mediapipe/framework:collection_item_id", "//mediapipe/framework:collection_item_id",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix_data_cc_proto",
"//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",

View File

@ -76,4 +76,9 @@ REGISTER_CALCULATOR(BeginLoopGpuBufferCalculator);
// A calculator to process std::vector<mediapipe::Image>. // A calculator to process std::vector<mediapipe::Image>.
typedef BeginLoopCalculator<std::vector<Image>> BeginLoopImageCalculator; typedef BeginLoopCalculator<std::vector<Image>> BeginLoopImageCalculator;
REGISTER_CALCULATOR(BeginLoopImageCalculator); REGISTER_CALCULATOR(BeginLoopImageCalculator);
// A calculator to process std::vector<float>.
typedef BeginLoopCalculator<std::vector<float>> BeginLoopFloatCalculator;
REGISTER_CALCULATOR(BeginLoopFloatCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -17,6 +17,7 @@
#include <vector> #include <vector>
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
@ -104,4 +105,7 @@ typedef ConcatenateVectorCalculator<mediapipe::RenderData>
ConcatenateRenderDataVectorCalculator; ConcatenateRenderDataVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator); MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator);
typedef ConcatenateVectorCalculator<mediapipe::Image>
ConcatenateImageVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateImageVectorCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -19,6 +19,7 @@
#include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix_data.pb.h"
#include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
@ -85,8 +86,12 @@ class ConstantSidePacketCalculator : public CalculatorBase {
packet.Set<LandmarkList>(); packet.Set<LandmarkList>();
} else if (packet_options.has_double_value()) { } else if (packet_options.has_double_value()) {
packet.Set<double>(); packet.Set<double>();
} else if (packet_options.has_matrix_data_value()) {
packet.Set<MatrixData>();
} else if (packet_options.has_time_series_header_value()) { } else if (packet_options.has_time_series_header_value()) {
packet.Set<TimeSeriesHeader>(); packet.Set<TimeSeriesHeader>();
} else if (packet_options.has_int64_value()) {
packet.Set<int64_t>();
} else { } else {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
"None of supported values were specified in options."); "None of supported values were specified in options.");
@ -121,9 +126,13 @@ class ConstantSidePacketCalculator : public CalculatorBase {
MakePacket<LandmarkList>(packet_options.landmark_list_value())); MakePacket<LandmarkList>(packet_options.landmark_list_value()));
} else if (packet_options.has_double_value()) { } else if (packet_options.has_double_value()) {
packet.Set(MakePacket<double>(packet_options.double_value())); packet.Set(MakePacket<double>(packet_options.double_value()));
} else if (packet_options.has_matrix_data_value()) {
packet.Set(MakePacket<MatrixData>(packet_options.matrix_data_value()));
} else if (packet_options.has_time_series_header_value()) { } else if (packet_options.has_time_series_header_value()) {
packet.Set(MakePacket<TimeSeriesHeader>( packet.Set(MakePacket<TimeSeriesHeader>(
packet_options.time_series_header_value())); packet_options.time_series_header_value()));
} else if (packet_options.has_int64_value()) {
packet.Set(MakePacket<int64_t>(packet_options.int64_value()));
} else { } else {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
"None of supported values were specified in options."); "None of supported values were specified in options.");

View File

@ -19,6 +19,7 @@ package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/formats/classification.proto"; import "mediapipe/framework/formats/classification.proto";
import "mediapipe/framework/formats/landmark.proto"; import "mediapipe/framework/formats/landmark.proto";
import "mediapipe/framework/formats/matrix_data.proto";
import "mediapipe/framework/formats/time_series_header.proto"; import "mediapipe/framework/formats/time_series_header.proto";
message ConstantSidePacketCalculatorOptions { message ConstantSidePacketCalculatorOptions {
@ -29,14 +30,16 @@ message ConstantSidePacketCalculatorOptions {
message ConstantSidePacket { message ConstantSidePacket {
oneof value { oneof value {
int32 int_value = 1; int32 int_value = 1;
uint64 uint64_value = 5;
int64 int64_value = 11;
float float_value = 2; float float_value = 2;
double double_value = 9;
bool bool_value = 3; bool bool_value = 3;
string string_value = 4; string string_value = 4;
uint64 uint64_value = 5;
ClassificationList classification_list_value = 6; ClassificationList classification_list_value = 6;
LandmarkList landmark_list_value = 7; LandmarkList landmark_list_value = 7;
double double_value = 9;
TimeSeriesHeader time_series_header_value = 10; TimeSeriesHeader time_series_header_value = 10;
MatrixData matrix_data_value = 12;
} }
} }

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cstdint>
#include <string> #include <string>
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
@ -58,6 +59,7 @@ TEST(ConstantSidePacketCalculatorTest, EveryPossibleType) {
DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f); DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
DoTestSingleSidePacket("{ bool_value: true }", true); DoTestSingleSidePacket("{ bool_value: true }", true);
DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str"); DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
DoTestSingleSidePacket<int64_t>("{ int64_value: 63 }", 63);
} }
TEST(ConstantSidePacketCalculatorTest, MultiplePackets) { TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {

View File

@ -14,6 +14,8 @@
#include "mediapipe/calculators/core/end_loop_calculator.h" #include "mediapipe/calculators/core/end_loop_calculator.h"
#include <array>
#include <utility>
#include <vector> #include <vector>
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
@ -84,4 +86,8 @@ typedef EndLoopCalculator<std::vector<std::array<float, 16>>>
EndLoopAffineMatrixCalculator; EndLoopAffineMatrixCalculator;
REGISTER_CALCULATOR(EndLoopAffineMatrixCalculator); REGISTER_CALCULATOR(EndLoopAffineMatrixCalculator);
typedef EndLoopCalculator<std::vector<std::pair<int, int>>>
EndLoopImageSizeCalculator;
REGISTER_CALCULATOR(EndLoopImageSizeCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/log/absl_log.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
@ -356,18 +357,18 @@ TEST_F(GateCalculatorTest, AllowWithStateChangeNoDataStreams) {
RunTimeStepWithoutDataStream(kTimestampValue2, "ALLOW", true); RunTimeStepWithoutDataStream(kTimestampValue2, "ALLOW", true);
constexpr int64_t kTimestampValue3 = 45; constexpr int64_t kTimestampValue3 = 45;
RunTimeStepWithoutDataStream(kTimestampValue3, "ALLOW", false); RunTimeStepWithoutDataStream(kTimestampValue3, "ALLOW", false);
LOG(INFO) << "a"; ABSL_LOG(INFO) << "a";
const std::vector<Packet>& output = const std::vector<Packet>& output =
runner()->Outputs().Get("STATE_CHANGE", 0).packets; runner()->Outputs().Get("STATE_CHANGE", 0).packets;
LOG(INFO) << "s"; ABSL_LOG(INFO) << "s";
ASSERT_EQ(2, output.size()); ASSERT_EQ(2, output.size());
LOG(INFO) << "d"; ABSL_LOG(INFO) << "d";
EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value()); EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value());
EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value()); EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value());
LOG(INFO) << "f"; ABSL_LOG(INFO) << "f";
EXPECT_EQ(true, output[0].Get<bool>()); // Allow. EXPECT_EQ(true, output[0].Get<bool>()); // Allow.
EXPECT_EQ(false, output[1].Get<bool>()); // Disallow. EXPECT_EQ(false, output[1].Get<bool>()); // Disallow.
LOG(INFO) << "g"; ABSL_LOG(INFO) << "g";
} }
TEST_F(GateCalculatorTest, DisallowWithStateChange) { TEST_F(GateCalculatorTest, DisallowWithStateChange) {

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/log/absl_log.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
@ -78,7 +79,7 @@ absl::Status ImmediateMuxCalculator::Process(CalculatorContext* cc) {
if (packet.Timestamp() >= cc->Outputs().Index(0).NextTimestampBound()) { if (packet.Timestamp() >= cc->Outputs().Index(0).NextTimestampBound()) {
cc->Outputs().Index(0).AddPacket(packet); cc->Outputs().Index(0).AddPacket(packet);
} else { } else {
LOG_FIRST_N(WARNING, 5) ABSL_LOG_FIRST_N(WARNING, 5)
<< "Dropping a packet with timestamp " << packet.Timestamp(); << "Dropping a packet with timestamp " << packet.Timestamp();
} }
if (cc->Outputs().NumEntries() >= 2) { if (cc->Outputs().NumEntries() >= 2) {

View File

@ -16,6 +16,7 @@
#include <vector> #include <vector>
#include "Eigen/Core" #include "Eigen/Core"
#include "absl/log/absl_check.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
@ -209,7 +210,7 @@ TEST(MatrixMultiplyCalculatorTest, Multiply) {
MatrixFromTextProto(kSamplesText, &samples); MatrixFromTextProto(kSamplesText, &samples);
Matrix expected; Matrix expected;
MatrixFromTextProto(kExpectedText, &expected); MatrixFromTextProto(kExpectedText, &expected);
CHECK_EQ(samples.cols(), expected.cols()); ABSL_CHECK_EQ(samples.cols(), expected.cols());
for (int i = 0; i < samples.cols(); ++i) { for (int i = 0; i < samples.cols(); ++i) {
// Take a column from samples and produce a packet with just that // Take a column from samples and produce a packet with just that

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/log/absl_log.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -53,7 +54,7 @@ class MergeCalculator : public Node {
static absl::Status UpdateContract(CalculatorContract* cc) { static absl::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK_GT(kIn(cc).Count(), 0) << "Needs at least one input stream"; RET_CHECK_GT(kIn(cc).Count(), 0) << "Needs at least one input stream";
if (kIn(cc).Count() == 1) { if (kIn(cc).Count() == 1) {
LOG(WARNING) ABSL_LOG(WARNING)
<< "MergeCalculator expects multiple input streams to merge but is " << "MergeCalculator expects multiple input streams to merge but is "
"receiving only one. Make sure the calculator is configured " "receiving only one. Make sure the calculator is configured "
"correctly or consider removing this calculator to reduce " "correctly or consider removing this calculator to reduce "
@ -72,7 +73,7 @@ class MergeCalculator : public Node {
} }
} }
LOG(WARNING) << "Empty input packets at timestamp " ABSL_LOG(WARNING) << "Empty input packets at timestamp "
<< cc->InputTimestamp().Value(); << cc->InputTimestamp().Value();
return absl::OkStatus(); return absl::OkStatus();

View File

@ -16,6 +16,9 @@
#include <memory> #include <memory>
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
namespace { namespace {
// Reflect an integer against the lower and upper bound of an interval. // Reflect an integer against the lower and upper bound of an interval.
int64_t ReflectBetween(int64_t ts, int64_t ts_min, int64_t ts_max) { int64_t ReflectBetween(int64_t ts, int64_t ts_min, int64_t ts_max) {
@ -177,7 +180,7 @@ PacketResamplerCalculator::GetSamplingStrategy(
const PacketResamplerCalculatorOptions& options) { const PacketResamplerCalculatorOptions& options) {
if (options.reproducible_sampling()) { if (options.reproducible_sampling()) {
if (!options.jitter_with_reflection()) { if (!options.jitter_with_reflection()) {
LOG(WARNING) ABSL_LOG(WARNING)
<< "reproducible_sampling enabled w/ jitter_with_reflection " << "reproducible_sampling enabled w/ jitter_with_reflection "
"disabled. " "disabled. "
<< "reproducible_sampling always uses jitter with reflection, " << "reproducible_sampling always uses jitter with reflection, "
@ -200,15 +203,15 @@ PacketResamplerCalculator::GetSamplingStrategy(
Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp( Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp(
int64_t index) const { int64_t index) const {
CHECK_EQ(jitter_, 0.0); ABSL_CHECK_EQ(jitter_, 0.0);
CHECK_NE(first_timestamp_, Timestamp::Unset()); ABSL_CHECK_NE(first_timestamp_, Timestamp::Unset());
return first_timestamp_ + TimestampDiffFromSeconds(index / frame_rate_); return first_timestamp_ + TimestampDiffFromSeconds(index / frame_rate_);
} }
int64_t PacketResamplerCalculator::TimestampToPeriodIndex( int64_t PacketResamplerCalculator::TimestampToPeriodIndex(
Timestamp timestamp) const { Timestamp timestamp) const {
CHECK_EQ(jitter_, 0.0); ABSL_CHECK_EQ(jitter_, 0.0);
CHECK_NE(first_timestamp_, Timestamp::Unset()); ABSL_CHECK_NE(first_timestamp_, Timestamp::Unset());
return MathUtil::SafeRound<int64_t, double>( return MathUtil::SafeRound<int64_t, double>(
(timestamp - first_timestamp_).Seconds() * frame_rate_); (timestamp - first_timestamp_).Seconds() * frame_rate_);
} }
@ -229,12 +232,14 @@ absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) {
if (resampler_options.output_header() != if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE) { PacketResamplerCalculatorOptions::NONE) {
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not " ABSL_LOG(WARNING)
<< "VideoHeader::frame_rate holds the target value and not "
"the actual value."; "the actual value.";
} }
if (calculator_->flush_last_packet_) { if (calculator_->flush_last_packet_) {
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is " ABSL_LOG(WARNING)
<< "PacketResamplerCalculatorOptions.flush_last_packet is "
"ignored, because we are adding jitter."; "ignored, because we are adding jitter.";
} }
@ -254,7 +259,7 @@ absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) {
} }
absl::Status LegacyJitterWithReflectionStrategy::Close(CalculatorContext* cc) { absl::Status LegacyJitterWithReflectionStrategy::Close(CalculatorContext* cc) {
if (!packet_reservoir_->IsEmpty()) { if (!packet_reservoir_->IsEmpty()) {
LOG(INFO) << "Emitting pack from reservoir."; ABSL_LOG(INFO) << "Emitting pack from reservoir.";
calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample()); calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample());
} }
return absl::OkStatus(); return absl::OkStatus();
@ -285,7 +290,7 @@ absl::Status LegacyJitterWithReflectionStrategy::Process(
if (calculator_->frame_time_usec_ < if (calculator_->frame_time_usec_ <
(cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) { (cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) {
LOG_FIRST_N(WARNING, 2) ABSL_LOG_FIRST_N(WARNING, 2)
<< "Adding jitter is not very useful when upsampling."; << "Adding jitter is not very useful when upsampling.";
} }
@ -340,8 +345,8 @@ void LegacyJitterWithReflectionStrategy::UpdateNextOutputTimestampWithJitter() {
next_output_timestamp_ = Timestamp(ReflectBetween( next_output_timestamp_ = Timestamp(ReflectBetween(
next_output_timestamp_.Value(), next_output_timestamp_min_.Value(), next_output_timestamp_.Value(), next_output_timestamp_min_.Value(),
next_output_timestamp_max_.Value())); next_output_timestamp_max_.Value()));
CHECK_GE(next_output_timestamp_, next_output_timestamp_min_); ABSL_CHECK_GE(next_output_timestamp_, next_output_timestamp_min_);
CHECK_LT(next_output_timestamp_, next_output_timestamp_max_); ABSL_CHECK_LT(next_output_timestamp_, next_output_timestamp_max_);
} }
absl::Status ReproducibleJitterWithReflectionStrategy::Open( absl::Status ReproducibleJitterWithReflectionStrategy::Open(
@ -352,12 +357,14 @@ absl::Status ReproducibleJitterWithReflectionStrategy::Open(
if (resampler_options.output_header() != if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE) { PacketResamplerCalculatorOptions::NONE) {
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not " ABSL_LOG(WARNING)
<< "VideoHeader::frame_rate holds the target value and not "
"the actual value."; "the actual value.";
} }
if (calculator_->flush_last_packet_) { if (calculator_->flush_last_packet_) {
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is " ABSL_LOG(WARNING)
<< "PacketResamplerCalculatorOptions.flush_last_packet is "
"ignored, because we are adding jitter."; "ignored, because we are adding jitter.";
} }
@ -411,7 +418,7 @@ absl::Status ReproducibleJitterWithReflectionStrategy::Process(
// Note, if the stream is upsampling, this could lead to the same packet // Note, if the stream is upsampling, this could lead to the same packet
// being emitted twice. Upsampling and jitter doesn't make much sense // being emitted twice. Upsampling and jitter doesn't make much sense
// but does technically work. // but does technically work.
LOG_FIRST_N(WARNING, 2) ABSL_LOG_FIRST_N(WARNING, 2)
<< "Adding jitter is not very useful when upsampling."; << "Adding jitter is not very useful when upsampling.";
} }
@ -499,12 +506,14 @@ absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) {
if (resampler_options.output_header() != if (resampler_options.output_header() !=
PacketResamplerCalculatorOptions::NONE) { PacketResamplerCalculatorOptions::NONE) {
LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not " ABSL_LOG(WARNING)
<< "VideoHeader::frame_rate holds the target value and not "
"the actual value."; "the actual value.";
} }
if (calculator_->flush_last_packet_) { if (calculator_->flush_last_packet_) {
LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is " ABSL_LOG(WARNING)
<< "PacketResamplerCalculatorOptions.flush_last_packet is "
"ignored, because we are adding jitter."; "ignored, because we are adding jitter.";
} }
@ -555,7 +564,7 @@ absl::Status JitterWithoutReflectionStrategy::Process(CalculatorContext* cc) {
if (calculator_->frame_time_usec_ < if (calculator_->frame_time_usec_ <
(cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) { (cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) {
LOG_FIRST_N(WARNING, 2) ABSL_LOG_FIRST_N(WARNING, 2)
<< "Adding jitter is not very useful when upsampling."; << "Adding jitter is not very useful when upsampling.";
} }

View File

@ -13,7 +13,6 @@
#include "mediapipe/framework/deps/random_base.h" #include "mediapipe/framework/deps/random_base.h"
#include "mediapipe/framework/formats/video_stream_header.h" #include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"

View File

@ -17,6 +17,7 @@
#include <cmath> // for ceil #include <cmath> // for ceil
#include <memory> #include <memory>
#include "absl/log/absl_check.h"
#include "mediapipe/calculators/core/packet_thinner_calculator.pb.h" #include "mediapipe/calculators/core/packet_thinner_calculator.pb.h"
#include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -160,7 +161,7 @@ absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) {
thinner_type_ = options.thinner_type(); thinner_type_ = options.thinner_type();
// This check enables us to assume only two thinner types exist in Process() // This check enables us to assume only two thinner types exist in Process()
CHECK(thinner_type_ == PacketThinnerCalculatorOptions::ASYNC || ABSL_CHECK(thinner_type_ == PacketThinnerCalculatorOptions::ASYNC ||
thinner_type_ == PacketThinnerCalculatorOptions::SYNC) thinner_type_ == PacketThinnerCalculatorOptions::SYNC)
<< "Unsupported thinner type."; << "Unsupported thinner type.";
@ -177,7 +178,8 @@ absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) {
} else { } else {
period_ = TimestampDiff(options.period()); period_ = TimestampDiff(options.period());
} }
CHECK_LT(TimestampDiff(0), period_) << "Specified period must be positive."; ABSL_CHECK_LT(TimestampDiff(0), period_)
<< "Specified period must be positive.";
if (options.has_start_time()) { if (options.has_start_time()) {
start_time_ = Timestamp(options.start_time()); start_time_ = Timestamp(options.start_time());
@ -189,7 +191,7 @@ absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) {
end_time_ = end_time_ =
options.has_end_time() ? Timestamp(options.end_time()) : Timestamp::Max(); options.has_end_time() ? Timestamp(options.end_time()) : Timestamp::Max();
CHECK_LT(start_time_, end_time_) ABSL_CHECK_LT(start_time_, end_time_)
<< "Invalid PacketThinner: start_time must be earlier than end_time"; << "Invalid PacketThinner: start_time must be earlier than end_time";
sync_output_timestamps_ = options.sync_output_timestamps(); sync_output_timestamps_ = options.sync_output_timestamps();
@ -232,7 +234,7 @@ absl::Status PacketThinnerCalculator::Close(CalculatorContext* cc) {
// Emit any saved packets before quitting. // Emit any saved packets before quitting.
if (!saved_packet_.IsEmpty()) { if (!saved_packet_.IsEmpty()) {
// Only sync thinner should have saved packets. // Only sync thinner should have saved packets.
CHECK_EQ(PacketThinnerCalculatorOptions::SYNC, thinner_type_); ABSL_CHECK_EQ(PacketThinnerCalculatorOptions::SYNC, thinner_type_);
if (sync_output_timestamps_) { if (sync_output_timestamps_) {
cc->Outputs().Index(0).AddPacket( cc->Outputs().Index(0).AddPacket(
saved_packet_.At(NearestSyncTimestamp(saved_packet_.Timestamp()))); saved_packet_.At(NearestSyncTimestamp(saved_packet_.Timestamp())));
@ -269,7 +271,7 @@ absl::Status PacketThinnerCalculator::SyncThinnerProcess(
const Timestamp saved_sync = NearestSyncTimestamp(saved); const Timestamp saved_sync = NearestSyncTimestamp(saved);
const Timestamp now = cc->InputTimestamp(); const Timestamp now = cc->InputTimestamp();
const Timestamp now_sync = NearestSyncTimestamp(now); const Timestamp now_sync = NearestSyncTimestamp(now);
CHECK_LE(saved_sync, now_sync); ABSL_CHECK_LE(saved_sync, now_sync);
if (saved_sync == now_sync) { if (saved_sync == now_sync) {
// Saved Packet is in same interval as current packet. // Saved Packet is in same interval as current packet.
// Replace saved packet with current if it is at least as // Replace saved packet with current if it is at least as
@ -295,7 +297,7 @@ absl::Status PacketThinnerCalculator::SyncThinnerProcess(
} }
Timestamp PacketThinnerCalculator::NearestSyncTimestamp(Timestamp now) const { Timestamp PacketThinnerCalculator::NearestSyncTimestamp(Timestamp now) const {
CHECK_NE(start_time_, Timestamp::Unset()) ABSL_CHECK_NE(start_time_, Timestamp::Unset())
<< "Method only valid for sync thinner calculator."; << "Method only valid for sync thinner calculator.";
// Computation is done using int64 arithmetic. No easy way to avoid // Computation is done using int64 arithmetic. No easy way to avoid
@ -303,12 +305,12 @@ Timestamp PacketThinnerCalculator::NearestSyncTimestamp(Timestamp now) const {
const int64_t now64 = now.Value(); const int64_t now64 = now.Value();
const int64_t start64 = start_time_.Value(); const int64_t start64 = start_time_.Value();
const int64_t period64 = period_.Value(); const int64_t period64 = period_.Value();
CHECK_LE(0, period64); ABSL_CHECK_LE(0, period64);
// Round now64 to its closest interval (units of period64). // Round now64 to its closest interval (units of period64).
int64_t sync64 = int64_t sync64 =
(now64 - start64 + period64 / 2) / period64 * period64 + start64; (now64 - start64 + period64 / 2) / period64 * period64 + start64;
CHECK_LE(abs(now64 - sync64), period64 / 2) ABSL_CHECK_LE(abs(now64 - sync64), period64 / 2)
<< "start64: " << start64 << "; now64: " << now64 << "start64: " << start64 << "; now64: " << now64
<< "; sync64: " << sync64; << "; sync64: " << sync64;

View File

@ -16,6 +16,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/log/absl_check.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/calculators/core/packet_thinner_calculator.pb.h" #include "mediapipe/calculators/core/packet_thinner_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -70,7 +71,7 @@ class SimpleRunner : public CalculatorRunner {
} }
double GetFrameRate() const { double GetFrameRate() const {
CHECK(!Outputs().Index(0).header.IsEmpty()); ABSL_CHECK(!Outputs().Index(0).header.IsEmpty());
return Outputs().Index(0).header.Get<VideoHeader>().frame_rate; return Outputs().Index(0).header.Get<VideoHeader>().frame_rate;
} }
}; };

View File

@ -123,8 +123,11 @@ class PreviousLoopbackCalculator : public Node {
// However, LOOP packet is empty. // However, LOOP packet is empty.
kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1); kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1);
} else { } else {
// Avoids sending leftovers to a stream that's already closed.
if (!kPrevLoop(cc).IsClosed()) {
kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp)); kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp));
} }
}
loop_packets_.pop_front(); loop_packets_.pop_front();
main_packet_specs_.pop_front(); main_packet_specs_.pop_front();
} }

View File

@ -14,6 +14,7 @@
#include <deque> #include <deque>
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/core/sequence_shift_calculator.pb.h" #include "mediapipe/calculators/core/sequence_shift_calculator.pb.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -101,7 +102,7 @@ void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) {
kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp())); kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp()));
packet_cache_.pop_front(); packet_cache_.pop_front();
} else if (emit_empty_packets_before_first_packet_) { } else if (emit_empty_packets_before_first_packet_) {
LOG(FATAL) << "Not supported yet"; ABSL_LOG(FATAL) << "Not supported yet";
} }
// Store current packet for later output. // Store current packet for later output.
packet_cache_.push_back(kIn(cc).packet()); packet_cache_.push_back(kIn(cc).packet());

View File

@ -97,6 +97,7 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:source_location", "//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_check",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -125,6 +126,7 @@ cc_library(
"//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_check",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -135,7 +137,6 @@ cc_library(
deps = [ deps = [
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
], ],
@ -152,11 +153,11 @@ cc_library(
"//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",
"@com_google_absl//absl/log:absl_log",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [ "//conditions:default": [
@ -203,6 +204,7 @@ cc_library(
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
@ -301,6 +303,7 @@ cc_library(
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_log",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [ "//conditions:default": [
@ -397,6 +400,7 @@ cc_library(
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -421,6 +425,8 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:image_frame_util", "//mediapipe/util:image_frame_util",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@libyuv", "@libyuv",
], ],
@ -626,9 +632,9 @@ cc_library(
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",
"@com_google_absl//absl/log:absl_log",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [ "//conditions:default": [
@ -666,6 +672,7 @@ cc_test(
"//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/log:absl_log",
], ],
) )

View File

@ -384,6 +384,8 @@ class GlTextureWarpAffineRunner
glActiveTexture(GL_TEXTURE0); glActiveTexture(GL_TEXTURE0);
glBindTexture(GL_TEXTURE_2D, 0); glBindTexture(GL_TEXTURE_2D, 0);
glFlush();
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -15,6 +15,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "absl/log/absl_check.h"
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "mediapipe/calculators/image/bilateral_filter_calculator.pb.h" #include "mediapipe/calculators/image/bilateral_filter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -112,7 +113,7 @@ class BilateralFilterCalculator : public CalculatorBase {
REGISTER_CALCULATOR(BilateralFilterCalculator); REGISTER_CALCULATOR(BilateralFilterCalculator);
absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) { absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1); RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
if (cc->Inputs().HasTag(kInputFrameTag) && if (cc->Inputs().HasTag(kInputFrameTag) &&
cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().HasTag(kInputFrameTagGpu)) {
@ -183,8 +184,8 @@ absl::Status BilateralFilterCalculator::Open(CalculatorContext* cc) {
sigma_color_ = options_.sigma_color(); sigma_color_ = options_.sigma_color();
sigma_space_ = options_.sigma_space(); sigma_space_ = options_.sigma_space();
CHECK_GE(sigma_color_, 0.0); ABSL_CHECK_GE(sigma_color_, 0.0);
CHECK_GE(sigma_space_, 0.0); ABSL_CHECK_GE(sigma_space_, 0.0);
if (!use_gpu_) sigma_color_ *= 255.0; if (!use_gpu_) sigma_color_ *= 255.0;
if (use_gpu_) { if (use_gpu_) {

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/log/absl_check.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
@ -25,8 +26,8 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
void SetColorChannel(int channel, uint8 value, cv::Mat* mat) { void SetColorChannel(int channel, uint8 value, cv::Mat* mat) {
CHECK(mat->depth() == CV_8U); ABSL_CHECK(mat->depth() == CV_8U);
CHECK(channel < mat->channels()); ABSL_CHECK(channel < mat->channels());
const int step = mat->channels(); const int step = mat->channels();
for (int r = 0; r < mat->rows; ++r) { for (int r = 0; r < mat->rows; ++r) {
uint8* row_ptr = mat->ptr<uint8>(r); uint8* row_ptr = mat->ptr<uint8>(r);

View File

@ -16,6 +16,7 @@
#include <cmath> #include <cmath>
#include "absl/log/absl_log.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
@ -202,7 +203,8 @@ absl::Status ImageCroppingCalculator::ValidateBorderModeForGPU(
switch (options.border_mode()) { switch (options.border_mode()) {
case mediapipe::ImageCroppingCalculatorOptions::BORDER_ZERO: case mediapipe::ImageCroppingCalculatorOptions::BORDER_ZERO:
LOG(WARNING) << "BORDER_ZERO mode is not supported by GPU " ABSL_LOG(WARNING)
<< "BORDER_ZERO mode is not supported by GPU "
<< "implementation and will fall back into BORDER_REPLICATE"; << "implementation and will fall back into BORDER_REPLICATE";
break; break;
case mediapipe::ImageCroppingCalculatorOptions::BORDER_REPLICATE: case mediapipe::ImageCroppingCalculatorOptions::BORDER_REPLICATE:

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/log/absl_check.h"
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
@ -61,7 +62,7 @@ absl::Status OpenCvImageEncoderCalculator::Open(CalculatorContext* cc) {
absl::Status OpenCvImageEncoderCalculator::Process(CalculatorContext* cc) { absl::Status OpenCvImageEncoderCalculator::Process(CalculatorContext* cc) {
const ImageFrame& image_frame = cc->Inputs().Index(0).Get<ImageFrame>(); const ImageFrame& image_frame = cc->Inputs().Index(0).Get<ImageFrame>();
CHECK_EQ(1, image_frame.ByteDepth()); ABSL_CHECK_EQ(1, image_frame.ByteDepth());
std::unique_ptr<OpenCvImageEncoderCalculatorResults> encoded_result = std::unique_ptr<OpenCvImageEncoderCalculatorResults> encoded_result =
absl::make_unique<OpenCvImageEncoderCalculatorResults>(); absl::make_unique<OpenCvImageEncoderCalculatorResults>();

View File

@ -18,6 +18,8 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "libyuv/scale.h" #include "libyuv/scale.h"
@ -293,7 +295,7 @@ absl::Status ScaleImageCalculator::InitializeFrameInfo(CalculatorContext* cc) {
header->width = output_width_; header->width = output_width_;
header->height = output_height_; header->height = output_height_;
header->format = output_format_; header->format = output_format_;
LOG(INFO) << "OUTPUTTING HEADER on stream"; ABSL_LOG(INFO) << "OUTPUTTING HEADER on stream";
cc->Outputs() cc->Outputs()
.Tag("VIDEO_HEADER") .Tag("VIDEO_HEADER")
.Add(header.release(), Timestamp::PreStream()); .Add(header.release(), Timestamp::PreStream());
@ -393,7 +395,8 @@ absl::Status ScaleImageCalculator::Open(CalculatorContext* cc) {
.SetHeader(Adopt(output_header.release())); .SetHeader(Adopt(output_header.release()));
has_header_ = true; has_header_ = true;
} else { } else {
LOG(WARNING) << "Stream had a VideoHeader which didn't have sufficient " ABSL_LOG(WARNING)
<< "Stream had a VideoHeader which didn't have sufficient "
"information. " "information. "
"Dropping VideoHeader and trying to deduce needed " "Dropping VideoHeader and trying to deduce needed "
"information."; "information.";
@ -507,7 +510,7 @@ absl::Status ScaleImageCalculator::ValidateImageFrame(
absl::Status ScaleImageCalculator::ValidateYUVImage(CalculatorContext* cc, absl::Status ScaleImageCalculator::ValidateYUVImage(CalculatorContext* cc,
const YUVImage& yuv_image) { const YUVImage& yuv_image) {
CHECK_EQ(input_format_, ImageFormat::YCBCR420P); ABSL_CHECK_EQ(input_format_, ImageFormat::YCBCR420P);
if (!has_header_) { if (!has_header_) {
if (input_width_ != yuv_image.width() || if (input_width_ != yuv_image.width() ||
input_height_ != yuv_image.height()) { input_height_ != yuv_image.height()) {

View File

@ -18,6 +18,7 @@
#include <string> #include <string>
#include "absl/log/absl_check.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -40,10 +41,10 @@ absl::Status FindCropDimensions(int input_width, int input_height, //
const std::string& max_aspect_ratio, // const std::string& max_aspect_ratio, //
int* crop_width, int* crop_height, // int* crop_width, int* crop_height, //
int* col_start, int* row_start) { int* col_start, int* row_start) {
CHECK(crop_width); ABSL_CHECK(crop_width);
CHECK(crop_height); ABSL_CHECK(crop_height);
CHECK(col_start); ABSL_CHECK(col_start);
CHECK(row_start); ABSL_CHECK(row_start);
double min_aspect_ratio_q = 0.0; double min_aspect_ratio_q = 0.0;
double max_aspect_ratio_q = 0.0; double max_aspect_ratio_q = 0.0;
@ -83,8 +84,8 @@ absl::Status FindCropDimensions(int input_width, int input_height, //
} }
} }
CHECK_LE(*crop_width, input_width); ABSL_CHECK_LE(*crop_width, input_width);
CHECK_LE(*crop_height, input_height); ABSL_CHECK_LE(*crop_height, input_height);
return absl::OkStatus(); return absl::OkStatus();
} }
@ -96,8 +97,8 @@ absl::Status FindOutputDimensions(int input_width, //
bool preserve_aspect_ratio, // bool preserve_aspect_ratio, //
int scale_to_multiple_of, // int scale_to_multiple_of, //
int* output_width, int* output_height) { int* output_width, int* output_height) {
CHECK(output_width); ABSL_CHECK(output_width);
CHECK(output_height); ABSL_CHECK(output_height);
if (target_max_area > 0 && input_width * input_height > target_max_area) { if (target_max_area > 0 && input_width * input_height > target_max_area) {
preserve_aspect_ratio = true; preserve_aspect_ratio = true;

View File

@ -15,13 +15,13 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/image/segmentation_smoothing_calculator.pb.h" #include "mediapipe/calculators/image/segmentation_smoothing_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h" #include "mediapipe/framework/port/vector.h"
@ -110,7 +110,7 @@ REGISTER_CALCULATOR(SegmentationSmoothingCalculator);
absl::Status SegmentationSmoothingCalculator::GetContract( absl::Status SegmentationSmoothingCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1); RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
cc->Inputs().Tag(kCurrentMaskTag).Set<Image>(); cc->Inputs().Tag(kCurrentMaskTag).Set<Image>();
cc->Inputs().Tag(kPreviousMaskTag).Set<Image>(); cc->Inputs().Tag(kPreviousMaskTag).Set<Image>();
@ -273,7 +273,7 @@ absl::Status SegmentationSmoothingCalculator::RenderGpu(CalculatorContext* cc) {
const auto& previous_frame = cc->Inputs().Tag(kPreviousMaskTag).Get<Image>(); const auto& previous_frame = cc->Inputs().Tag(kPreviousMaskTag).Get<Image>();
if (previous_frame.format() != current_frame.format()) { if (previous_frame.format() != current_frame.format()) {
LOG(ERROR) << "Warning: mixing input format types. "; ABSL_LOG(ERROR) << "Warning: mixing input format types. ";
} }
auto previous_texture = gpu_helper_.CreateSourceTexture(previous_frame); auto previous_texture = gpu_helper_.CreateSourceTexture(previous_frame);

View File

@ -14,6 +14,7 @@
#include <memory> #include <memory>
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/image/segmentation_smoothing_calculator.pb.h" #include "mediapipe/calculators/image/segmentation_smoothing_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
@ -169,7 +170,7 @@ void RunTest(bool use_gpu, float mix_ratio, cv::Mat& test_result) {
} }
} }
} else { } else {
LOG(ERROR) << "invalid ratio"; ABSL_LOG(ERROR) << "invalid ratio";
} }
} }

View File

@ -14,13 +14,13 @@
#include <memory> #include <memory>
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/image/set_alpha_calculator.pb.h" #include "mediapipe/calculators/image/set_alpha_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
@ -142,7 +142,7 @@ class SetAlphaCalculator : public CalculatorBase {
REGISTER_CALCULATOR(SetAlphaCalculator); REGISTER_CALCULATOR(SetAlphaCalculator);
absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) { absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1); RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
bool use_gpu = false; bool use_gpu = false;
@ -268,7 +268,7 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>(); const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
const cv::Mat input_mat = formats::MatView(&input_frame); const cv::Mat input_mat = formats::MatView(&input_frame);
if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) { if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) {
LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported"; ABSL_LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported";
} }
// Setup destination image // Setup destination image
@ -328,7 +328,7 @@ absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) {
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
if (!(input_frame.format() == mediapipe::GpuBufferFormat::kBGRA32 || if (!(input_frame.format() == mediapipe::GpuBufferFormat::kBGRA32 ||
input_frame.format() == mediapipe::GpuBufferFormat::kRGB24)) { input_frame.format() == mediapipe::GpuBufferFormat::kRGB24)) {
LOG(ERROR) << "Only RGB or RGBA input image supported"; ABSL_LOG(ERROR) << "Only RGB or RGBA input image supported";
} }
auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); auto input_texture = gpu_helper_.CreateSourceTexture(input_frame);

View File

@ -38,7 +38,7 @@ std::string FourCCToString(libyuv::FourCC fourcc) {
buf[0] = (fourcc >> 24) & 0xff; buf[0] = (fourcc >> 24) & 0xff;
buf[1] = (fourcc >> 16) & 0xff; buf[1] = (fourcc >> 16) & 0xff;
buf[2] = (fourcc >> 8) & 0xff; buf[2] = (fourcc >> 8) & 0xff;
buf[3] = (fourcc)&0xff; buf[3] = (fourcc) & 0xff;
buf[4] = 0; buf[4] = 0;
return std::string(buf); return std::string(buf);
} }

View File

@ -31,12 +31,14 @@ mediapipe_proto_library(
cc_library( cc_library(
name = "callback_packet_calculator", name = "callback_packet_calculator",
srcs = ["callback_packet_calculator.cc"], srcs = ["callback_packet_calculator.cc"],
hdrs = ["callback_packet_calculator.h"],
visibility = ["//mediapipe/framework:__subpackages__"], visibility = ["//mediapipe/framework:__subpackages__"],
deps = [ deps = [
":callback_packet_calculator_cc_proto", ":callback_packet_calculator_cc_proto",
"//mediapipe/framework:calculator_base", "//mediapipe/framework:calculator_base",
"//mediapipe/framework:calculator_registry", "//mediapipe/framework:calculator_registry",
"//mediapipe/framework:output_side_packet", "//mediapipe/framework:output_side_packet",
"@com_google_absl//absl/status",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -11,10 +11,12 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mediapipe/calculators/internal/callback_packet_calculator.h"
#include <functional> #include <functional>
#include <string> #include <string>
#include "absl/status/status.h"
#include "mediapipe/calculators/internal/callback_packet_calculator.pb.h" // NOLINT #include "mediapipe/calculators/internal/callback_packet_calculator.pb.h" // NOLINT
#include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_registry.h" #include "mediapipe/framework/calculator_registry.h"
@ -39,18 +41,10 @@ void DumpPostStreamPacket(Packet* post_stream_packet, const Packet& packet) {
*post_stream_packet = packet; *post_stream_packet = packet;
} }
} }
} // namespace } // namespace
// Creates a callback which takes a packet and stores it either in a absl::Status CallbackPacketCalculator::GetContract(CalculatorContract* cc) {
// vector of packets or stores only the packet at PostStream timestamp.
// The kind of callback is controlled by an option. The callback is
// a std::function and is directly usable by CallbackCalculator.
// Since the options for the packet generator include a serialized pointer
// value, the resulting callback is only valid on the original machine
// while that pointer is still alive.
class CallbackPacketCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
const auto& options = cc->Options<CallbackPacketCalculatorOptions>(); const auto& options = cc->Options<CallbackPacketCalculatorOptions>();
switch (options.type()) { switch (options.type()) {
case CallbackPacketCalculatorOptions::VECTOR_PACKET: case CallbackPacketCalculatorOptions::VECTOR_PACKET:
@ -64,9 +58,9 @@ class CallbackPacketCalculator : public CalculatorBase {
<< "Invalid type of callback to produce."; << "Invalid type of callback to produce.";
} }
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) override { absl::Status CallbackPacketCalculator::Open(CalculatorContext* cc) {
const auto& options = cc->Options<CallbackPacketCalculatorOptions>(); const auto& options = cc->Options<CallbackPacketCalculatorOptions>();
void* ptr; void* ptr;
if (sscanf(options.pointer().c_str(), "%p", &ptr) != 1) { if (sscanf(options.pointer().c_str(), "%p", &ptr) != 1) {
@ -91,12 +85,11 @@ class CallbackPacketCalculator : public CalculatorBase {
<< "Invalid type to dump into."; << "Invalid type to dump into.";
} }
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) override { absl::Status CallbackPacketCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
};
REGISTER_CALCULATOR(CallbackPacketCalculator); REGISTER_CALCULATOR(CallbackPacketCalculator);

View File

@ -0,0 +1,39 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_INTERNAL_CALLBACK_PACKET_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_INTERNAL_CALLBACK_PACKET_CALCULATOR_H_
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_base.h"
namespace mediapipe {
// Creates a callback which takes a packet and stores it either in a
// vector of packets or stores only the packet at PostStream timestamp.
// The kind of callback is controlled by an option. The callback is
// a std::function and is directly usable by CallbackCalculator.
// Since the options for the packet generator include a serialized pointer
// value, the resulting callback is only valid on the original machine
// while that pointer is still alive.
class CallbackPacketCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_INTERNAL_CALLBACK_PACKET_CALCULATOR_H_

View File

@ -87,6 +87,7 @@ cc_library(
"//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/util:time_series_util", "//mediapipe/util:time_series_util",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
@ -181,6 +182,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
], ],
alwayslink = 1, alwayslink = 1,
@ -198,6 +200,7 @@ cc_test(
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/log:absl_check",
"@org_tensorflow//tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/c:common",
], ],
) )
@ -228,7 +231,6 @@ cc_library(
"//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
alwayslink = 1, alwayslink = 1,
@ -280,7 +282,6 @@ cc_library(
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils", "//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
"//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -476,6 +477,7 @@ cc_library(
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_framework_ios",
"//mediapipe/util/tflite:config", "//mediapipe/util/tflite:config",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
@ -622,6 +624,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/gpu:gpu_origin_proto",
], ],
) )
@ -651,7 +654,13 @@ cc_library(
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/gpu:gpu_buffer_format",
"//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings:str_format",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": ["tensor_converter_calculator_gpu_deps"], "//conditions:default": ["tensor_converter_calculator_gpu_deps"],
@ -701,6 +710,7 @@ cc_test(
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
@ -739,6 +749,8 @@ cc_library(
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/formats/object_detection:anchor_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
] + selects.with_or({ ] + selects.with_or({
@ -795,6 +807,7 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/log:absl_check",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -987,6 +1000,8 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/gpu:gpu_origin_cc_proto",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [":image_to_tensor_calculator_gpu_deps"], "//conditions:default": [":image_to_tensor_calculator_gpu_deps"],
@ -1079,6 +1094,7 @@ cc_test(
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/util:image_test_utils", "//mediapipe/util:image_test_utils",
"@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
@ -1206,6 +1222,7 @@ cc_library(
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:shader_util", "//mediapipe/gpu:shader_util",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
}), }),

View File

@ -20,6 +20,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/log/absl_check.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
@ -282,13 +283,17 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
if (options.has_volume_gain_db()) { if (options.has_volume_gain_db()) {
gain_ = pow(10, options.volume_gain_db() / 20.0); gain_ = pow(10, options.volume_gain_db() / 20.0);
} }
if (options.has_source_sample_rate()) {
source_sample_rate_ = options.source_sample_rate();
} else {
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
!kAudioIn(cc).Header().IsEmpty()) !kAudioIn(cc).Header().IsEmpty())
<< "Must either specify the time series header of the \"AUDIO\" stream " << "Must either specify the time series header of the \"AUDIO\" stream "
"or have the \"SAMPLE_RATE\" stream connected."; "or have the \"SAMPLE_RATE\" stream connected.";
if (!kAudioIn(cc).Header().IsEmpty()) { if (!kAudioIn(cc).Header().IsEmpty()) {
mediapipe::TimeSeriesHeader input_header; mediapipe::TimeSeriesHeader input_header;
MP_RETURN_IF_ERROR(mediapipe::time_series_util::FillTimeSeriesHeaderIfValid( MP_RETURN_IF_ERROR(
mediapipe::time_series_util::FillTimeSeriesHeaderIfValid(
kAudioIn(cc).Header(), &input_header)); kAudioIn(cc).Header(), &input_header));
if (stream_mode_) { if (stream_mode_) {
MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate())); MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate()));
@ -296,6 +301,7 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
source_sample_rate_ = input_header.sample_rate(); source_sample_rate_ = input_header.sample_rate();
} }
} }
}
AppendZerosToSampleBuffer(padding_samples_before_); AppendZerosToSampleBuffer(padding_samples_before_);
if (options.has_fft_size()) { if (options.has_fft_size()) {
RET_CHECK(IsValidFftSize(options.fft_size())) RET_CHECK(IsValidFftSize(options.fft_size()))
@ -343,7 +349,7 @@ absl::Status AudioToTensorCalculator::Process(CalculatorContext* cc) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
"The audio data should be stored in column-major."); "The audio data should be stored in column-major.");
} }
CHECK(channels_match || mono_output); ABSL_CHECK(channels_match || mono_output);
const Matrix& input = channels_match ? input_frame const Matrix& input = channels_match ? input_frame
// Mono mixdown. // Mono mixdown.
: input_frame.colwise().mean(); : input_frame.colwise().mean();
@ -452,7 +458,7 @@ absl::Status AudioToTensorCalculator::SetupStreamingResampler(
} }
void AudioToTensorCalculator::AppendZerosToSampleBuffer(int num_samples) { void AudioToTensorCalculator::AppendZerosToSampleBuffer(int num_samples) {
CHECK_GE(num_samples, 0); // Ensured by `UpdateContract`. ABSL_CHECK_GE(num_samples, 0); // Ensured by `UpdateContract`.
if (num_samples == 0) { if (num_samples == 0) {
return; return;
} }

View File

@ -85,4 +85,7 @@ message AudioToTensorCalculatorOptions {
// The volume gain, measured in dB. // The volume gain, measured in dB.
// Scale the input audio amplitude by 10^(volume_gain_db/20). // Scale the input audio amplitude by 10^(volume_gain_db/20).
optional double volume_gain_db = 12; optional double volume_gain_db = 12;
// The source number of samples per second (hertz) of the input audio buffers.
optional double source_sample_rate = 13;
} }

View File

@ -22,7 +22,6 @@
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h" #include "absl/strings/ascii.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
@ -244,7 +243,8 @@ std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
input_tensors.reserve(kNumInputTensorsForBert); input_tensors.reserve(kNumInputTensorsForBert);
for (int i = 0; i < kNumInputTensorsForBert; ++i) { for (int i = 0; i < kNumInputTensorsForBert; ++i) {
input_tensors.push_back( input_tensors.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({tensor_size})}); {Tensor::ElementType::kInt32,
Tensor::Shape({1, tensor_size}, has_dynamic_input_tensors_)});
} }
std::memcpy(input_tensors[input_ids_tensor_index_] std::memcpy(input_tensors[input_ids_tensor_index_]
.GetCpuWriteView() .GetCpuWriteView()

View File

@ -18,6 +18,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/log/absl_check.h"
#include "mediapipe/calculators/tensor/feedback_tensors_calculator.pb.h" #include "mediapipe/calculators/tensor/feedback_tensors_calculator.pb.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -65,7 +66,7 @@ template <typename T>
Tensor MakeTensor(std::initializer_list<int> shape, Tensor MakeTensor(std::initializer_list<int> shape,
std::initializer_list<T> values) { std::initializer_list<T> values) {
Tensor tensor(TensorElementType<T>::value, shape); Tensor tensor(TensorElementType<T>::value, shape);
CHECK_EQ(values.size(), tensor.shape().num_elements()) ABSL_CHECK_EQ(values.size(), tensor.shape().num_elements())
<< "The size of `values` is incompatible with `shape`"; << "The size of `values` is incompatible with `shape`";
absl::c_copy(values, tensor.GetCpuWriteView().buffer<T>()); absl::c_copy(values, tensor.GetCpuWriteView().buffer<T>());
return tensor; return tensor;

View File

@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter.h"
#include "mediapipe/calculators/tensor/image_to_tensor_utils.h" #include "mediapipe/calculators/tensor/image_to_tensor_utils.h"
@ -284,7 +285,7 @@ class ImageToTensorCalculator : public Node {
cc, GetBorderMode(options_.border_mode()), cc, GetBorderMode(options_.border_mode()),
GetOutputTensorType(/*uses_gpu=*/false, params_))); GetOutputTensorType(/*uses_gpu=*/false, params_)));
#else #else
LOG(FATAL) << "Cannot create image to tensor CPU converter since " ABSL_LOG(FATAL) << "Cannot create image to tensor CPU converter since "
"MEDIAPIPE_DISABLE_OPENCV is defined and " "MEDIAPIPE_DISABLE_OPENCV is defined and "
"MEDIAPIPE_ENABLE_HALIDE is not defined."; "MEDIAPIPE_ENABLE_HALIDE is not defined.";
#endif // !MEDIAPIPE_DISABLE_HALIDE #endif // !MEDIAPIPE_DISABLE_HALIDE

View File

@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "absl/flags/flag.h" #include "absl/flags/flag.h"
#include "absl/log/absl_check.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
@ -205,7 +206,7 @@ mediapipe::ImageFormat::Format GetImageFormat(int image_channels) {
} else if (image_channels == 1) { } else if (image_channels == 1) {
return ImageFormat::GRAY8; return ImageFormat::GRAY8;
} }
CHECK(false) << "Unsupported input image channles: " << image_channels; ABSL_CHECK(false) << "Unsupported input image channles: " << image_channels;
} }
Packet MakeImageFramePacket(cv::Mat input) { Packet MakeImageFramePacket(cv::Mat input) {

View File

@ -22,6 +22,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/log/absl_log.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/calculators/tensor/image_to_tensor_converter.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter.h"
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h"
@ -259,7 +260,7 @@ class GlProcessor : public ImageToTensorConverter {
// error. So in that case, we'll grab the transpose of our original matrix // error. So in that case, we'll grab the transpose of our original matrix
// and send that instead. // and send that instead.
const auto gl_context = mediapipe::GlContext::GetCurrent(); const auto gl_context = mediapipe::GlContext::GetCurrent();
LOG_IF(FATAL, !gl_context) << "GlContext is not bound to the thread."; ABSL_LOG_IF(FATAL, !gl_context) << "GlContext is not bound to the thread.";
if (gl_context->GetGlVersion() == mediapipe::GlVersion::kGLES2) { if (gl_context->GetGlVersion() == mediapipe::GlVersion::kGLES2) {
GetTransposedRotatedSubRectToRectTransformMatrix( GetTransposedRotatedSubRectToRectTransformMatrix(
sub_rect, texture.width(), texture.height(), flip_horizontaly, sub_rect, texture.width(), texture.height(), flip_horizontaly,

View File

@ -88,6 +88,20 @@ message InferenceCalculatorOptions {
// serialized model is invalid or missing. // serialized model is invalid or missing.
optional string serialized_model_dir = 7; optional string serialized_model_dir = 7;
enum CacheWritingBehavior {
// Do not write any caches.
NO_WRITE = 0;
// Try to write caches, log on failure.
TRY_WRITE = 1;
// Write caches or return an error if write fails.
WRITE_OR_ERROR = 2;
}
// Specifies how GPU caches are written to disk.
optional CacheWritingBehavior cache_writing_behavior = 10
[default = WRITE_OR_ERROR];
// Unique token identifying the model. Used in conjunction with // Unique token identifying the model. Used in conjunction with
// "serialized_model_dir". It is the caller's responsibility to ensure // "serialized_model_dir". It is the caller's responsibility to ensure
// there is no clash of the tokens. // there is no clash of the tokens.

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cstdint>
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <string> #include <string>
@ -26,6 +27,7 @@
#include "mediapipe/util/tflite/tflite_gpu_runner.h" #include "mediapipe/util/tflite/tflite_gpu_runner.h"
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
#include "absl/log/absl_log.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/android/file/base/file.h" #include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/filesystem.h" #include "mediapipe/util/android/file/base/filesystem.h"
@ -68,13 +70,21 @@ class InferenceCalculatorGlAdvancedImpl
const mediapipe::InferenceCalculatorOptions::Delegate::Gpu& const mediapipe::InferenceCalculatorOptions::Delegate::Gpu&
gpu_delegate_options); gpu_delegate_options);
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; // Writes caches to disk based on |cache_writing_behavior_|.
absl::Status SaveGpuCachesBasedOnBehavior(
tflite::gpu::TFLiteGPURunner* gpu_runner) const;
bool UseSerializedModel() const { return use_serialized_model_; }
private: private:
// Writes caches to disk, returns error on failure.
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
bool use_kernel_caching_ = false; bool use_kernel_caching_ = false;
std::string cached_kernel_filename_; std::string cached_kernel_filename_;
bool use_serialized_model_ = false; bool use_serialized_model_ = false;
std::string serialized_model_path_; std::string serialized_model_path_;
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::CacheWritingBehavior
cache_writing_behavior_;
}; };
// Helper class that wraps everything related to GPU inference acceleration. // Helper class that wraps everything related to GPU inference acceleration.
@ -150,8 +160,6 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
} }
absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() { absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() {
MP_RETURN_IF_ERROR(
on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get()));
return gpu_helper_.RunInGlContext([this]() -> absl::Status { return gpu_helper_.RunInGlContext([this]() -> absl::Status {
tflite_gpu_runner_.reset(); tflite_gpu_runner_.reset();
return absl::OkStatus(); return absl::OkStatus();
@ -226,9 +234,15 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
tflite_gpu_runner_->GetOutputShapes()[i].c}; tflite_gpu_runner_->GetOutputShapes()[i].c};
} }
if (on_disk_cache_helper_.UseSerializedModel()) {
tflite_gpu_runner_->ForceOpenCLInitFromSerializedModel();
}
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get())); on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
return tflite_gpu_runner_->Build(); MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
return on_disk_cache_helper_.SaveGpuCachesBasedOnBehavior(
tflite_gpu_runner_.get());
} }
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
@ -257,9 +271,36 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init(
mediapipe::file::JoinPath(gpu_delegate_options.serialized_model_dir(), mediapipe::file::JoinPath(gpu_delegate_options.serialized_model_dir(),
gpu_delegate_options.model_token()); gpu_delegate_options.model_token());
} }
cache_writing_behavior_ = gpu_delegate_options.has_cache_writing_behavior()
? gpu_delegate_options.cache_writing_behavior()
: mediapipe::InferenceCalculatorOptions::
Delegate::Gpu::WRITE_OR_ERROR;
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::
SaveGpuCachesBasedOnBehavior(
tflite::gpu::TFLiteGPURunner* gpu_runner) const {
switch (cache_writing_behavior_) {
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::NO_WRITE:
return absl::OkStatus();
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::TRY_WRITE: {
auto status = SaveGpuCaches(gpu_runner);
if (!status.ok()) {
ABSL_LOG_FIRST_N(WARNING, 1) << "Failed to save gpu caches: " << status;
}
return absl::OkStatus();
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::WRITE_OR_ERROR:
return SaveGpuCaches(gpu_runner);
default:
ABSL_LOG_FIRST_N(ERROR, 1)
<< "Unknown cache writing behavior: "
<< static_cast<uint32_t>(cache_writing_behavior_);
return absl::InvalidArgumentError("Unknown cache writing behavior.");
}
}
absl::Status absl::Status
InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches( InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches(
tflite::gpu::TFLiteGPURunner* gpu_runner) const { tflite::gpu::TFLiteGPURunner* gpu_runner) const {

View File

@ -21,6 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/log/absl_log.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_calculator.h"
@ -74,7 +75,7 @@ tflite::gpu::BHWC BhwcFromTensorShape(const Tensor::Shape& shape) {
break; break;
default: default:
// Handles 0 and >4. // Handles 0 and >4.
LOG(FATAL) ABSL_LOG(FATAL)
<< "Dimensions size must be in range [1,4] for GPU inference, but " << "Dimensions size must be in range [1,4] for GPU inference, but "
<< shape.dims.size() << " is provided"; << shape.dims.size() << " is provided";
} }

View File

@ -16,7 +16,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/log/check.h" #include "absl/log/absl_check.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"

View File

@ -96,6 +96,19 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) { CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
// Read CPU input into tensors. // Read CPU input into tensors.
RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size()); RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
// If the input tensors have dynamic shape, then the tensors need to be
// resized and reallocated before we can copy the tensor values.
bool resized_tensor_shapes = false;
for (int i = 0; i < input_tensors.size(); ++i) {
if (input_tensors[i].shape().is_dynamic) {
interpreter_->ResizeInputTensorStrict(i, input_tensors[i].shape().dims);
resized_tensor_shapes = true;
}
}
// Reallocation is needed for memory sanity.
if (resized_tensor_shapes) interpreter_->AllocateTensors();
for (int i = 0; i < input_tensors.size(); ++i) { for (int i = 0; i < input_tensors.size(); ++i) {
const TfLiteType input_tensor_type = const TfLiteType input_tensor_type =
interpreter_->tensor(interpreter_->inputs()[i])->type; interpreter_->tensor(interpreter_->inputs()[i])->type;

View File

@ -20,7 +20,6 @@
#include <vector> #include <vector>
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h" #include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
@ -161,7 +160,7 @@ absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
// not found in the tokenizer vocab. // not found in the tokenizer vocab.
std::vector<Tensor> result; std::vector<Tensor> result;
result.push_back( result.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})}); {Tensor::ElementType::kInt32, Tensor::Shape({1, max_seq_len_})});
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(), std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
input_tokens.data(), input_tokens.size() * sizeof(int32_t)); input_tokens.data(), input_tokens.size() * sizeof(int32_t));
kTensorsOut(cc).Send(std::move(result)); kTensorsOut(cc).Send(std::move(result));

View File

@ -12,9 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cstdint>
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/log/absl_check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h" #include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
@ -22,7 +27,8 @@
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/gpu/gpu_buffer_format.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
@ -43,12 +49,36 @@
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
namespace { namespace {
constexpr int kWorkgroupSize = 8; // Block size for GPU shader. constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
// Commonly used to compute the number of blocks to launch in a kernel. // Commonly used to compute the number of blocks to launch in a kernel.
int NumGroups(const int size, const int group_size) { // NOLINT int NumGroups(const int size, const int group_size) { // NOLINT
return (size + group_size - 1) / group_size; return (size + group_size - 1) / group_size;
} }
absl::StatusOr<bool> ShouldFlipVertically(
const mediapipe::TensorConverterCalculatorOptions& options) {
if (!options.has_gpu_origin()) {
return options.flip_vertically();
}
switch (options.gpu_origin()) {
case mediapipe::GpuOrigin::TOP_LEFT:
return false;
case mediapipe::GpuOrigin::DEFAULT:
case mediapipe::GpuOrigin::CONVENTIONAL:
// TOP_LEFT on Metal, BOTTOM_LEFT on OpenGL.
#ifdef __APPLE__
return false;
#else
return true;
#endif
}
return absl::InvalidArgumentError(
absl::StrFormat("Unhandled GPU origin %i", options.gpu_origin()));
}
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
RowMajorMatrixXf; RowMajorMatrixXf;
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor> typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>
@ -58,6 +88,7 @@ constexpr char kImageFrameTag[] = "IMAGE";
constexpr char kGpuBufferTag[] = "IMAGE_GPU"; constexpr char kGpuBufferTag[] = "IMAGE_GPU";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kMatrixTag[] = "MATRIX"; constexpr char kMatrixTag[] = "MATRIX";
} // namespace } // namespace
namespace mediapipe { namespace mediapipe {
@ -378,16 +409,27 @@ absl::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) {
// Get input image sizes. // Get input image sizes.
const auto& input = const auto& input =
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
mediapipe::ImageFormat::Format format = mediapipe::GpuBufferFormat format = input.format();
mediapipe::ImageFormatForGpuBufferFormat(input.format());
const bool include_alpha = (max_num_channels_ == 4); const bool include_alpha = (max_num_channels_ == 4);
const bool single_channel = (max_num_channels_ == 1); const bool single_channel = (max_num_channels_ == 1);
if (!(format == mediapipe::ImageFormat::GRAY8 ||
format == mediapipe::ImageFormat::SRGB || RET_CHECK(format == mediapipe::GpuBufferFormat::kBGRA32 ||
format == mediapipe::ImageFormat::SRGBA)) format == mediapipe::GpuBufferFormat::kRGB24 ||
RET_CHECK_FAIL() << "Unsupported GPU input format."; format == mediapipe::GpuBufferFormat::kRGBA32 ||
if (include_alpha && (format != mediapipe::ImageFormat::SRGBA)) format == mediapipe::GpuBufferFormat::kRGBAFloat128 ||
RET_CHECK_FAIL() << "Num input channels is less than desired output."; format == mediapipe::GpuBufferFormat::kRGBAHalf64 ||
format == mediapipe::GpuBufferFormat::kGrayFloat32 ||
format == mediapipe::GpuBufferFormat::kGrayHalf16 ||
format == mediapipe::GpuBufferFormat::kOneComponent8)
<< "Unsupported GPU input format: " << static_cast<uint32_t>(format);
if (include_alpha) {
RET_CHECK(format == mediapipe::GpuBufferFormat::kBGRA32 ||
format == mediapipe::GpuBufferFormat::kRGBA32 ||
format == mediapipe::GpuBufferFormat::kRGBAFloat128 ||
format == mediapipe::GpuBufferFormat::kRGBAHalf64)
<< "Num input channels is less than desired output, input format: "
<< static_cast<uint32_t>(format);
}
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
id<MTLDevice> device = gpu_helper_.mtlDevice; id<MTLDevice> device = gpu_helper_.mtlDevice;
@ -582,7 +624,7 @@ absl::Status TensorConverterCalculator::LoadOptions(CalculatorContext* cc) {
if (options.has_output_tensor_float_range()) { if (options.has_output_tensor_float_range()) {
output_range_.emplace(options.output_tensor_float_range().min(), output_range_.emplace(options.output_tensor_float_range().min(),
options.output_tensor_float_range().max()); options.output_tensor_float_range().max());
CHECK_GT(output_range_->second, output_range_->first); ABSL_CHECK_GT(output_range_->second, output_range_->first);
} }
// Custom div and sub values. // Custom div and sub values.
@ -593,16 +635,16 @@ absl::Status TensorConverterCalculator::LoadOptions(CalculatorContext* cc) {
} }
// Get y-flip mode. // Get y-flip mode.
flip_vertically_ = options.flip_vertically(); ASSIGN_OR_RETURN(flip_vertically_, ShouldFlipVertically(options));
// Get row_major_matrix mode. // Get row_major_matrix mode.
row_major_matrix_ = options.row_major_matrix(); row_major_matrix_ = options.row_major_matrix();
// Get desired way to handle input channels. // Get desired way to handle input channels.
max_num_channels_ = options.max_num_channels(); max_num_channels_ = options.max_num_channels();
CHECK_GE(max_num_channels_, 1); ABSL_CHECK_GE(max_num_channels_, 1);
CHECK_LE(max_num_channels_, 4); ABSL_CHECK_LE(max_num_channels_, 4);
CHECK_NE(max_num_channels_, 2); ABSL_CHECK_NE(max_num_channels_, 2);
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -3,6 +3,7 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/gpu/gpu_origin.proto";
// Full Example: // Full Example:
// //
@ -43,8 +44,14 @@ message TensorConverterCalculatorOptions {
// with a coordinate system where the origin is at the bottom-left corner // with a coordinate system where the origin is at the bottom-left corner
// (e.g., in OpenGL) whereas the ML model expects an image with a top-left // (e.g., in OpenGL) whereas the ML model expects an image with a top-left
// origin. // origin.
// Prefer gpu_origin over this field.
optional bool flip_vertically = 2 [default = false]; optional bool flip_vertically = 2 [default = false];
// Determines when the input image should be flipped vertically.
// See GpuOrigin.Mode for more information.
// If unset, falls back to flip_vertically for backwards compatibility.
optional GpuOrigin.Mode gpu_origin = 10;
// Controls how many channels of the input image get passed through to the // Controls how many channels of the input image get passed through to the
// tensor. Valid values are 1,3,4 only. Ignored for iOS GPU. // tensor. Valid values are 1,3,4 only. Ignored for iOS GPU.
optional int32 max_num_channels = 3 [default = 3]; optional int32 max_num_channels = 3 [default = 3];

View File

@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cstdint>
#include <memory>
#include <random> #include <random>
#include <utility>
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
@ -24,8 +27,10 @@
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT #include "mediapipe/framework/port/status_matchers.h" // NOLINT
#include "mediapipe/framework/tool/validate_type.h" #include "mediapipe/framework/tool/validate_type.h"
@ -40,7 +45,6 @@ constexpr char kTransposeOptionsString[] =
} // namespace } // namespace
using RandomEngine = std::mt19937_64; using RandomEngine = std::mt19937_64;
using testing::Eq;
const uint32_t kSeed = 1234; const uint32_t kSeed = 1234;
const int kNumSizes = 8; const int kNumSizes = 8;
const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2},
@ -110,12 +114,12 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixColMajor) {
// Wait until the calculator done processing. // Wait until the calculator done processing.
MP_ASSERT_OK(graph_->WaitUntilIdle()); MP_ASSERT_OK(graph_->WaitUntilIdle());
EXPECT_EQ(1, output_packets.size()); ASSERT_EQ(output_packets.size(), 1);
// Get and process results. // Get and process results.
const std::vector<Tensor>& tensor_vec = const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>(); output_packets[0].Get<std::vector<Tensor>>();
EXPECT_EQ(1, tensor_vec.size()); ASSERT_EQ(tensor_vec.size(), 1);
const Tensor* tensor = &tensor_vec[0]; const Tensor* tensor = &tensor_vec[0];
EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type()); EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type());
@ -127,7 +131,7 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixColMajor) {
auto tensor_buffer = view.buffer<float>(); auto tensor_buffer = view.buffer<float>();
for (int i = 0; i < num_rows * num_columns; ++i) { for (int i = 0; i < num_rows * num_columns; ++i) {
const float expected = uniform_dist(random); const float expected = uniform_dist(random);
EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i; EXPECT_FLOAT_EQ(tensor_buffer[i], expected) << "at i = " << i;
} }
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
@ -172,12 +176,12 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixRowMajor) {
// Wait until the calculator done processing. // Wait until the calculator done processing.
MP_ASSERT_OK(graph_->WaitUntilIdle()); MP_ASSERT_OK(graph_->WaitUntilIdle());
EXPECT_EQ(1, output_packets.size()); ASSERT_EQ(output_packets.size(), 1);
// Get and process results. // Get and process results.
const std::vector<Tensor>& tensor_vec = const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>(); output_packets[0].Get<std::vector<Tensor>>();
EXPECT_EQ(1, tensor_vec.size()); ASSERT_EQ(tensor_vec.size(), 1);
const Tensor* tensor = &tensor_vec[0]; const Tensor* tensor = &tensor_vec[0];
EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type()); EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type());
@ -189,7 +193,7 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixRowMajor) {
auto tensor_buffer = view.buffer<float>(); auto tensor_buffer = view.buffer<float>();
for (int i = 0; i < num_rows * num_columns; ++i) { for (int i = 0; i < num_rows * num_columns; ++i) {
const float expected = uniform_dist(random); const float expected = uniform_dist(random);
EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i; EXPECT_EQ(tensor_buffer[i], expected) << "at i = " << i;
} }
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
@ -239,12 +243,12 @@ TEST_F(TensorConverterCalculatorTest, CustomDivAndSub) {
// Get and process results. // Get and process results.
const std::vector<Tensor>& tensor_vec = const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>(); output_packets[0].Get<std::vector<Tensor>>();
EXPECT_EQ(1, tensor_vec.size()); ASSERT_EQ(tensor_vec.size(), 1);
const Tensor* tensor = &tensor_vec[0]; const Tensor* tensor = &tensor_vec[0];
EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type()); EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type());
auto view = tensor->GetCpuReadView(); auto view = tensor->GetCpuReadView();
EXPECT_FLOAT_EQ(67.0f, *view.buffer<float>()); EXPECT_FLOAT_EQ(*view.buffer<float>(), 67.0f);
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone(). // after calling WaitUntilDone().
@ -259,8 +263,8 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) {
for (std::pair<float, float> range : range_values) { for (std::pair<float, float> range : range_values) {
CalculatorGraph graph; CalculatorGraph graph;
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>( mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
absl::Substitute(R"( R"pb(
input_stream: "input_image" input_stream: "input_image"
node { node {
calculator: "TensorConverterCalculator" calculator: "TensorConverterCalculator"
@ -268,14 +272,11 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) {
output_stream: "TENSORS:tensor" output_stream: "TENSORS:tensor"
options { options {
[mediapipe.TensorConverterCalculatorOptions.ext] { [mediapipe.TensorConverterCalculatorOptions.ext] {
output_tensor_float_range { output_tensor_float_range { min: $0 max: $1 }
min: $0
max: $1
} }
} }
} }
} )pb",
)",
/*$0=*/range.first, /*$0=*/range.first,
/*$1=*/range.second)); /*$1=*/range.second));
std::vector<Packet> output_packets; std::vector<Packet> output_packets;
@ -292,26 +293,23 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) {
// Wait until the calculator finishes processing. // Wait until the calculator finishes processing.
MP_ASSERT_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets.size(), Eq(1)); ASSERT_EQ(output_packets.size(), 1);
// Get and process results. // Get and process results.
const std::vector<Tensor>& tensor_vec = const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>(); output_packets[0].Get<std::vector<Tensor>>();
EXPECT_THAT(tensor_vec.size(), Eq(1)); ASSERT_EQ(tensor_vec.size(), 1);
const Tensor* tensor = &tensor_vec[0]; const Tensor* tensor = &tensor_vec[0];
// Calculate the expected normalized value: // Calculate the expected normalized value:
float normalized_value = float expected_value =
range.first + (200 * (range.second - range.first)) / 255.0; range.first + (200 * (range.second - range.first)) / 255.0;
EXPECT_THAT(tensor->element_type(), Eq(Tensor::ElementType::kFloat32)); EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32);
auto view = tensor->GetCpuReadView(); auto view = tensor->GetCpuReadView();
float dataf = *view.buffer<float>(); float actual_value = *view.buffer<float>();
EXPECT_THAT( EXPECT_FLOAT_EQ(actual_value, expected_value);
normalized_value,
testing::FloatNear(dataf, 2.0f * std::abs(dataf) *
std::numeric_limits<float>::epsilon()));
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone(). // after calling WaitUntilDone().
@ -320,4 +318,113 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) {
} }
} }
TEST_F(TensorConverterCalculatorTest, FlipVertically) {
CalculatorGraph graph;
CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_image"
node {
calculator: "TensorConverterCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.TensorConverterCalculatorOptions.ext] {
flip_vertically: true
output_tensor_float_range { min: 0 max: 255 }
}
}
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("tensor", &graph_config, &output_packets);
// Run the graph.
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
auto input_image = absl::make_unique<ImageFrame>(ImageFormat::GRAY8, 1, 2);
cv::Mat mat = mediapipe::formats::MatView(input_image.get());
constexpr uint8_t kY0Value = 100;
constexpr uint8_t kY1Value = 200;
mat.at<uint8_t>(0, 0) = kY0Value;
mat.at<uint8_t>(1, 0) = kY1Value; // Note: y, x!
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_image", Adopt(input_image.release()).At(Timestamp(0))));
// Wait until the calculator finishes processing.
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_packets.size(), 1);
// Get and process results.
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
ASSERT_EQ(tensor_vec.size(), 1);
const Tensor* tensor = &tensor_vec[0];
EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32);
const float* dataf = tensor->GetCpuReadView().buffer<float>();
EXPECT_EQ(static_cast<int>(roundf(dataf[0])), kY1Value); // Y0, Y1 flipped!
EXPECT_EQ(static_cast<int>(roundf(dataf[1])), kY0Value);
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
MP_ASSERT_OK(graph.CloseInputStream("input_image"));
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(TensorConverterCalculatorTest, GpuOriginOverridesFlipVertically) {
CalculatorGraph graph;
CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_image"
node {
calculator: "TensorConverterCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.TensorConverterCalculatorOptions.ext] {
flip_vertically: true
gpu_origin: TOP_LEFT
output_tensor_float_range { min: 0 max: 255 }
}
}
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("tensor", &graph_config, &output_packets);
// Run the graph.
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
auto input_image = absl::make_unique<ImageFrame>(ImageFormat::GRAY8, 1, 2);
cv::Mat mat = mediapipe::formats::MatView(input_image.get());
constexpr uint8_t kY0Value = 100;
constexpr uint8_t kY1Value = 200;
mat.at<uint8_t>(0, 0) = kY0Value;
mat.at<uint8_t>(1, 0) = kY1Value; // Note: y, x!
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_image", Adopt(input_image.release()).At(Timestamp(0))));
// Wait until the calculator finishes processing.
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_packets.size(), 1);
// Get and process results.
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
ASSERT_EQ(tensor_vec.size(), 1);
const Tensor* tensor = &tensor_vec[0];
EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32);
const float* dataf = tensor->GetCpuReadView().buffer<float>();
EXPECT_EQ(static_cast<int>(roundf(dataf[0])), kY0Value); // Not flipped!
EXPECT_EQ(static_cast<int>(roundf(dataf[1])), kY1Value);
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
MP_ASSERT_OK(graph.CloseInputStream("input_image"));
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -15,6 +15,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "absl/log/absl_log.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" #include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
@ -83,7 +84,7 @@ void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
void ConvertAnchorsToRawValues(const std::vector<Anchor>& anchors, void ConvertAnchorsToRawValues(const std::vector<Anchor>& anchors,
int num_boxes, float* raw_anchors) { int num_boxes, float* raw_anchors) {
CHECK_EQ(anchors.size(), num_boxes); ABSL_CHECK_EQ(anchors.size(), num_boxes);
int box = 0; int box = 0;
for (const auto& anchor : anchors) { for (const auto& anchor : anchors) {
raw_anchors[box * kNumCoordsPerBox + 0] = anchor.y_center(); raw_anchors[box * kNumCoordsPerBox + 0] = anchor.y_center();
@ -256,6 +257,7 @@ class TensorsToDetectionsCalculator : public Node {
bool gpu_inited_ = false; bool gpu_inited_ = false;
bool gpu_input_ = false; bool gpu_input_ = false;
bool gpu_has_enough_work_groups_ = true;
bool anchors_init_ = false; bool anchors_init_ = false;
}; };
MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator); MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator);
@ -291,7 +293,7 @@ absl::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) {
absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) { absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
auto output_detections = absl::make_unique<std::vector<Detection>>(); auto output_detections = absl::make_unique<std::vector<Detection>>();
bool gpu_processing = false; bool gpu_processing = false;
if (CanUseGpu()) { if (CanUseGpu() && gpu_has_enough_work_groups_) {
// Use GPU processing only if at least one input tensor is already on GPU // Use GPU processing only if at least one input tensor is already on GPU
// (to avoid CPU->GPU overhead). // (to avoid CPU->GPU overhead).
for (const auto& tensor : *kInTensors(cc)) { for (const auto& tensor : *kInTensors(cc)) {
@ -321,11 +323,20 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
RET_CHECK(!has_custom_box_indices_); RET_CHECK(!has_custom_box_indices_);
} }
if (gpu_processing) { if (gpu_processing && !gpu_inited_) {
if (!gpu_inited_) { auto status = GpuInit(cc);
MP_RETURN_IF_ERROR(GpuInit(cc)); if (status.ok()) {
gpu_inited_ = true; gpu_inited_ = true;
} else if (status.code() == absl::StatusCode::kFailedPrecondition) {
// For initialization error because of hardware limitation, fallback to
// CPU processing.
ABSL_LOG(WARNING) << status.message();
} else {
// For other error, let the error propagates.
return status;
} }
}
if (gpu_processing && gpu_inited_) {
MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get())); MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
} else { } else {
MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get())); MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
@ -346,17 +357,41 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
// TODO: Add flexible input tensor size handling. // TODO: Add flexible input tensor size handling.
auto raw_box_tensor = auto raw_box_tensor =
&input_tensors[tensor_mapping_.detections_tensor_index()]; &input_tensors[tensor_mapping_.detections_tensor_index()];
RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3);
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options"; RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options";
if (raw_box_tensor->shape().dims.size() == 3) {
// The tensors from CPU inference has dim 3.
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_); RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_); RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
} else if (raw_box_tensor->shape().dims.size() == 4) {
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
// we allow tensors with 4 dims.
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], 1);
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_boxes_);
RET_CHECK_EQ(raw_box_tensor->shape().dims[3], num_coords_);
} else {
return absl::InvalidArgumentError(
"The dimensions of box Tensor must be 3 or 4.");
}
auto raw_score_tensor = auto raw_score_tensor =
&input_tensors[tensor_mapping_.scores_tensor_index()]; &input_tensors[tensor_mapping_.scores_tensor_index()];
RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3); if (raw_score_tensor->shape().dims.size() == 3) {
// The tensors from CPU inference has dim 3.
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1); RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_); RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_); RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_);
} else if (raw_score_tensor->shape().dims.size() == 4) {
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
// we allow tensors with 4 dims.
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_boxes_);
RET_CHECK_EQ(raw_score_tensor->shape().dims[3], num_classes_);
} else {
return absl::InvalidArgumentError(
"The dimensions of score Tensor must be 3 or 4.");
}
auto raw_box_view = raw_box_tensor->GetCpuReadView(); auto raw_box_view = raw_box_tensor->GetCpuReadView();
auto raw_boxes = raw_box_view.buffer<float>(); auto raw_boxes = raw_box_view.buffer<float>();
auto raw_scores_view = raw_score_tensor->GetCpuReadView(); auto raw_scores_view = raw_score_tensor->GetCpuReadView();
@ -634,7 +669,7 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
output_detections)); output_detections));
#else #else
LOG(ERROR) << "GPU input on non-Android not supported yet."; ABSL_LOG(ERROR) << "GPU input on non-Android not supported yet.";
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
return absl::OkStatus(); return absl::OkStatus();
} }
@ -669,16 +704,16 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) {
num_boxes_ = options_.num_boxes(); num_boxes_ = options_.num_boxes();
num_coords_ = options_.num_coords(); num_coords_ = options_.num_coords();
box_output_format_ = GetBoxFormat(options_); box_output_format_ = GetBoxFormat(options_);
CHECK_NE(options_.max_results(), 0) ABSL_CHECK_NE(options_.max_results(), 0)
<< "The maximum number of the top-scored detection results must be " << "The maximum number of the top-scored detection results must be "
"non-zero."; "non-zero.";
max_results_ = options_.max_results(); max_results_ = options_.max_results();
// Currently only support 2D when num_values_per_keypoint equals to 2. // Currently only support 2D when num_values_per_keypoint equals to 2.
CHECK_EQ(options_.num_values_per_keypoint(), 2); ABSL_CHECK_EQ(options_.num_values_per_keypoint(), 2);
// Check if the output size is equal to the requested boxes and keypoints. // Check if the output size is equal to the requested boxes and keypoints.
CHECK_EQ(options_.num_keypoints() * options_.num_values_per_keypoint() + ABSL_CHECK_EQ(options_.num_keypoints() * options_.num_values_per_keypoint() +
kNumCoordsPerBox, kNumCoordsPerBox,
num_coords_); num_coords_);
@ -1111,15 +1146,21 @@ void main() {
int max_wg_size; // typically <= 1024 int max_wg_size; // typically <= 1024
glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
&max_wg_size); // y-dim &max_wg_size); // y-dim
CHECK_LT(num_classes_, max_wg_size) gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
<< "# classes must be < " << max_wg_size; if (!gpu_has_enough_work_groups_) {
return absl::FailedPreconditionError(absl::StrFormat(
"Hardware limitation: Processing will be done on CPU, because "
"num_classes %d exceeds the max work_group size %d.",
num_classes_, max_wg_size));
}
// TODO support better filtering. // TODO support better filtering.
if (class_index_set_.is_allowlist) { if (class_index_set_.is_allowlist) {
CHECK_EQ(class_index_set_.values.size(), ABSL_CHECK_EQ(class_index_set_.values.size(),
IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1) IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1)
<< "Only all classes >= class 0 or >= class 1"; << "Only all classes >= class 0 or >= class 1";
} else { } else {
CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1) ABSL_CHECK_EQ(class_index_set_.values.size(),
IsClassIndexAllowed(0) ? 0 : 1)
<< "Only ignore class 0 is allowed"; << "Only ignore class 0 is allowed";
} }
@ -1340,11 +1381,12 @@ kernel void scoreKernel(
// TODO support better filtering. // TODO support better filtering.
if (class_index_set_.is_allowlist) { if (class_index_set_.is_allowlist) {
CHECK_EQ(class_index_set_.values.size(), ABSL_CHECK_EQ(class_index_set_.values.size(),
IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1) IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1)
<< "Only all classes >= class 0 or >= class 1"; << "Only all classes >= class 0 or >= class 1";
} else { } else {
CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1) ABSL_CHECK_EQ(class_index_set_.values.size(),
IsClassIndexAllowed(0) ? 0 : 1)
<< "Only ignore class 0 is allowed"; << "Only ignore class 0 is allowed";
} }
@ -1370,7 +1412,13 @@ kernel void scoreKernel(
Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2}); Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2});
// # filter classes supported is hardware dependent. // # filter classes supported is hardware dependent.
int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup; int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup;
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size; gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
if (!gpu_has_enough_work_groups_) {
return absl::FailedPreconditionError(absl::StrFormat(
"Hardware limitation: Processing will be done on CPU, because "
"num_classes %d exceeds the max work_group size %d.",
num_classes_, max_wg_size));
}
} }
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)

View File

@ -142,7 +142,7 @@ absl::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) {
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32); RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
int num_values = input_tensors[0].shape().num_elements(); int num_values = input_tensors[0].shape().num_elements();
const int num_dimensions = num_values / num_landmarks_; const int num_dimensions = num_values / num_landmarks_;
CHECK_GT(num_dimensions, 0); ABSL_CHECK_GT(num_dimensions, 0);
auto view = input_tensors[0].GetCpuReadView(); auto view = input_tensors[0].GetCpuReadView();
auto raw_landmarks = view.buffer<float>(); auto raw_landmarks = view.buffer<float>();

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# #
# Placeholder: load py_proto_library
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library")
licenses(["notice"]) licenses(["notice"])
@ -314,6 +315,7 @@ cc_library(
"//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_check",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework",
@ -366,15 +368,14 @@ cc_library(
name = "pack_media_sequence_calculator", name = "pack_media_sequence_calculator",
srcs = ["pack_media_sequence_calculator.cc"], srcs = ["pack_media_sequence_calculator.cc"],
deps = [ deps = [
":pack_media_sequence_calculator_cc_proto",
"//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto",
"//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location",
"//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/formats:location_opencv",
"//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence",
"//mediapipe/util/sequence:media_sequence_util", "//mediapipe/util/sequence:media_sequence_util",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
@ -406,8 +407,13 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
# This dependency removed tensorflow_jellyfish_deps and xprofilez_with_server because they failed # This dependency removed the following 3 targets because they failed Boq conformance test:
# Boq conformance test. Weigh your use case to see if this will work for you. #
# tensorflow_jellyfish_deps
# jfprof_lib
# xprofilez_with_server
#
# If you need them plz consider tensorflow_inference_calculator_no_envelope_loader.
cc_library( cc_library(
name = "tensorflow_inference_calculator_for_boq", name = "tensorflow_inference_calculator_for_boq",
srcs = ["tensorflow_inference_calculator.cc"], srcs = ["tensorflow_inference_calculator.cc"],
@ -424,7 +430,7 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:check", "@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
@ -483,10 +489,10 @@ cc_library(
"//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto", "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/deps:clock", "//mediapipe/framework/deps:clock",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"@com_google_absl//absl/log:absl_log",
"@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core:protos_all_cc",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
@ -514,10 +520,10 @@ cc_library(
":tensorflow_session_from_frozen_graph_generator_cc_proto", ":tensorflow_session_from_frozen_graph_generator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/deps:clock", "//mediapipe/framework/deps:clock",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"@com_google_absl//absl/log:absl_log",
"@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core:protos_all_cc",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
@ -550,6 +556,7 @@ cc_library(
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/cc/saved_model:constants", "@org_tensorflow//tensorflow/cc/saved_model:constants",
"@org_tensorflow//tensorflow/cc/saved_model:loader_lite", "@org_tensorflow//tensorflow/cc/saved_model:loader_lite",
@ -627,6 +634,7 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/cc/saved_model:constants", "@org_tensorflow//tensorflow/cc/saved_model:constants",
@ -648,6 +656,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_log",
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework",
], ],
alwayslink = 1, alwayslink = 1,
@ -662,6 +671,7 @@ cc_library(
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_check",
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework",
], ],
alwayslink = 1, alwayslink = 1,
@ -677,6 +687,7 @@ cc_library(
"//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_check",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework",
@ -773,6 +784,7 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:audio_decoder_cc_proto", "//mediapipe/util:audio_decoder_cc_proto",
"//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core:protos_all_cc",
], ],
@ -787,6 +799,8 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework",
], ],
alwayslink = 1, alwayslink = 1,
@ -800,6 +814,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_log",
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework",
], ],
alwayslink = 1, alwayslink = 1,
@ -813,6 +828,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_log",
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework",
], ],
alwayslink = 1, alwayslink = 1,
@ -826,6 +842,8 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core:protos_all_cc",
], ],
alwayslink = 1, alwayslink = 1,
@ -920,22 +938,22 @@ cc_test(
srcs = ["pack_media_sequence_calculator_test.cc"], srcs = ["pack_media_sequence_calculator_test.cc"],
deps = [ deps = [
":pack_media_sequence_calculator", ":pack_media_sequence_calculator",
":pack_media_sequence_calculator_cc_proto",
"//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto",
"//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location",
"//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/formats:location_opencv",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core:protos_all_cc",
], ],
) )
@ -1077,6 +1095,7 @@ cc_test(
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
":tensor_to_image_frame_calculator", ":tensor_to_image_frame_calculator",
":tensor_to_image_frame_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
@ -1162,6 +1181,7 @@ cc_test(
"//mediapipe/framework/port:rectangle", "//mediapipe/framework/port:rectangle",
"//mediapipe/util:audio_decoder_cc_proto", "//mediapipe/util:audio_decoder_cc_proto",
"//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core:protos_all_cc",
@ -1243,6 +1263,8 @@ cc_test(
"//mediapipe/framework/tool:sink", "//mediapipe/framework/tool:sink",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
"@org_tensorflow//tensorflow/core:direct_session", "@org_tensorflow//tensorflow/core:direct_session",

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/log/absl_check.h"
#include "mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_options.pb.h" #include "mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
@ -28,7 +29,7 @@ namespace mediapipe {
namespace { namespace {
absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet,
TimeSeriesHeader* header) { TimeSeriesHeader* header) {
CHECK(header); ABSL_CHECK(header);
if (header_packet.IsEmpty()) { if (header_packet.IsEmpty()) {
return absl::UnknownError("No header found."); return absl::UnknownError("No header found.");
} }

View File

@ -12,21 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cstdint>
#include <optional>
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/strip.h"
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/formats/location_opencv.h" #include "mediapipe/framework/formats/location_opencv.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/sequence/media_sequence.h" #include "mediapipe/util/sequence/media_sequence.h"
#include "mediapipe/util/sequence/media_sequence_util.h" #include "mediapipe/util/sequence/media_sequence_util.h"
#include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/example.pb.h"
@ -36,6 +37,7 @@ namespace mediapipe {
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
const char kImageTag[] = "IMAGE"; const char kImageTag[] = "IMAGE";
const char kImageLabelPrefixTag[] = "IMAGE_LABEL_";
const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_"; const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_";
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
const char kIntFeaturePrefixTag[] = "INT_FEATURE_"; const char kIntFeaturePrefixTag[] = "INT_FEATURE_";
@ -44,6 +46,7 @@ const char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
const char kBBoxTag[] = "BBOX"; const char kBBoxTag[] = "BBOX";
const char kKeypointsTag[] = "KEYPOINTS"; const char kKeypointsTag[] = "KEYPOINTS";
const char kSegmentationMaskTag[] = "CLASS_SEGMENTATION"; const char kSegmentationMaskTag[] = "CLASS_SEGMENTATION";
const char kClipMediaIdTag[] = "CLIP_MEDIA_ID";
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
namespace mpms = mediapipe::mediasequence; namespace mpms = mediapipe::mediasequence;
@ -55,16 +58,21 @@ namespace mpms = mediapipe::mediasequence;
// context features can be supplied verbatim in the calculator's options. The // context features can be supplied verbatim in the calculator's options. The
// SequenceExample will conform to the description in media_sequence.h. // SequenceExample will conform to the description in media_sequence.h.
// //
// The supported input stream tags are "IMAGE", which stores the encoded // The supported input stream tags are:
// images from the OpenCVImageEncoderCalculator, "FORWARD_FLOW_ENCODED", which // * "IMAGE", which stores the encoded images from the
// stores the encoded optical flow from the same calculator, "BBOX" which stores // OpenCVImageEncoderCalculator,
// bounding boxes from vector<Detections>, and streams with the // * "IMAGE_LABEL", which stores whole image labels from Detection,
// "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector<float>'s // * "FORWARD_FLOW_ENCODED", which stores the encoded optical flow from the same
// associated with the name ${NAME}. "KEYPOINTS" stores a map of 2D keypoints // calculator,
// from flat_hash_map<string, vector<pair<float, float>>>. "IMAGE_${NAME}", // * "BBOX" which stores bounding boxes from vector<Detections>,
// "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store prefixed versions of // * streams with the "FLOAT_FEATURE_${NAME}" pattern, which stores the values
// each stream, which allows for multiple image streams to be included. However, // from vector<float>'s associated with the name ${NAME},
// the default names are suppored by more tools. // * "KEYPOINTS" stores a map of 2D keypoints from flat_hash_map<string,
// vector<pair<float, float>>>,
// * "CLIP_MEDIA_ID", which stores the clip's media ID as a string.
// "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store
// prefixed versions of each stream, which allows for multiple image streams to
// be included. However, the default names are suppored by more tools.
// //
// Example config: // Example config:
// node { // node {
@ -100,6 +108,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag)); RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag));
cc->InputSidePackets().Tag(kSequenceExampleTag).Set<tf::SequenceExample>(); cc->InputSidePackets().Tag(kSequenceExampleTag).Set<tf::SequenceExample>();
if (cc->InputSidePackets().HasTag(kClipMediaIdTag)) {
cc->InputSidePackets().Tag(kClipMediaIdTag).Set<std::string>();
}
if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) { if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) {
cc->Inputs() cc->Inputs()
@ -112,6 +123,10 @@ class PackMediaSequenceCalculator : public CalculatorBase {
for (const auto& tag : cc->Inputs().GetTags()) { for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kImageTag)) { if (absl::StartsWith(tag, kImageTag)) {
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
cc->Inputs().Tag(tag).Set<Detection>();
continue;
}
std::string key = ""; std::string key = "";
if (tag != kImageTag) { if (tag != kImageTag) {
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
@ -164,7 +179,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
} }
} }
CHECK(cc->Outputs().HasTag(kSequenceExampleTag) || RET_CHECK(cc->Outputs().HasTag(kSequenceExampleTag) ||
cc->OutputSidePackets().HasTag(kSequenceExampleTag)) cc->OutputSidePackets().HasTag(kSequenceExampleTag))
<< "Neither the output stream nor the output side packet is set to " << "Neither the output stream nor the output side packet is set to "
"output the sequence example."; "output the sequence example.";
@ -184,6 +199,11 @@ class PackMediaSequenceCalculator : public CalculatorBase {
cc->InputSidePackets() cc->InputSidePackets()
.Tag(kSequenceExampleTag) .Tag(kSequenceExampleTag)
.Get<tf::SequenceExample>()); .Get<tf::SequenceExample>());
if (cc->InputSidePackets().HasTag(kClipMediaIdTag) &&
!cc->InputSidePackets().Tag(kClipMediaIdTag).IsEmpty()) {
clip_media_id_ =
cc->InputSidePackets().Tag(kClipMediaIdTag).Get<std::string>();
}
const auto& context_features = const auto& context_features =
cc->Options<PackMediaSequenceCalculatorOptions>().context_feature_map(); cc->Options<PackMediaSequenceCalculatorOptions>().context_feature_map();
@ -199,6 +219,16 @@ class PackMediaSequenceCalculator : public CalculatorBase {
.replace_data_instead_of_append()) { .replace_data_instead_of_append()) {
for (const auto& tag : cc->Inputs().GetTags()) { for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kImageTag)) { if (absl::StartsWith(tag, kImageTag)) {
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
std::string key =
std::string(absl::StripPrefix(tag, kImageLabelPrefixTag));
mpms::ClearImageLabelString(key, sequence_.get());
mpms::ClearImageLabelConfidence(key, sequence_.get());
if (!key.empty() || mpms::HasImageEncoded(*sequence_)) {
mpms::ClearImageTimestamp(key, sequence_.get());
}
continue;
}
std::string key = ""; std::string key = "";
if (tag != kImageTag) { if (tag != kImageTag) {
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
@ -227,6 +257,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
mpms::ClearBBoxNumRegions(key, sequence_.get()); mpms::ClearBBoxNumRegions(key, sequence_.get());
mpms::ClearBBoxLabelString(key, sequence_.get()); mpms::ClearBBoxLabelString(key, sequence_.get());
mpms::ClearBBoxLabelIndex(key, sequence_.get()); mpms::ClearBBoxLabelIndex(key, sequence_.get());
mpms::ClearBBoxLabelConfidence(key, sequence_.get());
mpms::ClearBBoxClassString(key, sequence_.get()); mpms::ClearBBoxClassString(key, sequence_.get());
mpms::ClearBBoxClassIndex(key, sequence_.get()); mpms::ClearBBoxClassIndex(key, sequence_.get());
mpms::ClearBBoxTrackString(key, sequence_.get()); mpms::ClearBBoxTrackString(key, sequence_.get());
@ -343,6 +374,34 @@ class PackMediaSequenceCalculator : public CalculatorBase {
if (absl::StartsWith(tag, kImageTag) && if (absl::StartsWith(tag, kImageTag) &&
!cc->Inputs().Tag(tag).IsEmpty()) { !cc->Inputs().Tag(tag).IsEmpty()) {
std::string key = ""; std::string key = "";
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
std::string key =
std::string(absl::StripPrefix(tag, kImageLabelPrefixTag));
const auto& detection = cc->Inputs().Tag(tag).Get<Detection>();
if (detection.label().empty()) continue;
RET_CHECK(detection.label_size() == detection.score_size())
<< "Wrong image label data format: " << detection.label_size()
<< " vs " << detection.score_size();
if (!detection.label_id().empty()) {
RET_CHECK(detection.label_id_size() == detection.label_size())
<< "Wrong image label ID format: " << detection.label_id_size()
<< " vs " << detection.label_size();
}
std::vector<std::string> labels(detection.label().begin(),
detection.label().end());
std::vector<float> confidences(detection.score().begin(),
detection.score().end());
std::vector<int32_t> ids(detection.label_id().begin(),
detection.label_id().end());
if (!key.empty() || mpms::HasImageEncoded(*sequence_)) {
mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(),
sequence_.get());
}
mpms::AddImageLabelString(key, labels, sequence_.get());
mpms::AddImageLabelConfidence(key, confidences, sequence_.get());
if (!ids.empty()) mpms::AddImageLabelIndex(key, ids, sequence_.get());
continue;
}
if (tag != kImageTag) { if (tag != kImageTag) {
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
if (tag[tag_length] == '_') { if (tag[tag_length] == '_') {
@ -393,6 +452,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
mpms::ClearBBoxNumRegions(prefix, sequence_.get()); mpms::ClearBBoxNumRegions(prefix, sequence_.get());
mpms::ClearBBoxLabelString(prefix, sequence_.get()); mpms::ClearBBoxLabelString(prefix, sequence_.get());
mpms::ClearBBoxLabelIndex(prefix, sequence_.get()); mpms::ClearBBoxLabelIndex(prefix, sequence_.get());
mpms::ClearBBoxLabelConfidence(prefix, sequence_.get());
mpms::ClearBBoxClassString(prefix, sequence_.get()); mpms::ClearBBoxClassString(prefix, sequence_.get());
mpms::ClearBBoxClassIndex(prefix, sequence_.get()); mpms::ClearBBoxClassIndex(prefix, sequence_.get());
mpms::ClearBBoxTrackString(prefix, sequence_.get()); mpms::ClearBBoxTrackString(prefix, sequence_.get());
@ -460,6 +520,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
} }
std::vector<Location> predicted_locations; std::vector<Location> predicted_locations;
std::vector<std::string> predicted_class_strings; std::vector<std::string> predicted_class_strings;
std::vector<float> predicted_class_confidences;
std::vector<int> predicted_label_ids; std::vector<int> predicted_label_ids;
for (auto& detection : for (auto& detection :
cc->Inputs().Tag(tag).Get<std::vector<Detection>>()) { cc->Inputs().Tag(tag).Get<std::vector<Detection>>()) {
@ -488,6 +549,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
if (detection.label_id_size() > 0) { if (detection.label_id_size() > 0) {
predicted_label_ids.push_back(detection.label_id(0)); predicted_label_ids.push_back(detection.label_id(0));
} }
if (detection.score_size() > 0) {
predicted_class_confidences.push_back(detection.score(0));
}
} }
} }
if (!predicted_locations.empty()) { if (!predicted_locations.empty()) {
@ -501,6 +565,10 @@ class PackMediaSequenceCalculator : public CalculatorBase {
if (!predicted_label_ids.empty()) { if (!predicted_label_ids.empty()) {
mpms::AddBBoxLabelIndex(key, predicted_label_ids, sequence_.get()); mpms::AddBBoxLabelIndex(key, predicted_label_ids, sequence_.get());
} }
if (!predicted_class_confidences.empty()) {
mpms::AddBBoxLabelConfidence(key, predicted_class_confidences,
sequence_.get());
}
} }
} }
} }
@ -548,10 +616,14 @@ class PackMediaSequenceCalculator : public CalculatorBase {
} }
} }
} }
if (clip_media_id_.has_value()) {
mpms::SetClipMediaId(*clip_media_id_, sequence_.get());
}
return absl::OkStatus(); return absl::OkStatus();
} }
std::unique_ptr<tf::SequenceExample> sequence_; std::unique_ptr<tf::SequenceExample> sequence_;
std::optional<std::string> clip_media_id_ = std::nullopt;
std::map<std::string, bool> features_present_; std::map<std::string, bool> features_present_;
bool replace_keypoints_; bool replace_keypoints_;
}; };

View File

@ -12,28 +12,29 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm> #include <cstdint>
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h"
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/formats/location_opencv.h" #include "mediapipe/framework/formats/location_opencv.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "mediapipe/util/sequence/media_sequence.h" #include "mediapipe/util/sequence/media_sequence.h"
#include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/example/feature.pb.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
@ -59,9 +60,12 @@ constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST"; constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
constexpr char kIntFeatureOtherTag[] = "INT_FEATURE_OTHER"; constexpr char kIntFeatureOtherTag[] = "INT_FEATURE_OTHER";
constexpr char kIntFeatureTestTag[] = "INT_FEATURE_TEST"; constexpr char kIntFeatureTestTag[] = "INT_FEATURE_TEST";
constexpr char kImageLabelTestTag[] = "IMAGE_LABEL_TEST";
constexpr char kImageLabelOtherTag[] = "IMAGE_LABEL_OTHER";
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX"; constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kClipMediaIdTag[] = "CLIP_MEDIA_ID";
class PackMediaSequenceCalculatorTest : public ::testing::Test { class PackMediaSequenceCalculatorTest : public ::testing::Test {
protected: protected:
@ -69,10 +73,14 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test {
const tf::Features& features, const tf::Features& features,
const bool output_only_if_all_present, const bool output_only_if_all_present,
const bool replace_instead_of_append, const bool replace_instead_of_append,
const bool output_as_zero_timestamp = false) { const bool output_as_zero_timestamp = false,
const std::vector<std::string>& input_side_packets = {
"SEQUENCE_EXAMPLE:input_sequence"}) {
CalculatorGraphConfig::Node config; CalculatorGraphConfig::Node config;
config.set_calculator("PackMediaSequenceCalculator"); config.set_calculator("PackMediaSequenceCalculator");
config.add_input_side_packet("SEQUENCE_EXAMPLE:input_sequence"); for (const std::string& side_packet : input_side_packets) {
config.add_input_side_packet(side_packet);
}
config.add_output_stream("SEQUENCE_EXAMPLE:output_sequence"); config.add_output_stream("SEQUENCE_EXAMPLE:output_sequence");
for (const std::string& stream : input_streams) { for (const std::string& stream : input_streams) {
config.add_input_stream(stream); config.add_input_stream(stream);
@ -96,7 +104,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) {
mpms::SetClipMediaId(test_video_id, input_sequence.get()); mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(2); encoded_image.set_width(2);
@ -139,7 +148,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) {
mpms::SetClipMediaId(test_video_id, input_sequence.get()); mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(2); encoded_image.set_width(2);
@ -312,6 +322,76 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBytesLists) {
} }
} }
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImageLabels) {
SetUpCalculator(
{"IMAGE_LABEL_TEST:test_labels", "IMAGE_LABEL_OTHER:test_labels2"}, {},
false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
int num_timesteps = 2;
for (int i = 0; i < num_timesteps; ++i) {
Detection detection1;
detection1.add_label(absl::StrCat("foo", 2 << i));
detection1.add_label_id(i);
detection1.add_score(0.1 * i);
detection1.add_label(absl::StrCat("foo", 2 << i));
detection1.add_label_id(i);
detection1.add_score(0.1 * i);
auto label_ptr1 = ::absl::make_unique<Detection>(detection1);
runner_->MutableInputs()
->Tag(kImageLabelTestTag)
.packets.push_back(Adopt(label_ptr1.release()).At(Timestamp(i)));
Detection detection2;
detection2.add_label(absl::StrCat("bar", 2 << i));
detection2.add_score(0.2 * i);
detection2.add_label(absl::StrCat("bar", 2 << i));
detection2.add_score(0.2 * i);
auto label_ptr2 = ::absl::make_unique<Detection>(detection2);
runner_->MutableInputs()
->Tag(kImageLabelOtherTag)
.packets.push_back(Adopt(label_ptr2.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(num_timesteps,
mpms::GetImageTimestampSize("TEST", output_sequence));
ASSERT_EQ(num_timesteps,
mpms::GetImageLabelStringSize("TEST", output_sequence));
ASSERT_EQ(num_timesteps,
mpms::GetImageLabelConfidenceSize("TEST", output_sequence));
ASSERT_EQ(num_timesteps,
mpms::GetImageTimestampSize("OTHER", output_sequence));
ASSERT_EQ(num_timesteps,
mpms::GetImageLabelStringSize("OTHER", output_sequence));
ASSERT_EQ(num_timesteps,
mpms::GetImageLabelConfidenceSize("OTHER", output_sequence));
for (int i = 0; i < num_timesteps; ++i) {
ASSERT_EQ(i, mpms::GetImageTimestampAt("TEST", output_sequence, i));
ASSERT_THAT(mpms::GetImageLabelStringAt("TEST", output_sequence, i),
::testing::ElementsAreArray(
std::vector<std::string>(2, absl::StrCat("foo", 2 << i))));
ASSERT_THAT(mpms::GetImageLabelIndexAt("TEST", output_sequence, i),
::testing::ElementsAreArray(std::vector<int32_t>(2, i)));
ASSERT_THAT(mpms::GetImageLabelConfidenceAt("TEST", output_sequence, i),
::testing::ElementsAreArray(std::vector<float>(2, 0.1 * i)));
ASSERT_EQ(i, mpms::GetImageTimestampAt("OTHER", output_sequence, i));
ASSERT_THAT(mpms::GetImageLabelStringAt("OTHER", output_sequence, i),
::testing::ElementsAreArray(
std::vector<std::string>(2, absl::StrCat("bar", 2 << i))));
ASSERT_THAT(mpms::GetImageLabelConfidenceAt("OTHER", output_sequence, i),
::testing::ElementsAreArray(std::vector<float>(2, 0.2 * i)));
}
}
TEST_F(PackMediaSequenceCalculatorTest, OutputAsZeroTimestamp) { TEST_F(PackMediaSequenceCalculatorTest, OutputAsZeroTimestamp) {
SetUpCalculator({"FLOAT_FEATURE_TEST:test"}, {}, false, true, true); SetUpCalculator({"FLOAT_FEATURE_TEST:test"}, {}, false, true, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
@ -378,7 +458,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
Adopt(input_sequence.release()); Adopt(input_sequence.release());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
auto image_ptr = auto image_ptr =
@ -410,7 +491,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) {
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
std::string test_flow_string(bytes.begin(), bytes.end()); std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow; OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string); encoded_flow.set_encoded_image(test_flow_string);
@ -526,6 +608,10 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) {
auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i); auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i);
ASSERT_EQ(0, class_indices[0]); ASSERT_EQ(0, class_indices[0]);
ASSERT_EQ(1, class_indices[1]); ASSERT_EQ(1, class_indices[1]);
auto class_scores =
mpms::GetPredictedBBoxLabelConfidenceAt(output_sequence, i);
ASSERT_FLOAT_EQ(0.5, class_scores[0]);
ASSERT_FLOAT_EQ(0.75, class_scores[1]);
} }
} }
@ -618,7 +704,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
} }
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(width); encoded_image.set_width(width);
@ -667,6 +754,10 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i); auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i);
ASSERT_EQ(0, class_indices[0]); ASSERT_EQ(0, class_indices[0]);
ASSERT_EQ(1, class_indices[1]); ASSERT_EQ(1, class_indices[1]);
auto class_scores =
mpms::GetPredictedBBoxLabelConfidenceAt(output_sequence, i);
ASSERT_FLOAT_EQ(0.5, class_scores[0]);
ASSERT_FLOAT_EQ(0.75, class_scores[1]);
} }
} }
@ -757,6 +848,88 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
testing::ElementsAreArray(::std::vector<std::string>({"mask"}))); testing::ElementsAreArray(::std::vector<std::string>({"mask"})));
} }
TEST_F(PackMediaSequenceCalculatorTest, AddClipMediaId) {
SetUpCalculator(
/*input_streams=*/{"FLOAT_FEATURE_TEST:test",
"FLOAT_FEATURE_OTHER:test2"},
/*features=*/{},
/*output_only_if_all_present=*/false,
/*replace_instead_of_append=*/true,
/*output_as_zero_timestamp=*/false, /*input_side_packets=*/
{"SEQUENCE_EXAMPLE:input_sequence", "CLIP_MEDIA_ID:video_id"});
auto input_sequence = absl::make_unique<tf::SequenceExample>();
const std::string test_video_id = "test_video_id";
int num_timesteps = 2;
for (int i = 0; i < num_timesteps; ++i) {
auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs()
->Tag(kFloatFeatureTestTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs()
->Tag(kFloatFeatureOtherTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag(kClipMediaIdTag) =
MakePacket<std::string>(test_video_id);
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
}
TEST_F(PackMediaSequenceCalculatorTest, ReplaceClipMediaId) {
SetUpCalculator(
/*input_streams=*/{"FLOAT_FEATURE_TEST:test",
"FLOAT_FEATURE_OTHER:test2"},
/*features=*/{},
/*output_only_if_all_present=*/false,
/*replace_instead_of_append=*/true,
/*output_as_zero_timestamp=*/false, /*input_side_packets=*/
{"SEQUENCE_EXAMPLE:input_sequence", "CLIP_MEDIA_ID:video_id"});
auto input_sequence = absl::make_unique<tf::SequenceExample>();
const std::string existing_video_id = "existing_video_id";
mpms::SetClipMediaId(existing_video_id, input_sequence.get());
const std::string test_video_id = "test_video_id";
int num_timesteps = 2;
for (int i = 0; i < num_timesteps; ++i) {
auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs()
->Tag(kFloatFeatureTestTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs()
->Tag(kFloatFeatureOtherTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag(kClipMediaIdTag) =
MakePacket<std::string>(test_video_id).At(Timestamp(0));
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
}
TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) { TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) {
SetUpCalculator( SetUpCalculator(
{"FORWARD_FLOW_ENCODED:flow", "FLOAT_FEATURE_I3D_FLOW:feature"}, {}, {"FORWARD_FLOW_ENCODED:flow", "FLOAT_FEATURE_I3D_FLOW:feature"}, {},
@ -767,7 +940,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) {
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
std::string test_flow_string(bytes.begin(), bytes.end()); std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow; OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string); encoded_flow.set_encoded_image(test_flow_string);
@ -813,7 +987,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) {
mpms::SetClipMediaId(test_video_id, input_sequence.get()); mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
std::string test_flow_string(bytes.begin(), bytes.end()); std::string test_flow_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_flow; OpenCvImageEncoderCalculatorResults encoded_flow;
encoded_flow.set_encoded_image(test_flow_string); encoded_flow.set_encoded_image(test_flow_string);
@ -970,7 +1145,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(2); encoded_image.set_width(2);
@ -1021,7 +1197,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image; OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
int height = 2; int height = 2;
@ -1057,6 +1234,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
mpms::AddBBoxNumRegions(-1, input_sequence.get()); mpms::AddBBoxNumRegions(-1, input_sequence.get());
mpms::AddBBoxLabelString({"anything"}, input_sequence.get()); mpms::AddBBoxLabelString({"anything"}, input_sequence.get());
mpms::AddBBoxLabelIndex({-1}, input_sequence.get()); mpms::AddBBoxLabelIndex({-1}, input_sequence.get());
mpms::AddBBoxLabelConfidence({-1}, input_sequence.get());
mpms::AddBBoxClassString({"anything"}, input_sequence.get()); mpms::AddBBoxClassString({"anything"}, input_sequence.get());
mpms::AddBBoxClassIndex({-1}, input_sequence.get()); mpms::AddBBoxClassIndex({-1}, input_sequence.get());
mpms::AddBBoxTrackString({"anything"}, input_sequence.get()); mpms::AddBBoxTrackString({"anything"}, input_sequence.get());

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -99,9 +100,10 @@ class TensorSqueezeDimensionsCalculator : public CalculatorBase {
} }
} }
if (remove_dims_.empty()) { if (remove_dims_.empty()) {
LOG(ERROR) << "TensorSqueezeDimensionsCalculator is squeezing input with " ABSL_LOG(ERROR)
<< "TensorSqueezeDimensionsCalculator is squeezing input with "
"no single-dimensions. Calculator will be a no-op."; "no single-dimensions. Calculator will be a no-op.";
LOG(ERROR) << "Input to TensorSqueezeDimensionsCalculator has shape " ABSL_LOG(ERROR) << "Input to TensorSqueezeDimensionsCalculator has shape "
<< tensor_shape.DebugString(); << tensor_shape.DebugString();
} }
} }

View File

@ -14,6 +14,7 @@
#include <iostream> #include <iostream>
#include "absl/log/absl_check.h"
#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
@ -65,6 +66,7 @@ class TensorToImageFrameCalculator : public CalculatorBase {
private: private:
float scale_factor_; float scale_factor_;
bool scale_per_frame_min_max_;
}; };
REGISTER_CALCULATOR(TensorToImageFrameCalculator); REGISTER_CALCULATOR(TensorToImageFrameCalculator);
@ -88,6 +90,8 @@ absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) { absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
scale_factor_ = scale_factor_ =
cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor(); cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
scale_per_frame_min_max_ = cc->Options<TensorToImageFrameCalculatorOptions>()
.scale_per_frame_min_max();
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -96,7 +100,7 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get<tf::Tensor>(); const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get<tf::Tensor>();
int32_t depth = 1; int32_t depth = 1;
if (input_tensor.dims() != 2) { // Depth is 1 for 2D tensors. if (input_tensor.dims() != 2) { // Depth is 1 for 2D tensors.
CHECK(3 == input_tensor.dims()) ABSL_CHECK(3 == input_tensor.dims())
<< "Only 2 or 3-D Tensors can be converted to frames. Instead got: " << "Only 2 or 3-D Tensors can be converted to frames. Instead got: "
<< input_tensor.dims(); << input_tensor.dims();
depth = input_tensor.dim_size(2); depth = input_tensor.dim_size(2);
@ -109,16 +113,38 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8); auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
const int32_t total_size = height * width * depth; const int32_t total_size = height * width * depth;
if (scale_per_frame_min_max_) {
RET_CHECK_EQ(input_tensor.dtype(), tensorflow::DT_FLOAT)
<< "Setting scale_per_frame_min_max requires FLOAT input tensors.";
}
::std::unique_ptr<const ImageFrame> output; ::std::unique_ptr<const ImageFrame> output;
if (input_tensor.dtype() == tensorflow::DT_FLOAT) { if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
// Allocate buffer with alignments. // Allocate buffer with alignments.
std::unique_ptr<uint8_t[]> buffer( std::unique_ptr<uint8_t[]> buffer(
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]); new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
auto data = input_tensor.flat<float>().data(); auto data = input_tensor.flat<float>().data();
float min = 1e23;
float max = -1e23;
if (scale_per_frame_min_max_) {
for (int i = 0; i < total_size; ++i) { for (int i = 0; i < total_size; ++i) {
float d = scale_factor_ * data[i]; float d = scale_factor_ * data[i];
if (d < min) {
min = d;
}
if (d > max) {
max = d;
}
}
}
for (int i = 0; i < total_size; ++i) {
float d = data[i];
if (scale_per_frame_min_max_) {
d = 255 * (d - min) / (max - min + 1e-9);
} else {
d = scale_factor_ * d;
if (d < 0) d = 0; if (d < 0) d = 0;
if (d > 255) d = 255; if (d > 255) d = 255;
}
buffer[i] = d; buffer[i] = d;
} }
output = ::absl::make_unique<ImageFrame>( output = ::absl::make_unique<ImageFrame>(

View File

@ -26,4 +26,8 @@ message TensorToImageFrameCalculatorOptions {
// Multiples floating point tensor outputs by this value before converting to // Multiples floating point tensor outputs by this value before converting to
// uint8. This is useful for converting from range [0, 1] to [0, 255] // uint8. This is useful for converting from range [0, 1] to [0, 255]
optional float scale_factor = 1 [default = 1.0]; optional float scale_factor = 1 [default = 1.0];
// If true, scales any FLOAT tensor input of [min, max] to be between [0, 255]
// per frame. This overrides any explicit scale_factor.
optional bool scale_per_frame_min_max = 2 [default = false];
} }

View File

@ -11,7 +11,9 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <type_traits>
#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
@ -32,11 +34,14 @@ constexpr char kImage[] = "IMAGE";
template <class TypeParam> template <class TypeParam>
class TensorToImageFrameCalculatorTest : public ::testing::Test { class TensorToImageFrameCalculatorTest : public ::testing::Test {
protected: protected:
void SetUpRunner() { void SetUpRunner(bool scale_per_frame_min_max = false) {
CalculatorGraphConfig::Node config; CalculatorGraphConfig::Node config;
config.set_calculator("TensorToImageFrameCalculator"); config.set_calculator("TensorToImageFrameCalculator");
config.add_input_stream("TENSOR:input_tensor"); config.add_input_stream("TENSOR:input_tensor");
config.add_output_stream("IMAGE:output_image"); config.add_output_stream("IMAGE:output_image");
config.mutable_options()
->MutableExtension(mediapipe::TensorToImageFrameCalculatorOptions::ext)
->set_scale_per_frame_min_max(scale_per_frame_min_max);
runner_ = absl::make_unique<CalculatorRunner>(config); runner_ = absl::make_unique<CalculatorRunner>(config);
} }
@ -157,4 +162,47 @@ TYPED_TEST(TensorToImageFrameCalculatorTest,
} }
} }
TYPED_TEST(TensorToImageFrameCalculatorTest,
Converts3DTensorToImageFrame2DGrayWithScaling) {
this->SetUpRunner(true);
auto& runner = this->runner_;
constexpr int kWidth = 16;
constexpr int kHeight = 8;
const tf::TensorShape tensor_shape{kHeight, kWidth};
auto tensor = absl::make_unique<tf::Tensor>(
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
auto tensor_vec = tensor->template flat<TypeParam>().data();
// Writing sequence of integers as floats which we want normalized.
tensor_vec[0] = 255;
for (int i = 1; i < kWidth * kHeight; ++i) {
tensor_vec[i] = 200;
}
const int64_t time = 1234;
runner->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
if (!std::is_same<TypeParam, float>::value) {
EXPECT_FALSE(runner->Run().ok());
return; // Short circuit because does not apply to other types.
} else {
EXPECT_TRUE(runner->Run().ok());
const std::vector<Packet>& output_packets =
runner->Outputs().Tag(kImage).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
EXPECT_EQ(kWidth, output_image.Width());
EXPECT_EQ(kHeight, output_image.Height());
EXPECT_EQ(255, output_image.PixelData()[0]);
for (int i = 1; i < kWidth * kHeight; ++i) {
const uint8_t pixel_value = output_image.PixelData()[i];
ASSERT_EQ(0, pixel_value);
}
}
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -15,6 +15,7 @@
// Calculator converts from one-dimensional Tensor of DT_FLOAT to Matrix // Calculator converts from one-dimensional Tensor of DT_FLOAT to Matrix
// OR from (batched) two-dimensional Tensor of DT_FLOAT to Matrix. // OR from (batched) two-dimensional Tensor of DT_FLOAT to Matrix.
#include "absl/log/absl_check.h"
#include "mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
@ -36,7 +37,7 @@ constexpr char kReference[] = "REFERENCE";
absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet,
TimeSeriesHeader* header) { TimeSeriesHeader* header) {
CHECK(header); ABSL_CHECK(header);
if (header_packet.IsEmpty()) { if (header_packet.IsEmpty()) {
return absl::UnknownError("No header found."); return absl::UnknownError("No header found.");
} }
@ -191,7 +192,7 @@ absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) {
<< "Tensor stream packet does not contain a Tensor."; << "Tensor stream packet does not contain a Tensor.";
const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get<tf::Tensor>(); const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get<tf::Tensor>();
CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims()) ABSL_CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims())
<< "Only 1-D or 2-D Tensors can be converted to matrices."; << "Only 1-D or 2-D Tensors can be converted to matrices.";
const int32_t length = input_tensor.dim_size(input_tensor.dims() - 1); const int32_t length = input_tensor.dim_size(input_tensor.dims() - 1);
const int32_t width = const int32_t width =

View File

@ -20,6 +20,7 @@
#include <vector> #include <vector>
#include "absl/base/thread_annotations.h" #include "absl/base/thread_annotations.h"
#include "absl/log/absl_check.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
@ -515,7 +516,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
tf::Tensor concated; tf::Tensor concated;
const tf::Status concat_status = const tf::Status concat_status =
tf::tensor::Concat(keyed_tensors.second, &concated); tf::tensor::Concat(keyed_tensors.second, &concated);
CHECK(concat_status.ok()) << concat_status.ToString(); ABSL_CHECK(concat_status.ok()) << concat_status.ToString();
input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
concated); concated);
} }
@ -597,7 +598,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
std::vector<tf::Tensor> split_tensors; std::vector<tf::Tensor> split_tensors;
const tf::Status split_status = const tf::Status split_status =
tf::tensor::Split(outputs[i], split_vector, &split_tensors); tf::tensor::Split(outputs[i], split_vector, &split_tensors);
CHECK(split_status.ok()) << split_status.ToString(); ABSL_CHECK(split_status.ok()) << split_status.ToString();
// Loop over timestamps so that we don't copy the padding. // Loop over timestamps so that we don't copy the padding.
for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) { for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) {
tf::Tensor output_tensor(split_tensors[j]); tf::Tensor output_tensor(split_tensors[j]);

View File

@ -17,6 +17,8 @@
#include <vector> #include <vector>
#include "absl/flags/flag.h" #include "absl/flags/flag.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -118,7 +120,7 @@ class TensorflowInferenceCalculatorTest : public ::testing::Test {
// Create tensor from Vector and add as a Packet to the provided tag as input. // Create tensor from Vector and add as a Packet to the provided tag as input.
void AddVectorToInputsAsPacket(const std::vector<Packet>& packets, void AddVectorToInputsAsPacket(const std::vector<Packet>& packets,
const std::string& tag) { const std::string& tag) {
CHECK(!packets.empty()) ABSL_CHECK(!packets.empty())
<< "Please specify at least some data in the packet"; << "Please specify at least some data in the packet";
auto packets_ptr = absl::make_unique<std::vector<Packet>>(packets); auto packets_ptr = absl::make_unique<std::vector<Packet>>(packets);
runner_->MutableInputs()->Tag(tag).packets.push_back( runner_->MutableInputs()->Tag(tag).packets.push_back(
@ -586,12 +588,12 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) {
runner_->Outputs().Tag(kMultipliedTag).packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size()); ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
LOG(INFO) << "timestamp: " << 0; ABSL_LOG(INFO) << "timestamp: " << 0;
auto expected_tensor = tf::test::AsTensor<int32_t>({3, 8, 15}); auto expected_tensor = tf::test::AsTensor<int32_t>({3, 8, 15});
tf::test::ExpectTensorEqual<int32_t>(tensor_mult, expected_tensor); tf::test::ExpectTensorEqual<int32_t>(tensor_mult, expected_tensor);
const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>(); const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>();
auto expected_tensor1 = tf::test::AsTensor<int32_t>({9, 32, 75}); auto expected_tensor1 = tf::test::AsTensor<int32_t>({9, 32, 75});
LOG(INFO) << "timestamp: " << 1; ABSL_LOG(INFO) << "timestamp: " << 1;
tf::test::ExpectTensorEqual<int32_t>(tensor_mult1, expected_tensor1); tf::test::ExpectTensorEqual<int32_t>(tensor_mult1, expected_tensor1);
EXPECT_EQ(2, runner_ EXPECT_EQ(2, runner_
@ -627,12 +629,12 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStateOverride) {
runner_->Outputs().Tag(kMultipliedTag).packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size()); ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
LOG(INFO) << "timestamp: " << 0; ABSL_LOG(INFO) << "timestamp: " << 0;
auto expected_tensor = tf::test::AsTensor<int32_t>({3, 4, 5}); auto expected_tensor = tf::test::AsTensor<int32_t>({3, 4, 5});
tf::test::ExpectTensorEqual<int32_t>(tensor_mult, expected_tensor); tf::test::ExpectTensorEqual<int32_t>(tensor_mult, expected_tensor);
const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>(); const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get<tf::Tensor>();
auto expected_tensor1 = tf::test::AsTensor<int32_t>({3, 4, 5}); auto expected_tensor1 = tf::test::AsTensor<int32_t>({3, 4, 5});
LOG(INFO) << "timestamp: " << 1; ABSL_LOG(INFO) << "timestamp: " << 1;
tf::test::ExpectTensorEqual<int32_t>(tensor_mult1, expected_tensor1); tf::test::ExpectTensorEqual<int32_t>(tensor_mult1, expected_tensor1);
EXPECT_EQ(2, runner_ EXPECT_EQ(2, runner_

View File

@ -23,12 +23,12 @@
#include <string> #include <string>
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/clock.h" #include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/monotonic_clock.h" #include "mediapipe/framework/deps/monotonic_clock.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/status_util.h" #include "mediapipe/framework/tool/status_util.h"
@ -156,7 +156,7 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release())); cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release()));
const uint64_t end_time = absl::ToUnixMicros(clock->TimeNow()); const uint64_t end_time = absl::ToUnixMicros(clock->TimeNow());
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time ABSL_LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
<< " microseconds."; << " microseconds.";
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -24,13 +24,13 @@
#include <string> #include <string>
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/clock.h" #include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/monotonic_clock.h" #include "mediapipe/framework/deps/monotonic_clock.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/status_util.h" #include "mediapipe/framework/tool/status_util.h"
@ -155,7 +155,7 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
output_side_packets->Tag(kSessionTag) = Adopt(session.release()); output_side_packets->Tag(kSessionTag) = Adopt(session.release());
const uint64_t end_time = absl::ToUnixMicros(clock->TimeNow()); const uint64_t end_time = absl::ToUnixMicros(clock->TimeNow());
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time ABSL_LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
<< " microseconds."; << " microseconds.";
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -17,6 +17,7 @@
#if !defined(__ANDROID__) #if !defined(__ANDROID__)
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#endif #endif
#include "absl/log/absl_log.h"
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h"
@ -69,7 +70,7 @@ const std::string MaybeConvertSignatureToTag(
[](unsigned char c) { return std::toupper(c); }); [](unsigned char c) { return std::toupper(c); });
output = absl::StrReplaceAll( output = absl::StrReplaceAll(
output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
LOG(INFO) << "Renamed TAG from: " << name << " to " << output; ABSL_LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
return output; return output;
} else { } else {
return name; return name;

View File

@ -19,6 +19,7 @@
#if !defined(__ANDROID__) #if !defined(__ANDROID__)
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#endif #endif
#include "absl/log/absl_log.h"
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.pb.h"
@ -75,7 +76,7 @@ const std::string MaybeConvertSignatureToTag(
[](unsigned char c) { return std::toupper(c); }); [](unsigned char c) { return std::toupper(c); });
output = absl::StrReplaceAll( output = absl::StrReplaceAll(
output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
LOG(INFO) << "Renamed TAG from: " << name << " to " << output; ABSL_LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
return output; return output;
} else { } else {
return name; return name;

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/log/absl_log.h"
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h" #include "mediapipe/calculators/core/packet_resampler_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.pb.h" #include "mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.pb.h"
@ -201,7 +202,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
first_timestamp_seen_ = Timestamp::OneOverPostStream().Value(); first_timestamp_seen_ = Timestamp::OneOverPostStream().Value();
for (const auto& map_kv : sequence_->feature_lists().feature_list()) { for (const auto& map_kv : sequence_->feature_lists().feature_list()) {
if (absl::StrContains(map_kv.first, "/timestamp")) { if (absl::StrContains(map_kv.first, "/timestamp")) {
LOG(INFO) << "Found feature timestamps: " << map_kv.first ABSL_LOG(INFO) << "Found feature timestamps: " << map_kv.first
<< " with size: " << map_kv.second.feature_size(); << " with size: " << map_kv.second.feature_size();
int64_t recent_timestamp = Timestamp::PreStream().Value(); int64_t recent_timestamp = Timestamp::PreStream().Value();
for (int i = 0; i < map_kv.second.feature_size(); ++i) { for (int i = 0; i < map_kv.second.feature_size(); ++i) {
@ -309,7 +310,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
audio_decoder_options->set_end_time( audio_decoder_options->set_end_time(
end_time + options.extra_padding_from_media_decoder()); end_time + options.extra_padding_from_media_decoder());
} }
LOG(INFO) << "Created AudioDecoderOptions:\n" ABSL_LOG(INFO) << "Created AudioDecoderOptions:\n"
<< audio_decoder_options->DebugString(); << audio_decoder_options->DebugString();
cc->OutputSidePackets() cc->OutputSidePackets()
.Tag(kAudioDecoderOptions) .Tag(kAudioDecoderOptions)
@ -331,7 +332,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
->set_end_time(Timestamp::FromSeconds(end_time).Value()); ->set_end_time(Timestamp::FromSeconds(end_time).Value());
} }
LOG(INFO) << "Created PacketResamplerOptions:\n" ABSL_LOG(INFO) << "Created PacketResamplerOptions:\n"
<< resampler_options->DebugString(); << resampler_options->DebugString();
cc->OutputSidePackets() cc->OutputSidePackets()
.Tag(kPacketResamplerOptions) .Tag(kPacketResamplerOptions)
@ -351,7 +352,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) override { absl::Status Process(CalculatorContext* cc) override {
if (timestamps_.empty()) { if (timestamps_.empty()) {
// This occurs when we only have metadata to unpack. // This occurs when we only have metadata to unpack.
LOG(INFO) << "only unpacking metadata because there are no timestamps."; ABSL_LOG(INFO)
<< "only unpacking metadata because there are no timestamps.";
return tool::StatusStop(); return tool::StatusStop();
} }
// In Process(), we loop through timestamps on a reference stream and emit // In Process(), we loop through timestamps on a reference stream and emit

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/log/absl_log.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h" #include "mediapipe/calculators/core/packet_resampler_calculator.pb.h"
@ -81,7 +82,7 @@ class UnpackMediaSequenceCalculatorTest : public ::testing::Test {
if (options != nullptr) { if (options != nullptr) {
*config.mutable_options() = *options; *config.mutable_options() = *options;
} }
LOG(INFO) << config.DebugString(); ABSL_LOG(INFO) << config.DebugString();
runner_ = absl::make_unique<CalculatorRunner>(config); runner_ = absl::make_unique<CalculatorRunner>(config);
} }

View File

@ -14,6 +14,8 @@
#include <iterator> #include <iterator>
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.pb.h" #include "mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
@ -46,7 +48,7 @@ std::string GetQuantizedFeature(
.Get(index) .Get(index)
.bytes_list() .bytes_list()
.value(); .value();
CHECK_EQ(1, bytes_list.size()); ABSL_CHECK_EQ(1, bytes_list.size());
return bytes_list.Get(0); return bytes_list.Get(0);
} }
} // namespace } // namespace
@ -149,8 +151,9 @@ class UnpackYt8mSequenceExampleCalculator : public CalculatorBase {
.Set(MakePacket<int>(segment_size)); .Set(MakePacket<int>(segment_size));
} }
} }
LOG(INFO) << "Reading the sequence example that contains yt8m id: " ABSL_LOG(INFO) << "Reading the sequence example that contains yt8m id: "
<< yt8m_id << ". Feature list length: " << feature_list_length_; << yt8m_id
<< ". Feature list length: " << feature_list_length_;
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -14,6 +14,7 @@
// //
// Converts vector<float> (or vector<vector<float>>) to 1D (or 2D) tf::Tensor. // Converts vector<float> (or vector<vector<float>>) to 1D (or 2D) tf::Tensor.
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_options.pb.h" #include "mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -68,7 +69,7 @@ absl::Status VectorFloatToTensorCalculator::GetContract(
// Output vector<float>. // Output vector<float>.
); );
} else { } else {
LOG(FATAL) << "input size not supported"; ABSL_LOG(FATAL) << "input size not supported";
} }
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported."; << "Only one output stream is supported.";
@ -125,7 +126,7 @@ absl::Status VectorFloatToTensorCalculator::Process(CalculatorContext* cc) {
} }
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
} else { } else {
LOG(FATAL) << "input size not supported"; ABSL_LOG(FATAL) << "input size not supported";
} }
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -15,6 +15,8 @@
// Converts a single int or vector<int> or vector<vector<int>> to 1D (or 2D) // Converts a single int or vector<int> or vector<vector<int>> to 1D (or 2D)
// tf::Tensor. // tf::Tensor.
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_options.pb.h" #include "mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -86,7 +88,7 @@ absl::Status VectorIntToTensorCalculator::GetContract(CalculatorContract* cc) {
cc->Inputs().Tag(kVectorInt).Set<std::vector<int>>(); cc->Inputs().Tag(kVectorInt).Set<std::vector<int>>();
} }
} else { } else {
LOG(FATAL) << "input size not supported"; ABSL_LOG(FATAL) << "input size not supported";
} }
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported."; << "Only one output stream is supported.";
@ -113,11 +115,11 @@ absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) {
.Get<std::vector<std::vector<int>>>(); .Get<std::vector<std::vector<int>>>();
const int32_t rows = input.size(); const int32_t rows = input.size();
CHECK_GE(rows, 1); ABSL_CHECK_GE(rows, 1);
const int32_t cols = input[0].size(); const int32_t cols = input[0].size();
CHECK_GE(cols, 1); ABSL_CHECK_GE(cols, 1);
for (int i = 1; i < rows; ++i) { for (int i = 1; i < rows; ++i) {
CHECK_EQ(input[i].size(), cols); ABSL_CHECK_EQ(input[i].size(), cols);
} }
if (options_.transpose()) { if (options_.transpose()) {
tensor_shape = tf::TensorShape({cols, rows}); tensor_shape = tf::TensorShape({cols, rows});
@ -140,7 +142,7 @@ absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) {
AssignMatrixValue<int>(c, r, input[r][c], output.get()); AssignMatrixValue<int>(c, r, input[r][c], output.get());
break; break;
default: default:
LOG(FATAL) << "tensor data type is not supported."; ABSL_LOG(FATAL) << "tensor data type is not supported.";
} }
} }
} }
@ -158,7 +160,7 @@ absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) {
AssignMatrixValue<int>(r, c, input[r][c], output.get()); AssignMatrixValue<int>(r, c, input[r][c], output.get());
break; break;
default: default:
LOG(FATAL) << "tensor data type is not supported."; ABSL_LOG(FATAL) << "tensor data type is not supported.";
} }
} }
} }
@ -171,7 +173,7 @@ absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) {
} else { } else {
input = cc->Inputs().Tag(kVectorInt).Value().Get<std::vector<int>>(); input = cc->Inputs().Tag(kVectorInt).Value().Get<std::vector<int>>();
} }
CHECK_GE(input.size(), 1); ABSL_CHECK_GE(input.size(), 1);
const int32_t length = input.size(); const int32_t length = input.size();
tensor_shape = tf::TensorShape({length}); tensor_shape = tf::TensorShape({length});
auto output = ::absl::make_unique<tf::Tensor>(options_.tensor_data_type(), auto output = ::absl::make_unique<tf::Tensor>(options_.tensor_data_type(),
@ -188,12 +190,12 @@ absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) {
output->tensor<int, 1>()(i) = input.at(i); output->tensor<int, 1>()(i) = input.at(i);
break; break;
default: default:
LOG(FATAL) << "tensor data type is not supported."; ABSL_LOG(FATAL) << "tensor data type is not supported.";
} }
} }
cc->Outputs().Tag(kTensorOut).Add(output.release(), cc->InputTimestamp()); cc->Outputs().Tag(kTensorOut).Add(output.release(), cc->InputTimestamp());
} else { } else {
LOG(FATAL) << "input size not supported"; ABSL_LOG(FATAL) << "input size not supported";
} }
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -15,6 +15,7 @@
// Converts vector<std::string> (or vector<vector<std::string>>) to 1D (or 2D) // Converts vector<std::string> (or vector<vector<std::string>>) to 1D (or 2D)
// tf::Tensor. // tf::Tensor.
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.pb.h" #include "mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -69,7 +70,7 @@ absl::Status VectorStringToTensorCalculator::GetContract(
// Input vector<std::string>. // Input vector<std::string>.
); );
} else { } else {
LOG(FATAL) << "input size not supported"; ABSL_LOG(FATAL) << "input size not supported";
} }
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported."; << "Only one output stream is supported.";
@ -129,7 +130,7 @@ absl::Status VectorStringToTensorCalculator::Process(CalculatorContext* cc) {
} }
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
} else { } else {
LOG(FATAL) << "input size not supported"; ABSL_LOG(FATAL) << "input size not supported";
} }
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -103,6 +103,8 @@ cc_library(
"//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/formats/object_detection:anchor_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -196,10 +198,13 @@ cc_library(
deps = [ deps = [
":tflite_inference_calculator_cc_proto", ":tflite_inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/util/tflite:config", "//mediapipe/util/tflite:config",
"//mediapipe/util/tflite:tflite_model_loader", "//mediapipe/util/tflite:tflite_model_loader",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
@ -275,6 +280,7 @@ cc_library(
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"//mediapipe/util/tflite:config", "//mediapipe/util/tflite:config",
"@com_google_absl//absl/log:absl_check",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
] + selects.with_or({ ] + selects.with_or({
@ -392,6 +398,8 @@ cc_library(
"//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/formats/object_detection:anchor_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/util/tflite:config", "//mediapipe/util/tflite:config",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
@ -428,6 +436,7 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
@ -456,6 +465,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/log:absl_check",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
], ],
alwayslink = 1, alwayslink = 1,

View File

@ -16,6 +16,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" #include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/object_detection/anchor.pb.h" #include "mediapipe/framework/formats/object_detection/anchor.pb.h"
@ -272,13 +274,13 @@ absl::Status SsdAnchorsCalculator::GenerateAnchors(
if (options.feature_map_height_size()) { if (options.feature_map_height_size()) {
if (options.strides_size()) { if (options.strides_size()) {
LOG(ERROR) << "Found feature map shapes. Strides will be ignored."; ABSL_LOG(ERROR) << "Found feature map shapes. Strides will be ignored.";
} }
CHECK_EQ(options.feature_map_height_size(), kNumLayers); ABSL_CHECK_EQ(options.feature_map_height_size(), kNumLayers);
CHECK_EQ(options.feature_map_height_size(), ABSL_CHECK_EQ(options.feature_map_height_size(),
options.feature_map_width_size()); options.feature_map_width_size());
} else { } else {
CHECK_EQ(options.strides_size(), kNumLayers); ABSL_CHECK_EQ(options.strides_size(), kNumLayers);
} }
if (options.multiscale_anchor_generation()) { if (options.multiscale_anchor_generation()) {

View File

@ -15,6 +15,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/log/absl_check.h"
#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
@ -643,7 +644,7 @@ absl::Status TfLiteConverterCalculator::LoadOptions(CalculatorContext* cc) {
if (options.has_output_tensor_float_range()) { if (options.has_output_tensor_float_range()) {
output_range_.emplace(options.output_tensor_float_range().min(), output_range_.emplace(options.output_tensor_float_range().min(),
options.output_tensor_float_range().max()); options.output_tensor_float_range().max());
CHECK_GT(output_range_->second, output_range_->first); ABSL_CHECK_GT(output_range_->second, output_range_->first);
} }
// Custom div and sub values. // Custom div and sub values.
@ -661,9 +662,9 @@ absl::Status TfLiteConverterCalculator::LoadOptions(CalculatorContext* cc) {
// Get desired way to handle input channels. // Get desired way to handle input channels.
max_num_channels_ = options.max_num_channels(); max_num_channels_ = options.max_num_channels();
CHECK_GE(max_num_channels_, 1); ABSL_CHECK_GE(max_num_channels_, 1);
CHECK_LE(max_num_channels_, 4); ABSL_CHECK_LE(max_num_channels_, 4);
CHECK_NE(max_num_channels_, 2); ABSL_CHECK_NE(max_num_channels_, 2);
#if defined(MEDIAPIPE_IOS) #if defined(MEDIAPIPE_IOS)
if (cc->Inputs().HasTag(kGpuBufferTag)) if (cc->Inputs().HasTag(kGpuBufferTag))
// Currently on iOS, tflite gpu input tensor must be 4 channels, // Currently on iOS, tflite gpu input tensor must be 4 channels,

View File

@ -17,9 +17,12 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/config.h" #include "mediapipe/util/tflite/config.h"
@ -109,7 +112,7 @@ std::unique_ptr<tflite::Interpreter> BuildEdgeTpuInterpreter(
edgetpu::EdgeTpuContext* edgetpu_context) { edgetpu::EdgeTpuContext* edgetpu_context) {
resolver->AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp()); resolver->AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp());
std::unique_ptr<tflite::Interpreter> interpreter; std::unique_ptr<tflite::Interpreter> interpreter;
CHECK_EQ(tflite::InterpreterBuilder(model, *resolver)(&interpreter), ABSL_CHECK_EQ(tflite::InterpreterBuilder(model, *resolver)(&interpreter),
kTfLiteOk); kTfLiteOk);
interpreter->SetExternalContext(kTfLiteEdgeTpuContext, edgetpu_context); interpreter->SetExternalContext(kTfLiteEdgeTpuContext, edgetpu_context);
return interpreter; return interpreter;
@ -406,11 +409,12 @@ absl::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) {
} }
if (use_advanced_gpu_api_ && !gpu_input_) { if (use_advanced_gpu_api_ && !gpu_input_) {
LOG(WARNING) << "Cannot use advanced GPU APIs, input must be GPU buffers." ABSL_LOG(WARNING)
<< "Cannot use advanced GPU APIs, input must be GPU buffers."
"Falling back to the default TFLite API."; "Falling back to the default TFLite API.";
use_advanced_gpu_api_ = false; use_advanced_gpu_api_ = false;
} }
CHECK(!use_advanced_gpu_api_ || gpu_inference_); ABSL_CHECK(!use_advanced_gpu_api_ || gpu_inference_);
MP_RETURN_IF_ERROR(LoadModel(cc)); MP_RETURN_IF_ERROR(LoadModel(cc));
@ -802,9 +806,10 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner(
const int tensor_idx = interpreter_->inputs()[i]; const int tensor_idx = interpreter_->inputs()[i];
interpreter_->SetTensorParametersReadWrite(tensor_idx, kTfLiteFloat32, "", interpreter_->SetTensorParametersReadWrite(tensor_idx, kTfLiteFloat32, "",
shape, quant); shape, quant);
CHECK(interpreter_->ResizeInputTensor(tensor_idx, shape) == kTfLiteOk); ABSL_CHECK(interpreter_->ResizeInputTensor(tensor_idx, shape) ==
kTfLiteOk);
} }
CHECK(interpreter_->AllocateTensors() == kTfLiteOk); ABSL_CHECK(interpreter_->AllocateTensors() == kTfLiteOk);
} }
// Create and bind OpenGL buffers for outputs. // Create and bind OpenGL buffers for outputs.
@ -1053,7 +1058,7 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) {
gpu_data_in_[i]->shape.w * gpu_data_in_[i]->shape.c; gpu_data_in_[i]->shape.w * gpu_data_in_[i]->shape.c;
// Input to model can be RGBA only. // Input to model can be RGBA only.
if (tensor->dims->data[3] != 4) { if (tensor->dims->data[3] != 4) {
LOG(WARNING) << "Please ensure input GPU tensor is 4 channels."; ABSL_LOG(WARNING) << "Please ensure input GPU tensor is 4 channels.";
} }
const std::string shader_source = const std::string shader_source =
absl::Substitute(R"(#include <metal_stdlib> absl::Substitute(R"(#include <metal_stdlib>

View File

@ -17,6 +17,7 @@
#include <vector> #include <vector>
#include "absl/container/node_hash_map.h" #include "absl/container/node_hash_map.h"
#include "absl/log/absl_check.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.pb.h"
@ -172,7 +173,7 @@ absl::Status TfLiteTensorsToClassificationCalculator::Process(
// Note that partial_sort will raise error when top_k_ > // Note that partial_sort will raise error when top_k_ >
// classification_list->classification_size(). // classification_list->classification_size().
CHECK_GE(classification_list->classification_size(), top_k_); ABSL_CHECK_GE(classification_list->classification_size(), top_k_);
auto raw_classification_list = classification_list->mutable_classification(); auto raw_classification_list = classification_list->mutable_classification();
if (top_k_ > 0 && classification_list->classification_size() >= top_k_) { if (top_k_ > 0 && classification_list->classification_size() >= top_k_) {
std::partial_sort(raw_classification_list->begin(), std::partial_sort(raw_classification_list->begin(),

View File

@ -15,6 +15,8 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h"
@ -93,7 +95,7 @@ void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
void ConvertAnchorsToRawValues(const std::vector<Anchor>& anchors, void ConvertAnchorsToRawValues(const std::vector<Anchor>& anchors,
int num_boxes, float* raw_anchors) { int num_boxes, float* raw_anchors) {
CHECK_EQ(anchors.size(), num_boxes); ABSL_CHECK_EQ(anchors.size(), num_boxes);
int box = 0; int box = 0;
for (const auto& anchor : anchors) { for (const auto& anchor : anchors) {
raw_anchors[box * kNumCoordsPerBox + 0] = anchor.y_center(); raw_anchors[box * kNumCoordsPerBox + 0] = anchor.y_center();
@ -288,14 +290,14 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU(
const TfLiteTensor* raw_score_tensor = &input_tensors[1]; const TfLiteTensor* raw_score_tensor = &input_tensors[1];
// TODO: Add flexible input tensor size handling. // TODO: Add flexible input tensor size handling.
CHECK_EQ(raw_box_tensor->dims->size, 3); ABSL_CHECK_EQ(raw_box_tensor->dims->size, 3);
CHECK_EQ(raw_box_tensor->dims->data[0], 1); ABSL_CHECK_EQ(raw_box_tensor->dims->data[0], 1);
CHECK_EQ(raw_box_tensor->dims->data[1], num_boxes_); ABSL_CHECK_EQ(raw_box_tensor->dims->data[1], num_boxes_);
CHECK_EQ(raw_box_tensor->dims->data[2], num_coords_); ABSL_CHECK_EQ(raw_box_tensor->dims->data[2], num_coords_);
CHECK_EQ(raw_score_tensor->dims->size, 3); ABSL_CHECK_EQ(raw_score_tensor->dims->size, 3);
CHECK_EQ(raw_score_tensor->dims->data[0], 1); ABSL_CHECK_EQ(raw_score_tensor->dims->data[0], 1);
CHECK_EQ(raw_score_tensor->dims->data[1], num_boxes_); ABSL_CHECK_EQ(raw_score_tensor->dims->data[1], num_boxes_);
CHECK_EQ(raw_score_tensor->dims->data[2], num_classes_); ABSL_CHECK_EQ(raw_score_tensor->dims->data[2], num_classes_);
const float* raw_boxes = raw_box_tensor->data.f; const float* raw_boxes = raw_box_tensor->data.f;
const float* raw_scores = raw_score_tensor->data.f; const float* raw_scores = raw_score_tensor->data.f;
@ -303,13 +305,13 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU(
if (!anchors_init_) { if (!anchors_init_) {
if (input_tensors.size() == kNumInputTensorsWithAnchors) { if (input_tensors.size() == kNumInputTensorsWithAnchors) {
const TfLiteTensor* anchor_tensor = &input_tensors[2]; const TfLiteTensor* anchor_tensor = &input_tensors[2];
CHECK_EQ(anchor_tensor->dims->size, 2); ABSL_CHECK_EQ(anchor_tensor->dims->size, 2);
CHECK_EQ(anchor_tensor->dims->data[0], num_boxes_); ABSL_CHECK_EQ(anchor_tensor->dims->data[0], num_boxes_);
CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox); ABSL_CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox);
const float* raw_anchors = anchor_tensor->data.f; const float* raw_anchors = anchor_tensor->data.f;
ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_); ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_);
} else if (side_packet_anchors_) { } else if (side_packet_anchors_) {
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); ABSL_CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
anchors_ = anchors_ =
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>(); cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
} else { } else {
@ -409,7 +411,7 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
CopyBuffer(input_tensors[1], gpu_data_->raw_scores_buffer)); CopyBuffer(input_tensors[1], gpu_data_->raw_scores_buffer));
if (!anchors_init_) { if (!anchors_init_) {
if (side_packet_anchors_) { if (side_packet_anchors_) {
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); ABSL_CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
const auto& anchors = const auto& anchors =
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>(); cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox); std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox);
@ -417,7 +419,7 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
MP_RETURN_IF_ERROR(gpu_data_->raw_anchors_buffer.Write<float>( MP_RETURN_IF_ERROR(gpu_data_->raw_anchors_buffer.Write<float>(
absl::MakeSpan(raw_anchors))); absl::MakeSpan(raw_anchors)));
} else { } else {
CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); ABSL_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
CopyBuffer(input_tensors[2], gpu_data_->raw_anchors_buffer)); CopyBuffer(input_tensors[2], gpu_data_->raw_anchors_buffer));
} }
@ -477,7 +479,7 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
commandBuffer:[gpu_helper_ commandBuffer]]; commandBuffer:[gpu_helper_ commandBuffer]];
if (!anchors_init_) { if (!anchors_init_) {
if (side_packet_anchors_) { if (side_packet_anchors_) {
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); ABSL_CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
const auto& anchors = const auto& anchors =
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>(); cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox); std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox);
@ -541,7 +543,7 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
output_detections)); output_detections));
#else #else
LOG(ERROR) << "GPU input on non-Android not supported yet."; ABSL_LOG(ERROR) << "GPU input on non-Android not supported yet.";
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE #endif // MEDIAPIPE_TFLITE_GL_INFERENCE
return absl::OkStatus(); return absl::OkStatus();
} }
@ -567,10 +569,10 @@ absl::Status TfLiteTensorsToDetectionsCalculator::LoadOptions(
num_coords_ = options_.num_coords(); num_coords_ = options_.num_coords();
// Currently only support 2D when num_values_per_keypoint equals to 2. // Currently only support 2D when num_values_per_keypoint equals to 2.
CHECK_EQ(options_.num_values_per_keypoint(), 2); ABSL_CHECK_EQ(options_.num_values_per_keypoint(), 2);
// Check if the output size is equal to the requested boxes and keypoints. // Check if the output size is equal to the requested boxes and keypoints.
CHECK_EQ(options_.num_keypoints() * options_.num_values_per_keypoint() + ABSL_CHECK_EQ(options_.num_keypoints() * options_.num_values_per_keypoint() +
kNumCoordsPerBox, kNumCoordsPerBox,
num_coords_); num_coords_);
@ -897,10 +899,11 @@ void main() {
int max_wg_size; // typically <= 1024 int max_wg_size; // typically <= 1024
glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
&max_wg_size); // y-dim &max_wg_size); // y-dim
CHECK_LT(num_classes_, max_wg_size) ABSL_CHECK_LT(num_classes_, max_wg_size)
<< "# classes must be < " << max_wg_size; << "# classes must be < " << max_wg_size;
// TODO support better filtering. // TODO support better filtering.
CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; ABSL_CHECK_LE(ignore_classes_.size(), 1)
<< "Only ignore class 0 is allowed";
// Shader program // Shader program
GlShader score_shader; GlShader score_shader;
@ -1115,7 +1118,7 @@ kernel void scoreKernel(
ignore_classes_.size() ? 1 : 0); ignore_classes_.size() ? 1 : 0);
// TODO support better filtering. // TODO support better filtering.
CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; ABSL_CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed";
{ {
// Shader program // Shader program
@ -1147,7 +1150,8 @@ kernel void scoreKernel(
options:MTLResourceStorageModeShared]; options:MTLResourceStorageModeShared];
// # filter classes supported is hardware dependent. // # filter classes supported is hardware dependent.
int max_wg_size = gpu_data_->score_program.maxTotalThreadsPerThreadgroup; int max_wg_size = gpu_data_->score_program.maxTotalThreadsPerThreadgroup;
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size; ABSL_CHECK_LT(num_classes_, max_wg_size)
<< "# classes must be <" << max_wg_size;
} }
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE #endif // MEDIAPIPE_TFLITE_GL_INFERENCE

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/log/absl_check.h"
#include "mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
@ -199,7 +200,7 @@ absl::Status TfLiteTensorsToLandmarksCalculator::Process(
num_values *= raw_tensor->dims->data[i]; num_values *= raw_tensor->dims->data[i];
} }
const int num_dimensions = num_values / num_landmarks_; const int num_dimensions = num_values / num_landmarks_;
CHECK_GT(num_dimensions, 0); ABSL_CHECK_GT(num_dimensions, 0);
const float* raw_landmarks = raw_tensor->data.f; const float* raw_landmarks = raw_tensor->data.f;

View File

@ -183,9 +183,9 @@ cc_library(
"//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",
"//mediapipe/framework/deps:clock", "//mediapipe/framework/deps:clock",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/time", "@com_google_absl//absl/time",
], ],
@ -248,11 +248,12 @@ cc_library(
":annotation_overlay_calculator_cc_proto", ":annotation_overlay_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:image_opencv",
"//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -260,6 +261,7 @@ cc_library(
"//mediapipe/util:annotation_renderer", "//mediapipe/util:annotation_renderer",
"//mediapipe/util:color_cc_proto", "//mediapipe/util:color_cc_proto",
"//mediapipe/util:render_data_cc_proto", "//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
@ -374,9 +376,10 @@ cc_library(
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:rectangle", "//mediapipe/framework/port:rectangle",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -675,6 +678,7 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/util:color_cc_proto", "//mediapipe/util:color_cc_proto",
"//mediapipe/util:render_data_cc_proto", "//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
@ -731,6 +735,7 @@ cc_library(
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"//mediapipe/util:color_cc_proto", "//mediapipe/util:color_cc_proto",
"//mediapipe/util:render_data_cc_proto", "//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
alwayslink = 1, alwayslink = 1,
@ -746,6 +751,7 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/util:color_cc_proto", "//mediapipe/util:color_cc_proto",
"//mediapipe/util:render_data_cc_proto", "//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/log:absl_check",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -1149,6 +1155,7 @@ cc_library(
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/log:absl_log",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -1209,6 +1216,7 @@ cc_library(
"//mediapipe/framework/port:rectangle", "//mediapipe/framework/port:rectangle",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:rectangle_util", "//mediapipe/util:rectangle_util",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
], ],
alwayslink = 1, alwayslink = 1,
@ -1480,6 +1488,7 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
], ],
alwayslink = 1, alwayslink = 1,

View File

@ -14,15 +14,17 @@
#include <memory> #include <memory>
#include "absl/log/absl_log.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/calculators/util/annotation_overlay_calculator.pb.h" #include "mediapipe/calculators/util/annotation_overlay_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/image_opencv.h"
#include "mediapipe/framework/formats/video_stream_header.h" #include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
@ -45,6 +47,7 @@ namespace {
constexpr char kVectorTag[] = "VECTOR"; constexpr char kVectorTag[] = "VECTOR";
constexpr char kGpuBufferTag[] = "IMAGE_GPU"; constexpr char kGpuBufferTag[] = "IMAGE_GPU";
constexpr char kImageFrameTag[] = "IMAGE"; constexpr char kImageFrameTag[] = "IMAGE";
constexpr char kImageTag[] = "UIMAGE"; // Universal Image
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
@ -57,13 +60,16 @@ size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT
constexpr uchar kAnnotationBackgroundColor = 2; // Grayscale value. constexpr uchar kAnnotationBackgroundColor = 2; // Grayscale value.
// Future Image type. // Future Image type.
inline bool HasImageTag(mediapipe::CalculatorContext* cc) { return false; } inline bool HasImageTag(mediapipe::CalculatorContext* cc) {
return cc->Inputs().HasTag(kImageTag);
}
} // namespace } // namespace
// A calculator for rendering data on images. // A calculator for rendering data on images.
// //
// Inputs: // Inputs:
// 1. IMAGE or IMAGE_GPU (optional): An ImageFrame (or GpuBuffer), // 1. IMAGE or IMAGE_GPU (optional): An ImageFrame (or GpuBuffer),
// or UIMAGE (an Image).
// containing the input image. // containing the input image.
// If output is CPU, and input isn't provided, the renderer creates a // If output is CPU, and input isn't provided, the renderer creates a
// blank canvas with the width, height and color provided in the options. // blank canvas with the width, height and color provided in the options.
@ -76,6 +82,7 @@ inline bool HasImageTag(mediapipe::CalculatorContext* cc) { return false; }
// //
// Output: // Output:
// 1. IMAGE or IMAGE_GPU: A rendered ImageFrame (or GpuBuffer), // 1. IMAGE or IMAGE_GPU: A rendered ImageFrame (or GpuBuffer),
// or UIMAGE (an Image).
// Note: Output types should match their corresponding input stream type. // Note: Output types should match their corresponding input stream type.
// //
// For CPU input frames, only SRGBA, SRGB and GRAY8 format are supported. The // For CPU input frames, only SRGBA, SRGB and GRAY8 format are supported. The
@ -135,6 +142,9 @@ class AnnotationOverlayCalculator : public CalculatorBase {
absl::Status CreateRenderTargetCpu(CalculatorContext* cc, absl::Status CreateRenderTargetCpu(CalculatorContext* cc,
std::unique_ptr<cv::Mat>& image_mat, std::unique_ptr<cv::Mat>& image_mat,
ImageFormat::Format* target_format); ImageFormat::Format* target_format);
absl::Status CreateRenderTargetCpuImage(CalculatorContext* cc,
std::unique_ptr<cv::Mat>& image_mat,
ImageFormat::Format* target_format);
template <typename Type, const char* Tag> template <typename Type, const char* Tag>
absl::Status CreateRenderTargetGpu(CalculatorContext* cc, absl::Status CreateRenderTargetGpu(CalculatorContext* cc,
std::unique_ptr<cv::Mat>& image_mat); std::unique_ptr<cv::Mat>& image_mat);
@ -172,30 +182,38 @@ class AnnotationOverlayCalculator : public CalculatorBase {
REGISTER_CALCULATOR(AnnotationOverlayCalculator); REGISTER_CALCULATOR(AnnotationOverlayCalculator);
absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) { absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1); RET_CHECK_GE(cc->Inputs().NumEntries(), 1);
bool use_gpu = false; bool use_gpu = false;
if (cc->Inputs().HasTag(kImageFrameTag) && RET_CHECK(cc->Inputs().HasTag(kImageFrameTag) +
cc->Inputs().HasTag(kGpuBufferTag)) { cc->Inputs().HasTag(kGpuBufferTag) +
return absl::InternalError("Cannot have multiple input images."); cc->Inputs().HasTag(kImageTag) <=
} 1);
if (cc->Inputs().HasTag(kGpuBufferTag) != RET_CHECK(cc->Outputs().HasTag(kImageFrameTag) +
cc->Outputs().HasTag(kGpuBufferTag)) { cc->Outputs().HasTag(kGpuBufferTag) +
return absl::InternalError("GPU output must have GPU input."); cc->Outputs().HasTag(kImageTag) ==
} 1);
// Input image to render onto copy of. Should be same type as output. // Input image to render onto copy of. Should be same type as output.
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kGpuBufferTag)) { if (cc->Inputs().HasTag(kGpuBufferTag)) {
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
CHECK(cc->Outputs().HasTag(kGpuBufferTag)); RET_CHECK(cc->Outputs().HasTag(kGpuBufferTag));
use_gpu = true; use_gpu = true;
} }
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kImageFrameTag)) { if (cc->Inputs().HasTag(kImageFrameTag)) {
cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>(); cc->Inputs().Tag(kImageFrameTag).Set<ImageFrame>();
CHECK(cc->Outputs().HasTag(kImageFrameTag)); RET_CHECK(cc->Outputs().HasTag(kImageFrameTag));
}
if (cc->Inputs().HasTag(kImageTag)) {
cc->Inputs().Tag(kImageTag).Set<mediapipe::Image>();
RET_CHECK(cc->Outputs().HasTag(kImageTag));
#if !MEDIAPIPE_DISABLE_GPU
use_gpu = true; // Prepare GPU resources because images can come in on GPU.
#endif
} }
// Data streams to render. // Data streams to render.
@ -220,6 +238,9 @@ absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) {
if (cc->Outputs().HasTag(kImageFrameTag)) { if (cc->Outputs().HasTag(kImageFrameTag)) {
cc->Outputs().Tag(kImageFrameTag).Set<ImageFrame>(); cc->Outputs().Tag(kImageFrameTag).Set<ImageFrame>();
} }
if (cc->Outputs().HasTag(kImageTag)) {
cc->Outputs().Tag(kImageTag).Set<mediapipe::Image>();
}
if (use_gpu) { if (use_gpu) {
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
@ -252,9 +273,14 @@ absl::Status AnnotationOverlayCalculator::Open(CalculatorContext* cc) {
renderer_ = absl::make_unique<AnnotationRenderer>(); renderer_ = absl::make_unique<AnnotationRenderer>();
renderer_->SetFlipTextVertically(options_.flip_text_vertically()); renderer_->SetFlipTextVertically(options_.flip_text_vertically());
if (use_gpu_) renderer_->SetScaleFactor(options_.gpu_scale_factor()); if (use_gpu_) renderer_->SetScaleFactor(options_.gpu_scale_factor());
if (renderer_->GetScaleFactor() < 1.0 && HasImageTag(cc))
ABSL_LOG(WARNING)
<< "Annotation scale factor only supports GPU backed Image.";
// Set the output header based on the input header (if present). // Set the output header based on the input header (if present).
const char* tag = use_gpu_ ? kGpuBufferTag : kImageFrameTag; const char* tag = HasImageTag(cc) ? kImageTag
: use_gpu_ ? kGpuBufferTag
: kImageFrameTag;
if (image_frame_available_ && !cc->Inputs().Tag(tag).Header().IsEmpty()) { if (image_frame_available_ && !cc->Inputs().Tag(tag).Header().IsEmpty()) {
const auto& input_header = const auto& input_header =
cc->Inputs().Tag(tag).Header().Get<VideoHeader>(); cc->Inputs().Tag(tag).Header().Get<VideoHeader>();
@ -280,6 +306,12 @@ absl::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) {
cc->Inputs().Tag(kImageFrameTag).IsEmpty()) { cc->Inputs().Tag(kImageFrameTag).IsEmpty()) {
return absl::OkStatus(); return absl::OkStatus();
} }
if (cc->Inputs().HasTag(kImageTag) && cc->Inputs().Tag(kImageTag).IsEmpty()) {
return absl::OkStatus();
}
if (HasImageTag(cc)) {
use_gpu_ = cc->Inputs().Tag(kImageTag).Get<mediapipe::Image>().UsesGpu();
}
// Initialize render target, drawn with OpenCV. // Initialize render target, drawn with OpenCV.
std::unique_ptr<cv::Mat> image_mat; std::unique_ptr<cv::Mat> image_mat;
@ -289,10 +321,17 @@ absl::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) {
if (!gpu_initialized_) { if (!gpu_initialized_) {
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { gpu_helper_.RunInGlContext([this, cc]() -> absl::Status {
if (HasImageTag(cc)) {
return GlSetup<mediapipe::Image, kImageTag>(cc);
}
return GlSetup<mediapipe::GpuBuffer, kGpuBufferTag>(cc); return GlSetup<mediapipe::GpuBuffer, kGpuBufferTag>(cc);
})); }));
gpu_initialized_ = true; gpu_initialized_ = true;
} }
if (HasImageTag(cc)) {
MP_RETURN_IF_ERROR(
(CreateRenderTargetGpu<mediapipe::Image, kImageTag>(cc, image_mat)));
}
if (cc->Inputs().HasTag(kGpuBufferTag)) { if (cc->Inputs().HasTag(kGpuBufferTag)) {
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
(CreateRenderTargetGpu<mediapipe::GpuBuffer, kGpuBufferTag>( (CreateRenderTargetGpu<mediapipe::GpuBuffer, kGpuBufferTag>(
@ -300,6 +339,10 @@ absl::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) {
} }
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else {
if (cc->Outputs().HasTag(kImageTag)) {
MP_RETURN_IF_ERROR(
CreateRenderTargetCpuImage(cc, image_mat, &target_format));
}
if (cc->Outputs().HasTag(kImageFrameTag)) { if (cc->Outputs().HasTag(kImageFrameTag)) {
MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format)); MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format));
} }
@ -339,6 +382,9 @@ absl::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) {
uchar* image_mat_ptr = image_mat->data; uchar* image_mat_ptr = image_mat->data;
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc, image_mat_ptr]() -> absl::Status { gpu_helper_.RunInGlContext([this, cc, image_mat_ptr]() -> absl::Status {
if (HasImageTag(cc)) {
return RenderToGpu<mediapipe::Image, kImageTag>(cc, image_mat_ptr);
}
return RenderToGpu<mediapipe::GpuBuffer, kGpuBufferTag>( return RenderToGpu<mediapipe::GpuBuffer, kGpuBufferTag>(
cc, image_mat_ptr); cc, image_mat_ptr);
})); }));
@ -381,6 +427,10 @@ absl::Status AnnotationOverlayCalculator::RenderToCpu(
ImageFrame::kDefaultAlignmentBoundary); ImageFrame::kDefaultAlignmentBoundary);
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
if (HasImageTag(cc)) {
auto out = std::make_unique<mediapipe::Image>(std::move(output_frame));
cc->Outputs().Tag(kImageTag).Add(out.release(), cc->InputTimestamp());
}
if (cc->Outputs().HasTag(kImageFrameTag)) { if (cc->Outputs().HasTag(kImageFrameTag)) {
cc->Outputs() cc->Outputs()
.Tag(kImageFrameTag) .Tag(kImageFrameTag)
@ -487,6 +537,54 @@ absl::Status AnnotationOverlayCalculator::CreateRenderTargetCpu(
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status AnnotationOverlayCalculator::CreateRenderTargetCpuImage(
CalculatorContext* cc, std::unique_ptr<cv::Mat>& image_mat,
ImageFormat::Format* target_format) {
if (image_frame_available_) {
const auto& input_frame =
cc->Inputs().Tag(kImageTag).Get<mediapipe::Image>();
int target_mat_type;
switch (input_frame.image_format()) {
case ImageFormat::SRGBA:
*target_format = ImageFormat::SRGBA;
target_mat_type = CV_8UC4;
break;
case ImageFormat::SRGB:
*target_format = ImageFormat::SRGB;
target_mat_type = CV_8UC3;
break;
case ImageFormat::GRAY8:
*target_format = ImageFormat::SRGB;
target_mat_type = CV_8UC3;
break;
default:
return absl::UnknownError("Unexpected image frame format.");
break;
}
image_mat = absl::make_unique<cv::Mat>(
input_frame.height(), input_frame.width(), target_mat_type);
auto input_mat = formats::MatView(&input_frame);
if (input_frame.image_format() == ImageFormat::GRAY8) {
cv::Mat rgb_mat;
cv::cvtColor(*input_mat, rgb_mat, cv::COLOR_GRAY2RGB);
rgb_mat.copyTo(*image_mat);
} else {
input_mat->copyTo(*image_mat);
}
} else {
image_mat = absl::make_unique<cv::Mat>(
options_.canvas_height_px(), options_.canvas_width_px(), CV_8UC3,
cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(),
options_.canvas_color().b()));
*target_format = ImageFormat::SRGB;
}
return absl::OkStatus();
}
template <typename Type, const char* Tag> template <typename Type, const char* Tag>
absl::Status AnnotationOverlayCalculator::CreateRenderTargetGpu( absl::Status AnnotationOverlayCalculator::CreateRenderTargetGpu(
CalculatorContext* cc, std::unique_ptr<cv::Mat>& image_mat) { CalculatorContext* cc, std::unique_ptr<cv::Mat>& image_mat) {

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/log/absl_check.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "mediapipe/calculators/util/association_calculator.pb.h" #include "mediapipe/calculators/util/association_calculator.pb.h"
#include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_context.h"
@ -72,7 +73,7 @@ class AssociationCalculator : public CalculatorBase {
prev_input_stream_id_ = cc->Inputs().GetId("PREV", 0); prev_input_stream_id_ = cc->Inputs().GetId("PREV", 0);
} }
options_ = cc->Options<::mediapipe::AssociationCalculatorOptions>(); options_ = cc->Options<::mediapipe::AssociationCalculatorOptions>();
CHECK_GE(options_.min_similarity_threshold(), 0); ABSL_CHECK_GE(options_.min_similarity_threshold(), 0);
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -19,6 +19,7 @@
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
@ -85,7 +86,8 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
ASSIGN_OR_RETURN(string_path, ASSIGN_OR_RETURN(string_path,
PathToResourceAsFile(options.label_map_path())); PathToResourceAsFile(options.label_map_path()));
std::string label_map_string; std::string label_map_string;
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); MP_RETURN_IF_ERROR(
mediapipe::GetResourceContents(string_path, &label_map_string));
std::istringstream stream(label_map_string); std::istringstream stream(label_map_string);
std::string line; std::string line;

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/log/absl_check.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
@ -233,13 +234,13 @@ void DetectionsToRenderDataCalculator::AddLabels(
const Detection& detection, const Detection& detection,
const DetectionsToRenderDataCalculatorOptions& options, const DetectionsToRenderDataCalculatorOptions& options,
float text_line_height, RenderData* render_data) { float text_line_height, RenderData* render_data) {
CHECK(detection.label().empty() || detection.label_id().empty() || ABSL_CHECK(detection.label().empty() || detection.label_id().empty() ||
detection.label_size() == detection.label_id_size()) detection.label_size() == detection.label_id_size())
<< "String or integer labels should be of same size. Or only one of them " << "String or integer labels should be of same size. Or only one of them "
"is present."; "is present.";
const auto num_labels = const auto num_labels =
std::max(detection.label_size(), detection.label_id_size()); std::max(detection.label_size(), detection.label_id_size());
CHECK_EQ(detection.score_size(), num_labels) ABSL_CHECK_EQ(detection.score_size(), num_labels)
<< "Number of scores and labels should match for detection."; << "Number of scores and labels should match for detection.";
// Extracts all "label(_id),score" for the detection. // Extracts all "label(_id),score" for the detection.
@ -361,7 +362,7 @@ void DetectionsToRenderDataCalculator::AddDetectionToRenderData(
const Detection& detection, const Detection& detection,
const DetectionsToRenderDataCalculatorOptions& options, const DetectionsToRenderDataCalculatorOptions& options,
RenderData* render_data) { RenderData* render_data) {
CHECK(detection.location_data().format() == LocationData::BOUNDING_BOX || ABSL_CHECK(detection.location_data().format() == LocationData::BOUNDING_BOX ||
detection.location_data().format() == detection.location_data().format() ==
LocationData::RELATIVE_BOUNDING_BOX) LocationData::RELATIVE_BOUNDING_BOX)
<< "Only Detection with formats of BOUNDING_BOX or RELATIVE_BOUNDING_BOX " << "Only Detection with formats of BOUNDING_BOX or RELATIVE_BOUNDING_BOX "

View File

@ -19,6 +19,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/log/absl_check.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/calculators/util/labels_to_render_data_calculator.pb.h" #include "mediapipe/calculators/util/labels_to_render_data_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -114,7 +115,8 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
video_height_ = video_header.height; video_height_ = video_header.height;
return absl::OkStatus(); return absl::OkStatus();
} else { } else {
CHECK_EQ(options_.location(), LabelsToRenderDataCalculatorOptions::TOP_LEFT) ABSL_CHECK_EQ(options_.location(),
LabelsToRenderDataCalculatorOptions::TOP_LEFT)
<< "Only TOP_LEFT is supported without VIDEO_PRESTREAM."; << "Only TOP_LEFT is supported without VIDEO_PRESTREAM.";
} }
@ -144,7 +146,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
if (cc->Inputs().HasTag(kScoresTag)) { if (cc->Inputs().HasTag(kScoresTag)) {
std::vector<float> score_vector = std::vector<float> score_vector =
cc->Inputs().Tag(kScoresTag).Get<std::vector<float>>(); cc->Inputs().Tag(kScoresTag).Get<std::vector<float>>();
CHECK_EQ(label_vector.size(), score_vector.size()); ABSL_CHECK_EQ(label_vector.size(), score_vector.size());
scores.resize(label_vector.size()); scores.resize(label_vector.size());
for (int i = 0; i < label_vector.size(); ++i) { for (int i = 0; i < label_vector.size(); ++i) {
scores[i] = score_vector[i]; scores[i] = score_vector[i];

Some files were not shown because too many files have changed in this diff Show More