Merge pull request #1 from google/master

Multi-hand Tracking
This commit is contained in:
Ali Zahid Raja 2019-11-17 12:32:58 +05:00 committed by GitHub
commit ac96722a54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
200 changed files with 14999 additions and 732 deletions

View File

@ -3,7 +3,7 @@
# Basic build settings
build --jobs 128
build --define='absl=1'
build --cxxopt='-std=c++11'
build --cxxopt='-std=c++14'
build --copt='-Wno-sign-compare'
build --copt='-Wno-unused-function'
build --copt='-Wno-uninitialized'

View File

@ -37,10 +37,18 @@ A web-based visualizer is hosted on [viz.mediapipe.dev](https://viz.mediapipe.de
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General community discussion around MediaPipe
## Publications
* [On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html)
* [MediaPipe: A Framework for Building Perception Pipelines](https://arxiv.org/abs/1906.08172)
## Events
[Open sourced at CVPR 2019](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) on June 17~20 in Long Beach, CA
* [MediaPipe Madrid Meetup, 16 Dec 2019](https://www.meetup.com/Madrid-AI-Developers-Group/events/266329088/)
* [MediaPipe London Meetup, Google 123 Building, 12 Dec 2019](https://www.meetup.com/London-AI-Tech-Talk/events/266329038)
* [ML Conference, Berlin, 11 Dec 2019](https://mlconference.ai/machine-learning-advanced-development/mediapipe-building-real-time-cross-platform-mobile-web-edge-desktop-video-audio-ml-pipelines/)
* [MediaPipe Berlin Meetup, Google Berlin, 11 Dec 2019](https://www.meetup.com/Berlin-AI-Tech-Talk/events/266328794/)
* [The 3rd Workshop on YouTube-8M Large Scale Video Understanding Workshop](https://research.google.com/youtube8m/workshop2019/index.html) Seoul, Korea ICCV 2019
* [AI DevWorld 2019](https://aidevworld.com) on Oct 10 in San Jose, California
* [Google Industry Workshop at ICIP 2019](http://2019.ieeeicip.org/?action=page4&id=14#Google) [Presentation](https://docs.google.com/presentation/d/e/2PACX-1vRIBBbO_LO9v2YmvbHHEt1cwyqH6EjDxiILjuT0foXy1E7g6uyh4CesB2DkkEwlRDO9_lWfuKMZx98T/pub?start=false&loop=false&delayms=3000&slide=id.g556cc1a659_0_5) on Sept 24 in Taipei, Taiwan
* [Open sourced at CVPR 2019](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) on June 17~20 in Long Beach, CA
## Alpha Disclaimer
MediaPipe is currently in alpha for v0.6. We are still making breaking API changes and expect to get to stable API by v1.0.

View File

@ -103,9 +103,9 @@ http_archive(
],
)
# 2019-08-15
_TENSORFLOW_GIT_COMMIT = "67def62936e28f97c16182dfcc467d8d1cae02b4"
_TENSORFLOW_SHA256= "ddd4e3c056e7c0ff2ef29133b30fa62781dfbf8a903e99efb91a02d292fa9562"
# 2019-11-12
_TENSORFLOW_GIT_COMMIT = "a5f9bcd64453ff3d1f64cb4da4786db3d2da7f82"
_TENSORFLOW_SHA256= "f2b6f2ab2ffe63e86eccd3ce4bea6b7197383d726638dfeeebcdc1e7de73f075"
http_archive(
name = "org_tensorflow",
urls = [
@ -114,13 +114,6 @@ http_archive(
],
strip_prefix = "tensorflow-%s" % _TENSORFLOW_GIT_COMMIT,
sha256 = _TENSORFLOW_SHA256,
patches = [
"@//third_party:tensorflow_065c20bf79253257c87bd4614bb9a7fdef015cbb.diff",
"@//third_party:tensorflow_f67fcbefce906cd419e4657f0d41e21019b71abd.diff",
],
patch_args = [
"-p1",
],
)
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
@ -254,18 +247,11 @@ android_sdk_repository(
# iOS basic build deps.
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
git_repository(
http_archive(
name = "build_bazel_rules_apple",
remote = "https://github.com/bazelbuild/rules_apple.git",
tag = "0.18.0",
patches = [
"@//third_party:rules_apple_c0863d0596ae6b769a29fa3fb72ff036444fd249.diff",
],
patch_args = [
"-p1",
],
sha256 = "bdc8e66e70b8a75da23b79f1f8c6207356df07d041d96d2189add7ee0780cf4e",
strip_prefix = "rules_apple-b869b0d3868d78a1d4ffd866ccb304fb68aa12c3",
url = "https://github.com/bazelbuild/rules_apple/archive/b869b0d3868d78a1d4ffd866ccb304fb68aa12c3.tar.gz",
)
load(

View File

@ -113,8 +113,15 @@ class SpectrogramCalculator : public CalculatorBase {
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
Timestamp CurrentOutputTimestamp() {
// Current output timestamp is the *center* of the next frame to be
Timestamp CurrentOutputTimestamp(CalculatorContext* cc) {
if (use_local_timestamp_) {
return cc->InputTimestamp();
}
return CumulativeOutputTimestamp();
}
Timestamp CumulativeOutputTimestamp() {
// Cumulative output timestamp is the *center* of the next frame to be
// emitted, hence delayed by half a window duration compared to relevant
// input timestamp.
return initial_input_timestamp_ +
@ -141,6 +148,7 @@ class SpectrogramCalculator : public CalculatorBase {
const OutputMatrixType postprocess_output_fn(const OutputMatrixType&),
CalculatorContext* cc);
bool use_local_timestamp_;
double input_sample_rate_;
bool pad_final_packet_;
int frame_duration_samples_;
@ -173,6 +181,8 @@ const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518;
SpectrogramCalculatorOptions spectrogram_options =
cc->Options<SpectrogramCalculatorOptions>();
use_local_timestamp_ = spectrogram_options.use_local_timestamp();
if (spectrogram_options.frame_duration_seconds() <= 0.0) {
::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Invalid or missing frame_duration_seconds.\n"
@ -351,11 +361,11 @@ template <class OutputMatrixType>
<< "Inconsistent number of spectrogram channels.";
if (allow_multichannel_input_) {
cc->Outputs().Index(0).Add(spectrogram_matrices.release(),
CurrentOutputTimestamp());
CurrentOutputTimestamp(cc));
} else {
cc->Outputs().Index(0).Add(
new OutputMatrixType(spectrogram_matrices->at(0)),
CurrentOutputTimestamp());
CurrentOutputTimestamp(cc));
}
cumulative_completed_frames_ += output_vectors.size();
}

View File

@ -66,4 +66,11 @@ message SpectrogramCalculatorOptions {
// uniformly regardless of output type (i.e., even dBs are multiplied, not
// offset).
optional double output_scale = 7 [default = 1.0];
// If use_local_timestamp is true, the output packet's timestamp is based on
// the last sample of the packet and it's inferred from the latest input
// packet's timestamp. If false, the output packet's timestamp is based on
// the cumulative timestamping, which is inferred from the intial input
// timestamp and the cumulative number of samples.
optional bool use_local_timestamp = 8 [default = false];
}

View File

@ -13,12 +13,12 @@
# limitations under the License.
#
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "concatenate_vector_calculator_proto",
srcs = ["concatenate_vector_calculator.proto"],
@ -26,6 +26,13 @@ proto_library(
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "dequantize_byte_array_calculator_proto",
srcs = ["dequantize_byte_array_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "packet_cloner_calculator_proto",
srcs = ["packet_cloner_calculator.proto"],
@ -72,6 +79,13 @@ proto_library(
],
)
proto_library(
name = "clip_vector_size_calculator_proto",
srcs = ["clip_vector_size_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
mediapipe_cc_proto_library(
name = "packet_cloner_calculator_cc_proto",
srcs = ["packet_cloner_calculator.proto"],
@ -104,6 +118,22 @@ mediapipe_cc_proto_library(
deps = [":concatenate_vector_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "clip_vector_size_calculator_cc_proto",
srcs = ["clip_vector_size_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":clip_vector_size_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "dequantize_byte_array_calculator_cc_proto",
srcs = ["dequantize_byte_array_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":dequantize_byte_array_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "quantize_float_vector_calculator_cc_proto",
srcs = ["quantize_float_vector_calculator.proto"],
@ -154,6 +184,66 @@ cc_test(
],
)
cc_library(
name = "begin_loop_calculator",
srcs = ["begin_loop_calculator.cc"],
hdrs = ["begin_loop_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/memory",
],
alwayslink = 1,
)
cc_library(
name = "end_loop_calculator",
srcs = ["end_loop_calculator.cc"],
hdrs = ["end_loop_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util:render_data_cc_proto",
],
alwayslink = 1,
)
cc_test(
name = "begin_end_loop_calculator_graph_test",
srcs = ["begin_end_loop_calculator_graph_test.cc"],
deps = [
":begin_loop_calculator",
":end_loop_calculator",
"//mediapipe/calculators/core:packet_cloner_calculator",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/memory",
],
)
cc_library(
name = "concatenate_vector_calculator",
srcs = ["concatenate_vector_calculator.cc"],
@ -204,6 +294,50 @@ cc_test(
],
)
cc_library(
name = "clip_vector_size_calculator",
srcs = ["clip_vector_size_calculator.cc"],
hdrs = ["clip_vector_size_calculator.h"],
visibility = ["//visibility:public"],
deps = [
":clip_vector_size_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/lite:framework",
],
alwayslink = 1,
)
cc_library(
name = "clip_detection_vector_size_calculator",
srcs = ["clip_detection_vector_size_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":clip_vector_size_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
],
alwayslink = 1,
)
cc_test(
name = "clip_vector_size_calculator_test",
srcs = ["clip_vector_size_calculator_test.cc"],
deps = [
":clip_vector_size_calculator",
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "counting_source_calculator",
srcs = ["counting_source_calculator.cc"],
@ -285,7 +419,7 @@ cc_library(
"//visibility:public",
],
deps = [
"//mediapipe/calculators/core:packet_cloner_calculator_cc_proto",
":packet_cloner_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"@com_google_absl//absl/strings",
],
@ -387,6 +521,32 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "string_to_int_calculator",
srcs = ["string_to_int_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_library(
name = "side_packet_to_stream_calculator",
srcs = ["side_packet_to_stream_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "immediate_mux_calculator_test",
srcs = ["immediate_mux_calculator_test.cc"],
@ -558,6 +718,32 @@ cc_test(
],
)
cc_library(
name = "dequantize_byte_array_calculator",
srcs = ["dequantize_byte_array_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":dequantize_byte_array_calculator_cc_proto",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "dequantize_byte_array_calculator_test",
srcs = ["dequantize_byte_array_calculator_test.cc"],
deps = [
":dequantize_byte_array_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)
cc_library(
name = "quantize_float_vector_calculator",
srcs = ["quantize_float_vector_calculator.cc"],
@ -694,3 +880,29 @@ cc_test(
"//mediapipe/framework/port:status",
],
)
cc_library(
name = "stream_to_side_packet_calculator",
srcs = ["stream_to_side_packet_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "stream_to_side_packet_calculator_test",
srcs = ["stream_to_side_packet_calculator_test.cc"],
deps = [
":stream_to_side_packet_calculator",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/memory",
],
)

View File

@ -0,0 +1,335 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "absl/memory/memory.h"
#include "mediapipe/calculators/core/begin_loop_calculator.h"
#include "mediapipe/calculators/core/end_loop_calculator.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
namespace {
typedef BeginLoopCalculator<std::vector<int>> BeginLoopIntegerCalculator;
REGISTER_CALCULATOR(BeginLoopIntegerCalculator);
class IncrementCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
const int& input_int = cc->Inputs().Index(0).Get<int>();
auto output_int = absl::make_unique<int>(input_int + 1);
cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(IncrementCalculator);
typedef EndLoopCalculator<std::vector<int>> EndLoopIntegersCalculator;
REGISTER_CALCULATOR(EndLoopIntegersCalculator);
class BeginEndLoopCalculatorGraphTest : public ::testing::Test {
protected:
BeginEndLoopCalculatorGraphTest() {
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
num_threads: 4
input_stream: "ints"
node {
calculator: "BeginLoopIntegerCalculator"
input_stream: "ITERABLE:ints"
output_stream: "ITEM:int"
output_stream: "BATCH_END:timestamp"
}
node {
calculator: "IncrementCalculator"
input_stream: "int"
output_stream: "int_plus_one"
}
node {
calculator: "EndLoopIntegersCalculator"
input_stream: "ITEM:int_plus_one"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:ints_plus_one"
}
)");
tool::AddVectorSink("ints_plus_one", &graph_config_, &output_packets_);
}
CalculatorGraphConfig graph_config_;
std::vector<Packet> output_packets_;
};
TEST_F(BeginEndLoopCalculatorGraphTest, SingleEmptyVector) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector = absl::make_unique<std::vector<int>>();
Timestamp input_timestamp = Timestamp(0);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector.release()).At(input_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
// EndLoopCalc will forward the timestamp bound because there are no elements
// in collection to output.
ASSERT_EQ(0, output_packets_.size());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphTest, SingleNonEmptyVector) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector = absl::make_unique<std::vector<int>>();
input_vector->emplace_back(0);
input_vector->emplace_back(1);
input_vector->emplace_back(2);
Timestamp input_timestamp = Timestamp(0);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector.release()).At(input_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, output_packets_.size());
EXPECT_EQ(input_timestamp, output_packets_[0].Timestamp());
std::vector<int> expected_output_vector = {1, 2, 3};
EXPECT_EQ(expected_output_vector, output_packets_[0].Get<std::vector<int>>());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphTest, MultipleVectors) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector0 = absl::make_unique<std::vector<int>>();
input_vector0->emplace_back(0);
input_vector0->emplace_back(1);
Timestamp input_timestamp0 = Timestamp(0);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector0.release()).At(input_timestamp0)));
auto input_vector1 = absl::make_unique<std::vector<int>>();
Timestamp input_timestamp1 = Timestamp(1);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector1.release()).At(input_timestamp1)));
auto input_vector2 = absl::make_unique<std::vector<int>>();
input_vector2->emplace_back(2);
input_vector2->emplace_back(3);
Timestamp input_timestamp2 = Timestamp(2);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector2.release()).At(input_timestamp2)));
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(2, output_packets_.size());
EXPECT_EQ(input_timestamp0, output_packets_[0].Timestamp());
std::vector<int> expected_output_vector0 = {1, 2};
EXPECT_EQ(expected_output_vector0,
output_packets_[0].Get<std::vector<int>>());
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
// no elements in vector to process.
EXPECT_EQ(input_timestamp2, output_packets_[1].Timestamp());
std::vector<int> expected_output_vector2 = {3, 4};
EXPECT_EQ(expected_output_vector2,
output_packets_[1].Get<std::vector<int>>());
}
class MultiplierCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Inputs().Index(1).Set<int>();
cc->Outputs().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
const int& input_int = cc->Inputs().Index(0).Get<int>();
const int& multiplier_int = cc->Inputs().Index(1).Get<int>();
auto output_int = absl::make_unique<int>(input_int * multiplier_int);
cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(MultiplierCalculator);
class BeginEndLoopCalculatorGraphWithClonedInputsTest : public ::testing::Test {
protected:
BeginEndLoopCalculatorGraphWithClonedInputsTest() {
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
num_threads: 4
input_stream: "ints"
input_stream: "multiplier"
node {
calculator: "BeginLoopIntegerCalculator"
input_stream: "ITERABLE:ints"
input_stream: "CLONE:multiplier"
output_stream: "ITEM:int_at_loop"
output_stream: "CLONE:multiplier_cloned_at_loop"
output_stream: "BATCH_END:timestamp"
}
node {
calculator: "MultiplierCalculator"
input_stream: "int_at_loop"
input_stream: "multiplier_cloned_at_loop"
output_stream: "multiplied_int_at_loop"
}
node {
calculator: "EndLoopIntegersCalculator"
input_stream: "ITEM:multiplied_int_at_loop"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:multiplied_ints"
}
)");
tool::AddVectorSink("multiplied_ints", &graph_config_, &output_packets_);
}
CalculatorGraphConfig graph_config_;
std::vector<Packet> output_packets_;
};
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleEmptyVector) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector = absl::make_unique<std::vector<int>>();
Timestamp input_timestamp = Timestamp(42);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector.release()).At(input_timestamp)));
auto multiplier = absl::make_unique<int>(2);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"multiplier", Adopt(multiplier.release()).At(input_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
// EndLoopCalc will forward the timestamp bound because there are no elements
// in collection to output.
ASSERT_EQ(0, output_packets_.size());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleNonEmptyVector) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector = absl::make_unique<std::vector<int>>();
input_vector->emplace_back(0);
input_vector->emplace_back(1);
input_vector->emplace_back(2);
Timestamp input_timestamp = Timestamp(42);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector.release()).At(input_timestamp)));
auto multiplier = absl::make_unique<int>(2);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"multiplier", Adopt(multiplier.release()).At(input_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, output_packets_.size());
EXPECT_EQ(input_timestamp, output_packets_[0].Timestamp());
std::vector<int> expected_output_vector = {0, 2, 4};
EXPECT_EQ(expected_output_vector, output_packets_[0].Get<std::vector<int>>());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector0 = absl::make_unique<std::vector<int>>();
input_vector0->emplace_back(0);
input_vector0->emplace_back(1);
Timestamp input_timestamp0 = Timestamp(42);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector0.release()).At(input_timestamp0)));
auto multiplier0 = absl::make_unique<int>(2);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"multiplier", Adopt(multiplier0.release()).At(input_timestamp0)));
auto input_vector1 = absl::make_unique<std::vector<int>>();
Timestamp input_timestamp1 = Timestamp(43);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector1.release()).At(input_timestamp1)));
auto multiplier1 = absl::make_unique<int>(2);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"multiplier", Adopt(multiplier1.release()).At(input_timestamp1)));
auto input_vector2 = absl::make_unique<std::vector<int>>();
input_vector2->emplace_back(2);
input_vector2->emplace_back(3);
Timestamp input_timestamp2 = Timestamp(44);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector2.release()).At(input_timestamp2)));
auto multiplier2 = absl::make_unique<int>(3);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"multiplier", Adopt(multiplier2.release()).At(input_timestamp2)));
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(2, output_packets_.size());
EXPECT_EQ(input_timestamp0, output_packets_[0].Timestamp());
std::vector<int> expected_output_vector0 = {0, 2};
EXPECT_EQ(expected_output_vector0,
output_packets_[0].Get<std::vector<int>>());
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
// no elements in vector to process.
EXPECT_EQ(input_timestamp2, output_packets_[1].Timestamp());
std::vector<int> expected_output_vector2 = {6, 9};
EXPECT_EQ(expected_output_vector2,
output_packets_[1].Get<std::vector<int>>());
}
} // namespace
} // namespace mediapipe

View File

@ -0,0 +1,40 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/begin_loop_calculator.h"
#include <vector>
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe {
// A calculator to process std::vector<NormalizedLandmark>.
typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedLandmark>>
BeginLoopNormalizedLandmarkCalculator;
REGISTER_CALCULATOR(BeginLoopNormalizedLandmarkCalculator);
// A calculator to process std::vector<std::vector<NormalizedLandmark>>.
typedef BeginLoopCalculator<
std::vector<std::vector<::mediapipe::NormalizedLandmark>>>
BeginLoopNormalizedLandmarksVectorCalculator;
REGISTER_CALCULATOR(BeginLoopNormalizedLandmarksVectorCalculator);
// A calculator to process std::vector<NormalizedRect>.
typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
BeginLoopNormalizedRectCalculator;
REGISTER_CALCULATOR(BeginLoopNormalizedRectCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,157 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
#include "absl/memory/memory.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Calculator for implementing loops on iterable collections inside a MediaPipe
// graph.
//
// It is designed to be used like:
//
// node {
// calculator: "BeginLoopWithIterableCalculator"
// input_stream: "ITERABLE:input_iterable" # IterableT @ext_ts
// output_stream: "ITEM:input_element" # ItemT @loop_internal_ts
// output_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
// }
//
// node {
// calculator: "ElementToBlaConverterSubgraph"
// input_stream: "ITEM:input_to_loop_body" # ItemT @loop_internal_ts
// output_stream: "BLA:output_of_loop_body" # ItemU @loop_internal_ts
// }
//
// node {
// calculator: "EndLoopWithOutputCalculator"
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
// }
//
// BeginLoopCalculator accepts an optional input stream tagged with "TICK"
// which if non-empty, wakes up the calculator and calls
// BeginLoopCalculator::Process(). Input streams tagged with "CLONE" are cloned
// to the corresponding output streams at loop timestamps. This ensures that a
// MediaPipe graph or sub-graph can run multiple times, once per element in the
// "ITERABLE" for each pakcet clone of the packets in the "CLONE" input streams.
template <typename IterableT>
class BeginLoopCalculator : public CalculatorBase {
using ItemT = typename IterableT::value_type;
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
// A non-empty packet in the optional "TICK" input stream wakes up the
// calculator.
if (cc->Inputs().HasTag("TICK")) {
cc->Inputs().Tag("TICK").SetAny();
}
// An iterable collection in the input stream.
RET_CHECK(cc->Inputs().HasTag("ITERABLE"));
cc->Inputs().Tag("ITERABLE").Set<IterableT>();
// An element from the collection.
RET_CHECK(cc->Outputs().HasTag("ITEM"));
cc->Outputs().Tag("ITEM").Set<ItemT>();
RET_CHECK(cc->Outputs().HasTag("BATCH_END"));
cc->Outputs()
.Tag("BATCH_END")
.Set<Timestamp>(
// A flush signal to the corresponding EndLoopCalculator for it to
// emit the aggregated result with the timestamp contained in this
// flush signal packet.
);
// Input streams tagged with "CLONE" are cloned to the corresponding
// "CLONE" output streams at loop timestamps.
RET_CHECK(cc->Inputs().NumEntries("CLONE") ==
cc->Outputs().NumEntries("CLONE"));
if (cc->Inputs().NumEntries("CLONE") > 0) {
for (int i = 0; i < cc->Inputs().NumEntries("CLONE"); ++i) {
cc->Inputs().Get("CLONE", i).SetAny();
cc->Outputs().Get("CLONE", i).SetSameAs(&cc->Inputs().Get("CLONE", i));
}
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
Timestamp last_timestamp = loop_internal_timestamp_;
if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) {
const IterableT& collection =
cc->Inputs().Tag("ITERABLE").template Get<IterableT>();
for (const auto& item : collection) {
cc->Outputs().Tag("ITEM").AddPacket(
MakePacket<ItemT>(item).At(loop_internal_timestamp_));
ForwardClonePackets(cc, loop_internal_timestamp_);
++loop_internal_timestamp_;
}
}
// The collection was empty and nothing was processed.
if (last_timestamp == loop_internal_timestamp_) {
// Increment loop_internal_timestamp_ because it is used up now.
++loop_internal_timestamp_;
for (auto it = cc->Outputs().begin(); it < cc->Outputs().end(); ++it) {
it->SetNextTimestampBound(loop_internal_timestamp_);
}
}
// The for loop processing the input collection already incremented
// loop_internal_timestamp_. To emit BATCH_END packet along the last
// non-BATCH_END packet, decrement by one.
cc->Outputs()
.Tag("BATCH_END")
.AddPacket(MakePacket<Timestamp>(cc->InputTimestamp())
.At(Timestamp(loop_internal_timestamp_ - 1)));
return ::mediapipe::OkStatus();
}
private:
void ForwardClonePackets(CalculatorContext* cc, Timestamp output_timestamp) {
if (cc->Inputs().NumEntries("CLONE") > 0) {
for (int i = 0; i < cc->Inputs().NumEntries("CLONE"); ++i) {
if (!cc->Inputs().Get("CLONE", i).IsEmpty()) {
auto input_packet = cc->Inputs().Get("CLONE", i).Value();
cc->Outputs()
.Get("CLONE", i)
.AddPacket(std::move(input_packet).At(output_timestamp));
}
}
}
}
// Fake timestamps generated per element in collection.
Timestamp loop_internal_timestamp_ = Timestamp(0);
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_

View File

@ -0,0 +1,26 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
#include "mediapipe/framework/formats/detection.pb.h"
namespace mediapipe {
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
ClipDetectionVectorSizeCalculator;
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,28 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe {
typedef ClipVectorSizeCalculator<::mediapipe::NormalizedRect>
ClipNormalizedRectVectorSizeCalculator;
REGISTER_CALCULATOR(ClipNormalizedRectVectorSizeCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,137 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_
#include <type_traits>
#include <vector>
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Clips the size of the input vector of type T to a specified max_vec_size.
// In a graph it will be used as:
// node {
// calculator: "ClipIntVectorSizeCalculator"
// input_stream: "input_vector"
// output_stream: "output_vector"
// options {
// [mediapipe.ClipIntVectorSizeCalculatorOptions.ext] {
// max_vec_size: 5
// }
// }
// }
template <typename T>
class ClipVectorSizeCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().NumEntries() == 1);
RET_CHECK(cc->Outputs().NumEntries() == 1);
if (cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>()
.max_vec_size() < 1) {
return ::mediapipe::InternalError(
"max_vec_size should be greater than or equal to 1.");
}
cc->Inputs().Index(0).Set<std::vector<T>>();
cc->Outputs().Index(0).Set<std::vector<T>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
max_vec_size_ = cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>()
.max_vec_size();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
if (max_vec_size_ < 1) {
return ::mediapipe::InternalError(
"max_vec_size should be greater than or equal to 1.");
}
if (cc->Inputs().Index(0).IsEmpty()) {
return ::mediapipe::OkStatus();
}
return ClipVectorSize<T>(std::is_copy_constructible<T>(), cc);
}
template <typename U>
::mediapipe::Status ClipVectorSize(std::true_type, CalculatorContext* cc) {
auto output = absl::make_unique<std::vector<U>>();
const std::vector<U>& input_vector =
cc->Inputs().Index(0).Get<std::vector<U>>();
if (max_vec_size_ >= input_vector.size()) {
output->insert(output->end(), input_vector.begin(), input_vector.end());
} else {
for (int i = 0; i < max_vec_size_; ++i) {
output->push_back(input_vector[i]);
}
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
template <typename U>
::mediapipe::Status ClipVectorSize(std::false_type, CalculatorContext* cc) {
return ConsumeAndClipVectorSize<T>(std::is_move_constructible<U>(), cc);
}
template <typename U>
::mediapipe::Status ConsumeAndClipVectorSize(std::true_type,
CalculatorContext* cc) {
auto output = absl::make_unique<std::vector<U>>();
::mediapipe::StatusOr<std::unique_ptr<std::vector<U>>> input_status =
cc->Inputs().Index(0).Value().Consume<std::vector<U>>();
if (input_status.ok()) {
std::unique_ptr<std::vector<U>> input_vector =
std::move(input_status).ValueOrDie();
auto begin_it = input_vector->begin();
auto end_it = input_vector->end();
if (max_vec_size_ < input_vector->size()) {
end_it = input_vector->begin() + max_vec_size_;
}
output->insert(output->end(), std::make_move_iterator(begin_it),
std::make_move_iterator(end_it));
} else {
return input_status.status();
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
template <typename U>
::mediapipe::Status ConsumeAndClipVectorSize(std::false_type,
CalculatorContext* cc) {
return ::mediapipe::InternalError(
"Cannot copy or move input vectors and clip their size.");
}
private:
int max_vec_size_ = 0;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_CLIP_VECTOR_SIZE_CALCULATOR_H_

View File

@ -0,0 +1,28 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message ClipVectorSizeCalculatorOptions {
extend CalculatorOptions {
optional ClipVectorSizeCalculatorOptions ext = 274674998;
}
// Maximum size of output vector.
optional int32 max_vec_size = 1 [default = 1];
}

View File

@ -0,0 +1,179 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/clip_vector_size_calculator.h"
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
typedef ClipVectorSizeCalculator<int> TestClipIntVectorSizeCalculator;
REGISTER_CALCULATOR(TestClipIntVectorSizeCalculator);
void AddInputVector(const std::vector<int>& input, int64 timestamp,
CalculatorRunner* runner) {
runner->MutableInputs()->Index(0).packets.push_back(
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
}
TEST(TestClipIntVectorSizeCalculatorTest, EmptyVectorInput) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "TestClipIntVectorSizeCalculator"
input_stream: "input_vector"
output_stream: "output_vector"
options {
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 1 }
}
)");
CalculatorRunner runner(node_config);
std::vector<int> input = {};
AddInputVector(input, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
EXPECT_TRUE(outputs[0].Get<std::vector<int>>().empty());
}
TEST(TestClipIntVectorSizeCalculatorTest, OneTimestamp) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "TestClipIntVectorSizeCalculator"
input_stream: "input_vector"
output_stream: "output_vector"
options {
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 2 }
}
)");
CalculatorRunner runner(node_config);
std::vector<int> input = {0, 1, 2, 3};
AddInputVector(input, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const std::vector<int>& output = outputs[0].Get<std::vector<int>>();
EXPECT_EQ(2, output.size());
std::vector<int> expected_vector = {0, 1};
EXPECT_EQ(expected_vector, output);
}
TEST(TestClipIntVectorSizeCalculatorTest, TwoInputsAtTwoTimestamps) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "TestClipIntVectorSizeCalculator"
input_stream: "input_vector"
output_stream: "output_vector"
options {
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 3 }
}
)");
CalculatorRunner runner(node_config);
{
std::vector<int> input = {0, 1, 2, 3};
AddInputVector(input, /*timestamp=*/1, &runner);
}
{
std::vector<int> input = {2, 3, 4, 5};
AddInputVector(input, /*timestamp=*/2, &runner);
}
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(2, outputs.size());
{
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const std::vector<int>& output = outputs[0].Get<std::vector<int>>();
EXPECT_EQ(3, output.size());
std::vector<int> expected_vector = {0, 1, 2};
EXPECT_EQ(expected_vector, output);
}
{
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
const std::vector<int>& output = outputs[1].Get<std::vector<int>>();
EXPECT_EQ(3, output.size());
std::vector<int> expected_vector = {2, 3, 4};
EXPECT_EQ(expected_vector, output);
}
}
typedef ClipVectorSizeCalculator<std::unique_ptr<int>>
TestClipUniqueIntPtrVectorSizeCalculator;
REGISTER_CALCULATOR(TestClipUniqueIntPtrVectorSizeCalculator);
TEST(TestClipUniqueIntPtrVectorSizeCalculatorTest, ConsumeOneTimestamp) {
/* Note: We don't use CalculatorRunner for this test because it keeps copies
* of input packets, so packets sent to the graph don't have sole ownership.
* The test needs to send packets that own the data.
*/
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "input_vector"
node {
calculator: "TestClipUniqueIntPtrVectorSizeCalculator"
input_stream: "input_vector"
output_stream: "output_vector"
options {
[mediapipe.ClipVectorSizeCalculatorOptions.ext] { max_vec_size: 3 }
}
}
)");
std::vector<Packet> outputs;
tool::AddVectorSink("output_vector", &graph_config, &outputs);
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config));
MP_EXPECT_OK(graph.StartRun({}));
// input1 : {0, 1, 2, 3, 4, 5}
auto input_vector = absl::make_unique<std::vector<std::unique_ptr<int>>>(6);
for (int i = 0; i < 6; ++i) {
input_vector->at(i) = absl::make_unique<int>(i);
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input_vector", Adopt(input_vector.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const std::vector<std::unique_ptr<int>>& result =
outputs[0].Get<std::vector<std::unique_ptr<int>>>();
EXPECT_EQ(3, result.size());
for (int i = 0; i < 3; ++i) {
const std::unique_ptr<int>& v = result[i];
EXPECT_EQ(i, *v);
}
}
} // namespace mediapipe

View File

@ -0,0 +1,90 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cfloat>
#include "mediapipe/calculators/core/dequantize_byte_array_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/status.h"
// Dequantizes a byte array to a vector of floats.
//
// Example config:
// node {
// calculator: "DequantizeByteArrayCalculator"
// input_stream: "ENCODED:encoded"
// output_stream: "FLOAT_VECTOR:float_vector"
// options {
// [mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
// max_quantized_value: 2
// min_quantized_value: -2
// }
// }
// }
namespace mediapipe {
class DequantizeByteArrayCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("ENCODED").Set<std::string>();
cc->Outputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
const auto options =
cc->Options<::mediapipe::DequantizeByteArrayCalculatorOptions>();
if (!options.has_max_quantized_value() ||
!options.has_min_quantized_value()) {
return ::mediapipe::InvalidArgumentError(
"Both max_quantized_value and min_quantized_value must be provided "
"in DequantizeByteArrayCalculatorOptions.");
}
float max_quantized_value = options.max_quantized_value();
float min_quantized_value = options.min_quantized_value();
if (max_quantized_value < min_quantized_value + FLT_EPSILON) {
return ::mediapipe::InvalidArgumentError(
"max_quantized_value must be greater than min_quantized_value.");
}
float range = max_quantized_value - min_quantized_value;
scalar_ = range / 255.0;
bias_ = (range / 512.0) + min_quantized_value;
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
const std::string& encoded =
cc->Inputs().Tag("ENCODED").Value().Get<std::string>();
std::vector<float> float_vector;
float_vector.reserve(encoded.length());
for (int i = 0; i < encoded.length(); ++i) {
float_vector.push_back(
static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_);
}
cc->Outputs()
.Tag("FLOAT_VECTOR")
.AddPacket(MakePacket<std::vector<float>>(float_vector)
.At(cc->InputTimestamp()));
return ::mediapipe::OkStatus();
}
private:
float scalar_;
float bias_;
};
REGISTER_CALCULATOR(DequantizeByteArrayCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,28 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message DequantizeByteArrayCalculatorOptions {
extend CalculatorOptions {
optional DequantizeByteArrayCalculatorOptions ext = 272316343;
}
optional float max_quantized_value = 1;
optional float min_quantized_value = 2;
}

View File

@ -0,0 +1,137 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: 2
}
}
)");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"Both max_quantized_value and min_quantized_value must be provided"));
}
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: -2
min_quantized_value: 2
}
}
)");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"max_quantized_value must be greater than min_quantized_value"));
}
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: 1
min_quantized_value: 1
}
}
)");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"max_quantized_value must be greater than min_quantized_value"));
}
TEST(DequantizeByteArrayCalculatorTest, TestDequantization) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: 2
min_quantized_value: -2
}
}
)");
CalculatorRunner runner(node_config);
unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01};
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(
std::string(reinterpret_cast<char const*>(input), 4))
.At(Timestamp(0)));
auto status = runner.Run();
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs =
runner.Outputs().Tag("FLOAT_VECTOR").packets;
EXPECT_EQ(1, outputs.size());
const std::vector<float>& result = outputs[0].Get<std::vector<float>>();
ASSERT_FALSE(result.empty());
EXPECT_EQ(4, result.size());
EXPECT_NEAR(0, result[0], 0.01);
EXPECT_NEAR(2, result[1], 0.01);
EXPECT_NEAR(-2, result[2], 0.01);
EXPECT_NEAR(-1.976, result[3], 0.01);
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
}
} // namespace mediapipe

View File

@ -0,0 +1,45 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/end_loop_calculator.h"
#include <vector>
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
typedef EndLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
EndLoopNormalizedRectCalculator;
REGISTER_CALCULATOR(EndLoopNormalizedRectCalculator);
typedef EndLoopCalculator<std::vector<::mediapipe::NormalizedLandmark>>
EndLoopNormalizedLandmarkCalculator;
REGISTER_CALCULATOR(EndLoopNormalizedLandmarkCalculator);
typedef EndLoopCalculator<
std::vector<std::vector<::mediapipe::NormalizedLandmark>>>
EndLoopNormalizedLandmarksVectorCalculator;
REGISTER_CALCULATOR(EndLoopNormalizedLandmarksVectorCalculator);
typedef EndLoopCalculator<std::vector<bool>> EndLoopBooleanCalculator;
REGISTER_CALCULATOR(EndLoopBooleanCalculator);
typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>>
EndLoopRenderDataCalculator;
REGISTER_CALCULATOR(EndLoopRenderDataCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,106 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Calculator for completing the processing of loops on iterable collections
// inside a MediaPipe graph. The EndLoopCalculator collects all input packets
// from ITEM input_stream into a collection and upon receiving the flush signal
// from the "BATCH_END" tagged input stream, it emits the aggregated results
// at the original timestamp contained in the "BATCH_END" input stream.
//
// It is designed to be used like:
//
// node {
// calculator: "BeginLoopWithIterableCalculator"
// input_stream: "ITERABLE:input_iterable" # IterableT @ext_ts
// output_stream: "ITEM:input_element" # ItemT @loop_internal_ts
// output_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
// }
//
// node {
// calculator: "ElementToBlaConverterSubgraph"
// input_stream: "ITEM:input_to_loop_body" # ItemT @loop_internal_ts
// output_stream: "BLA:output_of_loop_body" # ItemU @loop_internal_ts
// }
//
// node {
// calculator: "EndLoopWithOutputCalculator"
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
// }
template <typename IterableT>
class EndLoopCalculator : public CalculatorBase {
using ItemT = typename IterableT::value_type;
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("BATCH_END"))
<< "Missing BATCH_END tagged input_stream.";
cc->Inputs().Tag("BATCH_END").Set<Timestamp>();
RET_CHECK(cc->Inputs().HasTag("ITEM"));
cc->Inputs().Tag("ITEM").Set<ItemT>();
RET_CHECK(cc->Outputs().HasTag("ITERABLE"));
cc->Outputs().Tag("ITERABLE").Set<IterableT>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
if (!cc->Inputs().Tag("ITEM").IsEmpty()) {
if (!input_stream_collection_) {
input_stream_collection_.reset(new IterableT);
}
input_stream_collection_->push_back(
cc->Inputs().Tag("ITEM").template Get<ItemT>());
}
if (!cc->Inputs().Tag("BATCH_END").Value().IsEmpty()) { // flush signal
Timestamp loop_control_ts =
cc->Inputs().Tag("BATCH_END").template Get<Timestamp>();
if (input_stream_collection_) {
cc->Outputs()
.Tag("ITERABLE")
.Add(input_stream_collection_.release(), loop_control_ts);
} else {
// Since there is no collection, inform downstream calculators to not
// expect any packet by updating the timestamp bounds.
cc->Outputs()
.Tag("ITERABLE")
.SetNextTimestampBound(Timestamp(loop_control_ts.Value() + 1));
}
}
return ::mediapipe::OkStatus();
}
private:
std::unique_ptr<IterableT> input_stream_collection_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_

View File

@ -74,6 +74,12 @@ class PacketResamplerCalculator : public CalculatorBase {
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
// Calculates the first sampled timestamp that incorporates a jittering
// offset.
void InitializeNextOutputTimestampWithJitter();
// Calculates the next sampled timestamp that incorporates a jittering offset.
void UpdateNextOutputTimestampWithJitter();
// Logic for Process() when jitter_ != 0.0.
::mediapipe::Status ProcessWithJitter(CalculatorContext* cc);
@ -233,6 +239,7 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
<< Timestamp::kTimestampUnitsPerSecond;
frame_time_usec_ = static_cast<int64>(1000000.0 / frame_rate_);
video_header_.frame_rate = frame_rate_;
if (resampler_options.output_header() !=
@ -295,6 +302,17 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
return ::mediapipe::OkStatus();
}
void PacketResamplerCalculator::InitializeNextOutputTimestampWithJitter() {
next_output_timestamp_ =
first_timestamp_ + frame_time_usec_ * random_->RandFloat();
}
void PacketResamplerCalculator::UpdateNextOutputTimestampWithJitter() {
next_output_timestamp_ +=
frame_time_usec_ *
((1.0 - jitter_) + 2.0 * jitter_ * random_->RandFloat());
}
::mediapipe::Status PacketResamplerCalculator::ProcessWithJitter(
CalculatorContext* cc) {
RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream());
@ -302,8 +320,13 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
if (first_timestamp_ == Timestamp::Unset()) {
first_timestamp_ = cc->InputTimestamp();
next_output_timestamp_ =
first_timestamp_ + frame_time_usec_ * random_->RandFloat();
InitializeNextOutputTimestampWithJitter();
if (first_timestamp_ == next_output_timestamp_) {
OutputWithinLimits(
cc,
cc->Inputs().Get(input_data_id_).Value().At(next_output_timestamp_));
UpdateNextOutputTimestampWithJitter();
}
return ::mediapipe::OkStatus();
}
@ -322,9 +345,7 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
? last_packet_
: cc->Inputs().Get(input_data_id_).Value())
.At(next_output_timestamp_));
next_output_timestamp_ +=
frame_time_usec_ *
((1.0 - jitter_) + 2.0 * jitter_ * random_->RandFloat());
UpdateNextOutputTimestampWithJitter();
return ::mediapipe::OkStatus();
}

View File

@ -102,6 +102,12 @@ class PreviousLoopbackCalculator : public CalculatorBase {
cc->Outputs().Get(loop_out_id_).AddPacket(std::move(previous_loopback));
}
}
if (!main_ts_.empty()) {
cc->Outputs().Get(loop_out_id_).SetNextTimestampBound(main_ts_.front());
}
if (cc->Inputs().Get(main_id_).IsDone() && main_ts_.empty()) {
cc->Outputs().Get(loop_out_id_).Close();
}
return ::mediapipe::OkStatus();
}

View File

@ -107,5 +107,96 @@ TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
MP_EXPECT_OK(graph_.WaitUntilDone());
}
// A Calculator that outputs a summary packet in CalculatorBase::Close().
class PacketOnCloseCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
sum_ += cc->Inputs().Index(0).Value().Get<int>();
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return ::mediapipe::OkStatus();
}
::mediapipe::Status Close(CalculatorContext* cc) final {
cc->Outputs().Index(0).AddPacket(
MakePacket<int>(sum_).At(Timestamp::Max()));
return ::mediapipe::OkStatus();
}
private:
int sum_ = 0;
};
REGISTER_CALCULATOR(PacketOnCloseCalculator);
// Demonstrates that all ouput and input streams in PreviousLoopbackCalculator
// will close as expected when all graph input streams are closed.
TEST(PreviousLoopbackCalculator, ClosesCorrectly) {
std::vector<Packet> outputs;
CalculatorGraphConfig graph_config_ =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'in'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:in'
input_stream: 'LOOP:out'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:previous'
}
# This calculator synchronizes its inputs as normal, so it is used
# to check that both "in" and "previous" are ready.
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
input_stream: 'previous'
output_stream: 'out'
output_stream: 'previous2'
}
node {
calculator: 'PacketOnCloseCalculator'
input_stream: 'out'
output_stream: 'close_out'
}
)");
tool::AddVectorSink("close_out", &graph_config_, &outputs);
CalculatorGraph graph_;
MP_ASSERT_OK(graph_.Initialize(graph_config_, {}));
MP_ASSERT_OK(graph_.StartRun({}));
auto send_packet = [&graph_](const std::string& input_name, int n) {
MP_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int>(n).At(Timestamp(n))));
};
send_packet("in", 1);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1}));
send_packet("in", 5);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 5}));
send_packet("in", 15);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 5, 15}));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs),
(std::vector<int64>{1, 5, 15, Timestamp::Max().Value()}));
MP_EXPECT_OK(graph_.WaitUntilDone());
}
} // anonymous namespace
} // namespace mediapipe

View File

@ -0,0 +1,83 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <memory>
#include <set>
#include <string>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
using mediapipe::PacketTypeSet;
using mediapipe::Timestamp;
namespace {
static std::map<std::string, Timestamp>* kTimestampMap = []() {
auto* res = new std::map<std::string, Timestamp>();
res->emplace("AT_PRESTREAM", Timestamp::PreStream());
res->emplace("AT_POSTSTREAM", Timestamp::PostStream());
res->emplace("AT_ZERO", Timestamp(0));
return res;
}();
} // namespace
// Outputs the single input_side_packet at the timestamp specified in the
// output_stream tag. Valid tags are AT_PRESTREAM, AT_POSTSTREAM and AT_ZERO.
class SidePacketToStreamCalculator : public CalculatorBase {
public:
SidePacketToStreamCalculator() = default;
~SidePacketToStreamCalculator() override = default;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(SidePacketToStreamCalculator);
::mediapipe::Status SidePacketToStreamCalculator::GetContract(
CalculatorContract* cc) {
cc->InputSidePackets().Index(0).SetAny();
std::set<std::string> tags = cc->Outputs().GetTags();
RET_CHECK_EQ(tags.size(), 1);
RET_CHECK_EQ(kTimestampMap->count(*tags.begin()), 1);
cc->Outputs().Tag(*tags.begin()).SetAny();
return ::mediapipe::OkStatus();
}
::mediapipe::Status SidePacketToStreamCalculator::Process(
CalculatorContext* cc) {
return mediapipe::tool::StatusStop();
}
::mediapipe::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
std::set<std::string> tags = cc->Outputs().GetTags();
RET_CHECK_EQ(tags.size(), 1);
const std::string& tag = *tags.begin();
RET_CHECK_EQ(kTimestampMap->count(tag), 1);
cc->Outputs().Tag(tag).AddPacket(
cc->InputSidePackets().Index(0).At(kTimestampMap->at(tag)));
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -34,7 +34,9 @@ namespace mediapipe {
// SplitVectorCalculatorOptions. If the option "element_only" is set to true,
// all ranges should be of size 1 and all outputs will be elements of type T. If
// "element_only" is false, ranges can be non-zero in size and all outputs will
// be of type std::vector<T>.
// be of type std::vector<T>. If the option "combine_outputs" is set to true,
// only one output stream can be specified and all ranges of elements will be
// combined into one vector.
// To use this class for a particular type T, register a calculator using
// SplitVectorCalculator<T>.
template <typename T>
@ -49,28 +51,47 @@ class SplitVectorCalculator : public CalculatorBase {
const auto& options =
cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
if (cc->Outputs().NumEntries() != options.ranges_size()) {
return ::mediapipe::InvalidArgumentError(
"The number of output streams should match the number of ranges "
"specified in the CalculatorOptions.");
}
// Set the output types for each output stream.
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
if (options.ranges(i).begin() < 0 || options.ranges(i).end() < 0 ||
options.ranges(i).begin() >= options.ranges(i).end()) {
return ::mediapipe::InvalidArgumentError(
"Indices should be non-negative and begin index should be less "
"than the end index.");
}
if (options.element_only()) {
if (options.ranges(i).end() - options.ranges(i).begin() != 1) {
return ::mediapipe::InvalidArgumentError(
"Since element_only is true, all ranges should be of size 1.");
if (options.combine_outputs()) {
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
cc->Outputs().Index(0).Set<std::vector<T>>();
for (int i = 0; i < options.ranges_size() - 1; ++i) {
for (int j = i + 1; j < options.ranges_size(); ++j) {
const auto& range_0 = options.ranges(i);
const auto& range_1 = options.ranges(j);
if ((range_0.begin() >= range_1.begin() &&
range_0.begin() < range_1.end()) ||
(range_1.begin() >= range_0.begin() &&
range_1.begin() < range_0.end())) {
return ::mediapipe::InvalidArgumentError(
"Ranges must be non-overlapping when using combine_outputs "
"option.");
}
}
}
} else {
if (cc->Outputs().NumEntries() != options.ranges_size()) {
return ::mediapipe::InvalidArgumentError(
"The number of output streams should match the number of ranges "
"specified in the CalculatorOptions.");
}
// Set the output types for each output stream.
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
if (options.ranges(i).begin() < 0 || options.ranges(i).end() < 0 ||
options.ranges(i).begin() >= options.ranges(i).end()) {
return ::mediapipe::InvalidArgumentError(
"Indices should be non-negative and begin index should be less "
"than the end index.");
}
if (options.element_only()) {
if (options.ranges(i).end() - options.ranges(i).begin() != 1) {
return ::mediapipe::InvalidArgumentError(
"Since element_only is true, all ranges should be of size 1.");
}
cc->Outputs().Index(i).Set<T>();
} else {
cc->Outputs().Index(i).Set<std::vector<T>>();
}
cc->Outputs().Index(i).Set<T>();
} else {
cc->Outputs().Index(i).Set<std::vector<T>>();
}
}
@ -83,13 +104,15 @@ class SplitVectorCalculator : public CalculatorBase {
const auto& options =
cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
element_only_ = options.element_only();
combine_outputs_ = options.combine_outputs();
for (const auto& range : options.ranges()) {
ranges_.push_back({range.begin(), range.end()});
max_range_end_ = std::max(max_range_end_, range.end());
total_elements_ += range.end() - range.begin();
}
element_only_ = options.element_only();
return ::mediapipe::OkStatus();
}
@ -97,17 +120,29 @@ class SplitVectorCalculator : public CalculatorBase {
const auto& input = cc->Inputs().Index(0).Get<std::vector<T>>();
RET_CHECK_GE(input.size(), max_range_end_);
if (element_only_) {
if (combine_outputs_) {
auto output = absl::make_unique<std::vector<T>>();
output->reserve(total_elements_);
for (int i = 0; i < ranges_.size(); ++i) {
cc->Outputs().Index(i).AddPacket(
MakePacket<T>(input[ranges_[i].first]).At(cc->InputTimestamp()));
}
} else {
for (int i = 0; i < ranges_.size(); ++i) {
auto output = absl::make_unique<std::vector<T>>(
auto elements = absl::make_unique<std::vector<T>>(
input.begin() + ranges_[i].first,
input.begin() + ranges_[i].second);
cc->Outputs().Index(i).Add(output.release(), cc->InputTimestamp());
output->insert(output->end(), elements->begin(), elements->end());
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
} else {
if (element_only_) {
for (int i = 0; i < ranges_.size(); ++i) {
cc->Outputs().Index(i).AddPacket(
MakePacket<T>(input[ranges_[i].first]).At(cc->InputTimestamp()));
}
} else {
for (int i = 0; i < ranges_.size(); ++i) {
auto output = absl::make_unique<std::vector<T>>(
input.begin() + ranges_[i].first,
input.begin() + ranges_[i].second);
cc->Outputs().Index(i).Add(output.release(), cc->InputTimestamp());
}
}
}
@ -117,7 +152,9 @@ class SplitVectorCalculator : public CalculatorBase {
private:
std::vector<std::pair<int32, int32>> ranges_;
int32 max_range_end_ = -1;
int32 total_elements_ = 0;
bool element_only_ = false;
bool combine_outputs_ = false;
};
} // namespace mediapipe

View File

@ -37,4 +37,7 @@ message SplitVectorCalculatorOptions {
// just element of type T. By default, if a range specifies only one element,
// it is outputted as an std::vector<T>.
optional bool element_only = 2 [default = false];
// Combines output elements to one vector.
optional bool combine_outputs = 3 [default = false];
}

View File

@ -105,6 +105,34 @@ class SplitTfLiteTensorVectorCalculatorTest : public ::testing::Test {
}
}
void ValidateCombinedVectorOutput(std::vector<Packet>& output_packets,
int expected_elements,
std::vector<int>& input_begin_indices,
std::vector<int>& input_end_indices) {
ASSERT_EQ(1, output_packets.size());
ASSERT_EQ(input_begin_indices.size(), input_end_indices.size());
const std::vector<TfLiteTensor>& output_vec =
output_packets[0].Get<std::vector<TfLiteTensor>>();
ASSERT_EQ(expected_elements, output_vec.size());
const int num_ranges = input_begin_indices.size();
int element_id = 0;
for (int range_id = 0; range_id < num_ranges; ++range_id) {
for (int i = input_begin_indices[range_id];
i < input_end_indices[range_id]; ++i) {
const int expected_value = i;
const TfLiteTensor* result = &output_vec[element_id];
float* result_buffer = result->data.f;
ASSERT_NE(result_buffer, nullptr);
ASSERT_EQ(result_buffer, input_buffers_[i]);
for (int j = 0; j < width * height * channels; ++j) {
ASSERT_EQ(expected_value, result_buffer[j]);
}
element_id++;
}
}
}
void ValidateElementOutput(std::vector<Packet>& output_packets,
int input_begin_index) {
ASSERT_EQ(1, output_packets.size());
@ -234,6 +262,65 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOutputStreamCountTest) {
ASSERT_FALSE(graph.Initialize(graph_config).ok());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest,
InvalidCombineOutputsMultipleOutputsTest) {
ASSERT_NE(interpreter_, nullptr);
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "tensor_in"
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "tensor_in"
output_stream: "range_0"
output_stream: "range_1"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 2 end: 3 }
combine_outputs: true
}
}
}
)");
// Run the graph.
CalculatorGraph graph;
// The graph should fail running because the number of output streams does not
// match the number of range elements in the options.
ASSERT_FALSE(graph.Initialize(graph_config).ok());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOverlappingRangesTest) {
ASSERT_NE(interpreter_, nullptr);
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "tensor_in"
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "tensor_in"
output_stream: "range_0"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 3 }
ranges: { begin: 1 end: 4 }
combine_outputs: true
}
}
}
)");
// Run the graph.
CalculatorGraph graph;
// The graph should fail running because there are overlapping ranges.
ASSERT_FALSE(graph.Initialize(graph_config).ok());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) {
ASSERT_NE(interpreter_, nullptr);
@ -289,6 +376,53 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) {
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestCombiningOutputs) {
ASSERT_NE(interpreter_, nullptr);
PrepareTfLiteTensorVector(/*vector_size=*/5);
ASSERT_NE(input_vec_, nullptr);
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "tensor_in"
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "tensor_in"
output_stream: "range_0"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 2 end: 3 }
ranges: { begin: 4 end: 5 }
combine_outputs: true
}
}
}
)");
std::vector<Packet> range_0_packets;
tool::AddVectorSink("range_0", &graph_config, &range_0_packets);
// Run the graph.
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"tensor_in", Adopt(input_vec_.release()).At(Timestamp(0))));
// Wait until the calculator finishes processing.
MP_ASSERT_OK(graph.WaitUntilIdle());
std::vector<int> input_begin_indices = {0, 2, 4};
std::vector<int> input_end_indices = {1, 3, 5};
ValidateCombinedVectorOutput(range_0_packets, /*expected_elements=*/3,
input_begin_indices, input_end_indices);
// Fully close the graph at the end.
MP_ASSERT_OK(graph.CloseInputStream("tensor_in"));
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest,
ElementOnlyDisablesVectorOutputs) {
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.

View File

@ -0,0 +1,48 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/timestamp.h"
namespace mediapipe {
// A calculator that takes a packet of an input stream and converts it to an
// output side packet. This calculator only works under the assumption that the
// input stream only has a single packet passing through.
//
// Example config:
// node {
// calculator: "StreamToSidePacketCalculator"
// input_stream: "stream"
// output_side_packet: "side_packet"
// }
class StreamToSidePacketCalculator : public mediapipe::CalculatorBase {
public:
static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->OutputSidePackets().Index(0).SetAny();
return mediapipe::OkStatus();
}
mediapipe::Status Process(mediapipe::CalculatorContext* cc) override {
mediapipe::Packet& packet = cc->Inputs().Index(0).Value();
cc->OutputSidePackets().Index(0).Set(
packet.At(mediapipe::Timestamp::Unset()));
return mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(StreamToSidePacketCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,67 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "absl/memory/memory.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
namespace mediapipe {
using ::testing::Test;
class StreamToSidePacketCalculatorTest : public Test {
protected:
StreamToSidePacketCalculatorTest() {
const char kConfig[] = R"(
calculator: "StreamToSidePacketCalculator"
input_stream: "stream"
output_side_packet: "side_packet"
)";
runner_ = absl::make_unique<CalculatorRunner>(kConfig);
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(StreamToSidePacketCalculatorTest,
StreamToSidePacketCalculatorWithEmptyStreamFails) {
EXPECT_EQ(runner_->Run().code(), mediapipe::StatusCode::kUnavailable);
}
TEST_F(StreamToSidePacketCalculatorTest,
StreamToSidePacketCalculatorWithSinglePacketCreatesSidePacket) {
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(new std::string("test")).At(Timestamp(1)));
MP_ASSERT_OK(runner_->Run());
EXPECT_EQ(runner_->OutputSidePackets().Index(0).Get<std::string>(), "test");
}
TEST_F(StreamToSidePacketCalculatorTest,
StreamToSidePacketCalculatorWithMultiplePacketsFails) {
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(new std::string("test1")).At(Timestamp(1)));
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(new std::string("test2")).At(Timestamp(2)));
EXPECT_EQ(runner_->Run().code(), mediapipe::StatusCode::kAlreadyExists);
}
} // namespace mediapipe

View File

@ -0,0 +1,79 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/types.h>
#include <memory>
#include <string>
#include "absl/strings/numbers.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Calculator that converts a std::string into an integer type, or fails if the
// conversion is not possible.
//
// Example config:
// node {
// calculator: "StringToIntCalculator"
// input_side_packet: "string"
// output_side_packet: "index"
// }
template <typename IntType>
class StringToIntCalculatorTemplate : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Index(0).Set<std::string>();
cc->OutputSidePackets().Index(0).Set<IntType>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
IntType number;
if (!absl::SimpleAtoi(cc->InputSidePackets().Index(0).Get<std::string>(),
&number)) {
return ::mediapipe::InvalidArgumentError(
"The std::string could not be parsed as an integer.");
}
cc->OutputSidePackets().Index(0).Set(MakePacket<IntType>(number));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
return ::mediapipe::OkStatus();
}
};
using StringToIntCalculator = StringToIntCalculatorTemplate<int>;
REGISTER_CALCULATOR(StringToIntCalculator);
using StringToUintCalculator = StringToIntCalculatorTemplate<uint>;
REGISTER_CALCULATOR(StringToUintCalculator);
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;
REGISTER_CALCULATOR(StringToInt32Calculator);
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>;
REGISTER_CALCULATOR(StringToUint32Calculator);
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>;
REGISTER_CALCULATOR(StringToInt64Calculator);
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>;
REGISTER_CALCULATOR(StringToUint64Calculator);
} // namespace mediapipe

View File

@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
exports_files(["LICENSE"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "opencv_image_encoder_calculator_proto",
srcs = ["opencv_image_encoder_calculator.proto"],

View File

@ -13,12 +13,12 @@
# limitations under the License.
#
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "graph_tensors_packet_generator_proto",
srcs = ["graph_tensors_packet_generator.proto"],
@ -104,6 +104,17 @@ proto_library(
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "unpack_media_sequence_calculator_proto",
srcs = ["unpack_media_sequence_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/core:packet_resampler_calculator_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/util:audio_decoder_proto",
],
)
proto_library(
name = "vector_float_to_tensor_calculator_options_proto",
srcs = ["vector_float_to_tensor_calculator_options.proto"],
@ -127,7 +138,7 @@ mediapipe_cc_proto_library(
srcs = ["image_frame_to_tensor_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
visibility = ["//visibility:public"],
deps = [":image_frame_to_tensor_calculator_proto"],
@ -162,7 +173,7 @@ mediapipe_cc_proto_library(
srcs = ["pack_media_sequence_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
visibility = ["//visibility:public"],
deps = [":pack_media_sequence_calculator_proto"],
@ -181,7 +192,7 @@ mediapipe_cc_proto_library(
srcs = ["tensorflow_session_from_frozen_graph_generator.proto"],
cc_deps = [
"//mediapipe/framework:packet_generator_cc_proto",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
visibility = ["//visibility:public"],
deps = [":tensorflow_session_from_frozen_graph_generator_proto"],
@ -192,7 +203,7 @@ mediapipe_cc_proto_library(
srcs = ["tensorflow_session_from_frozen_graph_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
visibility = ["//visibility:public"],
deps = [":tensorflow_session_from_frozen_graph_calculator_proto"],
@ -261,6 +272,17 @@ mediapipe_cc_proto_library(
deps = [":unpack_media_sequence_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "vector_int_to_tensor_calculator_options_cc_proto",
srcs = ["vector_int_to_tensor_calculator_options.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"@org_tensorflow//tensorflow/core:protos_all",
],
visibility = ["//visibility:public"],
deps = [":vector_int_to_tensor_calculator_options_proto"],
)
mediapipe_cc_proto_library(
name = "vector_float_to_tensor_calculator_options_cc_proto",
srcs = ["vector_float_to_tensor_calculator_options.proto"],
@ -274,7 +296,7 @@ cc_library(
srcs = ["graph_tensors_packet_generator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:graph_tensors_packet_generator_cc_proto",
":graph_tensors_packet_generator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -289,7 +311,7 @@ cc_library(
srcs = ["image_frame_to_tensor_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:image_frame_to_tensor_calculator_cc_proto",
":image_frame_to_tensor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:ret_check",
@ -311,7 +333,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/calculators/tensorflow:matrix_to_tensor_calculator_options_cc_proto",
":matrix_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:status",
@ -332,7 +354,7 @@ cc_library(
srcs = ["lapped_tensor_buffer_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator_cc_proto",
":lapped_tensor_buffer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -386,7 +408,7 @@ cc_library(
"//mediapipe/util/sequence:media_sequence",
"//mediapipe/util/sequence:media_sequence_util",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
alwayslink = 1,
)
@ -401,7 +423,7 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
alwayslink = 1,
)
@ -414,7 +436,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":tensorflow_session",
"//mediapipe/calculators/tensorflow:tensorflow_inference_calculator_cc_proto",
":tensorflow_inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/tool:status_util",
"@com_google_absl//absl/strings",
@ -492,7 +514,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":tensorflow_session",
"//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto",
":tensorflow_session_from_frozen_graph_generator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/tool:status_util",
"//mediapipe/framework/port:status",
@ -551,7 +573,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":tensorflow_session",
"//mediapipe/calculators/tensorflow:tensorflow_session_from_saved_model_generator_cc_proto",
":tensorflow_session_from_saved_model_generator_cc_proto",
"//mediapipe/framework:packet_generator",
"//mediapipe/framework:packet_type",
"//mediapipe/framework/tool:status_util",
@ -575,7 +597,7 @@ cc_library(
srcs = ["tensor_squeeze_dimensions_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:tensor_squeeze_dimensions_calculator_cc_proto",
":tensor_squeeze_dimensions_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -589,7 +611,7 @@ cc_library(
srcs = ["tensor_to_image_frame_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:tensor_to_image_frame_calculator_cc_proto",
":tensor_to_image_frame_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:ret_check",
@ -605,7 +627,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/calculators/tensorflow:tensor_to_matrix_calculator_cc_proto",
":tensor_to_matrix_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:status",
@ -621,6 +643,22 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "tfrecord_reader_calculator",
srcs = ["tfrecord_reader_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all",
],
alwayslink = 1,
)
cc_library(
name = "tensor_to_vector_float_calculator",
srcs = ["tensor_to_vector_float_calculator.cc"],
@ -629,7 +667,7 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
"//mediapipe/calculators/tensorflow:tensor_to_vector_float_calculator_options_cc_proto",
":tensor_to_vector_float_calculator_options_cc_proto",
] + select({
"//conditions:default": [
"@org_tensorflow//tensorflow/core:framework",
@ -657,7 +695,21 @@ cc_library(
"//mediapipe/util:audio_decoder_cc_proto",
"//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
alwayslink = 1,
)
cc_library(
name = "vector_int_to_tensor_calculator",
srcs = ["vector_int_to_tensor_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":vector_int_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:framework",
],
alwayslink = 1,
)
@ -667,7 +719,7 @@ cc_library(
srcs = ["vector_float_to_tensor_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:vector_float_to_tensor_calculator_options_cc_proto",
":vector_float_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -676,12 +728,26 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "unpack_yt8m_sequence_example_calculator",
srcs = ["unpack_yt8m_sequence_example_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":lapped_tensor_buffer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:protos_all",
],
alwayslink = 1,
)
cc_test(
name = "graph_tensors_packet_generator_test",
srcs = ["graph_tensors_packet_generator_test.cc"],
deps = [
":graph_tensors_packet_generator",
"//mediapipe/calculators/tensorflow:graph_tensors_packet_generator_cc_proto",
":graph_tensors_packet_generator_cc_proto",
"//mediapipe/framework:packet",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:packet_set",
@ -713,7 +779,7 @@ cc_test(
srcs = ["matrix_to_tensor_calculator_test.cc"],
deps = [
":matrix_to_tensor_calculator",
"//mediapipe/calculators/tensorflow:matrix_to_tensor_calculator_options_cc_proto",
":matrix_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
@ -729,13 +795,13 @@ cc_test(
srcs = ["lapped_tensor_buffer_calculator_test.cc"],
deps = [
":lapped_tensor_buffer_calculator",
"//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator_cc_proto",
":lapped_tensor_buffer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/memory",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
)
@ -774,7 +840,7 @@ cc_test(
"//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
)
@ -801,7 +867,7 @@ cc_test(
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:direct_session",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
"@org_tensorflow//tensorflow/core:testlib",
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
"@org_tensorflow//tensorflow/core/kernels:math",
@ -817,7 +883,7 @@ cc_test(
":tensorflow_inference_calculator",
":tensorflow_session",
":tensorflow_session_from_frozen_graph_generator",
"//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto",
":tensorflow_session_from_frozen_graph_generator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework:packet_generator_cc_proto",
@ -831,7 +897,7 @@ cc_test(
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:direct_session",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
"@org_tensorflow//tensorflow/core:testlib",
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
"@org_tensorflow//tensorflow/core/kernels:math",
@ -847,7 +913,7 @@ cc_test(
":tensorflow_inference_calculator",
":tensorflow_session",
":tensorflow_session_from_saved_model_generator",
"//mediapipe/calculators/tensorflow:tensorflow_session_from_saved_model_generator_cc_proto",
":tensorflow_session_from_saved_model_generator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework:packet_generator_cc_proto",
@ -857,14 +923,8 @@ cc_test(
"//mediapipe/framework/tool:tag_map_helper",
"//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:all_kernels",
"@org_tensorflow//tensorflow/core:direct_session",
"@org_tensorflow//tensorflow/core/kernels:array",
"@org_tensorflow//tensorflow/core/kernels:bitcast_op",
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
"@org_tensorflow//tensorflow/core/kernels:io",
"@org_tensorflow//tensorflow/core/kernels:state",
"@org_tensorflow//tensorflow/core/kernels:string",
"@org_tensorflow//tensorflow/core/kernels/data:tensor_dataset_op",
],
)
@ -888,14 +948,8 @@ cc_test(
"//mediapipe/framework/tool:tag_map_helper",
"//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:all_kernels",
"@org_tensorflow//tensorflow/core:direct_session",
"@org_tensorflow//tensorflow/core/kernels:array",
"@org_tensorflow//tensorflow/core/kernels:bitcast_op",
"@org_tensorflow//tensorflow/core/kernels:conv_ops",
"@org_tensorflow//tensorflow/core/kernels:io",
"@org_tensorflow//tensorflow/core/kernels:state",
"@org_tensorflow//tensorflow/core/kernels:string",
"@org_tensorflow//tensorflow/core/kernels/data:tensor_dataset_op",
],
)
@ -904,12 +958,12 @@ cc_test(
srcs = ["tensor_squeeze_dimensions_calculator_test.cc"],
deps = [
":tensor_squeeze_dimensions_calculator",
"//mediapipe/calculators/tensorflow:tensor_squeeze_dimensions_calculator_cc_proto",
":tensor_squeeze_dimensions_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
)
@ -919,13 +973,13 @@ cc_test(
srcs = ["tensor_to_image_frame_calculator_test.cc"],
deps = [
":tensor_to_image_frame_calculator",
"//mediapipe/calculators/tensorflow:tensor_to_image_frame_calculator_cc_proto",
":tensor_to_image_frame_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
)
@ -935,14 +989,14 @@ cc_test(
srcs = ["tensor_to_matrix_calculator_test.cc"],
deps = [
":tensor_to_matrix_calculator",
"//mediapipe/calculators/tensorflow:tensor_to_matrix_calculator_cc_proto",
":tensor_to_matrix_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
)
@ -951,12 +1005,12 @@ cc_test(
srcs = ["tensor_to_vector_float_calculator_test.cc"],
deps = [
":tensor_to_vector_float_calculator",
"//mediapipe/calculators/tensorflow:tensor_to_vector_float_calculator_options_cc_proto",
":tensor_to_vector_float_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
)
@ -976,7 +1030,21 @@ cc_test(
"//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
)
cc_test(
name = "vector_int_to_tensor_calculator_test",
srcs = ["vector_int_to_tensor_calculator_test.cc"],
deps = [
":vector_int_to_tensor_calculator",
":vector_int_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all",
],
)
@ -985,12 +1053,12 @@ cc_test(
srcs = ["vector_float_to_tensor_calculator_test.cc"],
deps = [
":vector_float_to_tensor_calculator",
"//mediapipe/calculators/tensorflow:vector_float_to_tensor_calculator_options_cc_proto",
":vector_float_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:protos_all",
],
)
@ -1014,7 +1082,7 @@ cc_test(
":tensorflow_session",
":tensorflow_inference_calculator",
":tensorflow_session_from_frozen_graph_generator",
"//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_generator_cc_proto",
":tensorflow_session_from_frozen_graph_generator_cc_proto",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",

View File

@ -29,6 +29,11 @@
namespace mediapipe {
const char kBufferSize[] = "BUFFER_SIZE";
const char kOverlap[] = "OVERLAP";
const char kTimestampOffset[] = "TIMESTAMP_OFFSET";
const char kCalculatorOptions[] = "CALCULATOR_OPTIONS";
namespace tf = tensorflow;
// Given an input stream of tensors, concatenates the tensors over timesteps.
@ -72,6 +77,9 @@ class LappedTensorBufferCalculator : public CalculatorBase {
::mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor);
int steps_until_output_;
int buffer_size_;
int overlap_;
int timestamp_offset_;
std::unique_ptr<CircularBuffer<Timestamp>> timestamp_buffer_;
std::unique_ptr<CircularBuffer<tf::Tensor>> buffer_;
LappedTensorBufferCalculatorOptions options_;
@ -87,6 +95,21 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
);
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one output stream is supported.";
if (cc->InputSidePackets().HasTag(kBufferSize)) {
cc->InputSidePackets().Tag(kBufferSize).Set<int>();
}
if (cc->InputSidePackets().HasTag(kOverlap)) {
cc->InputSidePackets().Tag(kOverlap).Set<int>();
}
if (cc->InputSidePackets().HasTag(kTimestampOffset)) {
cc->InputSidePackets().Tag(kTimestampOffset).Set<int>();
}
if (cc->InputSidePackets().HasTag(kCalculatorOptions)) {
cc->InputSidePackets()
.Tag(kCalculatorOptions)
.Set<LappedTensorBufferCalculatorOptions>();
}
cc->Outputs().Index(0).Set<tf::Tensor>(
// Output tensorflow::Tensor stream with possibly overlapping steps.
);
@ -95,16 +118,33 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
::mediapipe::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<LappedTensorBufferCalculatorOptions>();
RET_CHECK_LT(options_.overlap(), options_.buffer_size());
RET_CHECK_GE(options_.timestamp_offset(), 0)
if (cc->InputSidePackets().HasTag(kCalculatorOptions)) {
options_ = cc->InputSidePackets()
.Tag(kCalculatorOptions)
.Get<LappedTensorBufferCalculatorOptions>();
}
buffer_size_ = options_.buffer_size();
if (cc->InputSidePackets().HasTag(kBufferSize)) {
buffer_size_ = cc->InputSidePackets().Tag(kBufferSize).Get<int>();
}
overlap_ = options_.overlap();
if (cc->InputSidePackets().HasTag(kOverlap)) {
overlap_ = cc->InputSidePackets().Tag(kOverlap).Get<int>();
}
timestamp_offset_ = options_.timestamp_offset();
if (cc->InputSidePackets().HasTag(kTimestampOffset)) {
timestamp_offset_ = cc->InputSidePackets().Tag(kTimestampOffset).Get<int>();
}
RET_CHECK_LT(overlap_, buffer_size_);
RET_CHECK_GE(timestamp_offset_, 0)
<< "Negative timestamp_offset is not allowed.";
RET_CHECK_LT(options_.timestamp_offset(), options_.buffer_size())
RET_CHECK_LT(timestamp_offset_, buffer_size_)
<< "output_frame_num_offset has to be less than buffer_size.";
timestamp_buffer_ =
absl::make_unique<CircularBuffer<Timestamp>>(options_.buffer_size());
buffer_ =
absl::make_unique<CircularBuffer<tf::Tensor>>(options_.buffer_size());
steps_until_output_ = options_.buffer_size();
absl::make_unique<CircularBuffer<Timestamp>>(buffer_size_);
buffer_ = absl::make_unique<CircularBuffer<tf::Tensor>>(buffer_size_);
steps_until_output_ = buffer_size_;
return ::mediapipe::OkStatus();
}
@ -128,11 +168,10 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
concatenated.get());
RET_CHECK(concat_status.ok()) << concat_status.ToString();
cc->Outputs().Index(0).Add(
concatenated.release(),
timestamp_buffer_->Get(options_.timestamp_offset()));
cc->Outputs().Index(0).Add(concatenated.release(),
timestamp_buffer_->Get(timestamp_offset_));
steps_until_output_ = options_.buffer_size() - options_.overlap();
steps_until_output_ = buffer_size_ - overlap_;
}
return ::mediapipe::OkStatus();
}

View File

@ -0,0 +1,126 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h"
namespace mediapipe {
const char kTFRecordPath[] = "TFRECORD_PATH";
const char kRecordIndex[] = "RECORD_INDEX";
const char kExampleTag[] = "EXAMPLE";
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
// Reads a tensorflow example/sequence example from a tfrecord file.
// If the "RECORD_INDEX" input side packet is provided, the calculator is going
// to fetch the example/sequence example of the tfrecord file at the target
// record index. Otherwise, the reader always reads the first example/sequence
// example of the tfrecord file.
//
// Example config:
// node {
// calculator: "TFRecordReaderCalculator"
// input_side_packet: "TFRECORD_PATH:tfrecord_path"
// input_side_packet: "RECORD_INDEX:record_index"
// output_side_packet: "SEQUENCE_EXAMPLE:sequence_example"
// }
class TFRecordReaderCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
};
::mediapipe::Status TFRecordReaderCalculator::GetContract(
CalculatorContract* cc) {
cc->InputSidePackets().Tag(kTFRecordPath).Set<std::string>();
if (cc->InputSidePackets().HasTag(kRecordIndex)) {
cc->InputSidePackets().Tag(kRecordIndex).Set<int>();
}
RET_CHECK(cc->OutputSidePackets().HasTag(kExampleTag) ||
cc->OutputSidePackets().HasTag(kSequenceExampleTag))
<< "TFRecordReaderCalculator must output either Tensorflow example or "
"sequence example.";
if (cc->OutputSidePackets().HasTag(kExampleTag)) {
cc->OutputSidePackets().Tag(kExampleTag).Set<tensorflow::Example>();
} else {
cc->OutputSidePackets()
.Tag(kSequenceExampleTag)
.Set<tensorflow::SequenceExample>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TFRecordReaderCalculator::Open(CalculatorContext* cc) {
std::unique_ptr<tensorflow::RandomAccessFile> file;
auto tf_status = tensorflow::Env::Default()->NewRandomAccessFile(
cc->InputSidePackets().Tag(kTFRecordPath).Get<std::string>(), &file);
RET_CHECK(tf_status.ok())
<< "Failed to open tfrecord file: " << tf_status.error_message();
tensorflow::io::RecordReader reader(file.get(),
tensorflow::io::RecordReaderOptions());
tensorflow::uint64 offset = 0;
std::string example_str;
const int target_idx =
cc->InputSidePackets().HasTag(kRecordIndex)
? cc->InputSidePackets().Tag(kRecordIndex).Get<int>()
: 0;
int current_idx = 0;
while (current_idx <= target_idx) {
tf_status = reader.ReadRecord(&offset, &example_str);
RET_CHECK(tf_status.ok())
<< "Failed to read tfrecord: " << tf_status.error_message();
if (current_idx == target_idx) {
if (cc->OutputSidePackets().HasTag(kExampleTag)) {
tensorflow::Example tf_example;
tf_example.ParseFromString(example_str);
cc->OutputSidePackets()
.Tag(kExampleTag)
.Set(MakePacket<tensorflow::Example>(std::move(tf_example)));
} else {
tensorflow::SequenceExample tf_sequence_example;
tf_sequence_example.ParseFromString(example_str);
cc->OutputSidePackets()
.Tag(kSequenceExampleTag)
.Set(MakePacket<tensorflow::SequenceExample>(
std::move(tf_sequence_example)));
}
}
++current_idx;
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TFRecordReaderCalculator::Process(CalculatorContext* cc) {
return ::mediapipe::OkStatus();
}
REGISTER_CALCULATOR(TFRecordReaderCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,192 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iterator>
#include "mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
namespace mediapipe {
namespace {
const char kId[] = "id";
const char kRgb[] = "rgb";
const char kAudio[] = "audio";
const char kDesiredSegmentSize[] = "DESIRED_SEGMENT_SIZE";
const char kYt8mId[] = "YT8M_ID";
const char kYt8mSequenceExample[] = "YT8M_SEQUENCE_EXAMPLE";
const char kQuantizedRgbFeature[] = "QUANTIZED_RGB_FEATURE";
const char kQuantizedAudioFeature[] = "QUANTIZED_AUDIO_FEATURE";
const char kSegmentSize[] = "SEGMENT_SIZE";
const char kLappedTensorBufferCalculatorOptions[] =
"LAPPED_TENSOR_BUFFER_CALCULATOR_OPTIONS";
std::string GetQuantizedFeature(
const tensorflow::SequenceExample& sequence_example, const std::string& key,
int index) {
const auto& bytes_list = sequence_example.feature_lists()
.feature_list()
.at(key)
.feature()
.Get(index)
.bytes_list()
.value();
CHECK_EQ(1, bytes_list.size());
return bytes_list.Get(0);
}
} // namespace
// Unpacks YT8M Sequence Example. Note that the audio feature and rgb feature
// output are quantized. DequantizeByteArrayCalculator can do the dequantization
// for you.
//
// Example config:
// node {
// calculator: "UnpackYt8mSequenceExampleCalculator"
// input_side_packet: "YT8M_SEQUENCE_EXAMPLE:yt8m_sequence_example"
// output_stream: "QUANTIZED_RGB_FEATURE:quantized_rgb_feature"
// output_stream: "QUANTIZED_AUDIO_FEATURE:quantized_audio_feature"
// }
class UnpackYt8mSequenceExampleCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets()
.Tag(kYt8mSequenceExample)
.Set<tensorflow::SequenceExample>();
if (cc->InputSidePackets().HasTag(kDesiredSegmentSize)) {
cc->InputSidePackets().Tag(kDesiredSegmentSize).Set<int>();
}
cc->Outputs().Tag(kQuantizedRgbFeature).Set<std::string>();
cc->Outputs().Tag(kQuantizedAudioFeature).Set<std::string>();
if (cc->OutputSidePackets().HasTag(kYt8mId)) {
cc->OutputSidePackets().Tag(kYt8mId).Set<std::string>();
}
if (cc->OutputSidePackets().HasTag(kLappedTensorBufferCalculatorOptions)) {
cc->OutputSidePackets()
.Tag(kLappedTensorBufferCalculatorOptions)
.Set<::mediapipe::LappedTensorBufferCalculatorOptions>();
}
if (cc->OutputSidePackets().HasTag(kSegmentSize)) {
cc->OutputSidePackets().Tag(kSegmentSize).Set<int>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
const tensorflow::SequenceExample& sequence_example =
cc->InputSidePackets()
.Tag(kYt8mSequenceExample)
.Get<tensorflow::SequenceExample>();
const std::string& yt8m_id =
sequence_example.context().feature().at(kId).bytes_list().value().Get(
0);
if (cc->OutputSidePackets().HasTag(kYt8mId)) {
cc->OutputSidePackets().Tag(kYt8mId).Set(
MakePacket<std::string>(yt8m_id));
}
int rgb_feature_list_length =
sequence_example.feature_lists().feature_list().at(kRgb).feature_size();
int audio_feature_list_length = sequence_example.feature_lists()
.feature_list()
.at(kAudio)
.feature_size();
if (rgb_feature_list_length != audio_feature_list_length) {
return ::mediapipe::FailedPreconditionError(absl::StrCat(
"Data corruption: the length of audio features and rgb features are "
"not equal. Please check the sequence example that contains yt8m "
"id: ",
yt8m_id));
}
feature_list_length_ = rgb_feature_list_length;
if (cc->OutputSidePackets().HasTag(kLappedTensorBufferCalculatorOptions) ||
cc->OutputSidePackets().HasTag(kSegmentSize)) {
// If the desired segment size is specified, take the min of the length of
// the feature list and the desired size to be the output segment size.
int segment_size = feature_list_length_;
if (cc->InputSidePackets().HasTag(kDesiredSegmentSize)) {
int desired_segment_size =
cc->InputSidePackets().Tag(kDesiredSegmentSize).Get<int>();
RET_CHECK(desired_segment_size > 0)
<< "The desired segment size must be greater than zero.";
segment_size = std::min(
feature_list_length_,
cc->InputSidePackets().Tag(kDesiredSegmentSize).Get<int>());
}
if (cc->OutputSidePackets().HasTag(
kLappedTensorBufferCalculatorOptions)) {
auto lapped_tensor_buffer_calculator_options = absl::make_unique<
::mediapipe::LappedTensorBufferCalculatorOptions>();
lapped_tensor_buffer_calculator_options->set_add_batch_dim_to_tensors(
true);
lapped_tensor_buffer_calculator_options->set_buffer_size(segment_size);
lapped_tensor_buffer_calculator_options->set_overlap(segment_size - 1);
lapped_tensor_buffer_calculator_options->set_timestamp_offset(
segment_size - 1);
cc->OutputSidePackets()
.Tag(kLappedTensorBufferCalculatorOptions)
.Set(Adopt(lapped_tensor_buffer_calculator_options.release()));
}
if (cc->OutputSidePackets().HasTag(kSegmentSize)) {
cc->OutputSidePackets()
.Tag(kSegmentSize)
.Set(MakePacket<int>(segment_size));
}
}
LOG(INFO) << "Reading the sequence example that contains yt8m id: "
<< yt8m_id << ". Feature list length: " << feature_list_length_;
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
if (current_index_ >= feature_list_length_) {
return ::mediapipe::tool::StatusStop();
}
const tensorflow::SequenceExample& sequence_example =
cc->InputSidePackets()
.Tag(kYt8mSequenceExample)
.Get<tensorflow::SequenceExample>();
// Uses microsecond as the unit of time. In the YT8M dataset, each feature
// represents a second.
const Timestamp timestamp = Timestamp(current_index_ * 1000000);
cc->Outputs()
.Tag(kQuantizedRgbFeature)
.AddPacket(
MakePacket<std::string>(
GetQuantizedFeature(sequence_example, kRgb, current_index_))
.At(timestamp));
cc->Outputs()
.Tag(kQuantizedAudioFeature)
.AddPacket(
MakePacket<std::string>(
GetQuantizedFeature(sequence_example, kAudio, current_index_))
.At(timestamp));
++current_index_;
return ::mediapipe::OkStatus();
}
private:
int current_index_ = 0;
int feature_list_length_ = 0;
};
REGISTER_CALCULATOR(UnpackYt8mSequenceExampleCalculator);
} // namespace mediapipe

View File

@ -23,10 +23,12 @@
namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
auto& INPUT_1D = VectorFloatToTensorCalculatorOptions::INPUT_1D;
auto& INPUT_2D = VectorFloatToTensorCalculatorOptions::INPUT_2D;
} // namespace
namespace tf = ::tensorflow;
// The calculator expects one input (a packet containing a vector<float> or
// vector<vector<float>>) and generates one output (a packet containing a

View File

@ -0,0 +1,203 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Converts a single int or vector<int> or vector<vector<int>> to 1D (or 2D)
// tf::Tensor.
#include "mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
const char kVectorInt[] = "VECTOR_INT";
const char kSingleInt[] = "SINGLE_INT";
const char kTensorOut[] = "TENSOR_OUT";
namespace {
auto& INPUT_1D = VectorIntToTensorCalculatorOptions::INPUT_1D;
auto& INPUT_2D = VectorIntToTensorCalculatorOptions::INPUT_2D;
} // namespace
namespace tf = ::tensorflow;
template <typename TensorType>
void AssignMatrixValue(int r, int c, int value, tf::Tensor* output_tensor) {
output_tensor->tensor<TensorType, 2>()(r, c) = value;
}
// The calculator expects one input (a packet containing a single int or
// vector<int> or vector<vector<int>>) and generates one output (a packet
// containing a tf::Tensor containing the same data). The output tensor will be
// either 1D or 2D with dimensions corresponding to the input vector int. It
// will hold DT_INT32 or DT_UINT8 or DT_INT64 values.
//
// Example config:
// node {
// calculator: "VectorIntToTensorCalculator"
// input_stream: "SINGLE_INT:segment_size_int_stream"
// output_stream: "TENSOR_OUT:segment_size_tensor"
// }
//
// or
//
// node {
// calculator: "VectorIntToTensorCalculator"
// input_stream: "VECTOR_INT:vector_int_features"
// output_stream: "TENSOR_OUT:tensor_features"
// }
class VectorIntToTensorCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
VectorIntToTensorCalculatorOptions options_;
};
REGISTER_CALCULATOR(VectorIntToTensorCalculator);
::mediapipe::Status VectorIntToTensorCalculator::GetContract(
CalculatorContract* cc) {
const auto& options = cc->Options<VectorIntToTensorCalculatorOptions>();
// Start with only one input packet.
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one input stream is supported.";
if (options.input_size() == INPUT_2D) {
cc->Inputs().Tag(kVectorInt).Set<std::vector<std::vector<int>>>();
} else if (options.input_size() == INPUT_1D) {
if (cc->Inputs().HasTag(kSingleInt)) {
cc->Inputs().Tag(kSingleInt).Set<int>();
} else {
cc->Inputs().Tag(kVectorInt).Set<std::vector<int>>();
}
} else {
LOG(FATAL) << "input size not supported";
}
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported.";
cc->Outputs().Tag(kTensorOut).Set<tf::Tensor>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status VectorIntToTensorCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<VectorIntToTensorCalculatorOptions>();
RET_CHECK(options_.tensor_data_type() == tf::DT_UINT8 ||
options_.tensor_data_type() == tf::DT_INT32 ||
options_.tensor_data_type() == tf::DT_INT64)
<< "Output tensor data type is not supported.";
return ::mediapipe::OkStatus();
}
::mediapipe::Status VectorIntToTensorCalculator::Process(
CalculatorContext* cc) {
tf::TensorShape tensor_shape;
if (options_.input_size() == INPUT_2D) {
const std::vector<std::vector<int>>& input =
cc->Inputs()
.Tag(kVectorInt)
.Value()
.Get<std::vector<std::vector<int>>>();
const int32 rows = input.size();
CHECK_GE(rows, 1);
const int32 cols = input[0].size();
CHECK_GE(cols, 1);
for (int i = 1; i < rows; ++i) {
CHECK_EQ(input[i].size(), cols);
}
if (options_.transpose()) {
tensor_shape = tf::TensorShape({cols, rows});
} else {
tensor_shape = tf::TensorShape({rows, cols});
}
auto output = ::absl::make_unique<tf::Tensor>(options_.tensor_data_type(),
tensor_shape);
if (options_.transpose()) {
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
switch (options_.tensor_data_type()) {
case tf::DT_INT64:
AssignMatrixValue<tf::int64>(c, r, input[r][c], output.get());
break;
case tf::DT_UINT8:
AssignMatrixValue<uint8>(c, r, input[r][c], output.get());
break;
case tf::DT_INT32:
AssignMatrixValue<int>(c, r, input[r][c], output.get());
break;
default:
LOG(FATAL) << "tensor data type is not supported.";
}
}
}
} else {
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
switch (options_.tensor_data_type()) {
case tf::DT_INT64:
AssignMatrixValue<tf::int64>(r, c, input[r][c], output.get());
break;
case tf::DT_UINT8:
AssignMatrixValue<uint8>(r, c, input[r][c], output.get());
break;
case tf::DT_INT32:
AssignMatrixValue<int>(r, c, input[r][c], output.get());
break;
default:
LOG(FATAL) << "tensor data type is not supported.";
}
}
}
}
cc->Outputs().Tag(kTensorOut).Add(output.release(), cc->InputTimestamp());
} else if (options_.input_size() == INPUT_1D) {
std::vector<int> input;
if (cc->Inputs().HasTag(kSingleInt)) {
input.push_back(cc->Inputs().Tag(kSingleInt).Get<int>());
} else {
input = cc->Inputs().Tag(kVectorInt).Value().Get<std::vector<int>>();
}
CHECK_GE(input.size(), 1);
const int32 length = input.size();
tensor_shape = tf::TensorShape({length});
auto output = ::absl::make_unique<tf::Tensor>(options_.tensor_data_type(),
tensor_shape);
for (int i = 0; i < length; ++i) {
switch (options_.tensor_data_type()) {
case tf::DT_INT64:
output->tensor<tf::int64, 1>()(i) = input.at(i);
break;
case tf::DT_UINT8:
output->tensor<uint8, 1>()(i) = input.at(i);
break;
case tf::DT_INT32:
output->tensor<int, 1>()(i) = input.at(i);
break;
default:
LOG(FATAL) << "tensor data type is not supported.";
}
}
cc->Outputs().Tag(kTensorOut).Add(output.release(), cc->InputTimestamp());
} else {
LOG(FATAL) << "input size not supported";
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,43 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "tensorflow/core/framework/types.proto";
message VectorIntToTensorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional VectorIntToTensorCalculatorOptions ext = 275364184;
}
enum InputSize {
UNKNOWN = 0;
INPUT_1D = 1;
INPUT_2D = 2;
}
// If input_size is INPUT_2D, unpack a vector<vector<int>> to a
// 2d tensor (matrix). If INPUT_1D, convert a single int or vector<int>
// into a 1d tensor (vector).
optional InputSize input_size = 1 [default = INPUT_1D];
// If true, the output tensor is transposed.
// Otherwise, the output tensor is not transposed.
// It will be ignored if tensor_is_2d is INPUT_1D.
optional bool transpose = 2 [default = false];
optional tensorflow.DataType tensor_data_type = 3 [default = DT_INT32];
}

View File

@ -0,0 +1,202 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mediapipe {
namespace {
namespace tf = ::tensorflow;
class VectorIntToTensorCalculatorTest : public ::testing::Test {
protected:
void SetUpRunner(
const VectorIntToTensorCalculatorOptions::InputSize input_size,
const tensorflow::DataType tensor_data_type, const bool transpose,
const bool single_value) {
CalculatorGraphConfig::Node config;
config.set_calculator("VectorIntToTensorCalculator");
if (single_value) {
config.add_input_stream("SINGLE_INT:input_int");
} else {
config.add_input_stream("VECTOR_INT:input_int");
}
config.add_output_stream("TENSOR_OUT:output_tensor");
auto options = config.mutable_options()->MutableExtension(
VectorIntToTensorCalculatorOptions::ext);
options->set_input_size(input_size);
options->set_transpose(transpose);
options->set_tensor_data_type(tensor_data_type);
runner_ = ::absl::make_unique<CalculatorRunner>(config);
}
void TestConvertFromVectoVectorInt(const bool transpose) {
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_2D,
tensorflow::DT_INT32, transpose, false);
auto input = ::absl::make_unique<std::vector<std::vector<int>>>(
2, std::vector<int>(2));
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
input->at(i).at(j) = i * 2 + j;
}
}
const int64 time = 1234;
runner_->MutableInputs()
->Tag("VECTOR_INT")
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
EXPECT_EQ(2, output_tensor.dims());
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
const auto matrix = output_tensor.matrix<int>();
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
if (!transpose) {
EXPECT_EQ(i * 2 + j, matrix(i, j));
} else {
EXPECT_EQ(j * 2 + i, matrix(i, j));
}
}
}
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(VectorIntToTensorCalculatorTest, TestSingleValue) {
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
tensorflow::DT_INT32, false, true);
const int64 time = 1234;
runner_->MutableInputs()
->Tag("SINGLE_INT")
.packets.push_back(MakePacket<int>(1).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
EXPECT_EQ(1, output_tensor.dims());
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
const auto vec = output_tensor.vec<int32>();
EXPECT_EQ(1, vec(0));
}
TEST_F(VectorIntToTensorCalculatorTest, TesOneDim) {
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
tensorflow::DT_INT32, false, false);
auto input = ::absl::make_unique<std::vector<int>>(5);
for (int i = 0; i < 5; ++i) {
input->at(i) = i;
}
const int64 time = 1234;
runner_->MutableInputs()
->Tag("VECTOR_INT")
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
EXPECT_EQ(1, output_tensor.dims());
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
const auto vec = output_tensor.vec<int32>();
for (int i = 0; i < 5; ++i) {
EXPECT_EQ(i, vec(i));
}
}
TEST_F(VectorIntToTensorCalculatorTest, TestTwoDims) {
for (bool transpose : {false, true}) {
TestConvertFromVectoVectorInt(transpose);
}
}
TEST_F(VectorIntToTensorCalculatorTest, TestInt64) {
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
tensorflow::DT_INT64, false, true);
const int64 time = 1234;
runner_->MutableInputs()
->Tag("SINGLE_INT")
.packets.push_back(MakePacket<int>(2 ^ 31).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
EXPECT_EQ(1, output_tensor.dims());
EXPECT_EQ(tf::DT_INT64, output_tensor.dtype());
const auto vec = output_tensor.vec<tf::int64>();
EXPECT_EQ(2 ^ 31, vec(0));
}
TEST_F(VectorIntToTensorCalculatorTest, TestUint8) {
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
tensorflow::DT_UINT8, false, false);
auto input = ::absl::make_unique<std::vector<int>>(5);
for (int i = 0; i < 5; ++i) {
input->at(i) = i;
}
const int64 time = 1234;
runner_->MutableInputs()
->Tag("VECTOR_INT")
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
EXPECT_EQ(1, output_tensor.dims());
EXPECT_EQ(tf::DT_UINT8, output_tensor.dtype());
const auto vec = output_tensor.vec<uint8>();
for (int i = 0; i < 5; ++i) {
EXPECT_EQ(i, vec(i));
}
}
} // namespace
} // namespace mediapipe

View File

@ -238,6 +238,7 @@ cc_library(
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape",
"@org_tensorflow//tensorflow/lite/delegates/gpu/metal:buffer_convert",
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal",
],
"//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper",

View File

@ -25,7 +25,8 @@
#include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
@ -45,7 +46,8 @@
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // iOS
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
typedef id<MTLBuffer> GpuTensor;
@ -67,7 +69,8 @@ typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>
namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlProgram;
using ::tflite::gpu::gl::GlShader;
@ -146,7 +149,8 @@ class TfLiteConverterCalculator : public CalculatorBase {
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_out_;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
@ -181,7 +185,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
if (cc->Inputs().HasTag("MATRIX")) cc->Inputs().Tag("MATRIX").Set<Matrix>();
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Inputs().HasTag("IMAGE_GPU")) {
cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
use_gpu |= true;
@ -190,7 +194,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
if (cc->Outputs().HasTag("TENSORS"))
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Outputs().HasTag("TENSORS_GPU")) {
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
use_gpu |= true;
@ -198,7 +202,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
#endif // !MEDIAPIPE_DISABLE_GPU
if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
@ -218,7 +223,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
if (cc->Inputs().HasTag("IMAGE_GPU") ||
cc->Outputs().HasTag("IMAGE_OUT_GPU")) {
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
use_gpu_ = true;
#else
RET_CHECK_FAIL() << "GPU processing not enabled.";
@ -231,7 +236,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
cc->Outputs().HasTag("TENSORS_GPU"));
// Cannot use quantization.
use_quantized_tensors_ = false;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
@ -264,7 +270,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
}
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
#endif
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
@ -383,7 +390,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
// GpuBuffer to tflite::gpu::GlBuffer conversion.
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
MP_RETURN_IF_ERROR(
@ -468,7 +476,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
}
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
// Get input image sizes.
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
mediapipe::ImageFormat::Format format =
@ -485,7 +493,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
RET_CHECK_FAIL() << "Num input channels is less than desired output.";
#endif // !MEDIAPIPE_DISABLE_GPU
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status {
// Device memory.

View File

@ -27,7 +27,7 @@
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
@ -48,11 +48,13 @@
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h"
#endif // iOS
namespace {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
typedef id<MTLBuffer> GpuTensor;
@ -68,13 +70,14 @@ size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT
// * Aux
namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
using ::tflite::gpu::gl::CopyBuffer;
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlBuffer;
#endif
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
struct GPUData {
int elements = 1;
GpuTensor buffer;
@ -147,7 +150,8 @@ class TfLiteInferenceCalculator : public CalculatorBase {
std::unique_ptr<tflite::FlatBufferModel> model_;
TfLiteDelegate* delegate_ = nullptr;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_in_;
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
@ -179,7 +183,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
if (cc->Inputs().HasTag("TENSORS"))
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Inputs().HasTag("TENSORS_GPU")) {
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
use_gpu |= true;
@ -188,7 +192,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
if (cc->Outputs().HasTag("TENSORS"))
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Outputs().HasTag("TENSORS_GPU")) {
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
use_gpu |= true;
@ -206,7 +210,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
use_gpu |= options.use_gpu();
if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
@ -225,7 +230,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
MP_RETURN_IF_ERROR(LoadOptions(cc));
if (cc->Inputs().HasTag("TENSORS_GPU")) {
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
gpu_input_ = true;
gpu_inference_ = true; // Inference must be on GPU also.
#else
@ -235,7 +240,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
}
if (cc->Outputs().HasTag("TENSORS_GPU")) {
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
gpu_output_ = true;
RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU"))
<< "GPU output must also have GPU Input.";
@ -248,13 +253,15 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
MP_RETURN_IF_ERROR(LoadModel(cc));
if (gpu_inference_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
#endif
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &cc]() -> ::mediapipe::Status { return LoadDelegate(cc); }));
#else
@ -262,6 +269,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
#endif
}
#if defined(__EMSCRIPTEN__)
MP_RETURN_IF_ERROR(LoadDelegate(cc));
#endif // __EMSCRIPTEN__
return ::mediapipe::OkStatus();
}
@ -269,7 +280,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// 1. Receive pre-processed tensor inputs.
if (gpu_input_) {
// Read GPU input into SSBO.
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
RET_CHECK_EQ(input_tensors.size(), 1);
@ -315,7 +327,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// 2. Run inference.
if (gpu_inference_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
@ -330,7 +343,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// 3. Output processed tensors.
if (gpu_output_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
// Output result tensors (GPU).
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
@ -392,7 +406,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
if (delegate_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
TfLiteGpuDelegateDelete(delegate_);
gpu_data_in_.reset();
@ -456,6 +471,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
RET_CHECK(interpreter_);
#if defined(__EMSCRIPTEN__)
interpreter_->SetNumThreads(1);
#endif // __EMSCRIPTEN__
if (gpu_output_) {
use_quantized_tensors_ = false;
} else {
@ -471,7 +490,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate(
CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
// Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed = 1;
@ -533,9 +553,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
// Configure and create the delegate.
GpuDelegateOptions options;
TFLGpuDelegateOptions options;
options.allow_precision_loss = false; // Must match converter, F=float/T=half
options.wait_type = GpuDelegateOptions::WaitType::kPassive;
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive;
if (!delegate_) delegate_ = TFLGpuDelegateCreate(&options);
id<MTLDevice> device = gpu_helper_.mtlDevice;

View File

@ -24,7 +24,8 @@
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/resource_util.h"
#include "tensorflow/lite/interpreter.h"
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
#if defined(__EMSCRIPTEN__) || defined(__ANDROID__) || \
(defined(__APPLE__) && !TARGET_OS_OSX)
#include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/helpers.h"
#else
@ -66,8 +67,8 @@ class TfLiteTensorsToClassificationCalculator : public CalculatorBase {
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::TfLiteTensorsToClassificationCalculatorOptions options_;
int top_k_ = 0;
double min_score_threshold_ = 0;
std::unordered_map<int, std::string> label_map_;
bool label_map_loaded_ = false;
};
@ -93,15 +94,14 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
auto options = cc->Options<
options_ = cc->Options<
::mediapipe::TfLiteTensorsToClassificationCalculatorOptions>();
top_k_ = options.top_k();
min_score_threshold_ = options.min_score_threshold();
if (options.has_label_map_path()) {
top_k_ = options_.top_k();
if (options_.has_label_map_path()) {
std::string string_path;
ASSIGN_OR_RETURN(string_path,
PathToResourceAsFile(options.label_map_path()));
PathToResourceAsFile(options_.label_map_path()));
std::string label_map_string;
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
@ -125,9 +125,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
RET_CHECK_EQ(input_tensors.size(), 1);
const TfLiteTensor* raw_score_tensor = &input_tensors[0];
RET_CHECK_EQ(raw_score_tensor->dims->size, 2);
RET_CHECK_EQ(raw_score_tensor->dims->data[0], 1);
int num_classes = raw_score_tensor->dims->data[1];
int num_classes = 1;
for (int i = 0; i < raw_score_tensor->dims->size; ++i) {
num_classes *= raw_score_tensor->dims->data[i];
}
if (label_map_loaded_) {
RET_CHECK_EQ(num_classes, label_map_.size());
}
@ -135,7 +137,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
auto classification_list = absl::make_unique<ClassificationList>();
for (int i = 0; i < num_classes; ++i) {
if (raw_scores[i] < min_score_threshold_) {
if (options_.has_min_score_threshold() &&
raw_scores[i] < options_.min_score_threshold()) {
continue;
}
Classification* classification = classification_list->add_classification();
@ -148,6 +151,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
// Note that partial_sort will raise error when top_k_ >
// classification_list->classification_size().
CHECK_GE(classification_list->classification_size(), top_k_);
auto raw_classification_list = classification_list->mutable_classification();
if (top_k_ > 0 && classification_list->classification_size() >= top_k_) {
std::partial_sort(raw_classification_list->begin(),

View File

@ -27,7 +27,8 @@
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
@ -55,12 +56,14 @@ constexpr int kNumCoordsPerBox = 4;
namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlShader;
#endif
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
typedef ::tflite::gpu::gl::GlProgram GpuProgram;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
@ -70,7 +73,7 @@ typedef id<MTLComputePipelineState> GpuProgram;
namespace {
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
struct GPUData {
GpuProgram decode_program;
GpuProgram score_program;
@ -180,7 +183,8 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
std::vector<Anchor> anchors_;
bool side_packet_anchors_{};
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
@ -204,7 +208,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
}
#if !defined(MEDIAPIPE_DISABLE_GPU)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Inputs().HasTag("TENSORS_GPU")) {
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
use_gpu |= true;
@ -222,7 +226,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
}
if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
@ -238,7 +243,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
if (cc->Inputs().HasTag("TENSORS_GPU")) {
gpu_input_ = true;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
@ -400,7 +406,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
}
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
CalculatorContext* cc, std::vector<Detection>* output_detections) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
RET_CHECK_GE(input_tensors.size(), 2);
@ -562,7 +569,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close(
CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); });
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_data_.reset();
@ -715,7 +723,8 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit(
CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
-> ::mediapipe::Status {
gpu_data_ = absl::make_unique<GPUData>();

View File

@ -21,7 +21,8 @@
namespace mediapipe {
// A calculator for converting TFLite tensors from regression models into
// landmarks.
// landmarks. Note that if the landmarks in the tensor has more than 3
// dimensions, only the first 3 dimensions will be converted to x,y,z.
//
// Input:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first
@ -122,9 +123,6 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
num_values *= raw_tensor->dims->data[i];
}
const int num_dimensions = num_values / num_landmarks_;
// Landmarks must have less than 3 dimensions. Otherwise please consider
// using matrix.
CHECK_LE(num_dimensions, 3);
CHECK_GT(num_dimensions, 0);
const float* raw_landmarks = raw_tensor->data.f;

View File

@ -28,7 +28,8 @@
#include "mediapipe/util/resource_util.h"
#include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h"
@ -53,7 +54,8 @@ float Clamp(float val, float min, float max) {
namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
using ::tflite::gpu::gl::CopyBuffer;
using ::tflite::gpu::gl::CreateReadWriteRgbaImageTexture;
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
@ -129,7 +131,8 @@ class TfLiteTensorsToSegmentationCalculator : public CalculatorBase {
int tensor_channels_ = 0;
bool use_gpu_ = false;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GlProgram> mask_program_with_prev_;
std::unique_ptr<GlProgram> mask_program_no_prev_;
@ -159,7 +162,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
}
// Inputs GPU.
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
if (cc->Inputs().HasTag("TENSORS_GPU")) {
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>();
use_gpu |= true;
@ -178,7 +182,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
if (cc->Outputs().HasTag("MASK")) {
cc->Outputs().Tag("MASK").Set<ImageFrame>();
}
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
if (cc->Outputs().HasTag("MASK_GPU")) {
cc->Outputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>();
use_gpu |= true;
@ -186,7 +191,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
#endif // !MEDIAPIPE_DISABLE_GPU
if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // !MEDIAPIPE_DISABLE_GPU
}
@ -199,7 +205,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
if (cc->Inputs().HasTag("TENSORS_GPU")) {
use_gpu_ = true;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif // !MEDIAPIPE_DISABLE_GPU
}
@ -207,7 +214,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
MP_RETURN_IF_ERROR(LoadOptions(cc));
if (use_gpu_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
MP_RETURN_IF_ERROR(InitGpu(cc));
@ -224,7 +232,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Process(
CalculatorContext* cc) {
if (use_gpu_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
MP_RETURN_IF_ERROR(ProcessGpu(cc));
@ -240,7 +249,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close(
CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
gpu_helper_.RunInGlContext([this] {
if (upsample_program_) glDeleteProgram(upsample_program_);
upsample_program_ = 0;
@ -367,7 +377,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
if (cc->Inputs().Tag("TENSORS_GPU").IsEmpty()) {
return ::mediapipe::OkStatus();
}
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
// Get input streams.
const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>();
@ -453,7 +464,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
}
void TfLiteTensorsToSegmentationCalculator::GlRender() {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right
@ -525,7 +537,8 @@ void TfLiteTensorsToSegmentationCalculator::GlRender() {
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::InitGpu(
CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
-> ::mediapipe::Status {
// A shader to process a segmentation tensor into an output mask,

View File

@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
package(default_visibility = ["//visibility:public"])
exports_files(["LICENSE"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "annotation_overlay_calculator_proto",
srcs = ["annotation_overlay_calculator.proto"],
@ -72,6 +72,24 @@ proto_library(
],
)
proto_library(
name = "collection_has_min_size_calculator_proto",
srcs = ["collection_has_min_size_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
proto_library(
name = "association_calculator_proto",
srcs = ["association_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_cc_proto_library(
name = "annotation_overlay_calculator_cc_proto",
srcs = ["annotation_overlay_calculator.proto"],
@ -141,6 +159,26 @@ mediapipe_cc_proto_library(
],
)
mediapipe_cc_proto_library(
name = "collection_has_min_size_calculator_cc_proto",
srcs = ["collection_has_min_size_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":collection_has_min_size_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "association_calculator_cc_proto",
srcs = ["association_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//mediapipe:__subpackages__"],
deps = [":association_calculator_proto"],
)
cc_library(
name = "packet_frequency_calculator",
srcs = ["packet_frequency_calculator.cc"],
@ -234,6 +272,7 @@ cc_library(
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector",
"//mediapipe/util:annotation_renderer",
"//mediapipe/util:render_data_cc_proto",
] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
@ -360,6 +399,16 @@ mediapipe_cc_proto_library(
deps = [":landmark_projection_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "landmarks_to_floats_calculator_cc_proto",
srcs = ["landmarks_to_floats_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//visibility:public"],
deps = [":landmarks_to_floats_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "rect_transformation_calculator_cc_proto",
srcs = ["rect_transformation_calculator.proto"],
@ -372,7 +421,12 @@ mediapipe_cc_proto_library(
cc_library(
name = "detections_to_rects_calculator",
srcs = ["detections_to_rects_calculator.cc"],
srcs = [
"detections_to_rects_calculator.cc",
],
hdrs = [
"detections_to_rects_calculator.h",
],
visibility = ["//visibility:public"],
deps = [
":detections_to_rects_calculator_cc_proto",
@ -454,6 +508,17 @@ proto_library(
],
)
proto_library(
name = "labels_to_render_data_calculator_proto",
srcs = ["labels_to_render_data_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
"//mediapipe/util:color_proto",
"//mediapipe/util:render_data_proto",
],
)
proto_library(
name = "thresholding_calculator_proto",
srcs = ["thresholding_calculator.proto"],
@ -483,6 +548,15 @@ proto_library(
],
)
proto_library(
name = "landmarks_to_floats_calculator_proto",
srcs = ["landmarks_to_floats_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
proto_library(
name = "rect_transformation_calculator_proto",
srcs = ["rect_transformation_calculator.proto"],
@ -577,6 +651,26 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "labels_to_render_data_calculator",
srcs = ["labels_to_render_data_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":labels_to_render_data_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/util:color_cc_proto",
"//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_library(
name = "rect_to_render_data_calculator",
srcs = ["rect_to_render_data_calculator.cc"],
@ -658,6 +752,22 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "landmarks_to_floats_calculator",
srcs = ["landmarks_to_floats_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":landmarks_to_floats_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@eigen_archive//:eigen",
],
alwayslink = 1,
)
cc_test(
name = "detection_letterbox_removal_calculator_test",
srcs = ["detection_letterbox_removal_calculator_test.cc"],
@ -714,6 +824,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":top_k_scores_calculator_cc_proto",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
@ -750,3 +861,125 @@ cc_test(
"//mediapipe/framework/port:status",
],
)
mediapipe_cc_proto_library(
name = "labels_to_render_data_calculator_cc_proto",
srcs = ["labels_to_render_data_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/util:color_cc_proto",
],
visibility = ["//visibility:public"],
deps = [":labels_to_render_data_calculator_proto"],
)
cc_library(
name = "local_file_contents_calculator",
srcs = ["local_file_contents_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "filter_collection_calculator",
srcs = ["filter_collection_calculator.cc"],
hdrs = ["filter_collection_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_library(
name = "collection_has_min_size_calculator",
srcs = ["collection_has_min_size_calculator.cc"],
hdrs = ["collection_has_min_size_calculator.h"],
visibility = ["//visibility:public"],
deps = [
":collection_has_min_size_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "association_calculator",
hdrs = ["association_calculator.h"],
visibility = ["//visibility:public"],
deps = [
":association_calculator_cc_proto",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/port:rectangle",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/memory",
],
alwayslink = 1,
)
cc_library(
name = "association_norm_rect_calculator",
srcs = ["association_norm_rect_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":association_calculator",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:rectangle",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "association_detection_calculator",
srcs = ["association_detection_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":association_calculator",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location",
"//mediapipe/framework/port:rectangle",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "association_calculator_test",
srcs = ["association_calculator_test.cc"],
deps = [
":association_detection_calculator",
":association_norm_rect_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:packet",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
)

View File

@ -26,6 +26,7 @@
#include "mediapipe/framework/port/vector.h"
#include "mediapipe/util/annotation_renderer.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
#if !defined(MEDIAPIPE_DISABLE_GPU)
#include "mediapipe/gpu/gl_calculator_helper.h"
@ -41,6 +42,8 @@ namespace {
constexpr char kInputFrameTag[] = "INPUT_FRAME";
constexpr char kOutputFrameTag[] = "OUTPUT_FRAME";
constexpr char kInputVectorTag[] = "VECTOR";
constexpr char kInputFrameTagGpu[] = "INPUT_FRAME_GPU";
constexpr char kOutputFrameTagGpu[] = "OUTPUT_FRAME_GPU";
@ -65,6 +68,9 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
// 2. RenderData proto on variable number of input streams. All the RenderData
// at a particular timestamp is drawn on the image in the order of their
// input streams. No tags required.
// 3. std::vector<RenderData> on variable number of input streams. RenderData
// objects at a particular timestamp are drawn on the image in order of the
// input vector items. These input streams are tagged with "VECTOR".
//
// Output:
// 1. OUTPUT_FRAME or OUTPUT_FRAME_GPU: A rendered ImageFrame (or GpuBuffer).
@ -85,6 +91,8 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
// input_stream: "render_data_1"
// input_stream: "render_data_2"
// input_stream: "render_data_3"
// input_stream: "VECTOR:0:render_data_vec_0"
// input_stream: "VECTOR:1:render_data_vec_1"
// output_stream: "OUTPUT_FRAME:decorated_frames"
// options {
// [mediapipe.AnnotationOverlayCalculatorOptions.ext] {
@ -99,6 +107,8 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
// input_stream: "render_data_1"
// input_stream: "render_data_2"
// input_stream: "render_data_3"
// input_stream: "VECTOR:0:render_data_vec_0"
// input_stream: "VECTOR:1:render_data_vec_1"
// output_stream: "OUTPUT_FRAME_GPU:decorated_frames"
// options {
// [mediapipe.AnnotationOverlayCalculatorOptions.ext] {
@ -138,9 +148,6 @@ class AnnotationOverlayCalculator : public CalculatorBase {
// Underlying helper renderer library.
std::unique_ptr<AnnotationRenderer> renderer_;
// Number of input streams with render data.
int num_render_streams_;
// Indicates if image frame is available as input.
bool image_frame_available_ = false;
@ -171,25 +178,28 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
return ::mediapipe::InternalError("GPU output must have GPU input.");
}
// Assume all inputs are render streams; adjust below.
int num_render_streams = cc->Inputs().NumEntries();
// Input image to render onto copy of.
#if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag(kInputFrameTagGpu)) {
cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>();
num_render_streams = cc->Inputs().NumEntries() - 1;
use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kInputFrameTag)) {
cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>();
num_render_streams = cc->Inputs().NumEntries() - 1;
}
// Data streams to render.
for (int i = 0; i < num_render_streams; ++i) {
cc->Inputs().Index(i).Set<RenderData>();
for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId();
++id) {
auto tag_and_index = cc->Inputs().TagAndIndexFromId(id);
std::string tag = tag_and_index.first;
if (tag == kInputVectorTag) {
cc->Inputs().Get(id).Set<std::vector<RenderData>>();
} else if (tag.empty()) {
// Empty tag defaults to accepting a single object of RenderData type.
cc->Inputs().Get(id).Set<RenderData>();
}
}
// Rendered image.
@ -228,12 +238,10 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
if (cc->Inputs().HasTag(kInputFrameTagGpu) ||
cc->Inputs().HasTag(kInputFrameTag)) {
image_frame_available_ = true;
num_render_streams_ = cc->Inputs().NumEntries() - 1;
} else {
image_frame_available_ = false;
RET_CHECK(options_.has_canvas_width_px());
RET_CHECK(options_.has_canvas_height_px());
num_render_streams_ = cc->Inputs().NumEntries();
}
// Initialize the helper renderer library.
@ -285,12 +293,28 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
renderer_->AdoptImage(image_mat.get());
// Render streams onto render target.
for (int i = 0; i < num_render_streams_; ++i) {
if (cc->Inputs().Index(i).IsEmpty()) {
for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId();
++id) {
auto tag_and_index = cc->Inputs().TagAndIndexFromId(id);
std::string tag = tag_and_index.first;
if (!tag.empty() && tag != kInputVectorTag) {
continue;
}
const RenderData& render_data = cc->Inputs().Index(i).Get<RenderData>();
renderer_->RenderDataOnImage(render_data);
if (cc->Inputs().Get(id).IsEmpty()) {
continue;
}
if (tag.empty()) {
// Empty tag defaults to accepting a single object of RenderData type.
const RenderData& render_data = cc->Inputs().Get(id).Get<RenderData>();
renderer_->RenderDataOnImage(render_data);
} else {
RET_CHECK_EQ(kInputVectorTag, tag);
const std::vector<RenderData>& render_data_vec =
cc->Inputs().Get(id).Get<std::vector<RenderData>>();
for (const RenderData& render_data : render_data_vec) {
renderer_->RenderDataOnImage(render_data);
}
}
}
if (use_gpu_) {

View File

@ -0,0 +1,259 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "mediapipe/calculators/util/association_calculator.pb.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Computes the overlap similarity based on Intersection over Union (IoU) of
// two rectangles.
inline float OverlapSimilarity(const Rectangle_f& rect1,
const Rectangle_f& rect2) {
if (!rect1.Intersects(rect2)) return 0.0f;
// Compute IoU similarity score.
const float intersection_area = Rectangle_f(rect1).Intersect(rect2).Area();
const float normalization = rect1.Area() + rect2.Area() - intersection_area;
return normalization > 0.0f ? intersection_area / normalization : 0.0f;
}
// AssocationCalculator<T> accepts multiple inputs of vectors of type T that can
// be converted to Rectangle_f. The output is a vector of type T that contains
// elements from the input vectors that don't overlap with each other. When
// two elements overlap, the element that comes in from a later input stream
// is kept in the output. This association operation is useful for multiple
// instance inference pipelines in MediaPipe.
// If an input stream is tagged with "PREV" tag, IDs of overlapping elements
// from "PREV" input stream are propagated to the output. Elements in the "PREV"
// input stream that don't overlap with other elements are not added to the
// output. This stream is designed to take detections from previous timestamp,
// e.g. output of PreviousLoopbackCalculator to provide temporal association.
// See AssociationDetectionCalculator and AssociationNormRectCalculator for
// example uses.
template <typename T>
class AssociationCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
// Atmost one input stream can be tagged with "PREV".
RET_CHECK_LE(cc->Inputs().NumEntries("PREV"), 1);
if (cc->Inputs().HasTag("PREV")) {
RET_CHECK_GE(cc->Inputs().NumEntries(), 2);
}
for (CollectionItemId id = cc->Inputs().BeginId();
id < cc->Inputs().EndId(); ++id) {
cc->Inputs().Get(id).Set<std::vector<T>>();
}
cc->Outputs().Index(0).Set<std::vector<T>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
has_prev_input_stream_ = cc->Inputs().HasTag("PREV");
if (has_prev_input_stream_) {
prev_input_stream_id_ = cc->Inputs().GetId("PREV", 0);
}
options_ = cc->Options<::mediapipe::AssociationCalculatorOptions>();
CHECK_GE(options_.min_similarity_threshold(), 0);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
auto get_non_overlapping_elements = GetNonOverlappingElements(cc);
if (!get_non_overlapping_elements.ok()) {
return get_non_overlapping_elements.status();
}
std::list<T> result = get_non_overlapping_elements.ValueOrDie();
if (has_prev_input_stream_ &&
!cc->Inputs().Get(prev_input_stream_id_).IsEmpty()) {
// Processed all regular input streams. Now compare the result list
// elements with those in the PREV input stream, and propagate IDs from
// PREV input stream as appropriate.
const std::vector<T>& prev_input_vec =
cc->Inputs()
.Get(prev_input_stream_id_)
.template Get<std::vector<T>>();
MP_RETURN_IF_ERROR(
PropagateIdsFromPreviousToCurrent(prev_input_vec, &result));
}
auto output = absl::make_unique<std::vector<T>>();
for (auto it = result.begin(); it != result.end(); ++it) {
output->push_back(*it);
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
protected:
::mediapipe::AssociationCalculatorOptions options_;
bool has_prev_input_stream_;
CollectionItemId prev_input_stream_id_;
virtual ::mediapipe::StatusOr<Rectangle_f> GetRectangle(const T& input) {
return ::mediapipe::OkStatus();
}
virtual std::pair<bool, int> GetId(const T& input) { return {false, -1}; }
virtual void SetId(T* input, int id) {}
private:
// Get a list of non-overlapping elements from all input streams, with
// increasing order of priority based on input stream index.
mediapipe::StatusOr<std::list<T>> GetNonOverlappingElements(
CalculatorContext* cc) {
std::list<T> result;
// Initialize result with the first non-empty input vector.
CollectionItemId non_empty_id = cc->Inputs().BeginId();
for (CollectionItemId id = cc->Inputs().BeginId();
id < cc->Inputs().EndId(); ++id) {
if (id == prev_input_stream_id_ || cc->Inputs().Get(id).IsEmpty()) {
continue;
}
const std::vector<T>& input_vec =
cc->Inputs().Get(id).Get<std::vector<T>>();
if (!input_vec.empty()) {
non_empty_id = id;
result.push_back(input_vec[0]);
for (int j = 1; j < input_vec.size(); ++j) {
MP_RETURN_IF_ERROR(AddElementToList(input_vec[j], &result));
}
break;
}
}
// Compare remaining input vectors with the non-empty result vector,
// remove lower-priority overlapping elements from the result vector and
// had corresponding higher-priority elements as necessary.
for (CollectionItemId id = non_empty_id + 1; id < cc->Inputs().EndId();
++id) {
if (id == prev_input_stream_id_ || cc->Inputs().Get(id).IsEmpty()) {
continue;
}
const std::vector<T>& input_vec =
cc->Inputs().Get(id).Get<std::vector<T>>();
for (int vi = 0; vi < input_vec.size(); ++vi) {
MP_RETURN_IF_ERROR(AddElementToList(input_vec[vi], &result));
}
}
return result;
}
::mediapipe::Status AddElementToList(T element, std::list<T>* current) {
// Compare this element with elements of the input collection. If this
// element has high overlap with elements of the collection, remove
// those elements from the collection and add this element.
ASSIGN_OR_RETURN(auto cur_rect, GetRectangle(element));
bool change_id = false;
int new_elem_id = -1;
for (auto uit = current->begin(); uit != current->end();) {
ASSIGN_OR_RETURN(auto prev_rect, GetRectangle(*uit));
if (OverlapSimilarity(cur_rect, prev_rect) >
options_.min_similarity_threshold()) {
std::pair<bool, int> prev_id = GetId(*uit);
// If prev_id.first is false when some element doesn't have an ID,
// change_id and new_elem_id will not be updated.
if (prev_id.first) {
change_id = prev_id.first;
new_elem_id = prev_id.second;
}
uit = current->erase(uit);
} else {
++uit;
}
}
if (change_id) {
SetId(&element, new_elem_id);
}
current->push_back(element);
return ::mediapipe::OkStatus();
}
// Compare elements of the current list with elements in from the collection
// of elements from the previous input stream, and propagate IDs from the
// previous input stream as appropriate.
::mediapipe::Status PropagateIdsFromPreviousToCurrent(
const std::vector<T>& prev_input_vec, std::list<T>* current) {
for (auto vit = current->begin(); vit != current->end(); ++vit) {
auto get_cur_rectangle = GetRectangle(*vit);
if (!get_cur_rectangle.ok()) {
return get_cur_rectangle.status();
}
const Rectangle_f& cur_rect = get_cur_rectangle.ValueOrDie();
bool change_id = false;
int id_for_vi = -1;
for (int ui = 0; ui < prev_input_vec.size(); ++ui) {
auto get_prev_rectangle = GetRectangle(prev_input_vec[ui]);
if (!get_prev_rectangle.ok()) {
return get_prev_rectangle.status();
}
const Rectangle_f& prev_rect = get_prev_rectangle.ValueOrDie();
if (OverlapSimilarity(cur_rect, prev_rect) >
options_.min_similarity_threshold()) {
std::pair<bool, int> prev_id = GetId(prev_input_vec[ui]);
// If prev_id.first is false when some element doesn't have an ID,
// change_id and id_for_vi will not be updated.
if (prev_id.first) {
change_id = prev_id.first;
id_for_vi = prev_id.second;
}
}
}
if (change_id) {
T element = *vit;
SetId(&element, id_for_vi);
*vit = element;
}
}
return ::mediapipe::OkStatus();
}
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_

View File

@ -0,0 +1,27 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message AssociationCalculatorOptions {
extend CalculatorOptions {
optional AssociationCalculatorOptions ext = 275124847;
}
optional float min_similarity_threshold = 1 [default = 1.0];
}

View File

@ -0,0 +1,476 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
::mediapipe::Detection DetectionWithRelativeLocationData(double xmin,
double ymin,
double width,
double height) {
::mediapipe::Detection detection;
::mediapipe::LocationData* location_data = detection.mutable_location_data();
location_data->set_format(::mediapipe::LocationData::RELATIVE_BOUNDING_BOX);
location_data->mutable_relative_bounding_box()->set_xmin(xmin);
location_data->mutable_relative_bounding_box()->set_ymin(ymin);
location_data->mutable_relative_bounding_box()->set_width(width);
location_data->mutable_relative_bounding_box()->set_height(height);
return detection;
}
} // namespace
class AssociationDetectionCalculatorTest : public ::testing::Test {
protected:
AssociationDetectionCalculatorTest() {
// 0.4 ================
// | | | |
// 0.3 ===================== | DET2 | |
// | | | DET1 | | | DET4 |
// 0.2 | DET0 | =========== ================
// | | | | | |
// 0.1 =====|=============== |
// | DET3 | | |
// 0.0 ================ |
// | DET5 |
// -0.1 ===========
// 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2
// Detection det_0.
det_0 = DetectionWithRelativeLocationData(/*xmin=*/0.1, /*ymin=*/0.1,
/*width=*/0.2, /*height=*/0.2);
det_0.set_detection_id(0);
// Detection det_1.
det_1 = DetectionWithRelativeLocationData(/*xmin=*/0.3, /*ymin=*/0.1,
/*width=*/0.2, /*height=*/0.2);
det_1.set_detection_id(1);
// Detection det_2.
det_2 = DetectionWithRelativeLocationData(/*xmin=*/0.9, /*ymin=*/0.2,
/*width=*/0.2, /*height=*/0.2);
det_2.set_detection_id(2);
// Detection det_3.
det_3 = DetectionWithRelativeLocationData(/*xmin=*/0.2, /*ymin=*/0.0,
/*width=*/0.3, /*height=*/0.3);
det_3.set_detection_id(3);
// Detection det_4.
det_4 = DetectionWithRelativeLocationData(/*xmin=*/1.0, /*ymin=*/0.2,
/*width=*/0.2, /*height=*/0.2);
det_4.set_detection_id(4);
// Detection det_5.
det_5 = DetectionWithRelativeLocationData(/*xmin=*/0.3, /*ymin=*/-0.1,
/*width=*/0.3, /*height=*/0.3);
det_5.set_detection_id(5);
}
::mediapipe::Detection det_0, det_1, det_2, det_3, det_4, det_5;
};
TEST_F(AssociationDetectionCalculatorTest, DetectionAssocTest) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "AssociationDetectionCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.AssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)"));
// Input Stream 0: det_0, det_1, det_2.
auto input_vec_0 = absl::make_unique<std::vector<::mediapipe::Detection>>();
input_vec_0->push_back(det_0);
input_vec_0->push_back(det_1);
input_vec_0->push_back(det_2);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: det_3, det_4.
auto input_vec_1 = absl::make_unique<std::vector<::mediapipe::Detection>>();
input_vec_1->push_back(det_3);
input_vec_1->push_back(det_4);
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: det_5.
auto input_vec_2 = absl::make_unique<std::vector<::mediapipe::Detection>>();
input_vec_2->push_back(det_5);
runner.MutableInputs()->Index(2).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
const auto& assoc_rects =
output[0].Get<std::vector<::mediapipe::Detection>>();
// det_3 overlaps with det_0, det_1 and det_5 overlaps with det_3. Since det_5
// is in the highest priority, we remove other rects. det_4 overlaps with
// det_2, and det_4 is higher priority, so we keep it. The final output
// therefore contains 2 elements.
EXPECT_EQ(2, assoc_rects.size());
// Outputs are in order of inputs, so det_4 is before det_5 in output vector.
// det_4 overlaps with det_2, so new id for det_4 is 2.
EXPECT_TRUE(assoc_rects[0].has_detection_id());
EXPECT_EQ(2, assoc_rects[0].detection_id());
det_4.set_detection_id(2);
EXPECT_THAT(assoc_rects[0], EqualsProto(det_4));
// det_3 overlaps with det_0, so new id for det_3 is 0.
// det_3 overlaps with det_1, so new id for det_3 is 1.
// det_5 overlaps with det_3, so new id for det_5 is 1.
EXPECT_TRUE(assoc_rects[1].has_detection_id());
EXPECT_EQ(1, assoc_rects[1].detection_id());
det_5.set_detection_id(1);
EXPECT_THAT(assoc_rects[1], EqualsProto(det_5));
}
TEST_F(AssociationDetectionCalculatorTest, DetectionAssocTestWithPrev) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "AssociationDetectionCalculator"
input_stream: "PREV:input_vec_0"
input_stream: "input_vec_1"
output_stream: "output_vec"
options {
[mediapipe.AssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)"));
// Input Stream 0: det_3, det_4.
auto input_vec_0 = absl::make_unique<std::vector<::mediapipe::Detection>>();
input_vec_0->push_back(det_3);
input_vec_0->push_back(det_4);
CollectionItemId prev_input_stream_id =
runner.MutableInputs()->GetId("PREV", 0);
runner.MutableInputs()
->Get(prev_input_stream_id)
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: det_5.
auto input_vec_1 = absl::make_unique<std::vector<::mediapipe::Detection>>();
input_vec_1->push_back(det_5);
CollectionItemId input_stream_id = runner.MutableInputs()->GetId("", 0);
runner.MutableInputs()
->Get(input_stream_id)
.packets.push_back(Adopt(input_vec_1.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
const auto& assoc_rects =
output[0].Get<std::vector<::mediapipe::Detection>>();
// det_5 overlaps with det_3 and doesn't overlap with det_4. Since det_4 is
// in the PREV input stream, it doesn't get copied to the output, so the final
// output contains 1 element.
EXPECT_EQ(1, assoc_rects.size());
// det_5 overlaps with det_3, det_3 is in PREV, so new id for det_5 is 3.
EXPECT_TRUE(assoc_rects[0].has_detection_id());
EXPECT_EQ(3, assoc_rects[0].detection_id());
det_5.set_detection_id(3);
EXPECT_THAT(assoc_rects[0], EqualsProto(det_5));
}
TEST_F(AssociationDetectionCalculatorTest, DetectionAssocTestReverse) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "AssociationDetectionCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.AssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)"));
// Input Stream 0: det_5.
auto input_vec_0 = absl::make_unique<std::vector<::mediapipe::Detection>>();
input_vec_0->push_back(det_5);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: det_3, det_4.
auto input_vec_1 = absl::make_unique<std::vector<::mediapipe::Detection>>();
input_vec_1->push_back(det_3);
input_vec_1->push_back(det_4);
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: det_0, det_1, det_2.
auto input_vec_2 = absl::make_unique<std::vector<::mediapipe::Detection>>();
input_vec_2->push_back(det_0);
input_vec_2->push_back(det_1);
input_vec_2->push_back(det_2);
runner.MutableInputs()->Index(2).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
const auto& assoc_rects =
output[0].Get<std::vector<::mediapipe::Detection>>();
// det_3 overlaps with det_5, so det_5 is removed. det_0 overlaps with det_3,
// so det_3 is removed as det_0 is in higher priority for keeping. det_2
// overlaps with det_4 so det_4 is removed as det_2 is higher priority for
// keeping. The final output therefore contains 3 elements.
EXPECT_EQ(3, assoc_rects.size());
// Outputs are in same order as inputs.
// det_3 overlaps with det_5, so new id for det_3 is 5.
// det_0 overlaps with det_3, so new id for det_0 is 5.
EXPECT_TRUE(assoc_rects[0].has_detection_id());
EXPECT_EQ(5, assoc_rects[0].detection_id());
det_0.set_detection_id(5);
EXPECT_THAT(assoc_rects[0], EqualsProto(det_0));
// det_1 stays with id 1.
EXPECT_TRUE(assoc_rects[1].has_detection_id());
EXPECT_EQ(1, assoc_rects[1].detection_id());
EXPECT_THAT(assoc_rects[1], EqualsProto(det_1));
// det_2 overlaps with det_4, so new id for det_2 is 4.
EXPECT_TRUE(assoc_rects[2].has_detection_id());
EXPECT_EQ(4, assoc_rects[2].detection_id());
det_2.set_detection_id(4);
EXPECT_THAT(assoc_rects[2], EqualsProto(det_2));
}
class AssociationNormRectCalculatorTest : public ::testing::Test {
protected:
AssociationNormRectCalculatorTest() {
// 0.4 ================
// | | | |
// 0.3 ===================== | NR2 | |
// | | | NR1 | | | NR4 |
// 0.2 | NR0 | =========== ================
// | | | | | |
// 0.1 =====|=============== |
// | NR3 | | |
// 0.0 ================ |
// | NR5 |
// -0.1 ===========
// 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2
// NormalizedRect nr_0.
nr_0.set_x_center(0.2);
nr_0.set_y_center(0.2);
nr_0.set_width(0.2);
nr_0.set_height(0.2);
// NormalizedRect nr_1.
nr_1.set_x_center(0.4);
nr_1.set_y_center(0.2);
nr_1.set_width(0.2);
nr_1.set_height(0.2);
// NormalizedRect nr_2.
nr_2.set_x_center(1.0);
nr_2.set_y_center(0.3);
nr_2.set_width(0.2);
nr_2.set_height(0.2);
// NormalizedRect nr_3.
nr_3.set_x_center(0.35);
nr_3.set_y_center(0.15);
nr_3.set_width(0.3);
nr_3.set_height(0.3);
// NormalizedRect nr_4.
nr_4.set_x_center(1.1);
nr_4.set_y_center(0.3);
nr_4.set_width(0.2);
nr_4.set_height(0.2);
// NormalizedRect nr_5.
nr_5.set_x_center(0.45);
nr_5.set_y_center(0.05);
nr_5.set_width(0.3);
nr_5.set_height(0.3);
}
::mediapipe::NormalizedRect nr_0, nr_1, nr_2, nr_3, nr_4, nr_5;
};
TEST_F(AssociationNormRectCalculatorTest, NormRectAssocTest) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "AssociationNormRectCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.AssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)"));
// Input Stream 0: nr_0, nr_1, nr_2.
auto input_vec_0 =
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
input_vec_0->push_back(nr_0);
input_vec_0->push_back(nr_1);
input_vec_0->push_back(nr_2);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_3, nr_4.
auto input_vec_1 =
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
input_vec_1->push_back(nr_3);
input_vec_1->push_back(nr_4);
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_5.
auto input_vec_2 =
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
input_vec_2->push_back(nr_5);
runner.MutableInputs()->Index(2).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
const auto& assoc_rects =
output[0].Get<std::vector<::mediapipe::NormalizedRect>>();
// nr_3 overlaps with nr_0, nr_1 and nr_5 overlaps with nr_3. Since nr_5 is
// in the highest priority, we remove other rects.
// nr_4 overlaps with nr_2, and nr_4 is higher priority, so we keep it.
// The final output therefore contains 2 elements.
EXPECT_EQ(2, assoc_rects.size());
// Outputs are in order of inputs, so nr_4 is before nr_5 in output vector.
EXPECT_THAT(assoc_rects[0], EqualsProto(nr_4));
EXPECT_THAT(assoc_rects[1], EqualsProto(nr_5));
}
TEST_F(AssociationNormRectCalculatorTest, NormRectAssocTestReverse) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "AssociationNormRectCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.AssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)"));
// Input Stream 0: nr_5.
auto input_vec_0 =
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
input_vec_0->push_back(nr_5);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_3, nr_4.
auto input_vec_1 =
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
input_vec_1->push_back(nr_3);
input_vec_1->push_back(nr_4);
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_0, nr_1, nr_2.
auto input_vec_2 =
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
input_vec_2->push_back(nr_0);
input_vec_2->push_back(nr_1);
input_vec_2->push_back(nr_2);
runner.MutableInputs()->Index(2).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
const auto& assoc_rects =
output[0].Get<std::vector<::mediapipe::NormalizedRect>>();
// nr_3 overlaps with nr_5, so nr_5 is removed. nr_0 overlaps with nr_3, so
// nr_3 is removed as nr_0 is in higher priority for keeping. nr_2 overlaps
// with nr_4 so nr_4 is removed as nr_2 is higher priority for keeping.
// The final output therefore contains 3 elements.
EXPECT_EQ(3, assoc_rects.size());
// Outputs are in same order as inputs.
EXPECT_THAT(assoc_rects[0], EqualsProto(nr_0));
EXPECT_THAT(assoc_rects[1], EqualsProto(nr_1));
EXPECT_THAT(assoc_rects[2], EqualsProto(nr_2));
}
TEST_F(AssociationNormRectCalculatorTest, NormRectAssocSingleInputStream) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "AssociationNormRectCalculator"
input_stream: "input_vec"
output_stream: "output_vec"
options {
[mediapipe.AssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)"));
// Input Stream : nr_3, nr_5.
auto input_vec =
absl::make_unique<std::vector<::mediapipe::NormalizedRect>>();
input_vec->push_back(nr_3);
input_vec->push_back(nr_5);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
const auto& assoc_rects =
output[0].Get<std::vector<::mediapipe::NormalizedRect>>();
// nr_5 overlaps with nr_3. Since nr_5 is after nr_3 in the same input stream
// we remove nr_3 and keep nr_5.
// The final output therefore contains 1 elements.
EXPECT_EQ(1, assoc_rects.size());
EXPECT_THAT(assoc_rects[0], EqualsProto(nr_5));
}
} // namespace mediapipe

View File

@ -0,0 +1,77 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/util/association_calculator.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// A subclass of AssociationCalculator<T> for Detection. Example:
// node {
// calculator: "AssociationDetectionCalculator"
// input_stream: "PREV:input_vec_0"
// input_stream: "input_vec_1"
// input_stream: "input_vec_2"
// output_stream: "output_vec"
// options {
// [mediapipe.AssociationCalculatorOptions.ext] {
// min_similarity_threshold: 0.1
// }
// }
class AssociationDetectionCalculator
: public AssociationCalculator<::mediapipe::Detection> {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
return AssociationCalculator<::mediapipe::Detection>::GetContract(cc);
}
::mediapipe::Status Open(CalculatorContext* cc) override {
return AssociationCalculator<::mediapipe::Detection>::Open(cc);
}
::mediapipe::Status Process(CalculatorContext* cc) override {
return AssociationCalculator<::mediapipe::Detection>::Process(cc);
}
::mediapipe::Status Close(CalculatorContext* cc) override {
return AssociationCalculator<::mediapipe::Detection>::Close(cc);
}
protected:
::mediapipe::StatusOr<Rectangle_f> GetRectangle(
const ::mediapipe::Detection& input) override {
if (!input.has_location_data()) {
return ::mediapipe::InternalError("Missing location_data in Detection");
}
const Location location(input.location_data());
return location.GetRelativeBBox();
}
std::pair<bool, int> GetId(const ::mediapipe::Detection& input) override {
return {input.has_detection_id(), input.detection_id()};
}
void SetId(::mediapipe::Detection* input, int id) override {
input->set_detection_id(id);
}
};
REGISTER_CALCULATOR(AssociationDetectionCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,72 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/util/association_calculator.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// A subclass of AssociationCalculator<T> for NormalizedRect. Example use case:
// node {
// calculator: "AssociationNormRectCalculator"
// input_stream: "input_vec_0"
// input_stream: "input_vec_1"
// input_stream: "input_vec_2"
// output_stream: "output_vec"
// options {
// [mediapipe.AssociationCalculatorOptions.ext] {
// min_similarity_threshold: 0.1
// }
// }
class AssociationNormRectCalculator
: public AssociationCalculator<::mediapipe::NormalizedRect> {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
return AssociationCalculator<::mediapipe::NormalizedRect>::GetContract(cc);
}
::mediapipe::Status Open(CalculatorContext* cc) override {
return AssociationCalculator<::mediapipe::NormalizedRect>::Open(cc);
}
::mediapipe::Status Process(CalculatorContext* cc) override {
return AssociationCalculator<::mediapipe::NormalizedRect>::Process(cc);
}
::mediapipe::Status Close(CalculatorContext* cc) override {
return AssociationCalculator<::mediapipe::NormalizedRect>::Close(cc);
}
protected:
::mediapipe::StatusOr<Rectangle_f> GetRectangle(
const ::mediapipe::NormalizedRect& input) override {
if (!input.has_x_center() || !input.has_y_center() || !input.has_width() ||
!input.has_height()) {
return ::mediapipe::InternalError(
"Missing dimensions in NormalizedRect.");
}
const float xmin = input.x_center() - input.width() / 2.0;
const float ymin = input.y_center() - input.height() / 2.0;
// TODO: Support rotation for rectangle.
return Rectangle_f(xmin, ymin, input.width(), input.height());
}
};
REGISTER_CALCULATOR(AssociationNormRectCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,26 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/util/collection_has_min_size_calculator.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe {
typedef CollectionHasMinSizeCalculator<std::vector<::mediapipe::NormalizedRect>>
NormalizedRectVectorHasMinSizeCalculator;
REGISTER_CALCULATOR(NormalizedRectVectorHasMinSizeCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,84 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_UTIL_COLLECTION_HAS_MIN_SIZE_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_UTIL_COLLECTION_HAS_MIN_SIZE_CALCULATOR_H_
#include <vector>
#include "mediapipe/calculators/util/collection_has_min_size_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Deterimines if an input iterable collection has a minimum size, specified
// in CollectionHasMinSizeCalculatorOptions. Example usage:
// node {
// calculator: "IntVectorHasMinSizeCalculator"
// input_stream: "ITERABLE:input_int_vector"
// output_stream: "has_min_ints"
// options {
// [mediapipe.CollectionHasMinSizeCalculatorOptions.ext] {
// min_size: 2
// }
// }
// }
template <typename IterableT>
class CollectionHasMinSizeCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("ITERABLE"));
RET_CHECK_EQ(1, cc->Inputs().NumEntries());
RET_CHECK_EQ(1, cc->Outputs().NumEntries());
RET_CHECK_GE(
cc->Options<::mediapipe::CollectionHasMinSizeCalculatorOptions>()
.min_size(),
0);
cc->Inputs().Tag("ITERABLE").Set<IterableT>();
cc->Outputs().Index(0).Set<bool>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
min_size_ =
cc->Options<::mediapipe::CollectionHasMinSizeCalculatorOptions>()
.min_size();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
const IterableT& input = cc->Inputs().Tag("ITERABLE").Get<IterableT>();
bool has_min_size = input.size() >= min_size_;
cc->Outputs().Index(0).AddPacket(
MakePacket<bool>(has_min_size).At(cc->InputTimestamp()));
return ::mediapipe::OkStatus();
}
private:
int min_size_ = 0;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_UTIL_COLLECTION_HAS_MIN_SIZE_CALCULATOR_H_

View File

@ -0,0 +1,29 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message CollectionHasMinSizeCalculatorOptions {
extend CalculatorOptions {
optional CollectionHasMinSizeCalculatorOptions ext = 259397840;
}
// The minimum size an input iterable collection should have for the
// calculator to output true.
optional int32 min_size = 1 [default = 0];
}

View File

@ -19,8 +19,8 @@
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/resource_util.h"
#if defined(MEDIAPIPE_LITE) || defined(__ANDROID__) || \
(defined(__APPLE__) && !TARGET_OS_OSX)
#if defined(MEDIAPIPE_LITE) || defined(__EMSCRIPTEN__) || \
defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
#include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/helpers.h"
#else

View File

@ -11,6 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/util/detections_to_rects_calculator.h"
#include <cmath>
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
@ -24,8 +26,6 @@
namespace mediapipe {
using mediapipe::DetectionsToRectsCalculatorOptions;
namespace {
constexpr char kDetectionTag[] = "DETECTION";
@ -36,7 +36,10 @@ constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kRectsTag[] = "RECTS";
constexpr char kNormRectsTag[] = "NORM_RECTS";
::mediapipe::Status DetectionToRect(const Detection& detection, Rect* rect) {
} // namespace
::mediapipe::Status DetectionsToRectsCalculator::DetectionToRect(
const Detection& detection, Rect* rect) {
const LocationData location_data = detection.location_data();
RET_CHECK(location_data.format() == LocationData::BOUNDING_BOX)
<< "Only Detection with formats of BOUNDING_BOX can be converted to Rect";
@ -48,8 +51,8 @@ constexpr char kNormRectsTag[] = "NORM_RECTS";
return ::mediapipe::OkStatus();
}
::mediapipe::Status DetectionToNormalizedRect(const Detection& detection,
NormalizedRect* rect) {
::mediapipe::Status DetectionsToRectsCalculator::DetectionToNormalizedRect(
const Detection& detection, NormalizedRect* rect) {
const LocationData location_data = detection.location_data();
RET_CHECK(location_data.format() == LocationData::RELATIVE_BOUNDING_BOX)
<< "Only Detection with formats of RELATIVE_BOUNDING_BOX can be "
@ -63,79 +66,6 @@ constexpr char kNormRectsTag[] = "NORM_RECTS";
return ::mediapipe::OkStatus();
}
// Wraps around an angle in radians to within -M_PI and M_PI.
inline float NormalizeRadians(float angle) {
return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI));
}
} // namespace
// A calculator that converts Detection proto to Rect proto.
//
// Detection is the format for encoding one or more detections in an image.
// The input can be a single Detection or std::vector<Detection>. The output can
// be either a single Rect or NormalizedRect, or std::vector<Rect> or
// std::vector<NormalizedRect>. If Rect is used, the LocationData format is
// expected to be BOUNDING_BOX, and if NormalizedRect is used it is expected to
// be RELATIVE_BOUNDING_BOX.
//
// When the input is std::vector<Detection> and the output is a Rect or
// NormalizedRect, only the first detection is converted. When the input is a
// single Detection and the output is a std::vector<Rect> or
// std::vector<NormalizedRect>, the output is a vector of size 1.
//
// Inputs:
//
// One of the following:
// DETECTION: A Detection proto.
// DETECTIONS: An std::vector<Detection>.
//
// IMAGE_SIZE (optional): A std::pair<int, int> represention image width and
// height. This is required only when rotation needs to be computed (see
// calculator options).
//
// Output:
// One of the following:
// RECT: A Rect proto.
// NORM_RECT: A NormalizedRect proto.
// RECTS: An std::vector<Rect>.
// NORM_RECTS: An std::vector<NormalizedRect>.
//
// Example config:
// node {
// calculator: "DetectionsToRectsCalculator"
// input_stream: "DETECTIONS:detections"
// input_stream: "IMAGE_SIZE:image_size"
// output_stream: "NORM_RECT:rect"
// options: {
// [mediapipe.DetectionsToRectCalculatorOptions.ext] {
// rotation_vector_start_keypoint_index: 0
// rotation_vector_end_keypoint_index: 2
// rotation_vector_target_angle_degrees: 90
// output_zero_rect_for_empty_detections: true
// }
// }
// }
class DetectionsToRectsCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
float ComputeRotation(const Detection& detection,
const std::pair<int, int> image_size);
DetectionsToRectsCalculatorOptions options_;
int start_keypoint_index_;
int end_keypoint_index_;
float target_angle_; // In radians.
bool rotate_;
bool output_zero_rect_for_empty_detections_;
};
REGISTER_CALCULATOR(DetectionsToRectsCalculator);
::mediapipe::Status DetectionsToRectsCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kDetectionTag) ^
@ -232,6 +162,13 @@ REGISTER_CALCULATOR(DetectionsToRectsCalculator);
.Tag(kNormRectTag)
.AddPacket(MakePacket<NormalizedRect>().At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag(kNormRectsTag)) {
auto rect_vector = absl::make_unique<std::vector<NormalizedRect>>();
rect_vector->emplace_back(NormalizedRect());
cc->Outputs()
.Tag(kNormRectsTag)
.Add(rect_vector.release(), cc->InputTimestamp());
}
}
return ::mediapipe::OkStatus();
}
@ -312,4 +249,6 @@ float DetectionsToRectsCalculator::ComputeRotation(
return NormalizeRadians(rotation);
}
REGISTER_CALCULATOR(DetectionsToRectsCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,105 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_
#include <cmath>
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// A calculator that converts Detection proto to Rect proto.
//
// Detection is the format for encoding one or more detections in an image.
// The input can be a single Detection or std::vector<Detection>. The output can
// be either a single Rect or NormalizedRect, or std::vector<Rect> or
// std::vector<NormalizedRect>. If Rect is used, the LocationData format is
// expected to be BOUNDING_BOX, and if NormalizedRect is used it is expected to
// be RELATIVE_BOUNDING_BOX.
//
// When the input is std::vector<Detection> and the output is a Rect or
// NormalizedRect, only the first detection is converted. When the input is a
// single Detection and the output is a std::vector<Rect> or
// std::vector<NormalizedRect>, the output is a vector of size 1.
//
// Inputs:
//
// One of the following:
// DETECTION: A Detection proto.
// DETECTIONS: An std::vector<Detection>.
//
// IMAGE_SIZE (optional): A std::pair<int, int> represention image width and
// height. This is required only when rotation needs to be computed (see
// calculator options).
//
// Output:
// One of the following:
// RECT: A Rect proto.
// NORM_RECT: A NormalizedRect proto.
// RECTS: An std::vector<Rect>.
// NORM_RECTS: An std::vector<NormalizedRect>.
//
// Example config:
// node {
// calculator: "DetectionsToRectsCalculator"
// input_stream: "DETECTIONS:detections"
// input_stream: "IMAGE_SIZE:image_size"
// output_stream: "NORM_RECT:rect"
// options: {
// [mediapipe.DetectionsToRectCalculatorOptions.ext] {
// rotation_vector_start_keypoint_index: 0
// rotation_vector_end_keypoint_index: 2
// rotation_vector_target_angle_degrees: 90
// output_zero_rect_for_empty_detections: true
// }
// }
// }
class DetectionsToRectsCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
protected:
virtual float ComputeRotation(const ::mediapipe::Detection& detection,
const std::pair<int, int> image_size);
virtual ::mediapipe::Status DetectionToRect(
const ::mediapipe::Detection& detection, ::mediapipe::Rect* rect);
virtual ::mediapipe::Status DetectionToNormalizedRect(
const ::mediapipe::Detection& detection,
::mediapipe::NormalizedRect* rect);
static inline float NormalizeRadians(float angle) {
return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI));
}
::mediapipe::DetectionsToRectsCalculatorOptions options_;
int start_keypoint_index_;
int end_keypoint_index_;
float target_angle_ = 0.0f; // In radians.
bool rotate_;
bool output_zero_rect_for_empty_detections_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_

View File

@ -0,0 +1,34 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/util/filter_collection_calculator.h"
#include <vector>
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe {
typedef FilterCollectionCalculator<std::vector<::mediapipe::NormalizedRect>>
FilterNormalizedRectCollectionCalculator;
REGISTER_CALCULATOR(FilterNormalizedRectCollectionCalculator);
typedef FilterCollectionCalculator<
std::vector<std::vector<::mediapipe::NormalizedLandmark>>>
FilterLandmarksCollectionCalculator;
REGISTER_CALCULATOR(FilterLandmarksCollectionCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,109 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_UTIL_FILTER_VECTOR_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_UTIL_FILTER_VECTOR_CALCULATOR_H_
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// A calculator that gates elements of an input collection based on
// corresponding boolean values of the "CONDITION" vector. If there is no input
// collection or "CONDITION" vector, the calculator forwards timestamp bounds
// for downstream calculators. If the "CONDITION" vector has false values for
// all elements of the input collection, the calculator outputs a packet
// containing an empty collection.
// Example usage:
// node {
// calculator: "FilterCollectionCalculator"
// input_stream: "ITERABLE:input_collection"
// input_stream: "CONDITION:condition_vector"
// output_stream: "ITERABLE:output_collection"
// }
// This calculator is able to handle collections of copyable types T.
template <typename IterableT>
class FilterCollectionCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("ITERABLE"));
RET_CHECK(cc->Inputs().HasTag("CONDITION"));
RET_CHECK(cc->Outputs().HasTag("ITERABLE"));
cc->Inputs().Tag("ITERABLE").Set<IterableT>();
cc->Inputs().Tag("CONDITION").Set<std::vector<bool>>();
cc->Outputs().Tag("ITERABLE").Set<IterableT>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
if (cc->Inputs().Tag("ITERABLE").IsEmpty()) {
return ::mediapipe::OkStatus();
}
if (cc->Inputs().Tag("CONDITION").IsEmpty()) {
return ::mediapipe::OkStatus();
}
const std::vector<bool>& filter_by =
cc->Inputs().Tag("CONDITION").Get<std::vector<bool>>();
return FilterCollection<IterableT>(
std::is_copy_constructible<typename IterableT::value_type>(), cc,
filter_by);
}
template <typename IterableU>
::mediapipe::Status FilterCollection(std::true_type, CalculatorContext* cc,
const std::vector<bool>& filter_by) {
const IterableU& input = cc->Inputs().Tag("ITERABLE").Get<IterableU>();
if (input.size() != filter_by.size()) {
return ::mediapipe::InternalError(absl::StrCat(
"Input vector size: ", input.size(),
" doesn't mach condition vector size: ", filter_by.size()));
}
auto output = absl::make_unique<IterableU>();
for (int i = 0; i < input.size(); ++i) {
if (filter_by[i]) {
output->push_back(input[i]);
}
}
cc->Outputs().Tag("ITERABLE").Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
template <typename IterableU>
::mediapipe::Status FilterCollection(std::false_type, CalculatorContext* cc,
const std::vector<bool>& filter_by) {
return ::mediapipe::InternalError(
"Cannot copy input collection to filter it.");
}
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_UTIL_FILTER_VECTOR_CALCULATOR_H_

View File

@ -0,0 +1,181 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/util/labels_to_render_data_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
constexpr float kFontHeightScale = 1.25f;
// A calculator takes in pairs of labels and scores or classifications, outputs
// generates render data. Either both "LABELS" and "SCORES" or "CLASSIFICATIONS"
// must be present.
//
// Usage example:
// node {
// calculator: "LabelsToRenderDataCalculator"
// input_stream: "LABELS:labels"
// input_stream: "SCORES:scores"
// output_stream: "VIDEO_PRESTREAM:video_header"
// options {
// [LabelsToRenderDataCalculatorOptions.ext] {
// color { r: 255 g: 0 b: 0 }
// color { r: 0 g: 255 b: 0 }
// color { r: 0 g: 0 b: 255 }
// thickness: 2.0
// font_height_px: 20
// max_num_labels: 3
// font_face: 1
// location: TOP_LEFT
// }
// }
// }
class LabelsToRenderDataCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
LabelsToRenderDataCalculatorOptions options_;
int num_colors_ = 0;
int video_width_ = 0;
int video_height_ = 0;
int label_height_px_ = 0;
int label_left_px_ = 0;
};
REGISTER_CALCULATOR(LabelsToRenderDataCalculator);
::mediapipe::Status LabelsToRenderDataCalculator::GetContract(
CalculatorContract* cc) {
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
cc->Inputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
} else {
RET_CHECK(cc->Inputs().HasTag("LABELS"))
<< "Must provide input stream \"LABELS\"";
cc->Inputs().Tag("LABELS").Set<std::vector<std::string>>();
if (cc->Inputs().HasTag("SCORES")) {
cc->Inputs().Tag("SCORES").Set<std::vector<float>>();
}
}
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
}
cc->Outputs().Tag("RENDER_DATA").Set<RenderData>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<LabelsToRenderDataCalculatorOptions>();
num_colors_ = options_.color_size();
label_height_px_ = std::ceil(options_.font_height_px() * kFontHeightScale);
return ::mediapipe::OkStatus();
}
::mediapipe::Status LabelsToRenderDataCalculator::Process(
CalculatorContext* cc) {
if (cc->Inputs().HasTag("VIDEO_PRESTREAM") &&
cc->InputTimestamp() == Timestamp::PreStream()) {
const VideoHeader& video_header =
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
video_width_ = video_header.width;
video_height_ = video_header.height;
return ::mediapipe::OkStatus();
} else {
CHECK_EQ(options_.location(), LabelsToRenderDataCalculatorOptions::TOP_LEFT)
<< "Only TOP_LEFT is supported without VIDEO_PRESTREAM.";
}
std::vector<std::string> labels;
std::vector<float> scores;
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
const ClassificationList& classifications =
cc->Inputs().Tag("CLASSIFICATIONS").Get<ClassificationList>();
labels.resize(classifications.classification_size());
scores.resize(classifications.classification_size());
for (int i = 0; i < classifications.classification_size(); ++i) {
labels[i] = classifications.classification(i).label();
scores[i] = classifications.classification(i).score();
}
} else {
const std::vector<std::string>& label_vector =
cc->Inputs().Tag("LABELS").Get<std::vector<std::string>>();
std::vector<float> score_vector;
if (cc->Inputs().HasTag("SCORES")) {
score_vector = cc->Inputs().Tag("SCORES").Get<std::vector<float>>();
}
CHECK_EQ(label_vector.size(), score_vector.size());
labels.resize(label_vector.size());
scores.resize(label_vector.size());
for (int i = 0; i < label_vector.size(); ++i) {
labels[i] = label_vector[i];
scores[i] = score_vector[i];
}
}
RenderData render_data;
int num_label = std::min((int)labels.size(), options_.max_num_labels());
int label_baseline_px = options_.vertical_offset_px();
if (options_.location() == LabelsToRenderDataCalculatorOptions::TOP_LEFT) {
label_baseline_px += label_height_px_;
} else if (options_.location() ==
LabelsToRenderDataCalculatorOptions::BOTTOM_LEFT) {
label_baseline_px += video_height_ - label_height_px_ * (num_label - 1);
}
label_left_px_ = options_.horizontal_offset_px();
for (int i = 0; i < num_label; ++i) {
auto* label_annotation = render_data.add_render_annotations();
label_annotation->set_thickness(options_.thickness());
if (num_colors_ > 0) {
*(label_annotation->mutable_color()) = options_.color(i % num_colors_);
} else {
label_annotation->mutable_color()->set_r(255);
label_annotation->mutable_color()->set_g(0);
label_annotation->mutable_color()->set_b(0);
}
auto* text = label_annotation->mutable_text();
std::string display_text = labels[i];
if (cc->Inputs().HasTag("SCORES")) {
absl::StrAppend(&display_text, ":", scores[i]);
}
text->set_display_text(display_text);
text->set_font_height(options_.font_height_px());
text->set_left(label_left_px_);
text->set_baseline(label_baseline_px + i * label_height_px_);
text->set_font_face(options_.font_face());
}
cc->Outputs()
.Tag("RENDER_DATA")
.AddPacket(MakePacket<RenderData>(render_data).At(cc->InputTimestamp()));
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,62 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/util/color.proto";
message LabelsToRenderDataCalculatorOptions {
extend CalculatorOptions {
optional LabelsToRenderDataCalculatorOptions ext = 271660364;
}
// Colors for drawing the label(s).
repeated Color color = 1;
// Thickness for drawing the label(s).
optional double thickness = 2 [default = 2];
// The font height in absolute pixels.
optional int32 font_height_px = 3 [default = 50];
// The offset of the starting text in horizontal direction in absolute pixels.
optional int32 horizontal_offset_px = 7 [default = 0];
// The offset of the starting text in vertical direction in absolute pixels.
optional int32 vertical_offset_px = 8 [default = 0];
// The maximum number of labels to display.
optional int32 max_num_labels = 4 [default = 1];
// Specifies the font for the text. Font must be one of the following from
// OpenCV:
// cv::FONT_HERSHEY_SIMPLEX (0)
// cv::FONT_HERSHEY_PLAIN (1)
// cv::FONT_HERSHEY_DUPLEX (2)
// cv::FONT_HERSHEY_COMPLEX (3)
// cv::FONT_HERSHEY_TRIPLEX (4)
// cv::FONT_HERSHEY_COMPLEX_SMALL (5)
// cv::FONT_HERSHEY_SCRIPT_SIMPLEX (6)
// cv::FONT_HERSHEY_SCRIPT_COMPLEX (7)
optional int32 font_face = 5 [default = 0];
// Label location.
enum Location {
TOP_LEFT = 0;
BOTTOM_LEFT = 1;
}
optional Location location = 6 [default = TOP_LEFT];
}

View File

@ -0,0 +1,138 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cmath>
#include <vector>
#include "Eigen/Core"
#include "mediapipe/calculators/util/landmarks_to_floats_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
namespace {
constexpr char kLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kFloatsTag[] = "FLOATS";
constexpr char kMatrixTag[] = "MATRIX";
} // namespace
// Converts a vector of landmarks to a vector of floats or a matrix.
// Input:
// NORM_LANDMARKS: An std::vector<NormalizedLandmark>.
//
// Output:
// FLOATS(optional): A vector of floats from flattened landmarks.
// MATRIX(optional): A matrix of floats of the landmarks.
//
// Usage example:
// node {
// calculator: "LandmarksToFloatsCalculator"
// input_stream: "NORM_LANDMARKS:landmarks"
// output_stream: "MATRIX:landmark_matrix"
// }
class LandmarksToFloatsCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag(kLandmarksTag).Set<std::vector<NormalizedLandmark>>();
RET_CHECK(cc->Outputs().HasTag(kFloatsTag) ||
cc->Outputs().HasTag(kMatrixTag));
if (cc->Outputs().HasTag(kFloatsTag)) {
cc->Outputs().Tag(kFloatsTag).Set<std::vector<float>>();
}
if (cc->Outputs().HasTag(kMatrixTag)) {
cc->Outputs().Tag(kMatrixTag).Set<Matrix>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
const auto& options =
cc->Options<::mediapipe::LandmarksToFloatsCalculatorOptions>();
num_dimensions_ = options.num_dimensions();
// Currently number of dimensions must be within [1, 3].
RET_CHECK_GE(num_dimensions_, 1);
RET_CHECK_LE(num_dimensions_, 3);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
// Only process if there's input landmarks.
if (cc->Inputs().Tag(kLandmarksTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
const auto& input_landmarks =
cc->Inputs().Tag(kLandmarksTag).Get<std::vector<NormalizedLandmark>>();
if (cc->Outputs().HasTag(kFloatsTag)) {
auto output_floats = absl::make_unique<std::vector<float>>();
for (const auto& landmark : input_landmarks) {
output_floats->emplace_back(landmark.x());
if (num_dimensions_ > 1) {
output_floats->emplace_back(landmark.y());
}
if (num_dimensions_ > 2) {
output_floats->emplace_back(landmark.z());
}
}
cc->Outputs()
.Tag(kFloatsTag)
.Add(output_floats.release(), cc->InputTimestamp());
} else {
auto output_matrix = absl::make_unique<Matrix>();
output_matrix->setZero(num_dimensions_, input_landmarks.size());
for (int i = 0; i < input_landmarks.size(); ++i) {
(*output_matrix)(0, i) = input_landmarks[i].x();
if (num_dimensions_ > 1) {
(*output_matrix)(1, i) = input_landmarks[i].y();
}
if (num_dimensions_ > 2) {
(*output_matrix)(2, i) = input_landmarks[i].z();
}
}
cc->Outputs()
.Tag(kMatrixTag)
.Add(output_matrix.release(), cc->InputTimestamp());
}
return ::mediapipe::OkStatus();
}
private:
int num_dimensions_ = 0;
};
REGISTER_CALCULATOR(LandmarksToFloatsCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,28 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message LandmarksToFloatsCalculatorOptions {
extend CalculatorOptions {
optional LandmarksToFloatsCalculatorOptions ext = 274035660;
}
// Number of dimensions to convert. Must within [1, 3].
optional int32 num_dimensions = 1 [default = 2];
}

View File

@ -0,0 +1,57 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// The calculator takes the path to the local file as an input side packet and
// outputs the contents of that file.
//
// Example config:
// node {
// calculator: "LocalFileContentsCalculator"
// input_side_packet: "FILE_PATH:file_path"
// output_side_packet: "CONTENTS:contents"
// }
class LocalFileContentsCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("FILE_PATH").Set<std::string>();
cc->OutputSidePackets().Tag("CONTENTS").Set<std::string>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
std::string contents;
MP_RETURN_IF_ERROR(mediapipe::file::GetContents(
cc->InputSidePackets().Tag("FILE_PATH").Get<std::string>(), &contents));
cc->OutputSidePackets()
.Tag("CONTENTS")
.Set(MakePacket<std::string>(std::move(contents)));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(LocalFileContentsCalculator);
} // namespace mediapipe

View File

@ -23,7 +23,9 @@ namespace mediapipe {
namespace {
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormRectsTag[] = "NORM_RECTS";
constexpr char kRectTag[] = "RECT";
constexpr char kRectsTag[] = "RECTS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
// Wraps around an angle in radians to within -M_PI and M_PI.
@ -72,17 +74,31 @@ REGISTER_CALCULATOR(RectTransformationCalculator);
::mediapipe::Status RectTransformationCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kNormRectTag) ^ cc->Inputs().HasTag(kRectTag));
RET_CHECK_EQ((cc->Inputs().HasTag(kNormRectTag) ? 1 : 0) +
(cc->Inputs().HasTag(kNormRectsTag) ? 1 : 0) +
(cc->Inputs().HasTag(kRectTag) ? 1 : 0) +
(cc->Inputs().HasTag(kRectsTag) ? 1 : 0),
1);
if (cc->Inputs().HasTag(kRectTag)) {
cc->Inputs().Tag(kRectTag).Set<Rect>();
cc->Outputs().Index(0).Set<Rect>();
}
if (cc->Inputs().HasTag(kRectsTag)) {
cc->Inputs().Tag(kRectsTag).Set<std::vector<Rect>>();
cc->Outputs().Index(0).Set<std::vector<Rect>>();
}
if (cc->Inputs().HasTag(kNormRectTag)) {
RET_CHECK(cc->Inputs().HasTag(kImageSizeTag));
cc->Inputs().Tag(kNormRectTag).Set<NormalizedRect>();
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
cc->Outputs().Index(0).Set<NormalizedRect>();
}
if (cc->Inputs().HasTag(kNormRectsTag)) {
RET_CHECK(cc->Inputs().HasTag(kImageSizeTag));
cc->Inputs().Tag(kNormRectsTag).Set<std::vector<NormalizedRect>>();
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>();
}
return ::mediapipe::OkStatus();
}
@ -105,7 +121,17 @@ REGISTER_CALCULATOR(RectTransformationCalculator);
cc->Outputs().Index(0).AddPacket(
MakePacket<Rect>(rect).At(cc->InputTimestamp()));
}
if (cc->Inputs().HasTag(kRectsTag) &&
!cc->Inputs().Tag(kRectsTag).IsEmpty()) {
auto rects = cc->Inputs().Tag(kRectsTag).Get<std::vector<Rect>>();
auto output_rects = absl::make_unique<std::vector<Rect>>(rects.size());
for (int i = 0; i < rects.size(); ++i) {
output_rects->at(i) = rects[i];
auto it = output_rects->begin() + i;
TransformRect(&(*it));
}
cc->Outputs().Index(0).Add(output_rects.release(), cc->InputTimestamp());
}
if (cc->Inputs().HasTag(kNormRectTag) &&
!cc->Inputs().Tag(kNormRectTag).IsEmpty()) {
auto rect = cc->Inputs().Tag(kNormRectTag).Get<NormalizedRect>();
@ -115,6 +141,21 @@ REGISTER_CALCULATOR(RectTransformationCalculator);
cc->Outputs().Index(0).AddPacket(
MakePacket<NormalizedRect>(rect).At(cc->InputTimestamp()));
}
if (cc->Inputs().HasTag(kNormRectsTag) &&
!cc->Inputs().Tag(kNormRectsTag).IsEmpty()) {
auto rects =
cc->Inputs().Tag(kNormRectsTag).Get<std::vector<NormalizedRect>>();
const auto& image_size =
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
auto output_rects =
absl::make_unique<std::vector<NormalizedRect>>(rects.size());
for (int i = 0; i < rects.size(); ++i) {
output_rects->at(i) = rects[i];
auto it = output_rects->begin() + i;
TransformNormalizedRect(&(*it), image_size.first, image_size.second);
}
cc->Outputs().Index(0).Add(output_rects.release(), cc->InputTimestamp());
}
return ::mediapipe::OkStatus();
}

View File

@ -23,13 +23,14 @@
#include "mediapipe/calculators/util/top_k_scores_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/util/resource_util.h"
#if defined(MEDIAPIPE_LITE) || defined(__ANDROID__) || \
(defined(__APPLE__) && !TARGET_OS_OSX)
#if defined(MEDIAPIPE_LITE) || defined(__EMSCRIPTEN__) || \
defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
#include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/helpers.h"
#else
@ -37,8 +38,10 @@
#endif
namespace mediapipe {
// A calculator that takes a vector of scores and returns the indexes, scores,
// labels of the top k elements.
// labels of the top k elements, classification protos, and summary std::string
// (in csv format).
//
// Usage example:
// node {
@ -47,6 +50,8 @@ namespace mediapipe {
// output_stream: "TOP_K_INDEXES:top_k_indexes"
// output_stream: "TOP_K_SCORES:top_k_scores"
// output_stream: "TOP_K_LABELS:top_k_labels"
// output_stream: "TOP_K_CLASSIFICATIONS:top_k_classes"
// output_stream: "SUMMARY:summary"
// options: {
// [mediapipe.TopKScoresCalculatorOptions.ext] {
// top_k: 5
@ -69,6 +74,7 @@ class TopKScoresCalculator : public CalculatorBase {
int top_k_ = -1;
float threshold_ = 0.0;
std::unordered_map<int, std::string> label_map_;
bool label_map_loaded_ = false;
};
REGISTER_CALCULATOR(TopKScoresCalculator);
@ -84,6 +90,12 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
cc->Outputs().Tag("TOP_K_LABELS").Set<std::vector<std::string>>();
}
if (cc->Outputs().HasTag("CLASSIFICATIONS")) {
cc->Outputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
}
if (cc->Outputs().HasTag("SUMMARY")) {
cc->Outputs().Tag("SUMMARY").Set<std::string>();
}
return ::mediapipe::OkStatus();
}
@ -149,7 +161,7 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
reverse(top_k_indexes.begin(), top_k_indexes.end());
reverse(top_k_scores.begin(), top_k_scores.end());
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
if (label_map_loaded_) {
for (int index : top_k_indexes) {
top_k_labels.push_back(label_map_[index]);
}
@ -172,6 +184,35 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
.AddPacket(MakePacket<std::vector<std::string>>(top_k_labels)
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag("SUMMARY")) {
std::vector<std::string> results;
for (int index = 0; index < top_k_indexes.size(); ++index) {
if (label_map_loaded_) {
results.push_back(
absl::StrCat(top_k_labels[index], ":", top_k_scores[index]));
} else {
results.push_back(
absl::StrCat(top_k_indexes[index], ":", top_k_scores[index]));
}
}
cc->Outputs().Tag("SUMMARY").AddPacket(
MakePacket<std::string>(absl::StrJoin(results, ","))
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag("TOP_K_CLASSIFICATION")) {
auto classification_list = absl::make_unique<ClassificationList>();
for (int index = 0; index < top_k_indexes.size(); ++index) {
Classification* classification =
classification_list->add_classification();
classification->set_index(top_k_indexes[index]);
classification->set_score(top_k_scores[index]);
if (label_map_loaded_) {
classification->set_label(top_k_labels[index]);
}
}
}
return ::mediapipe::OkStatus();
}
@ -188,6 +229,7 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
while (std::getline(stream, line)) {
label_map_[i++] = line;
}
label_map_loaded_ = true;
return ::mediapipe::OkStatus();
}

View File

@ -13,12 +13,16 @@
# limitations under the License.
#
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
load(
"//mediapipe/framework/tool:mediapipe_graph.bzl",
"mediapipe_binary_graph",
)
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "flow_to_image_calculator_proto",
srcs = ["flow_to_image_calculator.proto"],
@ -52,9 +56,7 @@ mediapipe_cc_proto_library(
cc_library(
name = "flow_to_image_calculator",
srcs = ["flow_to_image_calculator.cc"],
visibility = [
"//visibility:public",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/video:flow_to_image_calculator_cc_proto",
"//mediapipe/calculators/video/tool:flow_quantizer_model",
@ -129,10 +131,20 @@ cc_library(
alwayslink = 1,
)
filegroup(
name = "test_videos",
srcs = [
"testdata/format_FLV_H264_AAC.video",
"testdata/format_MKV_VP8_VORBIS.video",
"testdata/format_MP4_AVC720P_AAC.video",
],
visibility = ["//visibility:public"],
)
cc_test(
name = "opencv_video_decoder_calculator_test",
srcs = ["opencv_video_decoder_calculator_test.cc"],
data = ["//mediapipe/calculators/video/testdata:test_videos"],
data = [":test_videos"],
deps = [
":opencv_video_decoder_calculator",
"//mediapipe/framework:calculator_runner",
@ -151,7 +163,7 @@ cc_test(
cc_test(
name = "opencv_video_encoder_calculator_test",
srcs = ["opencv_video_encoder_calculator_test.cc"],
data = ["//mediapipe/calculators/video/testdata:test_videos"],
data = [":test_videos"],
deps = [
":opencv_video_decoder_calculator",
":opencv_video_encoder_calculator",
@ -175,7 +187,6 @@ cc_test(
cc_test(
name = "tvl1_optical_flow_calculator_test",
srcs = ["tvl1_optical_flow_calculator_test.cc"],
data = ["//mediapipe/calculators/image/testdata:test_images"],
deps = [
":tvl1_optical_flow_calculator",
"//mediapipe/framework:calculator_framework",

View File

@ -123,6 +123,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
cc->Outputs()
.Tag("VIDEO_PRESTREAM")
.Add(header.release(), Timestamp::PreStream());
cc->Outputs().Tag("VIDEO_PRESTREAM").Close();
}
// Rewind to the very first frame.
cap_->set(cv::CAP_PROP_POS_AVI_RATIO, 0);

View File

@ -13,24 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//mediapipe/calculators/video:__subpackages__"])
exports_files(["LICENSE"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "flow_quantizer_model_proto",
srcs = ["flow_quantizer_model.proto"],
visibility = ["//mediapipe:__subpackages__"],
visibility = ["//visibility:public"],
)
mediapipe_cc_proto_library(
name = "flow_quantizer_model_cc_proto",
srcs = ["flow_quantizer_model.proto"],
visibility = ["//mediapipe:__subpackages__"],
visibility = ["//visibility:public"],
deps = [":flow_quantizer_model_proto"],
)

View File

@ -0,0 +1,131 @@
## MediaPipe Android Archive Library
***Experimental Only***
The MediaPipe Android archive library is a convenient way to use MediaPipe with
Android Studio and Gradle. MediaPipe doesn't publish a general AAR that can be
used by all projects. Instead, developers need to add a mediapipe_aar() target
to generate a custom AAR file for their own projects. This is necessary in order
to include specific resources such as MediaPipe calculators needed for each
project.
### Steps to build a MediaPipe AAR
1. Create a mediapipe_aar() target.
In the MediaPipe directory, create a new mediapipe_aar() target in a BUILD
file. You need to figure out what calculators are used in the graph and
provide the calculator dependencies to the mediapipe_aar(). For example, to
build an AAR for [face detection gpu](./face_detection_mobile_gpu.md), you
can put the following code into
mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/BUILD.
```
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar")
mediapipe_aar(
name = "mp_face_detection_aar",
calculators = ["//mediapipe/graphs/face_detection:mobile_calculators"],
)
```
2. Run the Bazel build command to generate the AAR.
```bash
bazel build -c opt --fat_apk_cpu=arm64-v8a,armeabi-v7a //path/to/the/aar/build/file:aar_name
```
For the face detection AAR target we made in the step 1, run:
```bash
bazel build -c opt --fat_apk_cpu=arm64-v8a,armeabi-v7a \
//mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar
# It should print:
# Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar up-to-date:
# bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
```
3. (Optional) Save the AAR to your preferred location.
```bash
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
/absolute/path/to/your/preferred/location
```
### Steps to use a MediaPipe AAR in Android Studio with Gradle
1. Start Android Studio and go to your project.
2. Copy the AAR into app/libs.
```bash
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
/path/to/your/app/libs/
```
![Screenshot](images/mobile/aar_location.png)
3. Make app/src/main/assets and copy assets (graph, model, and etc) into
app/src/main/assets.
Build the MediaPipe binary graph and copy the assets into
app/src/main/assets, e.g., for the face detection graph, you need to build
and copy
[the binary graph](https://github.com/google/mediapipe/blob/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD#L41),
[the tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite),
and
[the label map](https://github.com/google/mediapipe/blob/master/mediapipe/models/face_detection_front_labelmap.txt).
```bash
bazel build -c opt mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu:binary_graph
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/facedetectiongpu.binarypb /path/to/your/app/src/main/assets/
cp mediapipe/models/face_detection_front.tflite /path/to/your/app/src/main/assets/
cp mediapipe/models/face_detection_front_labelmap.txt /path/to/your/app/src/main/assets/
```
![Screenshot](images/mobile/assets_location.png)
4. Make app/src/main/jniLibs and copy OpenCV JNI libraries into
app/src/main/jniLibs.
MediaPipe depends on OpenCV, you will need to copy the precompiled OpenCV so
files into app/src/main/jniLibs. You can download the official OpenCV
Android SDK from
[here](https://github.com/opencv/opencv/releases/download/4.1.0/opencv-4.1.0-android-sdk.zip)
and run:
```bash
cp -R ~/Downloads/OpenCV-android-sdk/sdk/native/libs/arm* /path/to/your/app/src/main/jniLibs/
```
![Screenshot](images/mobile/android_studio_opencv_location.png)
5. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR.
```
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar'])
implementation 'androidx.appcompat:appcompat:1.0.2'
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test.ext:junit:1.1.0'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1'
// MediaPipe deps
implementation 'com.google.flogger:flogger:0.3.1'
implementation 'com.google.flogger:flogger-system-backend:0.3.1'
implementation 'com.google.code.findbugs:jsr305:3.0.2'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.protobuf:protobuf-lite:3.0.0'
// CameraX core library
def camerax_version = "1.0.0-alpha06"
implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version"
}
```
6. Follow our Android app examples to use MediaPipe in Android Studio for your
use case. If you are looking for an example, a working face detection
example can be found
[here](https://github.com/jiuqiant/mediapipe_aar_example).

View File

@ -76,6 +76,15 @@ MediaPipe with a TFLite model for hand tracking in a GPU-accelerated pipeline.
* [Android](./hand_tracking_mobile_gpu.md)
* [iOS](./hand_tracking_mobile_gpu.md)
### Multi-Hand Tracking with GPU
[Multi-Hand Tracking with GPU](./multi_hand_tracking_mobile_gpu.md) illustrates
how to use MediaPipe with a TFLite model for multi-hand tracking in a
GPU-accelerated pipeline.
* [Android](./multi_hand_tracking_mobile_gpu.md)
* [iOS](./multi_hand_tracking_mobile_gpu.md)
### Hair Segmentation with GPU
[Hair Segmentation on GPU](./hair_segmentation_mobile_gpu.md) illustrates how to
@ -96,8 +105,9 @@ using the MediaPipe C++ APIs.
### Feature Extration for YouTube-8M Challenge
[Feature Extration for YouTube-8M Challenge](./youtube_8m.md) shows how to use
MediaPipe to prepare training data for the YouTube-8M Challenge.
[Feature Extration and Model Inference for YouTube-8M Challenge](./youtube_8m.md)
shows how to use MediaPipe to prepare training data for the YouTube-8M Challenge
and do the model inference with the baseline model.
### Preparing Data Sets with MediaSequence
@ -131,6 +141,15 @@ with live video from a webcam.
* [Desktop GPU](./hand_tracking_desktop.md)
* [Desktop CPU](./hand_tracking_desktop.md)
### Multi-Hand Tracking on Desktop with Webcam
[Multi-Hand Tracking on Desktop with Webcam](./multi_hand_tracking_desktop.md)
shows how to use MediaPipe with a TFLite model for multi-hand tracking on
desktop using CPU or GPU with live video from a webcam.
* [Desktop GPU](./multi_hand_tracking_desktop.md)
* [Desktop CPU](./multi_hand_tracking_desktop.md)
### Hair Segmentation on Desktop with Webcam
[Hair Segmentation on Desktop with Webcam](./hair_segmentation_desktop.md) shows

View File

@ -36,10 +36,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
# INFO: 711 processes: 710 linux-sandbox, 1 local.
# INFO: Build completed successfully, 734 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible
$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu \
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu \
--calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_desktop_live.pbtxt
```
@ -60,11 +59,10 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
# INFO: 711 processes: 710 linux-sandbox, 1 local.
# INFO: Build completed successfully, 734 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible,
# or GPU drivers not setup properly.
$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu \
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu \
--calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt
```
@ -79,7 +77,7 @@ below and paste it into
```bash
# MediaPipe graph that performs face detection with TensorFlow Lite on CPU & GPU.
# Used in the examples in
# mediapipie/examples/desktop/face_detection:face_detection_cpu.
# mediapipe/examples/desktop/face_detection:face_detection_cpu.
# Images on CPU coming into and out of the graph.
input_stream: "input_video"

View File

@ -31,15 +31,12 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
#INFO: Found 1 target...
#Target //mediapipe/examples/desktop/hair_segmentation:hair_segmentation_gpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu
#INFO: Elapsed time: 18.209s, Forge stats: 13026/13057 actions cached, 20.8s CPU used, 0.0s queue time, 89.3 MB ObjFS output (novel bytes: 87.4 MB), 0.0 MB local output, Critical Path: 11.88s, Remote (86.01% of the time): [queue: 0.00%, network: 16.83%, setup: 4.59%, process: 38.92%]
#INFO: Streaming build results to: http://sponge2/37d5a184-293b-4e98-a43e-b22084db3142
#INFO: Build completed successfully, 12210 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible,
# or GPU drivers not setup properly.
$ bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu \
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu \
--calculator_graph_config_file=mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt
```
@ -54,7 +51,7 @@ below and paste it into
```bash
# MediaPipe graph that performs hair segmentation with TensorFlow Lite on GPU.
# Used in the example in
# mediapipie/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu.
# mediapipe/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu.
# Images on GPU coming into and out of the graph.
input_stream: "input_video"

View File

@ -29,7 +29,7 @@ below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/).
```bash
# MediaPipe graph that performs hair segmentation with TensorFlow Lite on GPU.
# Used in the example in
# mediapipie/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu.
# mediapipe/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu.
# Images on GPU coming into and out of the graph.
input_stream: "input_video"

View File

@ -31,14 +31,11 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
# It should print:
#Target //mediapipe/examples/desktop/hand_tracking:hand_tracking_cpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu
#INFO: Elapsed time: 22.645s, Forge stats: 13356/13463 actions cached, 1.5m CPU used, 0.0s queue time, 819.8 MB ObjFS output (novel bytes: 85.6 MB), 0.0 MB local output, Critical Path: 14.43s, Remote (87.25% of the time): [queue: 0.00%, network: 14.88%, setup: 4.80%, process: 39.80%, fetch: 18.15%]
#INFO: Streaming build results to: http://sponge2/360196b9-33ab-44b1-84a7-1022b5043307
#INFO: Build completed successfully, 12517 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible
$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt
```
@ -55,15 +52,12 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
# It should print:
# Target //mediapipe/examples/desktop/hand_tracking:hand_tracking_gpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu
#INFO: Elapsed time: 84.055s, Forge stats: 6858/19343 actions cached, 1.6h CPU used, 0.9s queue time, 1.68 GB ObjFS output (novel bytes: 485.1 MB), 0.0 MB local output, Critical Path: 48.14s, Remote (99.40% of the time): [queue: 0.00%, setup: 5.59%, process: 74.44%]
#INFO: Streaming build results to: http://sponge2/00c7f95f-6fbc-432d-8978-f5d361efca3b
#INFO: Build completed successfully, 22455 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible,
# or GPU drivers not setup properly.
$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt
```
@ -79,7 +73,7 @@ below and paste it into
# MediaPipe graph that performs hand tracking on desktop with TensorFlow Lite
# on CPU & GPU.
# Used in the example in
# mediapipie/examples/desktop/hand_tracking:hand_tracking_cpu.
# mediapipe/examples/desktop/hand_tracking:hand_tracking_cpu.
# Images coming into and out of the graph.
input_stream: "input_video"

View File

@ -100,8 +100,8 @@ see the Visualizing Subgraphs section in the
```bash
# MediaPipe graph that performs hand tracking with TensorFlow Lite on GPU.
# Used in the examples in
# mediapipie/examples/android/src/java/com/mediapipe/apps/handtrackinggpu and
# mediapipie/examples/ios/handtrackinggpu.
# mediapipe/examples/android/src/java/com/mediapipe/apps/handtrackinggpu and
# mediapipe/examples/ios/handtrackinggpu.
# Images coming into and out of the graph.
input_stream: "input_video"

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 213 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 479 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 225 KiB

View File

@ -7,11 +7,8 @@ future.
Note: If you plan to use TensorFlow calculators and example apps, there is a
known issue with gcc and g++ version 6.3 and 7.3. Please use other versions.
Note: While Mediapipe configures TensorFlow, if you see the
following error:
`"...git_configure.bzl", line 14, in _fail fail(("%sGit Configuration
Error:%s %...)))`,
please install the python future library using: `$ pip install --user future`.
Note: To make Mediapipe work with TensorFlow, please install the python "future"
library and the python "six" library using `pip install --user future six`.
Choose your operating system:
@ -24,7 +21,8 @@ Choose your operating system:
To build and run Android apps:
- [Setting up Android SDK and NDK](#setting-up-android-sdk-and-ndk)
- [Setting up Android Studio with MediaPipe](#setting-up-android-studio-with-mediapipe)
- [Using MediaPipe with Gradle](#using-mediapipe-with-gradle)
- [Using MediaPipe with Bazel](#using-mediapipe-with-bazel)
To build and run iOS apps:
@ -51,8 +49,8 @@ To build and run iOS apps:
# Run 'bazel version' to check version of bazel installed
```
Option 2. Follow Bazel's
[documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
Option 2. Follow the official
[Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
to install any version of Bazel manually.
3. Install OpenCV and FFmpeg.
@ -75,10 +73,10 @@ To build and run iOS apps:
[documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html)
to manually build OpenCV from source code.
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to
Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
in "/usr/local/", you need to update the "linux_opencv" new_local_repository
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
like the following:
```bash
@ -161,8 +159,8 @@ To build and run iOS apps:
2. Install Bazel (0.24.1 and above required).
Follow Bazel's
[documentation](https://docs.bazel.build/versions/master/install-redhat.html)
Follow the official
[Bazel documentation](https://docs.bazel.build/versions/master/install-redhat.html)
to install Bazel manually.
3. Install OpenCV.
@ -178,10 +176,10 @@ To build and run iOS apps:
Option 2. Build OpenCV from source code.
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to
Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
in "/usr/local/", you need to update the "linux_opencv" new_local_repository
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
like the following:
```bash
@ -237,7 +235,7 @@ To build and run iOS apps:
* Install [Homebrew](https://brew.sh).
* Install [Xcode](https://developer.apple.com/xcode/) and its Command Line
Tools.
Tools by `xcode-select install`.
2. Checkout MediaPipe repository.
@ -257,8 +255,8 @@ To build and run iOS apps:
# Run 'bazel version' to check version of bazel installed
```
Option 2. Follow Bazel's
[documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x)
Option 2. Follow the official
[Bazel documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x)
to install any version of Bazel manually.
4. Install OpenCV and FFmpeg.
@ -281,7 +279,7 @@ To build and run iOS apps:
$ port install opencv
```
Note: when using MacPorts, please edit the [`WORKSAPCE`],
Note: when using MacPorts, please edit the [`WORKSPACE`],
[`opencv_macos.BUILD`], and [`ffmpeg_macos.BUILD`] files like the following:
```bash
@ -419,10 +417,10 @@ To build and run iOS apps:
[documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html)
to manually build OpenCV from source code.
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to
Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
in "/usr/local/", you need to update the "linux_opencv" new_local_repository
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
like the following:
```bash
@ -581,6 +579,11 @@ export ANDROID_HOME=<path to the Android SDK>
export ANDROID_NDK_HOME=<path to the Android NDK>
```
In order to use MediaPipe on earlier Android versions, MediaPipe needs to switch
to a lower Android API level. You can achieve this by specifying `api_level =
<api level integer>` in android_ndk_repository() and/or android_sdk_repository()
in the [`WORKSPACE`] file.
Please verify all the necessary packages are installed.
* Android SDK Platform API Level 28 or 29
@ -589,10 +592,20 @@ Please verify all the necessary packages are installed.
* Android SDK Tools 26.1.1
* Android NDK 17c or above
### Setting up Android Studio with MediaPipe
### Using MediaPipe with Gradle
The steps below use Android Studio 3.5 to build and install a MediaPipe example
app.
MediaPipe can be used within an existing project, such as a Gradle project,
using the MediaPipe AAR target defined in mediapipe_aar.bzl. Please see the
separate [MediaPipe Android Archive Library](./android_archive_library.md)
documentation.
### Using MediaPipe with Bazel
The MediaPipe project can be imported to Android Studio using the Bazel plugins.
This allows the MediaPipe examples and demos to be built and modified in Android
Studio. To incorporate MediaPipe into an existing Android Studio project, see:
"Using MediaPipe with Gradle". The steps below use Android Studio 3.5 to build
and install a MediaPipe example app.
1. Install and launch Android Studio 3.5.
@ -682,7 +695,7 @@ app.
* Press the `[+]` button to add the new configuration.
* Select `Run` to run the example app on the connected Android device.
[`WORKSAPCE`]: https://github.com/google/mediapipe/tree/master/WORKSPACE
[`WORKSPACE`]: https://github.com/google/mediapipe/tree/master/WORKSPACE
[`opencv_linux.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_linux.BUILD
[`opencv_macos.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_macos.BUILD
[`ffmpeg_macos.BUILD`]:https://github.com/google/mediapipe/tree/master/third_party/ffmpeg_macos.BUILD

View File

@ -0,0 +1,177 @@
## Multi-Hand Tracking on Desktop
This is an example of using MediaPipe to run hand tracking models (TensorFlow
Lite) and render bounding boxes on the detected hand instances (for multiple
hands). To know more about the hand tracking models, please refer to the model
[`README file`]. Moreover, if you are interested in running the same TensorfFlow
Lite model on Android/iOS, please see the
[Mulit-Hand Tracking on GPU on Android/iOS](multi_hand_tracking_mobile_gpu.md)
and
We show the hand tracking demos with TensorFlow Lite model using the Webcam:
- [TensorFlow Lite Multi-Hand Tracking Demo with Webcam (CPU)](#tensorflow-lite-multi-hand-tracking-demo-with-webcam-cpu)
- [TensorFlow Lite Multi-Hand Tracking Demo with Webcam (GPU)](#tensorflow-lite-multi-hand-tracking-demo-with-webcam-gpu)
Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please
see
[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md).
Note: If MediaPipe depends on OpenCV 2, please see the
[known issues with OpenCV 2](#known-issues-with-opencv-2) section.
### TensorFlow Lite Multi-Hand Tracking Demo with Webcam (CPU)
To build and run the TensorFlow Lite example on desktop (CPU) with Webcam, run:
```bash
# Video from webcam running on desktop CPU
$ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_cpu
# It should print:
#Target //mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_cpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_cpu
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_cpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop_live.pbtxt
```
### TensorFlow Lite Multi-Hand Tracking Demo with Webcam (GPU)
To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run:
```bash
# Video from webcam running on desktop GPU
# This works only for linux currently
$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_gpu
# It should print:
# Target //mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_gpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_gpu
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible,
# or GPU drivers not setup properly.
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_gpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt
```
#### Graph
![graph visualization](images/multi_hand_tracking_desktop.png)
To visualize the graph as shown above, copy the text specification of the graph
below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev).
```bash
# MediaPipe graph that performs multi-hand tracking on desktop with TensorFlow
# Lite on CPU.
# Used in the example in
# mediapipie/examples/desktop/multi_hand_tracking:multi_hand_tracking_cpu.
# Images coming into and out of the graph.
input_stream: "input_video"
output_stream: "output_video"
# Determines if an input vector of NormalizedRect has a size greater than or
# equal to the provided min_size.
node {
calculator: "NormalizedRectVectorHasMinSizeCalculator"
input_stream: "ITERABLE:prev_multi_hand_rects_from_landmarks"
output_stream: "prev_has_enough_hands"
node_options: {
[type.googleapis.com/mediapipe.CollectionHasMinSizeCalculatorOptions] {
# This value can be changed to support tracking arbitrary number of hands.
# Please also remember to modify max_vec_size in
# ClipVectorSizeCalculatorOptions in
# mediapipe/graphs/hand_tracking/subgraphs/multi_hand_detection_gpu.pbtxt
min_size: 2
}
}
}
# Drops the incoming image if the previous frame had at least N hands.
# Otherwise, passes the incoming image through to trigger a new round of hand
# detection in MultiHandDetectionSubgraph.
node {
calculator: "GateCalculator"
input_stream: "input_video"
input_stream: "DISALLOW:prev_has_enough_hands"
output_stream: "multi_hand_detection_input_video"
node_options: {
[type.googleapis.com/mediapipe.GateCalculatorOptions] {
empty_packets_as_allow: true
}
}
}
# Subgraph that detections hands (see multi_hand_detection_cpu.pbtxt).
node {
calculator: "MultiHandDetectionSubgraph"
input_stream: "multi_hand_detection_input_video"
output_stream: "DETECTIONS:multi_palm_detections"
output_stream: "NORM_RECTS:multi_palm_rects"
}
# Subgraph that localizes hand landmarks for multiple hands (see
# multi_hand_landmark.pbtxt).
node {
calculator: "MultiHandLandmarkSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "NORM_RECTS:multi_hand_rects"
output_stream: "LANDMARKS:multi_hand_landmarks"
output_stream: "NORM_RECTS:multi_hand_rects_from_landmarks"
}
# Caches a hand rectangle fed back from MultiHandLandmarkSubgraph, and upon the
# arrival of the next input image sends out the cached rectangle with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous hand rectangle. Note that upon the arrival of the
# very first input image, an empty packet is sent out to jump start the
# feedback loop.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:input_video"
input_stream: "LOOP:multi_hand_rects_from_landmarks"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:prev_multi_hand_rects_from_landmarks"
}
# Performs association between NormalizedRect vector elements from previous
# frame and those from the current frame if MultiHandDetectionSubgraph runs.
# This calculator ensures that the output multi_hand_rects vector doesn't
# contain overlapping regions based on the specified min_similarity_threshold.
node {
calculator: "AssociationNormRectCalculator"
input_stream: "prev_multi_hand_rects_from_landmarks"
input_stream: "multi_palm_rects"
output_stream: "multi_hand_rects"
node_options: {
[type.googleapis.com/mediapipe.AssociationCalculatorOptions] {
min_similarity_threshold: 0.1
}
}
}
# Subgraph that renders annotations and overlays them on top of the input
# images (see multi_hand_renderer_cpu.pbtxt).
node {
calculator: "MultiHandRendererSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "DETECTIONS:multi_palm_detections"
input_stream: "LANDMARKS:multi_hand_landmarks"
input_stream: "NORM_RECTS:0:multi_palm_rects"
input_stream: "NORM_RECTS:1:multi_hand_rects"
output_stream: "IMAGE:output_video"
}
```
[`README file`]:https://github.com/google/mediapipe/tree/master/mediapipe/README.md

View File

@ -0,0 +1,755 @@
# Multi-Hand Tracking (GPU)
This doc focuses on the
[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt)
that performs multi-hand tracking with TensorFlow Lite on GPU. It is related to
the [hand_tracking_example](./hand_tracking_mobile_gpu.md), and we recommend
users to review the (single) hand tracking example first.
![multi_hand_tracking_android_gpu.gif](images/mobile/multi_hand_tracking_android_gpu.gif)
In the visualization above, the red dots represent the hand landmarks and the
green lines are simply connections between selected landmark paris for
visualization of the hand skeleton. When there are fewer than `N` hands (`N=2`
in the graphs here), the purple box represents a hand rectangle that covers the
entire hand, derived from hand detection (see
[hand_detection_example](./hand_detection_mobile_gpu.md)). When there are `N`
hands (i.e. 2 hands for the graphs here), the red boxes represent hand
rectangles for each of the hands, derived from the previous round of hand
landmark localization using an ML model (see also
[model card](https://mediapipe.page.link/handmc)). Hand landmark localization
for each hand is performed only within the hand rectangle for computational
efficiency and accuracy. Hand detection is only invoked whenever there are fewer
than `N` hands in the previous iteration.
This example can also run a model that localizes hand landmarks in 3D (i.e.,
estimating an extra z coordinate):
![multi_hand_tracking_3d_android_gpu.gif](images/mobile/multi_hand_tracking_3d_android_gpu.gif)
In the visualization above, the localized hand landmarks are represented by dots
in different shades, with the brighter ones denoting landmarks closer to the
camera.
## Android
[Source](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu)
To build the app yourself, run:
```bash
bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu
```
To build for the 3D mode, run:
```bash
bazel build -c opt --config=android_arm64 --define 3D=true mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu
```
Once the app is built, install it on Android device with:
```bash
adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/multihandtrackinggpu/multihandtrackinggpu.apk
```
## iOS
[Source](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/multihandtrackinggpu).
See the general [instructions](./mediapipe_ios_setup.md) for building iOS
examples and generating an Xcode project. This will be the HandDetectionGpuApp
target.
To build on the command line:
```bash
bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/multihandtrackinggpu:MultiHandTrackingGpuApp
```
To build for the 3D mode, run:
```bash
bazel build -c opt --config=ios_arm64 --define 3D=true mediapipe/examples/ios/multihandtrackinggpu:MultiHandTrackingGpuApp
```
## Graph
The multi-hand tracking [main graph](#main-graph) internal utilizes a
[multi_hand_detection_subgraph](#multi-hand-detection-subgraph), a
[multi_hand_landmark_subgraph](#multi-hand-landmark-subgraph), and a
[multi_hand_renderer_subgraph](#multi-hand-renderer-subgraph).
The subgraphs show up in the main graph visualization as nodes colored in
purple, and the subgraph itself can also be visualized just like a regular
graph. For more information on how to visualize a graph that includes subgraphs,
see the Visualizing Subgraphs section in the
[visualizer documentation](./visualizer.md).
### Main Graph
![multi_hand_tracking_mobile_graph](images/mobile/multi_hand_tracking_mobile.png)
There are two key differences between this graph and the
[single_hand_tracking_mobile_graph](./hand_tracking_mobile_gpu.md).
1. There is a `NormalizedRectVectorHasMinSize` calculator, that checks if in
input vector of `NormalizedRect` objects has a minimum size equal to `N`. In
this graph, if the vector contains fewer than `N` objects,
`MultiHandDetection` subgraph runs. Otherwise, the `GateCalculator` doesn't
send any image packets to the `MultiHandDetection` subgraph. This way, the
main graph is efficient in that it avoids running the costly hand detection
step when there are already `N` hands in the frame.
2. The `MergeCalculator` has been replaced by the `AssociationNormRect`
calculator. This `AssociationNormRect` takes as input a vector of
`NormalizedRect` objects from the `MultiHandDetection` subgraph on the
current frame, and a vector of `NormalizedRect` objects from the
`MultiHandLandmark` subgraph from the previous frame, and performs an
association operation between these objects. This calculator ensures that
the output vector doesn't contain overlapping regions based on the specified
`min_similarity_threshold`.
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt)
```bash
# MediaPipe graph that performs multi-hand tracking with TensorFlow Lite on GPU.
# Used in the examples in
# mediapipie/examples/android/src/java/com/mediapipe/apps/multihandtrackinggpu.
# Images coming into and out of the graph.
input_stream: "input_video"
output_stream: "output_video"
# Throttles the images flowing downstream for flow control. It passes through
# the very first incoming image unaltered, and waits for downstream nodes
# (calculators and subgraphs) in the graph to finish their tasks before it
# passes through another image. All images that come in while waiting are
# dropped, limiting the number of in-flight images in most part of the graph to
# 1. This prevents the downstream nodes from queuing up incoming images and data
# excessively, which leads to increased latency and memory usage, unwanted in
# real-time mobile applications. It also eliminates unnecessarily computation,
# e.g., the output produced by a node may get dropped downstream if the
# subsequent nodes are still busy processing previous inputs.
node {
calculator: "FlowLimiterCalculator"
input_stream: "input_video"
input_stream: "FINISHED:multi_hand_rects"
input_stream_info: {
tag_index: "FINISHED"
back_edge: true
}
output_stream: "throttled_input_video"
}
# Determines if an input vector of NormalizedRect has a size greater than or
# equal to the provided min_size.
node {
calculator: "NormalizedRectVectorHasMinSizeCalculator"
input_stream: "ITERABLE:prev_multi_hand_rects_from_landmarks"
output_stream: "prev_has_enough_hands"
node_options: {
[type.googleapis.com/mediapipe.CollectionHasMinSizeCalculatorOptions] {
# This value can be changed to support tracking arbitrary number of hands.
# Please also remember to modify max_vec_size in
# ClipVectorSizeCalculatorOptions in
# mediapipe/graphs/hand_tracking/subgraphs/multi_hand_detection_gpu.pbtxt
min_size: 2
}
}
}
# Drops the incoming image if the previous frame had at least N hands.
# Otherwise, passes the incoming image through to trigger a new round of hand
# detection in MultiHandDetectionSubgraph.
node {
calculator: "GateCalculator"
input_stream: "throttled_input_video"
input_stream: "DISALLOW:prev_has_enough_hands"
output_stream: "multi_hand_detection_input_video"
node_options: {
[type.googleapis.com/mediapipe.GateCalculatorOptions] {
empty_packets_as_allow: true
}
}
}
# Subgraph that detections hands (see multi_hand_detection_gpu.pbtxt).
node {
calculator: "MultiHandDetectionSubgraph"
input_stream: "multi_hand_detection_input_video"
output_stream: "DETECTIONS:multi_palm_detections"
output_stream: "NORM_RECTS:multi_palm_rects"
}
# Subgraph that localizes hand landmarks for multiple hands (see
# multi_hand_landmark.pbtxt).
node {
calculator: "MultiHandLandmarkSubgraph"
input_stream: "IMAGE:throttled_input_video"
input_stream: "NORM_RECTS:multi_hand_rects"
output_stream: "LANDMARKS:multi_hand_landmarks"
output_stream: "NORM_RECTS:multi_hand_rects_from_landmarks"
}
# Caches a hand rectangle fed back from MultiHandLandmarkSubgraph, and upon the
# arrival of the next input image sends out the cached rectangle with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous hand rectangle. Note that upon the arrival of the
# very first input image, an empty packet is sent out to jump start the
# feedback loop.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:throttled_input_video"
input_stream: "LOOP:multi_hand_rects_from_landmarks"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:prev_multi_hand_rects_from_landmarks"
}
# Performs association between NormalizedRect vector elements from previous
# frame and those from the current frame if MultiHandDetectionSubgraph runs.
# This calculator ensures that the output multi_hand_rects vector doesn't
# contain overlapping regions based on the specified min_similarity_threshold.
node {
calculator: "AssociationNormRectCalculator"
input_stream: "prev_multi_hand_rects_from_landmarks"
input_stream: "multi_palm_rects"
output_stream: "multi_hand_rects"
node_options: {
[type.googleapis.com/mediapipe.AssociationCalculatorOptions] {
min_similarity_threshold: 0.1
}
}
}
# Subgraph that renders annotations and overlays them on top of the input
# images (see multi_hand_renderer_gpu.pbtxt).
node {
calculator: "MultiHandRendererSubgraph"
input_stream: "IMAGE:throttled_input_video"
input_stream: "DETECTIONS:multi_palm_detections"
input_stream: "LANDMARKS:multi_hand_landmarks"
input_stream: "NORM_RECTS:0:multi_palm_rects"
input_stream: "NORM_RECTS:1:multi_hand_rects"
output_stream: "IMAGE:output_video"
}
```
### Multi-Hand Detection Subgraph
![multi_hand_detection_gpu_subgraph](images/mobile/multi_hand_detection_gpu_subgraph.png)
This graph outputs a vector of `NormalizedRect` objects corresponding to each of
the hand instances visible in the frame. Note that at the end of this graph,
there is a `ClipNormalizedRectVectorSizeCalculator`. This calculator clips the
size of the input vector to a maximum size `N`. This implies that the
`MultiHandDetection` subgraph outputs a vector of maximum `N` hand instance
locations.
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/multi_hand_detection_gpu.pbtxt)
```bash
# MediaPipe multi-hand detection subgraph.
type: "MultiHandDetectionSubgraph"
input_stream: "input_video"
output_stream: "DETECTIONS:palm_detections"
output_stream: "NORM_RECTS:clipped_hand_rects_from_palm_detections"
# Transforms the input image on GPU to a 256x256 image. To scale the input
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
# resulting in potential letterboxing in the transformed image.
node: {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE_GPU:input_video"
output_stream: "IMAGE_GPU:transformed_input_video"
output_stream: "LETTERBOX_PADDING:letterbox_padding"
node_options: {
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
output_width: 256
output_height: 256
scale_mode: FIT
}
}
}
# Generates a single side packet containing a TensorFlow Lite op resolver that
# supports custom ops needed by the model used in this graph.
node {
calculator: "TfLiteCustomOpResolverCalculator"
output_side_packet: "opresolver"
node_options: {
[type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] {
use_gpu: true
}
}
}
# Converts the transformed input image on GPU into an image tensor stored as a
# TfLiteTensor.
node {
calculator: "TfLiteConverterCalculator"
input_stream: "IMAGE_GPU:transformed_input_video"
output_stream: "TENSORS_GPU:image_tensor"
}
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and
# scores.
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS_GPU:image_tensor"
output_stream: "TENSORS_GPU:detection_tensors"
input_side_packet: "CUSTOM_OP_RESOLVER:opresolver"
node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "mediapipe/models/palm_detection.tflite"
use_gpu: true
}
}
}
# Generates a single side packet containing a vector of SSD anchors based on
# the specification in the options.
node {
calculator: "SsdAnchorsCalculator"
output_side_packet: "anchors"
node_options: {
[type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] {
num_layers: 5
min_scale: 0.1171875
max_scale: 0.75
input_size_height: 256
input_size_width: 256
anchor_offset_x: 0.5
anchor_offset_y: 0.5
strides: 8
strides: 16
strides: 32
strides: 32
strides: 32
aspect_ratios: 1.0
fixed_anchor_size: true
}
}
}
# Decodes the detection tensors generated by the TensorFlow Lite model, based on
# the SSD anchors and the specification in the options, into a vector of
# detections. Each detection describes a detected object.
node {
calculator: "TfLiteTensorsToDetectionsCalculator"
input_stream: "TENSORS_GPU:detection_tensors"
input_side_packet: "ANCHORS:anchors"
output_stream: "DETECTIONS:detections"
node_options: {
[type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] {
num_classes: 1
num_boxes: 2944
num_coords: 18
box_coord_offset: 0
keypoint_coord_offset: 4
num_keypoints: 7
num_values_per_keypoint: 2
sigmoid_score: true
score_clipping_thresh: 100.0
reverse_output_order: true
x_scale: 256.0
y_scale: 256.0
h_scale: 256.0
w_scale: 256.0
min_score_thresh: 0.7
}
}
}
# Performs non-max suppression to remove excessive detections.
node {
calculator: "NonMaxSuppressionCalculator"
input_stream: "detections"
output_stream: "filtered_detections"
node_options: {
[type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] {
min_suppression_threshold: 0.3
overlap_type: INTERSECTION_OVER_UNION
algorithm: WEIGHTED
return_empty_detections: true
}
}
}
# Maps detection label IDs to the corresponding label text ("Palm"). The label
# map is provided in the label_map_path option.
node {
calculator: "DetectionLabelIdToTextCalculator"
input_stream: "filtered_detections"
output_stream: "labeled_detections"
node_options: {
[type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] {
label_map_path: "mediapipe/models/palm_detection_labelmap.txt"
}
}
}
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
# letterboxed image (after image transformation with the FIT scale mode) to the
# corresponding locations on the same image with the letterbox removed (the
# input image to the graph before image transformation).
node {
calculator: "DetectionLetterboxRemovalCalculator"
input_stream: "DETECTIONS:labeled_detections"
input_stream: "LETTERBOX_PADDING:letterbox_padding"
output_stream: "DETECTIONS:palm_detections"
}
# Extracts image size from the input images.
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE_GPU:input_video"
output_stream: "SIZE:image_size"
}
# Converts each palm detection into a rectangle (normalized by image size)
# that encloses the palm and is rotated such that the line connecting center of
# the wrist and MCP of the middle finger is aligned with the Y-axis of the
# rectangle.
node {
calculator: "DetectionsToRectsCalculator"
input_stream: "DETECTIONS:palm_detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "NORM_RECTS:palm_rects"
node_options: {
[type.googleapis.com/mediapipe.DetectionsToRectsCalculatorOptions] {
rotation_vector_start_keypoint_index: 0 # Center of wrist.
rotation_vector_end_keypoint_index: 2 # MCP of middle finger.
rotation_vector_target_angle_degrees: 90
output_zero_rect_for_empty_detections: true
}
}
}
# Expands and shifts the rectangle that contains the palm so that it's likely
# to cover the entire hand.
node {
calculator: "RectTransformationCalculator"
input_stream: "NORM_RECTS:palm_rects"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "hand_rects_from_palm_detections"
node_options: {
[type.googleapis.com/mediapipe.RectTransformationCalculatorOptions] {
scale_x: 2.6
scale_y: 2.6
shift_y: -0.5
square_long: true
}
}
}
# Clips the size of the input vector to the provided max_vec_size. This
# determines the maximum number of hand instances this graph outputs.
# Note that the performance gain of clipping detections earlier in this graph is
# minimal because NMS will minimize overlapping detections and the number of
# detections isn't expected to exceed 5-10.
node {
calculator: "ClipNormalizedRectVectorSizeCalculator"
input_stream: "hand_rects_from_palm_detections"
output_stream: "clipped_hand_rects_from_palm_detections"
node_options: {
[type.googleapis.com/mediapipe.ClipVectorSizeCalculatorOptions] {
# This value can be changed to support tracking arbitrary number of hands.
# Please also remember to modify min_size in
# CollectionHsMinSizeCalculatorOptions in
# mediapipe/graphs/hand_tracking/multi_hand_tracking_mobile.pbtxt and
# mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop_live.pbtxt.
max_vec_size: 2
}
}
}
```
### Multi-Hand Landmark Subgraph
![multi_hand_landmark_subgraph.pbtxt](images/mobile/multi_hand_landmark_subgraph.png)
This graph accepts as input a vector of `NormalizedRect` objects, corresponding
the the region of each hand instance in the input image. For each
`NormalizedRect` object, the graph runs the existing `HandLandmark` subgraph and
collect the outputs of this subgraph into vectors. This is enabled by
`BeginLoop` and `EndLoop` calculators.
The `BeginLoop` calculator accepts as input a packet containing an iterable
collection of elements. This calculator is templatized (see
[begin_loop_calculator.h](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/begin_loop_calculator.h)).
If the input packet arrived at a timestamp `ts`, this calculator outputs each
element in the collection at a fake timestamp `internal_ts`. At the end of the
collection, the calculator outputs the arrival timestamp `ts` in the output
stream tagged with `BATCH_END`.
The nodes between the `BeginLoop` calculator and the corresponding `EndLoop`
calculator process individual packets at the fake timestamps `internal_ts`.
After each element is processed, it is sent to the `EndLoop` calculator (see
[end_loop_calculator.h](https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/end_loop_calculator.h)),
which collects these elements in an output collection. The `EndLoop` calculator
listens for packets from the `BATCH_END` output stream of the `BeginLoop`
calculator. When the `BATCH_END` packet containing the real timestamp `ts`
arrives at the `EndLoop` calculator, the `EndLoop` calculator outputs a packet
containing the collection of processed elements at the real timestamp `ts`.
In the multi-hand landmark subgraph, the `EndLoop` calculators collect the
output vector of hand landmarks per hand instance, the boolean values indicating
the presence of each hand and the `NormalizedRect` objects corresponding to the
regions surrounding each hand into vectors.
Finally, based on the hand presence boolean value, the graph filters the
collections of hand landmarks and `NormalizdRect` objects corresponding to each
hand instance.
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/multi_hand_landmark.pbtxt)
```bash
# MediaPipe hand landmark localization subgraph.
type: "MultiHandLandmarkSubgraph"
input_stream: "IMAGE:input_video"
# A vector of NormalizedRect, one per each hand detected.
input_stream: "NORM_RECTS:multi_hand_rects"
# A vector of NormalizedLandmarks, one set per each hand.
output_stream: "LANDMARKS:filtered_multi_hand_landmarks"
# A vector of NormalizedRect, one per each hand.
output_stream: "NORM_RECTS:filtered_multi_hand_rects_for_next_frame"
# Outputs each element of multi_hand_rects at a fake timestamp for the rest
# of the graph to process. Clones the input_video packet for each
# single_hand_rect at the fake timestamp. At the end of the loop,
# outputs the BATCH_END timestamp for downstream calculators to inform them
# that all elements in the vector have been processed.
node {
calculator: "BeginLoopNormalizedRectCalculator"
input_stream: "ITERABLE:multi_hand_rects"
input_stream: "CLONE:input_video"
output_stream: "ITEM:single_hand_rect"
output_stream: "CLONE:input_video_cloned"
output_stream: "BATCH_END:single_hand_rect_timestamp"
}
node {
calculator: "HandLandmarkSubgraph"
input_stream: "IMAGE:input_video_cloned"
input_stream: "NORM_RECT:single_hand_rect"
output_stream: "LANDMARKS:single_hand_landmarks"
output_stream: "NORM_RECT:single_hand_rect_from_landmarks"
output_stream: "PRESENCE:single_hand_presence"
}
# Collects the boolean presence value for each single hand into a vector. Upon
# receiving the BATCH_END timestamp, outputs a vector of boolean values at the
# BATCH_END timestamp.
node {
calculator: "EndLoopBooleanCalculator"
input_stream: "ITEM:single_hand_presence"
input_stream: "BATCH_END:single_hand_rect_timestamp"
output_stream: "ITERABLE:multi_hand_presence"
}
# Collects a set of landmarks for each hand into a vector. Upon receiving the
# BATCH_END timestamp, outputs the vector of landmarks at the BATCH_END
# timestamp.
node {
calculator: "EndLoopNormalizedLandmarksVectorCalculator"
input_stream: "ITEM:single_hand_landmarks"
input_stream: "BATCH_END:single_hand_rect_timestamp"
output_stream: "ITERABLE:multi_hand_landmarks"
}
# Collects a NormalizedRect for each hand into a vector. Upon receiving the
# BATCH_END timestamp, outputs the vector of NormalizedRect at the BATCH_END
# timestamp.
node {
calculator: "EndLoopNormalizedRectCalculator"
input_stream: "ITEM:single_hand_rect_from_landmarks"
input_stream: "BATCH_END:single_hand_rect_timestamp"
output_stream: "ITERABLE:multi_hand_rects_for_next_frame"
}
# Filters the input vector of landmarks based on hand presence value for each
# hand. If the hand presence for hand #i is false, the set of landmarks
# corresponding to that hand are dropped from the vector.
node {
calculator: "FilterLandmarksCollectionCalculator"
input_stream: "ITERABLE:multi_hand_landmarks"
input_stream: "CONDITION:multi_hand_presence"
output_stream: "ITERABLE:filtered_multi_hand_landmarks"
}
# Filters the input vector of NormalizedRect based on hand presence value for
# each hand. If the hand presence for hand #i is false, the NormalizedRect
# corresponding to that hand are dropped from the vector.
node {
calculator: "FilterNormalizedRectCollectionCalculator"
input_stream: "ITERABLE:multi_hand_rects_for_next_frame"
input_stream: "CONDITION:multi_hand_presence"
output_stream: "ITERABLE:filtered_multi_hand_rects_for_next_frame"
}
```
### Multi-Hand Renderer Subgraph
![multi_hand_renderer_gpu_subgraph.pbtxt](images/mobile/multi_hand_renderer_gpu_subgraph.png)
This graph also uses `BeginLoop` and `EndLoop` calculators to iteratively
convert a set of hand landmarks per hand instance into corresponding
`RenderData` objects.
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/multi_hand_renderer_gpu.pbtxt)
```bash
# MediaPipe multi-hand tracking rendering subgraph.
type: "MultiHandRendererSubgraph"
input_stream: "IMAGE:input_image"
# A vector of NormalizedLandmarks, one for each hand.
input_stream: "LANDMARKS:multi_hand_landmarks"
# A vector of NormalizedRect, one for each hand.
input_stream: "NORM_RECTS:0:multi_palm_rects"
# A vector of NormalizedRect, one for each hand.
input_stream: "NORM_RECTS:1:multi_hand_rects"
# A vector of Detection, one for each hand.
input_stream: "DETECTIONS:palm_detections"
output_stream: "IMAGE:output_image"
# Converts detections to drawing primitives for annotation overlay.
node {
calculator: "DetectionsToRenderDataCalculator"
input_stream: "DETECTIONS:palm_detections"
output_stream: "RENDER_DATA:detection_render_data"
node_options: {
[type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] {
thickness: 4.0
color { r: 0 g: 255 b: 0 }
}
}
}
# Converts normalized rects to drawing primitives for annotation overlay.
node {
calculator: "RectToRenderDataCalculator"
input_stream: "NORM_RECTS:multi_hand_rects"
output_stream: "RENDER_DATA:multi_hand_rects_render_data"
node_options: {
[type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] {
filled: false
color { r: 255 g: 0 b: 0 }
thickness: 4.0
}
}
}
# Converts normalized rects to drawing primitives for annotation overlay.
node {
calculator: "RectToRenderDataCalculator"
input_stream: "NORM_RECTS:multi_palm_rects"
output_stream: "RENDER_DATA:multi_palm_rects_render_data"
node_options: {
[type.googleapis.com/mediapipe.RectToRenderDataCalculatorOptions] {
filled: false
color { r: 125 g: 0 b: 122 }
thickness: 4.0
}
}
}
# Outputs each element of multi_palm_landmarks at a fake timestamp for the rest
# of the graph to process. At the end of the loop, outputs the BATCH_END
# timestamp for downstream calculators to inform them that all elements in the
# vector have been processed.
node {
calculator: "BeginLoopNormalizedLandmarksVectorCalculator"
input_stream: "ITERABLE:multi_hand_landmarks"
output_stream: "ITEM:single_hand_landmarks"
output_stream: "BATCH_END:landmark_timestamp"
}
# Converts landmarks to drawing primitives for annotation overlay.
node {
calculator: "LandmarksToRenderDataCalculator"
input_stream: "NORM_LANDMARKS:single_hand_landmarks"
output_stream: "RENDER_DATA:single_hand_landmark_render_data"
node_options: {
[type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] {
landmark_connections: 0
landmark_connections: 1
landmark_connections: 1
landmark_connections: 2
landmark_connections: 2
landmark_connections: 3
landmark_connections: 3
landmark_connections: 4
landmark_connections: 0
landmark_connections: 5
landmark_connections: 5
landmark_connections: 6
landmark_connections: 6
landmark_connections: 7
landmark_connections: 7
landmark_connections: 8
landmark_connections: 5
landmark_connections: 9
landmark_connections: 9
landmark_connections: 10
landmark_connections: 10
landmark_connections: 11
landmark_connections: 11
landmark_connections: 12
landmark_connections: 9
landmark_connections: 13
landmark_connections: 13
landmark_connections: 14
landmark_connections: 14
landmark_connections: 15
landmark_connections: 15
landmark_connections: 16
landmark_connections: 13
landmark_connections: 17
landmark_connections: 0
landmark_connections: 17
landmark_connections: 17
landmark_connections: 18
landmark_connections: 18
landmark_connections: 19
landmark_connections: 19
landmark_connections: 20
landmark_color { r: 255 g: 0 b: 0 }
connection_color { r: 0 g: 255 b: 0 }
thickness: 4.0
}
}
}
# Collects a RenderData object for each hand into a vector. Upon receiving the
# BATCH_END timestamp, outputs the vector of RenderData at the BATCH_END
# timestamp.
node {
calculator: "EndLoopRenderDataCalculator"
input_stream: "ITEM:single_hand_landmark_render_data"
input_stream: "BATCH_END:landmark_timestamp"
output_stream: "ITERABLE:multi_hand_landmarks_render_data"
}
# Draws annotations and overlays them on top of the input images. Consumes
# a vector of RenderData objects and draws each of them on the input frame.
node {
calculator: "AnnotationOverlayCalculator"
input_stream: "INPUT_FRAME_GPU:input_image"
input_stream: "detection_render_data"
input_stream: "multi_hand_rects_render_data"
input_stream: "multi_palm_rects_render_data"
input_stream: "VECTOR:0:multi_hand_landmarks_render_data"
output_stream: "OUTPUT_FRAME_GPU:output_image"
}
```

View File

@ -35,10 +35,9 @@ $ bazel build -c opt \
# INFO: 2675 processes: 2673 linux-sandbox, 2 local.
# INFO: Build completed successfully, 2807 total actions
$ export GLOG_logtostderr=1
# Replace <input video path> and <output video path>.
# You can find a test video in mediapipe/examples/desktop/object_detection.
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>
```
@ -55,7 +54,7 @@ below and paste it into
# MediaPipe graph that performs object detection on desktop with TensorFlow
# on CPU.
# Used in the example in
# mediapipie/examples/desktop/object_detection:object_detection_tensorflow.
# mediapipe/examples/desktop/object_detection:object_detection_tensorflow.
# Decodes an input video file into images and a video header.
node {
@ -200,10 +199,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
# INFO: 711 processes: 710 linux-sandbox, 1 local.
# INFO: Build completed successfully, 734 total actions
$ export GLOG_logtostderr=1
# Replace <input video path> and <output video path>.
# You can find a test video in mediapipe/examples/desktop/object_detection.
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>
```
@ -220,14 +218,11 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
# It should print:
#Target //mediapipe/examples/desktop/object_detection:object_detection_cpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu
#INFO: Elapsed time: 16.020s, Forge stats: 13001/13003 actions cached, 2.1s CPU used, 0.0s queue time, 89.0 MB ObjFS output (novel bytes: 88.0 MB), 0.0 MB local output, Critical Path: 10.01s, Remote (41.42% of the time): [queue: 0.00%, setup: 4.21%, process: 12.48%]
#INFO: Streaming build results to: http://sponge2/1824d4cc-ba63-4350-bdc0-aacbd45b902b
#INFO: Build completed successfully, 12154 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu \
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu \
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt
```
@ -243,7 +238,7 @@ below and paste it into
# MediaPipe graph that performs object detection on desktop with TensorFlow Lite
# on CPU.
# Used in the example in
# mediapipie/examples/desktop/object_detection:object_detection_tflite.
# mediapipe/examples/desktop/object_detection:object_detection_tflite.
# max_queue_size limits the number of packets enqueued on any input stream
# by throttling inputs to the graph. This makes the graph only process one

26
mediapipe/docs/web.md Normal file
View File

@ -0,0 +1,26 @@
## MediaPipe on the Web
MediaPipe on the Web is an effort to use [WebAssembly](https://webassembly.org/)
to bring MediaPipe graphs, calculators, and related technologies to the web. The
aim is to have all the pieces (ML, rendering, and processing) running directly
in the browser client-side. The official API is under construction, but the core
technology has been proven effective, and we can already show interactive
cross-platform demos using your live webcam.
![image](images/web_effect.gif) ![image](images/web_segmentation.gif)
### Hand Tracking (with and without SIMD support)
For [Chrome Developer Summit 2019](https://developer.chrome.com/devsummit/), we
used this technology to showcase the potential for performance improvements
using Chrome experimental [WebAssembly SIMD](https://github.com/WebAssembly/simd)
support. Below are two different versions of the
[MediaPipe Hand Tracking Example](https://mediapipe.readthedocs.io/en/latest/hand_tracking_desktop.html)
running on the web:
1. WebAssembly MVP [demo](https://mediapipe.page.link/cds-ht) running around 5-8 frames per second on Desktop Chrome
2. WebAssembly SIMD [demo](https://mediapipe.page.link/cds-ht-simd) running around 15-18 frames per second on *Canary* Chrome for Desktop, which must additionally be launched with the option `--js-flags="--experimental-wasm-simd"`
NOTE: This page is a work-in-progress. More to come soon!

View File

@ -1,9 +1,11 @@
## Extracting Video Features for YouTube-8M Challenge
# Feature Extration and Model Inference for YouTube-8M Challenge
MediaPipe is a useful and general framework for media processing that can assist
with research, development, and deployment of ML models. This example focuses on
model development by demonstrating how to prepare training data for the
YouTube-8M Challenge.
model development by demonstrating how to prepare training data and do model
inference for the YouTube-8M Challenge.
## Extracting Video Features for YouTube-8M Challenge
[Youtube-8M Challenge](https://www.kaggle.com/c/youtube8m-2019) is an annual
video classification challenge hosted by Google. Over the last two years, the
@ -29,14 +31,14 @@ videos.
### Steps to run the YouTube-8M feature extraction graph
1. Checkout the mediapipe repository
1. Checkout the mediapipe repository.
```bash
git clone https://github.com/google/mediapipe.git
cd mediapipe
```
2. Download the PCA and model data
2. Download the PCA and model data.
```bash
mkdir /tmp/mediapipe
@ -49,7 +51,7 @@ videos.
tar -xvf /tmp/mediapipe/inception-2015-12-05.tgz
```
3. Get the VGGish frozen graph
3. Get the VGGish frozen graph.
Note: To run step 3 and step 4, you must have Python 2.7 or 3.5+ installed
with the TensorFlow 1.14+ package installed.
@ -60,24 +62,112 @@ videos.
python -m mediapipe.examples.desktop.youtube8m.generate_vggish_frozen_graph
```
4. Generate a MediaSequence metadata from the input video
4. Generate a MediaSequence metadata from the input video.
Note: the output file is /tmp/mediapipe/metadata.tfrecord
Note: the output file is /tmp/mediapipe/metadata.pb
```bash
# change clip_end_time_sec to match the length of your video.
python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \
--path_to_input_video=/absolute/path/to/the/local/video/file
--path_to_input_video=/absolute/path/to/the/local/video/file \
--clip_end_time_sec=120
```
5. Run the MediaPipe binary to extract the features
5. Run the MediaPipe binary to extract the features.
```bash
bazel build -c opt \
--define MEDIAPIPE_DISABLE_GPU=1 --define no_aws_support=true \
mediapipe/examples/desktop/youtube8m:extract_yt8m_features
./bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features \
--calculator_graph_config_file=mediapipe/graphs/youtube8m/feature_extraction.pbtxt \
--input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.tfrecord \
--output_side_packets=output_sequence_example=/tmp/mediapipe/output.tfrecord
--input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.pb \
--output_side_packets=output_sequence_example=/tmp/mediapipe/features.pb
```
6. [Optional] Read the features.pb in Python.
```
import tensorflow as tf
sequence_example = open('/tmp/mediapipe/features.pb', 'rb').read()
print(tf.train.SequenceExample.FromString(sequence_example))
```
## Model Inference for YouTube-8M Challenge
MediaPipe can help you do model inference for YouTube-8M Challenge with both
local videos and the YouTube-8M dataset. To visualize
[the graph for local videos](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt)
and
[the graph for the YouTube-8M dataset](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/youtube8m/yt8m_dataset_model_inference.pbtxt),
copy the text specification of the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). We use the baseline model
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
in our example. But, the model inference pipeline is highly customizable. You
are welcome to add new calculators or use your own machine learning models to do
the inference for both local videos and the dataset
### Steps to run the YouTube-8M model inference graph with Web Interface
1. Copy the baseline model
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
to local.
```bash
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
```
2. Build the inference binary.
```bash
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
mediapipe/examples/desktop/youtube8m:model_inference
```
3. Run the python web server.
Note: pip install absl-py
```bash
python mediapipe/examples/desktop/youtube8m/viewer/server.py --root `pwd`
```
Navigate to localhost:8008 in a web browser.
[Here](https://drive.google.com/file/d/19GSvdAAuAlACpBhHOaqMWZ_9p8bLUYKh/view?usp=sharing)
is a demo video showing the steps to use this web application. Also please
read
[youtube8m/README.md](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/youtube8m/README.md)
if you prefer to run the underlying model_inference binary in command line.
### Steps to run the YouTube-8M model inference graph with a local video
1. Make sure you have the features.pb from the feature extraction pipeline.
2. Copy the baseline model
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
to local.
```bash
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
```
3. Build and run the inference binary.
```bash
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
mediapipe/examples/desktop/youtube8m:model_inference
# segment_size is the number of seconds window of frames.
# overlap is the number of seconds adjacent segments share.
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/model_inference \
--calculator_graph_config_file=mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt \
--input_side_packets=input_sequence_example_path=/tmp/mediapipe/features.pb,input_video_path=/absolute/path/to/the/local/video/file,output_video_path=/tmp/mediapipe/annotated_video.mp4,segment_size=5,overlap=4
```
4. View the annotated video.

View File

@ -0,0 +1,33 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.apps.multihandtrackinggpu">
<uses-sdk
android:minSdkVersion="21"
android:targetSdkVersion="27" />
<!-- For using the camera -->
<uses-permission android:name="android.permission.CAMERA" />
<uses-feature android:name="android.hardware.camera" />
<uses-feature android:name="android.hardware.camera.autofocus" />
<!-- For MediaPipe -->
<uses-feature android:glEsVersion="0x00020000" android:required="true" />
<application
android:allowBackup="true"
android:label="@string/app_name"
android:supportsRtl="true"
android:theme="@style/AppTheme">
<activity
android:name=".MainActivity"
android:exported="true"
android:screenOrientation="portrait">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>

View File

@ -0,0 +1,103 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
cc_binary(
name = "libmediapipe_jni.so",
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/graphs/hand_tracking:multi_hand_mobile_calculators",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
],
)
cc_library(
name = "mediapipe_jni_lib",
srcs = [":libmediapipe_jni.so"],
alwayslink = 1,
)
# Maps the binary graph to an alias (e.g., the app name) for convenience so that the alias can be
# easily incorporated into the app via, for example,
# MainActivity.BINARY_GRAPH_NAME = "appname.binarypb".
genrule(
name = "binary_graph",
srcs = ["//mediapipe/graphs/hand_tracking:multi_hand_tracking_mobile_gpu_binary_graph"],
outs = ["multihandtrackinggpu.binarypb"],
cmd = "cp $< $@",
)
# To use the 3D model instead of the default 2D model, add "--define 3D=true" to the
# bazel build command.
config_setting(
name = "use_3d_model",
define_values = {
"3D": "true",
},
)
genrule(
name = "model",
srcs = select({
"//conditions:default": ["//mediapipe/models:hand_landmark.tflite"],
":use_3d_model": ["//mediapipe/models:hand_landmark_3d.tflite"],
}),
outs = ["hand_landmark.tflite"],
cmd = "cp $< $@",
)
android_library(
name = "mediapipe_lib",
srcs = glob(["*.java"]),
assets = [
":binary_graph",
":model",
"//mediapipe/models:palm_detection.tflite",
"//mediapipe/models:palm_detection_labelmap.txt",
],
assets_dir = "",
manifest = "AndroidManifest.xml",
resource_files = glob(["res/**"]),
deps = [
":mediapipe_jni_lib",
"//mediapipe/java/com/google/mediapipe/components:android_camerax_helper",
"//mediapipe/java/com/google/mediapipe/components:android_components",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/glutil",
"//third_party:androidx_appcompat",
"//third_party:androidx_constraint_layout",
"//third_party:androidx_legacy_support_v4",
"//third_party:androidx_material",
"//third_party:androidx_recyclerview",
"//third_party:opencv",
"@androidx_concurrent_futures//jar",
"@androidx_lifecycle//jar",
"@com_google_code_findbugs//jar",
"@com_google_guava_android//jar",
],
)
android_binary(
name = "multihandtrackinggpu",
manifest = "AndroidManifest.xml",
manifest_values = {"applicationId": "com.google.mediapipe.apps.multihandtrackinggpu"},
multidex = "native",
deps = [
":mediapipe_lib",
],
)

View File

@ -0,0 +1,167 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.apps.multihandtrackinggpu;
import android.graphics.SurfaceTexture;
import android.os.Bundle;
import androidx.appcompat.app.AppCompatActivity;
import android.util.Size;
import android.view.SurfaceHolder;
import android.view.SurfaceView;
import android.view.View;
import android.view.ViewGroup;
import com.google.mediapipe.components.CameraHelper;
import com.google.mediapipe.components.CameraXPreviewHelper;
import com.google.mediapipe.components.ExternalTextureConverter;
import com.google.mediapipe.components.FrameProcessor;
import com.google.mediapipe.components.PermissionHelper;
import com.google.mediapipe.framework.AndroidAssetUtil;
import com.google.mediapipe.glutil.EglManager;
/** Main activity of MediaPipe example apps. */
public class MainActivity extends AppCompatActivity {
private static final String TAG = "MainActivity";
private static final String BINARY_GRAPH_NAME = "multihandtrackinggpu.binarypb";
private static final String INPUT_VIDEO_STREAM_NAME = "input_video";
private static final String OUTPUT_VIDEO_STREAM_NAME = "output_video";
private static final CameraHelper.CameraFacing CAMERA_FACING = CameraHelper.CameraFacing.FRONT;
// Flips the camera-preview frames vertically before sending them into FrameProcessor to be
// processed in a MediaPipe graph, and flips the processed frames back when they are displayed.
// This is needed because OpenGL represents images assuming the image origin is at the bottom-left
// corner, whereas MediaPipe in general assumes the image origin is at top-left.
private static final boolean FLIP_FRAMES_VERTICALLY = true;
static {
// Load all native libraries needed by the app.
System.loadLibrary("mediapipe_jni");
System.loadLibrary("opencv_java4");
}
// {@link SurfaceTexture} where the camera-preview frames can be accessed.
private SurfaceTexture previewFrameTexture;
// {@link SurfaceView} that displays the camera-preview frames processed by a MediaPipe graph.
private SurfaceView previewDisplayView;
// Creates and manages an {@link EGLContext}.
private EglManager eglManager;
// Sends camera-preview frames into a MediaPipe graph for processing, and displays the processed
// frames onto a {@link Surface}.
private FrameProcessor processor;
// Converts the GL_TEXTURE_EXTERNAL_OES texture from Android camera into a regular texture to be
// consumed by {@link FrameProcessor} and the underlying MediaPipe graph.
private ExternalTextureConverter converter;
// Handles camera access via the {@link CameraX} Jetpack support library.
private CameraXPreviewHelper cameraHelper;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
previewDisplayView = new SurfaceView(this);
setupPreviewDisplayView();
// Initialize asset manager so that MediaPipe native libraries can access the app assets, e.g.,
// binary graphs.
AndroidAssetUtil.initializeNativeAssetManager(this);
eglManager = new EglManager(null);
processor =
new FrameProcessor(
this,
eglManager.getNativeContext(),
BINARY_GRAPH_NAME,
INPUT_VIDEO_STREAM_NAME,
OUTPUT_VIDEO_STREAM_NAME);
processor.getVideoSurfaceOutput().setFlipY(FLIP_FRAMES_VERTICALLY);
PermissionHelper.checkAndRequestCameraPermissions(this);
}
@Override
protected void onResume() {
super.onResume();
converter = new ExternalTextureConverter(eglManager.getContext());
converter.setFlipY(FLIP_FRAMES_VERTICALLY);
converter.setConsumer(processor);
if (PermissionHelper.cameraPermissionsGranted(this)) {
startCamera();
}
}
@Override
protected void onPause() {
super.onPause();
converter.close();
}
@Override
public void onRequestPermissionsResult(
int requestCode, String[] permissions, int[] grantResults) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults);
PermissionHelper.onRequestPermissionsResult(requestCode, permissions, grantResults);
}
private void setupPreviewDisplayView() {
previewDisplayView.setVisibility(View.GONE);
ViewGroup viewGroup = findViewById(R.id.preview_display_layout);
viewGroup.addView(previewDisplayView);
previewDisplayView
.getHolder()
.addCallback(
new SurfaceHolder.Callback() {
@Override
public void surfaceCreated(SurfaceHolder holder) {
processor.getVideoSurfaceOutput().setSurface(holder.getSurface());
}
@Override
public void surfaceChanged(SurfaceHolder holder, int format, int width, int height) {
// (Re-)Compute the ideal size of the camera-preview display (the area that the
// camera-preview frames get rendered onto, potentially with scaling and rotation)
// based on the size of the SurfaceView that contains the display.
Size viewSize = new Size(width, height);
Size displaySize = cameraHelper.computeDisplaySizeFromViewSize(viewSize);
// Connect the converter to the camera-preview frames as its input (via
// previewFrameTexture), and configure the output width and height as the computed
// display size.
converter.setSurfaceTextureAndAttachToGLContext(
previewFrameTexture, displaySize.getWidth(), displaySize.getHeight());
}
@Override
public void surfaceDestroyed(SurfaceHolder holder) {
processor.getVideoSurfaceOutput().setSurface(null);
}
});
}
private void startCamera() {
cameraHelper = new CameraXPreviewHelper();
cameraHelper.setOnCameraStartedListener(
surfaceTexture -> {
previewFrameTexture = surfaceTexture;
// Make the display view visible to start showing the preview. This triggers the
// SurfaceHolder.Callback added to (the holder of) previewDisplayView.
previewDisplayView.setVisibility(View.VISIBLE);
});
cameraHelper.startCamera(this, CAMERA_FACING, /*surfaceTexture=*/ null);
}
}

View File

@ -0,0 +1,20 @@
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent">
<FrameLayout
android:id="@+id/preview_display_layout"
android:layout_width="fill_parent"
android:layout_height="fill_parent"
android:layout_weight="1">
<TextView
android:id="@+id/no_camera_access_view"
android:layout_height="fill_parent"
android:layout_width="fill_parent"
android:gravity="center"
android:text="@string/no_camera_access" />
</FrameLayout>
</androidx.constraintlayout.widget.ConstraintLayout>

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="colorPrimary">#008577</color>
<color name="colorPrimaryDark">#00574B</color>
<color name="colorAccent">#D81B60</color>
</resources>

Some files were not shown because too many files have changed in this diff Show More