2
.bazelrc
|
@ -3,7 +3,7 @@
|
|||
# Basic build settings
|
||||
build --jobs 128
|
||||
build --define='absl=1'
|
||||
build --cxxopt='-std=c++11'
|
||||
build --cxxopt='-std=c++14'
|
||||
build --copt='-Wno-sign-compare'
|
||||
build --copt='-Wno-unused-function'
|
||||
build --copt='-Wno-uninitialized'
|
||||
|
|
10
README.md
|
@ -37,10 +37,18 @@ A web-based visualizer is hosted on [viz.mediapipe.dev](https://viz.mediapipe.de
|
|||
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General community discussion around MediaPipe
|
||||
|
||||
## Publications
|
||||
* [On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html)
|
||||
* [MediaPipe: A Framework for Building Perception Pipelines](https://arxiv.org/abs/1906.08172)
|
||||
|
||||
## Events
|
||||
[Open sourced at CVPR 2019](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) on June 17~20 in Long Beach, CA
|
||||
* [MediaPipe Madrid Meetup, 16 Dec 2019](https://www.meetup.com/Madrid-AI-Developers-Group/events/266329088/)
|
||||
* [MediaPipe London Meetup, Google 123 Building, 12 Dec 2019](https://www.meetup.com/London-AI-Tech-Talk/events/266329038)
|
||||
* [ML Conference, Berlin, 11 Dec 2019](https://mlconference.ai/machine-learning-advanced-development/mediapipe-building-real-time-cross-platform-mobile-web-edge-desktop-video-audio-ml-pipelines/)
|
||||
* [MediaPipe Berlin Meetup, Google Berlin, 11 Dec 2019](https://www.meetup.com/Berlin-AI-Tech-Talk/events/266328794/)
|
||||
* [The 3rd Workshop on YouTube-8M Large Scale Video Understanding Workshop](https://research.google.com/youtube8m/workshop2019/index.html) Seoul, Korea ICCV 2019
|
||||
* [AI DevWorld 2019](https://aidevworld.com) on Oct 10 in San Jose, California
|
||||
* [Google Industry Workshop at ICIP 2019](http://2019.ieeeicip.org/?action=page4&id=14#Google) [Presentation](https://docs.google.com/presentation/d/e/2PACX-1vRIBBbO_LO9v2YmvbHHEt1cwyqH6EjDxiILjuT0foXy1E7g6uyh4CesB2DkkEwlRDO9_lWfuKMZx98T/pub?start=false&loop=false&delayms=3000&slide=id.g556cc1a659_0_5) on Sept 24 in Taipei, Taiwan
|
||||
* [Open sourced at CVPR 2019](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) on June 17~20 in Long Beach, CA
|
||||
|
||||
## Alpha Disclaimer
|
||||
MediaPipe is currently in alpha for v0.6. We are still making breaking API changes and expect to get to stable API by v1.0.
|
||||
|
|
28
WORKSPACE
|
@ -103,9 +103,9 @@ http_archive(
|
|||
],
|
||||
)
|
||||
|
||||
# 2019-08-15
|
||||
_TENSORFLOW_GIT_COMMIT = "67def62936e28f97c16182dfcc467d8d1cae02b4"
|
||||
_TENSORFLOW_SHA256= "ddd4e3c056e7c0ff2ef29133b30fa62781dfbf8a903e99efb91a02d292fa9562"
|
||||
# 2019-11-12
|
||||
_TENSORFLOW_GIT_COMMIT = "a5f9bcd64453ff3d1f64cb4da4786db3d2da7f82"
|
||||
_TENSORFLOW_SHA256= "f2b6f2ab2ffe63e86eccd3ce4bea6b7197383d726638dfeeebcdc1e7de73f075"
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
urls = [
|
||||
|
@ -114,13 +114,6 @@ http_archive(
|
|||
],
|
||||
strip_prefix = "tensorflow-%s" % _TENSORFLOW_GIT_COMMIT,
|
||||
sha256 = _TENSORFLOW_SHA256,
|
||||
patches = [
|
||||
"@//third_party:tensorflow_065c20bf79253257c87bd4614bb9a7fdef015cbb.diff",
|
||||
"@//third_party:tensorflow_f67fcbefce906cd419e4657f0d41e21019b71abd.diff",
|
||||
],
|
||||
patch_args = [
|
||||
"-p1",
|
||||
],
|
||||
)
|
||||
|
||||
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
|
||||
|
@ -254,18 +247,11 @@ android_sdk_repository(
|
|||
|
||||
# iOS basic build deps.
|
||||
|
||||
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
|
||||
|
||||
git_repository(
|
||||
http_archive(
|
||||
name = "build_bazel_rules_apple",
|
||||
remote = "https://github.com/bazelbuild/rules_apple.git",
|
||||
tag = "0.18.0",
|
||||
patches = [
|
||||
"@//third_party:rules_apple_c0863d0596ae6b769a29fa3fb72ff036444fd249.diff",
|
||||
],
|
||||
patch_args = [
|
||||
"-p1",
|
||||
],
|
||||
sha256 = "bdc8e66e70b8a75da23b79f1f8c6207356df07d041d96d2189add7ee0780cf4e",
|
||||
strip_prefix = "rules_apple-b869b0d3868d78a1d4ffd866ccb304fb68aa12c3",
|
||||
url = "https://github.com/bazelbuild/rules_apple/archive/b869b0d3868d78a1d4ffd866ccb304fb68aa12c3.tar.gz",
|
||||
)
|
||||
|
||||
load(
|
||||
|
|
|
@ -113,8 +113,15 @@ class SpectrogramCalculator : public CalculatorBase {
|
|||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
Timestamp CurrentOutputTimestamp() {
|
||||
// Current output timestamp is the *center* of the next frame to be
|
||||
Timestamp CurrentOutputTimestamp(CalculatorContext* cc) {
|
||||
if (use_local_timestamp_) {
|
||||
return cc->InputTimestamp();
|
||||
}
|
||||
return CumulativeOutputTimestamp();
|
||||
}
|
||||
|
||||
Timestamp CumulativeOutputTimestamp() {
|
||||
// Cumulative output timestamp is the *center* of the next frame to be
|
||||
// emitted, hence delayed by half a window duration compared to relevant
|
||||
// input timestamp.
|
||||
return initial_input_timestamp_ +
|
||||
|
@ -141,6 +148,7 @@ class SpectrogramCalculator : public CalculatorBase {
|
|||
const OutputMatrixType postprocess_output_fn(const OutputMatrixType&),
|
||||
CalculatorContext* cc);
|
||||
|
||||
bool use_local_timestamp_;
|
||||
double input_sample_rate_;
|
||||
bool pad_final_packet_;
|
||||
int frame_duration_samples_;
|
||||
|
@ -173,6 +181,8 @@ const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518;
|
|||
SpectrogramCalculatorOptions spectrogram_options =
|
||||
cc->Options<SpectrogramCalculatorOptions>();
|
||||
|
||||
use_local_timestamp_ = spectrogram_options.use_local_timestamp();
|
||||
|
||||
if (spectrogram_options.frame_duration_seconds() <= 0.0) {
|
||||
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Invalid or missing frame_duration_seconds.\n"
|
||||
|
@ -351,11 +361,11 @@ template <class OutputMatrixType>
|
|||
<< "Inconsistent number of spectrogram channels.";
|
||||
if (allow_multichannel_input_) {
|
||||
cc->Outputs().Index(0).Add(spectrogram_matrices.release(),
|
||||
CurrentOutputTimestamp());
|
||||
CurrentOutputTimestamp(cc));
|
||||
} else {
|
||||
cc->Outputs().Index(0).Add(
|
||||
new OutputMatrixType(spectrogram_matrices->at(0)),
|
||||
CurrentOutputTimestamp());
|
||||
CurrentOutputTimestamp(cc));
|
||||
}
|
||||
cumulative_completed_frames_ += output_vectors.size();
|
||||
}
|
||||
|
|
|
@ -66,4 +66,11 @@ message SpectrogramCalculatorOptions {
|
|||
// uniformly regardless of output type (i.e., even dBs are multiplied, not
|
||||
// offset).
|
||||
optional double output_scale = 7 [default = 1.0];
|
||||
|
||||
// If use_local_timestamp is true, the output packet's timestamp is based on
|
||||
// the last sample of the packet and it's inferred from the latest input
|
||||
// packet's timestamp. If false, the output packet's timestamp is based on
|
||||
// the cumulative timestamping, which is inferred from the intial input
|
||||
// timestamp and the cumulative number of samples.
|
||||
optional bool use_local_timestamp = 8 [default = false];
|
||||
}
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "concatenate_vector_calculator_proto",
|
||||
srcs = ["concatenate_vector_calculator.proto"],
|
||||
|
@ -26,6 +26,13 @@ proto_library(
|
|||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "dequantize_byte_array_calculator_proto",
|
||||
srcs = ["dequantize_byte_array_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "packet_cloner_calculator_proto",
|
||||
srcs = ["packet_cloner_calculator.proto"],
|
||||
|
@ -72,6 +79,13 @@ proto_library(
|
|||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "clip_vector_size_calculator_proto",
|
||||
srcs = ["clip_vector_size_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "packet_cloner_calculator_cc_proto",
|
||||
srcs = ["packet_cloner_calculator.proto"],
|
||||
|
@ -104,6 +118,22 @@ mediapipe_cc_proto_library(
|
|||
deps = [":concatenate_vector_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "clip_vector_size_calculator_cc_proto",
|
||||
srcs = ["clip_vector_size_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":clip_vector_size_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "dequantize_byte_array_calculator_cc_proto",
|
||||
srcs = ["dequantize_byte_array_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":dequantize_byte_array_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "quantize_float_vector_calculator_cc_proto",
|
||||
srcs = ["quantize_float_vector_calculator.proto"],
|
||||
|
@ -154,6 +184,66 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "begin_loop_calculator",
|
||||
srcs = ["begin_loop_calculator.cc"],
|
||||
hdrs = ["begin_loop_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "end_loop_calculator",
|
||||
srcs = ["end_loop_calculator.cc"],
|
||||
hdrs = ["end_loop_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "begin_end_loop_calculator_graph_test",
|
||||
srcs = ["begin_end_loop_calculator_graph_test.cc"],
|
||||
deps = [
|
||||
":begin_loop_calculator",
|
||||
":end_loop_calculator",
|
||||
"//mediapipe/calculators/core:packet_cloner_calculator",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concatenate_vector_calculator",
|
||||
srcs = ["concatenate_vector_calculator.cc"],
|
||||
|
@ -204,6 +294,50 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "clip_vector_size_calculator",
|
||||
srcs = ["clip_vector_size_calculator.cc"],
|
||||
hdrs = ["clip_vector_size_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":clip_vector_size_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "clip_detection_vector_size_calculator",
|
||||
srcs = ["clip_detection_vector_size_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":clip_vector_size_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "clip_vector_size_calculator_test",
|
||||
srcs = ["clip_vector_size_calculator_test.cc"],
|
||||
deps = [
|
||||
":clip_vector_size_calculator",
|
||||
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "counting_source_calculator",
|
||||
srcs = ["counting_source_calculator.cc"],
|
||||
|
@ -285,7 +419,7 @@ cc_library(
|
|||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:packet_cloner_calculator_cc_proto",
|
||||
":packet_cloner_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -387,6 +521,32 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "string_to_int_calculator",
|
||||
srcs = ["string_to_int_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "side_packet_to_stream_calculator",
|
||||
srcs = ["side_packet_to_stream_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "immediate_mux_calculator_test",
|
||||
srcs = ["immediate_mux_calculator_test.cc"],
|
||||
|
@ -558,6 +718,32 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dequantize_byte_array_calculator",
|
||||
srcs = ["dequantize_byte_array_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":dequantize_byte_array_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "dequantize_byte_array_calculator_test",
|
||||
srcs = ["dequantize_byte_array_calculator_test.cc"],
|
||||
deps = [
|
||||
":dequantize_byte_array_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantize_float_vector_calculator",
|
||||
srcs = ["quantize_float_vector_calculator.cc"],
|
||||
|
@ -694,3 +880,29 @@ cc_test(
|
|||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_to_side_packet_calculator",
|
||||
srcs = ["stream_to_side_packet_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "stream_to_side_packet_calculator_test",
|
||||
srcs = ["stream_to_side_packet_calculator_test.cc"],
|
||||
deps = [
|
||||
":stream_to_side_packet_calculator",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,335 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/calculators/core/begin_loop_calculator.h"
|
||||
#include "mediapipe/calculators/core/end_loop_calculator.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
typedef BeginLoopCalculator<std::vector<int>> BeginLoopIntegerCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopIntegerCalculator);
|
||||
|
||||
class IncrementCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Outputs().Index(0).Set<int>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
const int& input_int = cc->Inputs().Index(0).Get<int>();
|
||||
auto output_int = absl::make_unique<int>(input_int + 1);
|
||||
cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(IncrementCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<int>> EndLoopIntegersCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopIntegersCalculator);
|
||||
|
||||
class BeginEndLoopCalculatorGraphTest : public ::testing::Test {
|
||||
protected:
|
||||
BeginEndLoopCalculatorGraphTest() {
|
||||
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
num_threads: 4
|
||||
input_stream: "ints"
|
||||
node {
|
||||
calculator: "BeginLoopIntegerCalculator"
|
||||
input_stream: "ITERABLE:ints"
|
||||
output_stream: "ITEM:int"
|
||||
output_stream: "BATCH_END:timestamp"
|
||||
}
|
||||
node {
|
||||
calculator: "IncrementCalculator"
|
||||
input_stream: "int"
|
||||
output_stream: "int_plus_one"
|
||||
}
|
||||
node {
|
||||
calculator: "EndLoopIntegersCalculator"
|
||||
input_stream: "ITEM:int_plus_one"
|
||||
input_stream: "BATCH_END:timestamp"
|
||||
output_stream: "ITERABLE:ints_plus_one"
|
||||
}
|
||||
)");
|
||||
tool::AddVectorSink("ints_plus_one", &graph_config_, &output_packets_);
|
||||
}
|
||||
|
||||
CalculatorGraphConfig graph_config_;
|
||||
std::vector<Packet> output_packets_;
|
||||
};
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphTest, SingleEmptyVector) {
|
||||
CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
auto input_vector = absl::make_unique<std::vector<int>>();
|
||||
Timestamp input_timestamp = Timestamp(0);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"ints", Adopt(input_vector.release()).At(input_timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// EndLoopCalc will forward the timestamp bound because there are no elements
|
||||
// in collection to output.
|
||||
ASSERT_EQ(0, output_packets_.size());
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphTest, SingleNonEmptyVector) {
|
||||
CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
auto input_vector = absl::make_unique<std::vector<int>>();
|
||||
input_vector->emplace_back(0);
|
||||
input_vector->emplace_back(1);
|
||||
input_vector->emplace_back(2);
|
||||
Timestamp input_timestamp = Timestamp(0);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"ints", Adopt(input_vector.release()).At(input_timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
ASSERT_EQ(1, output_packets_.size());
|
||||
EXPECT_EQ(input_timestamp, output_packets_[0].Timestamp());
|
||||
std::vector<int> expected_output_vector = {1, 2, 3};
|
||||
EXPECT_EQ(expected_output_vector, output_packets_[0].Get<std::vector<int>>());
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphTest, MultipleVectors) {
|
||||
CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
|
||||
auto input_vector0 = absl::make_unique<std::vector<int>>();
|
||||
input_vector0->emplace_back(0);
|
||||
input_vector0->emplace_back(1);
|
||||
Timestamp input_timestamp0 = Timestamp(0);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"ints", Adopt(input_vector0.release()).At(input_timestamp0)));
|
||||
|
||||
auto input_vector1 = absl::make_unique<std::vector<int>>();
|
||||
Timestamp input_timestamp1 = Timestamp(1);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"ints", Adopt(input_vector1.release()).At(input_timestamp1)));
|
||||
|
||||
auto input_vector2 = absl::make_unique<std::vector<int>>();
|
||||
input_vector2->emplace_back(2);
|
||||
input_vector2->emplace_back(3);
|
||||
Timestamp input_timestamp2 = Timestamp(2);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"ints", Adopt(input_vector2.release()).At(input_timestamp2)));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
|
||||
ASSERT_EQ(2, output_packets_.size());
|
||||
|
||||
EXPECT_EQ(input_timestamp0, output_packets_[0].Timestamp());
|
||||
std::vector<int> expected_output_vector0 = {1, 2};
|
||||
EXPECT_EQ(expected_output_vector0,
|
||||
output_packets_[0].Get<std::vector<int>>());
|
||||
|
||||
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
|
||||
// no elements in vector to process.
|
||||
|
||||
EXPECT_EQ(input_timestamp2, output_packets_[1].Timestamp());
|
||||
std::vector<int> expected_output_vector2 = {3, 4};
|
||||
EXPECT_EQ(expected_output_vector2,
|
||||
output_packets_[1].Get<std::vector<int>>());
|
||||
}
|
||||
|
||||
class MultiplierCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Inputs().Index(1).Set<int>();
|
||||
cc->Outputs().Index(0).Set<int>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
const int& input_int = cc->Inputs().Index(0).Get<int>();
|
||||
const int& multiplier_int = cc->Inputs().Index(1).Get<int>();
|
||||
auto output_int = absl::make_unique<int>(input_int * multiplier_int);
|
||||
cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(MultiplierCalculator);
|
||||
|
||||
class BeginEndLoopCalculatorGraphWithClonedInputsTest : public ::testing::Test {
|
||||
protected:
|
||||
BeginEndLoopCalculatorGraphWithClonedInputsTest() {
|
||||
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
num_threads: 4
|
||||
input_stream: "ints"
|
||||
input_stream: "multiplier"
|
||||
node {
|
||||
calculator: "BeginLoopIntegerCalculator"
|
||||
input_stream: "ITERABLE:ints"
|
||||
input_stream: "CLONE:multiplier"
|
||||
output_stream: "ITEM:int_at_loop"
|
||||
output_stream: "CLONE:multiplier_cloned_at_loop"
|
||||
output_stream: "BATCH_END:timestamp"
|
||||
}
|
||||
node {
|
||||
calculator: "MultiplierCalculator"
|
||||
input_stream: "int_at_loop"
|
||||
input_stream: "multiplier_cloned_at_loop"
|
||||
output_stream: "multiplied_int_at_loop"
|
||||
}
|
||||
node {
|
||||
calculator: "EndLoopIntegersCalculator"
|
||||
input_stream: "ITEM:multiplied_int_at_loop"
|
||||
input_stream: "BATCH_END:timestamp"
|
||||
output_stream: "ITERABLE:multiplied_ints"
|
||||
}
|
||||
)");
|
||||
tool::AddVectorSink("multiplied_ints", &graph_config_, &output_packets_);
|
||||
}
|
||||
|
||||
CalculatorGraphConfig graph_config_;
|
||||
std::vector<Packet> output_packets_;
|
||||
};
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleEmptyVector) {
|
||||
CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
auto input_vector = absl::make_unique<std::vector<int>>();
|
||||
Timestamp input_timestamp = Timestamp(42);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"ints", Adopt(input_vector.release()).At(input_timestamp)));
|
||||
auto multiplier = absl::make_unique<int>(2);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"multiplier", Adopt(multiplier.release()).At(input_timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// EndLoopCalc will forward the timestamp bound because there are no elements
|
||||
// in collection to output.
|
||||
ASSERT_EQ(0, output_packets_.size());
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleNonEmptyVector) {
|
||||
CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
auto input_vector = absl::make_unique<std::vector<int>>();
|
||||
input_vector->emplace_back(0);
|
||||
input_vector->emplace_back(1);
|
||||
input_vector->emplace_back(2);
|
||||
Timestamp input_timestamp = Timestamp(42);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"ints", Adopt(input_vector.release()).At(input_timestamp)));
|
||||
auto multiplier = absl::make_unique<int>(2);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"multiplier", Adopt(multiplier.release()).At(input_timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
ASSERT_EQ(1, output_packets_.size());
|
||||
EXPECT_EQ(input_timestamp, output_packets_[0].Timestamp());
|
||||
std::vector<int> expected_output_vector = {0, 2, 4};
|
||||
EXPECT_EQ(expected_output_vector, output_packets_[0].Get<std::vector<int>>());
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) {
|
||||
CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
|
||||
auto input_vector0 = absl::make_unique<std::vector<int>>();
|
||||
input_vector0->emplace_back(0);
|
||||
input_vector0->emplace_back(1);
|
||||
Timestamp input_timestamp0 = Timestamp(42);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"ints", Adopt(input_vector0.release()).At(input_timestamp0)));
|
||||
auto multiplier0 = absl::make_unique<int>(2);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"multiplier", Adopt(multiplier0.release()).At(input_timestamp0)));
|
||||
|
||||
auto input_vector1 = absl::make_unique<std::vector<int>>();
|
||||
Timestamp input_timestamp1 = Timestamp(43);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"ints", Adopt(input_vector1.release()).At(input_timestamp1)));
|
||||
auto multiplier1 = absl::make_unique<int>(2);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"multiplier", Adopt(multiplier1.release()).At(input_timestamp1)));
|
||||
|
||||
auto input_vector2 = absl::make_unique<std::vector<int>>();
|
||||
input_vector2->emplace_back(2);
|
||||
input_vector2->emplace_back(3);
|
||||
Timestamp input_timestamp2 = Timestamp(44);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"ints", Adopt(input_vector2.release()).At(input_timestamp2)));
|
||||
auto multiplier2 = absl::make_unique<int>(3);
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"multiplier", Adopt(multiplier2.release()).At(input_timestamp2)));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
|
||||
ASSERT_EQ(2, output_packets_.size());
|
||||
|
||||
EXPECT_EQ(input_timestamp0, output_packets_[0].Timestamp());
|
||||
std::vector<int> expected_output_vector0 = {0, 2};
|
||||
EXPECT_EQ(expected_output_vector0,
|
||||
output_packets_[0].Get<std::vector<int>>());
|
||||
|
||||
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
|
||||
// no elements in vector to process.
|
||||
|
||||
EXPECT_EQ(input_timestamp2, output_packets_[1].Timestamp());
|
||||
std::vector<int> expected_output_vector2 = {6, 9};
|
||||
EXPECT_EQ(expected_output_vector2,
|
||||
output_packets_[1].Get<std::vector<int>>());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
40
mediapipe/calculators/core/begin_loop_calculator.cc
Normal file
|
@ -0,0 +1,40 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/begin_loop_calculator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A calculator to process std::vector<NormalizedLandmark>.
|
||||
typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedLandmark>>
|
||||
BeginLoopNormalizedLandmarkCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopNormalizedLandmarkCalculator);
|
||||
|
||||
// A calculator to process std::vector<std::vector<NormalizedLandmark>>.
|
||||
typedef BeginLoopCalculator<
|
||||
std::vector<std::vector<::mediapipe::NormalizedLandmark>>>
|
||||
BeginLoopNormalizedLandmarksVectorCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopNormalizedLandmarksVectorCalculator);
|
||||
|
||||
// A calculator to process std::vector<NormalizedRect>.
|
||||
typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
|
||||
BeginLoopNormalizedRectCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopNormalizedRectCalculator);
|
||||
|
||||
} // namespace mediapipe
|
157
mediapipe/calculators/core/begin_loop_calculator.h
Normal file
|
@ -0,0 +1,157 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Calculator for implementing loops on iterable collections inside a MediaPipe
|
||||
// graph.
|
||||
//
|
||||
// It is designed to be used like:
|
||||
//
|
||||
// node {
|
||||
// calculator: "BeginLoopWithIterableCalculator"
|
||||
// input_stream: "ITERABLE:input_iterable" # IterableT @ext_ts
|
||||
// output_stream: "ITEM:input_element" # ItemT @loop_internal_ts
|
||||
// output_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
|
||||
// }
|
||||
//
|
||||
// node {
|
||||
// calculator: "ElementToBlaConverterSubgraph"
|
||||
// input_stream: "ITEM:input_to_loop_body" # ItemT @loop_internal_ts
|
||||
// output_stream: "BLA:output_of_loop_body" # ItemU @loop_internal_ts
|
||||
// }
|
||||
//
|
||||
// node {
|
||||
// calculator: "EndLoopWithOutputCalculator"
|
||||
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
|
||||
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
|
||||
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
|
||||
// }
|
||||
//
|
||||
// BeginLoopCalculator accepts an optional input stream tagged with "TICK"
|
||||
// which if non-empty, wakes up the calculator and calls
|
||||
// BeginLoopCalculator::Process(). Input streams tagged with "CLONE" are cloned
|
||||
// to the corresponding output streams at loop timestamps. This ensures that a
|
||||
// MediaPipe graph or sub-graph can run multiple times, once per element in the
|
||||
// "ITERABLE" for each pakcet clone of the packets in the "CLONE" input streams.
|
||||
template <typename IterableT>
|
||||
class BeginLoopCalculator : public CalculatorBase {
|
||||
using ItemT = typename IterableT::value_type;
|
||||
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
// A non-empty packet in the optional "TICK" input stream wakes up the
|
||||
// calculator.
|
||||
if (cc->Inputs().HasTag("TICK")) {
|
||||
cc->Inputs().Tag("TICK").SetAny();
|
||||
}
|
||||
|
||||
// An iterable collection in the input stream.
|
||||
RET_CHECK(cc->Inputs().HasTag("ITERABLE"));
|
||||
cc->Inputs().Tag("ITERABLE").Set<IterableT>();
|
||||
|
||||
// An element from the collection.
|
||||
RET_CHECK(cc->Outputs().HasTag("ITEM"));
|
||||
cc->Outputs().Tag("ITEM").Set<ItemT>();
|
||||
|
||||
RET_CHECK(cc->Outputs().HasTag("BATCH_END"));
|
||||
cc->Outputs()
|
||||
.Tag("BATCH_END")
|
||||
.Set<Timestamp>(
|
||||
// A flush signal to the corresponding EndLoopCalculator for it to
|
||||
// emit the aggregated result with the timestamp contained in this
|
||||
// flush signal packet.
|
||||
);
|
||||
|
||||
// Input streams tagged with "CLONE" are cloned to the corresponding
|
||||
// "CLONE" output streams at loop timestamps.
|
||||
RET_CHECK(cc->Inputs().NumEntries("CLONE") ==
|
||||
cc->Outputs().NumEntries("CLONE"));
|
||||
if (cc->Inputs().NumEntries("CLONE") > 0) {
|
||||
for (int i = 0; i < cc->Inputs().NumEntries("CLONE"); ++i) {
|
||||
cc->Inputs().Get("CLONE", i).SetAny();
|
||||
cc->Outputs().Get("CLONE", i).SetSameAs(&cc->Inputs().Get("CLONE", i));
|
||||
}
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
Timestamp last_timestamp = loop_internal_timestamp_;
|
||||
if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) {
|
||||
const IterableT& collection =
|
||||
cc->Inputs().Tag("ITERABLE").template Get<IterableT>();
|
||||
for (const auto& item : collection) {
|
||||
cc->Outputs().Tag("ITEM").AddPacket(
|
||||
MakePacket<ItemT>(item).At(loop_internal_timestamp_));
|
||||
ForwardClonePackets(cc, loop_internal_timestamp_);
|
||||
++loop_internal_timestamp_;
|
||||
}
|
||||
}
|
||||
|
||||
// The collection was empty and nothing was processed.
|
||||
if (last_timestamp == loop_internal_timestamp_) {
|
||||
// Increment loop_internal_timestamp_ because it is used up now.
|
||||
++loop_internal_timestamp_;
|
||||
for (auto it = cc->Outputs().begin(); it < cc->Outputs().end(); ++it) {
|
||||
it->SetNextTimestampBound(loop_internal_timestamp_);
|
||||
}
|
||||
}
|
||||
|
||||
// The for loop processing the input collection already incremented
|
||||
// loop_internal_timestamp_. To emit BATCH_END packet along the last
|
||||
// non-BATCH_END packet, decrement by one.
|
||||
cc->Outputs()
|
||||
.Tag("BATCH_END")
|
||||
.AddPacket(MakePacket<Timestamp>(cc->InputTimestamp())
|
||||
.At(Timestamp(loop_internal_timestamp_ - 1)));
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
void ForwardClonePackets(CalculatorContext* cc, Timestamp output_timestamp) {
|
||||
if (cc->Inputs().NumEntries("CLONE") > 0) {
|
||||
for (int i = 0; i < cc->Inputs().NumEntries("CLONE"); ++i) {
|
||||
if (!cc->Inputs().Get("CLONE", i).IsEmpty()) {
|
||||
auto input_packet = cc->Inputs().Get("CLONE", i).Value();
|
||||
cc->Outputs()
|
||||
.Get("CLONE", i)
|
||||
.AddPacket(std::move(input_packet).At(output_timestamp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fake timestamps generated per element in collection.
|
||||
Timestamp loop_internal_timestamp_ = Timestamp(0);
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
|
||||
ClipDetectionVectorSizeCalculator;
|
||||
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
|
||||
|
||||
} // namespace mediapipe
|
28
mediapipe/calculators/core/clip_vector_size_calculator.cc
Normal file
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef ClipVectorSizeCalculator<::mediapipe::NormalizedRect>
|
||||
ClipNormalizedRectVectorSizeCalculator;
|
||||
REGISTER_CALCULATOR(ClipNormalizedRectVectorSizeCalculator);
|
||||
|
||||
} // namespace mediapipe
|
137
mediapipe/calculators/core/clip_vector_size_calculator.h
Normal file
|
@ -0,0 +1,137 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Clips the size of the input vector of type T to a specified max_vec_size.
|
||||
// In a graph it will be used as:
|
||||
// node {
|
||||
// calculator: "ClipIntVectorSizeCalculator"
|
||||
// input_stream: "input_vector"
|
||||
// output_stream: "output_vector"
|
||||
// options {
|
||||
// [mediapipe.ClipIntVectorSizeCalculatorOptions.ext] {
|
||||
// max_vec_size: 5
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
template <typename T>
|
||||
class ClipVectorSizeCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().NumEntries() == 1);
|
||||
RET_CHECK(cc->Outputs().NumEntries() == 1);
|
||||
|
||||
if (cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>()
|
||||
.max_vec_size() < 1) {
|
||||
return ::mediapipe::InternalError(
|
||||
"max_vec_size should be greater than or equal to 1.");
|
||||
}
|
||||
|
||||
cc->Inputs().Index(0).Set<std::vector<T>>();
|
||||
cc->Outputs().Index(0).Set<std::vector<T>>();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
max_vec_size_ = cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>()
|
||||
.max_vec_size();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
if (max_vec_size_ < 1) {
|
||||
return ::mediapipe::InternalError(
|
||||
"max_vec_size should be greater than or equal to 1.");
|
||||
}
|
||||
if (cc->Inputs().Index(0).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
return ClipVectorSize<T>(std::is_copy_constructible<T>(), cc);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
::mediapipe::Status ClipVectorSize(std::true_type, CalculatorContext* cc) {
|
||||
auto output = absl::make_unique<std::vector<U>>();
|
||||
const std::vector<U>& input_vector =
|
||||
cc->Inputs().Index(0).Get<std::vector<U>>();
|
||||
if (max_vec_size_ >= input_vector.size()) {
|
||||
output->insert(output->end(), input_vector.begin(), input_vector.end());
|
||||
} else {
|
||||
for (int i = 0; i < max_vec_size_; ++i) {
|
||||
output->push_back(input_vector[i]);
|
||||
}
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
::mediapipe::Status ClipVectorSize(std::false_type, CalculatorContext* cc) {
|
||||
return ConsumeAndClipVectorSize<T>(std::is_move_constructible<U>(), cc);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
::mediapipe::Status ConsumeAndClipVectorSize(std::true_type,
|
||||
CalculatorContext* cc) {
|
||||
auto output = absl::make_unique<std::vector<U>>();
|
||||
::mediapipe::StatusOr<std::unique_ptr<std::vector<U>>> input_status =
|
||||
cc->Inputs().Index(0).Value().Consume<std::vector<U>>();
|
||||
|
||||
if (input_status.ok()) {
|
||||
std::unique_ptr<std::vector<U>> input_vector =
|
||||
std::move(input_status).ValueOrDie();
|
||||
auto begin_it = input_vector->begin();
|
||||
auto end_it = input_vector->end();
|
||||
if (max_vec_size_ < input_vector->size()) {
|
||||
end_it = input_vector->begin() + max_vec_size_;
|
||||
}
|
||||
output->insert(output->end(), std::make_move_iterator(begin_it),
|
||||
std::make_move_iterator(end_it));
|
||||
} else {
|
||||
return input_status.status();
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
::mediapipe::Status ConsumeAndClipVectorSize(std::false_type,
|
||||
CalculatorContext* cc) {
|
||||
return ::mediapipe::InternalError(
|
||||
"Cannot copy or move input vectors and clip their size.");
|
||||
}
|
||||
|
||||
private:
|
||||
int max_vec_size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_
|
28
mediapipe/calculators/core/clip_vector_size_calculator.proto
Normal file
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message ClipVectorSizeCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional ClipVectorSizeCalculatorOptions ext = 274674998;
|
||||
}
|
||||
|
||||
// Maximum size of output vector.
|
||||
optional int32 max_vec_size = 1 [default = 1];
|
||||
}
|
179
mediapipe/calculators/core/clip_vector_size_calculator_test.cc
Normal file
|
@ -0,0 +1,179 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef ClipVectorSizeCalculator<int> TestClipIntVectorSizeCalculator;
|
||||
REGISTER_CALCULATOR(TestClipIntVectorSizeCalculator);
|
||||
|
||||
void AddInputVector(const std::vector<int>& input, int64 timestamp,
|
||||
CalculatorRunner* runner) {
|
||||
runner->MutableInputs()->Index(0).packets.push_back(
|
||||
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
|
||||
}
|
||||
|
||||
TEST(TestClipIntVectorSizeCalculatorTest, EmptyVectorInput) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "TestClipIntVectorSizeCalculator"
|
||||
input_stream: "input_vector"
|
||||
output_stream: "output_vector"
|
||||
options {
|
||||
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 1 }
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
|
||||
std::vector<int> input = {};
|
||||
AddInputVector(input, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
EXPECT_TRUE(outputs[0].Get<std::vector<int>>().empty());
|
||||
}
|
||||
|
||||
TEST(TestClipIntVectorSizeCalculatorTest, OneTimestamp) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "TestClipIntVectorSizeCalculator"
|
||||
input_stream: "input_vector"
|
||||
output_stream: "output_vector"
|
||||
options {
|
||||
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 2 }
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
|
||||
std::vector<int> input = {0, 1, 2, 3};
|
||||
AddInputVector(input, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const std::vector<int>& output = outputs[0].Get<std::vector<int>>();
|
||||
EXPECT_EQ(2, output.size());
|
||||
std::vector<int> expected_vector = {0, 1};
|
||||
EXPECT_EQ(expected_vector, output);
|
||||
}
|
||||
|
||||
TEST(TestClipIntVectorSizeCalculatorTest, TwoInputsAtTwoTimestamps) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "TestClipIntVectorSizeCalculator"
|
||||
input_stream: "input_vector"
|
||||
output_stream: "output_vector"
|
||||
options {
|
||||
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 3 }
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
|
||||
{
|
||||
std::vector<int> input = {0, 1, 2, 3};
|
||||
AddInputVector(input, /*timestamp=*/1, &runner);
|
||||
}
|
||||
{
|
||||
std::vector<int> input = {2, 3, 4, 5};
|
||||
AddInputVector(input, /*timestamp=*/2, &runner);
|
||||
}
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(2, outputs.size());
|
||||
{
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const std::vector<int>& output = outputs[0].Get<std::vector<int>>();
|
||||
EXPECT_EQ(3, output.size());
|
||||
std::vector<int> expected_vector = {0, 1, 2};
|
||||
EXPECT_EQ(expected_vector, output);
|
||||
}
|
||||
{
|
||||
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||
const std::vector<int>& output = outputs[1].Get<std::vector<int>>();
|
||||
EXPECT_EQ(3, output.size());
|
||||
std::vector<int> expected_vector = {2, 3, 4};
|
||||
EXPECT_EQ(expected_vector, output);
|
||||
}
|
||||
}
|
||||
|
||||
typedef ClipVectorSizeCalculator<std::unique_ptr<int>>
|
||||
TestClipUniqueIntPtrVectorSizeCalculator;
|
||||
REGISTER_CALCULATOR(TestClipUniqueIntPtrVectorSizeCalculator);
|
||||
|
||||
TEST(TestClipUniqueIntPtrVectorSizeCalculatorTest, ConsumeOneTimestamp) {
|
||||
/* Note: We don't use CalculatorRunner for this test because it keeps copies
|
||||
* of input packets, so packets sent to the graph don't have sole ownership.
|
||||
* The test needs to send packets that own the data.
|
||||
*/
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "input_vector"
|
||||
node {
|
||||
calculator: "TestClipUniqueIntPtrVectorSizeCalculator"
|
||||
input_stream: "input_vector"
|
||||
output_stream: "output_vector"
|
||||
options {
|
||||
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 3 }
|
||||
}
|
||||
}
|
||||
)");
|
||||
|
||||
std::vector<Packet> outputs;
|
||||
tool::AddVectorSink("output_vector", &graph_config, &outputs);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_EXPECT_OK(graph.Initialize(graph_config));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
|
||||
// input1 : {0, 1, 2, 3, 4, 5}
|
||||
auto input_vector = absl::make_unique<std::vector<std::unique_ptr<int>>>(6);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
input_vector->at(i) = absl::make_unique<int>(i);
|
||||
}
|
||||
|
||||
MP_EXPECT_OK(graph.AddPacketToInputStream(
|
||||
"input_vector", Adopt(input_vector.release()).At(Timestamp(1))));
|
||||
|
||||
MP_EXPECT_OK(graph.WaitUntilIdle());
|
||||
MP_EXPECT_OK(graph.CloseAllPacketSources());
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
const std::vector<std::unique_ptr<int>>& result =
|
||||
outputs[0].Get<std::vector<std::unique_ptr<int>>>();
|
||||
EXPECT_EQ(3, result.size());
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
const std::unique_ptr<int>& v = result[i];
|
||||
EXPECT_EQ(i, *v);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,90 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cfloat>
|
||||
|
||||
#include "mediapipe/calculators/core/dequantize_byte_array_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
// Dequantizes a byte array to a vector of floats.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "DequantizeByteArrayCalculator"
|
||||
// input_stream: "ENCODED:encoded"
|
||||
// output_stream: "FLOAT_VECTOR:float_vector"
|
||||
// options {
|
||||
// [mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
// max_quantized_value: 2
|
||||
// min_quantized_value: -2
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
namespace mediapipe {
|
||||
|
||||
class DequantizeByteArrayCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("ENCODED").Set<std::string>();
|
||||
cc->Outputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
const auto options =
|
||||
cc->Options<::mediapipe::DequantizeByteArrayCalculatorOptions>();
|
||||
if (!options.has_max_quantized_value() ||
|
||||
!options.has_min_quantized_value()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Both max_quantized_value and min_quantized_value must be provided "
|
||||
"in DequantizeByteArrayCalculatorOptions.");
|
||||
}
|
||||
float max_quantized_value = options.max_quantized_value();
|
||||
float min_quantized_value = options.min_quantized_value();
|
||||
if (max_quantized_value < min_quantized_value + FLT_EPSILON) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"max_quantized_value must be greater than min_quantized_value.");
|
||||
}
|
||||
float range = max_quantized_value - min_quantized_value;
|
||||
scalar_ = range / 255.0;
|
||||
bias_ = (range / 512.0) + min_quantized_value;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
const std::string& encoded =
|
||||
cc->Inputs().Tag("ENCODED").Value().Get<std::string>();
|
||||
std::vector<float> float_vector;
|
||||
float_vector.reserve(encoded.length());
|
||||
for (int i = 0; i < encoded.length(); ++i) {
|
||||
float_vector.push_back(
|
||||
static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("FLOAT_VECTOR")
|
||||
.AddPacket(MakePacket<std::vector<float>>(float_vector)
|
||||
.At(cc->InputTimestamp()));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
float scalar_;
|
||||
float bias_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(DequantizeByteArrayCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message DequantizeByteArrayCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional DequantizeByteArrayCalculatorOptions ext = 272316343;
|
||||
}
|
||||
|
||||
optional float max_quantized_value = 1;
|
||||
optional float min_quantized_value = 2;
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: 2
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"Both max_quantized_value and min_quantized_value must be provided"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: -2
|
||||
min_quantized_value: 2
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"max_quantized_value must be greater than min_quantized_value"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: 1
|
||||
min_quantized_value: 1
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"max_quantized_value must be greater than min_quantized_value"));
|
||||
}
|
||||
|
||||
TEST(DequantizeByteArrayCalculatorTest, TestDequantization) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: 2
|
||||
min_quantized_value: -2
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01};
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(
|
||||
std::string(reinterpret_cast<char const*>(input), 4))
|
||||
.At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs =
|
||||
runner.Outputs().Tag("FLOAT_VECTOR").packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
const std::vector<float>& result = outputs[0].Get<std::vector<float>>();
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_EQ(4, result.size());
|
||||
EXPECT_NEAR(0, result[0], 0.01);
|
||||
EXPECT_NEAR(2, result[1], 0.01);
|
||||
EXPECT_NEAR(-2, result[2], 0.01);
|
||||
EXPECT_NEAR(-1.976, result[3], 0.01);
|
||||
|
||||
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
45
mediapipe/calculators/core/end_loop_calculator.cc
Normal file
|
@ -0,0 +1,45 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/end_loop_calculator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef EndLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
|
||||
EndLoopNormalizedRectCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopNormalizedRectCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<::mediapipe::NormalizedLandmark>>
|
||||
EndLoopNormalizedLandmarkCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopNormalizedLandmarkCalculator);
|
||||
|
||||
typedef EndLoopCalculator<
|
||||
std::vector<std::vector<::mediapipe::NormalizedLandmark>>>
|
||||
EndLoopNormalizedLandmarksVectorCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopNormalizedLandmarksVectorCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<bool>> EndLoopBooleanCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopBooleanCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>>
|
||||
EndLoopRenderDataCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopRenderDataCalculator);
|
||||
|
||||
} // namespace mediapipe
|
106
mediapipe/calculators/core/end_loop_calculator.h
Normal file
|
@ -0,0 +1,106 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
||||
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Calculator for completing the processing of loops on iterable collections
|
||||
// inside a MediaPipe graph. The EndLoopCalculator collects all input packets
|
||||
// from ITEM input_stream into a collection and upon receiving the flush signal
|
||||
// from the "BATCH_END" tagged input stream, it emits the aggregated results
|
||||
// at the original timestamp contained in the "BATCH_END" input stream.
|
||||
//
|
||||
// It is designed to be used like:
|
||||
//
|
||||
// node {
|
||||
// calculator: "BeginLoopWithIterableCalculator"
|
||||
// input_stream: "ITERABLE:input_iterable" # IterableT @ext_ts
|
||||
// output_stream: "ITEM:input_element" # ItemT @loop_internal_ts
|
||||
// output_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
|
||||
// }
|
||||
//
|
||||
// node {
|
||||
// calculator: "ElementToBlaConverterSubgraph"
|
||||
// input_stream: "ITEM:input_to_loop_body" # ItemT @loop_internal_ts
|
||||
// output_stream: "BLA:output_of_loop_body" # ItemU @loop_internal_ts
|
||||
// }
|
||||
//
|
||||
// node {
|
||||
// calculator: "EndLoopWithOutputCalculator"
|
||||
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
|
||||
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
|
||||
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
|
||||
// }
|
||||
template <typename IterableT>
|
||||
class EndLoopCalculator : public CalculatorBase {
|
||||
using ItemT = typename IterableT::value_type;
|
||||
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("BATCH_END"))
|
||||
<< "Missing BATCH_END tagged input_stream.";
|
||||
cc->Inputs().Tag("BATCH_END").Set<Timestamp>();
|
||||
|
||||
RET_CHECK(cc->Inputs().HasTag("ITEM"));
|
||||
cc->Inputs().Tag("ITEM").Set<ItemT>();
|
||||
|
||||
RET_CHECK(cc->Outputs().HasTag("ITERABLE"));
|
||||
cc->Outputs().Tag("ITERABLE").Set<IterableT>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
if (!cc->Inputs().Tag("ITEM").IsEmpty()) {
|
||||
if (!input_stream_collection_) {
|
||||
input_stream_collection_.reset(new IterableT);
|
||||
}
|
||||
input_stream_collection_->push_back(
|
||||
cc->Inputs().Tag("ITEM").template Get<ItemT>());
|
||||
}
|
||||
|
||||
if (!cc->Inputs().Tag("BATCH_END").Value().IsEmpty()) { // flush signal
|
||||
Timestamp loop_control_ts =
|
||||
cc->Inputs().Tag("BATCH_END").template Get<Timestamp>();
|
||||
if (input_stream_collection_) {
|
||||
cc->Outputs()
|
||||
.Tag("ITERABLE")
|
||||
.Add(input_stream_collection_.release(), loop_control_ts);
|
||||
} else {
|
||||
// Since there is no collection, inform downstream calculators to not
|
||||
// expect any packet by updating the timestamp bounds.
|
||||
cc->Outputs()
|
||||
.Tag("ITERABLE")
|
||||
.SetNextTimestampBound(Timestamp(loop_control_ts.Value() + 1));
|
||||
}
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<IterableT> input_stream_collection_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
|
@ -74,6 +74,12 @@ class PacketResamplerCalculator : public CalculatorBase {
|
|||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Calculates the first sampled timestamp that incorporates a jittering
|
||||
// offset.
|
||||
void InitializeNextOutputTimestampWithJitter();
|
||||
// Calculates the next sampled timestamp that incorporates a jittering offset.
|
||||
void UpdateNextOutputTimestampWithJitter();
|
||||
|
||||
// Logic for Process() when jitter_ != 0.0.
|
||||
::mediapipe::Status ProcessWithJitter(CalculatorContext* cc);
|
||||
|
||||
|
@ -233,6 +239,7 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
|
|||
<< Timestamp::kTimestampUnitsPerSecond;
|
||||
|
||||
frame_time_usec_ = static_cast<int64>(1000000.0 / frame_rate_);
|
||||
|
||||
video_header_.frame_rate = frame_rate_;
|
||||
|
||||
if (resampler_options.output_header() !=
|
||||
|
@ -295,6 +302,17 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
void PacketResamplerCalculator::InitializeNextOutputTimestampWithJitter() {
|
||||
next_output_timestamp_ =
|
||||
first_timestamp_ + frame_time_usec_ * random_->RandFloat();
|
||||
}
|
||||
|
||||
void PacketResamplerCalculator::UpdateNextOutputTimestampWithJitter() {
|
||||
next_output_timestamp_ +=
|
||||
frame_time_usec_ *
|
||||
((1.0 - jitter_) + 2.0 * jitter_ * random_->RandFloat());
|
||||
}
|
||||
|
||||
::mediapipe::Status PacketResamplerCalculator::ProcessWithJitter(
|
||||
CalculatorContext* cc) {
|
||||
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
|
||||
|
@ -302,8 +320,13 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
|
|||
|
||||
if (first_timestamp_ == Timestamp::Unset()) {
|
||||
first_timestamp_ = cc->InputTimestamp();
|
||||
next_output_timestamp_ =
|
||||
first_timestamp_ + frame_time_usec_ * random_->RandFloat();
|
||||
InitializeNextOutputTimestampWithJitter();
|
||||
if (first_timestamp_ == next_output_timestamp_) {
|
||||
OutputWithinLimits(
|
||||
cc,
|
||||
cc->Inputs().Get(input_data_id_).Value().At(next_output_timestamp_));
|
||||
UpdateNextOutputTimestampWithJitter();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -322,9 +345,7 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
|
|||
? last_packet_
|
||||
: cc->Inputs().Get(input_data_id_).Value())
|
||||
.At(next_output_timestamp_));
|
||||
next_output_timestamp_ +=
|
||||
frame_time_usec_ *
|
||||
((1.0 - jitter_) + 2.0 * jitter_ * random_->RandFloat());
|
||||
UpdateNextOutputTimestampWithJitter();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -102,6 +102,12 @@ class PreviousLoopbackCalculator : public CalculatorBase {
|
|||
cc->Outputs().Get(loop_out_id_).AddPacket(std::move(previous_loopback));
|
||||
}
|
||||
}
|
||||
if (!main_ts_.empty()) {
|
||||
cc->Outputs().Get(loop_out_id_).SetNextTimestampBound(main_ts_.front());
|
||||
}
|
||||
if (cc->Inputs().Get(main_id_).IsDone() && main_ts_.empty()) {
|
||||
cc->Outputs().Get(loop_out_id_).Close();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -107,5 +107,96 @@ TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
|
|||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
// A Calculator that outputs a summary packet in CalculatorBase::Close().
|
||||
class PacketOnCloseCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Outputs().Index(0).Set<int>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
sum_ += cc->Inputs().Index(0).Value().Get<int>();
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Close(CalculatorContext* cc) final {
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
MakePacket<int>(sum_).At(Timestamp::Max()));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int sum_ = 0;
|
||||
};
|
||||
REGISTER_CALCULATOR(PacketOnCloseCalculator);
|
||||
|
||||
// Demonstrates that all ouput and input streams in PreviousLoopbackCalculator
|
||||
// will close as expected when all graph input streams are closed.
|
||||
TEST(PreviousLoopbackCalculator, ClosesCorrectly) {
|
||||
std::vector<Packet> outputs;
|
||||
CalculatorGraphConfig graph_config_ =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: 'in'
|
||||
node {
|
||||
calculator: 'PreviousLoopbackCalculator'
|
||||
input_stream: 'MAIN:in'
|
||||
input_stream: 'LOOP:out'
|
||||
input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||
output_stream: 'PREV_LOOP:previous'
|
||||
}
|
||||
# This calculator synchronizes its inputs as normal, so it is used
|
||||
# to check that both "in" and "previous" are ready.
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'in'
|
||||
input_stream: 'previous'
|
||||
output_stream: 'out'
|
||||
output_stream: 'previous2'
|
||||
}
|
||||
node {
|
||||
calculator: 'PacketOnCloseCalculator'
|
||||
input_stream: 'out'
|
||||
output_stream: 'close_out'
|
||||
}
|
||||
)");
|
||||
tool::AddVectorSink("close_out", &graph_config_, &outputs);
|
||||
|
||||
CalculatorGraph graph_;
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config_, {}));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
|
||||
auto send_packet = [&graph_](const std::string& input_name, int n) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream(
|
||||
input_name, MakePacket<int>(n).At(Timestamp(n))));
|
||||
};
|
||||
|
||||
send_packet("in", 1);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1}));
|
||||
|
||||
send_packet("in", 5);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 5}));
|
||||
|
||||
send_packet("in", 15);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 5, 15}));
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_EQ(TimestampValues(outputs),
|
||||
(std::vector<int64>{1, 5, 15, Timestamp::Max().Value()}));
|
||||
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
using mediapipe::PacketTypeSet;
|
||||
using mediapipe::Timestamp;
|
||||
|
||||
namespace {
|
||||
static std::map<std::string, Timestamp>* kTimestampMap = []() {
|
||||
auto* res = new std::map<std::string, Timestamp>();
|
||||
res->emplace("AT_PRESTREAM", Timestamp::PreStream());
|
||||
res->emplace("AT_POSTSTREAM", Timestamp::PostStream());
|
||||
res->emplace("AT_ZERO", Timestamp(0));
|
||||
return res;
|
||||
}();
|
||||
|
||||
} // namespace
|
||||
|
||||
// Outputs the single input_side_packet at the timestamp specified in the
|
||||
// output_stream tag. Valid tags are AT_PRESTREAM, AT_POSTSTREAM and AT_ZERO.
|
||||
class SidePacketToStreamCalculator : public CalculatorBase {
|
||||
public:
|
||||
SidePacketToStreamCalculator() = default;
|
||||
~SidePacketToStreamCalculator() override = default;
|
||||
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
};
|
||||
REGISTER_CALCULATOR(SidePacketToStreamCalculator);
|
||||
|
||||
::mediapipe::Status SidePacketToStreamCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Index(0).SetAny();
|
||||
|
||||
std::set<std::string> tags = cc->Outputs().GetTags();
|
||||
RET_CHECK_EQ(tags.size(), 1);
|
||||
|
||||
RET_CHECK_EQ(kTimestampMap->count(*tags.begin()), 1);
|
||||
cc->Outputs().Tag(*tags.begin()).SetAny();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status SidePacketToStreamCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
return mediapipe::tool::StatusStop();
|
||||
}
|
||||
|
||||
::mediapipe::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
|
||||
std::set<std::string> tags = cc->Outputs().GetTags();
|
||||
RET_CHECK_EQ(tags.size(), 1);
|
||||
const std::string& tag = *tags.begin();
|
||||
RET_CHECK_EQ(kTimestampMap->count(tag), 1);
|
||||
cc->Outputs().Tag(tag).AddPacket(
|
||||
cc->InputSidePackets().Index(0).At(kTimestampMap->at(tag)));
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -34,7 +34,9 @@ namespace mediapipe {
|
|||
// SplitVectorCalculatorOptions. If the option "element_only" is set to true,
|
||||
// all ranges should be of size 1 and all outputs will be elements of type T. If
|
||||
// "element_only" is false, ranges can be non-zero in size and all outputs will
|
||||
// be of type std::vector<T>.
|
||||
// be of type std::vector<T>. If the option "combine_outputs" is set to true,
|
||||
// only one output stream can be specified and all ranges of elements will be
|
||||
// combined into one vector.
|
||||
// To use this class for a particular type T, register a calculator using
|
||||
// SplitVectorCalculator<T>.
|
||||
template <typename T>
|
||||
|
@ -49,6 +51,24 @@ class SplitVectorCalculator : public CalculatorBase {
|
|||
const auto& options =
|
||||
cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
|
||||
|
||||
if (options.combine_outputs()) {
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
|
||||
cc->Outputs().Index(0).Set<std::vector<T>>();
|
||||
for (int i = 0; i < options.ranges_size() - 1; ++i) {
|
||||
for (int j = i + 1; j < options.ranges_size(); ++j) {
|
||||
const auto& range_0 = options.ranges(i);
|
||||
const auto& range_1 = options.ranges(j);
|
||||
if ((range_0.begin() >= range_1.begin() &&
|
||||
range_0.begin() < range_1.end()) ||
|
||||
(range_1.begin() >= range_0.begin() &&
|
||||
range_1.begin() < range_0.end())) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Ranges must be non-overlapping when using combine_outputs "
|
||||
"option.");
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (cc->Outputs().NumEntries() != options.ranges_size()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"The number of output streams should match the number of ranges "
|
||||
|
@ -73,6 +93,7 @@ class SplitVectorCalculator : public CalculatorBase {
|
|||
cc->Outputs().Index(i).Set<std::vector<T>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -83,13 +104,15 @@ class SplitVectorCalculator : public CalculatorBase {
|
|||
const auto& options =
|
||||
cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
|
||||
|
||||
element_only_ = options.element_only();
|
||||
combine_outputs_ = options.combine_outputs();
|
||||
|
||||
for (const auto& range : options.ranges()) {
|
||||
ranges_.push_back({range.begin(), range.end()});
|
||||
max_range_end_ = std::max(max_range_end_, range.end());
|
||||
total_elements_ += range.end() - range.begin();
|
||||
}
|
||||
|
||||
element_only_ = options.element_only();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -97,6 +120,17 @@ class SplitVectorCalculator : public CalculatorBase {
|
|||
const auto& input = cc->Inputs().Index(0).Get<std::vector<T>>();
|
||||
RET_CHECK_GE(input.size(), max_range_end_);
|
||||
|
||||
if (combine_outputs_) {
|
||||
auto output = absl::make_unique<std::vector<T>>();
|
||||
output->reserve(total_elements_);
|
||||
for (int i = 0; i < ranges_.size(); ++i) {
|
||||
auto elements = absl::make_unique<std::vector<T>>(
|
||||
input.begin() + ranges_[i].first,
|
||||
input.begin() + ranges_[i].second);
|
||||
output->insert(output->end(), elements->begin(), elements->end());
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
} else {
|
||||
if (element_only_) {
|
||||
for (int i = 0; i < ranges_.size(); ++i) {
|
||||
cc->Outputs().Index(i).AddPacket(
|
||||
|
@ -110,6 +144,7 @@ class SplitVectorCalculator : public CalculatorBase {
|
|||
cc->Outputs().Index(i).Add(output.release(), cc->InputTimestamp());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -117,7 +152,9 @@ class SplitVectorCalculator : public CalculatorBase {
|
|||
private:
|
||||
std::vector<std::pair<int32, int32>> ranges_;
|
||||
int32 max_range_end_ = -1;
|
||||
int32 total_elements_ = 0;
|
||||
bool element_only_ = false;
|
||||
bool combine_outputs_ = false;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -37,4 +37,7 @@ message SplitVectorCalculatorOptions {
|
|||
// just element of type T. By default, if a range specifies only one element,
|
||||
// it is outputted as an std::vector<T>.
|
||||
optional bool element_only = 2 [default = false];
|
||||
|
||||
// Combines output elements to one vector.
|
||||
optional bool combine_outputs = 3 [default = false];
|
||||
}
|
||||
|
|
|
@ -105,6 +105,34 @@ class SplitTfLiteTensorVectorCalculatorTest : public ::testing::Test {
|
|||
}
|
||||
}
|
||||
|
||||
void ValidateCombinedVectorOutput(std::vector<Packet>& output_packets,
|
||||
int expected_elements,
|
||||
std::vector<int>& input_begin_indices,
|
||||
std::vector<int>& input_end_indices) {
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
ASSERT_EQ(input_begin_indices.size(), input_end_indices.size());
|
||||
const std::vector<TfLiteTensor>& output_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensor>>();
|
||||
ASSERT_EQ(expected_elements, output_vec.size());
|
||||
const int num_ranges = input_begin_indices.size();
|
||||
|
||||
int element_id = 0;
|
||||
for (int range_id = 0; range_id < num_ranges; ++range_id) {
|
||||
for (int i = input_begin_indices[range_id];
|
||||
i < input_end_indices[range_id]; ++i) {
|
||||
const int expected_value = i;
|
||||
const TfLiteTensor* result = &output_vec[element_id];
|
||||
float* result_buffer = result->data.f;
|
||||
ASSERT_NE(result_buffer, nullptr);
|
||||
ASSERT_EQ(result_buffer, input_buffers_[i]);
|
||||
for (int j = 0; j < width * height * channels; ++j) {
|
||||
ASSERT_EQ(expected_value, result_buffer[j]);
|
||||
}
|
||||
element_id++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateElementOutput(std::vector<Packet>& output_packets,
|
||||
int input_begin_index) {
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
|
@ -234,6 +262,65 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOutputStreamCountTest) {
|
|||
ASSERT_FALSE(graph.Initialize(graph_config).ok());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest,
|
||||
InvalidCombineOutputsMultipleOutputsTest) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
input_stream: "tensor_in"
|
||||
output_stream: "range_0"
|
||||
output_stream: "range_1"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 1 }
|
||||
ranges: { begin: 2 end: 3 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
// The graph should fail running because the number of output streams does not
|
||||
// match the number of range elements in the options.
|
||||
ASSERT_FALSE(graph.Initialize(graph_config).ok());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOverlappingRangesTest) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
input_stream: "tensor_in"
|
||||
output_stream: "range_0"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 3 }
|
||||
ranges: { begin: 1 end: 4 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
// The graph should fail running because there are overlapping ranges.
|
||||
ASSERT_FALSE(graph.Initialize(graph_config).ok());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
|
@ -289,6 +376,53 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) {
|
|||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestCombiningOutputs) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
PrepareTfLiteTensorVector(/*vector_size=*/5);
|
||||
ASSERT_NE(input_vec_, nullptr);
|
||||
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
input_stream: "tensor_in"
|
||||
output_stream: "range_0"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 1 }
|
||||
ranges: { begin: 2 end: 3 }
|
||||
ranges: { begin: 4 end: 5 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
std::vector<Packet> range_0_packets;
|
||||
tool::AddVectorSink("range_0", &graph_config, &range_0_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tensor_in", Adopt(input_vec_.release()).At(Timestamp(0))));
|
||||
// Wait until the calculator finishes processing.
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
std::vector<int> input_begin_indices = {0, 2, 4};
|
||||
std::vector<int> input_end_indices = {1, 3, 5};
|
||||
ValidateCombinedVectorOutput(range_0_packets, /*expected_elements=*/3,
|
||||
input_begin_indices, input_end_indices);
|
||||
|
||||
// Fully close the graph at the end.
|
||||
MP_ASSERT_OK(graph.CloseInputStream("tensor_in"));
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest,
|
||||
ElementOnlyDisablesVectorOutputs) {
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A calculator that takes a packet of an input stream and converts it to an
|
||||
// output side packet. This calculator only works under the assumption that the
|
||||
// input stream only has a single packet passing through.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "StreamToSidePacketCalculator"
|
||||
// input_stream: "stream"
|
||||
// output_side_packet: "side_packet"
|
||||
// }
|
||||
class StreamToSidePacketCalculator : public mediapipe::CalculatorBase {
|
||||
public:
|
||||
static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
cc->OutputSidePackets().Index(0).SetAny();
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::Status Process(mediapipe::CalculatorContext* cc) override {
|
||||
mediapipe::Packet& packet = cc->Inputs().Index(0).Value();
|
||||
cc->OutputSidePackets().Index(0).Set(
|
||||
packet.At(mediapipe::Timestamp::Unset()));
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(StreamToSidePacketCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::testing::Test;
|
||||
|
||||
class StreamToSidePacketCalculatorTest : public Test {
|
||||
protected:
|
||||
StreamToSidePacketCalculatorTest() {
|
||||
const char kConfig[] = R"(
|
||||
calculator: "StreamToSidePacketCalculator"
|
||||
input_stream: "stream"
|
||||
output_side_packet: "side_packet"
|
||||
)";
|
||||
runner_ = absl::make_unique<CalculatorRunner>(kConfig);
|
||||
}
|
||||
|
||||
std::unique_ptr<CalculatorRunner> runner_;
|
||||
};
|
||||
|
||||
TEST_F(StreamToSidePacketCalculatorTest,
|
||||
StreamToSidePacketCalculatorWithEmptyStreamFails) {
|
||||
EXPECT_EQ(runner_->Run().code(), mediapipe::StatusCode::kUnavailable);
|
||||
}
|
||||
|
||||
TEST_F(StreamToSidePacketCalculatorTest,
|
||||
StreamToSidePacketCalculatorWithSinglePacketCreatesSidePacket) {
|
||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new std::string("test")).At(Timestamp(1)));
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
EXPECT_EQ(runner_->OutputSidePackets().Index(0).Get<std::string>(), "test");
|
||||
}
|
||||
|
||||
TEST_F(StreamToSidePacketCalculatorTest,
|
||||
StreamToSidePacketCalculatorWithMultiplePacketsFails) {
|
||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new std::string("test1")).At(Timestamp(1)));
|
||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new std::string("test2")).At(Timestamp(2)));
|
||||
EXPECT_EQ(runner_->Run().code(), mediapipe::StatusCode::kAlreadyExists);
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
79
mediapipe/calculators/core/string_to_int_calculator.cc
Normal file
|
@ -0,0 +1,79 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Calculator that converts a std::string into an integer type, or fails if the
|
||||
// conversion is not possible.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "StringToIntCalculator"
|
||||
// input_side_packet: "string"
|
||||
// output_side_packet: "index"
|
||||
// }
|
||||
template <typename IntType>
|
||||
class StringToIntCalculatorTemplate : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Index(0).Set<std::string>();
|
||||
cc->OutputSidePackets().Index(0).Set<IntType>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
IntType number;
|
||||
if (!absl::SimpleAtoi(cc->InputSidePackets().Index(0).Get<std::string>(),
|
||||
&number)) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"The std::string could not be parsed as an integer.");
|
||||
}
|
||||
cc->OutputSidePackets().Index(0).Set(MakePacket<IntType>(number));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
using StringToIntCalculator = StringToIntCalculatorTemplate<int>;
|
||||
REGISTER_CALCULATOR(StringToIntCalculator);
|
||||
|
||||
using StringToUintCalculator = StringToIntCalculatorTemplate<uint>;
|
||||
REGISTER_CALCULATOR(StringToUintCalculator);
|
||||
|
||||
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;
|
||||
REGISTER_CALCULATOR(StringToInt32Calculator);
|
||||
|
||||
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>;
|
||||
REGISTER_CALCULATOR(StringToUint32Calculator);
|
||||
|
||||
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>;
|
||||
REGISTER_CALCULATOR(StringToInt64Calculator);
|
||||
|
||||
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>;
|
||||
REGISTER_CALCULATOR(StringToUint64Calculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -12,14 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "opencv_image_encoder_calculator_proto",
|
||||
srcs = ["opencv_image_encoder_calculator.proto"],
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "graph_tensors_packet_generator_proto",
|
||||
srcs = ["graph_tensors_packet_generator.proto"],
|
||||
|
@ -104,6 +104,17 @@ proto_library(
|
|||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "unpack_media_sequence_calculator_proto",
|
||||
srcs = ["unpack_media_sequence_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:packet_resampler_calculator_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/util:audio_decoder_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "vector_float_to_tensor_calculator_options_proto",
|
||||
srcs = ["vector_float_to_tensor_calculator_options.proto"],
|
||||
|
@ -127,7 +138,7 @@ mediapipe_cc_proto_library(
|
|||
srcs = ["image_frame_to_tensor_calculator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":image_frame_to_tensor_calculator_proto"],
|
||||
|
@ -162,7 +173,7 @@ mediapipe_cc_proto_library(
|
|||
srcs = ["pack_media_sequence_calculator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":pack_media_sequence_calculator_proto"],
|
||||
|
@ -181,7 +192,7 @@ mediapipe_cc_proto_library(
|
|||
srcs = ["tensorflow_session_from_frozen_graph_generator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":tensorflow_session_from_frozen_graph_generator_proto"],
|
||||
|
@ -192,7 +203,7 @@ mediapipe_cc_proto_library(
|
|||
srcs = ["tensorflow_session_from_frozen_graph_calculator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":tensorflow_session_from_frozen_graph_calculator_proto"],
|
||||
|
@ -261,6 +272,17 @@ mediapipe_cc_proto_library(
|
|||
deps = [":unpack_media_sequence_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "vector_int_to_tensor_calculator_options_cc_proto",
|
||||
srcs = ["vector_int_to_tensor_calculator_options.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":vector_int_to_tensor_calculator_options_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "vector_float_to_tensor_calculator_options_cc_proto",
|
||||
srcs = ["vector_float_to_tensor_calculator_options.proto"],
|
||||
|
@ -274,7 +296,7 @@ cc_library(
|
|||
srcs = ["graph_tensors_packet_generator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensorflow:graph_tensors_packet_generator_cc_proto",
|
||||
":graph_tensors_packet_generator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -289,7 +311,7 @@ cc_library(
|
|||
srcs = ["image_frame_to_tensor_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensorflow:image_frame_to_tensor_calculator_cc_proto",
|
||||
":image_frame_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
@ -311,7 +333,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/calculators/tensorflow:matrix_to_tensor_calculator_options_cc_proto",
|
||||
":matrix_to_tensor_calculator_options_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -332,7 +354,7 @@ cc_library(
|
|||
srcs = ["lapped_tensor_buffer_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator_cc_proto",
|
||||
":lapped_tensor_buffer_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -386,7 +408,7 @@ cc_library(
|
|||
"//mediapipe/util/sequence:media_sequence",
|
||||
"//mediapipe/util/sequence:media_sequence_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -401,7 +423,7 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -414,7 +436,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tensorflow_session",
|
||||
"//mediapipe/calculators/tensorflow:tensorflow_inference_calculator_cc_proto",
|
||||
":tensorflow_inference_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/tool:status_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -492,7 +514,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tensorflow_session",
|
||||
"//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto",
|
||||
":tensorflow_session_from_frozen_graph_generator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/tool:status_util",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -551,7 +573,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tensorflow_session",
|
||||
"//mediapipe/calculators/tensorflow:tensorflow_session_from_saved_model_generator_cc_proto",
|
||||
":tensorflow_session_from_saved_model_generator_cc_proto",
|
||||
"//mediapipe/framework:packet_generator",
|
||||
"//mediapipe/framework:packet_type",
|
||||
"//mediapipe/framework/tool:status_util",
|
||||
|
@ -575,7 +597,7 @@ cc_library(
|
|||
srcs = ["tensor_squeeze_dimensions_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensorflow:tensor_squeeze_dimensions_calculator_cc_proto",
|
||||
":tensor_squeeze_dimensions_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -589,7 +611,7 @@ cc_library(
|
|||
srcs = ["tensor_to_image_frame_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensorflow:tensor_to_image_frame_calculator_cc_proto",
|
||||
":tensor_to_image_frame_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
@ -605,7 +627,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/calculators/tensorflow:tensor_to_matrix_calculator_cc_proto",
|
||||
":tensor_to_matrix_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -621,6 +643,22 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfrecord_reader_calculator",
|
||||
srcs = ["tfrecord_reader_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@org_tensorflow//tensorflow/core:lib",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_to_vector_float_calculator",
|
||||
srcs = ["tensor_to_vector_float_calculator.cc"],
|
||||
|
@ -629,7 +667,7 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/calculators/tensorflow:tensor_to_vector_float_calculator_options_cc_proto",
|
||||
":tensor_to_vector_float_calculator_options_cc_proto",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
|
@ -657,7 +695,21 @@ cc_library(
|
|||
"//mediapipe/util:audio_decoder_cc_proto",
|
||||
"//mediapipe/util/sequence:media_sequence",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "vector_int_to_tensor_calculator",
|
||||
srcs = ["vector_int_to_tensor_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":vector_int_to_tensor_calculator_options_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -667,7 +719,7 @@ cc_library(
|
|||
srcs = ["vector_float_to_tensor_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensorflow:vector_float_to_tensor_calculator_options_cc_proto",
|
||||
":vector_float_to_tensor_calculator_options_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -676,12 +728,26 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "unpack_yt8m_sequence_example_calculator",
|
||||
srcs = ["unpack_yt8m_sequence_example_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":lapped_tensor_buffer_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "graph_tensors_packet_generator_test",
|
||||
srcs = ["graph_tensors_packet_generator_test.cc"],
|
||||
deps = [
|
||||
":graph_tensors_packet_generator",
|
||||
"//mediapipe/calculators/tensorflow:graph_tensors_packet_generator_cc_proto",
|
||||
":graph_tensors_packet_generator_cc_proto",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework:packet_set",
|
||||
|
@ -713,7 +779,7 @@ cc_test(
|
|||
srcs = ["matrix_to_tensor_calculator_test.cc"],
|
||||
deps = [
|
||||
":matrix_to_tensor_calculator",
|
||||
"//mediapipe/calculators/tensorflow:matrix_to_tensor_calculator_options_cc_proto",
|
||||
":matrix_to_tensor_calculator_options_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
|
@ -729,13 +795,13 @@ cc_test(
|
|||
srcs = ["lapped_tensor_buffer_calculator_test.cc"],
|
||||
deps = [
|
||||
":lapped_tensor_buffer_calculator",
|
||||
"//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator_cc_proto",
|
||||
":lapped_tensor_buffer_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -774,7 +840,7 @@ cc_test(
|
|||
"//mediapipe/util/sequence:media_sequence",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -801,7 +867,7 @@ cc_test(
|
|||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:direct_session",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
"@org_tensorflow//tensorflow/core:testlib",
|
||||
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
|
||||
"@org_tensorflow//tensorflow/core/kernels:math",
|
||||
|
@ -817,7 +883,7 @@ cc_test(
|
|||
":tensorflow_inference_calculator",
|
||||
":tensorflow_session",
|
||||
":tensorflow_session_from_frozen_graph_generator",
|
||||
"//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto",
|
||||
":tensorflow_session_from_frozen_graph_generator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
|
@ -831,7 +897,7 @@ cc_test(
|
|||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:direct_session",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
"@org_tensorflow//tensorflow/core:testlib",
|
||||
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
|
||||
"@org_tensorflow//tensorflow/core/kernels:math",
|
||||
|
@ -847,7 +913,7 @@ cc_test(
|
|||
":tensorflow_inference_calculator",
|
||||
":tensorflow_session",
|
||||
":tensorflow_session_from_saved_model_generator",
|
||||
"//mediapipe/calculators/tensorflow:tensorflow_session_from_saved_model_generator_cc_proto",
|
||||
":tensorflow_session_from_saved_model_generator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
|
@ -857,14 +923,8 @@ cc_test(
|
|||
"//mediapipe/framework/tool:tag_map_helper",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:all_kernels",
|
||||
"@org_tensorflow//tensorflow/core:direct_session",
|
||||
"@org_tensorflow//tensorflow/core/kernels:array",
|
||||
"@org_tensorflow//tensorflow/core/kernels:bitcast_op",
|
||||
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
|
||||
"@org_tensorflow//tensorflow/core/kernels:io",
|
||||
"@org_tensorflow//tensorflow/core/kernels:state",
|
||||
"@org_tensorflow//tensorflow/core/kernels:string",
|
||||
"@org_tensorflow//tensorflow/core/kernels/data:tensor_dataset_op",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -888,14 +948,8 @@ cc_test(
|
|||
"//mediapipe/framework/tool:tag_map_helper",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:all_kernels",
|
||||
"@org_tensorflow//tensorflow/core:direct_session",
|
||||
"@org_tensorflow//tensorflow/core/kernels:array",
|
||||
"@org_tensorflow//tensorflow/core/kernels:bitcast_op",
|
||||
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
|
||||
"@org_tensorflow//tensorflow/core/kernels:io",
|
||||
"@org_tensorflow//tensorflow/core/kernels:state",
|
||||
"@org_tensorflow//tensorflow/core/kernels:string",
|
||||
"@org_tensorflow//tensorflow/core/kernels/data:tensor_dataset_op",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -904,12 +958,12 @@ cc_test(
|
|||
srcs = ["tensor_squeeze_dimensions_calculator_test.cc"],
|
||||
deps = [
|
||||
":tensor_squeeze_dimensions_calculator",
|
||||
"//mediapipe/calculators/tensorflow:tensor_squeeze_dimensions_calculator_cc_proto",
|
||||
":tensor_squeeze_dimensions_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -919,13 +973,13 @@ cc_test(
|
|||
srcs = ["tensor_to_image_frame_calculator_test.cc"],
|
||||
deps = [
|
||||
":tensor_to_image_frame_calculator",
|
||||
"//mediapipe/calculators/tensorflow:tensor_to_image_frame_calculator_cc_proto",
|
||||
":tensor_to_image_frame_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -935,14 +989,14 @@ cc_test(
|
|||
srcs = ["tensor_to_matrix_calculator_test.cc"],
|
||||
deps = [
|
||||
":tensor_to_matrix_calculator",
|
||||
"//mediapipe/calculators/tensorflow:tensor_to_matrix_calculator_cc_proto",
|
||||
":tensor_to_matrix_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -951,12 +1005,12 @@ cc_test(
|
|||
srcs = ["tensor_to_vector_float_calculator_test.cc"],
|
||||
deps = [
|
||||
":tensor_to_vector_float_calculator",
|
||||
"//mediapipe/calculators/tensorflow:tensor_to_vector_float_calculator_options_cc_proto",
|
||||
":tensor_to_vector_float_calculator_options_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -976,7 +1030,21 @@ cc_test(
|
|||
"//mediapipe/util/sequence:media_sequence",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "vector_int_to_tensor_calculator_test",
|
||||
srcs = ["vector_int_to_tensor_calculator_test.cc"],
|
||||
deps = [
|
||||
":vector_int_to_tensor_calculator",
|
||||
":vector_int_to_tensor_calculator_options_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -985,12 +1053,12 @@ cc_test(
|
|||
srcs = ["vector_float_to_tensor_calculator_test.cc"],
|
||||
deps = [
|
||||
":vector_float_to_tensor_calculator",
|
||||
"//mediapipe/calculators/tensorflow:vector_float_to_tensor_calculator_options_cc_proto",
|
||||
":vector_float_to_tensor_calculator_options_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
"@org_tensorflow//tensorflow/core:protos_all",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1014,7 +1082,7 @@ cc_test(
|
|||
":tensorflow_session",
|
||||
":tensorflow_inference_calculator",
|
||||
":tensorflow_session_from_frozen_graph_generator",
|
||||
"//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto",
|
||||
":tensorflow_session_from_frozen_graph_generator_cc_proto",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
|
|
|
@ -29,6 +29,11 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
const char kBufferSize[] = "BUFFER_SIZE";
|
||||
const char kOverlap[] = "OVERLAP";
|
||||
const char kTimestampOffset[] = "TIMESTAMP_OFFSET";
|
||||
const char kCalculatorOptions[] = "CALCULATOR_OPTIONS";
|
||||
|
||||
namespace tf = tensorflow;
|
||||
|
||||
// Given an input stream of tensors, concatenates the tensors over timesteps.
|
||||
|
@ -72,6 +77,9 @@ class LappedTensorBufferCalculator : public CalculatorBase {
|
|||
::mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor);
|
||||
|
||||
int steps_until_output_;
|
||||
int buffer_size_;
|
||||
int overlap_;
|
||||
int timestamp_offset_;
|
||||
std::unique_ptr<CircularBuffer<Timestamp>> timestamp_buffer_;
|
||||
std::unique_ptr<CircularBuffer<tf::Tensor>> buffer_;
|
||||
LappedTensorBufferCalculatorOptions options_;
|
||||
|
@ -87,6 +95,21 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
|
|||
);
|
||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
|
||||
<< "Only one output stream is supported.";
|
||||
|
||||
if (cc->InputSidePackets().HasTag(kBufferSize)) {
|
||||
cc->InputSidePackets().Tag(kBufferSize).Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag(kOverlap)) {
|
||||
cc->InputSidePackets().Tag(kOverlap).Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag(kTimestampOffset)) {
|
||||
cc->InputSidePackets().Tag(kTimestampOffset).Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag(kCalculatorOptions)) {
|
||||
cc->InputSidePackets()
|
||||
.Tag(kCalculatorOptions)
|
||||
.Set<LappedTensorBufferCalculatorOptions>();
|
||||
}
|
||||
cc->Outputs().Index(0).Set<tf::Tensor>(
|
||||
// Output tensorflow::Tensor stream with possibly overlapping steps.
|
||||
);
|
||||
|
@ -95,16 +118,33 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
|
|||
|
||||
::mediapipe::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) {
|
||||
options_ = cc->Options<LappedTensorBufferCalculatorOptions>();
|
||||
RET_CHECK_LT(options_.overlap(), options_.buffer_size());
|
||||
RET_CHECK_GE(options_.timestamp_offset(), 0)
|
||||
if (cc->InputSidePackets().HasTag(kCalculatorOptions)) {
|
||||
options_ = cc->InputSidePackets()
|
||||
.Tag(kCalculatorOptions)
|
||||
.Get<LappedTensorBufferCalculatorOptions>();
|
||||
}
|
||||
buffer_size_ = options_.buffer_size();
|
||||
if (cc->InputSidePackets().HasTag(kBufferSize)) {
|
||||
buffer_size_ = cc->InputSidePackets().Tag(kBufferSize).Get<int>();
|
||||
}
|
||||
overlap_ = options_.overlap();
|
||||
if (cc->InputSidePackets().HasTag(kOverlap)) {
|
||||
overlap_ = cc->InputSidePackets().Tag(kOverlap).Get<int>();
|
||||
}
|
||||
timestamp_offset_ = options_.timestamp_offset();
|
||||
if (cc->InputSidePackets().HasTag(kTimestampOffset)) {
|
||||
timestamp_offset_ = cc->InputSidePackets().Tag(kTimestampOffset).Get<int>();
|
||||
}
|
||||
|
||||
RET_CHECK_LT(overlap_, buffer_size_);
|
||||
RET_CHECK_GE(timestamp_offset_, 0)
|
||||
<< "Negative timestamp_offset is not allowed.";
|
||||
RET_CHECK_LT(options_.timestamp_offset(), options_.buffer_size())
|
||||
RET_CHECK_LT(timestamp_offset_, buffer_size_)
|
||||
<< "output_frame_num_offset has to be less than buffer_size.";
|
||||
timestamp_buffer_ =
|
||||
absl::make_unique<CircularBuffer<Timestamp>>(options_.buffer_size());
|
||||
buffer_ =
|
||||
absl::make_unique<CircularBuffer<tf::Tensor>>(options_.buffer_size());
|
||||
steps_until_output_ = options_.buffer_size();
|
||||
absl::make_unique<CircularBuffer<Timestamp>>(buffer_size_);
|
||||
buffer_ = absl::make_unique<CircularBuffer<tf::Tensor>>(buffer_size_);
|
||||
steps_until_output_ = buffer_size_;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -128,11 +168,10 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
|
|||
concatenated.get());
|
||||
RET_CHECK(concat_status.ok()) << concat_status.ToString();
|
||||
|
||||
cc->Outputs().Index(0).Add(
|
||||
concatenated.release(),
|
||||
timestamp_buffer_->Get(options_.timestamp_offset()));
|
||||
cc->Outputs().Index(0).Add(concatenated.release(),
|
||||
timestamp_buffer_->Get(timestamp_offset_));
|
||||
|
||||
steps_until_output_ = options_.buffer_size() - options_.overlap();
|
||||
steps_until_output_ = buffer_size_ - overlap_;
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
|
126
mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc
Normal file
|
@ -0,0 +1,126 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/io/record_reader.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
const char kTFRecordPath[] = "TFRECORD_PATH";
|
||||
const char kRecordIndex[] = "RECORD_INDEX";
|
||||
const char kExampleTag[] = "EXAMPLE";
|
||||
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
||||
|
||||
// Reads a tensorflow example/sequence example from a tfrecord file.
|
||||
// If the "RECORD_INDEX" input side packet is provided, the calculator is going
|
||||
// to fetch the example/sequence example of the tfrecord file at the target
|
||||
// record index. Otherwise, the reader always reads the first example/sequence
|
||||
// example of the tfrecord file.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "TFRecordReaderCalculator"
|
||||
// input_side_packet: "TFRECORD_PATH:tfrecord_path"
|
||||
// input_side_packet: "RECORD_INDEX:record_index"
|
||||
// output_side_packet: "SEQUENCE_EXAMPLE:sequence_example"
|
||||
// }
|
||||
class TFRecordReaderCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
|
||||
::mediapipe::Status TFRecordReaderCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Tag(kTFRecordPath).Set<std::string>();
|
||||
if (cc->InputSidePackets().HasTag(kRecordIndex)) {
|
||||
cc->InputSidePackets().Tag(kRecordIndex).Set<int>();
|
||||
}
|
||||
|
||||
RET_CHECK(cc->OutputSidePackets().HasTag(kExampleTag) ||
|
||||
cc->OutputSidePackets().HasTag(kSequenceExampleTag))
|
||||
<< "TFRecordReaderCalculator must output either Tensorflow example or "
|
||||
"sequence example.";
|
||||
if (cc->OutputSidePackets().HasTag(kExampleTag)) {
|
||||
cc->OutputSidePackets().Tag(kExampleTag).Set<tensorflow::Example>();
|
||||
} else {
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kSequenceExampleTag)
|
||||
.Set<tensorflow::SequenceExample>();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TFRecordReaderCalculator::Open(CalculatorContext* cc) {
|
||||
std::unique_ptr<tensorflow::RandomAccessFile> file;
|
||||
auto tf_status = tensorflow::Env::Default()->NewRandomAccessFile(
|
||||
cc->InputSidePackets().Tag(kTFRecordPath).Get<std::string>(), &file);
|
||||
RET_CHECK(tf_status.ok())
|
||||
<< "Failed to open tfrecord file: " << tf_status.error_message();
|
||||
tensorflow::io::RecordReader reader(file.get(),
|
||||
tensorflow::io::RecordReaderOptions());
|
||||
tensorflow::uint64 offset = 0;
|
||||
std::string example_str;
|
||||
const int target_idx =
|
||||
cc->InputSidePackets().HasTag(kRecordIndex)
|
||||
? cc->InputSidePackets().Tag(kRecordIndex).Get<int>()
|
||||
: 0;
|
||||
int current_idx = 0;
|
||||
while (current_idx <= target_idx) {
|
||||
tf_status = reader.ReadRecord(&offset, &example_str);
|
||||
RET_CHECK(tf_status.ok())
|
||||
<< "Failed to read tfrecord: " << tf_status.error_message();
|
||||
if (current_idx == target_idx) {
|
||||
if (cc->OutputSidePackets().HasTag(kExampleTag)) {
|
||||
tensorflow::Example tf_example;
|
||||
tf_example.ParseFromString(example_str);
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kExampleTag)
|
||||
.Set(MakePacket<tensorflow::Example>(std::move(tf_example)));
|
||||
} else {
|
||||
tensorflow::SequenceExample tf_sequence_example;
|
||||
tf_sequence_example.ParseFromString(example_str);
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kSequenceExampleTag)
|
||||
.Set(MakePacket<tensorflow::SequenceExample>(
|
||||
std::move(tf_sequence_example)));
|
||||
}
|
||||
}
|
||||
++current_idx;
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TFRecordReaderCalculator::Process(CalculatorContext* cc) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
REGISTER_CALCULATOR(TFRecordReaderCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,192 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iterator>
|
||||
|
||||
#include "mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/example/feature.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
const char kId[] = "id";
|
||||
const char kRgb[] = "rgb";
|
||||
const char kAudio[] = "audio";
|
||||
const char kDesiredSegmentSize[] = "DESIRED_SEGMENT_SIZE";
|
||||
const char kYt8mId[] = "YT8M_ID";
|
||||
const char kYt8mSequenceExample[] = "YT8M_SEQUENCE_EXAMPLE";
|
||||
const char kQuantizedRgbFeature[] = "QUANTIZED_RGB_FEATURE";
|
||||
const char kQuantizedAudioFeature[] = "QUANTIZED_AUDIO_FEATURE";
|
||||
const char kSegmentSize[] = "SEGMENT_SIZE";
|
||||
const char kLappedTensorBufferCalculatorOptions[] =
|
||||
"LAPPED_TENSOR_BUFFER_CALCULATOR_OPTIONS";
|
||||
|
||||
std::string GetQuantizedFeature(
|
||||
const tensorflow::SequenceExample& sequence_example, const std::string& key,
|
||||
int index) {
|
||||
const auto& bytes_list = sequence_example.feature_lists()
|
||||
.feature_list()
|
||||
.at(key)
|
||||
.feature()
|
||||
.Get(index)
|
||||
.bytes_list()
|
||||
.value();
|
||||
CHECK_EQ(1, bytes_list.size());
|
||||
return bytes_list.Get(0);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Unpacks YT8M Sequence Example. Note that the audio feature and rgb feature
|
||||
// output are quantized. DequantizeByteArrayCalculator can do the dequantization
|
||||
// for you.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "UnpackYt8mSequenceExampleCalculator"
|
||||
// input_side_packet: "YT8M_SEQUENCE_EXAMPLE:yt8m_sequence_example"
|
||||
// output_stream: "QUANTIZED_RGB_FEATURE:quantized_rgb_feature"
|
||||
// output_stream: "QUANTIZED_AUDIO_FEATURE:quantized_audio_feature"
|
||||
// }
|
||||
class UnpackYt8mSequenceExampleCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets()
|
||||
.Tag(kYt8mSequenceExample)
|
||||
.Set<tensorflow::SequenceExample>();
|
||||
if (cc->InputSidePackets().HasTag(kDesiredSegmentSize)) {
|
||||
cc->InputSidePackets().Tag(kDesiredSegmentSize).Set<int>();
|
||||
}
|
||||
cc->Outputs().Tag(kQuantizedRgbFeature).Set<std::string>();
|
||||
cc->Outputs().Tag(kQuantizedAudioFeature).Set<std::string>();
|
||||
if (cc->OutputSidePackets().HasTag(kYt8mId)) {
|
||||
cc->OutputSidePackets().Tag(kYt8mId).Set<std::string>();
|
||||
}
|
||||
if (cc->OutputSidePackets().HasTag(kLappedTensorBufferCalculatorOptions)) {
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kLappedTensorBufferCalculatorOptions)
|
||||
.Set<::mediapipe::LappedTensorBufferCalculatorOptions>();
|
||||
}
|
||||
if (cc->OutputSidePackets().HasTag(kSegmentSize)) {
|
||||
cc->OutputSidePackets().Tag(kSegmentSize).Set<int>();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
const tensorflow::SequenceExample& sequence_example =
|
||||
cc->InputSidePackets()
|
||||
.Tag(kYt8mSequenceExample)
|
||||
.Get<tensorflow::SequenceExample>();
|
||||
const std::string& yt8m_id =
|
||||
sequence_example.context().feature().at(kId).bytes_list().value().Get(
|
||||
0);
|
||||
if (cc->OutputSidePackets().HasTag(kYt8mId)) {
|
||||
cc->OutputSidePackets().Tag(kYt8mId).Set(
|
||||
MakePacket<std::string>(yt8m_id));
|
||||
}
|
||||
|
||||
int rgb_feature_list_length =
|
||||
sequence_example.feature_lists().feature_list().at(kRgb).feature_size();
|
||||
int audio_feature_list_length = sequence_example.feature_lists()
|
||||
.feature_list()
|
||||
.at(kAudio)
|
||||
.feature_size();
|
||||
|
||||
if (rgb_feature_list_length != audio_feature_list_length) {
|
||||
return ::mediapipe::FailedPreconditionError(absl::StrCat(
|
||||
"Data corruption: the length of audio features and rgb features are "
|
||||
"not equal. Please check the sequence example that contains yt8m "
|
||||
"id: ",
|
||||
yt8m_id));
|
||||
}
|
||||
feature_list_length_ = rgb_feature_list_length;
|
||||
if (cc->OutputSidePackets().HasTag(kLappedTensorBufferCalculatorOptions) ||
|
||||
cc->OutputSidePackets().HasTag(kSegmentSize)) {
|
||||
// If the desired segment size is specified, take the min of the length of
|
||||
// the feature list and the desired size to be the output segment size.
|
||||
int segment_size = feature_list_length_;
|
||||
if (cc->InputSidePackets().HasTag(kDesiredSegmentSize)) {
|
||||
int desired_segment_size =
|
||||
cc->InputSidePackets().Tag(kDesiredSegmentSize).Get<int>();
|
||||
RET_CHECK(desired_segment_size > 0)
|
||||
<< "The desired segment size must be greater than zero.";
|
||||
segment_size = std::min(
|
||||
feature_list_length_,
|
||||
cc->InputSidePackets().Tag(kDesiredSegmentSize).Get<int>());
|
||||
}
|
||||
if (cc->OutputSidePackets().HasTag(
|
||||
kLappedTensorBufferCalculatorOptions)) {
|
||||
auto lapped_tensor_buffer_calculator_options = absl::make_unique<
|
||||
::mediapipe::LappedTensorBufferCalculatorOptions>();
|
||||
lapped_tensor_buffer_calculator_options->set_add_batch_dim_to_tensors(
|
||||
true);
|
||||
lapped_tensor_buffer_calculator_options->set_buffer_size(segment_size);
|
||||
lapped_tensor_buffer_calculator_options->set_overlap(segment_size - 1);
|
||||
lapped_tensor_buffer_calculator_options->set_timestamp_offset(
|
||||
segment_size - 1);
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kLappedTensorBufferCalculatorOptions)
|
||||
.Set(Adopt(lapped_tensor_buffer_calculator_options.release()));
|
||||
}
|
||||
if (cc->OutputSidePackets().HasTag(kSegmentSize)) {
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kSegmentSize)
|
||||
.Set(MakePacket<int>(segment_size));
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "Reading the sequence example that contains yt8m id: "
|
||||
<< yt8m_id << ". Feature list length: " << feature_list_length_;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
if (current_index_ >= feature_list_length_) {
|
||||
return ::mediapipe::tool::StatusStop();
|
||||
}
|
||||
const tensorflow::SequenceExample& sequence_example =
|
||||
cc->InputSidePackets()
|
||||
.Tag(kYt8mSequenceExample)
|
||||
.Get<tensorflow::SequenceExample>();
|
||||
|
||||
// Uses microsecond as the unit of time. In the YT8M dataset, each feature
|
||||
// represents a second.
|
||||
const Timestamp timestamp = Timestamp(current_index_ * 1000000);
|
||||
cc->Outputs()
|
||||
.Tag(kQuantizedRgbFeature)
|
||||
.AddPacket(
|
||||
MakePacket<std::string>(
|
||||
GetQuantizedFeature(sequence_example, kRgb, current_index_))
|
||||
.At(timestamp));
|
||||
cc->Outputs()
|
||||
.Tag(kQuantizedAudioFeature)
|
||||
.AddPacket(
|
||||
MakePacket<std::string>(
|
||||
GetQuantizedFeature(sequence_example, kAudio, current_index_))
|
||||
.At(timestamp));
|
||||
++current_index_;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int current_index_ = 0;
|
||||
int feature_list_length_ = 0;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(UnpackYt8mSequenceExampleCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -23,10 +23,12 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
namespace {
|
||||
auto& INPUT_1D = VectorFloatToTensorCalculatorOptions::INPUT_1D;
|
||||
auto& INPUT_2D = VectorFloatToTensorCalculatorOptions::INPUT_2D;
|
||||
} // namespace
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
// The calculator expects one input (a packet containing a vector<float> or
|
||||
// vector<vector<float>>) and generates one output (a packet containing a
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Converts a single int or vector<int> or vector<vector<int>> to 1D (or 2D)
|
||||
// tf::Tensor.
|
||||
|
||||
#include "mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_options.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
const char kVectorInt[] = "VECTOR_INT";
|
||||
const char kSingleInt[] = "SINGLE_INT";
|
||||
const char kTensorOut[] = "TENSOR_OUT";
|
||||
|
||||
namespace {
|
||||
auto& INPUT_1D = VectorIntToTensorCalculatorOptions::INPUT_1D;
|
||||
auto& INPUT_2D = VectorIntToTensorCalculatorOptions::INPUT_2D;
|
||||
} // namespace
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
template <typename TensorType>
|
||||
void AssignMatrixValue(int r, int c, int value, tf::Tensor* output_tensor) {
|
||||
output_tensor->tensor<TensorType, 2>()(r, c) = value;
|
||||
}
|
||||
|
||||
// The calculator expects one input (a packet containing a single int or
|
||||
// vector<int> or vector<vector<int>>) and generates one output (a packet
|
||||
// containing a tf::Tensor containing the same data). The output tensor will be
|
||||
// either 1D or 2D with dimensions corresponding to the input vector int. It
|
||||
// will hold DT_INT32 or DT_UINT8 or DT_INT64 values.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "VectorIntToTensorCalculator"
|
||||
// input_stream: "SINGLE_INT:segment_size_int_stream"
|
||||
// output_stream: "TENSOR_OUT:segment_size_tensor"
|
||||
// }
|
||||
//
|
||||
// or
|
||||
//
|
||||
// node {
|
||||
// calculator: "VectorIntToTensorCalculator"
|
||||
// input_stream: "VECTOR_INT:vector_int_features"
|
||||
// output_stream: "TENSOR_OUT:tensor_features"
|
||||
// }
|
||||
class VectorIntToTensorCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
VectorIntToTensorCalculatorOptions options_;
|
||||
};
|
||||
REGISTER_CALCULATOR(VectorIntToTensorCalculator);
|
||||
|
||||
::mediapipe::Status VectorIntToTensorCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
const auto& options = cc->Options<VectorIntToTensorCalculatorOptions>();
|
||||
// Start with only one input packet.
|
||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
|
||||
<< "Only one input stream is supported.";
|
||||
if (options.input_size() == INPUT_2D) {
|
||||
cc->Inputs().Tag(kVectorInt).Set<std::vector<std::vector<int>>>();
|
||||
} else if (options.input_size() == INPUT_1D) {
|
||||
if (cc->Inputs().HasTag(kSingleInt)) {
|
||||
cc->Inputs().Tag(kSingleInt).Set<int>();
|
||||
} else {
|
||||
cc->Inputs().Tag(kVectorInt).Set<std::vector<int>>();
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "input size not supported";
|
||||
}
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
|
||||
<< "Only one output stream is supported.";
|
||||
cc->Outputs().Tag(kTensorOut).Set<tf::Tensor>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status VectorIntToTensorCalculator::Open(CalculatorContext* cc) {
|
||||
options_ = cc->Options<VectorIntToTensorCalculatorOptions>();
|
||||
RET_CHECK(options_.tensor_data_type() == tf::DT_UINT8 ||
|
||||
options_.tensor_data_type() == tf::DT_INT32 ||
|
||||
options_.tensor_data_type() == tf::DT_INT64)
|
||||
<< "Output tensor data type is not supported.";
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status VectorIntToTensorCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
tf::TensorShape tensor_shape;
|
||||
if (options_.input_size() == INPUT_2D) {
|
||||
const std::vector<std::vector<int>>& input =
|
||||
cc->Inputs()
|
||||
.Tag(kVectorInt)
|
||||
.Value()
|
||||
.Get<std::vector<std::vector<int>>>();
|
||||
|
||||
const int32 rows = input.size();
|
||||
CHECK_GE(rows, 1);
|
||||
const int32 cols = input[0].size();
|
||||
CHECK_GE(cols, 1);
|
||||
for (int i = 1; i < rows; ++i) {
|
||||
CHECK_EQ(input[i].size(), cols);
|
||||
}
|
||||
if (options_.transpose()) {
|
||||
tensor_shape = tf::TensorShape({cols, rows});
|
||||
} else {
|
||||
tensor_shape = tf::TensorShape({rows, cols});
|
||||
}
|
||||
auto output = ::absl::make_unique<tf::Tensor>(options_.tensor_data_type(),
|
||||
tensor_shape);
|
||||
if (options_.transpose()) {
|
||||
for (int r = 0; r < rows; ++r) {
|
||||
for (int c = 0; c < cols; ++c) {
|
||||
switch (options_.tensor_data_type()) {
|
||||
case tf::DT_INT64:
|
||||
AssignMatrixValue<tf::int64>(c, r, input[r][c], output.get());
|
||||
break;
|
||||
case tf::DT_UINT8:
|
||||
AssignMatrixValue<uint8>(c, r, input[r][c], output.get());
|
||||
break;
|
||||
case tf::DT_INT32:
|
||||
AssignMatrixValue<int>(c, r, input[r][c], output.get());
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "tensor data type is not supported.";
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int r = 0; r < rows; ++r) {
|
||||
for (int c = 0; c < cols; ++c) {
|
||||
switch (options_.tensor_data_type()) {
|
||||
case tf::DT_INT64:
|
||||
AssignMatrixValue<tf::int64>(r, c, input[r][c], output.get());
|
||||
break;
|
||||
case tf::DT_UINT8:
|
||||
AssignMatrixValue<uint8>(r, c, input[r][c], output.get());
|
||||
break;
|
||||
case tf::DT_INT32:
|
||||
AssignMatrixValue<int>(r, c, input[r][c], output.get());
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "tensor data type is not supported.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
cc->Outputs().Tag(kTensorOut).Add(output.release(), cc->InputTimestamp());
|
||||
} else if (options_.input_size() == INPUT_1D) {
|
||||
std::vector<int> input;
|
||||
if (cc->Inputs().HasTag(kSingleInt)) {
|
||||
input.push_back(cc->Inputs().Tag(kSingleInt).Get<int>());
|
||||
} else {
|
||||
input = cc->Inputs().Tag(kVectorInt).Value().Get<std::vector<int>>();
|
||||
}
|
||||
CHECK_GE(input.size(), 1);
|
||||
const int32 length = input.size();
|
||||
tensor_shape = tf::TensorShape({length});
|
||||
auto output = ::absl::make_unique<tf::Tensor>(options_.tensor_data_type(),
|
||||
tensor_shape);
|
||||
for (int i = 0; i < length; ++i) {
|
||||
switch (options_.tensor_data_type()) {
|
||||
case tf::DT_INT64:
|
||||
output->tensor<tf::int64, 1>()(i) = input.at(i);
|
||||
break;
|
||||
case tf::DT_UINT8:
|
||||
output->tensor<uint8, 1>()(i) = input.at(i);
|
||||
break;
|
||||
case tf::DT_INT32:
|
||||
output->tensor<int, 1>()(i) = input.at(i);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "tensor data type is not supported.";
|
||||
}
|
||||
}
|
||||
cc->Outputs().Tag(kTensorOut).Add(output.release(), cc->InputTimestamp());
|
||||
} else {
|
||||
LOG(FATAL) << "input size not supported";
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
message VectorIntToTensorCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional VectorIntToTensorCalculatorOptions ext = 275364184;
|
||||
}
|
||||
enum InputSize {
|
||||
UNKNOWN = 0;
|
||||
INPUT_1D = 1;
|
||||
INPUT_2D = 2;
|
||||
}
|
||||
|
||||
// If input_size is INPUT_2D, unpack a vector<vector<int>> to a
|
||||
// 2d tensor (matrix). If INPUT_1D, convert a single int or vector<int>
|
||||
// into a 1d tensor (vector).
|
||||
optional InputSize input_size = 1 [default = INPUT_1D];
|
||||
|
||||
// If true, the output tensor is transposed.
|
||||
// Otherwise, the output tensor is not transposed.
|
||||
// It will be ignored if tensor_is_2d is INPUT_1D.
|
||||
optional bool transpose = 2 [default = false];
|
||||
|
||||
optional tensorflow.DataType tensor_data_type = 3 [default = DT_INT32];
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
// Copyright 2018 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_options.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
class VectorIntToTensorCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUpRunner(
|
||||
const VectorIntToTensorCalculatorOptions::InputSize input_size,
|
||||
const tensorflow::DataType tensor_data_type, const bool transpose,
|
||||
const bool single_value) {
|
||||
CalculatorGraphConfig::Node config;
|
||||
config.set_calculator("VectorIntToTensorCalculator");
|
||||
if (single_value) {
|
||||
config.add_input_stream("SINGLE_INT:input_int");
|
||||
} else {
|
||||
config.add_input_stream("VECTOR_INT:input_int");
|
||||
}
|
||||
config.add_output_stream("TENSOR_OUT:output_tensor");
|
||||
auto options = config.mutable_options()->MutableExtension(
|
||||
VectorIntToTensorCalculatorOptions::ext);
|
||||
options->set_input_size(input_size);
|
||||
options->set_transpose(transpose);
|
||||
options->set_tensor_data_type(tensor_data_type);
|
||||
runner_ = ::absl::make_unique<CalculatorRunner>(config);
|
||||
}
|
||||
|
||||
void TestConvertFromVectoVectorInt(const bool transpose) {
|
||||
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_2D,
|
||||
tensorflow::DT_INT32, transpose, false);
|
||||
auto input = ::absl::make_unique<std::vector<std::vector<int>>>(
|
||||
2, std::vector<int>(2));
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
input->at(i).at(j) = i * 2 + j;
|
||||
}
|
||||
}
|
||||
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("VECTOR_INT")
|
||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
||||
EXPECT_EQ(2, output_tensor.dims());
|
||||
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
|
||||
const auto matrix = output_tensor.matrix<int>();
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
if (!transpose) {
|
||||
EXPECT_EQ(i * 2 + j, matrix(i, j));
|
||||
} else {
|
||||
EXPECT_EQ(j * 2 + i, matrix(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<CalculatorRunner> runner_;
|
||||
};
|
||||
|
||||
TEST_F(VectorIntToTensorCalculatorTest, TestSingleValue) {
|
||||
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
|
||||
tensorflow::DT_INT32, false, true);
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("SINGLE_INT")
|
||||
.packets.push_back(MakePacket<int>(1).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
||||
EXPECT_EQ(1, output_tensor.dims());
|
||||
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
|
||||
const auto vec = output_tensor.vec<int32>();
|
||||
EXPECT_EQ(1, vec(0));
|
||||
}
|
||||
|
||||
TEST_F(VectorIntToTensorCalculatorTest, TesOneDim) {
|
||||
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
|
||||
tensorflow::DT_INT32, false, false);
|
||||
auto input = ::absl::make_unique<std::vector<int>>(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
input->at(i) = i;
|
||||
}
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("VECTOR_INT")
|
||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
||||
EXPECT_EQ(1, output_tensor.dims());
|
||||
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
|
||||
const auto vec = output_tensor.vec<int32>();
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
EXPECT_EQ(i, vec(i));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(VectorIntToTensorCalculatorTest, TestTwoDims) {
|
||||
for (bool transpose : {false, true}) {
|
||||
TestConvertFromVectoVectorInt(transpose);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(VectorIntToTensorCalculatorTest, TestInt64) {
|
||||
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
|
||||
tensorflow::DT_INT64, false, true);
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("SINGLE_INT")
|
||||
.packets.push_back(MakePacket<int>(2 ^ 31).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
||||
EXPECT_EQ(1, output_tensor.dims());
|
||||
EXPECT_EQ(tf::DT_INT64, output_tensor.dtype());
|
||||
const auto vec = output_tensor.vec<tf::int64>();
|
||||
EXPECT_EQ(2 ^ 31, vec(0));
|
||||
}
|
||||
|
||||
TEST_F(VectorIntToTensorCalculatorTest, TestUint8) {
|
||||
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
|
||||
tensorflow::DT_UINT8, false, false);
|
||||
auto input = ::absl::make_unique<std::vector<int>>(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
input->at(i) = i;
|
||||
}
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("VECTOR_INT")
|
||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
||||
EXPECT_EQ(1, output_tensor.dims());
|
||||
EXPECT_EQ(tf::DT_UINT8, output_tensor.dtype());
|
||||
const auto vec = output_tensor.vec<uint8>();
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
EXPECT_EQ(i, vec(i));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -238,6 +238,7 @@ cc_library(
|
|||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/metal:buffer_convert",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
|
|
|
@ -25,7 +25,8 @@
|
|||
#include "tensorflow/lite/error_reporter.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||
|
@ -45,7 +46,8 @@
|
|||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||
#endif // iOS
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
typedef id<MTLBuffer> GpuTensor;
|
||||
|
@ -67,7 +69,8 @@ typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
using ::tflite::gpu::gl::GlProgram;
|
||||
using ::tflite::gpu::gl::GlShader;
|
||||
|
@ -146,7 +149,8 @@ class TfLiteConverterCalculator : public CalculatorBase {
|
|||
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GPUData> gpu_data_out_;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
|
@ -181,7 +185,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
|
||||
if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
|
||||
if (cc->Inputs().HasTag("MATRIX")) cc->Inputs().Tag("MATRIX").Set<Matrix>();
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU")) {
|
||||
cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
|
||||
use_gpu |= true;
|
||||
|
@ -190,7 +194,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
|
||||
if (cc->Outputs().HasTag("TENSORS"))
|
||||
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Outputs().HasTag("TENSORS_GPU")) {
|
||||
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
|
@ -198,7 +202,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
if (use_gpu) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
|
@ -218,7 +223,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU") ||
|
||||
cc->Outputs().HasTag("IMAGE_OUT_GPU")) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
use_gpu_ = true;
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing not enabled.";
|
||||
|
@ -231,7 +236,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
cc->Outputs().HasTag("TENSORS_GPU"));
|
||||
// Cannot use quantization.
|
||||
use_quantized_tensors_ = false;
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
|
@ -264,7 +270,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
|
||||
#endif
|
||||
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
|
@ -383,7 +390,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
// GpuBuffer to tflite::gpu::GlBuffer conversion.
|
||||
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
|
||||
MP_RETURN_IF_ERROR(
|
||||
|
@ -468,7 +476,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
// Get input image sizes.
|
||||
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
|
||||
mediapipe::ImageFormat::Format format =
|
||||
|
@ -485,7 +493,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
RET_CHECK_FAIL() << "Num input channels is less than desired output.";
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status {
|
||||
// Device memory.
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
|
@ -48,11 +48,13 @@
|
|||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h"
|
||||
#endif // iOS
|
||||
|
||||
namespace {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
typedef id<MTLBuffer> GpuTensor;
|
||||
|
@ -68,13 +70,14 @@ size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT
|
|||
// * Aux
|
||||
namespace mediapipe {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
using ::tflite::gpu::gl::CopyBuffer;
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
using ::tflite::gpu::gl::GlBuffer;
|
||||
#endif
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
struct GPUData {
|
||||
int elements = 1;
|
||||
GpuTensor buffer;
|
||||
|
@ -147,7 +150,8 @@ class TfLiteInferenceCalculator : public CalculatorBase {
|
|||
std::unique_ptr<tflite::FlatBufferModel> model_;
|
||||
TfLiteDelegate* delegate_ = nullptr;
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GPUData> gpu_data_in_;
|
||||
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
|
||||
|
@ -179,7 +183,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
if (cc->Inputs().HasTag("TENSORS"))
|
||||
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
|
@ -188,7 +192,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
if (cc->Outputs().HasTag("TENSORS"))
|
||||
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Outputs().HasTag("TENSORS_GPU")) {
|
||||
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
|
@ -206,7 +210,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
use_gpu |= options.use_gpu();
|
||||
|
||||
if (use_gpu) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
|
@ -225,7 +230,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
gpu_input_ = true;
|
||||
gpu_inference_ = true; // Inference must be on GPU also.
|
||||
#else
|
||||
|
@ -235,7 +240,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
}
|
||||
|
||||
if (cc->Outputs().HasTag("TENSORS_GPU")) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
gpu_output_ = true;
|
||||
RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU"))
|
||||
<< "GPU output must also have GPU Input.";
|
||||
|
@ -248,13 +253,15 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||
|
||||
if (gpu_inference_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
RET_CHECK(gpu_helper_);
|
||||
#endif
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &cc]() -> ::mediapipe::Status { return LoadDelegate(cc); }));
|
||||
#else
|
||||
|
@ -262,6 +269,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined(__EMSCRIPTEN__)
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -269,7 +280,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
// 1. Receive pre-processed tensor inputs.
|
||||
if (gpu_input_) {
|
||||
// Read GPU input into SSBO.
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
|
@ -315,7 +327,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
// 2. Run inference.
|
||||
if (gpu_inference_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
|
@ -330,7 +343,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
// 3. Output processed tensors.
|
||||
if (gpu_output_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
// Output result tensors (GPU).
|
||||
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
|
@ -392,7 +406,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
|
||||
if (delegate_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
|
||||
TfLiteGpuDelegateDelete(delegate_);
|
||||
gpu_data_in_.reset();
|
||||
|
@ -456,6 +471,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
RET_CHECK(interpreter_);
|
||||
|
||||
#if defined(__EMSCRIPTEN__)
|
||||
interpreter_->SetNumThreads(1);
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
if (gpu_output_) {
|
||||
use_quantized_tensors_ = false;
|
||||
} else {
|
||||
|
@ -471,7 +490,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
// Configure and create the delegate.
|
||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||
options.compile_options.precision_loss_allowed = 1;
|
||||
|
@ -533,9 +553,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
// Configure and create the delegate.
|
||||
GpuDelegateOptions options;
|
||||
TFLGpuDelegateOptions options;
|
||||
options.allow_precision_loss = false; // Must match converter, F=float/T=half
|
||||
options.wait_type = GpuDelegateOptions::WaitType::kPassive;
|
||||
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive;
|
||||
if (!delegate_) delegate_ = TFLGpuDelegateCreate(&options);
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
|
||||
|
|
|
@ -24,7 +24,8 @@
|
|||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#if defined(__EMSCRIPTEN__) || defined(__ANDROID__) || \
|
||||
(defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#include "mediapipe/util/android/file/base/file.h"
|
||||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#else
|
||||
|
@ -66,8 +67,8 @@ class TfLiteTensorsToClassificationCalculator : public CalculatorBase {
|
|||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
::mediapipe::TfLiteTensorsToClassificationCalculatorOptions options_;
|
||||
int top_k_ = 0;
|
||||
double min_score_threshold_ = 0;
|
||||
std::unordered_map<int, std::string> label_map_;
|
||||
bool label_map_loaded_ = false;
|
||||
};
|
||||
|
@ -93,15 +94,14 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
|
|||
CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
auto options = cc->Options<
|
||||
options_ = cc->Options<
|
||||
::mediapipe::TfLiteTensorsToClassificationCalculatorOptions>();
|
||||
|
||||
top_k_ = options.top_k();
|
||||
min_score_threshold_ = options.min_score_threshold();
|
||||
if (options.has_label_map_path()) {
|
||||
top_k_ = options_.top_k();
|
||||
if (options_.has_label_map_path()) {
|
||||
std::string string_path;
|
||||
ASSIGN_OR_RETURN(string_path,
|
||||
PathToResourceAsFile(options.label_map_path()));
|
||||
PathToResourceAsFile(options_.label_map_path()));
|
||||
std::string label_map_string;
|
||||
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
|
||||
|
||||
|
@ -125,9 +125,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
|
|||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
|
||||
const TfLiteTensor* raw_score_tensor = &input_tensors[0];
|
||||
RET_CHECK_EQ(raw_score_tensor->dims->size, 2);
|
||||
RET_CHECK_EQ(raw_score_tensor->dims->data[0], 1);
|
||||
int num_classes = raw_score_tensor->dims->data[1];
|
||||
int num_classes = 1;
|
||||
for (int i = 0; i < raw_score_tensor->dims->size; ++i) {
|
||||
num_classes *= raw_score_tensor->dims->data[i];
|
||||
}
|
||||
|
||||
if (label_map_loaded_) {
|
||||
RET_CHECK_EQ(num_classes, label_map_.size());
|
||||
}
|
||||
|
@ -135,7 +137,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
|
|||
|
||||
auto classification_list = absl::make_unique<ClassificationList>();
|
||||
for (int i = 0; i < num_classes; ++i) {
|
||||
if (raw_scores[i] < min_score_threshold_) {
|
||||
if (options_.has_min_score_threshold() &&
|
||||
raw_scores[i] < options_.min_score_threshold()) {
|
||||
continue;
|
||||
}
|
||||
Classification* classification = classification_list->add_classification();
|
||||
|
@ -148,6 +151,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
|
|||
|
||||
// Note that partial_sort will raise error when top_k_ >
|
||||
// classification_list->classification_size().
|
||||
CHECK_GE(classification_list->classification_size(), top_k_);
|
||||
auto raw_classification_list = classification_list->mutable_classification();
|
||||
if (top_k_ > 0 && classification_list->classification_size() >= top_k_) {
|
||||
std::partial_sort(raw_classification_list->begin(),
|
||||
|
|
|
@ -27,7 +27,8 @@
|
|||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||
|
@ -55,12 +56,14 @@ constexpr int kNumCoordsPerBox = 4;
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
using ::tflite::gpu::gl::GlShader;
|
||||
#endif
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
typedef ::tflite::gpu::gl::GlProgram GpuProgram;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
|
@ -70,7 +73,7 @@ typedef id<MTLComputePipelineState> GpuProgram;
|
|||
|
||||
namespace {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
struct GPUData {
|
||||
GpuProgram decode_program;
|
||||
GpuProgram score_program;
|
||||
|
@ -180,7 +183,8 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
|
|||
std::vector<Anchor> anchors_;
|
||||
bool side_packet_anchors_{};
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GPUData> gpu_data_;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
|
@ -204,7 +208,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
}
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
|
@ -222,7 +226,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
}
|
||||
|
||||
if (use_gpu) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
|
@ -238,7 +243,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
gpu_input_ = true;
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
|
@ -400,7 +406,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
}
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
|
||||
CalculatorContext* cc, std::vector<Detection>* output_detections) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK_GE(input_tensors.size(), 2);
|
||||
|
@ -562,7 +569,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); });
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
gpu_data_.reset();
|
||||
|
@ -715,7 +723,8 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
|
||||
-> ::mediapipe::Status {
|
||||
gpu_data_ = absl::make_unique<GPUData>();
|
||||
|
|
|
@ -21,7 +21,8 @@
|
|||
namespace mediapipe {
|
||||
|
||||
// A calculator for converting TFLite tensors from regression models into
|
||||
// landmarks.
|
||||
// landmarks. Note that if the landmarks in the tensor has more than 3
|
||||
// dimensions, only the first 3 dimensions will be converted to x,y,z.
|
||||
//
|
||||
// Input:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first
|
||||
|
@ -122,9 +123,6 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
|
|||
num_values *= raw_tensor->dims->data[i];
|
||||
}
|
||||
const int num_dimensions = num_values / num_landmarks_;
|
||||
// Landmarks must have less than 3 dimensions. Otherwise please consider
|
||||
// using matrix.
|
||||
CHECK_LE(num_dimensions, 3);
|
||||
CHECK_GT(num_dimensions, 0);
|
||||
|
||||
const float* raw_landmarks = raw_tensor->data.f;
|
||||
|
|
|
@ -28,7 +28,8 @@
|
|||
#include "mediapipe/util/resource_util.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gl_simple_shaders.h"
|
||||
#include "mediapipe/gpu/shader_util.h"
|
||||
|
@ -53,7 +54,8 @@ float Clamp(float val, float min, float max) {
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
using ::tflite::gpu::gl::CopyBuffer;
|
||||
using ::tflite::gpu::gl::CreateReadWriteRgbaImageTexture;
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
|
@ -129,7 +131,8 @@ class TfLiteTensorsToSegmentationCalculator : public CalculatorBase {
|
|||
int tensor_channels_ = 0;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GlProgram> mask_program_with_prev_;
|
||||
std::unique_ptr<GlProgram> mask_program_no_prev_;
|
||||
|
@ -159,7 +162,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
}
|
||||
|
||||
// Inputs GPU.
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>();
|
||||
use_gpu |= true;
|
||||
|
@ -178,7 +182,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
if (cc->Outputs().HasTag("MASK")) {
|
||||
cc->Outputs().Tag("MASK").Set<ImageFrame>();
|
||||
}
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
if (cc->Outputs().HasTag("MASK_GPU")) {
|
||||
cc->Outputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>();
|
||||
use_gpu |= true;
|
||||
|
@ -186,7 +191,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
if (use_gpu) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
|
@ -199,7 +205,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
use_gpu_ = true;
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
|
@ -207,7 +214,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
||||
if (use_gpu_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
|
||||
MP_RETURN_IF_ERROR(InitGpu(cc));
|
||||
|
@ -224,7 +232,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
if (use_gpu_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
|
||||
MP_RETURN_IF_ERROR(ProcessGpu(cc));
|
||||
|
@ -240,7 +249,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
gpu_helper_.RunInGlContext([this] {
|
||||
if (upsample_program_) glDeleteProgram(upsample_program_);
|
||||
upsample_program_ = 0;
|
||||
|
@ -367,7 +377,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
if (cc->Inputs().Tag("TENSORS_GPU").IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
// Get input streams.
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>();
|
||||
|
@ -453,7 +464,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
}
|
||||
|
||||
void TfLiteTensorsToSegmentationCalculator::GlRender() {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
static const GLfloat square_vertices[] = {
|
||||
-1.0f, -1.0f, // bottom left
|
||||
1.0f, -1.0f, // bottom right
|
||||
|
@ -525,7 +537,8 @@ void TfLiteTensorsToSegmentationCalculator::GlRender() {
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::InitGpu(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
|
||||
-> ::mediapipe::Status {
|
||||
// A shader to process a segmentation tensor into an output mask,
|
||||
|
|
|
@ -12,14 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "annotation_overlay_calculator_proto",
|
||||
srcs = ["annotation_overlay_calculator.proto"],
|
||||
|
@ -72,6 +72,24 @@ proto_library(
|
|||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "collection_has_min_size_calculator_proto",
|
||||
srcs = ["collection_has_min_size_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "association_calculator_proto",
|
||||
srcs = ["association_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "annotation_overlay_calculator_cc_proto",
|
||||
srcs = ["annotation_overlay_calculator.proto"],
|
||||
|
@ -141,6 +159,26 @@ mediapipe_cc_proto_library(
|
|||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "collection_has_min_size_calculator_cc_proto",
|
||||
srcs = ["collection_has_min_size_calculator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
deps = [":collection_has_min_size_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "association_calculator_cc_proto",
|
||||
srcs = ["association_calculator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
deps = [":association_calculator_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "packet_frequency_calculator",
|
||||
srcs = ["packet_frequency_calculator.cc"],
|
||||
|
@ -234,6 +272,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:vector",
|
||||
"//mediapipe/util:annotation_renderer",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
"//conditions:default": [
|
||||
|
@ -360,6 +399,16 @@ mediapipe_cc_proto_library(
|
|||
deps = [":landmark_projection_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "landmarks_to_floats_calculator_cc_proto",
|
||||
srcs = ["landmarks_to_floats_calculator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":landmarks_to_floats_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "rect_transformation_calculator_cc_proto",
|
||||
srcs = ["rect_transformation_calculator.proto"],
|
||||
|
@ -372,7 +421,12 @@ mediapipe_cc_proto_library(
|
|||
|
||||
cc_library(
|
||||
name = "detections_to_rects_calculator",
|
||||
srcs = ["detections_to_rects_calculator.cc"],
|
||||
srcs = [
|
||||
"detections_to_rects_calculator.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"detections_to_rects_calculator.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":detections_to_rects_calculator_cc_proto",
|
||||
|
@ -454,6 +508,17 @@ proto_library(
|
|||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "labels_to_render_data_calculator_proto",
|
||||
srcs = ["labels_to_render_data_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/util:color_proto",
|
||||
"//mediapipe/util:render_data_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "thresholding_calculator_proto",
|
||||
srcs = ["thresholding_calculator.proto"],
|
||||
|
@ -483,6 +548,15 @@ proto_library(
|
|||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "landmarks_to_floats_calculator_proto",
|
||||
srcs = ["landmarks_to_floats_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "rect_transformation_calculator_proto",
|
||||
srcs = ["rect_transformation_calculator.proto"],
|
||||
|
@ -577,6 +651,26 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "labels_to_render_data_calculator",
|
||||
srcs = ["labels_to_render_data_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":labels_to_render_data_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:video_stream_header",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rect_to_render_data_calculator",
|
||||
srcs = ["rect_to_render_data_calculator.cc"],
|
||||
|
@ -658,6 +752,22 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "landmarks_to_floats_calculator",
|
||||
srcs = ["landmarks_to_floats_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":landmarks_to_floats_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "detection_letterbox_removal_calculator_test",
|
||||
srcs = ["detection_letterbox_removal_calculator_test.cc"],
|
||||
|
@ -714,6 +824,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":top_k_scores_calculator_cc_proto",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
|
@ -750,3 +861,125 @@ cc_test(
|
|||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "labels_to_render_data_calculator_cc_proto",
|
||||
srcs = ["labels_to_render_data_calculator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":labels_to_render_data_calculator_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "local_file_contents_calculator",
|
||||
srcs = ["local_file_contents_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "filter_collection_calculator",
|
||||
srcs = ["filter_collection_calculator.cc"],
|
||||
hdrs = ["filter_collection_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "collection_has_min_size_calculator",
|
||||
srcs = ["collection_has_min_size_calculator.cc"],
|
||||
hdrs = ["collection_has_min_size_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":collection_has_min_size_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "association_calculator",
|
||||
hdrs = ["association_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":association_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework/port:rectangle",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "association_norm_rect_calculator",
|
||||
srcs = ["association_norm_rect_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":association_calculator",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:rectangle",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "association_detection_calculator",
|
||||
srcs = ["association_detection_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":association_calculator",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:location",
|
||||
"//mediapipe/framework/port:rectangle",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "association_calculator_test",
|
||||
srcs = ["association_calculator_test.cc"],
|
||||
deps = [
|
||||
":association_detection_calculator",
|
||||
":association_norm_rect_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/deps:message_matchers",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "mediapipe/framework/port/vector.h"
|
||||
#include "mediapipe/util/annotation_renderer.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
|
@ -41,6 +42,8 @@ namespace {
|
|||
constexpr char kInputFrameTag[] = "INPUT_FRAME";
|
||||
constexpr char kOutputFrameTag[] = "OUTPUT_FRAME";
|
||||
|
||||
constexpr char kInputVectorTag[] = "VECTOR";
|
||||
|
||||
constexpr char kInputFrameTagGpu[] = "INPUT_FRAME_GPU";
|
||||
constexpr char kOutputFrameTagGpu[] = "OUTPUT_FRAME_GPU";
|
||||
|
||||
|
@ -65,6 +68,9 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
|
|||
// 2. RenderData proto on variable number of input streams. All the RenderData
|
||||
// at a particular timestamp is drawn on the image in the order of their
|
||||
// input streams. No tags required.
|
||||
// 3. std::vector<RenderData> on variable number of input streams. RenderData
|
||||
// objects at a particular timestamp are drawn on the image in order of the
|
||||
// input vector items. These input streams are tagged with "VECTOR".
|
||||
//
|
||||
// Output:
|
||||
// 1. OUTPUT_FRAME or OUTPUT_FRAME_GPU: A rendered ImageFrame (or GpuBuffer).
|
||||
|
@ -85,6 +91,8 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
|
|||
// input_stream: "render_data_1"
|
||||
// input_stream: "render_data_2"
|
||||
// input_stream: "render_data_3"
|
||||
// input_stream: "VECTOR:0:render_data_vec_0"
|
||||
// input_stream: "VECTOR:1:render_data_vec_1"
|
||||
// output_stream: "OUTPUT_FRAME:decorated_frames"
|
||||
// options {
|
||||
// [mediapipe.AnnotationOverlayCalculatorOptions.ext] {
|
||||
|
@ -99,6 +107,8 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
|
|||
// input_stream: "render_data_1"
|
||||
// input_stream: "render_data_2"
|
||||
// input_stream: "render_data_3"
|
||||
// input_stream: "VECTOR:0:render_data_vec_0"
|
||||
// input_stream: "VECTOR:1:render_data_vec_1"
|
||||
// output_stream: "OUTPUT_FRAME_GPU:decorated_frames"
|
||||
// options {
|
||||
// [mediapipe.AnnotationOverlayCalculatorOptions.ext] {
|
||||
|
@ -138,9 +148,6 @@ class AnnotationOverlayCalculator : public CalculatorBase {
|
|||
// Underlying helper renderer library.
|
||||
std::unique_ptr<AnnotationRenderer> renderer_;
|
||||
|
||||
// Number of input streams with render data.
|
||||
int num_render_streams_;
|
||||
|
||||
// Indicates if image frame is available as input.
|
||||
bool image_frame_available_ = false;
|
||||
|
||||
|
@ -171,25 +178,28 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
|
|||
return ::mediapipe::InternalError("GPU output must have GPU input.");
|
||||
}
|
||||
|
||||
// Assume all inputs are render streams; adjust below.
|
||||
int num_render_streams = cc->Inputs().NumEntries();
|
||||
|
||||
// Input image to render onto copy of.
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
if (cc->Inputs().HasTag(kInputFrameTagGpu)) {
|
||||
cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>();
|
||||
num_render_streams = cc->Inputs().NumEntries() - 1;
|
||||
use_gpu |= true;
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
if (cc->Inputs().HasTag(kInputFrameTag)) {
|
||||
cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>();
|
||||
num_render_streams = cc->Inputs().NumEntries() - 1;
|
||||
}
|
||||
|
||||
// Data streams to render.
|
||||
for (int i = 0; i < num_render_streams; ++i) {
|
||||
cc->Inputs().Index(i).Set<RenderData>();
|
||||
for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId();
|
||||
++id) {
|
||||
auto tag_and_index = cc->Inputs().TagAndIndexFromId(id);
|
||||
std::string tag = tag_and_index.first;
|
||||
if (tag == kInputVectorTag) {
|
||||
cc->Inputs().Get(id).Set<std::vector<RenderData>>();
|
||||
} else if (tag.empty()) {
|
||||
// Empty tag defaults to accepting a single object of RenderData type.
|
||||
cc->Inputs().Get(id).Set<RenderData>();
|
||||
}
|
||||
}
|
||||
|
||||
// Rendered image.
|
||||
|
@ -228,12 +238,10 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
|
|||
if (cc->Inputs().HasTag(kInputFrameTagGpu) ||
|
||||
cc->Inputs().HasTag(kInputFrameTag)) {
|
||||
image_frame_available_ = true;
|
||||
num_render_streams_ = cc->Inputs().NumEntries() - 1;
|
||||
} else {
|
||||
image_frame_available_ = false;
|
||||
RET_CHECK(options_.has_canvas_width_px());
|
||||
RET_CHECK(options_.has_canvas_height_px());
|
||||
num_render_streams_ = cc->Inputs().NumEntries();
|
||||
}
|
||||
|
||||
// Initialize the helper renderer library.
|
||||
|
@ -285,12 +293,28 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
|
|||
renderer_->AdoptImage(image_mat.get());
|
||||
|
||||
// Render streams onto render target.
|
||||
for (int i = 0; i < num_render_streams_; ++i) {
|
||||
if (cc->Inputs().Index(i).IsEmpty()) {
|
||||
for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId();
|
||||
++id) {
|
||||
auto tag_and_index = cc->Inputs().TagAndIndexFromId(id);
|
||||
std::string tag = tag_and_index.first;
|
||||
if (!tag.empty() && tag != kInputVectorTag) {
|
||||
continue;
|
||||
}
|
||||
const RenderData& render_data = cc->Inputs().Index(i).Get<RenderData>();
|
||||
if (cc->Inputs().Get(id).IsEmpty()) {
|
||||
continue;
|
||||
}
|
||||
if (tag.empty()) {
|
||||
// Empty tag defaults to accepting a single object of RenderData type.
|
||||
const RenderData& render_data = cc->Inputs().Get(id).Get<RenderData>();
|
||||
renderer_->RenderDataOnImage(render_data);
|
||||
} else {
|
||||
RET_CHECK_EQ(kInputVectorTag, tag);
|
||||
const std::vector<RenderData>& render_data_vec =
|
||||
cc->Inputs().Get(id).Get<std::vector<RenderData>>();
|
||||
for (const RenderData& render_data : render_data_vec) {
|
||||
renderer_->RenderDataOnImage(render_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (use_gpu_) {
|
||||
|
|
259
mediapipe/calculators/util/association_calculator.h
Normal file
|
@ -0,0 +1,259 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/calculators/util/association_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/rectangle.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Computes the overlap similarity based on Intersection over Union (IoU) of
|
||||
// two rectangles.
|
||||
inline float OverlapSimilarity(const Rectangle_f& rect1,
|
||||
const Rectangle_f& rect2) {
|
||||
if (!rect1.Intersects(rect2)) return 0.0f;
|
||||
// Compute IoU similarity score.
|
||||
const float intersection_area = Rectangle_f(rect1).Intersect(rect2).Area();
|
||||
const float normalization = rect1.Area() + rect2.Area() - intersection_area;
|
||||
return normalization > 0.0f ? intersection_area / normalization : 0.0f;
|
||||
}
|
||||
|
||||
// AssocationCalculator<T> accepts multiple inputs of vectors of type T that can
|
||||
// be converted to Rectangle_f. The output is a vector of type T that contains
|
||||
// elements from the input vectors that don't overlap with each other. When
|
||||
// two elements overlap, the element that comes in from a later input stream
|
||||
// is kept in the output. This association operation is useful for multiple
|
||||
// instance inference pipelines in MediaPipe.
|
||||
// If an input stream is tagged with "PREV" tag, IDs of overlapping elements
|
||||
// from "PREV" input stream are propagated to the output. Elements in the "PREV"
|
||||
// input stream that don't overlap with other elements are not added to the
|
||||
// output. This stream is designed to take detections from previous timestamp,
|
||||
// e.g. output of PreviousLoopbackCalculator to provide temporal association.
|
||||
// See AssociationDetectionCalculator and AssociationNormRectCalculator for
|
||||
// example uses.
|
||||
template <typename T>
|
||||
class AssociationCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
// Atmost one input stream can be tagged with "PREV".
|
||||
RET_CHECK_LE(cc->Inputs().NumEntries("PREV"), 1);
|
||||
|
||||
if (cc->Inputs().HasTag("PREV")) {
|
||||
RET_CHECK_GE(cc->Inputs().NumEntries(), 2);
|
||||
}
|
||||
|
||||
for (CollectionItemId id = cc->Inputs().BeginId();
|
||||
id < cc->Inputs().EndId(); ++id) {
|
||||
cc->Inputs().Get(id).Set<std::vector<T>>();
|
||||
}
|
||||
|
||||
cc->Outputs().Index(0).Set<std::vector<T>>();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
has_prev_input_stream_ = cc->Inputs().HasTag("PREV");
|
||||
if (has_prev_input_stream_) {
|
||||
prev_input_stream_id_ = cc->Inputs().GetId("PREV", 0);
|
||||
}
|
||||
options_ = cc->Options<::mediapipe::AssociationCalculatorOptions>();
|
||||
CHECK_GE(options_.min_similarity_threshold(), 0);
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
auto get_non_overlapping_elements = GetNonOverlappingElements(cc);
|
||||
if (!get_non_overlapping_elements.ok()) {
|
||||
return get_non_overlapping_elements.status();
|
||||
}
|
||||
std::list<T> result = get_non_overlapping_elements.ValueOrDie();
|
||||
|
||||
if (has_prev_input_stream_ &&
|
||||
!cc->Inputs().Get(prev_input_stream_id_).IsEmpty()) {
|
||||
// Processed all regular input streams. Now compare the result list
|
||||
// elements with those in the PREV input stream, and propagate IDs from
|
||||
// PREV input stream as appropriate.
|
||||
const std::vector<T>& prev_input_vec =
|
||||
cc->Inputs()
|
||||
.Get(prev_input_stream_id_)
|
||||
.template Get<std::vector<T>>();
|
||||
|
||||
MP_RETURN_IF_ERROR(
|
||||
PropagateIdsFromPreviousToCurrent(prev_input_vec, &result));
|
||||
}
|
||||
|
||||
auto output = absl::make_unique<std::vector<T>>();
|
||||
for (auto it = result.begin(); it != result.end(); ++it) {
|
||||
output->push_back(*it);
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
protected:
|
||||
::mediapipe::AssociationCalculatorOptions options_;
|
||||
|
||||
bool has_prev_input_stream_;
|
||||
CollectionItemId prev_input_stream_id_;
|
||||
|
||||
virtual ::mediapipe::StatusOr<Rectangle_f> GetRectangle(const T& input) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
virtual std::pair<bool, int> GetId(const T& input) { return {false, -1}; }
|
||||
|
||||
virtual void SetId(T* input, int id) {}
|
||||
|
||||
private:
|
||||
// Get a list of non-overlapping elements from all input streams, with
|
||||
// increasing order of priority based on input stream index.
|
||||
mediapipe::StatusOr<std::list<T>> GetNonOverlappingElements(
|
||||
CalculatorContext* cc) {
|
||||
std::list<T> result;
|
||||
|
||||
// Initialize result with the first non-empty input vector.
|
||||
CollectionItemId non_empty_id = cc->Inputs().BeginId();
|
||||
for (CollectionItemId id = cc->Inputs().BeginId();
|
||||
id < cc->Inputs().EndId(); ++id) {
|
||||
if (id == prev_input_stream_id_ || cc->Inputs().Get(id).IsEmpty()) {
|
||||
continue;
|
||||
}
|
||||
const std::vector<T>& input_vec =
|
||||
cc->Inputs().Get(id).Get<std::vector<T>>();
|
||||
if (!input_vec.empty()) {
|
||||
non_empty_id = id;
|
||||
result.push_back(input_vec[0]);
|
||||
for (int j = 1; j < input_vec.size(); ++j) {
|
||||
MP_RETURN_IF_ERROR(AddElementToList(input_vec[j], &result));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Compare remaining input vectors with the non-empty result vector,
|
||||
// remove lower-priority overlapping elements from the result vector and
|
||||
// had corresponding higher-priority elements as necessary.
|
||||
for (CollectionItemId id = non_empty_id + 1; id < cc->Inputs().EndId();
|
||||
++id) {
|
||||
if (id == prev_input_stream_id_ || cc->Inputs().Get(id).IsEmpty()) {
|
||||
continue;
|
||||
}
|
||||
const std::vector<T>& input_vec =
|
||||
cc->Inputs().Get(id).Get<std::vector<T>>();
|
||||
|
||||
for (int vi = 0; vi < input_vec.size(); ++vi) {
|
||||
MP_RETURN_IF_ERROR(AddElementToList(input_vec[vi], &result));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
::mediapipe::Status AddElementToList(T element, std::list<T>* current) {
|
||||
// Compare this element with elements of the input collection. If this
|
||||
// element has high overlap with elements of the collection, remove
|
||||
// those elements from the collection and add this element.
|
||||
ASSIGN_OR_RETURN(auto cur_rect, GetRectangle(element));
|
||||
|
||||
bool change_id = false;
|
||||
int new_elem_id = -1;
|
||||
|
||||
for (auto uit = current->begin(); uit != current->end();) {
|
||||
ASSIGN_OR_RETURN(auto prev_rect, GetRectangle(*uit));
|
||||
if (OverlapSimilarity(cur_rect, prev_rect) >
|
||||
options_.min_similarity_threshold()) {
|
||||
std::pair<bool, int> prev_id = GetId(*uit);
|
||||
// If prev_id.first is false when some element doesn't have an ID,
|
||||
// change_id and new_elem_id will not be updated.
|
||||
if (prev_id.first) {
|
||||
change_id = prev_id.first;
|
||||
new_elem_id = prev_id.second;
|
||||
}
|
||||
uit = current->erase(uit);
|
||||
} else {
|
||||
++uit;
|
||||
}
|
||||
}
|
||||
|
||||
if (change_id) {
|
||||
SetId(&element, new_elem_id);
|
||||
}
|
||||
current->push_back(element);
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Compare elements of the current list with elements in from the collection
|
||||
// of elements from the previous input stream, and propagate IDs from the
|
||||
// previous input stream as appropriate.
|
||||
::mediapipe::Status PropagateIdsFromPreviousToCurrent(
|
||||
const std::vector<T>& prev_input_vec, std::list<T>* current) {
|
||||
for (auto vit = current->begin(); vit != current->end(); ++vit) {
|
||||
auto get_cur_rectangle = GetRectangle(*vit);
|
||||
if (!get_cur_rectangle.ok()) {
|
||||
return get_cur_rectangle.status();
|
||||
}
|
||||
const Rectangle_f& cur_rect = get_cur_rectangle.ValueOrDie();
|
||||
|
||||
bool change_id = false;
|
||||
int id_for_vi = -1;
|
||||
|
||||
for (int ui = 0; ui < prev_input_vec.size(); ++ui) {
|
||||
auto get_prev_rectangle = GetRectangle(prev_input_vec[ui]);
|
||||
if (!get_prev_rectangle.ok()) {
|
||||
return get_prev_rectangle.status();
|
||||
}
|
||||
const Rectangle_f& prev_rect = get_prev_rectangle.ValueOrDie();
|
||||
|
||||
if (OverlapSimilarity(cur_rect, prev_rect) >
|
||||
options_.min_similarity_threshold()) {
|
||||
std::pair<bool, int> prev_id = GetId(prev_input_vec[ui]);
|
||||
// If prev_id.first is false when some element doesn't have an ID,
|
||||
// change_id and id_for_vi will not be updated.
|
||||
if (prev_id.first) {
|
||||
change_id = prev_id.first;
|
||||
id_for_vi = prev_id.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (change_id) {
|
||||
T element = *vit;
|
||||
SetId(&element, id_for_vi);
|
||||
*vit = element;
|
||||
}
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_
|
27
mediapipe/calculators/util/association_calculator.proto
Normal file
|
@ -0,0 +1,27 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message AssociationCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional AssociationCalculatorOptions ext = 275124847;
|
||||
}
|
||||
|
||||
optional float min_similarity_threshold = 1 [default = 1.0];
|
||||
}
|
476
mediapipe/calculators/util/association_calculator_test.cc
Normal file
|
@ -0,0 +1,476 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/deps/message_matchers.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
::mediapipe::Detection DetectionWithRelativeLocationData(double xmin,
|
||||
double ymin,
|
||||
double width,
|
||||
double height) {
|
||||
::mediapipe::Detection detection;
|
||||
::mediapipe::LocationData* location_data = detection.mutable_location_data();
|
||||
location_data->set_format(::mediapipe::LocationData::RELATIVE_BOUNDING_BOX);
|
||||
location_data->mutable_relative_bounding_box()->set_xmin(xmin);
|
||||
location_data->mutable_relative_bounding_box()->set_ymin(ymin);
|
||||
location_data->mutable_relative_bounding_box()->set_width(width);
|
||||
location_data->mutable_relative_bounding_box()->set_height(height);
|
||||
return detection;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class AssociationDetectionCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
AssociationDetectionCalculatorTest() {
|
||||
// 0.4 ================
|
||||
// | | | |
|
||||
// 0.3 ===================== | DET2 | |
|
||||
// | | | DET1 | | | DET4 |
|
||||
// 0.2 | DET0 | =========== ================
|
||||
// | | | | | |
|
||||
// 0.1 =====|=============== |
|
||||
// | DET3 | | |
|
||||
// 0.0 ================ |
|
||||
// | DET5 |
|
||||
// -0.1 ===========
|
||||
// 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2
|
||||
|
||||
// Detection det_0.
|
||||
det_0 = DetectionWithRelativeLocationData(/*xmin=*/0.1, /*ymin=*/0.1,
|
||||
/*width=*/0.2, /*height=*/0.2);
|
||||
det_0.set_detection_id(0);
|
||||
|
||||
// Detection det_1.
|
||||
det_1 = DetectionWithRelativeLocationData(/*xmin=*/0.3, /*ymin=*/0.1,
|
||||
/*width=*/0.2, /*height=*/0.2);
|
||||
det_1.set_detection_id(1);
|
||||
|
||||
// Detection det_2.
|
||||
det_2 = DetectionWithRelativeLocationData(/*xmin=*/0.9, /*ymin=*/0.2,
|
||||
/*width=*/0.2, /*height=*/0.2);
|
||||
det_2.set_detection_id(2);
|
||||
|
||||
// Detection det_3.
|
||||
det_3 = DetectionWithRelativeLocationData(/*xmin=*/0.2, /*ymin=*/0.0,
|
||||
/*width=*/0.3, /*height=*/0.3);
|
||||
det_3.set_detection_id(3);
|
||||
|
||||
// Detection det_4.
|
||||
det_4 = DetectionWithRelativeLocationData(/*xmin=*/1.0, /*ymin=*/0.2,
|
||||
/*width=*/0.2, /*height=*/0.2);
|
||||
det_4.set_detection_id(4);
|
||||
|
||||
// Detection det_5.
|
||||
det_5 = DetectionWithRelativeLocationData(/*xmin=*/0.3, /*ymin=*/-0.1,
|
||||
/*width=*/0.3, /*height=*/0.3);
|
||||
det_5.set_detection_id(5);
|
||||
}
|
||||
|
||||
::mediapipe::Detection det_0, det_1, det_2, det_3, det_4, det_5;
|
||||
};
|
||||
|
||||
TEST_F(AssociationDetectionCalculatorTest, DetectionAssocTest) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "AssociationDetectionCalculator"
|
||||
input_stream: "input_vec_0"
|
||||
input_stream: "input_vec_1"
|
||||
input_stream: "input_vec_2"
|
||||
output_stream: "output_vec"
|
||||
options {
|
||||
[mediapipe.AssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
)"));
|
||||
|
||||
// Input Stream 0: det_0, det_1, det_2.
|
||||
auto input_vec_0 = absl::make_unique<std::vector<::mediapipe::Detection>>();
|
||||
input_vec_0->push_back(det_0);
|
||||
input_vec_0->push_back(det_1);
|
||||
input_vec_0->push_back(det_2);
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 1: det_3, det_4.
|
||||
auto input_vec_1 = absl::make_unique<std::vector<::mediapipe::Detection>>();
|
||||
input_vec_1->push_back(det_3);
|
||||
input_vec_1->push_back(det_4);
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 2: det_5.
|
||||
auto input_vec_2 = absl::make_unique<std::vector<::mediapipe::Detection>>();
|
||||
input_vec_2->push_back(det_5);
|
||||
runner.MutableInputs()->Index(2).packets.push_back(
|
||||
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output.size());
|
||||
const auto& assoc_rects =
|
||||
output[0].Get<std::vector<::mediapipe::Detection>>();
|
||||
|
||||
// det_3 overlaps with det_0, det_1 and det_5 overlaps with det_3. Since det_5
|
||||
// is in the highest priority, we remove other rects. det_4 overlaps with
|
||||
// det_2, and det_4 is higher priority, so we keep it. The final output
|
||||
// therefore contains 2 elements.
|
||||
EXPECT_EQ(2, assoc_rects.size());
|
||||
// Outputs are in order of inputs, so det_4 is before det_5 in output vector.
|
||||
|
||||
// det_4 overlaps with det_2, so new id for det_4 is 2.
|
||||
EXPECT_TRUE(assoc_rects[0].has_detection_id());
|
||||
EXPECT_EQ(2, assoc_rects[0].detection_id());
|
||||
det_4.set_detection_id(2);
|
||||
EXPECT_THAT(assoc_rects[0], EqualsProto(det_4));
|
||||
|
||||
// det_3 overlaps with det_0, so new id for det_3 is 0.
|
||||
// det_3 overlaps with det_1, so new id for det_3 is 1.
|
||||
// det_5 overlaps with det_3, so new id for det_5 is 1.
|
||||
EXPECT_TRUE(assoc_rects[1].has_detection_id());
|
||||
EXPECT_EQ(1, assoc_rects[1].detection_id());
|
||||
det_5.set_detection_id(1);
|
||||
EXPECT_THAT(assoc_rects[1], EqualsProto(det_5));
|
||||
}
|
||||
|
||||
TEST_F(AssociationDetectionCalculatorTest, DetectionAssocTestWithPrev) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "AssociationDetectionCalculator"
|
||||
input_stream: "PREV:input_vec_0"
|
||||
input_stream: "input_vec_1"
|
||||
output_stream: "output_vec"
|
||||
options {
|
||||
[mediapipe.AssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
)"));
|
||||
|
||||
// Input Stream 0: det_3, det_4.
|
||||
auto input_vec_0 = absl::make_unique<std::vector<::mediapipe::Detection>>();
|
||||
input_vec_0->push_back(det_3);
|
||||
input_vec_0->push_back(det_4);
|
||||
CollectionItemId prev_input_stream_id =
|
||||
runner.MutableInputs()->GetId("PREV", 0);
|
||||
runner.MutableInputs()
|
||||
->Get(prev_input_stream_id)
|
||||
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 1: det_5.
|
||||
auto input_vec_1 = absl::make_unique<std::vector<::mediapipe::Detection>>();
|
||||
input_vec_1->push_back(det_5);
|
||||
CollectionItemId input_stream_id = runner.MutableInputs()->GetId("", 0);
|
||||
runner.MutableInputs()
|
||||
->Get(input_stream_id)
|
||||
.packets.push_back(Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output.size());
|
||||
const auto& assoc_rects =
|
||||
output[0].Get<std::vector<::mediapipe::Detection>>();
|
||||
|
||||
// det_5 overlaps with det_3 and doesn't overlap with det_4. Since det_4 is
|
||||
// in the PREV input stream, it doesn't get copied to the output, so the final
|
||||
// output contains 1 element.
|
||||
EXPECT_EQ(1, assoc_rects.size());
|
||||
|
||||
// det_5 overlaps with det_3, det_3 is in PREV, so new id for det_5 is 3.
|
||||
EXPECT_TRUE(assoc_rects[0].has_detection_id());
|
||||
EXPECT_EQ(3, assoc_rects[0].detection_id());
|
||||
det_5.set_detection_id(3);
|
||||
EXPECT_THAT(assoc_rects[0], EqualsProto(det_5));
|
||||
}
|
||||
|
||||
TEST_F(AssociationDetectionCalculatorTest, DetectionAssocTestReverse) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "AssociationDetectionCalculator"
|
||||
input_stream: "input_vec_0"
|
||||
input_stream: "input_vec_1"
|
||||
input_stream: "input_vec_2"
|
||||
output_stream: "output_vec"
|
||||
options {
|
||||
[mediapipe.AssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
)"));
|
||||
|
||||
// Input Stream 0: det_5.
|
||||
auto input_vec_0 = absl::make_unique<std::vector<::mediapipe::Detection>>();
|
||||
input_vec_0->push_back(det_5);
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 1: det_3, det_4.
|
||||
auto input_vec_1 = absl::make_unique<std::vector<::mediapipe::Detection>>();
|
||||
input_vec_1->push_back(det_3);
|
||||
input_vec_1->push_back(det_4);
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 2: det_0, det_1, det_2.
|
||||
auto input_vec_2 = absl::make_unique<std::vector<::mediapipe::Detection>>();
|
||||
input_vec_2->push_back(det_0);
|
||||
input_vec_2->push_back(det_1);
|
||||
input_vec_2->push_back(det_2);
|
||||
runner.MutableInputs()->Index(2).packets.push_back(
|
||||
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output.size());
|
||||
const auto& assoc_rects =
|
||||
output[0].Get<std::vector<::mediapipe::Detection>>();
|
||||
|
||||
// det_3 overlaps with det_5, so det_5 is removed. det_0 overlaps with det_3,
|
||||
// so det_3 is removed as det_0 is in higher priority for keeping. det_2
|
||||
// overlaps with det_4 so det_4 is removed as det_2 is higher priority for
|
||||
// keeping. The final output therefore contains 3 elements.
|
||||
EXPECT_EQ(3, assoc_rects.size());
|
||||
// Outputs are in same order as inputs.
|
||||
|
||||
// det_3 overlaps with det_5, so new id for det_3 is 5.
|
||||
// det_0 overlaps with det_3, so new id for det_0 is 5.
|
||||
EXPECT_TRUE(assoc_rects[0].has_detection_id());
|
||||
EXPECT_EQ(5, assoc_rects[0].detection_id());
|
||||
det_0.set_detection_id(5);
|
||||
EXPECT_THAT(assoc_rects[0], EqualsProto(det_0));
|
||||
|
||||
// det_1 stays with id 1.
|
||||
EXPECT_TRUE(assoc_rects[1].has_detection_id());
|
||||
EXPECT_EQ(1, assoc_rects[1].detection_id());
|
||||
EXPECT_THAT(assoc_rects[1], EqualsProto(det_1));
|
||||
|
||||
// det_2 overlaps with det_4, so new id for det_2 is 4.
|
||||
EXPECT_TRUE(assoc_rects[2].has_detection_id());
|
||||
EXPECT_EQ(4, assoc_rects[2].detection_id());
|
||||
det_2.set_detection_id(4);
|
||||
EXPECT_THAT(assoc_rects[2], EqualsProto(det_2));
|
||||
}
|
||||
|
||||
class AssociationNormRectCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
AssociationNormRectCalculatorTest() {
|
||||
// 0.4 ================
|
||||
// | | | |
|
||||
// 0.3 ===================== | NR2 | |
|
||||
// | | | NR1 | | | NR4 |
|
||||
// 0.2 | NR0 | =========== ================
|
||||
// | | | | | |
|
||||
// 0.1 =====|=============== |
|
||||
// | NR3 | | |
|
||||
// 0.0 ================ |
|
||||
// | NR5 |
|
||||
// -0.1 ===========
|
||||
// 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2
|
||||
|
||||
// NormalizedRect nr_0.
|
||||
nr_0.set_x_center(0.2);
|
||||
nr_0.set_y_center(0.2);
|
||||
nr_0.set_width(0.2);
|
||||
nr_0.set_height(0.2);
|
||||
|
||||
// NormalizedRect nr_1.
|
||||
nr_1.set_x_center(0.4);
|
||||
nr_1.set_y_center(0.2);
|
||||
nr_1.set_width(0.2);
|
||||
nr_1.set_height(0.2);
|
||||
|
||||
// NormalizedRect nr_2.
|
||||
nr_2.set_x_center(1.0);
|
||||
nr_2.set_y_center(0.3);
|
||||
nr_2.set_width(0.2);
|
||||
nr_2.set_height(0.2);
|
||||
|
||||
// NormalizedRect nr_3.
|
||||
nr_3.set_x_center(0.35);
|
||||
nr_3.set_y_center(0.15);
|
||||
nr_3.set_width(0.3);
|
||||
nr_3.set_height(0.3);
|
||||
|
||||
// NormalizedRect nr_4.
|
||||
nr_4.set_x_center(1.1);
|
||||
nr_4.set_y_center(0.3);
|
||||
nr_4.set_width(0.2);
|
||||
nr_4.set_height(0.2);
|
||||
|
||||
// NormalizedRect nr_5.
|
||||
nr_5.set_x_center(0.45);
|
||||
nr_5.set_y_center(0.05);
|
||||
nr_5.set_width(0.3);
|
||||
nr_5.set_height(0.3);
|
||||
}
|
||||
|
||||
::mediapipe::NormalizedRect nr_0, nr_1, nr_2, nr_3, nr_4, nr_5;
|
||||
};
|
||||
|
||||
TEST_F(AssociationNormRectCalculatorTest, NormRectAssocTest) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "AssociationNormRectCalculator"
|
||||
input_stream: "input_vec_0"
|
||||
input_stream: "input_vec_1"
|
||||
input_stream: "input_vec_2"
|
||||
output_stream: "output_vec"
|
||||
options {
|
||||
[mediapipe.AssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
)"));
|
||||
|
||||
// Input Stream 0: nr_0, nr_1, nr_2.
|
||||
auto input_vec_0 =
|
||||
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
|
||||
input_vec_0->push_back(nr_0);
|
||||
input_vec_0->push_back(nr_1);
|
||||
input_vec_0->push_back(nr_2);
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 1: nr_3, nr_4.
|
||||
auto input_vec_1 =
|
||||
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
|
||||
input_vec_1->push_back(nr_3);
|
||||
input_vec_1->push_back(nr_4);
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 2: nr_5.
|
||||
auto input_vec_2 =
|
||||
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
|
||||
input_vec_2->push_back(nr_5);
|
||||
runner.MutableInputs()->Index(2).packets.push_back(
|
||||
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output.size());
|
||||
const auto& assoc_rects =
|
||||
output[0].Get<std::vector<::mediapipe::NormalizedRect>>();
|
||||
|
||||
// nr_3 overlaps with nr_0, nr_1 and nr_5 overlaps with nr_3. Since nr_5 is
|
||||
// in the highest priority, we remove other rects.
|
||||
// nr_4 overlaps with nr_2, and nr_4 is higher priority, so we keep it.
|
||||
// The final output therefore contains 2 elements.
|
||||
EXPECT_EQ(2, assoc_rects.size());
|
||||
// Outputs are in order of inputs, so nr_4 is before nr_5 in output vector.
|
||||
EXPECT_THAT(assoc_rects[0], EqualsProto(nr_4));
|
||||
EXPECT_THAT(assoc_rects[1], EqualsProto(nr_5));
|
||||
}
|
||||
|
||||
TEST_F(AssociationNormRectCalculatorTest, NormRectAssocTestReverse) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "AssociationNormRectCalculator"
|
||||
input_stream: "input_vec_0"
|
||||
input_stream: "input_vec_1"
|
||||
input_stream: "input_vec_2"
|
||||
output_stream: "output_vec"
|
||||
options {
|
||||
[mediapipe.AssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
)"));
|
||||
|
||||
// Input Stream 0: nr_5.
|
||||
auto input_vec_0 =
|
||||
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
|
||||
input_vec_0->push_back(nr_5);
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 1: nr_3, nr_4.
|
||||
auto input_vec_1 =
|
||||
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
|
||||
input_vec_1->push_back(nr_3);
|
||||
input_vec_1->push_back(nr_4);
|
||||
runner.MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||
|
||||
// Input Stream 2: nr_0, nr_1, nr_2.
|
||||
auto input_vec_2 =
|
||||
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
|
||||
input_vec_2->push_back(nr_0);
|
||||
input_vec_2->push_back(nr_1);
|
||||
input_vec_2->push_back(nr_2);
|
||||
runner.MutableInputs()->Index(2).packets.push_back(
|
||||
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output.size());
|
||||
const auto& assoc_rects =
|
||||
output[0].Get<std::vector<::mediapipe::NormalizedRect>>();
|
||||
|
||||
// nr_3 overlaps with nr_5, so nr_5 is removed. nr_0 overlaps with nr_3, so
|
||||
// nr_3 is removed as nr_0 is in higher priority for keeping. nr_2 overlaps
|
||||
// with nr_4 so nr_4 is removed as nr_2 is higher priority for keeping.
|
||||
// The final output therefore contains 3 elements.
|
||||
EXPECT_EQ(3, assoc_rects.size());
|
||||
// Outputs are in same order as inputs.
|
||||
EXPECT_THAT(assoc_rects[0], EqualsProto(nr_0));
|
||||
EXPECT_THAT(assoc_rects[1], EqualsProto(nr_1));
|
||||
EXPECT_THAT(assoc_rects[2], EqualsProto(nr_2));
|
||||
}
|
||||
|
||||
TEST_F(AssociationNormRectCalculatorTest, NormRectAssocSingleInputStream) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "AssociationNormRectCalculator"
|
||||
input_stream: "input_vec"
|
||||
output_stream: "output_vec"
|
||||
options {
|
||||
[mediapipe.AssociationCalculatorOptions.ext] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
)"));
|
||||
|
||||
// Input Stream : nr_3, nr_5.
|
||||
auto input_vec =
|
||||
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
|
||||
input_vec->push_back(nr_3);
|
||||
input_vec->push_back(nr_5);
|
||||
runner.MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(input_vec.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output.size());
|
||||
const auto& assoc_rects =
|
||||
output[0].Get<std::vector<::mediapipe::NormalizedRect>>();
|
||||
|
||||
// nr_5 overlaps with nr_3. Since nr_5 is after nr_3 in the same input stream
|
||||
// we remove nr_3 and keep nr_5.
|
||||
// The final output therefore contains 1 elements.
|
||||
EXPECT_EQ(1, assoc_rects.size());
|
||||
EXPECT_THAT(assoc_rects[0], EqualsProto(nr_5));
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,77 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/util/association_calculator.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/location.h"
|
||||
#include "mediapipe/framework/port/rectangle.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A subclass of AssociationCalculator<T> for Detection. Example:
|
||||
// node {
|
||||
// calculator: "AssociationDetectionCalculator"
|
||||
// input_stream: "PREV:input_vec_0"
|
||||
// input_stream: "input_vec_1"
|
||||
// input_stream: "input_vec_2"
|
||||
// output_stream: "output_vec"
|
||||
// options {
|
||||
// [mediapipe.AssociationCalculatorOptions.ext] {
|
||||
// min_similarity_threshold: 0.1
|
||||
// }
|
||||
// }
|
||||
class AssociationDetectionCalculator
|
||||
: public AssociationCalculator<::mediapipe::Detection> {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
return AssociationCalculator<::mediapipe::Detection>::GetContract(cc);
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
return AssociationCalculator<::mediapipe::Detection>::Open(cc);
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
return AssociationCalculator<::mediapipe::Detection>::Process(cc);
|
||||
}
|
||||
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override {
|
||||
return AssociationCalculator<::mediapipe::Detection>::Close(cc);
|
||||
}
|
||||
|
||||
protected:
|
||||
::mediapipe::StatusOr<Rectangle_f> GetRectangle(
|
||||
const ::mediapipe::Detection& input) override {
|
||||
if (!input.has_location_data()) {
|
||||
return ::mediapipe::InternalError("Missing location_data in Detection");
|
||||
}
|
||||
const Location location(input.location_data());
|
||||
return location.GetRelativeBBox();
|
||||
}
|
||||
|
||||
std::pair<bool, int> GetId(const ::mediapipe::Detection& input) override {
|
||||
return {input.has_detection_id(), input.detection_id()};
|
||||
}
|
||||
|
||||
void SetId(::mediapipe::Detection* input, int id) override {
|
||||
input->set_detection_id(id);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(AssociationDetectionCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,72 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/util/association_calculator.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/rectangle.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A subclass of AssociationCalculator<T> for NormalizedRect. Example use case:
|
||||
// node {
|
||||
// calculator: "AssociationNormRectCalculator"
|
||||
// input_stream: "input_vec_0"
|
||||
// input_stream: "input_vec_1"
|
||||
// input_stream: "input_vec_2"
|
||||
// output_stream: "output_vec"
|
||||
// options {
|
||||
// [mediapipe.AssociationCalculatorOptions.ext] {
|
||||
// min_similarity_threshold: 0.1
|
||||
// }
|
||||
// }
|
||||
class AssociationNormRectCalculator
|
||||
: public AssociationCalculator<::mediapipe::NormalizedRect> {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
return AssociationCalculator<::mediapipe::NormalizedRect>::GetContract(cc);
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
return AssociationCalculator<::mediapipe::NormalizedRect>::Open(cc);
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
return AssociationCalculator<::mediapipe::NormalizedRect>::Process(cc);
|
||||
}
|
||||
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override {
|
||||
return AssociationCalculator<::mediapipe::NormalizedRect>::Close(cc);
|
||||
}
|
||||
|
||||
protected:
|
||||
::mediapipe::StatusOr<Rectangle_f> GetRectangle(
|
||||
const ::mediapipe::NormalizedRect& input) override {
|
||||
if (!input.has_x_center() || !input.has_y_center() || !input.has_width() ||
|
||||
!input.has_height()) {
|
||||
return ::mediapipe::InternalError(
|
||||
"Missing dimensions in NormalizedRect.");
|
||||
}
|
||||
const float xmin = input.x_center() - input.width() / 2.0;
|
||||
const float ymin = input.y_center() - input.height() / 2.0;
|
||||
// TODO: Support rotation for rectangle.
|
||||
return Rectangle_f(xmin, ymin, input.width(), input.height());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(AssociationNormRectCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,26 @@
|
|||
|
||||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/util/collection_has_min_size_calculator.h"
|
||||
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef CollectionHasMinSizeCalculator<std::vector<::mediapipe::NormalizedRect>>
|
||||
NormalizedRectVectorHasMinSizeCalculator;
|
||||
REGISTER_CALCULATOR(NormalizedRectVectorHasMinSizeCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_UTIL_COLLECTION_HAS_MIN_SIZE_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_UTIL_COLLECTION_HAS_MIN_SIZE_CALCULATOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/util/collection_has_min_size_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Deterimines if an input iterable collection has a minimum size, specified
|
||||
// in CollectionHasMinSizeCalculatorOptions. Example usage:
|
||||
// node {
|
||||
// calculator: "IntVectorHasMinSizeCalculator"
|
||||
// input_stream: "ITERABLE:input_int_vector"
|
||||
// output_stream: "has_min_ints"
|
||||
// options {
|
||||
// [mediapipe.CollectionHasMinSizeCalculatorOptions.ext] {
|
||||
// min_size: 2
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
template <typename IterableT>
|
||||
class CollectionHasMinSizeCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("ITERABLE"));
|
||||
RET_CHECK_EQ(1, cc->Inputs().NumEntries());
|
||||
|
||||
RET_CHECK_EQ(1, cc->Outputs().NumEntries());
|
||||
|
||||
RET_CHECK_GE(
|
||||
cc->Options<::mediapipe::CollectionHasMinSizeCalculatorOptions>()
|
||||
.min_size(),
|
||||
0);
|
||||
|
||||
cc->Inputs().Tag("ITERABLE").Set<IterableT>();
|
||||
cc->Outputs().Index(0).Set<bool>();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
min_size_ =
|
||||
cc->Options<::mediapipe::CollectionHasMinSizeCalculatorOptions>()
|
||||
.min_size();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
const IterableT& input = cc->Inputs().Tag("ITERABLE").Get<IterableT>();
|
||||
bool has_min_size = input.size() >= min_size_;
|
||||
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
MakePacket<bool>(has_min_size).At(cc->InputTimestamp()));
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int min_size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_UTIL_COLLECTION_HAS_MIN_SIZE_CALCULATOR_H_
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message CollectionHasMinSizeCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional CollectionHasMinSizeCalculatorOptions ext = 259397840;
|
||||
}
|
||||
|
||||
// The minimum size an input iterable collection should have for the
|
||||
// calculator to output true.
|
||||
optional int32 min_size = 1 [default = 0];
|
||||
}
|
|
@ -19,8 +19,8 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
|
||||
#if defined(MEDIAPIPE_LITE) || defined(__ANDROID__) || \
|
||||
(defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#if defined(MEDIAPIPE_LITE) || defined(__EMSCRIPTEN__) || \
|
||||
defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#include "mediapipe/util/android/file/base/file.h"
|
||||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#else
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "mediapipe/calculators/util/detections_to_rects_calculator.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
|
||||
|
@ -24,8 +26,6 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
using mediapipe::DetectionsToRectsCalculatorOptions;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kDetectionTag[] = "DETECTION";
|
||||
|
@ -36,7 +36,10 @@ constexpr char kNormRectTag[] = "NORM_RECT";
|
|||
constexpr char kRectsTag[] = "RECTS";
|
||||
constexpr char kNormRectsTag[] = "NORM_RECTS";
|
||||
|
||||
::mediapipe::Status DetectionToRect(const Detection& detection, Rect* rect) {
|
||||
} // namespace
|
||||
|
||||
::mediapipe::Status DetectionsToRectsCalculator::DetectionToRect(
|
||||
const Detection& detection, Rect* rect) {
|
||||
const LocationData location_data = detection.location_data();
|
||||
RET_CHECK(location_data.format() == LocationData::BOUNDING_BOX)
|
||||
<< "Only Detection with formats of BOUNDING_BOX can be converted to Rect";
|
||||
|
@ -48,8 +51,8 @@ constexpr char kNormRectsTag[] = "NORM_RECTS";
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status DetectionToNormalizedRect(const Detection& detection,
|
||||
NormalizedRect* rect) {
|
||||
::mediapipe::Status DetectionsToRectsCalculator::DetectionToNormalizedRect(
|
||||
const Detection& detection, NormalizedRect* rect) {
|
||||
const LocationData location_data = detection.location_data();
|
||||
RET_CHECK(location_data.format() == LocationData::RELATIVE_BOUNDING_BOX)
|
||||
<< "Only Detection with formats of RELATIVE_BOUNDING_BOX can be "
|
||||
|
@ -63,79 +66,6 @@ constexpr char kNormRectsTag[] = "NORM_RECTS";
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Wraps around an angle in radians to within -M_PI and M_PI.
|
||||
inline float NormalizeRadians(float angle) {
|
||||
return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// A calculator that converts Detection proto to Rect proto.
|
||||
//
|
||||
// Detection is the format for encoding one or more detections in an image.
|
||||
// The input can be a single Detection or std::vector<Detection>. The output can
|
||||
// be either a single Rect or NormalizedRect, or std::vector<Rect> or
|
||||
// std::vector<NormalizedRect>. If Rect is used, the LocationData format is
|
||||
// expected to be BOUNDING_BOX, and if NormalizedRect is used it is expected to
|
||||
// be RELATIVE_BOUNDING_BOX.
|
||||
//
|
||||
// When the input is std::vector<Detection> and the output is a Rect or
|
||||
// NormalizedRect, only the first detection is converted. When the input is a
|
||||
// single Detection and the output is a std::vector<Rect> or
|
||||
// std::vector<NormalizedRect>, the output is a vector of size 1.
|
||||
//
|
||||
// Inputs:
|
||||
//
|
||||
// One of the following:
|
||||
// DETECTION: A Detection proto.
|
||||
// DETECTIONS: An std::vector<Detection>.
|
||||
//
|
||||
// IMAGE_SIZE (optional): A std::pair<int, int> represention image width and
|
||||
// height. This is required only when rotation needs to be computed (see
|
||||
// calculator options).
|
||||
//
|
||||
// Output:
|
||||
// One of the following:
|
||||
// RECT: A Rect proto.
|
||||
// NORM_RECT: A NormalizedRect proto.
|
||||
// RECTS: An std::vector<Rect>.
|
||||
// NORM_RECTS: An std::vector<NormalizedRect>.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "DetectionsToRectsCalculator"
|
||||
// input_stream: "DETECTIONS:detections"
|
||||
// input_stream: "IMAGE_SIZE:image_size"
|
||||
// output_stream: "NORM_RECT:rect"
|
||||
// options: {
|
||||
// [mediapipe.DetectionsToRectCalculatorOptions.ext] {
|
||||
// rotation_vector_start_keypoint_index: 0
|
||||
// rotation_vector_end_keypoint_index: 2
|
||||
// rotation_vector_target_angle_degrees: 90
|
||||
// output_zero_rect_for_empty_detections: true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class DetectionsToRectsCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
float ComputeRotation(const Detection& detection,
|
||||
const std::pair<int, int> image_size);
|
||||
|
||||
DetectionsToRectsCalculatorOptions options_;
|
||||
int start_keypoint_index_;
|
||||
int end_keypoint_index_;
|
||||
float target_angle_; // In radians.
|
||||
bool rotate_;
|
||||
bool output_zero_rect_for_empty_detections_;
|
||||
};
|
||||
REGISTER_CALCULATOR(DetectionsToRectsCalculator);
|
||||
|
||||
::mediapipe::Status DetectionsToRectsCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kDetectionTag) ^
|
||||
|
@ -232,6 +162,13 @@ REGISTER_CALCULATOR(DetectionsToRectsCalculator);
|
|||
.Tag(kNormRectTag)
|
||||
.AddPacket(MakePacket<NormalizedRect>().At(cc->InputTimestamp()));
|
||||
}
|
||||
if (cc->Outputs().HasTag(kNormRectsTag)) {
|
||||
auto rect_vector = absl::make_unique<std::vector<NormalizedRect>>();
|
||||
rect_vector->emplace_back(NormalizedRect());
|
||||
cc->Outputs()
|
||||
.Tag(kNormRectsTag)
|
||||
.Add(rect_vector.release(), cc->InputTimestamp());
|
||||
}
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -312,4 +249,6 @@ float DetectionsToRectsCalculator::ComputeRotation(
|
|||
return NormalizeRadians(rotation);
|
||||
}
|
||||
|
||||
REGISTER_CALCULATOR(DetectionsToRectsCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
105
mediapipe/calculators/util/detections_to_rects_calculator.h
Normal file
|
@ -0,0 +1,105 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#ifndef MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_options.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A calculator that converts Detection proto to Rect proto.
|
||||
//
|
||||
// Detection is the format for encoding one or more detections in an image.
|
||||
// The input can be a single Detection or std::vector<Detection>. The output can
|
||||
// be either a single Rect or NormalizedRect, or std::vector<Rect> or
|
||||
// std::vector<NormalizedRect>. If Rect is used, the LocationData format is
|
||||
// expected to be BOUNDING_BOX, and if NormalizedRect is used it is expected to
|
||||
// be RELATIVE_BOUNDING_BOX.
|
||||
//
|
||||
// When the input is std::vector<Detection> and the output is a Rect or
|
||||
// NormalizedRect, only the first detection is converted. When the input is a
|
||||
// single Detection and the output is a std::vector<Rect> or
|
||||
// std::vector<NormalizedRect>, the output is a vector of size 1.
|
||||
//
|
||||
// Inputs:
|
||||
//
|
||||
// One of the following:
|
||||
// DETECTION: A Detection proto.
|
||||
// DETECTIONS: An std::vector<Detection>.
|
||||
//
|
||||
// IMAGE_SIZE (optional): A std::pair<int, int> represention image width and
|
||||
// height. This is required only when rotation needs to be computed (see
|
||||
// calculator options).
|
||||
//
|
||||
// Output:
|
||||
// One of the following:
|
||||
// RECT: A Rect proto.
|
||||
// NORM_RECT: A NormalizedRect proto.
|
||||
// RECTS: An std::vector<Rect>.
|
||||
// NORM_RECTS: An std::vector<NormalizedRect>.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "DetectionsToRectsCalculator"
|
||||
// input_stream: "DETECTIONS:detections"
|
||||
// input_stream: "IMAGE_SIZE:image_size"
|
||||
// output_stream: "NORM_RECT:rect"
|
||||
// options: {
|
||||
// [mediapipe.DetectionsToRectCalculatorOptions.ext] {
|
||||
// rotation_vector_start_keypoint_index: 0
|
||||
// rotation_vector_end_keypoint_index: 2
|
||||
// rotation_vector_target_angle_degrees: 90
|
||||
// output_zero_rect_for_empty_detections: true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class DetectionsToRectsCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
protected:
|
||||
virtual float ComputeRotation(const ::mediapipe::Detection& detection,
|
||||
const std::pair<int, int> image_size);
|
||||
virtual ::mediapipe::Status DetectionToRect(
|
||||
const ::mediapipe::Detection& detection, ::mediapipe::Rect* rect);
|
||||
virtual ::mediapipe::Status DetectionToNormalizedRect(
|
||||
const ::mediapipe::Detection& detection,
|
||||
::mediapipe::NormalizedRect* rect);
|
||||
|
||||
static inline float NormalizeRadians(float angle) {
|
||||
return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI));
|
||||
}
|
||||
|
||||
::mediapipe::DetectionsToRectsCalculatorOptions options_;
|
||||
int start_keypoint_index_;
|
||||
int end_keypoint_index_;
|
||||
float target_angle_ = 0.0f; // In radians.
|
||||
bool rotate_;
|
||||
bool output_zero_rect_for_empty_detections_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
#endif // MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_
|
34
mediapipe/calculators/util/filter_collection_calculator.cc
Normal file
|
@ -0,0 +1,34 @@
|
|||
|
||||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/util/filter_collection_calculator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
typedef FilterCollectionCalculator<std::vector<::mediapipe::NormalizedRect>>
|
||||
FilterNormalizedRectCollectionCalculator;
|
||||
REGISTER_CALCULATOR(FilterNormalizedRectCollectionCalculator);
|
||||
|
||||
typedef FilterCollectionCalculator<
|
||||
std::vector<std::vector<::mediapipe::NormalizedLandmark>>>
|
||||
FilterLandmarksCollectionCalculator;
|
||||
REGISTER_CALCULATOR(FilterLandmarksCollectionCalculator);
|
||||
|
||||
} // namespace mediapipe
|
109
mediapipe/calculators/util/filter_collection_calculator.h
Normal file
|
@ -0,0 +1,109 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_UTIL_FILTER_VECTOR_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_UTIL_FILTER_VECTOR_CALCULATOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A calculator that gates elements of an input collection based on
|
||||
// corresponding boolean values of the "CONDITION" vector. If there is no input
|
||||
// collection or "CONDITION" vector, the calculator forwards timestamp bounds
|
||||
// for downstream calculators. If the "CONDITION" vector has false values for
|
||||
// all elements of the input collection, the calculator outputs a packet
|
||||
// containing an empty collection.
|
||||
// Example usage:
|
||||
// node {
|
||||
// calculator: "FilterCollectionCalculator"
|
||||
// input_stream: "ITERABLE:input_collection"
|
||||
// input_stream: "CONDITION:condition_vector"
|
||||
// output_stream: "ITERABLE:output_collection"
|
||||
// }
|
||||
// This calculator is able to handle collections of copyable types T.
|
||||
template <typename IterableT>
|
||||
class FilterCollectionCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("ITERABLE"));
|
||||
RET_CHECK(cc->Inputs().HasTag("CONDITION"));
|
||||
RET_CHECK(cc->Outputs().HasTag("ITERABLE"));
|
||||
|
||||
cc->Inputs().Tag("ITERABLE").Set<IterableT>();
|
||||
cc->Inputs().Tag("CONDITION").Set<std::vector<bool>>();
|
||||
|
||||
cc->Outputs().Tag("ITERABLE").Set<IterableT>();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
if (cc->Inputs().Tag("ITERABLE").IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
if (cc->Inputs().Tag("CONDITION").IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
const std::vector<bool>& filter_by =
|
||||
cc->Inputs().Tag("CONDITION").Get<std::vector<bool>>();
|
||||
|
||||
return FilterCollection<IterableT>(
|
||||
std::is_copy_constructible<typename IterableT::value_type>(), cc,
|
||||
filter_by);
|
||||
}
|
||||
|
||||
template <typename IterableU>
|
||||
::mediapipe::Status FilterCollection(std::true_type, CalculatorContext* cc,
|
||||
const std::vector<bool>& filter_by) {
|
||||
const IterableU& input = cc->Inputs().Tag("ITERABLE").Get<IterableU>();
|
||||
if (input.size() != filter_by.size()) {
|
||||
return ::mediapipe::InternalError(absl::StrCat(
|
||||
"Input vector size: ", input.size(),
|
||||
" doesn't mach condition vector size: ", filter_by.size()));
|
||||
}
|
||||
|
||||
auto output = absl::make_unique<IterableU>();
|
||||
for (int i = 0; i < input.size(); ++i) {
|
||||
if (filter_by[i]) {
|
||||
output->push_back(input[i]);
|
||||
}
|
||||
}
|
||||
cc->Outputs().Tag("ITERABLE").Add(output.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
template <typename IterableU>
|
||||
::mediapipe::Status FilterCollection(std::false_type, CalculatorContext* cc,
|
||||
const std::vector<bool>& filter_by) {
|
||||
return ::mediapipe::InternalError(
|
||||
"Cannot copy input collection to filter it.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_UTIL_FILTER_VECTOR_CALCULATOR_H_
|
181
mediapipe/calculators/util/labels_to_render_data_calculator.cc
Normal file
|
@ -0,0 +1,181 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/calculators/util/labels_to_render_data_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/video_stream_header.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr float kFontHeightScale = 1.25f;
|
||||
|
||||
// A calculator takes in pairs of labels and scores or classifications, outputs
|
||||
// generates render data. Either both "LABELS" and "SCORES" or "CLASSIFICATIONS"
|
||||
// must be present.
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "LabelsToRenderDataCalculator"
|
||||
// input_stream: "LABELS:labels"
|
||||
// input_stream: "SCORES:scores"
|
||||
// output_stream: "VIDEO_PRESTREAM:video_header"
|
||||
// options {
|
||||
// [LabelsToRenderDataCalculatorOptions.ext] {
|
||||
// color { r: 255 g: 0 b: 0 }
|
||||
// color { r: 0 g: 255 b: 0 }
|
||||
// color { r: 0 g: 0 b: 255 }
|
||||
// thickness: 2.0
|
||||
// font_height_px: 20
|
||||
// max_num_labels: 3
|
||||
// font_face: 1
|
||||
// location: TOP_LEFT
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class LabelsToRenderDataCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
LabelsToRenderDataCalculatorOptions options_;
|
||||
int num_colors_ = 0;
|
||||
int video_width_ = 0;
|
||||
int video_height_ = 0;
|
||||
int label_height_px_ = 0;
|
||||
int label_left_px_ = 0;
|
||||
};
|
||||
REGISTER_CALCULATOR(LabelsToRenderDataCalculator);
|
||||
|
||||
::mediapipe::Status LabelsToRenderDataCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
|
||||
cc->Inputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
|
||||
} else {
|
||||
RET_CHECK(cc->Inputs().HasTag("LABELS"))
|
||||
<< "Must provide input stream \"LABELS\"";
|
||||
cc->Inputs().Tag("LABELS").Set<std::vector<std::string>>();
|
||||
if (cc->Inputs().HasTag("SCORES")) {
|
||||
cc->Inputs().Tag("SCORES").Set<std::vector<float>>();
|
||||
}
|
||||
}
|
||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
|
||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
|
||||
}
|
||||
cc->Outputs().Tag("RENDER_DATA").Set<RenderData>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) {
|
||||
options_ = cc->Options<LabelsToRenderDataCalculatorOptions>();
|
||||
num_colors_ = options_.color_size();
|
||||
label_height_px_ = std::ceil(options_.font_height_px() * kFontHeightScale);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status LabelsToRenderDataCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM") &&
|
||||
cc->InputTimestamp() == Timestamp::PreStream()) {
|
||||
const VideoHeader& video_header =
|
||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
|
||||
video_width_ = video_header.width;
|
||||
video_height_ = video_header.height;
|
||||
return ::mediapipe::OkStatus();
|
||||
} else {
|
||||
CHECK_EQ(options_.location(), LabelsToRenderDataCalculatorOptions::TOP_LEFT)
|
||||
<< "Only TOP_LEFT is supported without VIDEO_PRESTREAM.";
|
||||
}
|
||||
|
||||
std::vector<std::string> labels;
|
||||
std::vector<float> scores;
|
||||
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
|
||||
const ClassificationList& classifications =
|
||||
cc->Inputs().Tag("CLASSIFICATIONS").Get<ClassificationList>();
|
||||
labels.resize(classifications.classification_size());
|
||||
scores.resize(classifications.classification_size());
|
||||
for (int i = 0; i < classifications.classification_size(); ++i) {
|
||||
labels[i] = classifications.classification(i).label();
|
||||
scores[i] = classifications.classification(i).score();
|
||||
}
|
||||
} else {
|
||||
const std::vector<std::string>& label_vector =
|
||||
cc->Inputs().Tag("LABELS").Get<std::vector<std::string>>();
|
||||
std::vector<float> score_vector;
|
||||
if (cc->Inputs().HasTag("SCORES")) {
|
||||
score_vector = cc->Inputs().Tag("SCORES").Get<std::vector<float>>();
|
||||
}
|
||||
CHECK_EQ(label_vector.size(), score_vector.size());
|
||||
labels.resize(label_vector.size());
|
||||
scores.resize(label_vector.size());
|
||||
for (int i = 0; i < label_vector.size(); ++i) {
|
||||
labels[i] = label_vector[i];
|
||||
scores[i] = score_vector[i];
|
||||
}
|
||||
}
|
||||
|
||||
RenderData render_data;
|
||||
int num_label = std::min((int)labels.size(), options_.max_num_labels());
|
||||
int label_baseline_px = options_.vertical_offset_px();
|
||||
if (options_.location() == LabelsToRenderDataCalculatorOptions::TOP_LEFT) {
|
||||
label_baseline_px += label_height_px_;
|
||||
} else if (options_.location() ==
|
||||
LabelsToRenderDataCalculatorOptions::BOTTOM_LEFT) {
|
||||
label_baseline_px += video_height_ - label_height_px_ * (num_label - 1);
|
||||
}
|
||||
label_left_px_ = options_.horizontal_offset_px();
|
||||
for (int i = 0; i < num_label; ++i) {
|
||||
auto* label_annotation = render_data.add_render_annotations();
|
||||
label_annotation->set_thickness(options_.thickness());
|
||||
if (num_colors_ > 0) {
|
||||
*(label_annotation->mutable_color()) = options_.color(i % num_colors_);
|
||||
} else {
|
||||
label_annotation->mutable_color()->set_r(255);
|
||||
label_annotation->mutable_color()->set_g(0);
|
||||
label_annotation->mutable_color()->set_b(0);
|
||||
}
|
||||
|
||||
auto* text = label_annotation->mutable_text();
|
||||
std::string display_text = labels[i];
|
||||
if (cc->Inputs().HasTag("SCORES")) {
|
||||
absl::StrAppend(&display_text, ":", scores[i]);
|
||||
}
|
||||
text->set_display_text(display_text);
|
||||
text->set_font_height(options_.font_height_px());
|
||||
text->set_left(label_left_px_);
|
||||
text->set_baseline(label_baseline_px + i * label_height_px_);
|
||||
text->set_font_face(options_.font_face());
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("RENDER_DATA")
|
||||
.AddPacket(MakePacket<RenderData>(render_data).At(cc->InputTimestamp()));
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,62 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/util/color.proto";
|
||||
|
||||
message LabelsToRenderDataCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional LabelsToRenderDataCalculatorOptions ext = 271660364;
|
||||
}
|
||||
|
||||
// Colors for drawing the label(s).
|
||||
repeated Color color = 1;
|
||||
|
||||
// Thickness for drawing the label(s).
|
||||
optional double thickness = 2 [default = 2];
|
||||
|
||||
// The font height in absolute pixels.
|
||||
optional int32 font_height_px = 3 [default = 50];
|
||||
|
||||
// The offset of the starting text in horizontal direction in absolute pixels.
|
||||
optional int32 horizontal_offset_px = 7 [default = 0];
|
||||
// The offset of the starting text in vertical direction in absolute pixels.
|
||||
optional int32 vertical_offset_px = 8 [default = 0];
|
||||
|
||||
// The maximum number of labels to display.
|
||||
optional int32 max_num_labels = 4 [default = 1];
|
||||
|
||||
// Specifies the font for the text. Font must be one of the following from
|
||||
// OpenCV:
|
||||
// cv::FONT_HERSHEY_SIMPLEX (0)
|
||||
// cv::FONT_HERSHEY_PLAIN (1)
|
||||
// cv::FONT_HERSHEY_DUPLEX (2)
|
||||
// cv::FONT_HERSHEY_COMPLEX (3)
|
||||
// cv::FONT_HERSHEY_TRIPLEX (4)
|
||||
// cv::FONT_HERSHEY_COMPLEX_SMALL (5)
|
||||
// cv::FONT_HERSHEY_SCRIPT_SIMPLEX (6)
|
||||
// cv::FONT_HERSHEY_SCRIPT_COMPLEX (7)
|
||||
optional int32 font_face = 5 [default = 0];
|
||||
|
||||
// Label location.
|
||||
enum Location {
|
||||
TOP_LEFT = 0;
|
||||
BOTTOM_LEFT = 1;
|
||||
}
|
||||
optional Location location = 6 [default = TOP_LEFT];
|
||||
}
|
138
mediapipe/calculators/util/landmarks_to_floats_calculator.cc
Normal file
|
@ -0,0 +1,138 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/calculators/util/landmarks_to_floats_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kLandmarksTag[] = "NORM_LANDMARKS";
|
||||
constexpr char kFloatsTag[] = "FLOATS";
|
||||
constexpr char kMatrixTag[] = "MATRIX";
|
||||
|
||||
} // namespace
|
||||
|
||||
// Converts a vector of landmarks to a vector of floats or a matrix.
|
||||
// Input:
|
||||
// NORM_LANDMARKS: An std::vector<NormalizedLandmark>.
|
||||
//
|
||||
// Output:
|
||||
// FLOATS(optional): A vector of floats from flattened landmarks.
|
||||
// MATRIX(optional): A matrix of floats of the landmarks.
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "LandmarksToFloatsCalculator"
|
||||
// input_stream: "NORM_LANDMARKS:landmarks"
|
||||
// output_stream: "MATRIX:landmark_matrix"
|
||||
// }
|
||||
class LandmarksToFloatsCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag(kLandmarksTag).Set<std::vector<NormalizedLandmark>>();
|
||||
RET_CHECK(cc->Outputs().HasTag(kFloatsTag) ||
|
||||
cc->Outputs().HasTag(kMatrixTag));
|
||||
if (cc->Outputs().HasTag(kFloatsTag)) {
|
||||
cc->Outputs().Tag(kFloatsTag).Set<std::vector<float>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kMatrixTag)) {
|
||||
cc->Outputs().Tag(kMatrixTag).Set<Matrix>();
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
const auto& options =
|
||||
cc->Options<::mediapipe::LandmarksToFloatsCalculatorOptions>();
|
||||
num_dimensions_ = options.num_dimensions();
|
||||
// Currently number of dimensions must be within [1, 3].
|
||||
RET_CHECK_GE(num_dimensions_, 1);
|
||||
RET_CHECK_LE(num_dimensions_, 3);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
// Only process if there's input landmarks.
|
||||
if (cc->Inputs().Tag(kLandmarksTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
const auto& input_landmarks =
|
||||
cc->Inputs().Tag(kLandmarksTag).Get<std::vector<NormalizedLandmark>>();
|
||||
|
||||
if (cc->Outputs().HasTag(kFloatsTag)) {
|
||||
auto output_floats = absl::make_unique<std::vector<float>>();
|
||||
for (const auto& landmark : input_landmarks) {
|
||||
output_floats->emplace_back(landmark.x());
|
||||
if (num_dimensions_ > 1) {
|
||||
output_floats->emplace_back(landmark.y());
|
||||
}
|
||||
if (num_dimensions_ > 2) {
|
||||
output_floats->emplace_back(landmark.z());
|
||||
}
|
||||
}
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kFloatsTag)
|
||||
.Add(output_floats.release(), cc->InputTimestamp());
|
||||
} else {
|
||||
auto output_matrix = absl::make_unique<Matrix>();
|
||||
output_matrix->setZero(num_dimensions_, input_landmarks.size());
|
||||
for (int i = 0; i < input_landmarks.size(); ++i) {
|
||||
(*output_matrix)(0, i) = input_landmarks[i].x();
|
||||
if (num_dimensions_ > 1) {
|
||||
(*output_matrix)(1, i) = input_landmarks[i].y();
|
||||
}
|
||||
if (num_dimensions_ > 2) {
|
||||
(*output_matrix)(2, i) = input_landmarks[i].z();
|
||||
}
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag(kMatrixTag)
|
||||
.Add(output_matrix.release(), cc->InputTimestamp());
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int num_dimensions_ = 0;
|
||||
};
|
||||
REGISTER_CALCULATOR(LandmarksToFloatsCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message LandmarksToFloatsCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional LandmarksToFloatsCalculatorOptions ext = 274035660;
|
||||
}
|
||||
|
||||
// Number of dimensions to convert. Must within [1, 3].
|
||||
optional int32 num_dimensions = 1 [default = 2];
|
||||
}
|
57
mediapipe/calculators/util/local_file_contents_calculator.cc
Normal file
|
@ -0,0 +1,57 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
// The calculator takes the path to the local file as an input side packet and
|
||||
// outputs the contents of that file.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "LocalFileContentsCalculator"
|
||||
// input_side_packet: "FILE_PATH:file_path"
|
||||
// output_side_packet: "CONTENTS:contents"
|
||||
// }
|
||||
class LocalFileContentsCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Tag("FILE_PATH").Set<std::string>();
|
||||
cc->OutputSidePackets().Tag("CONTENTS").Set<std::string>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
std::string contents;
|
||||
MP_RETURN_IF_ERROR(mediapipe::file::GetContents(
|
||||
cc->InputSidePackets().Tag("FILE_PATH").Get<std::string>(), &contents));
|
||||
cc->OutputSidePackets()
|
||||
.Tag("CONTENTS")
|
||||
.Set(MakePacket<std::string>(std::move(contents)));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(LocalFileContentsCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -23,7 +23,9 @@ namespace mediapipe {
|
|||
namespace {
|
||||
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kNormRectsTag[] = "NORM_RECTS";
|
||||
constexpr char kRectTag[] = "RECT";
|
||||
constexpr char kRectsTag[] = "RECTS";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
|
||||
// Wraps around an angle in radians to within -M_PI and M_PI.
|
||||
|
@ -72,17 +74,31 @@ REGISTER_CALCULATOR(RectTransformationCalculator);
|
|||
|
||||
::mediapipe::Status RectTransformationCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kNormRectTag) ^ cc->Inputs().HasTag(kRectTag));
|
||||
RET_CHECK_EQ((cc->Inputs().HasTag(kNormRectTag) ? 1 : 0) +
|
||||
(cc->Inputs().HasTag(kNormRectsTag) ? 1 : 0) +
|
||||
(cc->Inputs().HasTag(kRectTag) ? 1 : 0) +
|
||||
(cc->Inputs().HasTag(kRectsTag) ? 1 : 0),
|
||||
1);
|
||||
if (cc->Inputs().HasTag(kRectTag)) {
|
||||
cc->Inputs().Tag(kRectTag).Set<Rect>();
|
||||
cc->Outputs().Index(0).Set<Rect>();
|
||||
}
|
||||
if (cc->Inputs().HasTag(kRectsTag)) {
|
||||
cc->Inputs().Tag(kRectsTag).Set<std::vector<Rect>>();
|
||||
cc->Outputs().Index(0).Set<std::vector<Rect>>();
|
||||
}
|
||||
if (cc->Inputs().HasTag(kNormRectTag)) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kImageSizeTag));
|
||||
cc->Inputs().Tag(kNormRectTag).Set<NormalizedRect>();
|
||||
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
|
||||
cc->Outputs().Index(0).Set<NormalizedRect>();
|
||||
}
|
||||
if (cc->Inputs().HasTag(kNormRectsTag)) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kImageSizeTag));
|
||||
cc->Inputs().Tag(kNormRectsTag).Set<std::vector<NormalizedRect>>();
|
||||
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
|
||||
cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>();
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -105,7 +121,17 @@ REGISTER_CALCULATOR(RectTransformationCalculator);
|
|||
cc->Outputs().Index(0).AddPacket(
|
||||
MakePacket<Rect>(rect).At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag(kRectsTag) &&
|
||||
!cc->Inputs().Tag(kRectsTag).IsEmpty()) {
|
||||
auto rects = cc->Inputs().Tag(kRectsTag).Get<std::vector<Rect>>();
|
||||
auto output_rects = absl::make_unique<std::vector<Rect>>(rects.size());
|
||||
for (int i = 0; i < rects.size(); ++i) {
|
||||
output_rects->at(i) = rects[i];
|
||||
auto it = output_rects->begin() + i;
|
||||
TransformRect(&(*it));
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output_rects.release(), cc->InputTimestamp());
|
||||
}
|
||||
if (cc->Inputs().HasTag(kNormRectTag) &&
|
||||
!cc->Inputs().Tag(kNormRectTag).IsEmpty()) {
|
||||
auto rect = cc->Inputs().Tag(kNormRectTag).Get<NormalizedRect>();
|
||||
|
@ -115,6 +141,21 @@ REGISTER_CALCULATOR(RectTransformationCalculator);
|
|||
cc->Outputs().Index(0).AddPacket(
|
||||
MakePacket<NormalizedRect>(rect).At(cc->InputTimestamp()));
|
||||
}
|
||||
if (cc->Inputs().HasTag(kNormRectsTag) &&
|
||||
!cc->Inputs().Tag(kNormRectsTag).IsEmpty()) {
|
||||
auto rects =
|
||||
cc->Inputs().Tag(kNormRectsTag).Get<std::vector<NormalizedRect>>();
|
||||
const auto& image_size =
|
||||
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
|
||||
auto output_rects =
|
||||
absl::make_unique<std::vector<NormalizedRect>>(rects.size());
|
||||
for (int i = 0; i < rects.size(); ++i) {
|
||||
output_rects->at(i) = rects[i];
|
||||
auto it = output_rects->begin() + i;
|
||||
TransformNormalizedRect(&(*it), image_size.first, image_size.second);
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output_rects.release(), cc->InputTimestamp());
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
|
|
@ -23,13 +23,14 @@
|
|||
|
||||
#include "mediapipe/calculators/util/top_k_scores_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
|
||||
#if defined(MEDIAPIPE_LITE) || defined(__ANDROID__) || \
|
||||
(defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#if defined(MEDIAPIPE_LITE) || defined(__EMSCRIPTEN__) || \
|
||||
defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#include "mediapipe/util/android/file/base/file.h"
|
||||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#else
|
||||
|
@ -37,8 +38,10 @@
|
|||
#endif
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A calculator that takes a vector of scores and returns the indexes, scores,
|
||||
// labels of the top k elements.
|
||||
// labels of the top k elements, classification protos, and summary std::string
|
||||
// (in csv format).
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
|
@ -47,6 +50,8 @@ namespace mediapipe {
|
|||
// output_stream: "TOP_K_INDEXES:top_k_indexes"
|
||||
// output_stream: "TOP_K_SCORES:top_k_scores"
|
||||
// output_stream: "TOP_K_LABELS:top_k_labels"
|
||||
// output_stream: "TOP_K_CLASSIFICATIONS:top_k_classes"
|
||||
// output_stream: "SUMMARY:summary"
|
||||
// options: {
|
||||
// [mediapipe.TopKScoresCalculatorOptions.ext] {
|
||||
// top_k: 5
|
||||
|
@ -69,6 +74,7 @@ class TopKScoresCalculator : public CalculatorBase {
|
|||
int top_k_ = -1;
|
||||
float threshold_ = 0.0;
|
||||
std::unordered_map<int, std::string> label_map_;
|
||||
bool label_map_loaded_ = false;
|
||||
};
|
||||
REGISTER_CALCULATOR(TopKScoresCalculator);
|
||||
|
||||
|
@ -84,6 +90,12 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
|
|||
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
|
||||
cc->Outputs().Tag("TOP_K_LABELS").Set<std::vector<std::string>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("CLASSIFICATIONS")) {
|
||||
cc->Outputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("SUMMARY")) {
|
||||
cc->Outputs().Tag("SUMMARY").Set<std::string>();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -149,7 +161,7 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
|
|||
reverse(top_k_indexes.begin(), top_k_indexes.end());
|
||||
reverse(top_k_scores.begin(), top_k_scores.end());
|
||||
|
||||
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
|
||||
if (label_map_loaded_) {
|
||||
for (int index : top_k_indexes) {
|
||||
top_k_labels.push_back(label_map_[index]);
|
||||
}
|
||||
|
@ -172,6 +184,35 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
|
|||
.AddPacket(MakePacket<std::vector<std::string>>(top_k_labels)
|
||||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("SUMMARY")) {
|
||||
std::vector<std::string> results;
|
||||
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
||||
if (label_map_loaded_) {
|
||||
results.push_back(
|
||||
absl::StrCat(top_k_labels[index], ":", top_k_scores[index]));
|
||||
} else {
|
||||
results.push_back(
|
||||
absl::StrCat(top_k_indexes[index], ":", top_k_scores[index]));
|
||||
}
|
||||
}
|
||||
cc->Outputs().Tag("SUMMARY").AddPacket(
|
||||
MakePacket<std::string>(absl::StrJoin(results, ","))
|
||||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("TOP_K_CLASSIFICATION")) {
|
||||
auto classification_list = absl::make_unique<ClassificationList>();
|
||||
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
||||
Classification* classification =
|
||||
classification_list->add_classification();
|
||||
classification->set_index(top_k_indexes[index]);
|
||||
classification->set_score(top_k_scores[index]);
|
||||
if (label_map_loaded_) {
|
||||
classification->set_label(top_k_labels[index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -188,6 +229,7 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
|
|||
while (std::getline(stream, line)) {
|
||||
label_map_[i++] = line;
|
||||
}
|
||||
label_map_loaded_ = true;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -13,12 +13,16 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
load(
|
||||
"//mediapipe/framework/tool:mediapipe_graph.bzl",
|
||||
"mediapipe_binary_graph",
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "flow_to_image_calculator_proto",
|
||||
srcs = ["flow_to_image_calculator.proto"],
|
||||
|
@ -52,9 +56,7 @@ mediapipe_cc_proto_library(
|
|||
cc_library(
|
||||
name = "flow_to_image_calculator",
|
||||
srcs = ["flow_to_image_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/video:flow_to_image_calculator_cc_proto",
|
||||
"//mediapipe/calculators/video/tool:flow_quantizer_model",
|
||||
|
@ -129,10 +131,20 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "test_videos",
|
||||
srcs = [
|
||||
"testdata/format_FLV_H264_AAC.video",
|
||||
"testdata/format_MKV_VP8_VORBIS.video",
|
||||
"testdata/format_MP4_AVC720P_AAC.video",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "opencv_video_decoder_calculator_test",
|
||||
srcs = ["opencv_video_decoder_calculator_test.cc"],
|
||||
data = ["//mediapipe/calculators/video/testdata:test_videos"],
|
||||
data = [":test_videos"],
|
||||
deps = [
|
||||
":opencv_video_decoder_calculator",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
|
@ -151,7 +163,7 @@ cc_test(
|
|||
cc_test(
|
||||
name = "opencv_video_encoder_calculator_test",
|
||||
srcs = ["opencv_video_encoder_calculator_test.cc"],
|
||||
data = ["//mediapipe/calculators/video/testdata:test_videos"],
|
||||
data = [":test_videos"],
|
||||
deps = [
|
||||
":opencv_video_decoder_calculator",
|
||||
":opencv_video_encoder_calculator",
|
||||
|
@ -175,7 +187,6 @@ cc_test(
|
|||
cc_test(
|
||||
name = "tvl1_optical_flow_calculator_test",
|
||||
srcs = ["tvl1_optical_flow_calculator_test.cc"],
|
||||
data = ["//mediapipe/calculators/image/testdata:test_images"],
|
||||
deps = [
|
||||
":tvl1_optical_flow_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
|
|
@ -123,6 +123,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
|||
cc->Outputs()
|
||||
.Tag("VIDEO_PRESTREAM")
|
||||
.Add(header.release(), Timestamp::PreStream());
|
||||
cc->Outputs().Tag("VIDEO_PRESTREAM").Close();
|
||||
}
|
||||
// Rewind to the very first frame.
|
||||
cap_->set(cv::CAP_PROP_POS_AVI_RATIO, 0);
|
||||
|
|
|
@ -13,24 +13,24 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//mediapipe/calculators/video:__subpackages__"])
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "flow_quantizer_model_proto",
|
||||
srcs = ["flow_quantizer_model.proto"],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "flow_quantizer_model_cc_proto",
|
||||
srcs = ["flow_quantizer_model.proto"],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":flow_quantizer_model_proto"],
|
||||
)
|
||||
|
||||
|
|
131
mediapipe/docs/android_archive_library.md
Normal file
|
@ -0,0 +1,131 @@
|
|||
## MediaPipe Android Archive Library
|
||||
|
||||
***Experimental Only***
|
||||
|
||||
The MediaPipe Android archive library is a convenient way to use MediaPipe with
|
||||
Android Studio and Gradle. MediaPipe doesn't publish a general AAR that can be
|
||||
used by all projects. Instead, developers need to add a mediapipe_aar() target
|
||||
to generate a custom AAR file for their own projects. This is necessary in order
|
||||
to include specific resources such as MediaPipe calculators needed for each
|
||||
project.
|
||||
|
||||
### Steps to build a MediaPipe AAR
|
||||
|
||||
1. Create a mediapipe_aar() target.
|
||||
|
||||
In the MediaPipe directory, create a new mediapipe_aar() target in a BUILD
|
||||
file. You need to figure out what calculators are used in the graph and
|
||||
provide the calculator dependencies to the mediapipe_aar(). For example, to
|
||||
build an AAR for [face detection gpu](./face_detection_mobile_gpu.md), you
|
||||
can put the following code into
|
||||
mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/BUILD.
|
||||
|
||||
```
|
||||
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar")
|
||||
|
||||
mediapipe_aar(
|
||||
name = "mp_face_detection_aar",
|
||||
calculators = ["//mediapipe/graphs/face_detection:mobile_calculators"],
|
||||
)
|
||||
```
|
||||
|
||||
2. Run the Bazel build command to generate the AAR.
|
||||
|
||||
```bash
|
||||
bazel build -c opt --fat_apk_cpu=arm64-v8a,armeabi-v7a //path/to/the/aar/build/file:aar_name
|
||||
```
|
||||
|
||||
For the face detection AAR target we made in the step 1, run:
|
||||
|
||||
```bash
|
||||
bazel build -c opt --fat_apk_cpu=arm64-v8a,armeabi-v7a \
|
||||
//mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar
|
||||
|
||||
# It should print:
|
||||
# Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar up-to-date:
|
||||
# bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
||||
```
|
||||
|
||||
3. (Optional) Save the AAR to your preferred location.
|
||||
|
||||
```bash
|
||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
||||
/absolute/path/to/your/preferred/location
|
||||
```
|
||||
|
||||
### Steps to use a MediaPipe AAR in Android Studio with Gradle
|
||||
|
||||
1. Start Android Studio and go to your project.
|
||||
|
||||
2. Copy the AAR into app/libs.
|
||||
|
||||
```bash
|
||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
||||
/path/to/your/app/libs/
|
||||
```
|
||||
|
||||

|
||||
|
||||
3. Make app/src/main/assets and copy assets (graph, model, and etc) into
|
||||
app/src/main/assets.
|
||||
|
||||
Build the MediaPipe binary graph and copy the assets into
|
||||
app/src/main/assets, e.g., for the face detection graph, you need to build
|
||||
and copy
|
||||
[the binary graph](https://github.com/google/mediapipe/blob/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD#L41),
|
||||
[the tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite),
|
||||
and
|
||||
[the label map](https://github.com/google/mediapipe/blob/master/mediapipe/models/face_detection_front_labelmap.txt).
|
||||
|
||||
```bash
|
||||
bazel build -c opt mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu:binary_graph
|
||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/facedetectiongpu.binarypb /path/to/your/app/src/main/assets/
|
||||
cp mediapipe/models/face_detection_front.tflite /path/to/your/app/src/main/assets/
|
||||
cp mediapipe/models/face_detection_front_labelmap.txt /path/to/your/app/src/main/assets/
|
||||
```
|
||||
|
||||

|
||||
|
||||
4. Make app/src/main/jniLibs and copy OpenCV JNI libraries into
|
||||
app/src/main/jniLibs.
|
||||
|
||||
MediaPipe depends on OpenCV, you will need to copy the precompiled OpenCV so
|
||||
files into app/src/main/jniLibs. You can download the official OpenCV
|
||||
Android SDK from
|
||||
[here](https://github.com/opencv/opencv/releases/download/4.1.0/opencv-4.1.0-android-sdk.zip)
|
||||
and run:
|
||||
|
||||
```bash
|
||||
cp -R ~/Downloads/OpenCV-android-sdk/sdk/native/libs/arm* /path/to/your/app/src/main/jniLibs/
|
||||
```
|
||||
|
||||

|
||||
|
||||
5. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR.
|
||||
|
||||
```
|
||||
dependencies {
|
||||
implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar'])
|
||||
implementation 'androidx.appcompat:appcompat:1.0.2'
|
||||
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
|
||||
testImplementation 'junit:junit:4.12'
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.0'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1'
|
||||
// MediaPipe deps
|
||||
implementation 'com.google.flogger:flogger:0.3.1'
|
||||
implementation 'com.google.flogger:flogger-system-backend:0.3.1'
|
||||
implementation 'com.google.code.findbugs:jsr305:3.0.2'
|
||||
implementation 'com.google.guava:guava:27.0.1-android'
|
||||
implementation 'com.google.guava:guava:27.0.1-android'
|
||||
implementation 'com.google.protobuf:protobuf-lite:3.0.0'
|
||||
// CameraX core library
|
||||
def camerax_version = "1.0.0-alpha06"
|
||||
implementation "androidx.camera:camera-core:$camerax_version"
|
||||
implementation "androidx.camera:camera-camera2:$camerax_version"
|
||||
}
|
||||
```
|
||||
|
||||
6. Follow our Android app examples to use MediaPipe in Android Studio for your
|
||||
use case. If you are looking for an example, a working face detection
|
||||
example can be found
|
||||
[here](https://github.com/jiuqiant/mediapipe_aar_example).
|
|
@ -76,6 +76,15 @@ MediaPipe with a TFLite model for hand tracking in a GPU-accelerated pipeline.
|
|||
* [Android](./hand_tracking_mobile_gpu.md)
|
||||
* [iOS](./hand_tracking_mobile_gpu.md)
|
||||
|
||||
### Multi-Hand Tracking with GPU
|
||||
|
||||
[Multi-Hand Tracking with GPU](./multi_hand_tracking_mobile_gpu.md) illustrates
|
||||
how to use MediaPipe with a TFLite model for multi-hand tracking in a
|
||||
GPU-accelerated pipeline.
|
||||
|
||||
* [Android](./multi_hand_tracking_mobile_gpu.md)
|
||||
* [iOS](./multi_hand_tracking_mobile_gpu.md)
|
||||
|
||||
### Hair Segmentation with GPU
|
||||
|
||||
[Hair Segmentation on GPU](./hair_segmentation_mobile_gpu.md) illustrates how to
|
||||
|
@ -96,8 +105,9 @@ using the MediaPipe C++ APIs.
|
|||
|
||||
### Feature Extration for YouTube-8M Challenge
|
||||
|
||||
[Feature Extration for YouTube-8M Challenge](./youtube_8m.md) shows how to use
|
||||
MediaPipe to prepare training data for the YouTube-8M Challenge.
|
||||
[Feature Extration and Model Inference for YouTube-8M Challenge](./youtube_8m.md)
|
||||
shows how to use MediaPipe to prepare training data for the YouTube-8M Challenge
|
||||
and do the model inference with the baseline model.
|
||||
|
||||
### Preparing Data Sets with MediaSequence
|
||||
|
||||
|
@ -131,6 +141,15 @@ with live video from a webcam.
|
|||
* [Desktop GPU](./hand_tracking_desktop.md)
|
||||
* [Desktop CPU](./hand_tracking_desktop.md)
|
||||
|
||||
### Multi-Hand Tracking on Desktop with Webcam
|
||||
|
||||
[Multi-Hand Tracking on Desktop with Webcam](./multi_hand_tracking_desktop.md)
|
||||
shows how to use MediaPipe with a TFLite model for multi-hand tracking on
|
||||
desktop using CPU or GPU with live video from a webcam.
|
||||
|
||||
* [Desktop GPU](./multi_hand_tracking_desktop.md)
|
||||
* [Desktop CPU](./multi_hand_tracking_desktop.md)
|
||||
|
||||
### Hair Segmentation on Desktop with Webcam
|
||||
|
||||
[Hair Segmentation on Desktop with Webcam](./hair_segmentation_desktop.md) shows
|
||||
|
|
|
@ -36,10 +36,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
|
|||
# INFO: 711 processes: 710 linux-sandbox, 1 local.
|
||||
# INFO: Build completed successfully, 734 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible
|
||||
$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_desktop_live.pbtxt
|
||||
```
|
||||
|
||||
|
@ -60,11 +59,10 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
|
|||
# INFO: 711 processes: 710 linux-sandbox, 1 local.
|
||||
# INFO: Build completed successfully, 734 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible,
|
||||
# or GPU drivers not setup properly.
|
||||
$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt
|
||||
```
|
||||
|
||||
|
@ -79,7 +77,7 @@ below and paste it into
|
|||
```bash
|
||||
# MediaPipe graph that performs face detection with TensorFlow Lite on CPU & GPU.
|
||||
# Used in the examples in
|
||||
# mediapipie/examples/desktop/face_detection:face_detection_cpu.
|
||||
# mediapipe/examples/desktop/face_detection:face_detection_cpu.
|
||||
|
||||
# Images on CPU coming into and out of the graph.
|
||||
input_stream: "input_video"
|
||||
|
|
|
@ -31,15 +31,12 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
|
|||
#INFO: Found 1 target...
|
||||
#Target //mediapipe/examples/desktop/hair_segmentation:hair_segmentation_gpu up-to-date:
|
||||
# bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu
|
||||
#INFO: Elapsed time: 18.209s, Forge stats: 13026/13057 actions cached, 20.8s CPU used, 0.0s queue time, 89.3 MB ObjFS output (novel bytes: 87.4 MB), 0.0 MB local output, Critical Path: 11.88s, Remote (86.01% of the time): [queue: 0.00%, network: 16.83%, setup: 4.59%, process: 38.92%]
|
||||
#INFO: Streaming build results to: http://sponge2/37d5a184-293b-4e98-a43e-b22084db3142
|
||||
#INFO: Build completed successfully, 12210 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible,
|
||||
# or GPU drivers not setup properly.
|
||||
$ bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt
|
||||
```
|
||||
|
||||
|
@ -54,7 +51,7 @@ below and paste it into
|
|||
```bash
|
||||
# MediaPipe graph that performs hair segmentation with TensorFlow Lite on GPU.
|
||||
# Used in the example in
|
||||
# mediapipie/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu.
|
||||
# mediapipe/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu.
|
||||
|
||||
# Images on GPU coming into and out of the graph.
|
||||
input_stream: "input_video"
|
||||
|
|
|
@ -29,7 +29,7 @@ below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/).
|
|||
```bash
|
||||
# MediaPipe graph that performs hair segmentation with TensorFlow Lite on GPU.
|
||||
# Used in the example in
|
||||
# mediapipie/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu.
|
||||
# mediapipe/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu.
|
||||
|
||||
# Images on GPU coming into and out of the graph.
|
||||
input_stream: "input_video"
|
||||
|
|
|
@ -31,14 +31,11 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
|
|||
# It should print:
|
||||
#Target //mediapipe/examples/desktop/hand_tracking:hand_tracking_cpu up-to-date:
|
||||
# bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu
|
||||
#INFO: Elapsed time: 22.645s, Forge stats: 13356/13463 actions cached, 1.5m CPU used, 0.0s queue time, 819.8 MB ObjFS output (novel bytes: 85.6 MB), 0.0 MB local output, Critical Path: 14.43s, Remote (87.25% of the time): [queue: 0.00%, network: 14.88%, setup: 4.80%, process: 39.80%, fetch: 18.15%]
|
||||
#INFO: Streaming build results to: http://sponge2/360196b9-33ab-44b1-84a7-1022b5043307
|
||||
#INFO: Build completed successfully, 12517 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible
|
||||
$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt
|
||||
```
|
||||
|
||||
|
@ -55,15 +52,12 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
|
|||
# It should print:
|
||||
# Target //mediapipe/examples/desktop/hand_tracking:hand_tracking_gpu up-to-date:
|
||||
# bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu
|
||||
#INFO: Elapsed time: 84.055s, Forge stats: 6858/19343 actions cached, 1.6h CPU used, 0.9s queue time, 1.68 GB ObjFS output (novel bytes: 485.1 MB), 0.0 MB local output, Critical Path: 48.14s, Remote (99.40% of the time): [queue: 0.00%, setup: 5.59%, process: 74.44%]
|
||||
#INFO: Streaming build results to: http://sponge2/00c7f95f-6fbc-432d-8978-f5d361efca3b
|
||||
#INFO: Build completed successfully, 22455 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible,
|
||||
# or GPU drivers not setup properly.
|
||||
$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt
|
||||
```
|
||||
|
||||
|
@ -79,7 +73,7 @@ below and paste it into
|
|||
# MediaPipe graph that performs hand tracking on desktop with TensorFlow Lite
|
||||
# on CPU & GPU.
|
||||
# Used in the example in
|
||||
# mediapipie/examples/desktop/hand_tracking:hand_tracking_cpu.
|
||||
# mediapipe/examples/desktop/hand_tracking:hand_tracking_cpu.
|
||||
|
||||
# Images coming into and out of the graph.
|
||||
input_stream: "input_video"
|
||||
|
|
|
@ -100,8 +100,8 @@ see the Visualizing Subgraphs section in the
|
|||
```bash
|
||||
# MediaPipe graph that performs hand tracking with TensorFlow Lite on GPU.
|
||||
# Used in the examples in
|
||||
# mediapipie/examples/android/src/java/com/mediapipe/apps/handtrackinggpu and
|
||||
# mediapipie/examples/ios/handtrackinggpu.
|
||||
# mediapipe/examples/android/src/java/com/mediapipe/apps/handtrackinggpu and
|
||||
# mediapipe/examples/ios/handtrackinggpu.
|
||||
|
||||
# Images coming into and out of the graph.
|
||||
input_stream: "input_video"
|
||||
|
|
BIN
mediapipe/docs/images/mobile/aar_location.png
Normal file
After Width: | Height: | Size: 35 KiB |
BIN
mediapipe/docs/images/mobile/android_studio_opencv_location.png
Normal file
After Width: | Height: | Size: 75 KiB |
BIN
mediapipe/docs/images/mobile/assets_location.png
Normal file
After Width: | Height: | Size: 56 KiB |
After Width: | Height: | Size: 149 KiB |
BIN
mediapipe/docs/images/mobile/multi_hand_landmark_subgraph.png
Normal file
After Width: | Height: | Size: 193 KiB |
After Width: | Height: | Size: 213 KiB |
After Width: | Height: | Size: 2.2 MiB |
BIN
mediapipe/docs/images/mobile/multi_hand_tracking_android_gpu.gif
Normal file
After Width: | Height: | Size: 2.6 MiB |
BIN
mediapipe/docs/images/mobile/multi_hand_tracking_mobile.png
Normal file
After Width: | Height: | Size: 112 KiB |
BIN
mediapipe/docs/images/multi_hand_tracking_desktop.png
Normal file
After Width: | Height: | Size: 124 KiB |
BIN
mediapipe/docs/images/web_effect.gif
Normal file
After Width: | Height: | Size: 479 KiB |
BIN
mediapipe/docs/images/web_segmentation.gif
Normal file
After Width: | Height: | Size: 225 KiB |
|
@ -7,11 +7,8 @@ future.
|
|||
Note: If you plan to use TensorFlow calculators and example apps, there is a
|
||||
known issue with gcc and g++ version 6.3 and 7.3. Please use other versions.
|
||||
|
||||
Note: While Mediapipe configures TensorFlow, if you see the
|
||||
following error:
|
||||
`"...git_configure.bzl", line 14, in _fail fail(("%sGit Configuration
|
||||
Error:%s %...)))`,
|
||||
please install the python future library using: `$ pip install --user future`.
|
||||
Note: To make Mediapipe work with TensorFlow, please install the python "future"
|
||||
library and the python "six" library using `pip install --user future six`.
|
||||
|
||||
Choose your operating system:
|
||||
|
||||
|
@ -24,7 +21,8 @@ Choose your operating system:
|
|||
To build and run Android apps:
|
||||
|
||||
- [Setting up Android SDK and NDK](#setting-up-android-sdk-and-ndk)
|
||||
- [Setting up Android Studio with MediaPipe](#setting-up-android-studio-with-mediapipe)
|
||||
- [Using MediaPipe with Gradle](#using-mediapipe-with-gradle)
|
||||
- [Using MediaPipe with Bazel](#using-mediapipe-with-bazel)
|
||||
|
||||
To build and run iOS apps:
|
||||
|
||||
|
@ -51,8 +49,8 @@ To build and run iOS apps:
|
|||
# Run 'bazel version' to check version of bazel installed
|
||||
```
|
||||
|
||||
Option 2. Follow Bazel's
|
||||
[documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
|
||||
Option 2. Follow the official
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
|
||||
to install any version of Bazel manually.
|
||||
|
||||
3. Install OpenCV and FFmpeg.
|
||||
|
@ -75,10 +73,10 @@ To build and run iOS apps:
|
|||
[documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html)
|
||||
to manually build OpenCV from source code.
|
||||
|
||||
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to
|
||||
Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
|
||||
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
|
||||
in "/usr/local/", you need to update the "linux_opencv" new_local_repository
|
||||
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
like the following:
|
||||
|
||||
```bash
|
||||
|
@ -161,8 +159,8 @@ To build and run iOS apps:
|
|||
|
||||
2. Install Bazel (0.24.1 and above required).
|
||||
|
||||
Follow Bazel's
|
||||
[documentation](https://docs.bazel.build/versions/master/install-redhat.html)
|
||||
Follow the official
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-redhat.html)
|
||||
to install Bazel manually.
|
||||
|
||||
3. Install OpenCV.
|
||||
|
@ -178,10 +176,10 @@ To build and run iOS apps:
|
|||
|
||||
Option 2. Build OpenCV from source code.
|
||||
|
||||
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to
|
||||
Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
|
||||
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
|
||||
in "/usr/local/", you need to update the "linux_opencv" new_local_repository
|
||||
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
like the following:
|
||||
|
||||
```bash
|
||||
|
@ -237,7 +235,7 @@ To build and run iOS apps:
|
|||
|
||||
* Install [Homebrew](https://brew.sh).
|
||||
* Install [Xcode](https://developer.apple.com/xcode/) and its Command Line
|
||||
Tools.
|
||||
Tools by `xcode-select install`.
|
||||
|
||||
2. Checkout MediaPipe repository.
|
||||
|
||||
|
@ -257,8 +255,8 @@ To build and run iOS apps:
|
|||
# Run 'bazel version' to check version of bazel installed
|
||||
```
|
||||
|
||||
Option 2. Follow Bazel's
|
||||
[documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x)
|
||||
Option 2. Follow the official
|
||||
[Bazel documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x)
|
||||
to install any version of Bazel manually.
|
||||
|
||||
4. Install OpenCV and FFmpeg.
|
||||
|
@ -281,7 +279,7 @@ To build and run iOS apps:
|
|||
$ port install opencv
|
||||
```
|
||||
|
||||
Note: when using MacPorts, please edit the [`WORKSAPCE`],
|
||||
Note: when using MacPorts, please edit the [`WORKSPACE`],
|
||||
[`opencv_macos.BUILD`], and [`ffmpeg_macos.BUILD`] files like the following:
|
||||
|
||||
```bash
|
||||
|
@ -419,10 +417,10 @@ To build and run iOS apps:
|
|||
[documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html)
|
||||
to manually build OpenCV from source code.
|
||||
|
||||
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to
|
||||
Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
|
||||
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
|
||||
in "/usr/local/", you need to update the "linux_opencv" new_local_repository
|
||||
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
like the following:
|
||||
|
||||
```bash
|
||||
|
@ -581,6 +579,11 @@ export ANDROID_HOME=<path to the Android SDK>
|
|||
export ANDROID_NDK_HOME=<path to the Android NDK>
|
||||
```
|
||||
|
||||
In order to use MediaPipe on earlier Android versions, MediaPipe needs to switch
|
||||
to a lower Android API level. You can achieve this by specifying `api_level =
|
||||
<api level integer>` in android_ndk_repository() and/or android_sdk_repository()
|
||||
in the [`WORKSPACE`] file.
|
||||
|
||||
Please verify all the necessary packages are installed.
|
||||
|
||||
* Android SDK Platform API Level 28 or 29
|
||||
|
@ -589,10 +592,20 @@ Please verify all the necessary packages are installed.
|
|||
* Android SDK Tools 26.1.1
|
||||
* Android NDK 17c or above
|
||||
|
||||
### Setting up Android Studio with MediaPipe
|
||||
### Using MediaPipe with Gradle
|
||||
|
||||
The steps below use Android Studio 3.5 to build and install a MediaPipe example
|
||||
app.
|
||||
MediaPipe can be used within an existing project, such as a Gradle project,
|
||||
using the MediaPipe AAR target defined in mediapipe_aar.bzl. Please see the
|
||||
separate [MediaPipe Android Archive Library](./android_archive_library.md)
|
||||
documentation.
|
||||
|
||||
### Using MediaPipe with Bazel
|
||||
|
||||
The MediaPipe project can be imported to Android Studio using the Bazel plugins.
|
||||
This allows the MediaPipe examples and demos to be built and modified in Android
|
||||
Studio. To incorporate MediaPipe into an existing Android Studio project, see:
|
||||
"Using MediaPipe with Gradle". The steps below use Android Studio 3.5 to build
|
||||
and install a MediaPipe example app.
|
||||
|
||||
1. Install and launch Android Studio 3.5.
|
||||
|
||||
|
@ -682,7 +695,7 @@ app.
|
|||
* Press the `[+]` button to add the new configuration.
|
||||
* Select `Run` to run the example app on the connected Android device.
|
||||
|
||||
[`WORKSAPCE`]: https://github.com/google/mediapipe/tree/master/WORKSPACE
|
||||
[`WORKSPACE`]: https://github.com/google/mediapipe/tree/master/WORKSPACE
|
||||
[`opencv_linux.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_linux.BUILD
|
||||
[`opencv_macos.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_macos.BUILD
|
||||
[`ffmpeg_macos.BUILD`]:https://github.com/google/mediapipe/tree/master/third_party/ffmpeg_macos.BUILD
|
||||
|
|
177
mediapipe/docs/multi_hand_tracking_desktop.md
Normal file
|
@ -0,0 +1,177 @@
|
|||
## Multi-Hand Tracking on Desktop
|
||||
|
||||
This is an example of using MediaPipe to run hand tracking models (TensorFlow
|
||||
Lite) and render bounding boxes on the detected hand instances (for multiple
|
||||
hands). To know more about the hand tracking models, please refer to the model
|
||||
[`README file`]. Moreover, if you are interested in running the same TensorfFlow
|
||||
Lite model on Android/iOS, please see the
|
||||
[Mulit-Hand Tracking on GPU on Android/iOS](multi_hand_tracking_mobile_gpu.md)
|
||||
and
|
||||
|
||||
We show the hand tracking demos with TensorFlow Lite model using the Webcam:
|
||||
|
||||
- [TensorFlow Lite Multi-Hand Tracking Demo with Webcam (CPU)](#tensorflow-lite-multi-hand-tracking-demo-with-webcam-cpu)
|
||||
|
||||
- [TensorFlow Lite Multi-Hand Tracking Demo with Webcam (GPU)](#tensorflow-lite-multi-hand-tracking-demo-with-webcam-gpu)
|
||||
|
||||
Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please
|
||||
see
|
||||
[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md).
|
||||
|
||||
Note: If MediaPipe depends on OpenCV 2, please see the
|
||||
[known issues with OpenCV 2](#known-issues-with-opencv-2) section.
|
||||
|
||||
### TensorFlow Lite Multi-Hand Tracking Demo with Webcam (CPU)
|
||||
|
||||
To build and run the TensorFlow Lite example on desktop (CPU) with Webcam, run:
|
||||
|
||||
```bash
|
||||
# Video from webcam running on desktop CPU
|
||||
$ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
|
||||
mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_cpu
|
||||
|
||||
# It should print:
|
||||
#Target //mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_cpu up-to-date:
|
||||
# bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_cpu
|
||||
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_cpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop_live.pbtxt
|
||||
```
|
||||
|
||||
### TensorFlow Lite Multi-Hand Tracking Demo with Webcam (GPU)
|
||||
|
||||
To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run:
|
||||
|
||||
```bash
|
||||
# Video from webcam running on desktop GPU
|
||||
# This works only for linux currently
|
||||
$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
|
||||
mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_gpu
|
||||
|
||||
# It should print:
|
||||
# Target //mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_gpu up-to-date:
|
||||
# bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_gpu
|
||||
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible,
|
||||
# or GPU drivers not setup properly.
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_gpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt
|
||||
```
|
||||
|
||||
#### Graph
|
||||
|
||||

|
||||
|
||||
To visualize the graph as shown above, copy the text specification of the graph
|
||||
below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev).
|
||||
|
||||
```bash
|
||||
# MediaPipe graph that performs multi-hand tracking on desktop with TensorFlow
|
||||
# Lite on CPU.
|
||||
# Used in the example in
|
||||
# mediapipie/examples/desktop/multi_hand_tracking:multi_hand_tracking_cpu.
|
||||
|
||||
# Images coming into and out of the graph.
|
||||
input_stream: "input_video"
|
||||
output_stream: "output_video"
|
||||
|
||||
# Determines if an input vector of NormalizedRect has a size greater than or
|
||||
# equal to the provided min_size.
|
||||
node {
|
||||
calculator: "NormalizedRectVectorHasMinSizeCalculator"
|
||||
input_stream: "ITERABLE:prev_multi_hand_rects_from_landmarks"
|
||||
output_stream: "prev_has_enough_hands"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.CollectionHasMinSizeCalculatorOptions] {
|
||||
# This value can be changed to support tracking arbitrary number of hands.
|
||||
# Please also remember to modify max_vec_size in
|
||||
# ClipVectorSizeCalculatorOptions in
|
||||
# mediapipe/graphs/hand_tracking/subgraphs/multi_hand_detection_gpu.pbtxt
|
||||
min_size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Drops the incoming image if the previous frame had at least N hands.
|
||||
# Otherwise, passes the incoming image through to trigger a new round of hand
|
||||
# detection in MultiHandDetectionSubgraph.
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "input_video"
|
||||
input_stream: "DISALLOW:prev_has_enough_hands"
|
||||
output_stream: "multi_hand_detection_input_video"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.GateCalculatorOptions] {
|
||||
empty_packets_as_allow: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Subgraph that detections hands (see multi_hand_detection_cpu.pbtxt).
|
||||
node {
|
||||
calculator: "MultiHandDetectionSubgraph"
|
||||
input_stream: "multi_hand_detection_input_video"
|
||||
output_stream: "DETECTIONS:multi_palm_detections"
|
||||
output_stream: "NORM_RECTS:multi_palm_rects"
|
||||
}
|
||||
|
||||
# Subgraph that localizes hand landmarks for multiple hands (see
|
||||
# multi_hand_landmark.pbtxt).
|
||||
node {
|
||||
calculator: "MultiHandLandmarkSubgraph"
|
||||
input_stream: "IMAGE:input_video"
|
||||
input_stream: "NORM_RECTS:multi_hand_rects"
|
||||
output_stream: "LANDMARKS:multi_hand_landmarks"
|
||||
output_stream: "NORM_RECTS:multi_hand_rects_from_landmarks"
|
||||
}
|
||||
|
||||
# Caches a hand rectangle fed back from MultiHandLandmarkSubgraph, and upon the
|
||||
# arrival of the next input image sends out the cached rectangle with the
|
||||
# timestamp replaced by that of the input image, essentially generating a packet
|
||||
# that carries the previous hand rectangle. Note that upon the arrival of the
|
||||
# very first input image, an empty packet is sent out to jump start the
|
||||
# feedback loop.
|
||||
node {
|
||||
calculator: "PreviousLoopbackCalculator"
|
||||
input_stream: "MAIN:input_video"
|
||||
input_stream: "LOOP:multi_hand_rects_from_landmarks"
|
||||
input_stream_info: {
|
||||
tag_index: "LOOP"
|
||||
back_edge: true
|
||||
}
|
||||
output_stream: "PREV_LOOP:prev_multi_hand_rects_from_landmarks"
|
||||
}
|
||||
|
||||
# Performs association between NormalizedRect vector elements from previous
|
||||
# frame and those from the current frame if MultiHandDetectionSubgraph runs.
|
||||
# This calculator ensures that the output multi_hand_rects vector doesn't
|
||||
# contain overlapping regions based on the specified min_similarity_threshold.
|
||||
node {
|
||||
calculator: "AssociationNormRectCalculator"
|
||||
input_stream: "prev_multi_hand_rects_from_landmarks"
|
||||
input_stream: "multi_palm_rects"
|
||||
output_stream: "multi_hand_rects"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.AssociationCalculatorOptions] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Subgraph that renders annotations and overlays them on top of the input
|
||||
# images (see multi_hand_renderer_cpu.pbtxt).
|
||||
node {
|
||||
calculator: "MultiHandRendererSubgraph"
|
||||
input_stream: "IMAGE:input_video"
|
||||
input_stream: "DETECTIONS:multi_palm_detections"
|
||||
input_stream: "LANDMARKS:multi_hand_landmarks"
|
||||
input_stream: "NORM_RECTS:0:multi_palm_rects"
|
||||
input_stream: "NORM_RECTS:1:multi_hand_rects"
|
||||
output_stream: "IMAGE:output_video"
|
||||
}
|
||||
```
|
||||
|
||||
[`README file`]:https://github.com/google/mediapipe/tree/master/mediapipe/README.md
|
755
mediapipe/docs/multi_hand_tracking_mobile_gpu.md
Normal file
|
@ -0,0 +1,755 @@
|
|||
# Multi-Hand Tracking (GPU)
|
||||
|
||||
This doc focuses on the
|
||||
[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt)
|
||||
that performs multi-hand tracking with TensorFlow Lite on GPU. It is related to
|
||||
the [hand_tracking_example](./hand_tracking_mobile_gpu.md), and we recommend
|
||||
users to review the (single) hand tracking example first.
|
||||
|
||||

|
||||
|
||||
In the visualization above, the red dots represent the hand landmarks and the
|
||||
green lines are simply connections between selected landmark paris for
|
||||
visualization of the hand skeleton. When there are fewer than `N` hands (`N=2`
|
||||
in the graphs here), the purple box represents a hand rectangle that covers the
|
||||
entire hand, derived from hand detection (see
|
||||
[hand_detection_example](./hand_detection_mobile_gpu.md)). When there are `N`
|
||||
hands (i.e. 2 hands for the graphs here), the red boxes represent hand
|
||||
rectangles for each of the hands, derived from the previous round of hand
|
||||
landmark localization using an ML model (see also
|
||||
[model card](https://mediapipe.page.link/handmc)). Hand landmark localization
|
||||
for each hand is performed only within the hand rectangle for computational
|
||||
efficiency and accuracy. Hand detection is only invoked whenever there are fewer
|
||||
than `N` hands in the previous iteration.
|
||||
|
||||
This example can also run a model that localizes hand landmarks in 3D (i.e.,
|
||||
estimating an extra z coordinate):
|
||||
|
||||

|
||||
|
||||
In the visualization above, the localized hand landmarks are represented by dots
|
||||
in different shades, with the brighter ones denoting landmarks closer to the
|
||||
camera.
|
||||
|
||||
## Android
|
||||
|
||||
[Source](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu)
|
||||
|
||||
To build the app yourself, run:
|
||||
|
||||
```bash
|
||||
bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu
|
||||
```
|
||||
|
||||
To build for the 3D mode, run:
|
||||
|
||||
```bash
|
||||
bazel build -c opt --config=android_arm64 --define 3D=true mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu
|
||||
```
|
||||
|
||||
Once the app is built, install it on Android device with:
|
||||
|
||||
```bash
|
||||
adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/multihandtrackinggpu.apk
|
||||
```
|
||||
|
||||
## iOS
|
||||
|
||||
[Source](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/multihandtrackinggpu).
|
||||
|
||||
See the general [instructions](./mediapipe_ios_setup.md) for building iOS
|
||||
examples and generating an Xcode project. This will be the HandDetectionGpuApp
|
||||
target.
|
||||
|
||||
To build on the command line:
|
||||
|
||||
```bash
|
||||
bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/multihandtrackinggpu:MultiHandTrackingGpuApp
|
||||
```
|
||||
|
||||
To build for the 3D mode, run:
|
||||
|
||||
```bash
|
||||
bazel build -c opt --config=ios_arm64 --define 3D=true mediapipe/examples/ios/multihandtrackinggpu:MultiHandTrackingGpuApp
|
||||
```
|
||||
|
||||
## Graph
|
||||
|
||||
The multi-hand tracking [main graph](#main-graph) internal utilizes a
|
||||
[multi_hand_detection_subgraph](#multi-hand-detection-subgraph), a
|
||||
[multi_hand_landmark_subgraph](#multi-hand-landmark-subgraph), and a
|
||||
[multi_hand_renderer_subgraph](#multi-hand-renderer-subgraph).
|
||||
|
||||
The subgraphs show up in the main graph visualization as nodes colored in
|
||||
purple, and the subgraph itself can also be visualized just like a regular
|
||||
graph. For more information on how to visualize a graph that includes subgraphs,
|
||||
see the Visualizing Subgraphs section in the
|
||||
[visualizer documentation](./visualizer.md).
|
||||
|
||||
### Main Graph
|
||||
|
||||

|
||||
|
||||
There are two key differences between this graph and the
|
||||
[single_hand_tracking_mobile_graph](./hand_tracking_mobile_gpu.md).
|
||||
|
||||
1. There is a `NormalizedRectVectorHasMinSize` calculator, that checks if in
|
||||
input vector of `NormalizedRect` objects has a minimum size equal to `N`. In
|
||||
this graph, if the vector contains fewer than `N` objects,
|
||||
`MultiHandDetection` subgraph runs. Otherwise, the `GateCalculator` doesn't
|
||||
send any image packets to the `MultiHandDetection` subgraph. This way, the
|
||||
main graph is efficient in that it avoids running the costly hand detection
|
||||
step when there are already `N` hands in the frame.
|
||||
2. The `MergeCalculator` has been replaced by the `AssociationNormRect`
|
||||
calculator. This `AssociationNormRect` takes as input a vector of
|
||||
`NormalizedRect` objects from the `MultiHandDetection` subgraph on the
|
||||
current frame, and a vector of `NormalizedRect` objects from the
|
||||
`MultiHandLandmark` subgraph from the previous frame, and performs an
|
||||
association operation between these objects. This calculator ensures that
|
||||
the output vector doesn't contain overlapping regions based on the specified
|
||||
`min_similarity_threshold`.
|
||||
|
||||
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt)
|
||||
|
||||
```bash
|
||||
# MediaPipe graph that performs multi-hand tracking with TensorFlow Lite on GPU.
|
||||
# Used in the examples in
|
||||
# mediapipie/examples/android/src/java/com/mediapipe/apps/multihandtrackinggpu.
|
||||
|
||||
# Images coming into and out of the graph.
|
||||
input_stream: "input_video"
|
||||
output_stream: "output_video"
|
||||
|
||||
# Throttles the images flowing downstream for flow control. It passes through
|
||||
# the very first incoming image unaltered, and waits for downstream nodes
|
||||
# (calculators and subgraphs) in the graph to finish their tasks before it
|
||||
# passes through another image. All images that come in while waiting are
|
||||
# dropped, limiting the number of in-flight images in most part of the graph to
|
||||
# 1. This prevents the downstream nodes from queuing up incoming images and data
|
||||
# excessively, which leads to increased latency and memory usage, unwanted in
|
||||
# real-time mobile applications. It also eliminates unnecessarily computation,
|
||||
# e.g., the output produced by a node may get dropped downstream if the
|
||||
# subsequent nodes are still busy processing previous inputs.
|
||||
node {
|
||||
calculator: "FlowLimiterCalculator"
|
||||
input_stream: "input_video"
|
||||
input_stream: "FINISHED:multi_hand_rects"
|
||||
input_stream_info: {
|
||||
tag_index: "FINISHED"
|
||||
back_edge: true
|
||||
}
|
||||
output_stream: "throttled_input_video"
|
||||
}
|
||||
|
||||
# Determines if an input vector of NormalizedRect has a size greater than or
|
||||
# equal to the provided min_size.
|
||||
node {
|
||||
calculator: "NormalizedRectVectorHasMinSizeCalculator"
|
||||
input_stream: "ITERABLE:prev_multi_hand_rects_from_landmarks"
|
||||
output_stream: "prev_has_enough_hands"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.CollectionHasMinSizeCalculatorOptions] {
|
||||
# This value can be changed to support tracking arbitrary number of hands.
|
||||
# Please also remember to modify max_vec_size in
|
||||
# ClipVectorSizeCalculatorOptions in
|
||||
# mediapipe/graphs/hand_tracking/subgraphs/multi_hand_detection_gpu.pbtxt
|
||||
min_size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Drops the incoming image if the previous frame had at least N hands.
|
||||
# Otherwise, passes the incoming image through to trigger a new round of hand
|
||||
# detection in MultiHandDetectionSubgraph.
|
||||
node {
|
||||
calculator: "GateCalculator"
|
||||
input_stream: "throttled_input_video"
|
||||
input_stream: "DISALLOW:prev_has_enough_hands"
|
||||
output_stream: "multi_hand_detection_input_video"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.GateCalculatorOptions] {
|
||||
empty_packets_as_allow: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Subgraph that detections hands (see multi_hand_detection_gpu.pbtxt).
|
||||
node {
|
||||
calculator: "MultiHandDetectionSubgraph"
|
||||
input_stream: "multi_hand_detection_input_video"
|
||||
output_stream: "DETECTIONS:multi_palm_detections"
|
||||
output_stream: "NORM_RECTS:multi_palm_rects"
|
||||
}
|
||||
|
||||
# Subgraph that localizes hand landmarks for multiple hands (see
|
||||
# multi_hand_landmark.pbtxt).
|
||||
node {
|
||||
calculator: "MultiHandLandmarkSubgraph"
|
||||
input_stream: "IMAGE:throttled_input_video"
|
||||
input_stream: "NORM_RECTS:multi_hand_rects"
|
||||
output_stream: "LANDMARKS:multi_hand_landmarks"
|
||||
output_stream: "NORM_RECTS:multi_hand_rects_from_landmarks"
|
||||
}
|
||||
|
||||
# Caches a hand rectangle fed back from MultiHandLandmarkSubgraph, and upon the
|
||||
# arrival of the next input image sends out the cached rectangle with the
|
||||
# timestamp replaced by that of the input image, essentially generating a packet
|
||||
# that carries the previous hand rectangle. Note that upon the arrival of the
|
||||
# very first input image, an empty packet is sent out to jump start the
|
||||
# feedback loop.
|
||||
node {
|
||||
calculator: "PreviousLoopbackCalculator"
|
||||
input_stream: "MAIN:throttled_input_video"
|
||||
input_stream: "LOOP:multi_hand_rects_from_landmarks"
|
||||
input_stream_info: {
|
||||
tag_index: "LOOP"
|
||||
back_edge: true
|
||||
}
|
||||
output_stream: "PREV_LOOP:prev_multi_hand_rects_from_landmarks"
|
||||
}
|
||||
|
||||
# Performs association between NormalizedRect vector elements from previous
|
||||
# frame and those from the current frame if MultiHandDetectionSubgraph runs.
|
||||
# This calculator ensures that the output multi_hand_rects vector doesn't
|
||||
# contain overlapping regions based on the specified min_similarity_threshold.
|
||||
node {
|
||||
calculator: "AssociationNormRectCalculator"
|
||||
input_stream: "prev_multi_hand_rects_from_landmarks"
|
||||
input_stream: "multi_palm_rects"
|
||||
output_stream: "multi_hand_rects"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.AssociationCalculatorOptions] {
|
||||
min_similarity_threshold: 0.1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Subgraph that renders annotations and overlays them on top of the input
|
||||
# images (see multi_hand_renderer_gpu.pbtxt).
|
||||
node {
|
||||
calculator: "MultiHandRendererSubgraph"
|
||||
input_stream: "IMAGE:throttled_input_video"
|
||||
input_stream: "DETECTIONS:multi_palm_detections"
|
||||
input_stream: "LANDMARKS:multi_hand_landmarks"
|
||||
input_stream: "NORM_RECTS:0:multi_palm_rects"
|
||||
input_stream: "NORM_RECTS:1:multi_hand_rects"
|
||||
output_stream: "IMAGE:output_video"
|
||||
}
|
||||
```
|
||||
|
||||
### Multi-Hand Detection Subgraph
|
||||
|
||||

|
||||
|
||||
This graph outputs a vector of `NormalizedRect` objects corresponding to each of
|
||||
the hand instances visible in the frame. Note that at the end of this graph,
|
||||
there is a `ClipNormalizedRectVectorSizeCalculator`. This calculator clips the
|
||||
size of the input vector to a maximum size `N`. This implies that the
|
||||
`MultiHandDetection` subgraph outputs a vector of maximum `N` hand instance
|
||||
locations.
|
||||
|
||||
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/multi_hand_detection_gpu.pbtxt)
|
||||
|
||||
```bash
|
||||
# MediaPipe multi-hand detection subgraph.
|
||||
|
||||
type: "MultiHandDetectionSubgraph"
|
||||
|
||||
input_stream: "input_video"
|
||||
output_stream: "DETECTIONS:palm_detections"
|
||||
output_stream: "NORM_RECTS:clipped_hand_rects_from_palm_detections"
|
||||
|
||||
# Transforms the input image on GPU to a 256x256 image. To scale the input
|
||||
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
|
||||
# resulting in potential letterboxing in the transformed image.
|
||||
node: {
|
||||
calculator: "ImageTransformationCalculator"
|
||||
input_stream: "IMAGE_GPU:input_video"
|
||||
output_stream: "IMAGE_GPU:transformed_input_video"
|
||||
output_stream: "LETTERBOX_PADDING:letterbox_padding"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
|
||||
output_width: 256
|
||||
output_height: 256
|
||||
scale_mode: FIT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Generates a single side packet containing a TensorFlow Lite op resolver that
|
||||
# supports custom ops needed by the model used in this graph.
|
||||
node {
|
||||
calculator: "TfLiteCustomOpResolverCalculator"
|
||||
output_side_packet: "opresolver"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] {
|
||||
use_gpu: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Converts the transformed input image on GPU into an image tensor stored as a
|
||||
# TfLiteTensor.
|
||||
node {
|
||||
calculator: "TfLiteConverterCalculator"
|
||||
input_stream: "IMAGE_GPU:transformed_input_video"
|
||||
output_stream: "TENSORS_GPU:image_tensor"
|
||||
}
|
||||
|
||||
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
|
||||
# vector of tensors representing, for instance, detection boxes/keypoints and
|
||||
# scores.
|
||||
node {
|
||||
calculator: "TfLiteInferenceCalculator"
|
||||
input_stream: "TENSORS_GPU:image_tensor"
|
||||
output_stream: "TENSORS_GPU:detection_tensors"
|
||||
input_side_packet: "CUSTOM_OP_RESOLVER:opresolver"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
|
||||
model_path: "mediapipe/models/palm_detection.tflite"
|
||||
use_gpu: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Generates a single side packet containing a vector of SSD anchors based on
|
||||
# the specification in the options.
|
||||
node {
|
||||
calculator: "SsdAnchorsCalculator"
|
||||
output_side_packet: "anchors"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] {
|
||||
num_layers: 5
|
||||
min_scale: 0.1171875
|
||||
max_scale: 0.75
|
||||
input_size_height: 256
|
||||
input_size_width: 256
|
||||
anchor_offset_x: 0.5
|
||||
anchor_offset_y: 0.5
|
||||
strides: 8
|
||||
strides: 16
|
||||
strides: 32
|
||||
strides: 32
|
||||
strides: 32
|
||||
aspect_ratios: 1.0
|
||||
fixed_anchor_size: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Decodes the detection tensors generated by the TensorFlow Lite model, based on
|
||||
# the SSD anchors and the specification in the options, into a vector of
|
||||
# detections. Each detection describes a detected object.
|
||||
node {
|
||||
calculator: "TfLiteTensorsToDetectionsCalculator"
|
||||
input_stream: "TENSORS_GPU:detection_tensors"
|
||||
input_side_packet: "ANCHORS:anchors"
|
||||
output_stream: "DETECTIONS:detections"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] {
|
||||
num_classes: 1
|
||||
num_boxes: 2944
|
||||
num_coords: 18
|
||||
box_coord_offset: 0
|
||||
keypoint_coord_offset: 4
|
||||
num_keypoints: 7
|
||||
num_values_per_keypoint: 2
|
||||
sigmoid_score: true
|
||||
score_clipping_thresh: 100.0
|
||||
reverse_output_order: true
|
||||
|
||||
x_scale: 256.0
|
||||
y_scale: 256.0
|
||||
h_scale: 256.0
|
||||
w_scale: 256.0
|
||||
min_score_thresh: 0.7
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Performs non-max suppression to remove excessive detections.
|
||||
node {
|
||||
calculator: "NonMaxSuppressionCalculator"
|
||||
input_stream: "detections"
|
||||
output_stream: "filtered_detections"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] {
|
||||
min_suppression_threshold: 0.3
|
||||
overlap_type: INTERSECTION_OVER_UNION
|
||||
algorithm: WEIGHTED
|
||||
return_empty_detections: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Maps detection label IDs to the corresponding label text ("Palm"). The label
|
||||
# map is provided in the label_map_path option.
|
||||
node {
|
||||
calculator: "DetectionLabelIdToTextCalculator"
|
||||
input_stream: "filtered_detections"
|
||||
output_stream: "labeled_detections"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] {
|
||||
label_map_path: "mediapipe/models/palm_detection_labelmap.txt"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
|
||||
# letterboxed image (after image transformation with the FIT scale mode) to the
|
||||
# corresponding locations on the same image with the letterbox removed (the
|
||||
# input image to the graph before image transformation).
|
||||
node {
|
||||
calculator: "DetectionLetterboxRemovalCalculator"
|
||||
input_stream: "DETECTIONS:labeled_detections"
|
||||
input_stream: "LETTERBOX_PADDING:letterbox_padding"
|
||||
output_stream: "DETECTIONS:palm_detections"
|
||||
}
|
||||
|
||||
# Extracts image size from the input images.
|
||||
node {
|
||||
calculator: "ImagePropertiesCalculator"
|
||||
input_stream: "IMAGE_GPU:input_video"
|
||||
output_stream: "SIZE:image_size"
|
||||
}
|
||||
|
||||
# Converts each palm detection into a rectangle (normalized by image size)
|
||||
# that encloses the palm and is rotated such that the line connecting center of
|
||||
# the wrist and MCP of the middle finger is aligned with the Y-axis of the
|
||||
# rectangle.
|
||||
node {
|
||||
calculator: "DetectionsToRectsCalculator"
|
||||
input_stream: "DETECTIONS:palm_detections"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "NORM_RECTS:palm_rects"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.DetectionsToRectsCalculatorOptions] {
|
||||
rotation_vector_start_keypoint_index: 0 # Center of wrist.
|
||||
rotation_vector_end_keypoint_index: 2 # MCP of middle finger.
|
||||
rotation_vector_target_angle_degrees: 90
|
||||
output_zero_rect_for_empty_detections: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Expands and shifts the rectangle that contains the palm so that it's likely
|
||||
# to cover the entire hand.
|
||||
node {
|
||||
calculator: "RectTransformationCalculator"
|
||||
input_stream: "NORM_RECTS:palm_rects"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "hand_rects_from_palm_detections"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.RectTransformationCalculatorOptions] {
|
||||
scale_x: 2.6
|
||||
scale_y: 2.6
|
||||
shift_y: -0.5
|
||||
square_long: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Clips the size of the input vector to the provided max_vec_size. This
|
||||
# determines the maximum number of hand instances this graph outputs.
|
||||
# Note that the performance gain of clipping detections earlier in this graph is
|
||||
# minimal because NMS will minimize overlapping detections and the number of
|
||||
# detections isn't expected to exceed 5-10.
|
||||
node {
|
||||
calculator: "ClipNormalizedRectVectorSizeCalculator"
|
||||
input_stream: "hand_rects_from_palm_detections"
|
||||
output_stream: "clipped_hand_rects_from_palm_detections"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.ClipVectorSizeCalculatorOptions] {
|
||||
# This value can be changed to support tracking arbitrary number of hands.
|
||||
# Please also remember to modify min_size in
|
||||
# CollectionHsMinSizeCalculatorOptions in
|
||||
# mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt and
|
||||
# mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop_live.pbtxt.
|
||||
max_vec_size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Multi-Hand Landmark Subgraph
|
||||
|
||||

|
||||
|
||||
This graph accepts as input a vector of `NormalizedRect` objects, corresponding
|
||||
the the region of each hand instance in the input image. For each
|
||||
`NormalizedRect` object, the graph runs the existing `HandLandmark` subgraph and
|
||||
collect the outputs of this subgraph into vectors. This is enabled by
|
||||
`BeginLoop` and `EndLoop` calculators.
|
||||
|
||||
The `BeginLoop` calculator accepts as input a packet containing an iterable
|
||||
collection of elements. This calculator is templatized (see
|
||||
[begin_loop_calculator.h](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/begin_loop_calculator.h)).
|
||||
If the input packet arrived at a timestamp `ts`, this calculator outputs each
|
||||
element in the collection at a fake timestamp `internal_ts`. At the end of the
|
||||
collection, the calculator outputs the arrival timestamp `ts` in the output
|
||||
stream tagged with `BATCH_END`.
|
||||
|
||||
The nodes between the `BeginLoop` calculator and the corresponding `EndLoop`
|
||||
calculator process individual packets at the fake timestamps `internal_ts`.
|
||||
After each element is processed, it is sent to the `EndLoop` calculator (see
|
||||
[end_loop_calculator.h](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/end_loop_calculator.h)),
|
||||
which collects these elements in an output collection. The `EndLoop` calculator
|
||||
listens for packets from the `BATCH_END` output stream of the `BeginLoop`
|
||||
calculator. When the `BATCH_END` packet containing the real timestamp `ts`
|
||||
arrives at the `EndLoop` calculator, the `EndLoop` calculator outputs a packet
|
||||
containing the collection of processed elements at the real timestamp `ts`.
|
||||
|
||||
In the multi-hand landmark subgraph, the `EndLoop` calculators collect the
|
||||
output vector of hand landmarks per hand instance, the boolean values indicating
|
||||
the presence of each hand and the `NormalizedRect` objects corresponding to the
|
||||
regions surrounding each hand into vectors.
|
||||
|
||||
Finally, based on the hand presence boolean value, the graph filters the
|
||||
collections of hand landmarks and `NormalizdRect` objects corresponding to each
|
||||
hand instance.
|
||||
|
||||
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/multi_hand_landmark.pbtxt)
|
||||
|
||||
```bash
|
||||
# MediaPipe hand landmark localization subgraph.
|
||||
|
||||
type: "MultiHandLandmarkSubgraph"
|
||||
|
||||
input_stream: "IMAGE:input_video"
|
||||
# A vector of NormalizedRect, one per each hand detected.
|
||||
input_stream: "NORM_RECTS:multi_hand_rects"
|
||||
# A vector of NormalizedLandmarks, one set per each hand.
|
||||
output_stream: "LANDMARKS:filtered_multi_hand_landmarks"
|
||||
# A vector of NormalizedRect, one per each hand.
|
||||
output_stream: "NORM_RECTS:filtered_multi_hand_rects_for_next_frame"
|
||||
|
||||
# Outputs each element of multi_hand_rects at a fake timestamp for the rest
|
||||
# of the graph to process. Clones the input_video packet for each
|
||||
# single_hand_rect at the fake timestamp. At the end of the loop,
|
||||
# outputs the BATCH_END timestamp for downstream calculators to inform them
|
||||
# that all elements in the vector have been processed.
|
||||
node {
|
||||
calculator: "BeginLoopNormalizedRectCalculator"
|
||||
input_stream: "ITERABLE:multi_hand_rects"
|
||||
input_stream: "CLONE:input_video"
|
||||
output_stream: "ITEM:single_hand_rect"
|
||||
output_stream: "CLONE:input_video_cloned"
|
||||
output_stream: "BATCH_END:single_hand_rect_timestamp"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "HandLandmarkSubgraph"
|
||||
input_stream: "IMAGE:input_video_cloned"
|
||||
input_stream: "NORM_RECT:single_hand_rect"
|
||||
output_stream: "LANDMARKS:single_hand_landmarks"
|
||||
output_stream: "NORM_RECT:single_hand_rect_from_landmarks"
|
||||
output_stream: "PRESENCE:single_hand_presence"
|
||||
}
|
||||
|
||||
# Collects the boolean presence value for each single hand into a vector. Upon
|
||||
# receiving the BATCH_END timestamp, outputs a vector of boolean values at the
|
||||
# BATCH_END timestamp.
|
||||
node {
|
||||
calculator: "EndLoopBooleanCalculator"
|
||||
input_stream: "ITEM:single_hand_presence"
|
||||
input_stream: "BATCH_END:single_hand_rect_timestamp"
|
||||
output_stream: "ITERABLE:multi_hand_presence"
|
||||
}
|
||||
|
||||
# Collects a set of landmarks for each hand into a vector. Upon receiving the
|
||||
# BATCH_END timestamp, outputs the vector of landmarks at the BATCH_END
|
||||
# timestamp.
|
||||
node {
|
||||
calculator: "EndLoopNormalizedLandmarksVectorCalculator"
|
||||
input_stream: "ITEM:single_hand_landmarks"
|
||||
input_stream: "BATCH_END:single_hand_rect_timestamp"
|
||||
output_stream: "ITERABLE:multi_hand_landmarks"
|
||||
}
|
||||
|
||||
# Collects a NormalizedRect for each hand into a vector. Upon receiving the
|
||||
# BATCH_END timestamp, outputs the vector of NormalizedRect at the BATCH_END
|
||||
# timestamp.
|
||||
node {
|
||||
calculator: "EndLoopNormalizedRectCalculator"
|
||||
input_stream: "ITEM:single_hand_rect_from_landmarks"
|
||||
input_stream: "BATCH_END:single_hand_rect_timestamp"
|
||||
output_stream: "ITERABLE:multi_hand_rects_for_next_frame"
|
||||
}
|
||||
|
||||
# Filters the input vector of landmarks based on hand presence value for each
|
||||
# hand. If the hand presence for hand #i is false, the set of landmarks
|
||||
# corresponding to that hand are dropped from the vector.
|
||||
node {
|
||||
calculator: "FilterLandmarksCollectionCalculator"
|
||||
input_stream: "ITERABLE:multi_hand_landmarks"
|
||||
input_stream: "CONDITION:multi_hand_presence"
|
||||
output_stream: "ITERABLE:filtered_multi_hand_landmarks"
|
||||
}
|
||||
|
||||
# Filters the input vector of NormalizedRect based on hand presence value for
|
||||
# each hand. If the hand presence for hand #i is false, the NormalizedRect
|
||||
# corresponding to that hand are dropped from the vector.
|
||||
node {
|
||||
calculator: "FilterNormalizedRectCollectionCalculator"
|
||||
input_stream: "ITERABLE:multi_hand_rects_for_next_frame"
|
||||
input_stream: "CONDITION:multi_hand_presence"
|
||||
output_stream: "ITERABLE:filtered_multi_hand_rects_for_next_frame"
|
||||
}
|
||||
```
|
||||
|
||||
### Multi-Hand Renderer Subgraph
|
||||
|
||||

|
||||
|
||||
This graph also uses `BeginLoop` and `EndLoop` calculators to iteratively
|
||||
convert a set of hand landmarks per hand instance into corresponding
|
||||
`RenderData` objects.
|
||||
|
||||
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/multi_hand_renderer_gpu.pbtxt)
|
||||
|
||||
```bash
|
||||
# MediaPipe multi-hand tracking rendering subgraph.
|
||||
|
||||
type: "MultiHandRendererSubgraph"
|
||||
|
||||
input_stream: "IMAGE:input_image"
|
||||
# A vector of NormalizedLandmarks, one for each hand.
|
||||
input_stream: "LANDMARKS:multi_hand_landmarks"
|
||||
# A vector of NormalizedRect, one for each hand.
|
||||
input_stream: "NORM_RECTS:0:multi_palm_rects"
|
||||
# A vector of NormalizedRect, one for each hand.
|
||||
input_stream: "NORM_RECTS:1:multi_hand_rects"
|
||||
# A vector of Detection, one for each hand.
|
||||
input_stream: "DETECTIONS:palm_detections"
|
||||
output_stream: "IMAGE:output_image"
|
||||
|
||||
# Converts detections to drawing primitives for annotation overlay.
|
||||
node {
|
||||
calculator: "DetectionsToRenderDataCalculator"
|
||||
input_stream: "DETECTIONS:palm_detections"
|
||||
output_stream: "RENDER_DATA:detection_render_data"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] {
|
||||
thickness: 4.0
|
||||
color { r: 0 g: 255 b: 0 }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Converts normalized rects to drawing primitives for annotation overlay.
|
||||
node {
|
||||
calculator: "RectToRenderDataCalculator"
|
||||
input_stream: "NORM_RECTS:multi_hand_rects"
|
||||
output_stream: "RENDER_DATA:multi_hand_rects_render_data"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] {
|
||||
filled: false
|
||||
color { r: 255 g: 0 b: 0 }
|
||||
thickness: 4.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Converts normalized rects to drawing primitives for annotation overlay.
|
||||
node {
|
||||
calculator: "RectToRenderDataCalculator"
|
||||
input_stream: "NORM_RECTS:multi_palm_rects"
|
||||
output_stream: "RENDER_DATA:multi_palm_rects_render_data"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] {
|
||||
filled: false
|
||||
color { r: 125 g: 0 b: 122 }
|
||||
thickness: 4.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Outputs each element of multi_palm_landmarks at a fake timestamp for the rest
|
||||
# of the graph to process. At the end of the loop, outputs the BATCH_END
|
||||
# timestamp for downstream calculators to inform them that all elements in the
|
||||
# vector have been processed.
|
||||
node {
|
||||
calculator: "BeginLoopNormalizedLandmarksVectorCalculator"
|
||||
input_stream: "ITERABLE:multi_hand_landmarks"
|
||||
output_stream: "ITEM:single_hand_landmarks"
|
||||
output_stream: "BATCH_END:landmark_timestamp"
|
||||
}
|
||||
|
||||
# Converts landmarks to drawing primitives for annotation overlay.
|
||||
node {
|
||||
calculator: "LandmarksToRenderDataCalculator"
|
||||
input_stream: "NORM_LANDMARKS:single_hand_landmarks"
|
||||
output_stream: "RENDER_DATA:single_hand_landmark_render_data"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
|
||||
landmark_connections: 0
|
||||
landmark_connections: 1
|
||||
landmark_connections: 1
|
||||
landmark_connections: 2
|
||||
landmark_connections: 2
|
||||
landmark_connections: 3
|
||||
landmark_connections: 3
|
||||
landmark_connections: 4
|
||||
landmark_connections: 0
|
||||
landmark_connections: 5
|
||||
landmark_connections: 5
|
||||
landmark_connections: 6
|
||||
landmark_connections: 6
|
||||
landmark_connections: 7
|
||||
landmark_connections: 7
|
||||
landmark_connections: 8
|
||||
landmark_connections: 5
|
||||
landmark_connections: 9
|
||||
landmark_connections: 9
|
||||
landmark_connections: 10
|
||||
landmark_connections: 10
|
||||
landmark_connections: 11
|
||||
landmark_connections: 11
|
||||
landmark_connections: 12
|
||||
landmark_connections: 9
|
||||
landmark_connections: 13
|
||||
landmark_connections: 13
|
||||
landmark_connections: 14
|
||||
landmark_connections: 14
|
||||
landmark_connections: 15
|
||||
landmark_connections: 15
|
||||
landmark_connections: 16
|
||||
landmark_connections: 13
|
||||
landmark_connections: 17
|
||||
landmark_connections: 0
|
||||
landmark_connections: 17
|
||||
landmark_connections: 17
|
||||
landmark_connections: 18
|
||||
landmark_connections: 18
|
||||
landmark_connections: 19
|
||||
landmark_connections: 19
|
||||
landmark_connections: 20
|
||||
landmark_color { r: 255 g: 0 b: 0 }
|
||||
connection_color { r: 0 g: 255 b: 0 }
|
||||
thickness: 4.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Collects a RenderData object for each hand into a vector. Upon receiving the
|
||||
# BATCH_END timestamp, outputs the vector of RenderData at the BATCH_END
|
||||
# timestamp.
|
||||
node {
|
||||
calculator: "EndLoopRenderDataCalculator"
|
||||
input_stream: "ITEM:single_hand_landmark_render_data"
|
||||
input_stream: "BATCH_END:landmark_timestamp"
|
||||
output_stream: "ITERABLE:multi_hand_landmarks_render_data"
|
||||
}
|
||||
|
||||
# Draws annotations and overlays them on top of the input images. Consumes
|
||||
# a vector of RenderData objects and draws each of them on the input frame.
|
||||
node {
|
||||
calculator: "AnnotationOverlayCalculator"
|
||||
input_stream: "INPUT_FRAME_GPU:input_image"
|
||||
input_stream: "detection_render_data"
|
||||
input_stream: "multi_hand_rects_render_data"
|
||||
input_stream: "multi_palm_rects_render_data"
|
||||
input_stream: "VECTOR:0:multi_hand_landmarks_render_data"
|
||||
output_stream: "OUTPUT_FRAME_GPU:output_image"
|
||||
}
|
||||
```
|
|
@ -35,10 +35,9 @@ $ bazel build -c opt \
|
|||
# INFO: 2675 processes: 2673 linux-sandbox, 2 local.
|
||||
# INFO: Build completed successfully, 2807 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# Replace <input video path> and <output video path>.
|
||||
# You can find a test video in mediapipe/examples/desktop/object_detection.
|
||||
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \
|
||||
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \
|
||||
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>
|
||||
```
|
||||
|
@ -55,7 +54,7 @@ below and paste it into
|
|||
# MediaPipe graph that performs object detection on desktop with TensorFlow
|
||||
# on CPU.
|
||||
# Used in the example in
|
||||
# mediapipie/examples/desktop/object_detection:object_detection_tensorflow.
|
||||
# mediapipe/examples/desktop/object_detection:object_detection_tensorflow.
|
||||
|
||||
# Decodes an input video file into images and a video header.
|
||||
node {
|
||||
|
@ -200,10 +199,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
|
|||
# INFO: 711 processes: 710 linux-sandbox, 1 local.
|
||||
# INFO: Build completed successfully, 734 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# Replace <input video path> and <output video path>.
|
||||
# You can find a test video in mediapipe/examples/desktop/object_detection.
|
||||
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \
|
||||
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \
|
||||
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>
|
||||
```
|
||||
|
@ -220,14 +218,11 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
|
|||
# It should print:
|
||||
#Target //mediapipe/examples/desktop/object_detection:object_detection_cpu up-to-date:
|
||||
# bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu
|
||||
#INFO: Elapsed time: 16.020s, Forge stats: 13001/13003 actions cached, 2.1s CPU used, 0.0s queue time, 89.0 MB ObjFS output (novel bytes: 88.0 MB), 0.0 MB local output, Critical Path: 10.01s, Remote (41.42% of the time): [queue: 0.00%, setup: 4.21%, process: 12.48%]
|
||||
#INFO: Streaming build results to: http://sponge2/1824d4cc-ba63-4350-bdc0-aacbd45b902b
|
||||
#INFO: Build completed successfully, 12154 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible
|
||||
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt
|
||||
```
|
||||
|
||||
|
@ -243,7 +238,7 @@ below and paste it into
|
|||
# MediaPipe graph that performs object detection on desktop with TensorFlow Lite
|
||||
# on CPU.
|
||||
# Used in the example in
|
||||
# mediapipie/examples/desktop/object_detection:object_detection_tflite.
|
||||
# mediapipe/examples/desktop/object_detection:object_detection_tflite.
|
||||
|
||||
# max_queue_size limits the number of packets enqueued on any input stream
|
||||
# by throttling inputs to the graph. This makes the graph only process one
|
||||
|
|
26
mediapipe/docs/web.md
Normal file
|
@ -0,0 +1,26 @@
|
|||
## MediaPipe on the Web
|
||||
|
||||
MediaPipe on the Web is an effort to use [WebAssembly](https://webassembly.org/)
|
||||
to bring MediaPipe graphs, calculators, and related technologies to the web. The
|
||||
aim is to have all the pieces (ML, rendering, and processing) running directly
|
||||
in the browser client-side. The official API is under construction, but the core
|
||||
technology has been proven effective, and we can already show interactive
|
||||
cross-platform demos using your live webcam.
|
||||
|
||||
 
|
||||
|
||||
### Hand Tracking (with and without SIMD support)
|
||||
|
||||
For [Chrome Developer Summit 2019](https://developer.chrome.com/devsummit/), we
|
||||
used this technology to showcase the potential for performance improvements
|
||||
using Chrome experimental [WebAssembly SIMD](https://github.com/WebAssembly/simd)
|
||||
support. Below are two different versions of the
|
||||
[MediaPipe Hand Tracking Example](https://mediapipe.readthedocs.io/en/latest/hand_tracking_desktop.html)
|
||||
running on the web:
|
||||
|
||||
1. WebAssembly MVP [demo](https://mediapipe.page.link/cds-ht) running around 5-8 frames per second on Desktop Chrome
|
||||
|
||||
2. WebAssembly SIMD [demo](https://mediapipe.page.link/cds-ht-simd) running around 15-18 frames per second on *Canary* Chrome for Desktop, which must additionally be launched with the option `--js-flags="--experimental-wasm-simd"`
|
||||
|
||||
|
||||
NOTE: This page is a work-in-progress. More to come soon!
|
|
@ -1,9 +1,11 @@
|
|||
## Extracting Video Features for YouTube-8M Challenge
|
||||
# Feature Extration and Model Inference for YouTube-8M Challenge
|
||||
|
||||
MediaPipe is a useful and general framework for media processing that can assist
|
||||
with research, development, and deployment of ML models. This example focuses on
|
||||
model development by demonstrating how to prepare training data for the
|
||||
YouTube-8M Challenge.
|
||||
model development by demonstrating how to prepare training data and do model
|
||||
inference for the YouTube-8M Challenge.
|
||||
|
||||
## Extracting Video Features for YouTube-8M Challenge
|
||||
|
||||
[Youtube-8M Challenge](https://www.kaggle.com/c/youtube8m-2019) is an annual
|
||||
video classification challenge hosted by Google. Over the last two years, the
|
||||
|
@ -29,14 +31,14 @@ videos.
|
|||
|
||||
### Steps to run the YouTube-8M feature extraction graph
|
||||
|
||||
1. Checkout the mediapipe repository
|
||||
1. Checkout the mediapipe repository.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/google/mediapipe.git
|
||||
cd mediapipe
|
||||
```
|
||||
|
||||
2. Download the PCA and model data
|
||||
2. Download the PCA and model data.
|
||||
|
||||
```bash
|
||||
mkdir /tmp/mediapipe
|
||||
|
@ -49,7 +51,7 @@ videos.
|
|||
tar -xvf /tmp/mediapipe/inception-2015-12-05.tgz
|
||||
```
|
||||
|
||||
3. Get the VGGish frozen graph
|
||||
3. Get the VGGish frozen graph.
|
||||
|
||||
Note: To run step 3 and step 4, you must have Python 2.7 or 3.5+ installed
|
||||
with the TensorFlow 1.14+ package installed.
|
||||
|
@ -60,24 +62,112 @@ videos.
|
|||
python -m mediapipe.examples.desktop.youtube8m.generate_vggish_frozen_graph
|
||||
```
|
||||
|
||||
4. Generate a MediaSequence metadata from the input video
|
||||
4. Generate a MediaSequence metadata from the input video.
|
||||
|
||||
Note: the output file is /tmp/mediapipe/metadata.tfrecord
|
||||
Note: the output file is /tmp/mediapipe/metadata.pb
|
||||
|
||||
```bash
|
||||
# change clip_end_time_sec to match the length of your video.
|
||||
python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \
|
||||
--path_to_input_video=/absolute/path/to/the/local/video/file
|
||||
--path_to_input_video=/absolute/path/to/the/local/video/file \
|
||||
--clip_end_time_sec=120
|
||||
```
|
||||
|
||||
5. Run the MediaPipe binary to extract the features
|
||||
5. Run the MediaPipe binary to extract the features.
|
||||
|
||||
```bash
|
||||
bazel build -c opt \
|
||||
--define MEDIAPIPE_DISABLE_GPU=1 --define no_aws_support=true \
|
||||
mediapipe/examples/desktop/youtube8m:extract_yt8m_features
|
||||
|
||||
./bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features
|
||||
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features \
|
||||
--calculator_graph_config_file=mediapipe/graphs/youtube8m/feature_extraction.pbtxt \
|
||||
--input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.tfrecord \
|
||||
--output_side_packets=output_sequence_example=/tmp/mediapipe/output.tfrecord
|
||||
--input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.pb \
|
||||
--output_side_packets=output_sequence_example=/tmp/mediapipe/features.pb
|
||||
```
|
||||
|
||||
6. [Optional] Read the features.pb in Python.
|
||||
|
||||
```
|
||||
import tensorflow as tf
|
||||
|
||||
sequence_example = open('/tmp/mediapipe/features.pb', 'rb').read()
|
||||
print(tf.train.SequenceExample.FromString(sequence_example))
|
||||
```
|
||||
|
||||
## Model Inference for YouTube-8M Challenge
|
||||
|
||||
MediaPipe can help you do model inference for YouTube-8M Challenge with both
|
||||
local videos and the YouTube-8M dataset. To visualize
|
||||
[the graph for local videos](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt)
|
||||
and
|
||||
[the graph for the YouTube-8M dataset](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/youtube8m/yt8m_dataset_model_inference.pbtxt),
|
||||
copy the text specification of the graph and paste it into
|
||||
[MediaPipe Visualizer](https://viz.mediapipe.dev/). We use the baseline model
|
||||
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
|
||||
in our example. But, the model inference pipeline is highly customizable. You
|
||||
are welcome to add new calculators or use your own machine learning models to do
|
||||
the inference for both local videos and the dataset
|
||||
|
||||
### Steps to run the YouTube-8M model inference graph with Web Interface
|
||||
|
||||
1. Copy the baseline model
|
||||
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
|
||||
to local.
|
||||
|
||||
```bash
|
||||
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
|
||||
|
||||
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
|
||||
```
|
||||
|
||||
2. Build the inference binary.
|
||||
|
||||
```bash
|
||||
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
|
||||
mediapipe/examples/desktop/youtube8m:model_inference
|
||||
```
|
||||
|
||||
3. Run the python web server.
|
||||
|
||||
Note: pip install absl-py
|
||||
|
||||
```bash
|
||||
python mediapipe/examples/desktop/youtube8m/viewer/server.py --root `pwd`
|
||||
```
|
||||
|
||||
Navigate to localhost:8008 in a web browser.
|
||||
[Here](https://drive.google.com/file/d/19GSvdAAuAlACpBhHOaqMWZ_9p8bLUYKh/view?usp=sharing)
|
||||
is a demo video showing the steps to use this web application. Also please
|
||||
read
|
||||
[youtube8m/README.md](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/youtube8m/README.md)
|
||||
if you prefer to run the underlying model_inference binary in command line.
|
||||
|
||||
### Steps to run the YouTube-8M model inference graph with a local video
|
||||
|
||||
1. Make sure you have the features.pb from the feature extraction pipeline.
|
||||
|
||||
2. Copy the baseline model
|
||||
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
|
||||
to local.
|
||||
|
||||
```bash
|
||||
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
|
||||
|
||||
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
|
||||
```
|
||||
|
||||
3. Build and run the inference binary.
|
||||
|
||||
```bash
|
||||
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
|
||||
mediapipe/examples/desktop/youtube8m:model_inference
|
||||
|
||||
# segment_size is the number of seconds window of frames.
|
||||
# overlap is the number of seconds adjacent segments share.
|
||||
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/model_inference \
|
||||
--calculator_graph_config_file=mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt \
|
||||
--input_side_packets=input_sequence_example_path=/tmp/mediapipe/features.pb,input_video_path=/absolute/path/to/the/local/video/file,output_video_path=/tmp/mediapipe/annotated_video.mp4,segment_size=5,overlap=4
|
||||
```
|
||||
|
||||
4. View the annotated video.
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.apps.multihandtrackinggpu">
|
||||
|
||||
<uses-sdk
|
||||
android:minSdkVersion="21"
|
||||
android:targetSdkVersion="27" />
|
||||
|
||||
<!-- For using the camera -->
|
||||
<uses-permission android:name="android.permission.CAMERA" />
|
||||
<uses-feature android:name="android.hardware.camera" />
|
||||
<uses-feature android:name="android.hardware.camera.autofocus" />
|
||||
<!-- For MediaPipe -->
|
||||
<uses-feature android:glEsVersion="0x00020000" android:required="true" />
|
||||
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:label="@string/app_name"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/AppTheme">
|
||||
<activity
|
||||
android:name=".MainActivity"
|
||||
android:exported="true"
|
||||
android:screenOrientation="portrait">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2019 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
cc_binary(
|
||||
name = "libmediapipe_jni.so",
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//mediapipe/graphs/hand_tracking:multi_hand_mobile_calculators",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mediapipe_jni_lib",
|
||||
srcs = [":libmediapipe_jni.so"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Maps the binary graph to an alias (e.g., the app name) for convenience so that the alias can be
|
||||
# easily incorporated into the app via, for example,
|
||||
# MainActivity.BINARY_GRAPH_NAME = "appname.binarypb".
|
||||
genrule(
|
||||
name = "binary_graph",
|
||||
srcs = ["//mediapipe/graphs/hand_tracking:multi_hand_tracking_mobile_gpu_binary_graph"],
|
||||
outs = ["multihandtrackinggpu.binarypb"],
|
||||
cmd = "cp $< $@",
|
||||
)
|
||||
|
||||
# To use the 3D model instead of the default 2D model, add "--define 3D=true" to the
|
||||
# bazel build command.
|
||||
config_setting(
|
||||
name = "use_3d_model",
|
||||
define_values = {
|
||||
"3D": "true",
|
||||
},
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "model",
|
||||
srcs = select({
|
||||
"//conditions:default": ["//mediapipe/models:hand_landmark.tflite"],
|
||||
":use_3d_model": ["//mediapipe/models:hand_landmark_3d.tflite"],
|
||||
}),
|
||||
outs = ["hand_landmark.tflite"],
|
||||
cmd = "cp $< $@",
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "mediapipe_lib",
|
||||
srcs = glob(["*.java"]),
|
||||
assets = [
|
||||
":binary_graph",
|
||||
":model",
|
||||
"//mediapipe/models:palm_detection.tflite",
|
||||
"//mediapipe/models:palm_detection_labelmap.txt",
|
||||
],
|
||||
assets_dir = "",
|
||||
manifest = "AndroidManifest.xml",
|
||||
resource_files = glob(["res/**"]),
|
||||
deps = [
|
||||
":mediapipe_jni_lib",
|
||||
"//mediapipe/java/com/google/mediapipe/components:android_camerax_helper",
|
||||
"//mediapipe/java/com/google/mediapipe/components:android_components",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/glutil",
|
||||
"//third_party:androidx_appcompat",
|
||||
"//third_party:androidx_constraint_layout",
|
||||
"//third_party:androidx_legacy_support_v4",
|
||||
"//third_party:androidx_material",
|
||||
"//third_party:androidx_recyclerview",
|
||||
"//third_party:opencv",
|
||||
"@androidx_concurrent_futures//jar",
|
||||
"@androidx_lifecycle//jar",
|
||||
"@com_google_code_findbugs//jar",
|
||||
"@com_google_guava_android//jar",
|
||||
],
|
||||
)
|
||||
|
||||
android_binary(
|
||||
name = "multihandtrackinggpu",
|
||||
manifest = "AndroidManifest.xml",
|
||||
manifest_values = {"applicationId": "com.google.mediapipe.apps.multihandtrackinggpu"},
|
||||
multidex = "native",
|
||||
deps = [
|
||||
":mediapipe_lib",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,167 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.apps.multihandtrackinggpu;
|
||||
|
||||
import android.graphics.SurfaceTexture;
|
||||
import android.os.Bundle;
|
||||
import androidx.appcompat.app.AppCompatActivity;
|
||||
import android.util.Size;
|
||||
import android.view.SurfaceHolder;
|
||||
import android.view.SurfaceView;
|
||||
import android.view.View;
|
||||
import android.view.ViewGroup;
|
||||
import com.google.mediapipe.components.CameraHelper;
|
||||
import com.google.mediapipe.components.CameraXPreviewHelper;
|
||||
import com.google.mediapipe.components.ExternalTextureConverter;
|
||||
import com.google.mediapipe.components.FrameProcessor;
|
||||
import com.google.mediapipe.components.PermissionHelper;
|
||||
import com.google.mediapipe.framework.AndroidAssetUtil;
|
||||
import com.google.mediapipe.glutil.EglManager;
|
||||
|
||||
/** Main activity of MediaPipe example apps. */
|
||||
public class MainActivity extends AppCompatActivity {
|
||||
private static final String TAG = "MainActivity";
|
||||
|
||||
private static final String BINARY_GRAPH_NAME = "multihandtrackinggpu.binarypb";
|
||||
private static final String INPUT_VIDEO_STREAM_NAME = "input_video";
|
||||
private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video";
|
||||
private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.FRONT;
|
||||
|
||||
// Flips the camera-preview frames vertically before sending them into FrameProcessor to be
|
||||
// processed in a MediaPipe graph, and flips the processed frames back when they are displayed.
|
||||
// This is needed because OpenGL represents images assuming the image origin is at the bottom-left
|
||||
// corner, whereas MediaPipe in general assumes the image origin is at top-left.
|
||||
private static final boolean FLIP_FRAMES_VERTICALLY = true;
|
||||
|
||||
static {
|
||||
// Load all native libraries needed by the app.
|
||||
System.loadLibrary("mediapipe_jni");
|
||||
System.loadLibrary("opencv_java4");
|
||||
}
|
||||
|
||||
// {@link SurfaceTexture} where the camera-preview frames can be accessed.
|
||||
private SurfaceTexture previewFrameTexture;
|
||||
// {@link SurfaceView} that displays the camera-preview frames processed by a MediaPipe graph.
|
||||
private SurfaceView previewDisplayView;
|
||||
|
||||
// Creates and manages an {@link EGLContext}.
|
||||
private EglManager eglManager;
|
||||
// Sends camera-preview frames into a MediaPipe graph for processing, and displays the processed
|
||||
// frames onto a {@link Surface}.
|
||||
private FrameProcessor processor;
|
||||
// Converts the GL_TEXTURE_EXTERNAL_OES texture from Android camera into a regular texture to be
|
||||
// consumed by {@link FrameProcessor} and the underlying MediaPipe graph.
|
||||
private ExternalTextureConverter converter;
|
||||
|
||||
// Handles camera access via the {@link CameraX} Jetpack support library.
|
||||
private CameraXPreviewHelper cameraHelper;
|
||||
|
||||
@Override
|
||||
protected void onCreate(Bundle savedInstanceState) {
|
||||
super.onCreate(savedInstanceState);
|
||||
setContentView(R.layout.activity_main);
|
||||
|
||||
previewDisplayView = new SurfaceView(this);
|
||||
setupPreviewDisplayView();
|
||||
|
||||
// Initialize asset manager so that MediaPipe native libraries can access the app assets, e.g.,
|
||||
// binary graphs.
|
||||
AndroidAssetUtil.initializeNativeAssetManager(this);
|
||||
|
||||
eglManager = new EglManager(null);
|
||||
processor =
|
||||
new FrameProcessor(
|
||||
this,
|
||||
eglManager.getNativeContext(),
|
||||
BINARY_GRAPH_NAME,
|
||||
INPUT_VIDEO_STREAM_NAME,
|
||||
OUTPUT_VIDEO_STREAM_NAME);
|
||||
processor.getVideoSurfaceOutput().setFlipY(FLIP_FRAMES_VERTICALLY);
|
||||
|
||||
PermissionHelper.checkAndRequestCameraPermissions(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onResume() {
|
||||
super.onResume();
|
||||
converter = new ExternalTextureConverter(eglManager.getContext());
|
||||
converter.setFlipY(FLIP_FRAMES_VERTICALLY);
|
||||
converter.setConsumer(processor);
|
||||
if (PermissionHelper.cameraPermissionsGranted(this)) {
|
||||
startCamera();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onPause() {
|
||||
super.onPause();
|
||||
converter.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onRequestPermissionsResult(
|
||||
int requestCode, String[] permissions, int[] grantResults) {
|
||||
super.onRequestPermissionsResult(requestCode, permissions, grantResults);
|
||||
PermissionHelper.onRequestPermissionsResult(requestCode, permissions, grantResults);
|
||||
}
|
||||
|
||||
private void setupPreviewDisplayView() {
|
||||
previewDisplayView.setVisibility(View.GONE);
|
||||
ViewGroup viewGroup = findViewById(R.id.preview_display_layout);
|
||||
viewGroup.addView(previewDisplayView);
|
||||
|
||||
previewDisplayView
|
||||
.getHolder()
|
||||
.addCallback(
|
||||
new SurfaceHolder.Callback() {
|
||||
@Override
|
||||
public void surfaceCreated(SurfaceHolder holder) {
|
||||
processor.getVideoSurfaceOutput().setSurface(holder.getSurface());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void surfaceChanged(SurfaceHolder holder, int format, int width, int height) {
|
||||
// (Re-)Compute the ideal size of the camera-preview display (the area that the
|
||||
// camera-preview frames get rendered onto, potentially with scaling and rotation)
|
||||
// based on the size of the SurfaceView that contains the display.
|
||||
Size viewSize = new Size(width, height);
|
||||
Size displaySize = cameraHelper.computeDisplaySizeFromViewSize(viewSize);
|
||||
|
||||
// Connect the converter to the camera-preview frames as its input (via
|
||||
// previewFrameTexture), and configure the output width and height as the computed
|
||||
// display size.
|
||||
converter.setSurfaceTextureAndAttachToGLContext(
|
||||
previewFrameTexture, displaySize.getWidth(), displaySize.getHeight());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void surfaceDestroyed(SurfaceHolder holder) {
|
||||
processor.getVideoSurfaceOutput().setSurface(null);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void startCamera() {
|
||||
cameraHelper = new CameraXPreviewHelper();
|
||||
cameraHelper.setOnCameraStartedListener(
|
||||
surfaceTexture -> {
|
||||
previewFrameTexture = surfaceTexture;
|
||||
// Make the display view visible to start showing the preview. This triggers the
|
||||
// SurfaceHolder.Callback added to (the holder of) previewDisplayView.
|
||||
previewDisplayView.setVisibility(View.VISIBLE);
|
||||
});
|
||||
cameraHelper.startCamera(this, CAMERA_FACING, /*surfaceTexture=*/ null);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:app="http://schemas.android.com/apk/res-auto"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent">
|
||||
|
||||
<FrameLayout
|
||||
android:id="@+id/preview_display_layout"
|
||||
android:layout_width="fill_parent"
|
||||
android:layout_height="fill_parent"
|
||||
android:layout_weight="1">
|
||||
<TextView
|
||||
android:id="@+id/no_camera_access_view"
|
||||
android:layout_height="fill_parent"
|
||||
android:layout_width="fill_parent"
|
||||
android:gravity="center"
|
||||
android:text="@string/no_camera_access" />
|
||||
</FrameLayout>
|
||||
</androidx.constraintlayout.widget.ConstraintLayout>
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<color name="colorPrimary">#008577</color>
|
||||
<color name="colorPrimaryDark">#00574B</color>
|
||||
<color name="colorAccent">#D81B60</color>
|
||||
</resources>
|